In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import random
from glob import glob
from tqdm import tqdm
from scipy.io import loadmat

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import os, sys
from typing import *
import torch
import random
import copy
import warnings

from sklearn.metrics import classification_report

In [2]:
data_dir = "/media/mountHDD3/data_storage/biomedical_data/ecg_data/SPH"
print(os.listdir(data_dir))

['metadata.csv', 'data_df.csv', 'data_df_no1.csv', 'records']


In [3]:
main_df = pd.read_csv(data_dir + "/data_df.csv")
main_df.shape

(20792, 4)

In [4]:
single_fns = main_df["File name"].values.tolist()
single_mat_paths = [data_dir + f"/records/{x}.h5" for x in single_fns]

In [5]:
ratio = [0.9, 0.1]

train_index = int(len(single_mat_paths)*ratio[0])

train_mat_paths = single_mat_paths[:train_index]
valid_mat_paths = single_mat_paths[train_index:]

In [6]:
import h5py

class HeartData(Dataset):
    def __init__(self, data_paths):
        self.data_paths = data_paths
        random.shuffle(self.data_paths)

    def __getitem__(self, idx):
        data_path = self.data_paths[idx]
        a = h5py.File(data_path, 'r')
        data_h5 = a['ecg']
        data = np.array(data_h5)
        clip_data = data[:, 0:3000]

        filename = data_path.split("/")[-1].split(".")[0]
        label = main_df[main_df["File name"] == filename]["New Label"].values.item()

        torch_data = torch.from_numpy(clip_data)

        return torch_data.float(), label

    def __len__(self):
        return len(self.data_paths)

In [7]:
train_ds = HeartData(train_mat_paths)
valid_ds = HeartData(valid_mat_paths)

In [8]:
train_dl = DataLoader(train_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)
valid_dl = DataLoader(valid_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)

In [9]:
class Auto_Encoder(nn.Module):
    def __init__(self):
        super(Auto_Encoder, self).__init__()

        self.enc = nn.Sequential(    
            # 12 x 3000
            nn.Conv1d(12, 64, 4, 2, 1),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 64 x 1500
            nn.Conv1d(64, 64, 3, 2, 1),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 64 x 750
            nn.Conv1d(64, 128, 3, 2, 1),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 128 x 375
            nn.Conv1d(128, 128, 3, 2, 1),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),  
            # 128 x 188
        )
        
        self.cls = nn.Sequential(
            nn.Dropout(p=0.2, inplace=False),
            nn.Linear(128, out_features=34, bias=True)
        )

        self.dec = nn.Sequential(
            # 128 x 188
            nn.ConvTranspose1d(128, 128, 3, 2, 1),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 128 x 375
            nn.ConvTranspose1d(128, 64, 2, 2),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 64 x 750
            nn.ConvTranspose1d(64, 64, 3, 2, 1, 1),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(), #Laeky relu 0.2
            # 64 x 1500
            nn.ConvTranspose1d(64, 12, 4, 2, 1),
            # 12 x 3000
            nn.Sigmoid()
        )


    def forward(self, x):
        enc = self.enc(x) # 128 x 7 x 7 
        cls = self.cls(enc.mean(dim=-1))
        dec = self.dec(enc)
        
        return dec, cls

In [10]:
class FocalClassifierV0(nn.Module):
    def __init__(self, gamma=0.3): #Change gamma value here in order to acquire other results
        super().__init__()
        
        self.gamma = gamma
        self.act = nn.LogSoftmax(dim=1)

    
    def forward(self, pred, target):

        logits = self.act(pred)

        B, C = tuple(logits.size())

        entropy = torch.pow(1 - logits, self.gamma) * logits * F.one_hot(target, num_classes=C).float()

        return (-1 / B) * torch.sum(entropy)

focalloss_fn = FocalClassifierV0()

In [11]:
epoch = 150
lr = 0.001
best_acc = 0
best_ep = 0


device = torch.device("cuda" if torch.cuda.is_available() else "cpu", index = 1)
model = Auto_Encoder().to(device)
optimizer = Adam(model.parameters(), lr=lr)
scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=epoch)
loss_fn_cls = nn.CrossEntropyLoss()
loss_fn_sig = nn.MSELoss()

