In [1]:
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
import sys

import matplotlib.pyplot as plt
import IPython.display as ipd

from tqdm import tqdm

from common import *
from dataset import ArrhythmiaDataset

from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd

from torch.utils.tensorboard import SummaryWriter


RECORD_DIR_PATH = '../data/mit-bih-arrhythmia-database-1.0.0'
WINDOW_SIZE = 540
CLASSES = ['N', 'L', 'R', 'a', 'V', 'J', 'F']

# Classes: 'N', 'L', 'R', 'A', 'a', 'V', 'j', 'J', 'E', 'f', 'F', '[', '!', ']', '/', 'x', '|', 'Q', 'S', 'e'
# TODO: S, e - need some preprocessing, dimensions seem to be wrong in one of these
# TODO: Q - of course, quite confusing, this is the most confused beat in confusion matrices
# TODO: Manual labeled samples seem to have R-peak dislocated with respect to original dataset; this MUST not happen, as original dataset has them marked and corrected by cardiologists over 30 years

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
# Randomness seed
random_seed = 1 # or any of your favorite number
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)

In [4]:
dataset = ArrhythmiaDataset(RECORD_DIR_PATH, WINDOW_SIZE, only_include_labels = CLASSES, include_manual_labels = True)

print(dataset.data.shape)
print(len(dataset.labels))

filename='V.csv' Unique_keys=50
filename='L.csv' Unique_keys=100
filename='R.csv' Unique_keys=150
filename='J.csv' Unique_keys=200
filename='a.csv' Unique_keys=253
filename='F.csv' Unique_keys=303
filename='N.csv' Unique_keys=353
row.sample=370 manual_label_dict_sample_key=370
row.sample=662 manual_label_dict_sample_key=663
row.sample=946 manual_label_dict_sample_key=947
row.sample=1231 manual_label_dict_sample_key=1231
row.sample=1515 manual_label_dict_sample_key=1515
row.sample=1809 manual_label_dict_sample_key=1809
row.sample=2402 manual_label_dict_sample_key=2403
row.sample=2706 manual_label_dict_sample_key=2706
row.sample=2998 manual_label_dict_sample_key=2998
row.sample=3282 manual_label_dict_sample_key=3283
row.sample=3560 manual_label_dict_sample_key=3560
row.sample=3862 manual_label_dict_sample_key=3863
row.sample=4170 manual_label_dict_sample_key=4171
row.sample=4466 manual_label_dict_sample_key=4466
row.sample=4764 manual_label_dict_sample_key=4765
row.sample=5060 manual_lab

In [5]:
labels, counts = torch.unique(dataset.labels_encoded, dim = 0, return_counts = True)

for label, count in zip(labels, counts):
    print(f'{dataset.get_label_from_tensor(label)}: {count}')


R: 4
N: 71
V: 19
J: 28
L: 11
F: 7
a: 22


In [6]:
# Drop some Normal beats to balance classes
normal_beat_mask = np.array(dataset.labels) == 'N'

new_labels = []
for idx, l in enumerate(normal_beat_mask):
    # TODO: Change this when there's more samples
    if l and random.uniform(0, 1) < 0.33:
        normal_beat_mask[idx] = False
    if not normal_beat_mask[idx]:
        new_labels.append(dataset.labels[idx])

new_data = dataset.data[normal_beat_mask == False]
dataset.data = new_data
dataset.labels = new_labels
dataset.encode_labels()

def show_class_count(dataset: ArrhythmiaDataset):
    print(dataset.data.shape)
    print(len(dataset.labels))
    labels, counts = torch.unique(dataset.labels_encoded, dim = 0, return_counts = True)

    for label, count in zip(labels, counts):
        print(f'{dataset.get_label_from_tensor(label)}: {count}')

show_class_count(dataset)

torch.Size([118, 6, 1080])
118
R: 4
N: 27
V: 19
J: 28
L: 11
F: 7
a: 22


In [7]:
def collate_fn(batch):

    # A data tuple has the form:
    # waveform, one-hot-encoded_label
    tensors, targets = [], []

    # Gather in lists, and encode labels as indices
    for waveform, label in batch:
        tensors += [waveform]
        targets += [label]

    # Group the list of tensors into a batched tensor
    tensors = torch.stack(tensors)
    tensors = tensors[:, :]
    targets = torch.stack(targets)

    return tensors, targets


