In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [7]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.signal import periodogram, resample
import wfdb
from sklearn.metrics import confusion_matrix, roc_auc_score, f1_score
from tqdm import tqdm
import matplotlib.pyplot as plt  # Optional, for visualization
import pickle  # For saving loss history and models

# ===========================
# Set Seed for Reproducibility
# ===========================
def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed(0)

# ===========================
# Configuration
# ===========================
DATASET_PATH = "/content/drive/My Drive/mit-bih-arrhythmia-database-1.0.0"
SAMPLE_RATE = 360  # Original sampling rate for MIT-BIH
TARGET_SAMPLE_RATE = 64  # Downsampled rate as per user's code

# Updated Window Size: 3.0 seconds to match kernel size of 192
WINDOW_SIZE_SEC = 3.0  # Window size in seconds around R-peak
WINDOW_SIZE = int(WINDOW_SIZE_SEC * TARGET_SAMPLE_RATE)  # 192 samples

OVERLAP = 0  # No overlap
NUM_SPLITS = 10
DATA_FRACTION = 1.0
NUM_KERNELS = 128
LEARNING_RATE = 0.001
BATCH_SIZE = 256
NUM_EPOCHS = 16
END_FACTOR = 0.1
USE_TQDM = False  # Disable tqdm during training
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ===========================
# Utility Functions
# ===========================

def load_record(record_name, path, target_fs=TARGET_SAMPLE_RATE):
    """
    Load a single record, resample, and extract R-peaks and annotations.
    """
    record = wfdb.rdrecord(os.path.join(path, record_name))
    annotation = wfdb.rdann(os.path.join(path, record_name), 'atr')

    # Select first channel (usually MLII)
    signal = record.p_signal[:,0]

    # Resample signal
    num_samples = int(len(signal) * target_fs / SAMPLE_RATE)
    signal_resampled = resample(signal, num_samples)

    # Adjust R-peak locations after resampling
    r_peaks = (annotation.sample * target_fs) // SAMPLE_RATE
    r_peaks = r_peaks[r_peaks < len(signal_resampled)]

    # Extract annotations
    annotations = annotation.symbol
    return signal_resampled, r_peaks, annotations

def extract_windows(signal, r_peaks, annotations, window_size=WINDOW_SIZE):
    """
    Extract windows around R-peaks and assign labels.
    """
    half_window = window_size // 2
    windows = []
    labels = []

    for peak, symbol in zip(r_peaks, annotations):
        start = peak - half_window
        end = peak + half_window
        if start < 0 or end > len(signal):
            continue  # Skip if window is out of bounds
        window = signal[start:end]
        windows.append(window)
        labels.append(symbol)
    return np.array(windows), np.array(labels)

def map_labels(labels):
    """
    Map original MIT-BIH labels to desired classes:
    - "Normal" (N, L, R, e, j)
    - "Afib" (A, a, F, J, S)
    - "Other" (all other classes)
    """
    normal = ['N', 'L', 'R', 'e', 'j']  # Including some variants
    afib = ['A', 'a', 'F', 'J', 'S']
    mapped_labels = []
    for label in labels:
        if label in normal:
            mapped_labels.append(0)  # Normal
        elif label in afib:
            mapped_labels.append(1)  # Afib
        else:
            mapped_labels.append(2)  # Other
    return np.array(mapped_labels)

