In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import random
from glob import glob
from tqdm import tqdm
from scipy.io import loadmat

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from sklearn.metrics import classification_report

In [2]:
data_dir = "/media/mountHDD3/data_storage/biomedical_data/ecg_data/SPH"
print(os.listdir(data_dir))

['metadata.csv', 'data_df.csv', 'data_df_no1.csv', 'records']


In [3]:
main_df = pd.read_csv(data_dir + "/data_df.csv")
main_df.shape

(20792, 4)

In [4]:
single_fns = main_df["File name"].values.tolist()
single_mat_paths = [data_dir + f"/records/{x}.h5" for x in single_fns]

In [5]:
ratio = [0.9, 0.1]

train_index = int(len(single_mat_paths)*ratio[0])

train_mat_paths = single_mat_paths[:train_index]
valid_mat_paths = single_mat_paths[train_index:]

train_label = main_df
valid_label = main_df

In [6]:
conv_block1 = nn.Sequential(
			nn.Conv1d(12, 64, 16, stride=1),
			nn.LeakyReLU(0.3),
			nn.BatchNorm1d(64),
		)

class Redu(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, padding):
        super().__init__()
        
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.kernel_size = kernel_size
        self.padding = padding
        
        self.conv1 = nn.Sequential(
            nn.MaxPool1d(4),
            nn.Conv1d(self.in_channel, self.out_channel, 1, 1)
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(self.in_channel, self.out_channel, self.kernel_size, 1),
            nn.LeakyReLU(),
            nn.BatchNorm1d(self.out_channel),
            nn.Dropout(0.2),
            nn.Conv1d(self.out_channel, self.out_channel, self.kernel_size, 4, self.padding)
        )
        self.conv3 = nn.Sequential(
            nn.LeakyReLU(),
            nn.BatchNorm1d(self.out_channel),
            nn.Dropout(0.2),
        )
    
    def forward(self, x1, x2):
        x1 = self.conv1(x1)
        x2 = self.conv2(x2)
        x_res = x1 + x2
        x3 = self.conv3(x_res)
        
        return x_res, x3

In [7]:
test = torch.randn(1, 12, 1400)
test = conv_block1(test)
model = Redu(64, 128, 3, 0)
a1, a2 = model(test, test)
print(type(a2))

<class 'torch.Tensor'>


In [8]:
class GlobalAvgPooling(nn.Module):
    def __init__(self):
        super(GlobalAvgPooling, self).__init__()

    def forward(self, x):
        return x.mean(dim=(2))

In [9]:
class LSAT(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.lstm = nn.LSTM(5, 5)
        
        self.conv1 = nn.Sequential(
            nn.Conv1d(960, 256, 1, 1),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
        )
    
        self.conv2 = nn.Sequential(
            nn.Conv1d(256, 128, 1, 1),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),                 
        )
        
        self.attention = nn.MultiheadAttention(5, 1, batch_first=True)
        
        self.conv3 = nn.Sequential(
            nn.LeakyReLU(0.3),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),
        )
        
        self.avgpo = GlobalAvgPooling()
            
        self.conv4 = nn.Sequential(
            nn.Linear(128, 64),
            nn.LeakyReLU(0.3),
            nn.Linear(64, 32),
            nn.LeakyReLU(0.3),
            nn.Linear(32, 34),
            nn.Sigmoid(),
        )
    
    def forward(self, x):
        x1, _ = self.lstm(x)
        x1 = self.conv1(x1)
        
        x1, _ = self.lstm(x1)
        x2 = self.conv2(x1)
        
        x3, _ = self.attention(x2, x2, x2)
        x3 = self.conv3(x3)
        
        x3 = self.avgpo(x2)
        x4 = self.conv4(x3)
        
        return x4   
    

In [10]:
# test1 = torch.randn(1, 960, 1400)
# model1 = LSAT()
# a = model1(test1)
# print(a.shape)

In [11]:
import h5py

class HeartData(Dataset):
    def __init__(self, data_paths):
        self.data_paths = data_paths
        random.shuffle(self.data_paths)

    def __getitem__(self, idx):
        data_path = self.data_paths[idx]
        a = h5py.File(data_path, 'r')
        data_h5 = a['ecg']
        data = np.array(data_h5)
        clip_data = data[:, -1400:]

        filename = data_path.split("/")[-1].split(".")[0]
        label = main_df[main_df["File name"] == filename]["New Label"].values.item()

        torch_data = torch.from_numpy(clip_data)

        return torch_data.float(), label

    def __len__(self):
        return len(self.data_paths)

In [12]:
train_ds = HeartData(train_mat_paths)
valid_ds = HeartData(valid_mat_paths)

In [13]:
train_dl = DataLoader(train_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)
valid_dl = DataLoader(valid_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)

