In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import random
from glob import glob
from tqdm import tqdm
from scipy.io import loadmat

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import os, sys
from typing import *
import torch
import random
import copy
import warnings

from sklearn.metrics import classification_report

In [2]:
# print(os.getcwd())

# for i in range (3):
#     os.chdir("..")
    
# main_data_dir = os.getcwd() + "/Data set"

In [3]:
data_dir = "/media/mountHDD3/data_storage/biomedical_data/ecg_data/SPH"
print(os.listdir(data_dir))

['metadata.csv', 'data_df.csv', 'data_df_no1.csv', 'records']


In [4]:
main_df = pd.read_csv(data_dir + "/data_df.csv")
main_df.shape

(20792, 4)

In [5]:
single_fns = main_df["File name"].values.tolist()
single_mat_paths = [data_dir + f"/records/{x}.h5" for x in single_fns]

In [6]:
# ratio = [0.9, 0.1]

# train_index = int(len(single_mat_paths)*ratio[0])

# train_mat_paths = single_mat_paths[:train_index]
# valid_mat_paths = single_mat_paths[train_index:]

# train_label = main_df
# valid_label = main_df

In [7]:
# single_mat_paths = [data_dir + f"/alldata/{x}.mat" for x in single_fns]
# print(len(single_mat_paths))

In [8]:
ratio = [0.8, 0.1]

train_index = int(len(single_mat_paths)*ratio[0])
valid_index = int(len(single_mat_paths)*(ratio[0]+ratio[1]))

train_mat_paths = single_mat_paths[:train_index]
valid_mat_paths = single_mat_paths[valid_index:]

train_label = main_df
valid_label = main_df

In [9]:
import h5py

class HeartData(Dataset):
    def __init__(self, data_paths):
        self.data_paths = data_paths
        random.shuffle(self.data_paths)

    def __getitem__(self, idx):
        data_path = self.data_paths[idx]
        a = h5py.File(data_path, 'r')
        data_h5 = a['ecg']
        data = np.array(data_h5)
        clip_data = data[:, 0:3000]

        filename = data_path.split("/")[-1].split(".")[0]
        label = main_df[main_df["File name"] == filename]["New Label"].values.item()

        torch_data = torch.from_numpy(clip_data)

        return torch_data.float(), label

    def __len__(self):
        return len(self.data_paths)

In [10]:
train_ds = HeartData(train_mat_paths)
valid_ds = HeartData(valid_mat_paths)

In [11]:
train_dl = DataLoader(train_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)
valid_dl = DataLoader(valid_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)

In [12]:
class Auto_Encoder(nn.Module):
    def __init__(self):
        super(Auto_Encoder, self).__init__()

        self.enc = nn.Sequential(    
            # 12 x 3000
            nn.Conv1d(12, 64, 4, 2, 1),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 64 x 1500
            nn.Conv1d(64, 64, 3, 2, 1),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 64 x 750
            nn.Conv1d(64, 128, 3, 2, 1),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 128 x 375
            nn.Conv1d(128, 128, 3, 2, 1),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),  
            # 128 x 188
        )
        
        self.cls = nn.Sequential(
            nn.Dropout(p=0.2, inplace=False),
            nn.Linear(128, out_features=34, bias=True)
        )

        self.dec = nn.Sequential(
            # 128 x 188
            nn.ConvTranspose1d(128, 128, 3, 2, 1),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 128 x 375
            nn.ConvTranspose1d(128, 64, 2, 2),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 64 x 750
            nn.ConvTranspose1d(64, 64, 3, 2, 1, 1),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(), #Laeky relu 0.2
            # 64 x 1500
            nn.ConvTranspose1d(64, 12, 4, 2, 1),
            # 12 x 3000
            nn.Sigmoid()
        )


    def forward(self, x):
        enc = self.enc(x) # 128 x 7 x 7 
        cls = self.cls(enc.mean(dim=-1))
        dec = self.dec(enc)
        
        return dec, cls