batch_size = 256

if device == "cuda":
    num_workers = 1
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

train_dataset, test_dataset = dataset.train_test_split(0.2)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

print('TRAIN DATASET:')
show_class_count(train_dataset)

print('TEST DATASET:')
show_class_count(test_dataset)

TRAIN DATASET:
torch.Size([94, 6, 1080])
0
R: 3
N: 21
V: 15
J: 22
L: 9
F: 6
a: 18
TEST DATASET:
torch.Size([24, 6, 1080])
0
R: 1
N: 6
V: 4
J: 6
L: 2
F: 1
a: 4


In [8]:
class M5(nn.Module):
    def __init__(self, n_input=1, n_output=35, stride=1, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=3, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(2)
        self.conv3 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(n_channel)
        self.pool3 = nn.MaxPool1d(3)
        self.conv4 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(3)
        self.conv5 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn5 = nn.BatchNorm1d(2 * n_channel)
        self.pool5 = nn.MaxPool1d(3)
        self.conv6 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn6 = nn.BatchNorm1d(2 * n_channel)
        self.pool6 = nn.MaxPool1d(3)
        self.fc1 = nn.Linear(2 * n_channel, n_channel)
        self.fc2 = nn.Linear(n_channel, n_output)

    def forward(self, x):
        # print(f'CONV1 INPUT SHAPE: {x.shape}')
        x = self.conv1(x)
        # print(f'CONV1 OUTPUT SHAPE: {x.shape}')
        x = F.relu(self.bn1(x))
        # print(f'POOL1 INPUT SHAPE: {x.shape}')
        x = self.pool1(x)
        # print(f'POOL1 OUTPUT SHAPE: {x.shape}')
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        # print(f'POOL2 INPUT SHAPE: {x.shape}')
        x = self.pool2(x)
        # print(f'POOL2 OUTPUT SHAPE: {x.shape}')
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        # print(f'POOL3 INPUT SHAPE: {x.shape}')
        x = self.pool3(x)
        # print(f'POOL3 OUTPUT SHAPE: {x.shape}')
        x = self.conv4(x)
        # print(f'BATCHNORM4 INPUT SHAPE: {x.shape}')
        x = F.relu(self.bn4(x))
        # print(f'POOL4 INPUT SHAPE: {x.shape}')
        x = self.pool4(x)
        # print(f'POOL4 OUTPUT SHAPE: {x.shape}')
        x = self.conv5(x)
        # print(f'BATCHNORM5 INPUT SHAPE: {x.shape}')
        x = F.relu(self.bn5(x))
        # print(f'POOL5 INPUT SHAPE: {x.shape}')
        x = self.pool5(x)
        # print(f'POOL5 OUTPUT SHAPE: {x.shape}')
        x = self.conv6(x)
        # print(f'BATCHNORM6 INPUT SHAPE: {x.shape}')
        x = F.relu(self.bn6(x))
        # print(f'POOL6 INPUT SHAPE: {x.shape}')
        x = self.pool6(x)
        # print(f'POOL6 OUTPUT SHAPE: {x.shape}')
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=2)


model = M5(n_input=6, n_output=len(set(dataset.labels)))
model.double().to(device)
print(model)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


n = count_parameters(model)
print("Number of parameters: %s" % n)

