In [1]:
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
import sys

import matplotlib.pyplot as plt
import IPython.display as ipd

from tqdm import tqdm

from common import *
from dataset import ArrhythmiaDataset

from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd

from torch.utils.tensorboard import SummaryWriter


RECORD_DIR_PATH = '../data/mit-bih-arrhythmia-database-1.0.0'
WINDOW_SIZE = 540
MOVING_AVERAGE_RANGE = 17
INCLUDE_MANUAL_LABELS = False
SUBSET_FROM_MANUAL_LABELS = True
INCLUDE_RAW_SIGNAL = True
CLASSES = ['N', 'L', 'R', 'a', 'V', 'J', 'F'] if SUBSET_FROM_MANUAL_LABELS else ['N', 'L', 'R', 'A', 'a', 'V', 'j', 'J', 'E', 'f', 'F', '[', '!', ']', '/', 'x', '|', 'Q']

# TODO: S, e - need some preprocessing, dimensions seem to be wrong in one of these
# TODO: Q - of course, quite confusing, this is the most confused beat in confusion matrices

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
# Randomness seed
random_seed = 1 # or any of your favorite number
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)

In [4]:
dataset = ArrhythmiaDataset(RECORD_DIR_PATH, WINDOW_SIZE, only_include_labels = CLASSES, moving_average_range = MOVING_AVERAGE_RANGE, include_manual_labels = INCLUDE_MANUAL_LABELS, subset_from_manual_labels = SUBSET_FROM_MANUAL_LABELS, include_raw_signal =
INCLUDE_RAW_SIGNAL)

print(dataset.data.shape)
print(len(dataset.labels))

filename='V.csv' Unique_keys=50
filename='L.csv' Unique_keys=100
filename='R.csv' Unique_keys=150
filename='J.csv' Unique_keys=218
filename='a.csv' Unique_keys=271
filename='F.csv' Unique_keys=321
filename='N.csv' Unique_keys=371
filename='100.atr' patient_record_number=100
beat_slice_array.shape=(38, 2, 1080) beat_slices.shape=torch.Size([38, 2, 1080])
filename='124.atr' patient_record_number=124
beat_slice_array.shape=(20, 2, 1080) beat_slices.shape=torch.Size([20, 2, 1080])
self.data.shape=torch.Size([38, 2, 1080]) beat_slices.shape=torch.Size([20, 2, 1080])
filename='111.atr' patient_record_number=111
beat_slice_array.shape=(50, 2, 1080) beat_slices.shape=torch.Size([50, 2, 1080])
self.data.shape=torch.Size([58, 2, 1080]) beat_slices.shape=torch.Size([50, 2, 1080])
filename='208.atr' patient_record_number=208
beat_slice_array.shape=(50, 2, 1080) beat_slices.shape=torch.Size([50, 2, 1080])
self.data.shape=torch.Size([108, 2, 1080]) beat_slices.shape=torch.Size([50, 2, 1080])
filenam

In [5]:
labels, counts = torch.unique(dataset.labels_encoded, dim = 0, return_counts = True)

for label, count in zip(labels, counts):
    print(f'{dataset.get_label_from_tensor(label)}: {count}')


N: 49
L: 50
R: 50
a: 53
F: 50
J: 68
V: 50


In [6]:
# Drop some Normal beats to balance classes
normal_beat_mask = np.array(dataset.labels) == 'N'

new_labels = []
for idx, l in enumerate(normal_beat_mask):
    # Leave 100% samples
    if l and random.uniform(0, 1) < 1:
        normal_beat_mask[idx] = False
    if not normal_beat_mask[idx]:
        new_labels.append(dataset.labels[idx])

new_data = dataset.data[normal_beat_mask == False]
dataset.data = new_data
dataset.labels = new_labels
dataset.encode_labels()

def show_class_count(dataset: ArrhythmiaDataset):
    print(dataset.data.shape)
    print(len(dataset.labels))
    labels, counts = torch.unique(dataset.labels_encoded, dim = 0, return_counts = True)

    for label, count in zip(labels, counts):
        print(f'{dataset.get_label_from_tensor(label)}: {count}')

show_class_count(dataset)

torch.Size([370, 2, 1080])
370
N: 49
L: 50
R: 50
a: 53
F: 50
J: 68
V: 50


In [7]:
def collate_fn(batch):

    # A data tuple has the form:
    # waveform, one-hot-encoded_label

    tensors, targets = [], []

    # Gather in lists, and encode labels as indices
    for waveform, label in batch:
        tensors += [waveform]
        targets += [label]

    # Group the list of tensors into a batched tensor
    tensors = torch.stack(tensors)
    tensors = tensors[:, :]
    targets = torch.stack(targets)

    return tensors, targets


batch_size = 16

if device == "cuda":
    num_workers = 1
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

train_dataset, test_dataset = dataset.train_test_split(0.2)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

print('TRAIN DATASET:')
show_class_count(train_dataset)

