In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import random
from glob import glob
from tqdm import tqdm
from scipy.io import loadmat

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from sklearn.metrics import classification_report

In [2]:
data_dir = "/media/mountHDD3/data_storage/biomedical_data/ecg_data/SPH"
print(os.listdir(data_dir))

['metadata.csv', 'data_df.csv', 'data_df_no1.csv', 'records']


In [3]:
main_df = pd.read_csv(data_dir + "/data_df.csv")
main_df.shape

(20792, 4)

In [4]:
single_fns = main_df["File name"].values.tolist()
single_mat_paths = [data_dir + f"/records/{x}.h5" for x in single_fns]

In [5]:
ratio = [0.9, 0.1]

train_index = int(len(single_mat_paths)*ratio[0])

train_mat_paths = single_mat_paths[:train_index]
valid_mat_paths = single_mat_paths[train_index:]

train_label = main_df
valid_label = main_df

In [6]:
conv_block1 = nn.Sequential(
			nn.Conv1d(12, 64, 16, stride=1),
			nn.LeakyReLU(0.3),
			nn.BatchNorm1d(64),
		)

class Redu(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, padding):
        super().__init__()
        
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.kernel_size = kernel_size
        self.padding = padding
        
        self.conv1 = nn.Sequential(
            nn.MaxPool1d(4),
            nn.Conv1d(self.in_channel, self.out_channel, 1, 1)
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(self.in_channel, self.out_channel, self.kernel_size, 1),
            nn.LeakyReLU(),
            nn.BatchNorm1d(self.out_channel),
            nn.Dropout(0.2),
            nn.Conv1d(self.out_channel, self.out_channel, self.kernel_size, 4, self.padding)
        )
        self.conv3 = nn.Sequential(
            nn.LeakyReLU(),
            nn.BatchNorm1d(self.out_channel),
            nn.Dropout(0.2),
        )
    
    def forward(self, x1, x2):
        x1 = self.conv1(x1)
        x2 = self.conv2(x2)
        x_res = x1 + x2
        x3 = self.conv3(x_res)
        
        return x_res, x3

In [7]:
test = torch.randn(1, 12, 1400)
test = conv_block1(test)
model = Redu(64, 128, 3, 0)
a1, a2 = model(test, test)
print(type(a2))

<class 'torch.Tensor'>


In [8]:
class GlobalAvgPooling(nn.Module):
    def __init__(self):
        super(GlobalAvgPooling, self).__init__()

    def forward(self, x):
        return x.mean(dim=(2))

In [9]:
class LSAT(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.lstm = nn.LSTM(5, 5)
        
        self.conv1 = nn.Sequential(
            nn.Conv1d(960, 256, 1, 1),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
        )
    
        self.conv2 = nn.Sequential(
            nn.Conv1d(256, 128, 1, 1),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),                 
        )
        
        self.attention = nn.MultiheadAttention(5, 1, batch_first=True)
        
        self.conv3 = nn.Sequential(
            nn.LeakyReLU(0.3),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),
        )
        
        self.avgpo = GlobalAvgPooling()
            
        self.conv4 = nn.Sequential(
            nn.Linear(128, 64),
            nn.LeakyReLU(0.3),
            nn.Linear(64, 32),
            nn.LeakyReLU(0.3),
            nn.Linear(32, 34),
            nn.Sigmoid(),
        )
    
    def forward(self, x):
        x1, _ = self.lstm(x)
        x1 = self.conv1(x1)
        
        x1, _ = self.lstm(x1)
        x2 = self.conv2(x1)
        
        x3, _ = self.attention(x2, x2, x2)
        x3 = self.conv3(x3)
        
        x3 = self.avgpo(x2)
        x4 = self.conv4(x3)
        
        return x4   
    

In [10]:
# test1 = torch.randn(1, 960, 1400)
# model1 = LSAT()
# a = model1(test1)
# print(a.shape)

In [11]:
import h5py