def load_dataset(path):
    """
    Load all records, extract windows and labels.
    """
    # List of all record names in the dataset
    records = [f.split('.')[0] for f in os.listdir(path) if f.endswith('.dat')]

    all_windows = []
    all_labels = []

    for record in tqdm(records, desc="Loading Records"):
        signal, r_peaks, annotations = load_record(record, path)
        windows, labels = extract_windows(signal, r_peaks, annotations)
        windows = windows[:, :WINDOW_SIZE]  # Ensure consistent window size
        all_windows.append(windows)
        all_labels.append(labels)

    all_windows = np.concatenate(all_windows, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    mapped_labels = map_labels(all_labels)

    # Exclude "Other" class if needed
    # For demonstration, we'll keep all classes
    return all_windows, mapped_labels

def generate_splits(X, Y, num_splits=NUM_SPLITS):
    """
    Generate cross-validation splits.
    """
    splits = []
    unique_classes = np.unique(Y)
    for split in range(num_splits):
        # Simple random split; for better stratification, implement stratified splits
        indices = np.arange(len(X))
        np.random.shuffle(indices)
        train_size = int(0.8 * len(X))
        train_idx = indices[:train_size]
        test_idx = indices[train_size:]
        splits.append((train_idx, test_idx))
    return splits

def calculate_metrics(y_true, y_prob, num_classes=3):
    """
    Calculate sensitivity, specificity, AUC, and F1 scores.
    """
    y_pred = np.argmax(y_prob, axis=1)
    sensitivities = []
    specificities = []
    AUCs = []
    F1s = []

    for cls in range(num_classes):
        # Binary labels for the current class
        true_binary = (y_true == cls).astype(int)
        pred_binary = (y_pred == cls).astype(int)

        # Compute confusion matrix
        cm = confusion_matrix(true_binary, pred_binary).ravel()
        if len(cm) == 4:
            tn, fp, fn, tp = cm
        elif len(cm) == 2:
            tn, tp = cm
            fp, fn = 0, 0
        elif len(cm) == 1:
            tp = cm[0]
            tn, fp, fn = 0, 0, 0
        else:
            tn, fp, fn, tp = 0, 0, 0, 0

        # Sensitivity (Recall)
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        sensitivities.append(sensitivity)
        specificities.append(specificity)

        # AUC
        try:
            auc = roc_auc_score(true_binary, y_prob[:, cls])
        except ValueError:
            auc = 0  # If only one class is present in y_true, AUC is not defined
        AUCs.append(auc)

        # F1 Score
        f1 = f1_score(true_binary, pred_binary, zero_division=0)
        F1s.append(f1)

    return sensitivities, specificities, AUCs, F1s

def print_table(sensitivities, specificities, AUCs, class_names):
    """
    Print a table of metrics.
    """
    print(f"{'Class':<10}{'Sensitivity':<15}{'Specificity':<15}{'AUC':<10}")
    for i, class_name in enumerate(class_names):
        print(f"{class_name:<10}{sensitivities[i]:<15.3f}{specificities[i]:<15.3f}{AUCs[i]:<10.3f}")

# ===========================
# Dataset Class
# ===========================
from torch.utils.data import Dataset, DataLoader

class ArrhythmiaDataset(Dataset):
    def __init__(self, X, Y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.Y = torch.tensor(Y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

# ===========================
# Model Definition
# ===========================

class LearnedFilters(nn.Module):
    """
    SMoLK Model: Learned Filters for Arrhythmia Classification

    This model applies multiple convolutional filters of varying kernel sizes
    to the input signal, extracts features by averaging the activations, and
    combines them with power spectrum features for classification.
    """
    def __init__(self, num_kernels=24, num_classes=3):
        super(LearnedFilters, self).__init__()
        self.conv1 = nn.Conv1d(1, num_kernels, 192, stride=1, bias=True)
        self.conv2 = nn.Conv1d(1, num_kernels, 96, stride=1, bias=True)
        self.conv3 = nn.Conv1d(1, num_kernels, 64, stride=1, bias=True)
        self.linear = nn.Linear(num_kernels*3 + 321, num_classes)  # 321 is the size of the power spectrum

    def forward(self, x, powerspectrum):
        c1 = F.leaky_relu(self.conv1(x)).mean(dim=-1)
        c2 = F.leaky_relu(self.conv2(x)).mean(dim=-1)
        c3 = F.leaky_relu(self.conv3(x)).mean(dim=-1)
        aggregate = torch.cat([c1, c2, c3, powerspectrum], dim=1)
        aggregate = self.linear(aggregate)
        return aggregate

# ===========================
# Training and Testing Functions
# ===========================

def compute_power_spectra(X_batch):
    """
    Compute power spectra for a batch of samples using periodogram.
    """
    PowerSpectra = []
    for i in range(len(X_batch)):
        f, Pxx = periodogram(X_batch[i], fs=TARGET_SAMPLE_RATE)
        # To match size 321, interpolate or truncate
        if len(Pxx) < 321:
            Pxx = np.pad(Pxx, (0, 321 - len(Pxx)), 'constant')
        else:
            Pxx = Pxx[:321]
        PowerSpectra.append(Pxx)
    return np.array(PowerSpectra).astype(np.float32)

def train_model(model, optimizer, scheduler, criterion, dataloader, device):
    """
    Train the model for one epoch.
    """
    model.train()
    running_loss = 0.0
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        # Compute power spectra
        powerspectrum = compute_power_spectra(data.cpu().numpy())
        powerspectrum = torch.tensor(powerspectrum, dtype=torch.float32).to(device)
        data = data.unsqueeze(1)  # Add channel dimension
        output = model(data, powerspectrum)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        scheduler.step()
        running_loss += loss.item() * data.size(0)
    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss

def evaluate_model(model, dataloader, device):
    """
    Evaluate the model and return probabilities and ground truth.
    """
    model.eval()
    probs = []
    ground_truth = []
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            powerspectrum = compute_power_spectra(data.cpu().numpy())
            powerspectrum = torch.tensor(powerspectrum, dtype=torch.float32).to(device)
            data = data.unsqueeze(1)  # Add channel dimension
            output = model(data, powerspectrum).softmax(dim=-1)
            probs.append(output.cpu().numpy())
            ground_truth.append(target.cpu().numpy())
    probs = np.concatenate(probs, axis=0)
    ground_truth = np.concatenate(ground_truth, axis=0)
    return probs, ground_truth

# ===========================
# Main Execution
# ===========================

def main():
    # Create directories to save models and loss history
    os.makedirs("saved_models", exist_ok=True)
    os.makedirs("loss_history", exist_ok=True)

    # Load and preprocess the dataset
    print("Loading and preprocessing the dataset...")
    X, Y = load_dataset(DATASET_PATH)
    print(f"Total samples: {len(X)}")
    print(f"Class distribution: {np.bincount(Y)}")

    # Generate cross-validation splits
    splits = generate_splits(X, Y, NUM_SPLITS)

    models = []
    loss_histories = []
    for split in range(NUM_SPLITS):
        print(f"\n=== Split {split + 1}/{NUM_SPLITS} ===")
        train_idx, test_idx = splits[split]
        X_train, Y_train = X[train_idx], Y[train_idx]
        X_test, Y_test = X[test_idx], Y[test_idx]

        # Shuffle training data
        p = np.random.permutation(len(X_train))
        X_train, Y_train = X_train[p], Y_train[p]

        # Use data fraction
        X_train = X_train[:int(DATA_FRACTION * len(X_train))]
        Y_train = Y_train[:int(DATA_FRACTION * len(Y_train))]

        # Compute class weights
        class_counts = np.bincount(Y_train)
        class_weights = 1. / class_counts
        class_weights = class_weights / class_weights.sum()
        class_weights = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)

        # Create DataLoader
        train_dataset = ArrhythmiaDataset(X_train, Y_train)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

        # Initialize the model
        model = LearnedFilters(num_kernels=NUM_KERNELS, num_classes=3).to(DEVICE)

        # Define optimizer, scheduler, and loss function
        optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=END_FACTOR, total_iters=NUM_EPOCHS*len(train_loader))
        criterion = nn.CrossEntropyLoss(weight=class_weights)

        # Training loop
        loss_history = []
        for epoch in range(NUM_EPOCHS):
            print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
            epoch_loss = train_model(model, optimizer, scheduler, criterion, train_loader, DEVICE)
            loss_history.append(epoch_loss)
            if not USE_TQDM:
                print(f"Epoch Loss: {epoch_loss:.4f}")

        # Append the trained model and loss history
        models.append(model)
        loss_histories.append(loss_history)

        # Save the trained model
        model_path = f"saved_models/model_split_{split+1}.pt"
        torch.save(model.state_dict(), model_path)
        print(f"Saved model to {model_path}")

        # Save the loss history
        loss_path = f"loss_history/loss_split_{split+1}.pkl"
        with open(loss_path, 'wb') as f:
            pickle.dump(loss_history, f)
        print(f"Saved loss history to {loss_path}")

    # Cross-Validation Evaluation
    print("\n=== Cross-Validation Evaluation ===")
    sensitivities = []
    specificities = []
    AUCs = []
    F1s = []
    class_names = ["Normal", "Afib", "Other"]

    for split in range(NUM_SPLITS):
        print(f"\n--- Evaluating Split {split + 1} ---")
        train_idx, test_idx = splits[split]
        X_test, Y_test = X[test_idx], Y[test_idx]

        test_dataset = ArrhythmiaDataset(X_test, Y_test)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

        model = models[split]
        model_path = f"saved_models/model_split_{split+1}.pt"
        model.load_state_dict(torch.load(model_path))
        model.to(DEVICE)

        probs, ground_truth = evaluate_model(model, test_loader, DEVICE)

        sen, spec, auc, f1 = calculate_metrics(ground_truth, probs, num_classes=3)
        sensitivities.append(sen)
        specificities.append(spec)
        AUCs.append(auc)
        F1s.append(f1)

    sensitivities = np.array(sensitivities)
    specificities = np.array(specificities)
    AUCs = np.array(AUCs)
    F1s = np.array(F1s)

    print("\n=== Cross-Validation Results ===")
    print_table(sensitivities.mean(axis=0), specificities.mean(axis=0), AUCs.mean(axis=0), class_names)
    print(f"F1 Score: {F1s.mean():.3f} ± {F1s.std():.3f}")

    # ===========================
    # Holdout Set Evaluation
    # ===========================
    print("\n=== Holdout Set Evaluation ===")
    # For demonstration, we'll use the last split as holdout
    holdout_split = NUM_SPLITS - 1
    train_idx, holdout_idx = splits[holdout_split]
    X_holdout, Y_holdout = X[holdout_idx], Y[holdout_idx]
    holdout_dataset = ArrhythmiaDataset(X_holdout, Y_holdout)
    holdout_loader = DataLoader(holdout_dataset, batch_size=BATCH_SIZE, shuffle=False)

    probs = []
    ground_truth = []
    for split in range(NUM_SPLITS):
        print(f"\n--- Evaluating Model {split + 1} on Holdout Set ---")
        model = models[split]
        model_path = f"saved_models/model_split_{split+1}.pt"
        model.load_state_dict(torch.load(model_path))
        model.to(DEVICE)

        prob, gt = evaluate_model(model, holdout_loader, DEVICE)
        probs.append(prob)
        ground_truth.append(gt)

    sensitivities = []
    specificities = []
    AUCs = []
    F1s = []

    for split in range(NUM_SPLITS):
        sen, spec, auc, f1 = calculate_metrics(ground_truth[split], probs[split], num_classes=3)
        sensitivities.append(sen)
        specificities.append(spec)
        AUCs.append(auc)
        F1s.append(f1)

    sensitivities = np.array(sensitivities)
    specificities = np.array(specificities)
    AUCs = np.array(AUCs)
    F1s = np.array(F1s)

    print("\n=== Holdout Set Results ===")
    print_table(sensitivities.mean(axis=0), specificities.mean(axis=0), AUCs.mean(axis=0), class_names)
    print(f"F1 Score: {F1s.mean():.3f} ± {F1s.std():.3f}")

if __name__ == "__main__":
    main()


Loading and preprocessing the dataset...


Loading Records: 100%|██████████| 48/48 [00:07<00:00,  6.30it/s]


Total samples: 112419
Class distribution: [90477  3581 18361]

=== Split 1/10 ===
Epoch 1/16
Epoch Loss: 0.6057
Epoch 2/16
Epoch Loss: 0.4054
Epoch 3/16
Epoch Loss: 0.3383
Epoch 4/16
Epoch Loss: 0.2988
Epoch 5/16
Epoch Loss: 0.2705
Epoch 6/16
Epoch Loss: 0.2490
Epoch 7/16
Epoch Loss: 0.2333
Epoch 8/16
Epoch Loss: 0.2168
Epoch 9/16
Epoch Loss: 0.2056
Epoch 10/16
Epoch Loss: 0.1927
Epoch 11/16
Epoch Loss: 0.1822
Epoch 12/16
Epoch Loss: 0.1756
Epoch 13/16
Epoch Loss: 0.1679
Epoch 14/16
Epoch Loss: 0.1614
Epoch 15/16
Epoch Loss: 0.1571
Epoch 16/16
Epoch Loss: 0.1536
Saved model to saved_models/model_split_1.pt
Saved loss history to loss_history/loss_split_1.pkl

=== Split 2/10 ===
Epoch 1/16
Epoch Loss: 0.6148
Epoch 2/16
Epoch Loss: 0.4056
Epoch 3/16
Epoch Loss: 0.3356
Epoch 4/16
Epoch Loss: 0.2984
Epoch 5/16
Epoch Loss: 0.2689
Epoch 6/16
Epoch Loss: 0.2471
Epoch 7/16
Epoch Loss: 0.2350
Epoch 8/16
Epoch Loss: 0.2177
Epoch 9/16
Epoch Loss: 0.2049
Epoch 10/16
Epoch Loss: 0.1965
Epoch 11/16
E

  model.load_state_dict(torch.load(model_path))



--- Evaluating Split 2 ---


  model.load_state_dict(torch.load(model_path))



--- Evaluating Split 3 ---


  model.load_state_dict(torch.load(model_path))



--- Evaluating Split 4 ---


  model.load_state_dict(torch.load(model_path))



--- Evaluating Split 5 ---


  model.load_state_dict(torch.load(model_path))



--- Evaluating Split 6 ---


  model.load_state_dict(torch.load(model_path))



--- Evaluating Split 7 ---


  model.load_state_dict(torch.load(model_path))



--- Evaluating Split 8 ---


  model.load_state_dict(torch.load(model_path))



--- Evaluating Split 9 ---


  model.load_state_dict(torch.load(model_path))



--- Evaluating Split 10 ---


  model.load_state_dict(torch.load(model_path))



=== Cross-Validation Results ===
Class     Sensitivity    Specificity    AUC       
Normal    0.939          0.957          0.988     
Afib      0.869          0.965          0.972     
Other     0.947          0.977          0.993     
F1 Score: 0.825 ± 0.165

=== Holdout Set Evaluation ===

--- Evaluating Model 1 on Holdout Set ---


  model.load_state_dict(torch.load(model_path))



--- Evaluating Model 2 on Holdout Set ---

--- Evaluating Model 3 on Holdout Set ---

--- Evaluating Model 4 on Holdout Set ---

--- Evaluating Model 5 on Holdout Set ---

--- Evaluating Model 6 on Holdout Set ---

--- Evaluating Model 7 on Holdout Set ---

--- Evaluating Model 8 on Holdout Set ---

--- Evaluating Model 9 on Holdout Set ---

--- Evaluating Model 10 on Holdout Set ---

=== Holdout Set Results ===
Class     Sensitivity    Specificity    AUC       
Normal    0.939          0.967          0.991     
Afib      0.917          0.965          0.984     
Other     0.955          0.978          0.995     
F1 Score: 0.832 ± 0.160