M5(
  (conv1): Conv1d(6, 32, kernel_size=(3,), stride=(1,))
  (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
  (bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
  (bn3): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv1d(32, 64, kernel_size=(3,), stride=(1,))
  (bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool4): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (conv5): Conv1d(64, 64, kernel_size=(3,), stride=(1,))
  (bn5): Bat

In [9]:
optimizer = optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.7)  # reduce the learning after 20 epochs by a factor

In [10]:
def train(model, epoch, log_interval, writer: SummaryWriter):
    train_losses = []
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):

        data = data.to(device)
        # print(f'DATA SHAPE: {data.shape}')
        target = target.to(device)

        # apply transform and model on whole batch directly on device
        output = model(data)

        # negative log-likelihood for a tensor of size (batch x 1 x n_output)
        squeezed_output = output.squeeze()
        loss = F.nll_loss(squeezed_output, target.argmax(dim = 1))

        writer.add_scalar('Train loss', loss.item(), epoch * len(train_loader.dataset) + batch_idx)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print training stats
        if batch_idx % log_interval == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

        # update progress bar
        pbar.update(pbar_update)
        # record loss
        train_losses.append(loss.item())
    return train_losses

In [11]:
def number_of_correct(pred, target):
    # count number of correct predictions
    return pred.squeeze().eq(target).sum().item()


def get_likely_index(tensor):
    # find most likely label index for each element in the batch
    return tensor.argmax(dim=-1)


def test(model, epoch, writer: SummaryWriter):
    model.eval()
    correct = 0
    y_true = []
    y_pred = []
    for data, target in test_loader:

        data = data.to(device)
        target = target.to(device)

        output = model(data)

        pred = get_likely_index(output)
        correct += number_of_correct(pred, target.argmax(dim = 1))

        y_true.extend(pred.squeeze().data.cpu().numpy())
        y_pred.extend(target.data.cpu().numpy().argmax(axis = 1))

        # update progress bar
        pbar.update(pbar_update)
    accuracy = 100. * correct / len(test_loader.dataset)
    writer.add_scalar('Test accuracy', accuracy, epoch)

    # Build confusion matrix
    cf_matrix = confusion_matrix(y_true, y_pred)
    df_cm = pd.DataFrame(cf_matrix, index = [i for i in CLASSES],
                         columns = [i for i in CLASSES])
    plt.figure(figsize = (12,7))
    cf_matrix_figure = sn.heatmap(df_cm, annot=True).get_figure()
    writer.add_figure('Test confusion matrix', cf_matrix_figure, epoch)

    print(f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.0f}%)\n")
    return accuracy

In [12]:
writer = SummaryWriter()

log_interval = 20
n_epoch = 300

CHECKPOINT_PATH = 'checkpoint.pt'
ACCURACY_MOVING_AVERAGE_SIZE = 30  # moving average for accuracy to check if performance degraded

pbar_update = 1 / (len(train_loader) + len(test_loader))
losses = []
accuracies = []

with tqdm(total=n_epoch) as pbar:
    for epoch in range(1, n_epoch + 1):
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, CHECKPOINT_PATH)

        train_losses = train(model, epoch, log_interval, writer)
        losses.extend(train_losses)

        accuracy = test(model, epoch, writer)
        accuracies.append(accuracy)
        scheduler.step()

        # Early stopping
        if len(accuracies) >= ACCURACY_MOVING_AVERAGE_SIZE + 1:
            is_performance_degraded = np.mean(accuracies[-ACCURACY_MOVING_AVERAGE_SIZE - 1:-1]) > np.mean(accuracies[-ACCURACY_MOVING_AVERAGE_SIZE:])
            if is_performance_degraded:
                # Reload the last non-degraded checkpoint
                checkpoint = torch.load(CHECKPOINT_PATH)
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                break


# Let's plot the training loss versus the number of iteration.
# plt.plot(losses);
# plt.title("training loss");

  0%|          | 1.5/300 [00:00<02:16,  2.19it/s]


Test Epoch: 1	Accuracy: 4/24 (17%)



  1%|          | 2.5/300 [00:00<01:33,  3.19it/s]


Test Epoch: 2	Accuracy: 4/24 (17%)


Test Epoch: 3	Accuracy: 4/24 (17%)



  2%|▏         | 4.5/300 [00:01<01:09,  4.26it/s]


Test Epoch: 4	Accuracy: 2/24 (8%)



  2%|▏         | 5.5/300 [00:01<01:04,  4.58it/s]


Test Epoch: 5	Accuracy: 2/24 (8%)


Test Epoch: 6	Accuracy: 2/24 (8%)



  2%|▏         | 6.5/300 [00:01<01:01,  4.81it/s]


Test Epoch: 7	Accuracy: 2/24 (8%)



  3%|▎         | 8.5/300 [00:02<01:01,  4.72it/s]


Test Epoch: 8	Accuracy: 2/24 (8%)



  3%|▎         | 9.5/300 [00:02<00:58,  4.97it/s]


Test Epoch: 9	Accuracy: 2/24 (8%)