In [13]:
class FocalClassifierV0(nn.Module):
    def __init__(self, gamma=0.3): #Change gamma value here in order to acquire other results
        super().__init__()
        
        self.gamma = gamma
        self.act = nn.LogSoftmax(dim=1)

    
    def forward(self, pred, target):

        logits = self.act(pred)

        B, C = tuple(logits.size())

        entropy = torch.pow(1 - logits, self.gamma) * logits * F.one_hot(target, num_classes=C).float()

        return (-1 / B) * torch.sum(entropy)

focalloss_fn = FocalClassifierV0()

In [14]:
epoch = 150
lr = 0.005
best_acc = 0
best_ep = 0
# class_la = [1,2,3,4,5,6,7,8,9]
# for i in range (len(class_la)):
#     class_la[i] = str(class_la[i])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu", index = 0)
# device = torch.device("cpu", index = 0)

model = Auto_Encoder().to(device)
optimizer = Adam(model.parameters(), lr=lr)
scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=epoch)
loss_fn_cls = nn.CrossEntropyLoss()
loss_fn_sig = nn.MSELoss()

In [None]:
warnings.filterwarnings('ignore')
for e in range(epoch):
    model.train()
    print(f"Epoch: {e}")
    y_true_list = [] 
    pred_list = []
    batch_cnt = 0
    total_loss = 0
    correct = 0
    for batch, (train_sig, train_label) in tqdm(enumerate(train_dl)):
        batch_cnt = batch
        train_sig = train_sig.to(device)
        train_label = train_label.to(device)
        
        res_sig, pred_cls = model(train_sig)
        loss_cls = focalloss_fn(pred_cls, train_label)
        loss_sig = loss_fn_sig(res_sig, train_sig)
        loss_tot = loss_cls + loss_sig
        
        optimizer.zero_grad()
        loss_tot.backward()
        optimizer.step()
        
        scheduler.step()
        
        total_loss += loss_tot.item()
        correct += (pred_cls.argmax(1) == train_label).type(torch.float).sum().item()
    
    total_loss /= batch_cnt
    correct /= len(train_dl.dataset)
    
    print(f"train loss: {total_loss} - train acc: {100*correct}")
    
    batch_cnt = 0
    val_total_loss = 0
    val_correct = 0
    model.eval()
    with torch.no_grad():
        for batch, (valid_sig, valid_label) in tqdm(enumerate(valid_dl)):
            batch_cnt = batch
            valid_sig = valid_sig.to(device)
            valid_label = valid_label.to(device)
            
            res_sig, pred_cls = model(valid_sig)
            loss_cls = loss_fn_cls(pred_cls, valid_label)
            loss_sig = loss_fn_sig(res_sig, valid_sig)
            loss_tot = loss_cls + loss_sig
            
            pred_pos = pred_cls.argmax(1)
            y_true_list.append(valid_label)
            pred_list.append(pred_pos)

            val_total_loss += loss_tot.item()
            val_correct += (pred_cls.argmax(1) == valid_label).type(torch.float).sum().item()
    
        val_total_loss /= batch_cnt
        val_correct /= len(valid_dl.dataset)
        if val_correct > best_acc:
            best_acc = val_correct
            best_ep = e
        
        print(f"valid loss: {val_total_loss} - valid acc: {100*val_correct}")

        
y_true = torch.cat(y_true_list).cpu().numpy()
pred = torch.cat(pred_list).cpu().numpy()

reports = classification_report(y_true, pred, output_dict=True)

print(reports)
print(f"Best acuracy: {best_acc} at epoch {best_ep}")

# 0.8091346153846154 at epoch 104

Epoch: 0


260it [00:23, 11.09it/s]

train loss: 2.363102183839069 - train acc: 68.77893344556003



33it [00:01, 26.08it/s]

valid loss: 2.8282896764576435 - valid acc: 43.55769230769231
Epoch: 1



260it [00:10, 24.87it/s]

train loss: 2.038209446838924 - train acc: 71.68881139902604



33it [00:01, 26.84it/s]

valid loss: 2.2859735507518053 - valid acc: 70.96153846153847
Epoch: 2



260it [00:08, 30.03it/s]

train loss: 1.9759728411449888 - train acc: 71.96537004749595



