In [None]:
import numpy as np
import wfdb
import pandas as pd
from glob import glob
import os
import biosppy
import pyhrv
import pyhrv.tools as tools
import json
import shap

In [2]:
import torch
import torch.nn as nn


class BasicBlock1d(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock1d, self).__init__()
        self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=7, stride=stride, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=0.2)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=7, stride=1, padding=3, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.downsample = downsample
    
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out


class ResNet1d(nn.Module):
    def __init__(self, block, layers, input_channels=12, inplanes=64, num_classes=9):
        super(ResNet1d, self).__init__()
        self.inplanes = inplanes
        self.conv1 = nn.Conv1d(input_channels, self.inplanes, kernel_size=15, stride=2, padding=7, bias=False)
        self.bn1 = nn.BatchNorm1d(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(BasicBlock1d, 64, layers[0])
        self.layer2 = self._make_layer(BasicBlock1d, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(BasicBlock1d, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(BasicBlock1d, 512, layers[3], stride=2)
        self.adaptiveavgpool = nn.AdaptiveAvgPool1d(1)
        self.adaptivemaxpool = nn.AdaptiveMaxPool1d(1)
        # self.fc1 = nn.Linear(512 * block.expansion * 2 + 33, 50)
        # self.fc2 = nn.Linear(50, num_classes)
        #self.softmax = nn.LogSoftmax(dim=1)
        self.fc = nn.Linear(512 * block.expansion * 2 + 33, num_classes)
        self.dropout = nn.Dropout(0.2)
    
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv1d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))
        return nn.Sequential(*layers)

    def forward(self, x, hrv):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x1 = self.adaptiveavgpool(x)
        x2 = self.adaptivemaxpool(x)
        x = torch.cat((x1, x2), dim=1)
        x = x.view(x.size(0), -1)
        x = torch.cat((x, hrv), dim=1)
        x = self.fc(x)
        # x = self.fc2(x)
        #x = self.softmax(x)
        return x
        # return self.fc(x)


def resnet18(**kwargs):
    model = ResNet1d(BasicBlock1d, [2, 2, 2, 2], **kwargs)
    return model


def resnet34(**kwargs):
    model = ResNet1d(BasicBlock1d, [3, 4, 6, 3], **kwargs)
    return model

In [3]:
import os
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
import wfdb


def scaling(X, sigma=0.1):
    scalingFactor = np.random.normal(loc=1.0, scale=sigma, size=(1, X.shape[1]))
    myNoise = np.matmul(np.ones((X.shape[0], 1)), scalingFactor)
    return X * myNoise


def shift(sig, interval=20):
    for col in range(sig.shape[1]):
        offset = np.random.choice(range(-interval, interval))
        sig[:, col] += offset / 1000 
    return sig


def transform(sig, train=False):
    if train:
        if np.random.randn() > 0.5: sig = scaling(sig)
        if np.random.randn() > 0.5: sig = shift(sig)
    return sig


class ECGDataset(Dataset):
    def __init__(self, phase, data_dir, hrv_features, label_csv, folds, leads):
        super(ECGDataset, self).__init__()
        self.phase = phase
        df = pd.read_csv(label_csv)
        df = df[df['fold'].isin(folds)]
        self.data_dir = data_dir
        df_hrv = pd.read_csv(hrv_features)
        df_hrv.replace([np.inf, -np.inf], 100000, inplace=True)
        df_hrv = df_hrv.merge(df[['patient_id', 'fold']], on='patient_id')
        self.hrv_features = df_hrv[df_hrv.columns[1:-1]].to_numpy()
        self.labels = df
        self.leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
        if leads == 'all':
            self.use_leads = np.where(np.in1d(self.leads, self.leads))[0]
        else:
            self.use_leads = np.where(np.in1d(self.leads, leads))[0]
        self.nleads = len(self.use_leads)
        self.classes = ['SNR', 'AF', 'IAVB', 'LBBB', 'RBBB', 'PAC', 'PVC', 'STD', 'STE']
        self.n_classes = len(self.classes)
        self.data_dict = {}
        self.label_dict = {}

    def __getitem__(self, index: int):
        row = self.labels.iloc[index]
        patient_id = row['patient_id']
        ecg_data, _ = wfdb.rdsamp(os.path.join(self.data_dir, patient_id))
        ecg_data = transform(ecg_data, self.phase == 'train')
        nsteps, _ = ecg_data.shape
        ecg_data = ecg_data[-15000:, self.use_leads]
        result = np.zeros((15000, self.nleads)) # 30 s, 500 Hz
        result[-nsteps:, :] = ecg_data
        if self.label_dict.get(patient_id):
            labels = self.label_dict.get(patient_id)
        else:
            labels = row[self.classes].to_numpy(dtype=np.float32)
            self.label_dict[patient_id] = labels
        return torch.from_numpy(result.transpose()).float(), torch.from_numpy(self.hrv_features[index]).float(), torch.from_numpy(labels).float()

    def __len__(self):
        return len(self.labels)

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from tqdm import tqdm
import numpy as np
from utils import cal_f1s, cal_aucs, split_data

def train(model, dataloader, criterion, epoch, scheduler, optimizer, device):
    print('Training epoch %d:' % epoch)
    model.train()
    running_loss = 0
    output_list, labels_list = [], []
    for _, (data, hrv, labels) in enumerate(tqdm(dataloader)):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(data, hrv)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        output_list.append(output.data.cpu().numpy())
        labels_list.append(labels.data.cpu().numpy())
    # scheduler.step()
    print('Loss: %.4f' % running_loss)
    

def evaluate(model, dataloader, criterion, device):
    print('Validating...')
    model.eval()
    running_loss = 0
    output_list, labels_list = [], []
    for _, (data, hrv, labels) in enumerate(tqdm(dataloader)):
        data, labels = data.to(device), labels.to(device)
        output = model(data, hrv)
        loss = criterion(output, labels)
        running_loss += loss.item()
        output = torch.sigmoid(output)
        output_list.append(output.data.cpu().numpy())
        labels_list.append(labels.data.cpu().numpy())
    print('Loss: %.4f' % running_loss)
    y_trues = np.vstack(labels_list)
    y_scores = np.vstack(output_list)
    f1s = cal_f1s(y_trues, y_scores)
    avg_f1 = np.mean(f1s)
    print('F1s:', f1s)
    print('Avg F1: %.4f' % avg_f1)
    if phase == 'train' and avg_f1 > 0:
        best_metric = avg_f1
        torch.save(model.state_dict(), model_path)
    else:
        aucs = cal_aucs(y_trues, y_scores)
        avg_auc = np.mean(aucs)
        print('AUCs:', aucs)
        print('Avg AUC: %.4f' % avg_auc)


if __name__ == "__main__":
    data_dir = 'ecg-diagnosis/data/CPSC'
    leads = 'all'
    phase = 'train'
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model_path = 'models/ECG_HRV_resnet.pth'
    
    if leads == 'all':
        #leads = 'all'
        nleads = 12
    else:
        #leads = args.leads.split(',')
        nleads = len(leads)
    n_epochs = 40
    label_csv = 'labelx.csv'
    hrv_features = 'mean_hrv_features.csv'
    
    train_folds, val_folds, test_folds = split_data()

    train_dataset = ECGDataset('train', data_dir, hrv_features, label_csv, train_folds, leads)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
    val_dataset = ECGDataset('val', data_dir, hrv_features, label_csv, val_folds, leads)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)
    test_dataset = ECGDataset('test', data_dir, hrv_features, label_csv, test_folds, leads)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)
    model = resnet34(input_channels=nleads).to(device)
    #model = ECG_HRVCONVNet(input_leads=nleads, num_classes=9).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.1)
    criterion = nn.BCEWithLogitsLoss()

    if phase == 'train':
        # if resume:
        #     model.load_state_dict(torch.load(model_path, map_location=device))
        for epoch in range(1, n_epochs+1):
            train(model, train_loader, criterion, epoch, scheduler, optimizer, device)
            evaluate(model, val_loader, criterion, device)
        
    else:
        model.load_state_dict(torch.load(model_path, map_location=device))
        evaluate(model, test_loader, criterion, device)

