In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import random
from glob import glob
from tqdm import tqdm
from scipy.io import loadmat

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import os, sys
from typing import *
import torch
import random
import copy
import warnings

In [2]:
data_dir = "/media/mountHDD3/data_storage/biomedical_data/ecg_data/SPH"
print(os.listdir(data_dir))

['metadata.csv', 'data_df.csv', 'data_df_no1.csv', 'records']


In [3]:
main_df = pd.read_csv(data_dir + "/data_df.csv")
main_df.shape

(20792, 4)

In [4]:
single_fns = main_df["File name"].values.tolist()
single_mat_paths = [data_dir + f"/records/{x}.h5" for x in single_fns]

In [5]:
ratio = [0.8, 0.1]

train_index = int(len(single_mat_paths)*ratio[0])
valid_index = int(len(single_mat_paths)*(ratio[0]+ratio[1]))

train_mat_paths = single_mat_paths[:train_index]
valid_mat_paths = single_mat_paths[valid_index:]

train_label = main_df
valid_label = main_df

In [6]:
import h5py

class HeartData(Dataset):
    def __init__(self, data_paths):
        self.data_paths = data_paths
        random.shuffle(self.data_paths)

    def __getitem__(self, idx):
        data_path = self.data_paths[idx]
        a = h5py.File(data_path, 'r')
        data_h5 = a['ecg']
        data = np.array(data_h5)
        clip_data = data[:, 0:3000]

        filename = data_path.split("/")[-1].split(".")[0]
        label = main_df[main_df["File name"] == filename]["New Label"].values.item()

        torch_data = torch.from_numpy(clip_data)

        return torch_data.float(), label

    def __len__(self):
        return len(self.data_paths)

In [7]:
train_ds = HeartData(train_mat_paths)
valid_ds = HeartData(valid_mat_paths)

# train_ds = ECG(train_data_paths, single_main_df)
# valid_ds = ECG(valid_data_paths, single_main_df)

In [8]:
train_dl = DataLoader(train_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)
valid_dl = DataLoader(valid_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)

In [9]:
class Auto_Encoder(nn.Module):
    def __init__(self):
        super(Auto_Encoder, self).__init__()

        self.enc = nn.Sequential(    
            # 12 x 3000
            nn.Conv1d(12, 64, 4, 2, 1),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 64 x 1500
            nn.Conv1d(64, 64, 3, 2, 1),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 64 x 750
            nn.Conv1d(64, 128, 3, 2, 1),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 128 x 375
            nn.Conv1d(128, 128, 3, 2, 1),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),  
            # 128 x 188
        )
        
        self.cls = nn.Sequential(
            nn.Dropout(p=0.2, inplace=False),
            nn.Linear(128, out_features=34, bias=True)
        )

        self.dec = nn.Sequential(
            # 128 x 188
            nn.ConvTranspose1d(128, 128, 3, 2, 1),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 128 x 375
            nn.ConvTranspose1d(128, 64, 2, 2),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(),
            # 64 x 750
            nn.ConvTranspose1d(64, 64, 3, 2, 1, 1),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2),
            # nn.ReLU(), #Laeky relu 0.2
            # 64 x 1500
            nn.ConvTranspose1d(64, 12, 4, 2, 1),
            # 12 x 3000
            nn.Sigmoid()
        )


    def forward(self, x):
        enc = self.enc(x) # 128 x 7 x 7 
        cls = self.cls(enc.mean(dim=-1))
        dec = self.dec(enc)
        
        return dec, cls

In [10]:
epoch = 150
lr = 0.0005
best_acc = 0
best_ep = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu", index = 1)
model = Auto_Encoder().to(device)
optimizer = Adam(model.parameters(), lr=lr)
scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=epoch)
loss_fn_cls = nn.CrossEntropyLoss()
loss_fn_sig = nn.MSELoss()

