In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import os
import wfdb
from tqdm import tqdm
tqdm.pandas()

import warnings
warnings.filterwarnings('ignore')

fs = 500
n_leads = 12
units = 'mV'
lead_names = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']

drop_cols = ['fs', 'N_leads', 'mV', 'lead_names', 'units' ] #1 unique value per df

In [2]:
metadata_df = pd.read_csv('./metadata.csv')
metadata_df.drop(columns = metadata_df.columns.intersection(drop_cols), inplace = True)

metadata_df

Unnamed: 0,file,duration_sec,sex,Age,Age group
0,./data/A0002,10.0,Female,49,40to49
1,./data/A0016,20.0,Female,14,1to17
2,./data/A0020,17.0,Female,27,18to29
3,./data/A0029,10.0,Female,35,30to39
4,./data/A0030,12.0,Male,46,40to49
...,...,...,...,...,...
14890,./data/JS41476,10.0,Female,38,30to39
14891,./data/JS41477,10.0,Male,43,40to49
14892,./data/JS41478,10.0,Female,52,50to59
14893,./data/JS41479,10.0,Female,48,40to49


In [3]:
group_to_num = {'1to17':0,'18to29':1, '30to39':2, '40to49':3, '50to59':4, '60to69':5, '70to89':6, '90+':7 }
sex_to_num = {"Female":0, "Male":1}

group_to_num_reversed = {value: key for key,value in group_to_num.items()}
sex_to_num_reversed = {value: key for key,value in sex_to_num.items()}

group_to_num, sex_to_num

({'1to17': 0,
  '18to29': 1,
  '30to39': 2,
  '40to49': 3,
  '50to59': 4,
  '60to69': 5,
  '70to89': 6,
  '90+': 7},
 {'Female': 0, 'Male': 1})

In [4]:
metadata_df['Age group'] = metadata_df['Age group'].apply(lambda x: group_to_num[x])
metadata_df['sex'] = metadata_df['sex'].apply(lambda x: sex_to_num[x])

In [5]:
metadata_df

Unnamed: 0,file,duration_sec,sex,Age,Age group
0,./data/A0002,10.0,0,49,3
1,./data/A0016,20.0,0,14,0
2,./data/A0020,17.0,0,27,1
3,./data/A0029,10.0,0,35,2
4,./data/A0030,12.0,1,46,3
...,...,...,...,...,...
14890,./data/JS41476,10.0,0,38,2
14891,./data/JS41477,10.0,1,43,3
14892,./data/JS41478,10.0,0,52,4
14893,./data/JS41479,10.0,0,48,3


In [6]:
from sklearn.model_selection import train_test_split

train_csv, test_csv, _, _ = train_test_split(metadata_df, metadata_df['Age'], stratify = metadata_df['Age group'],
                                            test_size=0.1)
train_csv.shape, test_csv.shape

((13405, 5), (1490, 5))

In [7]:
import torch
import torchvision.transforms.functional as tf
from torch.utils.data import Dataset, DataLoader
from sklearn.utils import shuffle

class EcgDataset(Dataset):
    def __init__(self, csv):
        super().__init__()
        self.csv = shuffle(csv).values
    def __len__(self):
        return len(self.csv)
    def __getitem__(self, idx):
        record = self.csv[idx]
        file = record[0]
        signal_arr = wfdb.rdsamp(file)[0]
        age, age_group, sex = record[-2], record[-1], record[-3]
        signal_arr = torch.tensor(signal_arr.T, dtype=torch.float32)
        return signal_arr, age, age_group, sex

train_dataset = EcgDataset(train_csv)
test_dataset = EcgDataset(test_csv)

#train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
#test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)

In [8]:
from torch_ecg.utils.utils_nn import adjust_cnn_filter_lengths
from torch_ecg.model_configs import ECG_CRNN_CONFIG
from torch_ecg.models.ecg_crnn import ECG_CRNN

config = adjust_cnn_filter_lengths(ECG_CRNN_CONFIG, fs=500)
# change the default CNN backbone
# bottleneck with global context attention variant of Nature Communications ResNet
config.cnn.name="resnet_nature_comm_bottle_neck_gc"


n_leads = 12
classes = [0,1,2,3,4,5,6,7]
model = ECG_CRNN(classes, n_leads, config)
#model.clf.lin_1 = torch.nn.Identity()
#model.sigmoid = torch.nn.Identity()
#model.softmax = torch.nn.Identity()

model(train_dataset[0][0].unsqueeze(0)).shape  # signal length 4000, batch size 2

