In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import random
from glob import glob
from tqdm import tqdm
from scipy.io import loadmat

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import os, sys
from typing import *
import torch
import random
import copy
import warnings

In [2]:
data_dir = "/media/mountHDD3/data_storage/biomedical_data/ecg_data/SPH"
print(os.listdir(data_dir))

['metadata.csv', 'data_df.csv', 'data_df_no1.csv', 'records']


In [3]:
main_df = pd.read_csv(data_dir + "/data_df.csv")
main_df.shape

(20792, 4)

In [4]:
single_fns = main_df["File name"].values.tolist()
single_mat_paths = [data_dir + f"/records/{x}.h5" for x in single_fns]

In [5]:
ratio = [0.9, 0.1]

train_index = int(len(single_mat_paths)*ratio[0])

train_mat_paths = single_mat_paths[:train_index]
valid_mat_paths = single_mat_paths[train_index:]


train_label = main_df
valid_label = main_df

In [6]:
import h5py

class HeartData(Dataset):
    def __init__(self, data_paths):
        self.data_paths = data_paths
        random.shuffle(self.data_paths)

    def __getitem__(self, idx):
        data_path = self.data_paths[idx]
        a = h5py.File(data_path, 'r')
        data_h5 = a['ecg']
        data = np.array(data_h5)
        clip_data = data[:, 0:3000]

        filename = data_path.split("/")[-1].split(".")[0]
        label = main_df[main_df["File name"] == filename]["New Label"].values.item()

        torch_data = torch.from_numpy(clip_data)

        return torch_data.float(), label

    def __len__(self):
        return len(self.data_paths)

In [7]:
train_ds = HeartData(train_mat_paths)
valid_ds = HeartData(valid_mat_paths)

In [8]:
train_dl = DataLoader(train_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)
valid_dl = DataLoader(valid_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)

In [9]:
class Auto_Encoder(nn.Module):
    def __init__(self):
        super(Auto_Encoder, self).__init__()

        self.enc = nn.Sequential(    
            # 12 x 3000
            nn.Conv1d(12, 64, 4, 2, 1),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 64 x 1500
            nn.Conv1d(64, 64, 3, 2, 1),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 64 x 750
            nn.Conv1d(64, 128, 3, 2, 1),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 128 x 375
            nn.Conv1d(128, 128, 3, 2, 1),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),  
            # 128 x 188
        )
        
        self.cls = nn.Sequential(
            nn.Dropout(p=0.2, inplace=False),
            nn.Linear(128, out_features=34, bias=True)
        )

        self.dec = nn.Sequential(
            # 128 x 188
            nn.ConvTranspose1d(128, 128, 3, 2, 1),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 128 x 375
            nn.ConvTranspose1d(128, 64, 2, 2),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 64 x 750
            nn.ConvTranspose1d(64, 64, 3, 2, 1, 1),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(), #Laeky relu 0.2
            # 64 x 1500
            nn.ConvTranspose1d(64, 12, 4, 2, 1),
            # 12 x 3000
            nn.Sigmoid()
        )


    def forward(self, x):
        enc = self.enc(x) # 128 x 7 x 7 
        cls = self.cls(enc.mean(dim=-1))
        dec = self.dec(enc)
        
        return dec, cls

In [10]:
epoch = 100
lr = 0.001
best_acc = 0
best_ep = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu", index = 1)
model = Auto_Encoder().to(device)
optimizer = Adam(model.parameters(), lr=lr)
scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=epoch)
loss_fn_cls = nn.CrossEntropyLoss()
loss_fn_sig = nn.MSELoss()