33it [00:01, 22.28it/s]

valid loss: 2.330646963790059 - valid acc: 72.25961538461539
Epoch: 3



260it [00:11, 21.69it/s]

train loss: 1.9081818203668337 - train acc: 72.66879095773461



33it [00:01, 26.58it/s]

valid loss: 2.4533389415591955 - valid acc: 63.99038461538461
Epoch: 4



260it [00:12, 21.33it/s]

train loss: 1.8377804450086646 - train acc: 73.56460049299585



33it [00:01, 31.74it/s]

valid loss: 2.9914537705481052 - valid acc: 39.08653846153846
Epoch: 5



260it [00:11, 23.29it/s]

train loss: 1.8037025452120423 - train acc: 73.76901340708231



33it [00:01, 30.58it/s]

valid loss: 2.23190444894135 - valid acc: 72.54807692307692
Epoch: 6



260it [00:10, 24.75it/s]

train loss: 1.7576426111132941 - train acc: 74.17783923525522



33it [00:01, 26.98it/s]

valid loss: 2.429736217483878 - valid acc: 72.35576923076923
Epoch: 7



260it [00:11, 23.64it/s]

train loss: 1.7134450802931915 - train acc: 75.01953946972885



33it [00:01, 30.80it/s]

valid loss: 2.3322845436632633 - valid acc: 73.50961538461539
Epoch: 8



260it [00:10, 24.21it/s]

train loss: 1.701278892493156 - train acc: 75.31413455179462



33it [00:01, 28.91it/s]

valid loss: 2.3057487420737743 - valid acc: 62.88461538461539
Epoch: 9



260it [00:10, 24.03it/s]

train loss: 1.7066725025305878 - train acc: 75.30211026273072



33it [00:01, 25.32it/s]

valid loss: 2.087089642882347 - valid acc: 73.26923076923076
Epoch: 10



260it [00:11, 23.01it/s]

train loss: 1.6816533916705363 - train acc: 75.85522755967054



33it [00:01, 28.85it/s]

valid loss: 2.167366983368993 - valid acc: 76.25
Epoch: 11



260it [00:09, 27.99it/s]

train loss: 1.646031306747304 - train acc: 76.02957975109722



33it [00:01, 27.66it/s]

valid loss: 2.1340427435934544 - valid acc: 76.49038461538461
Epoch: 12



260it [00:11, 22.32it/s]

train loss: 1.62232351694328 - train acc: 76.64281849335659



33it [00:01, 26.89it/s]

valid loss: 2.000857012346387 - valid acc: 73.41346153846153
Epoch: 13



260it [00:10, 24.65it/s]

train loss: 1.6063179348426435 - train acc: 76.87127998557085



33it [00:01, 30.06it/s]

valid loss: 2.140078794211149 - valid acc: 74.95192307692308
Epoch: 14



260it [00:08, 31.56it/s]

train loss: 1.5871475329730502 - train acc: 76.70293993867612



33it [00:01, 26.38it/s]

valid loss: 2.6330064237117767 - valid acc: 50.76923076923077
Epoch: 15



260it [00:12, 21.25it/s]

train loss: 1.5723793994505892 - train acc: 77.34623940359526



33it [00:01, 29.58it/s]

valid loss: 2.127269886434078 - valid acc: 76.15384615384615
Epoch: 16



260it [00:11, 22.43it/s]

train loss: 1.5746377000017056 - train acc: 77.27409366921181



33it [00:00, 40.30it/s]

valid loss: 2.363322051241994 - valid acc: 67.35576923076924
Epoch: 17



260it [00:09, 27.75it/s]

train loss: 1.5700437974285435 - train acc: 77.25004509108399



33it [00:01, 26.13it/s]

valid loss: 2.043008567765355 - valid acc: 75.38461538461539
Epoch: 18



260it [00:10, 24.30it/s]

train loss: 1.5531448820843199 - train acc: 77.35225154812721



33it [00:01, 26.24it/s]

valid loss: 1.9746236614882946 - valid acc: 72.83653846153845
Epoch: 19