torch.Size([1, 8])

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

In [10]:
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from copy import deepcopy

best_acc = 0
best_model = None

for epoch in range(3):
    total_loss = 0.0
    total_acc = 0.0
    train_dataset = EcgDataset(train_csv)
    model.train()
    for signal, age, age_group, sex in tqdm(train_dataset):
        optimizer.zero_grad()
        signal = signal.unsqueeze(0)
        outputs = model(signal)
        loss = criterion(outputs, torch.tensor([age_group]))
        loss.backward()
        optimizer.step()

        group_pred = outputs.squeeze().argmax()
        total_loss += loss.item()
        total_acc += (age_group==group_pred)
    print(f"train loss: {total_loss/len(train_dataset)}" )
    print(f"train accuracy: {total_acc/len(train_dataset)}" )
    
    model.eval()
    total_loss = 0.0
    y_true = []
    y_pred = []
    for signal, age, age_group, sex in tqdm(test_dataset):
        signal = signal.unsqueeze(0)
        outputs = model(signal)
        loss = criterion(outputs, torch.tensor([age_group]))
        group_pred = outputs.squeeze().argmax().item()
        total_loss += loss.item()
        y_true.append(age_group)
        y_pred.append(group_pred)
    print("TEST:")
    print(f"test loss: {total_loss/len(test_dataset)}" )
    print(confusion_matrix(y_true, y_pred))
    print(classification_report(y_true, y_pred, zero_division=1))
    
    test_acc = accuracy_score(y_true, y_pred)
    if test_acc>best_acc:
        best_acc = test_acc
        best_model = deepcopy(model)
        best_model.save("best_model.bin", train_config = config)

 70%|████████████████████████████████████████████████████▌                      | 9403/13405 [2:05:21<53:21,  1.25it/s]


KeyboardInterrupt: 

In [12]:
model.save("best_model.bin", train_config = config)

In [14]:
best_model = deepcopy(model)
best_model.eval()

ECG_CRNN(
  (cnn): ResNet(
    (input_stem): ResNetStem(
      (conv_0): Conv_Bn_Activation(
        (conv1d): Conv1d(12, 64, kernel_size=(17,), stride=(1,), padding=(8,))
        (batch_norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation_ReLU): ReLU()
      )
    )
    (ResNetBottleNeck_0_0): ResNetBottleNeck(
      (shortcut): DownSample(
        (down_sample): Conv1d(64, 512, kernel_size=(1,), stride=(4,), bias=False)
        (batch_normalization): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (main_stream): Sequential(
        (cba_head): Conv_Bn_Activation(
          (conv1d): Conv1d(64, 128, kernel_size=(1,), stride=(1,), bias=False)
          (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activation_ReLU): ReLU(inplace=True)
        )
        (dropout_0): Dropout(p=0.2, inplace=False)
        (cba_neck): Conv_Bn_Activa

In [15]:
model.eval()
total_loss = 0.0
y_true = []
y_pred = []
for signal, age, age_group, sex in tqdm(test_dataset):
    signal = signal.unsqueeze(0)
    outputs = model(signal)
    loss = criterion(outputs, torch.tensor([age_group]))
    group_pred = outputs.squeeze().argmax().item()
    total_loss += loss.item()
    y_true.append(age_group)
    y_pred.append(group_pred)
print("TEST:")
print(f"test loss: {total_loss/len(test_dataset)}" )
print(confusion_matrix(y_true, y_pred))
print(classification_report(y_true, y_pred, zero_division=1))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1490/1490 [10:11<00:00,  2.44it/s]

TEST:
test loss: nan
[[ 33   0   0   0   0   0   0   0]
 [154   0   0   0   0   0   0   0]
 [186   0   0   0   0   0   0   0]
 [280   0   0   0   0   0   0   0]
 [343   0   0   0   0   0   0   0]
 [290   0   0   0   0   0   0   0]
 [203   0   0   0   0   0   0   0]
 [  1   0   0   0   0   0   0   0]]
              precision    recall  f1-score   support

           0       0.02      1.00      0.04        33
           1       1.00      0.00      0.00       154
           2       1.00      0.00      0.00       186
           3       1.00      0.00      0.00       280
           4       1.00      0.00      0.00       343
           5       1.00      0.00      0.00       290
           6       1.00      0.00      0.00       203
           7       1.00      0.00      0.00         1

    accuracy                           0.02      1490
   macro avg       0.88      0.12      0.01      1490
weighted avg       0.98      0.02      0.00      1490