In [None]:
warnings.filterwarnings('ignore')
for e in range(epoch):
    model.train()
    print(f"Epoch: {e}")
    y_true_list = [] 
    pred_list = []
    batch_cnt = 0
    total_loss = 0
    correct = 0
    for batch, (train_sig, train_label) in tqdm(enumerate(train_dl)):
        batch_cnt = batch
        train_sig = train_sig.to(device)
        train_label = train_label.to(device)
        
        res_sig, pred_cls = model(train_sig)
        loss_cls = loss_fn_cls(pred_cls, train_label)
        loss_sig = loss_fn_sig(res_sig, train_sig)
        loss_tot = loss_cls + loss_sig
        
        optimizer.zero_grad()
        loss_tot.backward()
        optimizer.step()
        
        scheduler.step()
        
        total_loss += loss_tot.item()
        correct += (pred_cls.argmax(1) == train_label).type(torch.float).sum().item()
    
    total_loss /= batch_cnt
    correct /= len(train_dl.dataset)
    
    print(f"train loss: {total_loss} - train acc: {100*correct}")
    
    batch_cnt = 0
    val_total_loss = 0
    val_correct = 0
    model.eval()
    with torch.no_grad():
        for batch, (valid_sig, valid_label) in tqdm(enumerate(valid_dl)):
            batch_cnt = batch
            valid_sig = valid_sig.to(device)
            valid_label = valid_label.to(device)
            
            res_sig, pred_cls = model(valid_sig)
            loss_cls = loss_fn_cls(pred_cls, valid_label)
            loss_sig = loss_fn_sig(res_sig, valid_sig)
            loss_tot = loss_cls + loss_sig
            
            pred_pos = pred_cls.argmax(1)
            y_true_list.append(valid_label)
            pred_list.append(pred_pos)

            val_total_loss += loss_tot.item()
            val_correct += (pred_cls.argmax(1) == valid_label).type(torch.float).sum().item()
    
        val_total_loss /= batch_cnt
        val_correct /= len(valid_dl.dataset)
        if val_correct > best_acc:
            best_acc = val_correct
            best_ep = e
        
        print(f"valid loss: {val_total_loss} - valid acc: {100*val_correct}")

        
y_true = torch.cat(y_true_list).cpu().numpy()
pred = torch.cat(pred_list).cpu().numpy()

reports = classification_report(y_true, pred, output_dict=True)

print(reports)
print(f"Best acuracy: {best_acc} at epoch {best_ep}")

Epoch: 0


293it [00:06, 47.16it/s] 

train loss: 1.9791562379631278 - train acc: 67.19217614365112



33it [00:00, 48.66it/s]

valid loss: 2.2053504399955273 - valid acc: 69.375
Epoch: 1



293it [00:03, 88.31it/s] 

train loss: 1.5919805814142096 - train acc: 70.35057716973066



33it [00:00, 44.80it/s]

valid loss: 2.2181900180876255 - valid acc: 70.72115384615385
Epoch: 2



293it [00:04, 66.25it/s] 

train loss: 1.5155250540334884 - train acc: 71.44078666096622



33it [00:00, 51.25it/s]

valid loss: 2.7061340045183897 - valid acc: 60.24038461538461
Epoch: 3



293it [00:06, 45.85it/s]

train loss: 1.455223431121813 - train acc: 72.62719110731082



33it [00:00, 40.00it/s]

valid loss: 2.2795509174466133 - valid acc: 66.25
Epoch: 4



293it [00:06, 45.54it/s]

train loss: 1.4221444438173347 - train acc: 73.2150491663104



33it [00:00, 51.27it/s]

valid loss: 2.109616033732891 - valid acc: 71.58653846153847
Epoch: 5



293it [00:06, 46.13it/s]

train loss: 1.3938810057836035 - train acc: 73.74412141941



33it [00:00, 42.33it/s]

valid loss: 2.17787566781044 - valid acc: 71.875
Epoch: 6



293it [00:05, 52.51it/s]

train loss: 1.3766908431298113 - train acc: 73.995297135528



33it [00:00, 48.74it/s]

valid loss: 2.146686313673854 - valid acc: 74.13461538461539
Epoch: 7



293it [00:05, 49.85it/s]

train loss: 1.3422930193682239 - train acc: 74.5671227020094



33it [00:00, 53.76it/s]

valid loss: 2.1016198955476284 - valid acc: 74.61538461538461
Epoch: 8



293it [00:05, 53.89it/s]

train loss: 1.3452665426143229 - train acc: 74.85036340316374



33it [00:01, 27.35it/s]

valid loss: 2.1066941134631634 - valid acc: 71.875
Epoch: 9



293it [00:05, 50.76it/s]

train loss: 1.327454412024315 - train acc: 75.08016246259085



33it [00:00, 36.03it/s]

valid loss: 2.2020047195255756 - valid acc: 68.75
Epoch: 10