Test Epoch: 10	Accuracy: 2/24 (8%)



  4%|▍         | 11.5/300 [00:02<00:54,  5.31it/s]


Test Epoch: 11	Accuracy: 2/24 (8%)



  4%|▍         | 12.5/300 [00:02<00:55,  5.19it/s]


Test Epoch: 12	Accuracy: 2/24 (8%)


Test Epoch: 13	Accuracy: 2/24 (8%)



  5%|▍         | 14.5/300 [00:03<00:53,  5.36it/s]


Test Epoch: 14	Accuracy: 2/24 (8%)



  5%|▌         | 15.5/300 [00:03<00:52,  5.40it/s]


Test Epoch: 15	Accuracy: 2/24 (8%)



  6%|▌         | 16.5/300 [00:03<00:59,  4.75it/s]


Test Epoch: 16	Accuracy: 4/24 (17%)


Test Epoch: 17	Accuracy: 4/24 (17%)



  6%|▌         | 18.5/300 [00:04<00:55,  5.03it/s]


Test Epoch: 18	Accuracy: 4/24 (17%)



  6%|▋         | 19.5/300 [00:04<00:54,  5.13it/s]


Test Epoch: 19	Accuracy: 4/24 (17%)


Test Epoch: 20	Accuracy: 4/24 (17%)



  7%|▋         | 21.5/300 [00:04<00:52,  5.34it/s]


Test Epoch: 21	Accuracy: 4/24 (17%)



  8%|▊         | 22.5/300 [00:04<00:50,  5.46it/s]


Test Epoch: 22	Accuracy: 4/24 (17%)


Test Epoch: 23	Accuracy: 4/24 (17%)



  8%|▊         | 24.5/300 [00:05<00:49,  5.55it/s]


Test Epoch: 24	Accuracy: 4/24 (17%)



  8%|▊         | 25.5/300 [00:05<00:49,  5.57it/s]


Test Epoch: 25	Accuracy: 4/24 (17%)


Test Epoch: 26	Accuracy: 4/24 (17%)



  9%|▉         | 26.5/300 [00:05<00:49,  5.57it/s]



  9%|▉         | 27.5/300 [00:05<00:58,  4.63it/s]


Test Epoch: 27	Accuracy: 4/24 (17%)


Test Epoch: 28	Accuracy: 4/24 (17%)



 10%|▉         | 29.5/300 [00:06<00:53,  5.08it/s]


Test Epoch: 29	Accuracy: 5/24 (21%)



 10%|█         | 30.5/300 [00:06<00:51,  5.26it/s]


Test Epoch: 30	Accuracy: 5/24 (21%)


Test Epoch: 31	Accuracy: 7/24 (29%)



 11%|█         | 32.5/300 [00:06<00:50,  5.30it/s]


Test Epoch: 32	Accuracy: 7/24 (29%)



 11%|█         | 33.5/300 [00:06<00:50,  5.28it/s]


Test Epoch: 33	Accuracy: 9/24 (38%)


Test Epoch: 34	Accuracy: 9/24 (38%)



 12%|█▏        | 35.5/300 [00:07<00:48,  5.40it/s]


Test Epoch: 35	Accuracy: 9/24 (38%)



 12%|█▏        | 36.5/300 [00:07<00:49,  5.32it/s]


Test Epoch: 36	Accuracy: 9/24 (38%)


Test Epoch: 37	Accuracy: 10/24 (42%)



 13%|█▎        | 38.5/300 [00:07<00:49,  5.31it/s]


Test Epoch: 38	Accuracy: 10/24 (42%)



 13%|█▎        | 39.5/300 [00:08<00:48,  5.33it/s]


Test Epoch: 39	Accuracy: 10/24 (42%)


Test Epoch: 40	Accuracy: 10/24 (42%)



 14%|█▎        | 40.5/300 [00:08<00:48,  5.38it/s]



 14%|█▍        | 41.5/300 [00:08<01:01,  4.22it/s]


Test Epoch: 41	Accuracy: 10/24 (42%)



 14%|█▍        | 42.5/300 [00:08<01:00,  4.28it/s]


Test Epoch: 42	Accuracy: 10/24 (42%)



 14%|█▍        | 43.5/300 [00:09<00:56,  4.54it/s]


