In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import random
from glob import glob
from tqdm import tqdm
from scipy.io import loadmat

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from sklearn.metrics import classification_report

In [2]:
data_dir = "/media/mountHDD3/data_storage/biomedical_data/ecg_data/SPH"
print(os.listdir(data_dir))

['metadata.csv', 'data_df.csv', 'data_df_no1.csv', 'records']


In [3]:
main_df = pd.read_csv(data_dir + "/data_df.csv")
main_df.shape

(20792, 4)

In [4]:
single_fns = main_df["File name"].values.tolist()
single_mat_paths = [data_dir + f"/records/{x}.h5" for x in single_fns]

In [5]:
ratio = [0.9, 0.1]

train_index = int(len(single_mat_paths)*ratio[0])

train_mat_paths = single_mat_paths[:train_index]
valid_mat_paths = single_mat_paths[train_index:]

train_label = main_df
valid_label = main_df

In [6]:
conv_block1 = nn.Sequential(
			nn.Conv1d(12, 64, 16, stride=1),
			nn.LeakyReLU(0.3),
			nn.BatchNorm1d(64),
		)

class Redu(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, padding):
        super().__init__()
        
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.kernel_size = kernel_size
        self.padding = padding
        
        self.conv1 = nn.Sequential(
            nn.MaxPool1d(4),
            nn.Conv1d(self.in_channel, self.out_channel, 1, 1)
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(self.in_channel, self.out_channel, self.kernel_size, 1),
            nn.LeakyReLU(),
            nn.BatchNorm1d(self.out_channel),
            nn.Dropout(0.2),
            nn.Conv1d(self.out_channel, self.out_channel, self.kernel_size, 4, self.padding)
        )
        self.conv3 = nn.Sequential(
            nn.LeakyReLU(),
            nn.BatchNorm1d(self.out_channel),
            nn.Dropout(0.2),
        )
    
    def forward(self, x1, x2):
        x1 = self.conv1(x1)
        x2 = self.conv2(x2)
        x_res = x1 + x2
        x3 = self.conv3(x_res)
        
        return x_res, x3

In [7]:
test = torch.randn(1, 12, 1400)
test = conv_block1(test)
model = Redu(64, 128, 3, 0)
a1, a2 = model(test, test)
print(type(a2))

<class 'torch.Tensor'>


In [8]:
class GlobalAvgPooling(nn.Module):
    def __init__(self):
        super(GlobalAvgPooling, self).__init__()

    def forward(self, x):
        return x.mean(dim=(2))

In [9]:
class LSAT(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.lstm = nn.LSTM(5, 5)
        
        self.conv1 = nn.Sequential(
            nn.Conv1d(960, 256, 1, 1),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
        )
    
        self.conv2 = nn.Sequential(
            nn.Conv1d(256, 128, 1, 1),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),                 
        )
        
        self.attention = nn.MultiheadAttention(5, 1, batch_first=True)
        
        self.conv3 = nn.Sequential(
            nn.LeakyReLU(0.3),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),
        )
        
        self.avgpo = GlobalAvgPooling()
            
        self.conv4 = nn.Sequential(
            nn.Linear(128, 64),
            nn.LeakyReLU(0.3),
            nn.Linear(64, 32),
            nn.LeakyReLU(0.3),
            nn.Linear(32, 34),
            nn.Sigmoid(),
        )
    
    def forward(self, x):
        x1, _ = self.lstm(x)
        x1 = self.conv1(x1)
        
        x1, _ = self.lstm(x1)
        x2 = self.conv2(x1)
        
        x3, _ = self.attention(x2, x2, x2)
        x3 = self.conv3(x3)
        
        x3 = self.avgpo(x2)
        x4 = self.conv4(x3)
        
        return x4   
    

In [10]:
import h5py