260it [00:11, 22.66it/s]

train loss: 1.5318226856844765 - train acc: 78.03763602477004



33it [00:00, 40.48it/s]

valid loss: 2.4629353415220976 - valid acc: 59.03846153846154
Epoch: 20



260it [00:10, 25.75it/s]

train loss: 1.5235730949285868 - train acc: 77.91138098959898



33it [00:01, 22.26it/s]


valid loss: 2.333992687985301 - valid acc: 75.24038461538461
Epoch: 21


260it [00:12, 21.07it/s]

train loss: 1.4954555443124882 - train acc: 78.29014609511212



33it [00:01, 26.26it/s]

valid loss: 2.3693869467824697 - valid acc: 77.0673076923077
Epoch: 22



260it [00:08, 30.35it/s]

train loss: 1.4901541789065917 - train acc: 78.29014609511212



33it [00:01, 28.82it/s]

valid loss: 2.421873116865754 - valid acc: 71.0576923076923
Epoch: 23



260it [00:11, 22.20it/s]

train loss: 1.4922053512459097 - train acc: 78.63885047796549



33it [00:01, 25.66it/s]

valid loss: 2.2215776685625315 - valid acc: 75.38461538461539
Epoch: 24



260it [00:11, 22.42it/s]

train loss: 1.4939029124712853 - train acc: 78.37431611855949



33it [00:00, 37.61it/s]

valid loss: 2.014144852757454 - valid acc: 78.46153846153847
Epoch: 25



260it [00:09, 27.96it/s]

train loss: 1.478141222681318 - train acc: 78.77712980220043



33it [00:01, 31.62it/s]

valid loss: 1.9896768247708678 - valid acc: 78.89423076923077
Epoch: 26



260it [00:12, 21.21it/s]

train loss: 1.4909960986564519 - train acc: 78.66289905609331



33it [00:01, 29.05it/s]

valid loss: 2.553450735285878 - valid acc: 62.019230769230774
Epoch: 27



260it [00:09, 28.50it/s]

train loss: 1.4684261930495155 - train acc: 79.05970059520232



33it [00:01, 23.99it/s]

valid loss: 2.020404424518347 - valid acc: 76.82692307692308
Epoch: 28



260it [00:09, 26.31it/s]

train loss: 1.4410989747084246 - train acc: 79.41441712258762



33it [00:01, 22.89it/s]

valid loss: 2.061562789604068 - valid acc: 76.25
Epoch: 29



260it [00:09, 26.57it/s]

train loss: 1.4384617138093043 - train acc: 79.28214994288463



33it [00:01, 30.15it/s]

valid loss: 2.277484640479088 - valid acc: 77.01923076923077
Epoch: 30



260it [00:09, 26.55it/s]

train loss: 1.4415420378957475 - train acc: 79.48055071243913



33it [00:01, 24.78it/s]

valid loss: 4.079219937324524 - valid acc: 26.778846153846153
Epoch: 31



260it [00:10, 25.57it/s]

train loss: 1.487941875766143 - train acc: 78.64486262249744



33it [00:01, 24.50it/s]

valid loss: 2.230057622306049 - valid acc: 79.27884615384615
Epoch: 32



260it [00:12, 21.26it/s]

train loss: 1.4637942072507497 - train acc: 78.97553057175494



33it [00:01, 20.88it/s]

valid loss: 2.294026840478182 - valid acc: 74.66346153846153
Epoch: 33



260it [00:10, 25.55it/s]

train loss: 1.429444583111288 - train acc: 79.5767450249504



33it [00:01, 25.81it/s]

valid loss: 1.964141819626093 - valid acc: 77.0673076923077
Epoch: 34



260it [00:13, 19.30it/s]

train loss: 1.4208549243833108 - train acc: 79.40840497805567



33it [00:01, 23.71it/s]

valid loss: 2.1803986206650734 - valid acc: 69.1826923076923
Epoch: 35



260it [00:06, 39.53it/s]

train loss: 1.4074210086154202 - train acc: 79.67895148199364



33it [00:01, 17.97it/s]