In [None]:
warnings.filterwarnings('ignore')
for e in range(epoch):
    model.train()
    print(f"Epoch: {e}")
    y_true_list = [] 
    pred_list = []
    batch_cnt = 0
    total_loss = 0
    correct = 0
    for batch, (train_sig, train_label) in tqdm(enumerate(train_dl)):
        batch_cnt = batch
        train_sig = train_sig.to(device)
        train_label = train_label.to(device)
        
        res_sig, pred_cls = model(train_sig)
        loss_cls = loss_fn_cls(pred_cls, train_label)
        loss_sig = loss_fn_sig(res_sig, train_sig)
        loss_tot = loss_cls + loss_sig
        
        optimizer.zero_grad()
        loss_tot.backward()
        optimizer.step()
        
        scheduler.step()
        
        total_loss += loss_tot.item()
        correct += (pred_cls.argmax(1) == train_label).type(torch.float).sum().item()
    
    total_loss /= batch_cnt
    correct /= len(train_dl.dataset)
    
    print(f"train loss: {total_loss} - train acc: {100*correct}")
    
    batch_cnt = 0
    val_total_loss = 0
    val_correct = 0
    model.eval()
    with torch.no_grad():
        for batch, (valid_sig, valid_label) in tqdm(enumerate(valid_dl)):
            batch_cnt = batch
            valid_sig = valid_sig.to(device)
            valid_label = valid_label.to(device)
            
            res_sig, pred_cls = model(valid_sig)
            loss_cls = loss_fn_cls(pred_cls, valid_label)
            loss_sig = loss_fn_sig(res_sig, valid_sig)
            loss_tot = loss_cls + loss_sig
            
            pred_pos = pred_cls.argmax(1)
            y_true_list.append(valid_label)
            pred_list.append(pred_pos)

            val_total_loss += loss_tot.item()
            val_correct += (pred_cls.argmax(1) == valid_label).type(torch.float).sum().item()
    
        val_total_loss /= batch_cnt
        val_correct /= len(valid_dl.dataset)
        if val_correct > best_acc:
            best_acc = val_correct
            best_ep = e
        
        print(f"valid loss: {val_total_loss} - valid acc: {100*val_correct}")

        
y_true = torch.cat(y_true_list).cpu().numpy()
pred = torch.cat(pred_list).cpu().numpy()

reports = classification_report(y_true, pred, output_dict=True)

print(reports)
print(f"Best acuracy: {best_acc} at epoch {best_ep}")

Epoch: 0


260it [00:04, 57.96it/s] 

train loss: 2.4658435357583537 - train acc: 65.29188961702638



33it [00:00, 42.09it/s]

valid loss: 2.533148117363453 - valid acc: 68.02884615384616
Epoch: 1



260it [00:02, 93.22it/s] 

train loss: 1.7762429405823639 - train acc: 69.11561353934948



33it [00:01, 27.10it/s]

valid loss: 2.517861757427454 - valid acc: 63.84615384615384
Epoch: 2



260it [00:04, 60.94it/s]

train loss: 1.6652212195875102 - train acc: 70.3300667348043



33it [00:00, 43.21it/s]

valid loss: 2.186696257442236 - valid acc: 69.90384615384615
Epoch: 3



260it [00:02, 88.96it/s] 

train loss: 1.6044101809442735 - train acc: 71.11765766849035



33it [00:00, 47.96it/s]

valid loss: 2.3867783062160015 - valid acc: 63.74999999999999
Epoch: 4



260it [00:03, 70.19it/s] 

train loss: 1.5607786737813913 - train acc: 71.76696927794144



33it [00:00, 56.18it/s]

valid loss: 2.622519712895155 - valid acc: 50.0
Epoch: 5



260it [00:03, 80.81it/s] 

train loss: 1.5248639588190323 - train acc: 72.38020802020081



33it [00:00, 34.15it/s]

valid loss: 2.581562675535679 - valid acc: 53.55769230769231
Epoch: 6



260it [00:04, 62.08it/s] 

train loss: 1.4880057805753584 - train acc: 73.07160464137557



33it [00:00, 37.67it/s]

valid loss: 2.170233516022563 - valid acc: 70.76923076923077
Epoch: 7



260it [00:04, 60.79it/s] 

train loss: 1.4700661932179366 - train acc: 73.33613900078157



33it [00:01, 29.20it/s]

valid loss: 2.4378597624599934 - valid acc: 57.11538461538461
Epoch: 8



260it [00:04, 52.84it/s] 

train loss: 1.4533117968143183 - train acc: 73.85318343052967



33it [00:00, 40.62it/s]

valid loss: 2.3967596627771854 - valid acc: 61.97115384615385
Epoch: 9



260it [00:04, 54.30it/s] 

train loss: 1.4347063586058304 - train acc: 74.22593639151086



33it [00:01, 25.63it/s]

valid loss: 2.233720326796174 - valid acc: 66.4423076923077
Epoch: 10