class HeartData(Dataset):
    def __init__(self, data_paths):
        self.data_paths = data_paths
        random.shuffle(self.data_paths)

    def __getitem__(self, idx):
        data_path = self.data_paths[idx]
        a = h5py.File(data_path, 'r')
        data_h5 = a['ecg']
        data = np.array(data_h5)
        clip_data = data[:, -1400:]

        filename = data_path.split("/")[-1].split(".")[0]
        label = main_df[main_df["File name"] == filename]["New Label"].values.item()

        torch_data = torch.from_numpy(clip_data)

        return torch_data.float(), label

    def __len__(self):
        return len(self.data_paths)

In [11]:
train_ds = HeartData(train_mat_paths)
valid_ds = HeartData(valid_mat_paths)

In [12]:
train_dl = DataLoader(train_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)
valid_dl = DataLoader(valid_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)

In [13]:
class HeartDNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv_block1 = nn.Sequential(
			nn.Conv1d(12, 64, 16, stride=1),
			nn.LeakyReLU(0.3),
			nn.BatchNorm1d(64),
		)
        self.model_redu11 = Redu(64, 128, 3, 0)
        self.model_redu12 = Redu(128, 192, 3, 0)
        self.model_redu13 = Redu(192, 256, 3, 0)
        self.model_redu14 = Redu(256, 320, 3, 0)
        
        self.model_redu21 = Redu(64, 128, 5, 2)
        self.model_redu22 = Redu(128, 192, 5, 2)
        self.model_redu23 = Redu(192, 256, 5, 2)
        self.model_redu24 = Redu(256, 320, 5, 2)
        
        self.model_redu31 = Redu(64, 128, 7, 4)
        self.model_redu32 = Redu(128, 192, 7, 4)
        self.model_redu33 = Redu(192, 256, 7, 4)
        self.model_redu34 = Redu(256, 320, 7, 4)
        
        self.model_LSAT = LSAT()
    
    
    def forward(self, x):  
    
        data_conv = self.conv_block1(x)

        data_x11, data_x12 = self.model_redu11(data_conv, data_conv)
        data_x13, data_x14 = self.model_redu12(data_x11, data_x12)
        data_x15, data_x16 = self.model_redu13(data_x13, data_x14)
        data_x17, data_x18 = self.model_redu14(data_x15, data_x16)

        
        data_x21, data_x22 = self.model_redu21(data_conv, data_conv)
        data_x23, data_x24 = self.model_redu22(data_x21, data_x22)
        data_x25, data_x26 = self.model_redu23(data_x23, data_x24)
        data_x27, data_x28 = self.model_redu24(data_x25, data_x26)

        
        data_x31, data_x32 = self.model_redu31(data_conv, data_conv)
        data_x33, data_x34 = self.model_redu32(data_x31, data_x32)
        data_x35, data_x36 = self.model_redu33(data_x33, data_x34)
        data_x37, data_x38 = self.model_redu34(data_x35, data_x36)
        
        data = torch.cat((data_x18, data_x28, data_x38), dim = 1)
   
        out = self.model_LSAT(data)

        return out

In [14]:
# test = torch.randn(1, 12, 1400)
model = HeartDNN()

In [15]:
class FocalClassifierV0(nn.Module):
    def __init__(self, gamma=0.3): #Change gamma value here in order to acquire other results
        super().__init__()
        
        self.gamma = gamma
        self.act = nn.LogSoftmax(dim=1)

    
    def forward(self, pred, target):

        logits = self.act(pred)

        B, C = tuple(logits.size())

        entropy = torch.pow(1 - logits, self.gamma) * logits * F.one_hot(target, num_classes=C).float()

        return (-1 / B) * torch.sum(entropy)

focalloss_fn = FocalClassifierV0()

In [16]:
epoch = 150
lr = 0.0005
best_acc = 0
best_ep = 0
# class_la = [1,2,3,4,5,6,7,8,9]
# for i in range (len(class_la)):
#     class_la[i] = str(class_la[i])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu", index = 0)
# device = torch.device("cpu", index = 0)
model.to(device)
optimizer = Adam(model.parameters(), lr=lr, betas = (0.9, 0.999), weight_decay = 0.001)
loss_fn = nn.CrossEntropyLoss()