valid loss: 2.3141849376261234 - valid acc: 76.15384615384615
Epoch: 36



260it [00:13, 19.51it/s]

train loss: 1.4028411179665894 - train acc: 79.92544940780377



33it [00:01, 26.96it/s]

valid loss: 2.047720344737172 - valid acc: 73.84615384615385
Epoch: 37



260it [00:06, 40.56it/s]

train loss: 1.4000796553711172 - train acc: 80.04569229844284



33it [00:00, 38.47it/s]


valid loss: 1.9958821553736925 - valid acc: 78.3173076923077
Epoch: 38


260it [00:12, 20.79it/s]

train loss: 1.4042188964755378 - train acc: 79.90140082967595



33it [00:01, 20.51it/s]

valid loss: 2.082120737992227 - valid acc: 78.9423076923077
Epoch: 39



260it [00:10, 25.33it/s]

train loss: 1.40357258996448 - train acc: 79.78115793903685



33it [00:01, 24.95it/s]

valid loss: 1.9999539200216532 - valid acc: 78.89423076923077
Epoch: 40



260it [00:10, 23.97it/s]

train loss: 1.3932669986629118 - train acc: 80.16593518908195



33it [00:01, 20.95it/s]

valid loss: 2.0329557731747627 - valid acc: 77.64423076923077
Epoch: 41



260it [00:09, 27.56it/s]


train loss: 1.3991840685426498 - train acc: 80.11182588829435


33it [00:01, 23.15it/s]

valid loss: 1.87350516859442 - valid acc: 78.50961538461539
Epoch: 42



260it [00:09, 28.60it/s]

train loss: 1.3822343266608632 - train acc: 80.3703481031684



33it [00:01, 20.77it/s]

valid loss: 2.24451950378716 - valid acc: 71.73076923076923
Epoch: 43



260it [00:09, 26.58it/s]

train loss: 1.3775138273892715 - train acc: 80.3042145133169



33it [00:01, 25.66it/s]

valid loss: 2.807211885228753 - valid acc: 71.20192307692308
Epoch: 44



260it [00:09, 26.14it/s]

train loss: 1.395201950459867 - train acc: 79.93747369686768



33it [00:00, 34.39it/s]

valid loss: 2.074190284125507 - valid acc: 78.07692307692308
Epoch: 45



260it [00:10, 23.93it/s]

train loss: 1.404620073822014 - train acc: 79.75710936090904



33it [00:01, 20.20it/s]

valid loss: 2.2337886784225702 - valid acc: 77.6923076923077
Epoch: 46



260it [00:12, 21.18it/s]

train loss: 1.4791924910885947 - train acc: 79.10779775145794



33it [00:01, 23.81it/s]

valid loss: 2.0081645250320435 - valid acc: 78.50961538461539
Epoch: 47



260it [00:09, 27.32it/s]

train loss: 1.4192221227990154 - train acc: 79.87134010701618



33it [00:01, 25.37it/s]

valid loss: 2.06200756970793 - valid acc: 79.66346153846153
Epoch: 48



260it [00:12, 20.09it/s]

train loss: 1.3864677533449814 - train acc: 79.78115793903685



33it [00:01, 30.50it/s]

valid loss: 1.9917515739798546 - valid acc: 79.42307692307692
Epoch: 49



260it [00:06, 37.26it/s]

train loss: 1.3662761226584093 - train acc: 80.58678530631876



33it [00:01, 21.29it/s]

valid loss: 1.982814253307879 - valid acc: 77.83653846153847
Epoch: 50



260it [00:12, 20.47it/s]

train loss: 1.3528735209615994 - train acc: 80.89340467744844



33it [00:01, 22.00it/s]

valid loss: 2.001324787735939 - valid acc: 75.67307692307692
Epoch: 51



260it [00:07, 32.83it/s]

train loss: 1.349615479067946 - train acc: 80.78518607587326



33it [00:00, 46.14it/s]

valid loss: 1.9896288756281137 - valid acc: 76.49038461538461
Epoch: 52



260it [00:11, 21.70it/s]

train loss: 1.3540619131911216 - train acc: 80.78518607587326



