In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import random
from glob import glob
from tqdm import tqdm
from scipy.io import loadmat

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from sklearn.metrics import classification_report

In [2]:
data_dir = "/media/mountHDD3/data_storage/biomedical_data/ecg_data/SPH"
print(os.listdir(data_dir))

['metadata.csv', 'data_df.csv', 'data_df_no1.csv', 'records']


In [3]:
main_df = pd.read_csv(data_dir + "/data_df.csv")
main_df.shape

(20792, 4)

In [4]:
single_fns = main_df["File name"].values.tolist()
single_mat_paths = [data_dir + f"/records/{x}.h5" for x in single_fns]

In [5]:
ratio = [0.8, 0.1]

train_index = int(len(single_mat_paths)*ratio[0])
valid_index = int(len(single_mat_paths)*(ratio[0]+ratio[1]))

train_mat_paths = single_mat_paths[:train_index]
valid_mat_paths = single_mat_paths[valid_index:]

train_label = main_df
valid_label = main_df

In [6]:
conv_block1 = nn.Sequential(
			nn.Conv1d(12, 64, 16, stride=1),
			nn.LeakyReLU(0.3),
			nn.BatchNorm1d(64),
		)

class Redu(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, padding):
        super().__init__()
        
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.kernel_size = kernel_size
        self.padding = padding
        
        self.conv1 = nn.Sequential(
            nn.MaxPool1d(4),
            nn.Conv1d(self.in_channel, self.out_channel, 1, 1)
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(self.in_channel, self.out_channel, self.kernel_size, 1),
            nn.LeakyReLU(),
            nn.BatchNorm1d(self.out_channel),
            nn.Dropout(0.2),
            nn.Conv1d(self.out_channel, self.out_channel, self.kernel_size, 4, self.padding)
        )
        self.conv3 = nn.Sequential(
            nn.LeakyReLU(),
            nn.BatchNorm1d(self.out_channel),
            nn.Dropout(0.2),
        )
    
    def forward(self, x1, x2):
        x1 = self.conv1(x1)
        x2 = self.conv2(x2)
        x_res = x1 + x2
        x3 = self.conv3(x_res)
        
        return x_res, x3

In [7]:
test = torch.randn(1, 12, 1400)
test = conv_block1(test)
model = Redu(64, 128, 3, 0)
a1, a2 = model(test, test)
print(type(a2))

<class 'torch.Tensor'>


In [8]:
class GlobalAvgPooling(nn.Module):
    def __init__(self):
        super(GlobalAvgPooling, self).__init__()

    def forward(self, x):
        return x.mean(dim=(2))

In [9]:
class LSAT(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.lstm = nn.LSTM(5, 5)
        
        self.conv1 = nn.Sequential(
            nn.Conv1d(960, 256, 1, 1),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
        )
    
        self.conv2 = nn.Sequential(
            nn.Conv1d(256, 128, 1, 1),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),                 
        )
        
        self.attention = nn.MultiheadAttention(5, 1, batch_first=True)
        
        self.conv3 = nn.Sequential(
            nn.LeakyReLU(0.3),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),
        )
        
        self.avgpo = GlobalAvgPooling()
            
        self.conv4 = nn.Sequential(
            nn.Linear(128, 64),
            nn.LeakyReLU(0.3),
            nn.Linear(64, 32),
            nn.LeakyReLU(0.3),
            nn.Linear(32, 34),
            nn.Sigmoid(),
        )
    
    def forward(self, x):
        x1, _ = self.lstm(x)
        x1 = self.conv1(x1)
        
        x1, _ = self.lstm(x1)
        x2 = self.conv2(x1)
        
        x3, _ = self.attention(x2, x2, x2)
        x3 = self.conv3(x3)
        
        x3 = self.avgpo(x2)
        x4 = self.conv4(x3)
        
        return x4   
    

In [10]:
# test1 = torch.randn(1, 960, 1400)
# model1 = LSAT()
# a = model1(test1)
# print(a.shape)