260it [00:05, 48.19it/s] 

train loss: 1.4275999618082893 - train acc: 74.09366921180785



33it [00:00, 44.33it/s]

valid loss: 2.5600553564727306 - valid acc: 55.28846153846154
Epoch: 11



260it [00:05, 44.81it/s]

train loss: 1.4108484009978395 - train acc: 74.41231287200144



33it [00:00, 65.75it/s] 

valid loss: 1.9884002450853586 - valid acc: 74.71153846153847
Epoch: 12



260it [00:05, 44.72it/s]

train loss: 1.3950401001455242 - train acc: 74.79107797751458



33it [00:01, 30.37it/s]

valid loss: 2.229693081229925 - valid acc: 67.78846153846155
Epoch: 13



260it [00:06, 39.12it/s]

train loss: 1.384788457491223 - train acc: 75.19389166115553



33it [00:00, 52.06it/s]

valid loss: 2.176446845754981 - valid acc: 71.20192307692308
Epoch: 14



260it [00:05, 46.74it/s]

train loss: 1.3753515569399684 - train acc: 75.41634100883785



33it [00:00, 51.14it/s]

valid loss: 2.1053544003516436 - valid acc: 72.74038461538461
Epoch: 15



260it [00:03, 80.31it/s] 

train loss: 1.3654062237058366 - train acc: 75.53057175494499



33it [00:00, 33.37it/s]

valid loss: 2.3440801315009594 - valid acc: 70.67307692307693
Epoch: 16



260it [00:03, 73.82it/s] 

train loss: 1.3528312353784053 - train acc: 75.72296037996753



33it [00:01, 30.92it/s]

valid loss: 2.1030988581478596 - valid acc: 71.58653846153847
Epoch: 17



260it [00:03, 76.14it/s] 

train loss: 1.3523254559076892 - train acc: 75.82516683701076



33it [00:01, 30.69it/s]

valid loss: 2.0588407907634974 - valid acc: 75.4326923076923
Epoch: 18



260it [00:03, 76.32it/s] 

train loss: 1.3470988480740993 - train acc: 75.8191546924788



33it [00:00, 36.83it/s]

valid loss: 2.1142521277070045 - valid acc: 76.20192307692307
Epoch: 19



260it [00:03, 76.17it/s] 

train loss: 1.3293880677591419 - train acc: 76.27607767690735



33it [00:00, 40.04it/s]

valid loss: 2.0962680149823427 - valid acc: 75.96153846153845
Epoch: 20



260it [00:03, 74.96it/s] 

train loss: 1.325106294689031 - train acc: 76.37227198941862



33it [00:00, 36.25it/s]

valid loss: 2.558605959638953 - valid acc: 60.28846153846153
Epoch: 21



260it [00:03, 86.50it/s] 

train loss: 1.3106731070514812 - train acc: 76.73901280586786



33it [00:01, 19.72it/s]

valid loss: 2.1944606117904186 - valid acc: 75.0
Epoch: 22



260it [00:03, 73.12it/s] 


train loss: 1.303348236217462 - train acc: 76.8772921301028


33it [00:00, 60.77it/s]

valid loss: 2.0487102773040533 - valid acc: 74.75961538461539
Epoch: 23



260it [00:03, 75.00it/s] 

train loss: 1.3031399544601736 - train acc: 76.72097637227199



33it [00:00, 58.13it/s]

valid loss: 1.9472338641062379 - valid acc: 75.4326923076923
Epoch: 24



260it [00:03, 65.26it/s] 

train loss: 1.2997525890131254 - train acc: 76.9314014308904



33it [00:00, 57.33it/s]

valid loss: 2.3715789541602135 - valid acc: 67.54807692307693
Epoch: 25



260it [00:04, 59.86it/s] 

train loss: 1.2945726726736342 - train acc: 77.08771718872121



33it [00:00, 64.41it/s]

valid loss: 3.601867198944092 - valid acc: 25.384615384615383
Epoch: 26



260it [00:03, 68.30it/s] 

train loss: 1.3036234115311538 - train acc: 76.79312210665545



33it [00:00, 43.12it/s] 

valid loss: 2.9766896925866604 - valid acc: 42.83653846153846
Epoch: 27



260it [00:04, 53.32it/s] 

train loss: 1.2781462459950834 - train acc: 77.55065231768171



33it [00:00, 57.29it/s]