293it [00:06, 44.49it/s]

train loss: 1.3100688204373399 - train acc: 75.43822146216331



33it [00:00, 39.88it/s]

valid loss: 2.0877885408699512 - valid acc: 72.59615384615384
Epoch: 11



293it [00:06, 45.35it/s]

train loss: 1.2960280298370204 - train acc: 75.74818298418128



33it [00:00, 34.86it/s]

valid loss: 2.3172562774270773 - valid acc: 63.125
Epoch: 12



293it [00:07, 41.03it/s]

train loss: 1.2880591216356787 - train acc: 76.09020949123557



33it [00:00, 52.13it/s]

valid loss: 2.2608738392591476 - valid acc: 70.86538461538461
Epoch: 13



293it [00:06, 47.46it/s]

train loss: 1.2784640801279512 - train acc: 75.88713125267208



33it [00:00, 51.31it/s]

valid loss: 2.138380127027631 - valid acc: 67.9326923076923
Epoch: 14



293it [00:06, 44.91it/s]

train loss: 1.258454201972648 - train acc: 76.6673792218897



33it [00:00, 35.29it/s]

valid loss: 2.059717705473304 - valid acc: 73.5576923076923
Epoch: 15



293it [00:06, 46.48it/s] 

train loss: 1.2546118841799971 - train acc: 76.58187259512613



33it [00:00, 43.30it/s]

valid loss: 2.44201735034585 - valid acc: 58.12500000000001
Epoch: 16



293it [00:04, 65.63it/s] 

train loss: 1.2390846671306923 - train acc: 77.02543822146217



33it [00:00, 40.50it/s]

valid loss: 2.6802459321916103 - valid acc: 49.519230769230774
Epoch: 17



293it [00:04, 62.97it/s] 

train loss: 1.2402369016653871 - train acc: 76.68341171440787



33it [00:00, 43.81it/s]

valid loss: 2.0663449270650744 - valid acc: 77.45192307692308
Epoch: 18



293it [00:04, 66.51it/s] 

train loss: 1.2274861129587644 - train acc: 77.2392047883711



33it [00:00, 42.83it/s]

valid loss: 2.14088898897171 - valid acc: 76.39423076923076
Epoch: 19



293it [00:03, 86.64it/s] 

train loss: 1.2259962836357012 - train acc: 77.0307823856349



33it [00:00, 41.80it/s]

valid loss: 2.401182634755969 - valid acc: 69.1826923076923
Epoch: 20



293it [00:04, 72.70it/s] 

train loss: 1.2081279633184 - train acc: 77.61864044463445



33it [00:00, 38.13it/s]

valid loss: 2.0370675772428513 - valid acc: 74.66346153846153
Epoch: 21



293it [00:04, 61.14it/s] 

train loss: 1.2165423258118433 - train acc: 77.41021804189825



33it [00:00, 38.70it/s]

valid loss: 2.283893369138241 - valid acc: 71.0576923076923
Epoch: 22



293it [00:04, 70.66it/s] 

train loss: 1.2107710294323424 - train acc: 77.59191962377085



33it [00:01, 30.79it/s]

valid loss: 2.0436701104044914 - valid acc: 77.98076923076923
Epoch: 23



293it [00:04, 59.90it/s] 

train loss: 1.2010662613870347 - train acc: 78.11564771269774



33it [00:00, 42.69it/s]

valid loss: 2.184230636805296 - valid acc: 73.70192307692308
Epoch: 24



293it [00:04, 62.79it/s]

train loss: 1.1899263389306525 - train acc: 78.05686190679778



33it [00:00, 50.19it/s] 

valid loss: 2.217981733381748 - valid acc: 73.17307692307692
Epoch: 25



293it [00:05, 58.60it/s] 

train loss: 1.1902865371475482 - train acc: 78.01945275758871



33it [00:00, 44.14it/s]

valid loss: 2.263791171833873 - valid acc: 65.33653846153847
Epoch: 26



293it [00:05, 50.56it/s] 

train loss: 1.1919815789346826 - train acc: 78.03548525010689



33it [00:00, 49.57it/s]

valid loss: 2.1438838746398687 - valid acc: 75.14423076923077
Epoch: 27



293it [00:05, 53.54it/s]