Test Epoch: 43	Accuracy: 14/24 (58%)


Test Epoch: 44	Accuracy: 15/24 (62%)



 15%|█▌        | 45.5/300 [00:09<00:51,  4.90it/s]


Test Epoch: 45	Accuracy: 15/24 (62%)



 16%|█▌        | 46.5/300 [00:09<00:50,  4.99it/s]


Test Epoch: 46	Accuracy: 17/24 (71%)


Test Epoch: 47	Accuracy: 19/24 (79%)



 16%|█▌        | 48.5/300 [00:09<00:50,  4.93it/s]


Test Epoch: 48	Accuracy: 19/24 (79%)



 16%|█▋        | 49.5/300 [00:10<00:52,  4.80it/s]


Test Epoch: 49	Accuracy: 19/24 (79%)



 17%|█▋        | 50.5/300 [00:10<00:50,  4.90it/s]


Test Epoch: 50	Accuracy: 19/24 (79%)



 17%|█▋        | 51.5/300 [00:10<00:52,  4.74it/s]


Test Epoch: 51	Accuracy: 18/24 (75%)



 18%|█▊        | 52.5/300 [00:10<00:51,  4.83it/s]


Test Epoch: 52	Accuracy: 18/24 (75%)


Test Epoch: 53	Accuracy: 18/24 (75%)



 18%|█▊        | 54.5/300 [00:11<00:48,  5.09it/s]


Test Epoch: 54	Accuracy: 18/24 (75%)



 18%|█▊        | 55.5/300 [00:11<00:48,  5.08it/s]


Test Epoch: 55	Accuracy: 18/24 (75%)



 19%|█▉        | 56.5/300 [00:11<00:50,  4.84it/s]


Test Epoch: 56	Accuracy: 17/24 (71%)



 19%|█▉        | 57.5/300 [00:12<01:02,  3.86it/s]


Test Epoch: 57	Accuracy: 17/24 (71%)


Test Epoch: 58	Accuracy: 17/24 (71%)



 20%|█▉        | 59.5/300 [00:12<00:53,  4.46it/s]


Test Epoch: 59	Accuracy: 16/24 (67%)



 20%|██        | 60.5/300 [00:12<00:51,  4.68it/s]


Test Epoch: 60	Accuracy: 15/24 (62%)


Test Epoch: 61	Accuracy: 15/24 (62%)



 21%|██        | 62.5/300 [00:12<00:47,  5.04it/s]


Test Epoch: 62	Accuracy: 15/24 (62%)



 21%|██        | 63.5/300 [00:13<00:45,  5.19it/s]


Test Epoch: 63	Accuracy: 15/24 (62%)


Test Epoch: 64	Accuracy: 15/24 (62%)



 22%|██▏       | 65.5/300 [00:13<00:44,  5.26it/s]


Test Epoch: 65	Accuracy: 15/24 (62%)



 22%|██▏       | 66.5/300 [00:13<00:44,  5.26it/s]


Test Epoch: 66	Accuracy: 15/24 (62%)


Test Epoch: 67	Accuracy: 15/24 (62%)



 22%|██▎       | 67.5/300 [00:13<00:43,  5.29it/s]


Test Epoch: 68	Accuracy: 15/24 (62%)



 23%|██▎       | 69.5/300 [00:14<00:44,  5.15it/s]


Test Epoch: 69	Accuracy: 15/24 (62%)



 24%|██▎       | 70.5/300 [00:14<00:43,  5.27it/s]


Test Epoch: 70	Accuracy: 15/24 (62%)


Test Epoch: 71	Accuracy: 15/24 (62%)



 24%|██▍       | 71.5/300 [00:14<00:44,  5.18it/s]


Test Epoch: 72	Accuracy: 15/24 (62%)



 24%|██▍       | 73.5/300 [00:15<00:44,  5.11it/s]


Test Epoch: 73	Accuracy: 15/24 (62%)



 25%|██▍       | 74.5/300 [00:15<00:43,  5.16it/s]


Test Epoch: 74	Accuracy: 15/24 (62%)


Test Epoch: 75	Accuracy: 15/24 (62%)



 25%|██▌       | 76.0/300 [00:15<00:45,  4.87it/s]


Test Epoch: 76	Accuracy: 15/24 (62%)




