In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import random
from glob import glob
from tqdm import tqdm
from scipy.io import loadmat
from scipy import signal

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from sklearn.metrics import classification_report

In [2]:
data_dir = "/media/mountHDD3/data_storage/biomedical_data/ecg_data/SPH"
print(os.listdir(data_dir))

['metadata.csv', 'data_df.csv', 'data_df_no1.csv', 'records']


In [3]:
main_df = pd.read_csv(data_dir + "/data_df_no1.csv")
main_df.shape

(20835, 4)

In [4]:
single_fns = main_df["File name"].values.tolist()
single_mat_paths = [data_dir + f"/records/{x}.h5" for x in single_fns]

In [5]:
ratio = [0.8, 0.1]

train_index = int(len(single_mat_paths)*ratio[0])
valid_index = int(len(single_mat_paths)*(ratio[0]+ratio[1]))

train_mat_paths = single_mat_paths[:train_index]
valid_mat_paths = single_mat_paths[valid_index:]

train_label = main_df
valid_label = main_df

In [6]:
import h5py

class HeartData(Dataset):
    def __init__(self, data_paths):
        self.data_paths = data_paths
        random.shuffle(self.data_paths)

    def __getitem__(self, idx):
        data_path = self.data_paths[idx]
        a = h5py.File(data_path, 'r')
        data_h5 = a['ecg']
        data = np.array(data_h5)
        clip_data = data[:, :2500]

        filename = data_path.split("/")[-1].split(".")[0]
        label = main_df[main_df["File name"] == filename]["New Label"].values.item()

        torch_data = torch.from_numpy(clip_data)

        return torch_data.float(), label

    def __len__(self):
        return len(self.data_paths)

In [7]:
train_ds = HeartData(train_mat_paths)
valid_ds = HeartData(valid_mat_paths)

