# Atrial Fibrillation
This notebook contains all of the code needed to set up and run SMoLK models on the atrial fibrillation detection task.

## Setup
Here, we import all of the necessary libraries we need to run this code, as well as set the random seed for reproducibility. Note that we set device to "cuda" to run on a CUDA-enabled GPU. If you do not have a CUDA enabled GPU, you can set this to "cpu" instead, though it will be much slower.

In [None]:
import os # path/directory stuff
import pickle

# Deep learning

import torch
import torch.nn as nn
import torch.nn.functional as F

# Math
from scipy.signal import savgol_filter
import numpy as np
import random
from scipy.signal import periodogram
from utils.stats import calculate_metrics, print_table

# Data
from utils.datasets import ZhengEtAl, CinC, generate_split

# Set seed for reproducibility
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

#Progress bar
from tqdm import tqdm

device = "mps" # device to use

# Load the Data

In [None]:
X, Y, class_names = CinC(base_dir="../Interpretable Arrhythmia.nosync/")

# exclude class 3, which is the noisy class
X = X[Y != 3]
Y = Y[Y != 3]
class_names = ["Normal", "Afib", "Other"]

# Define the model, train, and test functions

In [None]:
class LearnedFilters(nn.Module):
    def __init__(self, num_kernels=24, num_classes=4):
        super(LearnedFilters, self).__init__()
        self.conv1 = nn.Conv1d(1, num_kernels, 192, stride=1, bias=True)
        self.conv2 = nn.Conv1d(1, num_kernels, 96, stride=1, bias=True)
        self.conv3 = nn.Conv1d(1, num_kernels, 64, stride=1, bias=True)
        
        self.linear = nn.Linear(num_kernels*3 + 321, num_classes) # 321 is the size of the power spectrum
    
    def forward(self, x, powerspectrum):
        c1 = F.leaky_relu(self.conv1(x)).mean(dim=-1)
        c2 = F.leaky_relu(self.conv2(x)).mean(dim=-1)
        c3 = F.leaky_relu(self.conv3(x)).mean(dim=-1)
        
        aggregate = torch.cat([c1,c2,c3, powerspectrum], dim=1)
        aggregate = self.linear(aggregate)
        
        return aggregate

def train(device, X, Y, class_weights=None, num_kernels=128, lr=0.001, batch_size=256, num_epoch=16, end_factor=0.1, use_tqdm=True):
    # compute power spectra for X
    PowerSpectra = []
    for i in tqdm(range(0, len(X))):
        PowerSpectra.append(periodogram(X[i], fs=64)[1])
    PowerSpectra = np.float32(PowerSpectra)

    model = LearnedFilters(num_kernels=num_kernels, num_classes=np.max(Y)+1).to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=end_factor, total_iters=num_epoch*len(X)//batch_size)
    criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights, dtype=torch.float, device=device))
    
    if use_tqdm:
        pbar = tqdm(range(0, num_epoch))
    else:
        pbar = range(0, num_epoch)
    for epoch in pbar:
        for batch_idx in range(0, len(X), batch_size):
            data = X[batch_idx:batch_idx+batch_size]
            powerspectrum = PowerSpectra[batch_idx:batch_idx+batch_size]
            target = Y[batch_idx:batch_idx+batch_size]

            data, powerspectrum, target = torch.tensor(data).to(device), torch.tensor(powerspectrum, dtype=torch.float32).to(device), torch.tensor(target).to(device).long()

            data = data.unsqueeze(1)
            
            optimizer.zero_grad()
            output = model(data, powerspectrum)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            scheduler.step()
            if use_tqdm:
                pbar.set_description(f"loss: {loss.item():.5f}")
    
    return model

def test(model, device, X):
    # compute power spectra for X
    PowerSpectra = []
    for i in range(0, len(X)):
        PowerSpectra.append(periodogram(X[i], fs=64)[1])
    PowerSpectra = np.float32(PowerSpectra)

    model.eval()
    with torch.no_grad():
        probs = []
        for i in range(0, len(X)):
            data = X[i]
            powerspectrum = PowerSpectra[i]
            
            data, powerspectrum = torch.tensor(data).to(device), torch.tensor(powerspectrum, dtype=torch.float32).to(device)
            data = data.unsqueeze(0).unsqueeze(1)
            powerspectrum = powerspectrum.unsqueeze(0)
            
            output = model(data, powerspectrum).softmax(dim=-1)
            probs.append(output.cpu().numpy())

    probs = np.concatenate(probs, axis=0)

    return probs