In [None]:
import argparse
import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

# from resnet import resnet34
# from dataset import ECGDataset
from utils import cal_scores, find_optimal_threshold, split_data
from sklearn.metrics import roc_curve, auc

def get_thresholds(val_loader, model, hrv, device, threshold_path):
    print('Finding optimal thresholds...')
    if os.path.exists(threshold_path):
        return pickle.load(open(threshold_path, 'rb'))
    output_list, label_list = [], []
    for _, (data, hrv, label) in enumerate(tqdm(val_loader)):
        data, labels = data.to(device), label.to(device)
        output = model(data, hrv)
        output = torch.sigmoid(output)
        output_list.append(output.data.cpu().numpy())
        label_list.append(labels.data.cpu().numpy())
    y_trues = np.vstack(label_list)
    y_scores = np.vstack(output_list)
    thresholds = []
    for i in range(y_trues.shape[1]):
        y_true = y_trues[:, i]
        y_score = y_scores[:, i]
        threshold = find_optimal_threshold(y_true, y_score)
        thresholds.append(threshold)
    # pickle.dump(thresholds, open(threshold_path, 'wb'))
    return thresholds


def apply_thresholds(test_loader, model, hrv, device, thresholds):
    output_list, label_list = [], []
    classes = ['SNR', 'AF', 'IAVB', 'LBBB', 'RBBB', 'PAC', 'PVC', 'STD', 'STE']
    for _, (data, hrv, label) in enumerate(tqdm(test_loader)):
        data, labels = data.to(device), label.to(device)
        output = model(data, hrv)
        output = torch.sigmoid(output)
        output_list.append(output.data.cpu().numpy())
        label_list.append(labels.data.cpu().numpy())
    y_trues = np.vstack(label_list)
    y_scores = np.vstack(output_list)
    y_preds = []
    scores = []
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(len(thresholds)):
        y_true = y_trues[:, i]
        y_score = y_scores[:, i]
        y_pred = (y_score >= thresholds[i]).astype(int)
        y_preds.append(y_pred)
        scores.append(cal_scores(y_true, y_pred, y_score))
        fpr[i], tpr[i], _ = roc_curve(y_true, y_score)
        roc_auc[classes[i]] = auc(fpr[i], tpr[i])
        plt.plot(fpr[i], tpr[i], label=f'{classes[i]} (AUC = {roc_auc[classes[i]]:0.2})')   
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ECG-HRV ROC Curve')
    plt.legend(loc="lower right")
    plt.savefig('results/ROC_ECG-HRV.png')
    plt.show()
    plt.close()
        
    y_preds = np.array(y_preds).transpose()
    scores = np.array(scores)
    print('Precisions:', scores[:, 0])
    print('Recalls:', scores[:, 1])
    print('F1s:', scores[:, 2])
    print('AUCs:', scores[:, 3])
    print('Accs:', scores[:, 4])
    print(np.mean(scores, axis=0))
    plot_cm(y_trues, y_preds)