In [8]:
class Teacher_Network(nn.Module):
    def __init__(self):
        super().__init__()

        # DownBlock 1
        self.down_block11 = nn.Sequential(
            nn.Conv1d(12, 64, 3, 1),
            nn.BatchNorm1d(64),
            nn.ReLU()
        )

        self.down_block12 = nn.Sequential(
            nn.Conv1d(64, 64, 3, 1),
            nn.BatchNorm1d(64),
            nn.ReLU()
        )

        self.mp = nn.MaxPool1d(2, stride=2)
    
        # DownBlock 2
        self.down_block21 = nn.Sequential(
            nn.Conv1d(64, 128, 3, 1, 2),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )

        self.down_block22 = nn.Sequential(
            nn.Conv1d(128, 128, 3, 1),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )
        
        # DownBlock 3
        self.down_block31 = nn.Sequential(
            nn.Conv1d(128, 256, 3, 1, 2),
            nn.BatchNorm1d(256),
            nn.ReLU()
        )

        self.down_block32 = nn.Sequential(
            nn.Conv1d(256, 256, 3, 1),
            nn.BatchNorm1d(256),
            nn.ReLU()
        )
        
        # DownBlock 4
        self.down_block41 = nn.Sequential(
            nn.Conv1d(256, 512, 3, 1, 2),
            nn.BatchNorm1d(512),
            nn.ReLU()
        )

        self.down_block42 = nn.Sequential(
            nn.Conv1d(512, 512, 3, 1),
            nn.BatchNorm1d(512),
            nn.ReLU()
        )

        # DownBlock 5
        self.down_block51 = nn.Sequential(
            nn.Conv1d(512, 1024, 3, 1, 2),
            nn.BatchNorm1d(1024),
            nn.ReLU()
        )

        self.down_block52 = nn.Sequential(
            nn.Conv1d(1024, 1024, 3, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU()
        )

        # UpBlock 1
        self.up_block11 = nn.Upsample(scale_factor=2)
        self.up_block12 = nn.ConvTranspose1d(1024, 512, 3, 1, 1)

        self.relu1 = nn.ReLU()
        self.output1 = nn.Sequential(
            nn.Conv1d(512,512,1,1),
            nn.Sigmoid()
        )
        
        self.up_block13 = nn.Sequential(
            nn.Conv1d(512, 512, 3, 1,1),
            nn.BatchNorm1d(512),
            nn.ReLU()
        )

        # UpBlock 2
        self.up_block21 = nn.Upsample(scale_factor=2)
        self.up_block22 = nn.ConvTranspose1d(512, 256, 3, 1, 1)

        self.relu2 = nn.ReLU()
        self.output2 = nn.Sequential(
            nn.Conv1d(256,256,1,1),
            nn.Sigmoid()
        )

        self.up_block23 = nn.Sequential(
            nn.Conv1d(256, 256, 3, 1, 1),
            nn.BatchNorm1d(256),
            nn.ReLU()
        )

        # UpBlock 3
        self.up_block31 = nn.Upsample(scale_factor=2)
        self.up_block32 = nn.ConvTranspose1d(256, 128, 3, 1, 1)

        self.relu3 = nn.ReLU()
        self.output3 = nn.Sequential(
            nn.Conv1d(128,128,1,1),
            nn.Sigmoid()
        )

        self.up_block33 = nn.Sequential(
            nn.Conv1d(128, 128, 3, 1,1),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )

        # UpBlock 4
        self.up_block41 = nn.Upsample(scale_factor=2)
        self.up_block42 = nn.ConvTranspose1d(128, 64, 3, 1, 1)

        self.relu4 = nn.ReLU()
        self.output4 = nn.Sequential(
            nn.Conv1d(64,64,1,1),
            nn.Sigmoid()
        )
        
        self.up_block43 = nn.Sequential(
            nn.Conv1d(64, 64, 3, 1,1),
            nn.BatchNorm1d(64),
            nn.ReLU()
        )

        # OutBlock
        self.out = nn.Conv1d(64, 3, 1)
    
    def forward(self, x):
        x11 = self.down_block11(x)
        x12 = self.down_block12(x11)
        x13 = self.mp(x12)

        # x11 = self.down_block11(x)
        # x12 = self.down_block12(x11)
        # x13 = self.mp(x11+x12)
        
        x21 = self.down_block21(x13)
        x22 = self.down_block22(x21)
        x23 = self.mp(x22)
        
        x31 = self.down_block31(x23)
        x32 = self.down_block32(x31)
        x33 = self.mp(x32)
        
        x41 = self.down_block41(x33)
        x42 = self.down_block42(x41)
        x43 = self.mp(x42)
        
        x5 = self.down_block51(x43)
        x5 = self.down_block52(x5)

        x61 = self.up_block11(x5)
        x62 = self.up_block12(x61)
        x6 = self.relu1(x42 + x62)
        x6 = self.output1(x6)
        x6_att = x6 * x62
        x6 = self.up_block13(x6_att)
        x6 = self.up_block13(x6)
        
        x71 = self.up_block21(x6)
        x72 = self.up_block22(x71)
        x7 = self.relu2(x32 + x72)
        x7 = self.output2(x7)
        x7_att = x7 * x72
        x7 = self.up_block23(x7_att)
        x7 = self.up_block23(x7)

        x81 = self.up_block31(x7)
        x82 = self.up_block32(x81)
        x8 = self.relu3(x22 + x82)
        x8 = self.output3(x8)
        x8_att = x8 * x82
        x8 = self.up_block33(x8_att)
        x8 = self.up_block33(x8)

        x91 = self.up_block41(x8)
        x92 = self.up_block42(x91)
        x9 = self.relu4(x12 + x92)
        x9 = self.output4(x9)
        x9_att = x9 * x92
        x9 = self.up_block43(x9_att)
        x9 = self.up_block43(x9)

        x10 = self.out(x9)
        
        return x7_att, x8_att, x9_att

In [9]:
# test1 = torch.randn(1, 12, 2500)
# model1 = Teacher_Network()
# a = model1(test1)
# print(a.shape)

In [10]:
# train_ds = HeartData(train_mat_paths)
# valid_ds = HeartData(valid_mat_paths)

In [11]:
train_dl = DataLoader(train_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)
valid_dl = DataLoader(valid_ds, batch_size = 64, shuffle = True, pin_memory = True, num_workers = 48)

In [12]:
class Student_Network(nn.Module):
    def __init__(self):
        super().__init__()

        # Start
        self.conv_block1 = nn.Sequential(
            nn.Conv1d(12, 64, 3, 1, 0, 2),
            nn.BatchNorm1d(64),
            nn.ReLU()
        )
        
        # ConvBlock 1
        self.conv_block11 = nn.Sequential(
            nn.Conv1d(64, 64, 3, 1, 1),
            nn.BatchNorm1d(64),
            nn.ReLU()
        )
    
        self.conv_block12 = nn.Sequential(
            nn.Conv1d(64, 64, 3, 1, 1),
            nn.BatchNorm1d(64),
            nn.ReLU()
        )
        
        # ConvBlock 2
        self.conv_block21 = nn.Sequential(
            nn.Conv1d(64, 128, 3, 1, 2, 2),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )
    
        self.conv_block22 = nn.Sequential(
            nn.Conv1d(128, 128, 3, 1, 1),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )
    
        # Attention 1
        self.att_block11 = nn.Sequential(
            nn.Conv1d(64, 64, 1, 1),
            nn.BatchNorm1d(64)
        )
    
        # self.att_block12 = nn.Sequential(
        #     nn.Conv1d(128, 128, 1, 1),
        #     nn.BatchNorm1d(128)
        # )
            
        self.att_block13 = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(64, 128, 3, 1, 1),
            nn.BatchNorm1d(128),
            nn.Sigmoid()
        )

        self.att_block12 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose1d(128, 64, 3, 1, 1),
            nn.BatchNorm1d(64)
        )
        
        # ConvBlock 3
        self.conv_block31 = nn.Sequential(
            nn.Conv1d(128, 256, 3, 1, 1),
            nn.BatchNorm1d(256),
            nn.ReLU()
        )
    
        self.conv_block32 = nn.Sequential(
            nn.Conv1d(256, 256, 3, 1, 1),
            nn.BatchNorm1d(256),
            nn.ReLU()
        )
    
        # Attention 2
        self.att_block21 = nn.Sequential(
            nn.Conv1d(128, 128, 1, 1),
            nn.BatchNorm1d(128)
        )

        self.att_block22 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose1d(256, 128, 3, 1, 1),
            nn.BatchNorm1d(128)
        )
            
        self.att_block23 = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(128, 256, 3, 1, 1),
            nn.BatchNorm1d(256),
            nn.Sigmoid()
        )
    
        # ConvBlock 4
        self.conv_block41 = nn.Sequential(
            nn.Conv1d(256, 512, 3, 1, 1),
            nn.BatchNorm1d(512),
            nn.ReLU()
        )
    
        self.conv_block42 = nn.Sequential(
            nn.Conv1d(512, 512, 3, 1, 1),
            nn.BatchNorm1d(512),
            nn.ReLU()
        )
    
        # Attention 3
        self.att_block31 = nn.Sequential(
            nn.Conv1d(256, 256, 1, 1),
            nn.BatchNorm1d(256)
        )

        self.att_block32 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose1d(512, 256, 3, 1, 1),
            nn.BatchNorm1d(256)
        )
  
        self.att_block33 = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(256, 512, 3, 1, 1),
            nn.BatchNorm1d(512),
            nn.Sigmoid()
        )

        self.mp = nn.MaxPool1d(2, stride=2)
        # self.linear = nn.Linear(1)

        # Out
        self.out1 = nn.AdaptiveAvgPool1d(1)
        self.out2 = nn.Linear(512,31)
    
    def forward(self, x):
        x1 = self.conv_block1(x)
        x1 = self.conv_block11(x1)
        x1 = self.conv_block12(x1)
        x1 = self.conv_block11(x1)
        x1 = self.conv_block12(x1)
        x1 = self.conv_block11(x1)
        x1 = self.conv_block12(x1)

        x2 = self.conv_block21(x1)
        # print(x2.shape)
        x2 = self.conv_block22(x2)
        # print(x2.shape)
        x2 = self.conv_block22(x2)
        x2 = self.conv_block22(x2)
        x2 = self.conv_block22(x2)
        x2 = self.conv_block22(x2)
        x2 = self.conv_block22(x2)
        x2 = self.conv_block22(x2)
        # print(x2.shape)
        x2 = self.mp(x2)
        # print(x2.shape)

        x_att11 = self.att_block11(x1)
        # x_att12 = self.att_block12(x2)
        x_att12 = self.att_block12(x2)
        # print(x_att11.shape)
        # print(x_att12.shape)
        x_att1 = x_att11 + x_att12
        x_att1 = self.att_block13(x_att1)
        x_att1 = self.mp(x_att1)
        # print(x_att1.shape)
        x_att1 = x_att1 * x2
        # print(x_att1.shape)

        x3 = self.conv_block31(x_att1)
        # print(x3.shape)
        x3 = self.conv_block32(x3)
        x3 = self.conv_block32(x3)
        x3 = self.conv_block32(x3)
        x3 = self.conv_block32(x3)
        x3 = self.conv_block32(x3)
        x3 = self.conv_block32(x3)
        x3 = self.conv_block32(x3)
        x3 = self.conv_block32(x3)
        x3 = self.conv_block32(x3)
        x3 = self.conv_block32(x3)
        x3 = self.conv_block32(x3)
        x3 = self.mp(x3)
        # print(x3.shape)

        x_att21 = self.att_block21(x2)
        # print(x_att21.shape)
        x_att22 = self.att_block22(x3)
        # print(x_att22.shape)
        x_att2 = x_att21 + x_att22
        x_att2 = self.att_block23(x_att2)
        x_att2 = self.mp(x_att2)
        x_att2 = x_att2 * x3

        x4 = self.conv_block41(x_att2)
        # print(x4.shape)
        x4 = self.conv_block42(x4)
        x4 = self.conv_block42(x4)
        x4 = self.conv_block42(x4)
        x4 = self.conv_block42(x4)
        x4 = self.conv_block42(x4)
        x4 = self.mp(x4)
        # print(x4.shape)

        x_att31 = self.att_block31(x3)
        # print(x_att31.shape)
        x_att32 = self.att_block32(x4)
        # print(x_att32.shape)
        x_att3 = x_att31 + x_att32
        x_att3 = self.att_block33(x_att3)
        x_att3 = self.mp(x_att3)
        # print(x_att3.shape)
        x_att3 = x_att3 * x4

        x_out = self.out1(x_att3)
        x_out = self.out2(x_out.squeeze(-1))

        return x_out