valid loss: 2.4439509473741055 - valid acc: 62.16346153846154
Epoch: 28


260it [00:04, 52.73it/s] 

train loss: 1.272721554436739 - train acc: 77.55666446221367



33it [00:01, 22.00it/s]

valid loss: 2.098497238010168 - valid acc: 75.625
Epoch: 29



260it [00:05, 48.12it/s]

train loss: 1.2648028582909854 - train acc: 77.76107737630012



33it [00:00, 54.23it/s]

valid loss: 2.013218779116869 - valid acc: 76.29807692307692
Epoch: 30



260it [00:07, 34.66it/s]

train loss: 1.263376371280567 - train acc: 77.91138098959898



33it [00:00, 43.83it/s]

valid loss: 2.0062202531844378 - valid acc: 76.82692307692308
Epoch: 31



260it [00:05, 46.49it/s]

train loss: 1.2645579159950198 - train acc: 77.64684663019298



33it [00:00, 56.59it/s] 

valid loss: 2.0680897179991007 - valid acc: 72.78846153846153
Epoch: 32



260it [00:05, 49.17it/s]

train loss: 1.2607044391190223 - train acc: 77.7911380989599



33it [00:00, 49.71it/s]

valid loss: 1.9709703624248505 - valid acc: 77.78846153846153
Epoch: 33



260it [00:05, 43.67it/s]

train loss: 1.2553755179565385 - train acc: 77.89334455600313



33it [00:00, 35.64it/s]

valid loss: 2.055208148434758 - valid acc: 74.375
Epoch: 34



260it [00:04, 52.04it/s]

train loss: 1.2505605752403672 - train acc: 78.00757530211027



33it [00:00, 38.19it/s]

valid loss: 2.9558896720409393 - valid acc: 50.817307692307686
Epoch: 35



260it [00:04, 56.07it/s]

train loss: 1.2385829126742816 - train acc: 78.56069259905009



33it [00:00, 41.79it/s]

valid loss: 2.3703319542109966 - valid acc: 71.77884615384616
Epoch: 36



260it [00:05, 45.14it/s]

train loss: 1.2404166625733541 - train acc: 78.3322311068358



33it [00:00, 42.29it/s]

valid loss: 2.0186740159988403 - valid acc: 76.34615384615384
Epoch: 37



260it [00:06, 37.80it/s]

train loss: 1.2453111911832595 - train acc: 78.39235255215536



33it [00:00, 51.06it/s]

valid loss: 2.4222942385822535 - valid acc: 70.52884615384616
Epoch: 38



260it [00:05, 50.78it/s]

train loss: 1.24590398745187 - train acc: 78.16389105994108



33it [00:00, 42.75it/s]

valid loss: 2.0354508757591248 - valid acc: 76.4423076923077
Epoch: 39



260it [00:05, 49.61it/s]

train loss: 1.2349585418995743 - train acc: 78.44646185294295



33it [00:00, 38.11it/s]

valid loss: 3.337994832545519 - valid acc: 36.97115384615385
Epoch: 40



260it [00:05, 44.22it/s]

train loss: 1.2299846216288313 - train acc: 78.39235255215536



33it [00:01, 19.62it/s]

valid loss: 2.1015727631747723 - valid acc: 75.38461538461539
Epoch: 41



260it [00:05, 44.07it/s]

train loss: 1.2238580194449333 - train acc: 79.02362772801058



33it [00:01, 30.89it/s]

valid loss: 1.9598903274163604 - valid acc: 77.9326923076923
Epoch: 42



260it [00:06, 42.86it/s]

train loss: 1.2205938264209792 - train acc: 78.84326339205194



33it [00:01, 30.80it/s]

valid loss: 1.9280921397730708 - valid acc: 76.58653846153847
Epoch: 43



260it [00:04, 54.11it/s]

train loss: 1.2126497323678727 - train acc: 79.10779775145794



33it [00:01, 30.63it/s]

valid loss: 2.169828973710537 - valid acc: 76.73076923076924
Epoch: 44



260it [00:05, 44.00it/s]

train loss: 1.2131677825708647 - train acc: 78.90338483737149



33it [00:01, 22.91it/s]

valid loss: 2.8770491369068623 - valid acc: 52.16346153846154
Epoch: 45



260it [00:04, 54.20it/s] 

train loss: 1.213739941939424 - train acc: 78.99356700535081