class HeartData(Dataset):
    def __init__(self, data_paths):
        self.data_paths = data_paths
        random.shuffle(self.data_paths)

    def __getitem__(self, idx):
        data_path = self.data_paths[idx]
        a = h5py.File(data_path, 'r')
        data_h5 = a['ecg']
        data = np.array(data_h5)
        clip_data = data[:, -1400:]

        filename = data_path.split("/")[-1].split(".")[0]
        label = main_df[main_df["File name"] == filename]["New Label"].values.item()

        torch_data = torch.from_numpy(clip_data)

        return torch_data.float(), label

    def __len__(self):
        return len(self.data_paths)

In [12]:
train_ds = HeartData(train_mat_paths)
valid_ds = HeartData(valid_mat_paths)

In [13]:
train_dl = DataLoader(train_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)
valid_dl = DataLoader(valid_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)
# test_dl = DataLoader(test_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)

In [14]:
class HeartDNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv_block1 = nn.Sequential(
			nn.Conv1d(12, 64, 16, stride=1),
			nn.LeakyReLU(0.3),
			nn.BatchNorm1d(64),
		)
        self.model_redu11 = Redu(64, 128, 3, 0)
        self.model_redu12 = Redu(128, 192, 3, 0)
        self.model_redu13 = Redu(192, 256, 3, 0)
        self.model_redu14 = Redu(256, 320, 3, 0)
        
        self.model_redu21 = Redu(64, 128, 5, 2)
        self.model_redu22 = Redu(128, 192, 5, 2)
        self.model_redu23 = Redu(192, 256, 5, 2)
        self.model_redu24 = Redu(256, 320, 5, 2)
        
        self.model_redu31 = Redu(64, 128, 7, 4)
        self.model_redu32 = Redu(128, 192, 7, 4)
        self.model_redu33 = Redu(192, 256, 7, 4)
        self.model_redu34 = Redu(256, 320, 7, 4)
        
        self.model_LSAT = LSAT()
    
    
    def forward(self, x):  
    
        data_conv = self.conv_block1(x)

        data_x11, data_x12 = self.model_redu11(data_conv, data_conv)
        data_x13, data_x14 = self.model_redu12(data_x11, data_x12)
        data_x15, data_x16 = self.model_redu13(data_x13, data_x14)
        data_x17, data_x18 = self.model_redu14(data_x15, data_x16)

        
        data_x21, data_x22 = self.model_redu21(data_conv, data_conv)
        data_x23, data_x24 = self.model_redu22(data_x21, data_x22)
        data_x25, data_x26 = self.model_redu23(data_x23, data_x24)
        data_x27, data_x28 = self.model_redu24(data_x25, data_x26)

        
        data_x31, data_x32 = self.model_redu31(data_conv, data_conv)
        data_x33, data_x34 = self.model_redu32(data_x31, data_x32)
        data_x35, data_x36 = self.model_redu33(data_x33, data_x34)
        data_x37, data_x38 = self.model_redu34(data_x35, data_x36)
        
        data = torch.cat((data_x18, data_x28, data_x38), dim = 1)
   
        out = self.model_LSAT(data)

        return out

In [15]:
# test = torch.randn(1, 12, 1400)
model = HeartDNN()

In [16]:
epoch = 150
lr = 0.0005
best_acc = 0
best_ep = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu", index = 0)
model.to(device)
optimizer = Adam(model.parameters(), lr=lr, betas = (0.9, 0.999), weight_decay = 0.001)
# scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=epoch)
loss_fn = nn.CrossEntropyLoss()