In [14]:
class HeartDNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv_block1 = nn.Sequential(
			nn.Conv1d(12, 64, 16, stride=1),
			nn.LeakyReLU(0.3),
			nn.BatchNorm1d(64),
		)
        self.model_redu11 = Redu(64, 128, 3, 0)
        self.model_redu12 = Redu(128, 192, 3, 0)
        self.model_redu13 = Redu(192, 256, 3, 0)
        self.model_redu14 = Redu(256, 320, 3, 0)
        
        self.model_redu21 = Redu(64, 128, 5, 2)
        self.model_redu22 = Redu(128, 192, 5, 2)
        self.model_redu23 = Redu(192, 256, 5, 2)
        self.model_redu24 = Redu(256, 320, 5, 2)
        
        self.model_redu31 = Redu(64, 128, 7, 4)
        self.model_redu32 = Redu(128, 192, 7, 4)
        self.model_redu33 = Redu(192, 256, 7, 4)
        self.model_redu34 = Redu(256, 320, 7, 4)
        
        self.model_LSAT = LSAT()
    
    
    def forward(self, x):  
    
        data_conv = self.conv_block1(x)

        data_x11, data_x12 = self.model_redu11(data_conv, data_conv)
        data_x13, data_x14 = self.model_redu12(data_x11, data_x12)
        data_x15, data_x16 = self.model_redu13(data_x13, data_x14)
        data_x17, data_x18 = self.model_redu14(data_x15, data_x16)

        
        data_x21, data_x22 = self.model_redu21(data_conv, data_conv)
        data_x23, data_x24 = self.model_redu22(data_x21, data_x22)
        data_x25, data_x26 = self.model_redu23(data_x23, data_x24)
        data_x27, data_x28 = self.model_redu24(data_x25, data_x26)

        
        data_x31, data_x32 = self.model_redu31(data_conv, data_conv)
        data_x33, data_x34 = self.model_redu32(data_x31, data_x32)
        data_x35, data_x36 = self.model_redu33(data_x33, data_x34)
        data_x37, data_x38 = self.model_redu34(data_x35, data_x36)
        
        data = torch.cat((data_x18, data_x28, data_x38), dim = 1)
   
        out = self.model_LSAT(data)

        return out

In [15]:
# test = torch.randn(1, 12, 1400)
model = HeartDNN()

In [16]:
epoch = 150
lr = 0.0005

device = torch.device("cuda" if torch.cuda.is_available() else "cpu", index = 1)
model.to(device)
optimizer = Adam(model.parameters(), lr=lr, betas = (0.9, 0.999), weight_decay = 0.001)
# scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=epoch)
loss_fn = nn.CrossEntropyLoss()