# Train the models

In [None]:
models = []
num_splits = 10
data_fraction = 1.0
num_kernels = 128
for split in range(0, num_splits):
    print(f"Split {split + 1}")
    X_train, Y_train, X_test, Y_test = generate_split(X, Y, split, num_splits)

    # shuffle X_train and Y_train
    p = np.random.permutation(len(X_train))
    X_train = X_train[p]
    Y_train = Y_train[p]

    X_train = X_train[:int(data_fraction * len(X_train))]
    Y_train = Y_train[:int(data_fraction * len(Y_train))]

    class_weights = 1 / np.bincount(Y_train)
    class_weights /= class_weights.sum()

    model = train(device, X_train, Y_train, class_weights=class_weights, num_kernels=num_kernels, lr=0.1, batch_size=1024, num_epoch=512, end_factor=0.0, use_tqdm=True)
    models.append(model)

# Test the models via Cross Validation
In this section, we simply run each model fold again the various test sets

In [None]:
# Test models
probs = []
ground_truth = []
for split in range(0, num_splits):
    print(f"Split {split + 1}")
    _, _, X_test, Y_test = generate_split(X, Y, split, num_splits)
    with torch.no_grad():
        Y_prob = test(models[split], device, X_test)
        Y_prob = torch.tensor(Y_prob).cpu().numpy()

    probs.append(Y_prob)
    ground_truth.append(Y_test)

sensitivities = []
specificities = []
AUCs = []
F1s = []

for split in range(0, num_splits):
    sen, spec, auc, f1 = calculate_metrics(ground_truth[split], probs[split], num_classes=Y.max()+1)
    sensitivities.append(sen)
    specificities.append(spec)
    AUCs.append(auc)
    F1s.append(f1)

sensitivities = np.array(sensitivities)
specificities = np.array(specificities)
AUCs = np.array(AUCs)
F1s = np.array(F1s)

print_table(sensitivities, specificities, AUCs, class_names)
print(f"F1: {F1s.mean():.3f} ± {F1s.std():.3f}")

# Test on Holdout Set

In [None]:
X, Y, class_names = ZhengEtAl(base_dir="../Interpretable Arrhythmia.nosync/")

class_names = ["Normal", "Afib", "Other"]
new_mapping = [2, 2, 0, 1, 2, 2, 2] # our loader outputs many different classes, so we want to condense the other classes to "other"
Y = np.array([new_mapping[y] for y in Y])

In [None]:
# Test models
probs = []
ground_truth = []
for split in range(0, num_splits):
    print(f"Split {split + 1}")
    with torch.no_grad():
        Y_prob = test(models[split], device, X)
        Y_prob = torch.tensor(Y_prob).cpu().numpy()

    probs.append(Y_prob)
    ground_truth.append(Y)

In [None]:
sensitivities = []
specificities = []
AUCs = []
F1s = []

for split in range(0, num_splits):
    sen, spec, auc, f1 = calculate_metrics(ground_truth[split], probs[split], num_classes=Y.max()+1)
    sensitivities.append(sen)
    specificities.append(spec)
    AUCs.append(auc)
    F1s.append(f1)

sensitivities = np.array(sensitivities)
specificities = np.array(specificities)
AUCs = np.array(AUCs)
F1s = np.array(F1s)

print_table(sensitivities, specificities, AUCs, class_names)
print(f"F1: {F1s.mean():.3f} ± {F1s.std():.3f}")

# ResNet Model
Now, let's compare against a ResNet

In [None]:
from utils.misc import train, test # NOTE: THIS OVERWRITES THE PREVIOUS train AND test FUNCTIONS

# code is ported from https://github.com/antonior92/automatic-ecg-diagnosis/tree/master?tab=readme-ov-file

