In [8]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import random
from glob import glob
from tqdm import tqdm
from scipy.io import loadmat

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from datetime import datetime
from sklearn.metrics import classification_report
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import roc_auc_score
from sklearn import metrics
import json

In [11]:
data_dir = "/media/mountHDD3/data_storage/biomedical_data/ecg_data/SPH"
print(os.listdir(data_dir))

['metadata.csv', 'data_df.csv', 'records']


In [12]:
main_df = pd.read_csv(data_dir + "/data_df.csv")
main_df.shape

(20792, 4)

In [13]:
# main_df = pd.read_csv(data_dir + "/data_df.csv")
# main_df.shape

In [16]:
# single_fns = single_main_df["Recording"].values.tolist()
# print(len(single_fns))

In [17]:
single_fns = main_df["File name"].values.tolist()
single_mat_paths = [data_dir + f"/records/{x}.h5" for x in single_fns]

In [18]:
names = np.array(single_mat_paths)
label = np.array(main_df["New Label"].values)
print(type(label[0]))

<class 'numpy.int64'>


In [19]:
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.05, random_state=42)
splits = sss.split(names, label)

for tr_idxs, vl_idxs in splits:
    pass
    
tr_names, tr_label = names[tr_idxs].tolist(), label[tr_idxs].tolist()
vl_names, vl_label = names[vl_idxs].tolist(), label[vl_idxs].tolist()

tr_idx_0 = np.argwhere(np.array(tr_label) == 0).flatten().tolist()
tr_idx_1 = np.argwhere(np.array(tr_label) == 1).flatten().tolist()

vl_idx_0 = np.argwhere(np.array(vl_label) == 0).flatten().tolist()
vl_idx_1 = np.argwhere(np.array(vl_label) == 1).flatten().tolist()

batch_sampler = WeightedRandomSampler(weights=(0.8, 0.2), num_samples=len(tr_names))

assert len(tr_names) == len(tr_label)
assert len(vl_names) == len(vl_label)

print(f'[INFO]: Found {len(tr_names)} training samples, {len(vl_names)} validation samples')
print(f'[INFO]: {len(tr_idx_0)} 0s, {len(tr_idx_1)} 1s in train')
print(f'[INFO]: {len(vl_idx_0)} 0s, {len(vl_idx_1)} 1s in valid')

[INFO]: Found 19752 training samples, 1040 validation samples
[INFO]: 13208 0s, 366 1s in train
[INFO]: 695 0s, 19 1s in valid


In [20]:
sample_sig = torch.randn(1, 12, 32)
conv_test = nn.Conv1d(12, 12, 3, 1, 1)
print(conv_test(sample_sig).shape)

torch.Size([1, 12, 32])


In [21]:
class BasicBlock(nn.Module):
    def __init__(self, channel_num):
        super(BasicBlock, self).__init__()
        self.conv_block1 = nn.Sequential(
			nn.Conv1d(channel_num, channel_num, 3, padding=1),
			nn.BatchNorm1d(channel_num),
			nn.ReLU(),
		)
        self.conv_block2 = nn.Sequential(
			nn.Conv1d(channel_num, channel_num, 3, padding=1),
			nn.BatchNorm1d(channel_num),
		)
        self.relu = nn.ReLU()
        torch.nn.init.kaiming_normal_(self.conv_block1[0].weight)
        torch.nn.init.kaiming_normal_(self.conv_block2[0].weight)
        
    def forward(self, x):
        residual = x
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = x + residual
        out = self.relu(x)
        return out

In [22]:
test_basic_block = BasicBlock(2)
sample_sig = torch.randn(1, 2, 32)
print(test_basic_block(sample_sig).shape)

torch.Size([1, 2, 32])


In [23]:
class ResNet(nn.Module):
    def __init__(self, in_channels = 12, type = 18, num_classes = 22):
        super(ResNet, self).__init__()
        self.struc_dict = {
            18: {
                "num_channels" : [64, 128, 256, 512],
                "counts" : [2, 2, 2, 2]
            }
        }
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2)
        torch.nn.init.kaiming_normal_(self.conv1.weight)
        self.max1 = nn.MaxPool1d(kernel_size=3, stride=2)
        self.main = nn.Sequential()
        for idx, struc in enumerate(
            zip(
                self.struc_dict[type]["num_channels"], 
                self.struc_dict[type]["counts"]
            )
        ):
            num_channel, cnt = struc
            for i in range(cnt):
                self.main.add_module(f"conv{idx+1}_{i}", BasicBlock(num_channel))
            if idx < len(self.struc_dict[type]["num_channels"]) - 1:
                self.main.add_module(f"ext_{idx}", nn.Conv1d(num_channel, self.struc_dict[type]["num_channels"][idx+1], 3, 1))
                self.main.add_module(f"extbn_{idx}", nn.BatchNorm1d(self.struc_dict[type]["num_channels"][idx+1]))
                                     
        self.avg = torch.nn.AdaptiveAvgPool1d((1))
        self.lin = nn.Linear(self.struc_dict[type]["num_channels"][-1], num_classes)
        torch.nn.init.kaiming_normal_(self.lin.weight)
    def forward(self, x):
        x = self.conv1(x)
        x = self.max1(x)
        x = self.main(x)
        x = self.avg(x)
        x = x.reshape(x.shape[0], -1)
        x = self.lin(x)
        return x