In [11]:
import h5py

class HeartData(Dataset):
    def __init__(self, data_paths):
        self.data_paths = data_paths
        random.shuffle(self.data_paths)

    def __getitem__(self, idx):
        data_path = self.data_paths[idx]
        a = h5py.File(data_path, 'r')
        data_h5 = a['ecg']
        data = np.array(data_h5)
        clip_data = data[:, -1400:]

        filename = data_path.split("/")[-1].split(".")[0]
        label = main_df[main_df["File name"] == filename]["New Label"].values.item()

        torch_data = torch.from_numpy(clip_data)

        return torch_data.float(), label

    def __len__(self):
        return len(self.data_paths)

In [12]:
train_ds = HeartData(train_mat_paths)
valid_ds = HeartData(valid_mat_paths)

In [13]:
train_dl = DataLoader(train_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)
valid_dl = DataLoader(valid_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)

In [14]:
class HeartDNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv_block1 = nn.Sequential(
			nn.Conv1d(12, 64, 16, stride=1),
			nn.LeakyReLU(0.3),
			nn.BatchNorm1d(64),
		)
        self.model_redu11 = Redu(64, 128, 3, 0)
        self.model_redu12 = Redu(128, 192, 3, 0)
        self.model_redu13 = Redu(192, 256, 3, 0)
        self.model_redu14 = Redu(256, 320, 3, 0)
        
        self.model_redu21 = Redu(64, 128, 5, 2)
        self.model_redu22 = Redu(128, 192, 5, 2)
        self.model_redu23 = Redu(192, 256, 5, 2)
        self.model_redu24 = Redu(256, 320, 5, 2)
        
        self.model_redu31 = Redu(64, 128, 7, 4)
        self.model_redu32 = Redu(128, 192, 7, 4)
        self.model_redu33 = Redu(192, 256, 7, 4)
        self.model_redu34 = Redu(256, 320, 7, 4)
        
        self.model_LSAT = LSAT()
    
    
    def forward(self, x):  
    
        data_conv = self.conv_block1(x)

        data_x11, data_x12 = self.model_redu11(data_conv, data_conv)
        data_x13, data_x14 = self.model_redu12(data_x11, data_x12)
        data_x15, data_x16 = self.model_redu13(data_x13, data_x14)
        data_x17, data_x18 = self.model_redu14(data_x15, data_x16)

        
        data_x21, data_x22 = self.model_redu21(data_conv, data_conv)
        data_x23, data_x24 = self.model_redu22(data_x21, data_x22)
        data_x25, data_x26 = self.model_redu23(data_x23, data_x24)
        data_x27, data_x28 = self.model_redu24(data_x25, data_x26)

        
        data_x31, data_x32 = self.model_redu31(data_conv, data_conv)
        data_x33, data_x34 = self.model_redu32(data_x31, data_x32)
        data_x35, data_x36 = self.model_redu33(data_x33, data_x34)
        data_x37, data_x38 = self.model_redu34(data_x35, data_x36)
        
        data = torch.cat((data_x18, data_x28, data_x38), dim = 1)
   
        out = self.model_LSAT(data)

        return out

In [15]:
# test = torch.randn(1, 12, 1400)
model = HeartDNN()

In [16]:
epoch = 150
lr = 0.0005
best_acc = 0
best_ep = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu", index = 0)
model.to(device)
optimizer = Adam(model.parameters(), lr=lr, betas = (0.9, 0.999), weight_decay = 0.001)
# scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=epoch)
loss_fn = nn.CrossEntropyLoss()