class ResidualUnit(nn.Module):
    def __init__(self, n_filters_in, n_filters_out, kernel_size=17, stride=1, dropout_rate=0.2, 
                 preactivation=True, postactivation_bn=False, activation_function='relu'):
        super(ResidualUnit, self).__init__()
        self.preactivation = preactivation
        self.postactivation_bn = postactivation_bn
        self.dropout_rate = dropout_rate

        # Activation function
        if activation_function == 'relu':
            self.activation = nn.ReLU(inplace=True)
        else:
            raise NotImplementedError("Activation function '{}' not implemented.".format(activation_function))

        self.bn1 = nn.BatchNorm1d(n_filters_in)
        self.conv1 = nn.Conv1d(n_filters_in, n_filters_out, kernel_size, stride=1, padding=kernel_size//2, bias=False)
        self.bn2 = nn.BatchNorm1d(n_filters_out)
        self.conv2 = nn.Conv1d(n_filters_out, n_filters_out, kernel_size, stride=stride, padding=kernel_size//2-1, bias=False)
        self.dropout = nn.Dropout(dropout_rate)
        self.downsample = stride != 1 or n_filters_in != n_filters_out
        if self.downsample:
            self.conv_shortcut = nn.Conv1d(n_filters_in, n_filters_out, 1, stride=stride, bias=False)

    def forward(self, x):
        identity = x

        out = x
        if self.preactivation:
            out = self.bn1(out)
            out = self.activation(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.activation(out)
        if self.dropout_rate > 0:
            out = self.dropout(out)
        out = self.conv2(out)

        if self.downsample:
            identity = self.conv_shortcut(identity)

        out += identity
        if not self.preactivation or self.postactivation_bn:
            out = self.bn2(out)
            out = self.activation(out)
        return out

class ResNet1D(nn.Module):
    def __init__(self, n_classes, kernel_size=16, last_layer_activation='sigmoid'):
        super(ResNet1D, self).__init__()
        self.conv1 = nn.Conv1d(1, 64, kernel_size, stride=1, padding=kernel_size//2, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU(inplace=True)
        
        self.res1 = ResidualUnit(64, 128, kernel_size=kernel_size, stride=3)
        self.res2 = ResidualUnit(128, 196, kernel_size=kernel_size, stride=3)
        self.res3 = ResidualUnit(196, 256, kernel_size=kernel_size, stride=2)
        self.res4 = ResidualUnit(256, 320, kernel_size=kernel_size, stride=2)

        self.flatten = nn.Flatten()
        self.fc = nn.Linear(5760, n_classes)
        
        # Last layer activation
        if last_layer_activation == 'sigmoid':
            self.last_activation = nn.Sigmoid()
        elif last_layer_activation == 'softmax':
            # Softmax is typically used with nn.CrossEntropyLoss, which expects raw scores.
            self.last_activation = lambda x: F.softmax(x, dim=1)
        else:
            raise NotImplementedError("Last layer activation '{}' not implemented.".format(last_layer_activation))

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = self.res4(x)

        x = self.flatten(x)
        x = self.fc(x)
        #x = self.last_activation(x)
        return x

# Train

In [None]:
X, Y, class_names = CinC(base_dir="../Interpretable Arrhythmia.nosync/")
# exclude class 3
X = X[Y != 3]
Y = Y[Y != 3]
class_names = ["Normal", "Afib", "Other"]

In [None]:
models = []
num_splits = 10
data_fraction = 1.0
for split in range(0, num_splits):
    print(f"Split {split + 1}")
    X_train, Y_train, X_test, Y_test = generate_split(X, Y, split, num_splits)

    # shuffle X_train and Y_train
    p = np.random.permutation(len(X_train))
    X_train = X_train[p]
    Y_train = Y_train[p]

    X_train = X_train[:int(data_fraction * len(X_train))]
    Y_train = Y_train[:int(data_fraction * len(Y_train))]

    class_weights = 1 / np.bincount(Y_train)
    class_weights /= class_weights.sum()

    model = ResNet1D(n_classes=len(class_names)).to(device)
    model = train(model, device, X_train, Y_train, class_weights=class_weights, lr=0.001, batch_size=1024, num_epoch=512, end_factor=0.0, use_tqdm=True)
    models.append(model)

In [None]:
# Test models
probs = []
ground_truth = []
for split in range(0, num_splits):
    print(f"Split {split + 1}")
    _, _, X_test, Y_test = generate_split(X, Y, split, num_splits)
    with torch.no_grad():
        Y_prob = test(models[split], device, X_test)
        Y_prob = torch.tensor(Y_prob).softmax(dim=-1).cpu().numpy()

    probs.append(Y_prob)
    ground_truth.append(Y_test)

In [None]:
sensitivities = []
specificities = []
AUCs = []
F1s = []

for split in range(0, num_splits):
    sen, spec, auc, f1 = calculate_metrics(ground_truth[split], probs[split], num_classes=Y.max()+1)
    sensitivities.append(sen)
    specificities.append(spec)
    AUCs.append(auc)
    F1s.append(f1)

sensitivities = np.array(sensitivities)
specificities = np.array(specificities)
AUCs = np.array(AUCs)
F1s = np.array(F1s)

print_table(sensitivities, specificities, AUCs, class_names)
print(f"F1: {F1s.mean():.3f} ± {F1s.std():.3f}")

# Test on Holdout Set

In [None]:
X, Y, class_names = ZhengEtAl(base_dir="../Interpretable Arrhythmia.nosync/")

class_names = ["Normal", "Afib", "Other"]
new_mapping = [2, 2, 0, 1, 2, 2, 2]
Y = np.array([new_mapping[y] for y in Y])

In [None]:
# Test models
probs = []
ground_truth = []
for split in range(0, num_splits):
    print(f"Split {split + 1}")
    with torch.no_grad():
        Y_prob = test(models[split], device, X)
        Y_prob = torch.tensor(Y_prob).cpu().numpy()

    probs.append(Y_prob)
    ground_truth.append(Y)

sensitivities = []
specificities = []
AUCs = []
F1s = []

for split in range(0, num_splits):
    sen, spec, auc, f1 = calculate_metrics(ground_truth[split], probs[split], num_classes=Y.max()+1)
    sensitivities.append(sen)
    specificities.append(spec)
    AUCs.append(auc)
    F1s.append(f1)

sensitivities = np.array(sensitivities)
specificities = np.array(specificities)
AUCs = np.array(AUCs)
F1s = np.array(F1s)

print_table(sensitivities, specificities, AUCs, class_names)
print(f"F1: {F1s.mean():.3f} ± {F1s.std():.3f}")

# Interpretability

In [None]:
from utils.datasets import ZhengEtAl_AVB

# Train

In [None]:
X, Y, class_names = ZhengEtAl_AVB(base_dir="../Interpretable Arrhythmia.nosync/")
X = X[Y != 2]
Y = Y[Y != 2]
class_names = class_names[:2]

In [None]:
# shuffle X_train and Y_train
p = np.random.permutation(len(X))
split = 0.8
X = X[p]
Y = Y[p]
X_train = X[:int(len(X) * split)]
Y_train = Y[:int(len(X) * split)]
X_test = X[int(len(X) * split):]
Y_test = Y[int(len(X) * split):]

In [None]:
class_weights = 1 / np.bincount(Y_train)
class_weights /= class_weights.sum()

def train(device, X, Y, class_weights=None, num_kernels=128, lr=0.001, batch_size=256, num_epoch=16, end_factor=0.1, use_tqdm=True):
    # compute power spectra for X
    PowerSpectra = []
    for i in tqdm(range(0, len(X))):
        PowerSpectra.append(periodogram(X[i], fs=64)[1])
    PowerSpectra = np.float32(PowerSpectra)

    model = LearnedFilters(num_kernels=num_kernels, num_classes=np.max(Y)+1).to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=end_factor, total_iters=num_epoch*len(X)//batch_size)
    criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights, dtype=torch.float, device=device))
    
    if use_tqdm:
        pbar = tqdm(range(0, num_epoch))
    else:
        pbar = range(0, num_epoch)
    for epoch in pbar:
        for batch_idx in range(0, len(X), batch_size):
            data = X[batch_idx:batch_idx+batch_size]
            powerspectrum = PowerSpectra[batch_idx:batch_idx+batch_size]
            target = Y[batch_idx:batch_idx+batch_size]

            data, powerspectrum, target = torch.tensor(data).to(device), torch.tensor(powerspectrum, dtype=torch.float32).to(device), torch.tensor(target).to(device).long()

            data = data.unsqueeze(1)
            
            optimizer.zero_grad()
            output = model(data, powerspectrum)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            scheduler.step()
            if use_tqdm:
                pbar.set_description(f"loss: {loss.item():.5f}")
    
    return model

lkm_model = train(device, X_train, Y_train, class_weights=class_weights, num_kernels=128, lr=0.1, batch_size=1024, num_epoch=512, end_factor=0.0, use_tqdm=True)

# Interpretation

In [None]:
# compute power spectra for X_train
from scipy.signal import periodogram
from tqdm import tqdm

PowerSpectra = []
for i in tqdm(range(0, len(X))):
    PowerSpectra.append(periodogram(X[i], fs=64)[1])

In [None]:
import matplotlib.pyplot as plt
import torch.nn.functional as F

COI = 1

indices = np.arange(len(X))[Y==COI]
idx = np.random.choice(indices)
print(idx)

num_kernels = 128
net = lkm_model
sd = lkm_model.state_dict()

#size = 5
#order = 3
x = torch.tensor(X[idx], dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0)

c1 = F.leaky_relu(net.conv1(x)[0])
c2 = F.leaky_relu(net.conv2(x)[0])
c3 = F.leaky_relu(net.conv3(x)[0])

x = x[0, 0]
feature_contribution = (torch.cat([c1.mean(dim=-1),c2.mean(dim=-1),c3.mean(dim=-1)]) * sd['linear.weight'][COI, :num_kernels*3]).sum().item()
spectral_contribution = np.sum(PowerSpectra[idx] * sd['linear.weight'][COI, -321:].detach().cpu().numpy())

with torch.no_grad():
    accum = torch.zeros_like(x)
    for i in range(num_kernels):
        for j in range(x.shape[-1] - 192 + 1):
            accum[j:j+192] += c1[i, j] * sd['linear.weight'][COI, i] / 192
        for j in range(x.shape[-1] - 96 + 1):
            accum[j:j+96] += c2[i, j] * sd['linear.weight'][COI, i+num_kernels] / 96
        for j in range(x.shape[-1] - 64 + 1):
            accum[j:j+64] += c3[i, j] * sd['linear.weight'][COI, i+num_kernels*2] / 64

accum = accum.cpu().numpy()[192//2:-192//2]
x = x.cpu().numpy()[192//2:-192//2]

plt.figure(figsize=(5, 1), dpi=300)
#plt.title(f"True label: {class_names[Y[idx]]}, Class of interest: {class_names[COI]}, Contribution: {feature_contribution:.2f}")
#plt.title("LKM")
#norm = plt.Normalize(-np.abs(accum).max(), np.abs(accum).max())
norm = plt.Normalize(accum.min(), accum.max())

num_interpolated_points = len(x) * 100  # 100 times the number of original points
linspace = np.linspace(0, len(x)-1, num_interpolated_points)
accum_interp = np.interp(linspace, np.arange(len(x)), accum)
point_sizes = np.interp(np.abs(accum_interp), (np.abs(accum_interp).min(), np.abs(accum_interp).max()), (10, 100))
x_interp = np.interp(linspace, np.arange(0, len(x)), x)

# Choose a suitable colormap (you can use any other colormap as needed)
colormap = plt.get_cmap('RdBu_r')

# Plot the curve
plt.plot(np.arange(0, len(x)), x, c='k', linewidth=0.5)  # Using black curve for better visualization

# Overlay with colored points
point_sizes = np.interp(np.abs(accum_interp), (np.abs(accum_interp).min(), np.abs(accum_interp).max()), (0.01, 0.5))
plt.scatter(linspace, x_interp, c=accum_interp, cmap=colormap, norm=norm, s=point_sizes)

# Show colorbar
cbar = plt.colorbar()
#cbar.set_label('Contribution')

# Remove axis ticks
plt.xticks([])
plt.yticks([])
plt.xlim(0, len(x)-1)

plt.show()


plt.figure(figsize=(5, 2), dpi=150)
plt.title(f"Frequency Spectrum, Contribution: {spectral_contribution:.2f}")

weight = sd['linear.weight'][COI, -321:].cpu().numpy() * PowerSpectra[idx]

# Normalize weight for color mapping
norm = plt.Normalize(weight.min(), weight.max())
colormap = plt.get_cmap('RdBu_r')
colors = colormap(norm(weight))

plt.bar(np.linspace(0, 32, len(PowerSpectra[idx])), PowerSpectra[idx], color=colors, width=1/len(PowerSpectra[idx])*32*2)
plt.xlim(0, 32)
plt.xlabel("Frequency (Hz)")