In [None]:
warnings.filterwarnings('ignore')
for e in range(epoch):
    model.train()
    print(f"Epoch: {e}")
    y_true_list = [] 
    pred_list = []
    batch_cnt = 0
    total_loss = 0
    correct = 0
    for batch, (train_sig, train_label) in tqdm(enumerate(train_dl)):
        batch_cnt = batch
        train_sig = train_sig.to(device)
        train_label = train_label.to(device)
        
        res_sig, pred_cls = model(train_sig)
        loss_cls = focalloss_fn(pred_cls, train_label)
        loss_sig = loss_fn_sig(res_sig, train_sig)
        loss_tot = loss_cls + loss_sig
        
        optimizer.zero_grad()
        loss_tot.backward()
        optimizer.step()
        
        scheduler.step()
        
        total_loss += loss_tot.item()
        correct += (pred_cls.argmax(1) == train_label).type(torch.float).sum().item()
    
    total_loss /= batch_cnt
    correct /= len(train_dl.dataset)
    
    print(f"train loss: {total_loss} - train acc: {100*correct}")
    
    batch_cnt = 0
    val_total_loss = 0
    val_correct = 0
    model.eval()
    with torch.no_grad():
        for batch, (valid_sig, valid_label) in tqdm(enumerate(valid_dl)):
            batch_cnt = batch
            valid_sig = valid_sig.to(device)
            valid_label = valid_label.to(device)
            
            res_sig, pred_cls = model(valid_sig)
            loss_cls = loss_fn_cls(pred_cls, valid_label)
            loss_sig = loss_fn_sig(res_sig, valid_sig)
            loss_tot = loss_cls + loss_sig
            
            pred_pos = pred_cls.argmax(1)
            y_true_list.append(valid_label)
            pred_list.append(pred_pos)

            val_total_loss += loss_tot.item()
            val_correct += (pred_cls.argmax(1) == valid_label).type(torch.float).sum().item()
    
        val_total_loss /= batch_cnt
        val_correct /= len(valid_dl.dataset)
        if val_correct > best_acc:
            best_acc = val_correct
            best_ep = e
        
        print(f"valid loss: {val_total_loss} - valid acc: {100*val_correct}")

        
y_true = torch.cat(y_true_list).cpu().numpy()
pred = torch.cat(pred_list).cpu().numpy()

reports = classification_report(y_true, pred, output_dict=True)

print(reports)
print(f"Best acuracy: {best_acc} at epoch {best_ep}")

Epoch: 0


293it [00:08, 35.79it/s]

train loss: 2.62947101339902 - train acc: 67.2509619495511



33it [00:00, 47.47it/s] 

valid loss: 2.3540140986442566 - valid acc: 68.99038461538461
Epoch: 1



293it [00:06, 46.24it/s]

train loss: 2.0624672272433973 - train acc: 70.35057716973066



33it [00:00, 54.48it/s]

valid loss: 2.1150878090411425 - valid acc: 71.15384615384616
Epoch: 2



293it [00:04, 59.81it/s] 

train loss: 1.9444456284176814 - train acc: 71.67058572039333



33it [00:01, 32.39it/s]

valid loss: 2.1561268232762814 - valid acc: 70.72115384615385
Epoch: 3



293it [00:04, 64.47it/s]

train loss: 1.862086401410299 - train acc: 72.61115861479264



33it [00:00, 48.52it/s]

valid loss: 2.5039987042546272 - valid acc: 58.12500000000001
Epoch: 4



293it [00:04, 67.48it/s] 

train loss: 1.809939130732458 - train acc: 73.09213339033775



33it [00:00, 48.91it/s]

valid loss: 2.0788914281874895 - valid acc: 71.53846153846153
Epoch: 5



293it [00:03, 84.26it/s] 

train loss: 1.7569357446611744 - train acc: 73.76015391192817



33it [00:00, 40.04it/s]

valid loss: 2.3798057455569506 - valid acc: 63.55769230769231
Epoch: 6



293it [00:05, 56.77it/s] 

train loss: 1.7348307748771694 - train acc: 73.96323215049166



33it [00:01, 26.41it/s]

valid loss: 2.3787407726049423 - valid acc: 63.36538461538461
Epoch: 7



293it [00:04, 70.23it/s] 

train loss: 1.7187057237510812 - train acc: 74.21440786660966



33it [00:01, 18.51it/s]

valid loss: 2.7573777679353952 - valid acc: 62.83653846153846
Epoch: 8



293it [00:03, 92.51it/s] 

train loss: 1.6785108189876765 - train acc: 74.98396750748184



33it [00:00, 56.83it/s]

valid loss: 2.127294886857271 - valid acc: 73.50961538461539
Epoch: 9



293it [00:04, 62.49it/s] 

train loss: 1.650914315287381 - train acc: 75.08550662676357



33it [00:01, 31.18it/s]

valid loss: 2.1218699365854263 - valid acc: 72.3076923076923
Epoch: 10



293it [00:04, 60.83it/s] 

train loss: 1.6347618825631598 - train acc: 75.51303976058145



33it [00:00, 51.96it/s]

valid loss: 2.5363092459738255 - valid acc: 55.144230769230774
Epoch: 11



293it [00:04, 60.47it/s] 

train loss: 1.628661958730384 - train acc: 75.49700726806327



33it [00:00, 58.40it/s]

valid loss: 2.064572250470519 - valid acc: 76.29807692307692
Epoch: 12



293it [00:06, 48.39it/s] 

train loss: 1.5992238094953641 - train acc: 75.94057289439931



33it [00:00, 42.23it/s]

valid loss: 2.073779696598649 - valid acc: 76.53846153846153
Epoch: 13



293it [00:07, 40.46it/s]

train loss: 1.5864371741062975 - train acc: 76.09020949123557



33it [00:01, 32.72it/s]

valid loss: 2.0009822817519307 - valid acc: 76.58653846153847
Epoch: 14



293it [00:04, 58.86it/s]

train loss: 1.57777888440106 - train acc: 76.10089781958102