In [24]:
model = ResNet()
sample_sig = torch.randn(1, 12, 2500)
model(sample_sig).shape

torch.Size([1, 22])

In [25]:
import h5py

class ECG(Dataset):
    def __init__(self, data_paths, label_df):
        self.data_paths = data_paths
        random.shuffle(self.data_paths)
        self.label_df = label_df

    def __getitem__(self, idx):
        data_path = self.data_paths[idx]
        a = h5py.File(data_path, 'r')
        data_h5 = a['ecg']
        data = np.array(data_h5)
        clip_data = data[:, 500:3000]

        # filename = data_path.split("/")[-1].split(".")[0]
        # label = self.label_df[self.label_df["File name"] == filename]["New Label"].values.item()
        label = self.label_df[idx]

        torch_data = torch.from_numpy(clip_data)

        return torch_data.float(), label

    def __len__(self):
        return len(self.data_paths)

In [26]:
check_ds = ECG(tr_names, tr_label)
sample, lbl = check_ds[0]
print(sample.shape, lbl)

torch.Size([12, 2500]) 11


In [27]:
model(sample.unsqueeze(dim=0)).shape

torch.Size([1, 22])

In [28]:
train_ds = ECG(tr_names, tr_label)
valid_ds = ECG(vl_names, vl_label)

print(len(train_ds))
print(len(valid_ds))

19752
1040


In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", index = 0)
# batch_size = 32
batch_size = 64

traindl = DataLoader(
    train_ds,
    batch_size=batch_size, 
    shuffle=True, 
    pin_memory=True, 
    num_workers=os.cpu_count()//2
)

validdl = DataLoader(
    valid_ds,
    batch_size=1, 
    shuffle=True, 
    pin_memory=True, 
    num_workers=os.cpu_count()//2
)

print(len(traindl))
print(len(validdl))

309
1040


In [30]:
class FocalClassifierV0(nn.Module):
    def __init__(self, gamma=0.1): #Change gamma value here in order to acquire other results
        super().__init__()
        
        self.gamma = gamma
        self.act = nn.LogSoftmax(dim=1)

    
    def forward(self, pred, target):

        logits = self.act(pred)

        B, C = tuple(logits.size())

        entropy = torch.pow(1 - logits, self.gamma) * logits * F.one_hot(target, num_classes=C).float()

        return (-1 / B) * torch.sum(entropy)

focalloss_fn = FocalClassifierV0()

In [31]:
epoch = 150
lr = 0.0005
model.to(device)
optimizer = Adam(model.parameters(), lr=lr)
scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=epoch)
loss_fn = nn.CrossEntropyLoss()
best_acc = 0
best_ep = 0

In [32]:
for e in range(epoch):
    model.train()
    print(f"Epoch: {e}")
    batch_cnt = 0
    total_loss = 0
    correct = 0
    for batch, (train_sig, train_label) in tqdm(enumerate(traindl)):
        batch_cnt = batch
        train_sig = train_sig.to(device)
        train_label = train_label.to(device)
        
        pred = model(train_sig)
        loss = focalloss_fn(pred, train_label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        scheduler.step()
        
        total_loss += loss.item()
        correct += (pred.argmax(1) == train_label).type(torch.float).sum().item()
    
    total_loss /= batch_cnt
    correct /= len(traindl.dataset)
    
    print(f"train loss: {total_loss} - train acc: {100*correct}")
    
    batch_cnt = 0
    val_total_loss = 0
    val_correct = 0
    model.eval()
    with torch.no_grad():
        for batch, (valid_sig, valid_label) in tqdm(enumerate(validdl)):
            batch_cnt = batch
            valid_sig = valid_sig.to(device)
            valid_label = valid_label.to(device)
            
            pred = model(valid_sig)
            loss = loss_fn(pred, valid_label)
            
            val_total_loss += loss.item()
            val_correct += (pred.argmax(1) == valid_label).type(torch.float).sum().item()
    
        val_total_loss /= batch_cnt
        val_correct /= len(validdl.dataset)
        
        if val_correct > best_acc:
            best_acc = val_correct
            best_ep = e
        
        print(f"valid loss: {val_total_loss} - valid acc: {100*val_correct}")
        
print(f"Best acuracy: {best_acc} at epoch {best_ep}")

Epoch: 0


309it [00:42,  7.30it/s]

train loss: 1.6981357492796787 - train acc: 66.44896719319563



1040it [00:03, 301.32it/s]

valid loss: 1.4561169335369426 - valid acc: 66.82692307692307
Epoch: 1



228it [00:33,  6.90it/s]


KeyboardInterrupt: 