print('TEST DATASET:')
show_class_count(test_dataset)

TRAIN DATASET:
torch.Size([296, 2, 1080])
0
N: 39
L: 40
R: 40
a: 42
F: 40
J: 55
V: 40
TEST DATASET:
torch.Size([74, 2, 1080])
0
N: 10
L: 10
R: 10
a: 11
F: 10
J: 13
V: 10


In [8]:
class M5(nn.Module):
    def __init__(self, n_input=1, n_output=35, stride=1, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=3, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(2)
        self.conv3 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(n_channel)
        self.pool3 = nn.MaxPool1d(3)
        self.conv4 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(3)
        self.conv5 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn5 = nn.BatchNorm1d(2 * n_channel)
        self.pool5 = nn.MaxPool1d(3)
        self.conv6 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn6 = nn.BatchNorm1d(2 * n_channel)
        self.pool6 = nn.MaxPool1d(3)
        self.fc1 = nn.Linear(2 * n_channel, n_channel)
        self.fc2 = nn.Linear(n_channel, n_output)

    def forward(self, x):
        # print(f'CONV1 INPUT SHAPE: {x.shape}')
        x = self.conv1(x)
        # print(f'CONV1 OUTPUT SHAPE: {x.shape}')
        x = F.relu(self.bn1(x))
        # print(f'POOL1 INPUT SHAPE: {x.shape}')
        x = self.pool1(x)
        # print(f'POOL1 OUTPUT SHAPE: {x.shape}')
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        # print(f'POOL2 INPUT SHAPE: {x.shape}')
        x = self.pool2(x)
        # print(f'POOL2 OUTPUT SHAPE: {x.shape}')
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        # print(f'POOL3 INPUT SHAPE: {x.shape}')
        x = self.pool3(x)
        # print(f'POOL3 OUTPUT SHAPE: {x.shape}')
        x = self.conv4(x)
        # print(f'BATCHNORM4 INPUT SHAPE: {x.shape}')
        x = F.relu(self.bn4(x))
        # print(f'POOL4 INPUT SHAPE: {x.shape}')
        x = self.pool4(x)
        # print(f'POOL4 OUTPUT SHAPE: {x.shape}')
        x = self.conv5(x)
        # print(f'BATCHNORM5 INPUT SHAPE: {x.shape}')
        x = F.relu(self.bn5(x))
        # print(f'POOL5 INPUT SHAPE: {x.shape}')
        x = self.pool5(x)
        # print(f'POOL5 OUTPUT SHAPE: {x.shape}')
        x = self.conv6(x)
        # print(f'BATCHNORM6 INPUT SHAPE: {x.shape}')
        x = F.relu(self.bn6(x))
        # print(f'POOL6 INPUT SHAPE: {x.shape}')
        x = self.pool6(x)
        # print(f'POOL6 OUTPUT SHAPE: {x.shape}')
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=2)


model = M5(n_input = 2, n_output = len(set(dataset.labels)))
model.double().to(device)
print(model)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


n = count_parameters(model)
print("Number of parameters: %s" % n)

M5(
  (conv1): Conv1d(2, 32, kernel_size=(3,), stride=(1,))
  (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
  (bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
  (bn3): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv1d(32, 64, kernel_size=(3,), stride=(1,))
  (bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool4): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (conv5): Conv1d(64, 64, kernel_size=(3,), stride=(1,))
  (bn5): Bat

In [9]:
optimizer = optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.7)  # reduce the learning after 20 epochs by a factor

In [10]:
def train(model, epoch, log_interval, writer: SummaryWriter):
    train_losses = []
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):

        data = data.to(device)
        # print(f'DATA SHAPE: {data.shape}')
        target = target.to(device)

        # apply transform and model on whole batch directly on device
        output = model(data)

        # negative log-likelihood for a tensor of size (batch x 1 x n_output)
        squeezed_output = output.squeeze()
        loss = F.nll_loss(squeezed_output, target.argmax(dim = 1))

        writer.add_scalar('Train loss', loss.item(), epoch * len(train_loader.dataset) + batch_idx)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print training stats
        if batch_idx % log_interval == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

        # update progress bar
        pbar.update(pbar_update)
        # record loss
        train_losses.append(loss.item())
    return train_losses

In [11]:
def number_of_correct(pred, target):
    # count number of correct predictions
    return pred.squeeze().eq(target).sum().item()


def get_likely_index(tensor):
    # find most likely label index for each element in the batch
    return tensor.argmax(dim=-1)


def test(model, epoch, writer: SummaryWriter):
    model.eval()
    correct = 0
    y_true = []
    y_pred = []
    for data, target in test_loader:

        data = data.to(device)
        target = target.to(device)

        output = model(data)

        pred = get_likely_index(output)
        correct += number_of_correct(pred, target.argmax(dim = 1))

        y_true.extend(pred.squeeze().data.cpu().numpy())
        y_pred.extend(target.data.cpu().numpy().argmax(axis = 1))

        # update progress bar
        pbar.update(pbar_update)
    accuracy = 100. * correct / len(test_loader.dataset)
    writer.add_scalar('Test accuracy', accuracy, epoch)

    # Build confusion matrix
    cf_matrix = confusion_matrix(y_true, y_pred)
    df_cm = pd.DataFrame(cf_matrix, index = [i for i in CLASSES],
                         columns = [i for i in CLASSES])
    plt.figure(figsize = (12,7))
    cf_matrix_figure = sn.heatmap(df_cm, annot=True).get_figure()
    writer.add_figure('Test confusion matrix', cf_matrix_figure, epoch)

    print(f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.0f}%)\n")
    return accuracy