33it [00:01, 26.43it/s]

valid loss: 2.0061372872442007 - valid acc: 76.34615384615384
Epoch: 15



293it [00:05, 52.61it/s]

train loss: 1.5601082256395522 - train acc: 76.36276186404446



33it [00:01, 30.32it/s]

valid loss: 2.1204007230699062 - valid acc: 69.1826923076923
Epoch: 16



293it [00:04, 60.86it/s] 

train loss: 1.5415429717873874 - train acc: 76.49636596836254



33it [00:01, 22.26it/s]

valid loss: 2.1815731581300497 - valid acc: 67.54807692307693
Epoch: 17



293it [00:03, 80.82it/s] 

train loss: 1.5220370572315502 - train acc: 77.00406156477126



33it [00:00, 58.50it/s]

valid loss: 2.392275758087635 - valid acc: 66.875
Epoch: 18



293it [00:04, 62.21it/s] 

train loss: 1.5298223221955234 - train acc: 76.75288584865328



33it [00:00, 48.68it/s]

valid loss: 2.458867823705077 - valid acc: 56.25
Epoch: 19



293it [00:03, 78.58it/s] 

train loss: 1.5070938404700527 - train acc: 77.1269773407439



33it [00:00, 39.25it/s]

valid loss: 1.8739322647452354 - valid acc: 77.83653846153847
Epoch: 20



293it [00:04, 71.39it/s] 

train loss: 1.5016069898050126 - train acc: 77.38349722103463



33it [00:00, 51.71it/s]

valid loss: 2.027285421267152 - valid acc: 71.53846153846153
Epoch: 21



293it [00:04, 70.35it/s] 

train loss: 1.4788224755288804 - train acc: 77.6988029072253



33it [00:00, 34.05it/s]

valid loss: 1.8646070510149002 - valid acc: 77.74038461538461
Epoch: 22



293it [00:04, 72.29it/s] 

train loss: 1.4977162898811576 - train acc: 77.31936725096195



33it [00:00, 40.00it/s]

valid loss: 2.1828516013920307 - valid acc: 70.86538461538461
Epoch: 23



293it [00:04, 72.59it/s] 


train loss: 1.491130993178446 - train acc: 77.36212056434374


33it [00:01, 26.43it/s]

valid loss: 2.1432315967977047 - valid acc: 76.0576923076923
Epoch: 24



293it [00:03, 73.90it/s] 

train loss: 1.4669463124177227 - train acc: 77.6507054296708



33it [00:00, 40.52it/s]

valid loss: 1.886636182665825 - valid acc: 76.20192307692307
Epoch: 25



293it [00:03, 85.65it/s] 

train loss: 1.4517156051038063 - train acc: 78.14771269773408



33it [00:00, 55.15it/s]

valid loss: 1.9418463297188282 - valid acc: 78.3173076923077
Epoch: 26



293it [00:03, 77.17it/s] 

train loss: 1.4424519371496487 - train acc: 78.13702436938863



33it [00:00, 43.58it/s]

valid loss: 2.0097037088125944 - valid acc: 79.75961538461539
Epoch: 27



293it [00:05, 48.85it/s] 

train loss: 1.4360942582357419 - train acc: 78.27062847370672



33it [00:00, 58.33it/s]

valid loss: 1.9445739835500717 - valid acc: 78.84615384615384
Epoch: 28



293it [00:05, 55.51it/s] 

train loss: 1.4328785371290493 - train acc: 78.53783668234287



33it [00:00, 54.26it/s]

valid loss: 1.932635836303234 - valid acc: 76.15384615384615
Epoch: 29



293it [00:04, 60.95it/s] 

train loss: 1.4201037793943327 - train acc: 78.73022659256092



33it [00:00, 42.11it/s]

valid loss: 2.021815488114953 - valid acc: 77.88461538461539
Epoch: 30



293it [00:06, 48.67it/s] 

train loss: 1.418915847903245 - train acc: 78.63937580162462



33it [00:01, 22.34it/s]

valid loss: 2.410148648545146 - valid acc: 61.58653846153847
Epoch: 31



293it [00:05, 56.07it/s]

train loss: 1.4156624666632038 - train acc: 78.60731081658828



33it [00:01, 23.88it/s]

valid loss: 2.1829073075205088 - valid acc: 76.58653846153847
Epoch: 32



293it [00:04, 65.62it/s] 

train loss: 1.395212031390569 - train acc: 79.30205215904232



33it [00:00, 43.63it/s]

valid loss: 2.2681895904242992 - valid acc: 70.67307692307693
Epoch: 33



293it [00:05, 55.80it/s] 

train loss: 1.3944758789180076 - train acc: 79.07759726378795



33it [00:00, 34.92it/s]

valid loss: 2.468285508453846 - valid acc: 66.00961538461539
Epoch: 34



293it [00:06, 47.80it/s] 

train loss: 1.382510972757862 - train acc: 79.23792218896965



33it [00:00, 49.99it/s]

valid loss: 1.9797816118225455 - valid acc: 76.49038461538461
Epoch: 35



186it [00:03, 77.37it/s]