In [None]:
for e in range(epoch):
    model.train()
    print(f"Epoch: {e}")
    y_true_list = [] 
    pred_list = []
    batch_cnt = 0
    total_loss = 0
    correct = 0
    correct = 0
    for batch, (train_sig, train_label) in tqdm(enumerate(train_dl)):
        batch_cnt = batch
        train_sig = train_sig.to(device)
        train_label = train_label.to(device)
        
        pred = model(train_sig)
        loss = loss_fn(pred, train_label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
#         scheduler.step()
        
        total_loss += loss.item()
        correct += (pred.argmax(1) == train_label).type(torch.float).sum().item()
    
    total_loss /= batch_cnt
    correct /= len(train_dl.dataset)
    
    print(f"train loss: {total_loss} - train acc: {100*correct}")
    
    batch_cnt = 0
    val_total_loss = 0
    val_correct = 0
    model.eval()
    with torch.no_grad():
        for batch, (valid_sig, valid_label) in tqdm(enumerate(valid_dl)):
            batch_cnt = batch
            valid_sig = valid_sig.to(device)
            valid_label = valid_label.to(device)
            
            pred = model(valid_sig)

            pred_pos = pred.argmax(1)
            y_true_list.append(valid_label)
            pred_list.append(pred_pos)
            
            loss = loss_fn(pred, valid_label)
            
            val_total_loss += loss.item()
            val_correct += (pred.argmax(1) == valid_label).type(torch.float).sum().item()
            
        val_total_loss /= batch_cnt
        val_correct /= len(valid_dl.dataset)
        if val_correct > best_acc:
            best_acc = val_correct
            best_ep = e
        
        print(f"valid loss: {val_total_loss} - valid acc: {100*val_correct}")
        
y_true = torch.cat(y_true_list).cpu().numpy()
pred = torch.cat(pred_list).cpu().numpy()

reports = classification_report(y_true, pred, output_dict=True) 

print(reports)
print(f"Best acuracy: {best_acc} at epoch {best_ep}")

Epoch: 0


260it [00:18, 14.00it/s]

train loss: 3.05023235980148 - train acc: 58.04124331148921



33it [00:01, 26.70it/s]

valid loss: 2.978894926607609 - valid acc: 66.0576923076923
Epoch: 1



260it [00:27,  9.61it/s]

train loss: 2.8525908398352073 - train acc: 67.0233872422293



33it [00:01, 23.24it/s]

valid loss: 3.0162433385849 - valid acc: 66.0576923076923
Epoch: 2



260it [00:29,  8.93it/s]

train loss: 2.8060603436355884 - train acc: 67.20375157818795



33it [00:01, 17.87it/s]

valid loss: 2.9976572319865227 - valid acc: 66.00961538461539
Epoch: 3



260it [00:30,  8.56it/s]


train loss: 2.7891141134799677 - train acc: 68.3460590392593


33it [00:02, 15.55it/s]

valid loss: 2.9935514256358147 - valid acc: 54.42307692307692
Epoch: 4



260it [00:31,  8.14it/s]


train loss: 2.7826221983405155 - train acc: 68.88715204713522


33it [00:01, 17.20it/s]

valid loss: 2.932927630841732 - valid acc: 65.86538461538461
Epoch: 5



260it [00:33,  7.72it/s]

train loss: 2.7817676941860596 - train acc: 69.14567426200927



33it [00:02, 15.60it/s]

valid loss: 2.865449905395508 - valid acc: 67.98076923076923
Epoch: 6



260it [00:32,  7.90it/s]

train loss: 2.781544505859434 - train acc: 69.56652437924608



33it [00:01, 17.41it/s]

valid loss: 2.93757014721632 - valid acc: 66.0576923076923
Epoch: 7



260it [00:33,  7.80it/s]

train loss: 2.7807835888218237 - train acc: 69.42824505501112



33it [00:02, 15.87it/s]

valid loss: 3.0246489197015762 - valid acc: 5.480769230769231
Epoch: 8



260it [00:33,  7.66it/s]

train loss: 2.770455012450347 - train acc: 69.03745566043408



33it [00:02, 14.85it/s]

valid loss: 2.846823461353779 - valid acc: 66.29807692307692
Epoch: 9



260it [00:34,  7.47it/s]

train loss: 2.758750863977381 - train acc: 69.12162568388143



33it [00:01, 29.42it/s]

valid loss: 2.923106335103512 - valid acc: 66.77884615384615
Epoch: 10



260it [00:36,  7.17it/s]

train loss: 2.7573479814418955 - train acc: 69.65069440269343



33it [00:01, 25.80it/s]

valid loss: 2.9753245636820793 - valid acc: 62.64423076923077
Epoch: 11



260it [00:35,  7.22it/s]

train loss: 2.7577343631435083 - train acc: 69.13966211747731



33it [00:01, 19.16it/s]

valid loss: 2.856489859521389 - valid acc: 69.42307692307692
Epoch: 12



260it [00:33,  7.79it/s]

train loss: 2.754647210758165 - train acc: 68.98935850417844



33it [00:02, 16.44it/s]

valid loss: 2.8678596317768097 - valid acc: 66.20192307692308
Epoch: 13



260it [00:31,  8.32it/s]

train loss: 2.7521525220981435 - train acc: 69.00739493777431



33it [00:02, 15.16it/s]

valid loss: 2.8740261271595955 - valid acc: 66.00961538461539
Epoch: 14



260it [00:28,  9.07it/s]

train loss: 2.749606530178468 - train acc: 69.03745566043408



33it [00:01, 18.98it/s]

valid loss: 2.8739308565855026 - valid acc: 66.15384615384615
Epoch: 15



260it [00:27,  9.46it/s]

train loss: 2.7488716707266434 - train acc: 69.59658510190584



33it [00:01, 17.11it/s]

valid loss: 2.9858867302536964 - valid acc: 6.298076923076923
Epoch: 16



260it [00:23, 10.91it/s]

train loss: 2.7470651602653002 - train acc: 69.25990500811639



33it [00:01, 16.55it/s]

valid loss: 2.8715673834085464 - valid acc: 66.4423076923077
Epoch: 17



260it [00:24, 10.73it/s]

train loss: 2.744913451920145 - train acc: 68.80298202368785



33it [00:02, 13.60it/s]

valid loss: 2.886696547269821 - valid acc: 66.34615384615384
Epoch: 18



260it [00:28,  9.05it/s]

train loss: 2.743116674275932 - train acc: 69.03745566043408



33it [00:02, 13.93it/s]

valid loss: 2.9413651898503304 - valid acc: 66.20192307692308
Epoch: 19



260it [00:29,  8.77it/s]

train loss: 2.742647270438294 - train acc: 69.60860939096976



33it [00:01, 18.66it/s]

valid loss: 2.851193517446518 - valid acc: 66.53846153846153
Epoch: 20



260it [00:29,  8.80it/s]

train loss: 2.7412986976299507 - train acc: 69.54247580111826



33it [00:01, 17.88it/s]

valid loss: 2.826848916709423 - valid acc: 68.3173076923077
Epoch: 21



260it [00:28,  9.06it/s]

train loss: 2.73989530419751 - train acc: 69.43425719954308



33it [00:01, 21.68it/s]

valid loss: 2.906074158847332 - valid acc: 67.5
Epoch: 22



260it [00:30,  8.52it/s]

train loss: 2.74096696257131 - train acc: 69.62063368003368



33it [00:01, 20.60it/s]

valid loss: 2.8242237344384193 - valid acc: 69.8076923076923
Epoch: 23



260it [00:31,  8.16it/s]

train loss: 2.7382801398347243 - train acc: 69.42223291047917



33it [00:01, 21.54it/s]

valid loss: 2.920686259865761 - valid acc: 67.54807692307693
Epoch: 24



260it [00:31,  8.16it/s]

train loss: 2.7381232429195093 - train acc: 70.05350808633439



33it [00:01, 22.50it/s]

valid loss: 2.8495705351233482 - valid acc: 65.8173076923077
Epoch: 25



260it [00:34,  7.59it/s]

train loss: 2.7378066230464624 - train acc: 69.51241507845849



33it [00:01, 27.03it/s]

valid loss: 2.8474792391061783 - valid acc: 65.8173076923077
Epoch: 26



200it [00:25,  7.47it/s]