In [None]:
for e in range(epoch):
    model.train()
    print(f"Epoch: {e}")
    y_true_list = [] 
    pred_list = []
    batch_cnt = 0
    total_loss = 0
    correct = 0
    for batch, (train_sig, train_label) in tqdm(enumerate(train_dl)):
        batch_cnt = batch
        train_sig = train_sig.to(device)
        train_label = train_label.to(device)
        
        pred = model(train_sig)
        loss = focalloss_fn(pred, train_label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
                
        total_loss += loss.item()
        correct += (pred.argmax(1) == train_label).type(torch.float).sum().item()
    
    total_loss /= batch_cnt
    correct /= len(train_dl.dataset)
    
    print(f"train loss: {total_loss} - train acc: {100*correct}")
    
    batch_cnt = 0
    val_total_loss = 0
    val_correct = 0
    model.eval()
    with torch.no_grad():
        for batch, (valid_sig, valid_label) in tqdm(enumerate(valid_dl)):
            batch_cnt = batch
            valid_sig = valid_sig.to(device)
            valid_label = valid_label.to(device)
            
            pred = model(valid_sig)
            
            pred_pos = pred.argmax(1)
            y_true_list.append(valid_label)
            pred_list.append(pred_pos)
            
            loss = loss_fn(pred, valid_label)
            
            val_total_loss += loss.item()
            val_correct += (pred.argmax(1) == valid_label).type(torch.float).sum().item()
    
        val_total_loss /= batch_cnt
        val_correct /= len(valid_dl.dataset)
        if val_correct > best_acc:
            best_acc = val_correct
            best_ep = e
        
        print(f"valid loss: {val_total_loss} - valid acc: {100*val_correct}")
        
y_true = torch.cat(y_true_list).cpu().numpy()
pred = torch.cat(pred_list).cpu().numpy()

reports = classification_report(y_true, pred, output_dict=True) 

print(reports)
print(f"Best acuracy: {best_acc} at epoch {best_ep}")

Epoch: 0


293it [00:32,  8.91it/s]

train loss: 4.65677594811949 - train acc: 58.497221034630186



33it [00:01, 22.73it/s]

valid loss: 2.9621065855026245 - valid acc: 66.00961538461539
Epoch: 1



293it [00:32,  8.95it/s]

train loss: 4.181484277934244 - train acc: 66.76998717400599



33it [00:01, 17.71it/s]

valid loss: 2.898215599358082 - valid acc: 66.0576923076923
Epoch: 2



293it [00:37,  7.80it/s]

train loss: 4.131686246558411 - train acc: 66.35848653270628



33it [00:01, 17.24it/s]

valid loss: 3.016777843236923 - valid acc: 64.66346153846155
Epoch: 3



293it [00:38,  7.57it/s]

train loss: 4.116657210539465 - train acc: 65.24690038477982



33it [00:02, 15.16it/s]

valid loss: 2.880805768072605 - valid acc: 59.27884615384615
Epoch: 4



293it [00:37,  7.81it/s]

train loss: 4.111311799042846 - train acc: 65.01710132535271



33it [00:02, 16.15it/s]

valid loss: 2.8542170003056526 - valid acc: 65.86538461538461
Epoch: 5



293it [00:37,  7.89it/s]

train loss: 4.106285821085107 - train acc: 64.9690038477982



33it [00:02, 16.10it/s]

valid loss: 2.8804452046751976 - valid acc: 57.59615384615384
Epoch: 6



293it [00:35,  8.21it/s]

train loss: 4.1012799290761555 - train acc: 65.06519880290722



33it [00:02, 15.93it/s]

valid loss: 2.8995436802506447 - valid acc: 62.21153846153846
Epoch: 7



293it [00:34,  8.40it/s]

train loss: 4.098745752687323 - train acc: 65.6316802052159



33it [00:01, 17.39it/s]

valid loss: 3.0175649225711823 - valid acc: 6.298076923076923
Epoch: 8



293it [00:29,  9.90it/s]

train loss: 4.09050930689459 - train acc: 67.51817015818726



33it [00:02, 15.91it/s]

valid loss: 2.889956869184971 - valid acc: 66.77884615384615
Epoch: 9



293it [00:25, 11.30it/s]

train loss: 4.083007860673617 - train acc: 68.36254809747754



33it [00:02, 15.95it/s]

valid loss: 2.81116783618927 - valid acc: 67.11538461538461
Epoch: 10



293it [00:33,  8.83it/s]

train loss: 4.071922956264182 - train acc: 68.9236853356135



33it [00:01, 23.05it/s]

valid loss: 2.9651056677103043 - valid acc: 66.0576923076923
Epoch: 11



293it [00:36,  7.96it/s]

train loss: 4.072496624842082 - train acc: 68.66716545532279



33it [00:01, 21.67it/s]

valid loss: 2.983697257936001 - valid acc: 66.0576923076923
Epoch: 12



293it [00:35,  8.23it/s]

train loss: 4.0679198764774895 - train acc: 68.93971782813168



33it [00:01, 19.69it/s]

valid loss: 2.8227678537368774 - valid acc: 68.3173076923077
Epoch: 13



293it [00:32,  8.93it/s]

train loss: 4.063525304402391 - train acc: 69.04125694741342



33it [00:01, 21.25it/s]

valid loss: 2.9751628562808037 - valid acc: 39.95192307692308
Epoch: 14



293it [00:34,  8.51it/s]

train loss: 4.059550690324339 - train acc: 69.15348439504062



33it [00:01, 24.74it/s]

valid loss: 2.832390807569027 - valid acc: 67.78846153846155
Epoch: 15



293it [00:35,  8.35it/s]

train loss: 4.0612085494276595 - train acc: 69.04125694741342



33it [00:01, 18.68it/s]

valid loss: 2.8821246922016144 - valid acc: 66.25
Epoch: 16



293it [00:34,  8.45it/s]

train loss: 4.057000051622522 - train acc: 69.2870884993587



33it [00:01, 26.06it/s]

valid loss: 2.833011195063591 - valid acc: 66.0576923076923
Epoch: 17



293it [00:35,  8.34it/s]

train loss: 4.057166893188268 - train acc: 69.16417272338606



33it [00:01, 18.94it/s]

valid loss: 2.821120209991932 - valid acc: 67.54807692307693
Epoch: 18



293it [00:32,  8.91it/s]

train loss: 4.0531594247034155 - train acc: 69.32449764856776



33it [00:01, 17.09it/s]

valid loss: 3.0222503691911697 - valid acc: 5.1923076923076925
Epoch: 19



293it [00:32,  9.12it/s]

train loss: 4.0525557129350425 - train acc: 69.49551090209492



33it [00:01, 21.06it/s]

valid loss: 2.8391196206212044 - valid acc: 67.11538461538461
Epoch: 20



293it [00:33,  8.66it/s]

train loss: 4.048359603097994 - train acc: 69.47947840957674



33it [00:01, 22.03it/s]

valid loss: 2.963358595967293 - valid acc: 64.42307692307693
Epoch: 21



293it [00:34,  8.47it/s]

train loss: 4.050135393665261 - train acc: 69.3993159469859



33it [00:02, 14.60it/s]

valid loss: 2.8397749438881874 - valid acc: 67.0673076923077
Epoch: 22



293it [00:34,  8.51it/s]

train loss: 4.0456890831254935 - train acc: 69.27640017101325



33it [00:01, 24.61it/s]

valid loss: 2.8406432569026947 - valid acc: 69.1826923076923
Epoch: 23



293it [00:37,  7.82it/s]

train loss: 4.046583612487741 - train acc: 69.47947840957674



33it [00:01, 16.64it/s]

valid loss: 2.82575536519289 - valid acc: 67.40384615384616
Epoch: 24