In [None]:
for e in range(epoch):
    model.train()
    print(f"Epoch: {e}")
    y_true_list = [] 
    pred_list = []
    batch_cnt = 0
    total_loss = 0
    correct = 0
    correct = 0
    for batch, (train_sig, train_label) in tqdm(enumerate(train_dl)):
        batch_cnt = batch
        train_sig = train_sig.to(device)
        train_label = train_label.to(device)
        
        pred = model(train_sig)
        loss = loss_fn(pred, train_label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
#         scheduler.step()
        
        total_loss += loss.item()
        correct += (pred.argmax(1) == train_label).type(torch.float).sum().item()
    
    total_loss /= batch_cnt
    correct /= len(train_dl.dataset)
    
    print(f"train loss: {total_loss} - train acc: {100*correct}")
    
    batch_cnt = 0
    val_total_loss = 0
    val_correct = 0
    model.eval()
    with torch.no_grad():
        for batch, (valid_sig, valid_label) in tqdm(enumerate(valid_dl)):
            batch_cnt = batch
            valid_sig = valid_sig.to(device)
            valid_label = valid_label.to(device)
            
            pred = model(valid_sig)

            pred_pos = pred.argmax(1)
            y_true_list.append(valid_label)
            pred_list.append(pred_pos)
            
            loss = loss_fn(pred, valid_label)
            
            val_total_loss += loss.item()
            val_correct += (pred.argmax(1) == valid_label).type(torch.float).sum().item()
            
        val_total_loss /= batch_cnt
        val_correct /= len(valid_dl.dataset)
        if val_correct > best_acc:
            best_acc = val_correct
            best_ep = e
        
        print(f"valid loss: {val_total_loss} - valid acc: {100*val_correct}")
        
y_true = torch.cat(y_true_list).cpu().numpy()
pred = torch.cat(pred_list).cpu().numpy()

reports = classification_report(y_true, pred, output_dict=True) 

print(reports)
print(f"Best acuracy: {best_acc} at epoch {best_ep}")

Epoch: 0


293it [00:19, 14.89it/s]

train loss: 3.0055240360024857 - train acc: 58.5239418554938



33it [00:01, 22.63it/s]

valid loss: 2.9332990273833275 - valid acc: 66.0576923076923
Epoch: 1



293it [00:32,  8.97it/s]

train loss: 2.807486946452154 - train acc: 66.81274048738777



33it [00:01, 17.25it/s]

valid loss: 3.0160359516739845 - valid acc: 66.00961538461539
Epoch: 2



293it [00:31,  9.31it/s]

train loss: 2.794082199057488 - train acc: 66.01111586147927



33it [00:02, 15.40it/s]

valid loss: 2.9213756546378136 - valid acc: 56.77884615384615
Epoch: 3



293it [00:36,  8.02it/s]

train loss: 2.7903167344119453 - train acc: 65.0812312954254



33it [00:01, 17.02it/s]

valid loss: 2.889999233186245 - valid acc: 63.17307692307692
Epoch: 4



293it [00:37,  7.76it/s]

train loss: 2.7854545434860336 - train acc: 64.92625053441643



33it [00:01, 17.09it/s]

valid loss: 2.927490457892418 - valid acc: 49.85576923076923
Epoch: 5



293it [00:36,  7.93it/s]

train loss: 2.7837277348727394 - train acc: 64.85677640017101



33it [00:02, 15.69it/s]

valid loss: 2.8995534479618073 - valid acc: 58.31730769230769
Epoch: 6



293it [00:36,  7.95it/s]

train loss: 2.7756853724179202 - train acc: 64.43458743052587



33it [00:02, 15.67it/s]

valid loss: 2.962083265185356 - valid acc: 36.05769230769231
Epoch: 7



293it [00:37,  7.78it/s]

train loss: 2.77422627602538 - train acc: 64.85143223599829



33it [00:01, 19.31it/s]

valid loss: 3.0173163190484047 - valid acc: 6.346153846153846
Epoch: 8



293it [00:37,  7.73it/s]

train loss: 2.773124436809592 - train acc: 65.10795211628901



33it [00:01, 28.08it/s]

valid loss: 2.927283674478531 - valid acc: 46.875
Epoch: 9



293it [00:39,  7.45it/s]

train loss: 2.7689259836118514 - train acc: 66.94100042753314



33it [00:01, 29.88it/s]

valid loss: 2.9425619691610336 - valid acc: 55.24038461538462
Epoch: 10



293it [00:38,  7.59it/s]

train loss: 2.7630813603531825 - train acc: 67.78537836682342



33it [00:01, 22.55it/s]

valid loss: 3.023670345544815 - valid acc: 6.201923076923077
Epoch: 11



293it [00:37,  7.84it/s]

train loss: 2.7600263104046863 - train acc: 68.06327490380505



33it [00:01, 19.82it/s]

valid loss: 2.953602723777294 - valid acc: 66.0576923076923
Epoch: 12



293it [00:35,  8.18it/s]

train loss: 2.757257102286979 - train acc: 68.40530141085934



33it [00:01, 16.54it/s]

valid loss: 2.934908412396908 - valid acc: 66.0576923076923
Epoch: 13



293it [00:33,  8.73it/s]

train loss: 2.7570362981051613 - train acc: 68.15412569474134



33it [00:02, 14.99it/s]

valid loss: 2.992158427834511 - valid acc: 66.0576923076923
Epoch: 14



293it [00:33,  8.62it/s]

train loss: 2.75371585316854 - train acc: 68.35720393330483



33it [00:01, 21.58it/s]

valid loss: 3.035379745066166 - valid acc: 61.58653846153847
Epoch: 15



293it [00:34,  8.61it/s]

train loss: 2.7513400724489396 - train acc: 68.41064557503206



33it [00:01, 19.68it/s]

valid loss: 3.039498694241047 - valid acc: 5.240384615384616
Epoch: 16



293it [00:31,  9.19it/s]

train loss: 2.7548914265959232 - train acc: 68.52287302265925



33it [00:02, 13.58it/s]

valid loss: 2.9915709048509598 - valid acc: 66.49038461538461
Epoch: 17



293it [00:28, 10.11it/s]

train loss: 2.7509506140669733 - train acc: 68.13809320222317



33it [00:02, 14.13it/s]

valid loss: 2.8705310076475143 - valid acc: 67.88461538461539
Epoch: 18



293it [00:28, 10.22it/s]

train loss: 2.7496938419668644 - train acc: 68.53890551517743



33it [00:02, 15.73it/s]

valid loss: 2.9794183894991875 - valid acc: 6.6826923076923075
Epoch: 19



293it [00:31,  9.38it/s]

train loss: 2.74761028322455 - train acc: 68.56562633604103



33it [00:01, 18.13it/s]

valid loss: 3.002136141061783 - valid acc: 5.528846153846153
Epoch: 20



293it [00:30,  9.51it/s]

train loss: 2.745327882570763 - train acc: 68.06327490380505



33it [00:02, 15.25it/s]

valid loss: 2.9720089435577393 - valid acc: 65.8173076923077
Epoch: 21



293it [00:29, 10.00it/s]

train loss: 2.744758071964734 - train acc: 68.44805472424113



33it [00:02, 14.80it/s]

valid loss: 2.874861180782318 - valid acc: 66.0576923076923
Epoch: 22



293it [00:28, 10.13it/s]

train loss: 2.746199228175699 - train acc: 68.29841812740487



33it [00:01, 16.91it/s]

valid loss: 2.9929718747735023 - valid acc: 6.25
Epoch: 23



293it [00:28, 10.21it/s]

train loss: 2.745768720973028 - train acc: 68.26100897819582



33it [00:02, 15.72it/s]

valid loss: 2.9275058209896088 - valid acc: 66.15384615384615
Epoch: 24



293it [00:31,  9.16it/s]

train loss: 2.7448461586481905 - train acc: 68.29841812740487



33it [00:02, 15.88it/s]

valid loss: 2.922559641301632 - valid acc: 67.40384615384616
Epoch: 25



112it [00:15,  7.95it/s]