def plot_cm(y_trues, y_preds, normalize=True, cmap=plt.cm.Blues):
    classes = ['SNR', 'AF', 'IAVB', 'LBBB', 'RBBB', 'PAC', 'PVC', 'STD', 'STE']
    for i, label in enumerate(classes):
        y_true = y_trues[:, i]
        y_pred = y_preds[:, i]
        cm = confusion_matrix(y_true, y_pred)
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fig, ax = plt.subplots(figsize=(4, 4))
        im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
        ax.figure.colorbar(im, ax=ax)
        ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           xticklabels=[0, 1], yticklabels=[0, 1],
           title=label,
           ylabel='True label',
           xlabel='Predicted label')
        plt.setp(ax.get_xticklabels(), ha="center")

        fmt = '.3f' if normalize else 'd'
        thresh = cm.max() / 2.
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                ax.text(j, i, format(cm[i, j], fmt),
                        ha="center", va="center",
                        color="white" if cm[i, j] > thresh else "black")
        np.set_printoptions(precision=3)
        fig.tight_layout()
        plt.savefig(f'results/{label}.png')
        plt.close(fig)


if __name__ == "__main__":
    #args = parse_args()
    #data_dir = os.path.normpath(args.data_dir)
    data_dir = 'ecg-diagnosis/data/CPSC'
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_path = 'models/ECG_HRV_resnet.pth'

    n_epochs = 40
    label_csv = 'labelx.csv'
    hrv_features = 'mean_hrv_features.csv'
    database = os.path.basename(data_dir)
    if not model_path:
        model_path = 'models/ECG_HRV_resnet.pth'
    threshold_path = f'models/{database}-threshold.pkl'
    leads = 'all'
    
    if leads == 'all':
       # leads = 'all'
        nleads = 12
    else:
        #leads = args.leads.split(',')
        nleads = len(leads)
    
    model = resnet34(input_channels=nleads).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    batch_size = 16
    train_folds, val_folds, test_folds = split_data(seed=42)
    train_dataset = ECGDataset('train', data_dir, hrv_features, label_csv, train_folds, leads)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_dataset = ECGDataset('val', data_dir, hrv_features, label_csv, val_folds, leads)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_dataset = ECGDataset('test', data_dir, hrv_features, label_csv, test_folds, leads)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    
    thresholds = get_thresholds(val_loader, model, hrv_features, device, threshold_path)
    print('Thresholds:', thresholds)

    print('Results on validation data:')
    apply_thresholds(val_loader, model, hrv_features, device, thresholds)

    print('Results on test data:')
    apply_thresholds(test_loader, model, hrv_features, device, thresholds)