train loss: 1.177036535045872 - train acc: 78.43095339888842



33it [00:00, 49.44it/s]

valid loss: 2.540796425193548 - valid acc: 53.46153846153846
Epoch: 28



293it [00:05, 51.55it/s]

train loss: 1.1748333777466866 - train acc: 78.50042753313382



33it [00:00, 51.14it/s]

valid loss: 2.272117040120065 - valid acc: 76.34615384615384
Epoch: 29



293it [00:05, 54.03it/s]

train loss: 1.170602997280147 - train acc: 78.42026507054297



33it [00:00, 53.20it/s]

valid loss: 2.9443919118493795 - valid acc: 76.49038461538461
Epoch: 30



293it [00:04, 68.99it/s] 

train loss: 1.1748641694653523 - train acc: 78.46836254809747



33it [00:00, 53.35it/s]

valid loss: 2.212433073669672 - valid acc: 78.26923076923077
Epoch: 31



293it [00:03, 87.94it/s] 

train loss: 1.163298785686493 - train acc: 78.79970072680634



33it [00:00, 40.83it/s]

valid loss: 2.1657086703926325 - valid acc: 76.82692307692308
Epoch: 32



293it [00:04, 60.36it/s] 

train loss: 1.162664407532509 - train acc: 78.97071398033347



33it [00:00, 34.52it/s]

valid loss: 2.0708397030830383 - valid acc: 76.39423076923076
Epoch: 33



293it [00:03, 78.67it/s] 

train loss: 1.1486506579469329 - train acc: 79.08294142796066



33it [00:00, 35.31it/s]

valid loss: 3.245583016425371 - valid acc: 38.70192307692308
Epoch: 34



293it [00:03, 87.51it/s] 

train loss: 1.1484370677642626 - train acc: 79.19516887558787



33it [00:00, 39.10it/s]

valid loss: 2.3809964936226606 - valid acc: 64.32692307692308
Epoch: 35



293it [00:03, 85.74it/s] 

train loss: 1.140346375434366 - train acc: 79.58529286019666



33it [00:00, 46.60it/s]

valid loss: 2.1309230476617813 - valid acc: 68.65384615384616
Epoch: 36



293it [00:04, 59.42it/s] 

train loss: 1.150606561401119 - train acc: 79.24326635314237



33it [00:00, 46.23it/s]

valid loss: 2.4986823480576277 - valid acc: 63.74999999999999
Epoch: 37



293it [00:03, 84.36it/s] 

train loss: 1.135467194122811 - train acc: 79.44100042753314



33it [00:00, 45.17it/s]

valid loss: 2.143334148451686 - valid acc: 77.45192307692308
Epoch: 38



293it [00:03, 86.12it/s] 

train loss: 1.1332636056085155 - train acc: 79.53719538264215



33it [00:00, 37.01it/s]

valid loss: 2.0091776847839355 - valid acc: 77.16346153846155
Epoch: 39



293it [00:04, 60.13it/s] 

train loss: 1.132152561761745 - train acc: 79.53719538264215



33it [00:00, 46.56it/s]

valid loss: 2.506824817508459 - valid acc: 57.45192307692307
Epoch: 40



293it [00:04, 72.38it/s] 

train loss: 1.1280969772232723 - train acc: 79.67614365113296



33it [00:00, 36.79it/s]


valid loss: 2.0433272356167436 - valid acc: 77.25961538461539
Epoch: 41


293it [00:03, 88.19it/s] 

train loss: 1.1212796448436502 - train acc: 79.83646857631467



33it [00:01, 30.70it/s]

valid loss: 2.151437832042575 - valid acc: 78.36538461538461
Epoch: 42



293it [00:04, 64.41it/s] 

train loss: 1.128409116133435 - train acc: 79.5692603676785



33it [00:00, 33.68it/s]

valid loss: 2.372211467474699 - valid acc: 68.5576923076923
Epoch: 43



293it [00:04, 59.53it/s] 

train loss: 1.1321521728006128 - train acc: 79.41427960666951



33it [00:00, 38.56it/s]

valid loss: 2.696209128946066 - valid acc: 60.192307692307686
Epoch: 44



293it [00:03, 86.90it/s] 

train loss: 1.107772357353609 - train acc: 80.27469003847798