In [13]:
test1 = torch.randn(1, 12, 2500)
model = Student_Network()
a = model(test1)
print(a.shape)

torch.Size([1, 31])


In [14]:
epoch = 150
lr = 0.0005
best_acc = 0
best_ep = 0

# model = HeartCTAT()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", index = 1)
model.to(device)
optimizer = Adam(model.parameters(), lr=lr, betas = (0.9, 0.999), weight_decay = 0.001)
# scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=epoch)
loss_fn = nn.CrossEntropyLoss()

In [15]:
for e in range(epoch):
    model.train()
    print(f"Epoch: {e}")
    y_true_list = [] 
    pred_list = []
    batch_cnt = 0
    total_loss = 0
    correct = 0
    correct = 0
    for batch, (train_sig, train_label) in tqdm(enumerate(train_dl)):
        batch_cnt = batch
        train_sig = train_sig.to(device)
        train_label = train_label.to(device)
        
        pred = model(train_sig)
        loss = loss_fn(pred, train_label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
#         scheduler.step()
        
        total_loss += loss.item()
        correct += (pred.argmax(1) == train_label).type(torch.float).sum().item()
    
    total_loss /= batch_cnt
    correct /= len(train_dl.dataset)
    
    print(f"train loss: {total_loss} - train acc: {100*correct}")
    
    batch_cnt = 0
    val_total_loss = 0
    val_correct = 0
    model.eval()
    with torch.no_grad():
        for batch, (valid_sig, valid_label) in tqdm(enumerate(valid_dl)):
            batch_cnt = batch
            valid_sig = valid_sig.to(device)
            valid_label = valid_label.to(device)
            
            pred = model(valid_sig)

            pred_pos = pred.argmax(1)
            y_true_list.append(valid_label)
            pred_list.append(pred_pos)
            
            loss = loss_fn(pred, valid_label)
            
            val_total_loss += loss.item()
            val_correct += (pred.argmax(1) == valid_label).type(torch.float).sum().item()
            
        val_total_loss /= batch_cnt
        val_correct /= len(valid_dl.dataset)
        if val_correct > best_acc:
            best_acc = val_correct
            best_ep = e
        
        print(f"valid loss: {val_total_loss} - valid acc: {100*val_correct}")
        
y_true = torch.cat(y_true_list).cpu().numpy()
pred = torch.cat(pred_list).cpu().numpy()

reports = classification_report(y_true, pred, output_dict=True) 

print(reports)
print(f"Best acuracy: {best_acc} at epoch {best_ep}")

Epoch: 0


261it [01:02,  4.17it/s]

train loss: 1.301960278245119 - train acc: 68.26253899688025



33it [00:15,  2.16it/s]

valid loss: 2.6641483083367348 - valid acc: 21.16122840690979
Epoch: 1



261it [00:52,  4.95it/s]

train loss: 1.0231786700395438 - train acc: 72.10823134149268



33it [00:02, 12.99it/s]

valid loss: 3.1822210773825645 - valid acc: 10.892514395393475
Epoch: 2



205it [00:41,  4.90it/s]


KeyboardInterrupt: 