In [12]:
writer = SummaryWriter()



log_interval = 20
n_epoch = 15

CHECKPOINT_PATH = '../models/moving_average_and_raw_signal - checkpoint.pt'
ACCURACY_MOVING_AVERAGE_SIZE = 30  # moving average for accuracy to check if performance degraded

writer.add_hparams({f'data_shape_{i}': shape for i, shape in enumerate(dataset.data.shape)} | {'data_moving_average_range': MOVING_AVERAGE_RANGE, 'data_window_size': WINDOW_SIZE, 'batch_size': batch_size, 'n_epoch': n_epoch}, {'hparam/fake_accuracy_just_to_have_any_metric': 10})

pbar_update = 1 / (len(train_loader) + len(test_loader))
losses = []
accuracies = []

with tqdm(total=n_epoch) as pbar:
    for epoch in range(1, n_epoch + 1):
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, CHECKPOINT_PATH)

        train_losses = train(model, epoch, log_interval, writer)
        losses.extend(train_losses)

        accuracy = test(model, epoch, writer)
        accuracies.append(accuracy)
        scheduler.step()

        # Early stopping
        if len(accuracies) >= ACCURACY_MOVING_AVERAGE_SIZE + 1:
            is_performance_degraded = np.mean(accuracies[-ACCURACY_MOVING_AVERAGE_SIZE - 1:-1]) > np.mean(accuracies[-ACCURACY_MOVING_AVERAGE_SIZE:])
            if is_performance_degraded:
                # Reload the last non-degraded checkpoint
                checkpoint = torch.load(CHECKPOINT_PATH)
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                break


# Let's plot the training loss versus the number of iteration.
# plt.plot(losses);
# plt.title("training loss");

  2%|▏         | 0.24999999999999997/15 [00:00<00:38,  2.62s/it] 



  9%|▊         | 1.2916666666666667/15 [00:01<00:11,  1.24it/s] 


Test Epoch: 1	Accuracy: 10/74 (14%)



 15%|█▌        | 2.2500000000000004/15 [00:02<00:08,  1.44it/s]


Test Epoch: 2	Accuracy: 29/74 (39%)



 22%|██▏       | 3.249999999999997/15 [00:02<00:07,  1.48it/s] 


Test Epoch: 3	Accuracy: 70/74 (95%)



 29%|██▊       | 4.2916666666666625/15 [00:03<00:07,  1.49it/s]


Test Epoch: 4	Accuracy: 73/74 (99%)



 35%|███▌      | 5.250000000000003/15 [00:04<00:06,  1.53it/s] 


Test Epoch: 5	Accuracy: 74/74 (100%)



 42%|████▏     | 6.291666666666677/15 [00:04<00:05,  1.49it/s] 


Test Epoch: 6	Accuracy: 74/74 (100%)



 48%|████▊     | 7.125000000000016/15 [00:05<00:07,  1.09it/s]


Test Epoch: 7	Accuracy: 74/74 (100%)



 55%|█████▌    | 8.250000000000018/15 [00:06<00:04,  1.45it/s] 


Test Epoch: 8	Accuracy: 74/74 (100%)



 62%|██████▏   | 9.29166666666667/15 [00:06<00:03,  1.68it/s] 


Test Epoch: 9	Accuracy: 74/74 (100%)



 69%|██████▊   | 10.291666666666655/15 [00:07<00:02,  1.83it/s]


Test Epoch: 10	Accuracy: 74/74 (100%)



 75%|███████▌  | 11.291666666666641/15 [00:07<00:01,  1.93it/s]


Test Epoch: 11	Accuracy: 74/74 (100%)



 82%|████████▏ | 12.291666666666627/15 [00:08<00:01,  1.90it/s]


Test Epoch: 12	Accuracy: 74/74 (100%)



 89%|████████▊ | 13.291666666666613/15 [00:09<00:00,  1.90it/s]


Test Epoch: 13	Accuracy: 74/74 (100%)



 95%|█████████▍| 14.249999999999932/15 [00:09<00:00,  1.69it/s]


Test Epoch: 14	Accuracy: 74/74 (100%)



100%|█████████▉| 14.999999999999922/15 [00:10<00:00,  1.50it/s]


Test Epoch: 15	Accuracy: 74/74 (100%)