In [None]:
for e in range(epoch):
    model.train()
    print(f"Epoch: {e}")
    batch_cnt = 0
    total_loss = 0
    correct = 0
    for batch, (train_sig, train_label) in tqdm(enumerate(train_dl)):
        batch_cnt = batch
        train_sig = train_sig.to(device)
        train_label = train_label.to(device)
        
        pred = model(train_sig)
        loss = loss_fn(pred, train_label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
#         scheduler.step()
        
        total_loss += loss.item()
        correct += (pred.argmax(1) == train_label).type(torch.float).sum().item()
    
    total_loss /= batch_cnt
    correct /= len(train_dl.dataset)
    
    print(f"train loss: {total_loss} - train acc: {100*correct}")
    
    batch_cnt = 0
    val_total_loss = 0
    val_correct = 0
    model.eval()
    with torch.no_grad():
        for batch, (valid_sig, valid_label) in tqdm(enumerate(valid_dl)):
            batch_cnt = batch
            valid_sig = valid_sig.to(device)
            valid_label = valid_label.to(device)
            
            pred = model(valid_sig)
            loss = loss_fn(pred, valid_label)
            
            val_total_loss += loss.item()
            val_correct += (pred.argmax(1) == valid_label).type(torch.float).sum().item()
    
        val_total_loss /= batch_cnt
        val_correct /= len(valid_dl.dataset)
        
        print(f"valid loss: {val_total_loss} - valid acc: {100*val_correct}")

Epoch: 0


293it [00:24, 11.76it/s]

train loss: 3.0547534856077743 - train acc: 62.494655835827274



33it [00:01, 27.66it/s]

valid loss: 2.973953001201153 - valid acc: 66.0576923076923
Epoch: 1



293it [00:21, 13.84it/s]

train loss: 2.8839842138225085 - train acc: 66.9570329200513



33it [00:01, 30.29it/s]

valid loss: 2.9739452600479126 - valid acc: 66.0576923076923
Epoch: 2



293it [00:21, 13.69it/s]

train loss: 2.861419574855125 - train acc: 66.9570329200513



33it [00:01, 30.11it/s]


valid loss: 2.9786726012825966 - valid acc: 66.0576923076923
Epoch: 3


293it [00:21, 13.81it/s]

train loss: 2.8323355774356895 - train acc: 66.99444206926036



33it [00:01, 29.23it/s]

valid loss: 2.926537901163101 - valid acc: 66.0576923076923
Epoch: 4



293it [00:21, 13.83it/s]

train loss: 2.8186834087110544 - train acc: 67.23492945703292



33it [00:01, 29.58it/s]

valid loss: 2.903954289853573 - valid acc: 66.0576923076923
Epoch: 5



293it [00:21, 13.91it/s]

train loss: 2.796970308643498 - train acc: 67.41663103890552



33it [00:01, 28.40it/s]

valid loss: 2.9189019948244095 - valid acc: 65.96153846153847
Epoch: 6



293it [00:21, 13.64it/s]

train loss: 2.7864304462524307 - train acc: 67.32578024796922



33it [00:01, 27.29it/s]

valid loss: 2.894411876797676 - valid acc: 65.91346153846153
Epoch: 7



293it [00:21, 13.56it/s]

train loss: 2.7832928438709206 - train acc: 67.30440359127833



33it [00:01, 26.58it/s]

valid loss: 3.1826119795441628 - valid acc: 11.826923076923077
Epoch: 8



293it [00:21, 13.72it/s]

train loss: 2.782355163195362 - train acc: 67.2509619495511



33it [00:01, 25.83it/s]

valid loss: 2.8783155381679535 - valid acc: 66.10576923076923
Epoch: 9



293it [00:21, 13.58it/s]

train loss: 2.77907725148005 - train acc: 67.49144933732364



33it [00:01, 27.87it/s]

valid loss: 2.906625084578991 - valid acc: 66.20192307692308
Epoch: 10



293it [00:21, 13.63it/s]

train loss: 2.778648847586488 - train acc: 67.34715690466011



33it [00:01, 28.10it/s]

valid loss: 2.9952602088451385 - valid acc: 66.00961538461539
Epoch: 11



293it [00:21, 13.50it/s]

train loss: 2.774873088483941 - train acc: 67.0478837109876



33it [00:01, 30.41it/s]

valid loss: 2.946235813200474 - valid acc: 65.96153846153847
Epoch: 12



293it [00:21, 13.53it/s]

train loss: 2.776053675233501 - train acc: 67.81209918768705



33it [00:01, 28.52it/s]

valid loss: 2.9364047199487686 - valid acc: 66.0576923076923
Epoch: 13



293it [00:21, 13.59it/s]

train loss: 2.7713423861216193 - train acc: 67.2990594271056



33it [00:01, 27.61it/s]

valid loss: 2.9006586223840714 - valid acc: 66.25
Epoch: 14



293it [00:21, 13.58it/s]

train loss: 2.7726424491568786 - train acc: 67.44869602394186



33it [00:01, 26.64it/s]

valid loss: 2.9441621005535126 - valid acc: 66.00961538461539
Epoch: 15



293it [00:21, 13.54it/s]

train loss: 2.7694024966187674 - train acc: 67.53954681487815



33it [00:01, 27.88it/s]

valid loss: 2.9914291501045227 - valid acc: 65.91346153846153
Epoch: 16



293it [00:21, 13.72it/s]

train loss: 2.7667947672817803 - train acc: 67.79072253099615



33it [00:01, 27.53it/s]

valid loss: 2.9003865122795105 - valid acc: 67.11538461538461
Epoch: 17



293it [00:21, 13.60it/s]

train loss: 2.7635234022793704 - train acc: 68.58700299273194



33it [00:01, 27.49it/s]

valid loss: 2.855415441095829 - valid acc: 67.16346153846153
Epoch: 18



293it [00:21, 13.65it/s]

train loss: 2.7570628368691223 - train acc: 68.70991876870457



33it [00:01, 27.70it/s]

valid loss: 2.989921383559704 - valid acc: 65.67307692307692
Epoch: 19



293it [00:21, 13.67it/s]

train loss: 2.756084093492325 - train acc: 68.73129542539547



33it [00:01, 27.63it/s]

valid loss: 3.0160753577947617 - valid acc: 66.00961538461539
Epoch: 20



293it [00:21, 13.60it/s]

train loss: 2.7543287350706858 - train acc: 68.6938862761864



33it [00:01, 28.90it/s]

valid loss: 2.878653734922409 - valid acc: 66.29807692307692
Epoch: 21



293it [00:21, 13.79it/s]

train loss: 2.7546903327719807 - train acc: 68.5549380076956



33it [00:01, 24.94it/s]

valid loss: 2.8526230603456497 - valid acc: 68.75
Epoch: 22



293it [00:21, 13.70it/s]

train loss: 2.753364238836994 - train acc: 69.01453612654981



33it [00:01, 26.29it/s]

valid loss: 3.026267573237419 - valid acc: 66.00961538461539
Epoch: 23



293it [00:21, 13.57it/s]

train loss: 2.754417456992685 - train acc: 68.82214621633177



33it [00:01, 24.16it/s]

valid loss: 2.8439242020249367 - valid acc: 68.36538461538461
Epoch: 24



293it [00:21, 13.70it/s]

train loss: 2.753336977468778 - train acc: 69.00384779820436



33it [00:01, 27.24it/s]

valid loss: 2.8545842096209526 - valid acc: 66.58653846153845
Epoch: 25



293it [00:21, 13.65it/s]

train loss: 2.7532259821891785 - train acc: 68.60837964942283



33it [00:01, 29.62it/s]

valid loss: 3.0661134272813797 - valid acc: 4.134615384615384
Epoch: 26



293it [00:21, 13.68it/s]

train loss: 2.752902580450659 - train acc: 68.88093202223172



33it [00:01, 24.59it/s]

valid loss: 2.8410847410559654 - valid acc: 68.36538461538461
Epoch: 27



293it [00:21, 13.83it/s]

train loss: 2.753196949828161 - train acc: 68.81145788798632



33it [00:01, 28.76it/s]

valid loss: 2.861026383936405 - valid acc: 67.83653846153847
Epoch: 28



293it [00:21, 13.79it/s]

train loss: 2.7521980706959557 - train acc: 68.74732791791364



33it [00:01, 29.81it/s]


valid loss: 2.8805154263973236 - valid acc: 66.97115384615384
Epoch: 29


293it [00:21, 13.69it/s]

train loss: 2.7525567926772654 - train acc: 69.15348439504062



33it [00:01, 28.43it/s]

valid loss: 2.8656307756900787 - valid acc: 67.0673076923077
Epoch: 30



293it [00:21, 13.61it/s]

train loss: 2.749854563850246 - train acc: 69.11607524583155



33it [00:01, 32.00it/s]

valid loss: 2.938569165766239 - valid acc: 66.0576923076923
Epoch: 31



293it [00:21, 13.65it/s]

train loss: 2.7519464305002397 - train acc: 69.23899102180418



33it [00:01, 26.70it/s]

valid loss: 2.8344115763902664 - valid acc: 67.74038461538461
Epoch: 32



293it [00:21, 13.69it/s]

train loss: 2.750471134708352 - train acc: 69.06263360410432



33it [00:01, 29.44it/s]

valid loss: 2.88390601426363 - valid acc: 66.25
Epoch: 33



293it [00:21, 13.69it/s]

train loss: 2.7499443480413253 - train acc: 68.91299700726806



33it [00:01, 27.27it/s]

valid loss: 2.8694888651371 - valid acc: 66.73076923076923
Epoch: 34



293it [00:21, 13.67it/s]

train loss: 2.751218509184171 - train acc: 69.02522445489525



33it [00:01, 22.79it/s]

valid loss: 2.8646877855062485 - valid acc: 60.62499999999999
Epoch: 35



293it [00:21, 13.74it/s]

train loss: 2.74905917742481 - train acc: 68.80076955964087



33it [00:01, 27.84it/s]

valid loss: 2.8303080424666405 - valid acc: 68.84615384615384
Epoch: 36



293it [00:21, 13.59it/s]

train loss: 2.7511058482405257 - train acc: 68.74198375374091



33it [00:01, 29.92it/s]

valid loss: 2.9639571830630302 - valid acc: 65.86538461538461
Epoch: 37



293it [00:21, 13.71it/s]

train loss: 2.750931582222246 - train acc: 69.0359127832407



33it [00:01, 28.74it/s]

valid loss: 2.8507902324199677 - valid acc: 67.98076923076923
Epoch: 38



293it [00:21, 13.59it/s]

train loss: 2.749431851792009 - train acc: 69.21227020094058



33it [00:01, 29.42it/s]

valid loss: 2.838011473417282 - valid acc: 66.92307692307692
Epoch: 39



293it [00:21, 13.54it/s]

train loss: 2.746379018646397 - train acc: 68.87558785805899



33it [00:01, 29.45it/s]

valid loss: 2.864581637084484 - valid acc: 65.48076923076923
Epoch: 40



293it [00:21, 13.66it/s]

train loss: 2.7491372480784375 - train acc: 69.1481402308679



33it [00:01, 24.04it/s]

valid loss: 2.8414665311574936 - valid acc: 66.82692307692307
Epoch: 41



55it [00:04, 14.25it/s]