In [1]:
# General imports
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from collections import Counter
import wfdb
import pandas as pd

# Device setup
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

Using device: cpu


In [3]:
# ----------------------------
# 1. Extract beats from a record
# ----------------------------
def extract_beats(record_name: str, window: int = 360, db_path: str = "./data/mitdb/1.0.0/"):
    """
    Extract beats and annotations from a single MIT-BIH record
    """
    record = wfdb.rdrecord(f"{db_path}/{record_name}")
    annotation = wfdb.rdann(f"{db_path}/{record_name}", 'atr')
    signal = record.p_signal[:, 0]  # use lead 0
    beats, labels = [], []

    for i, ann_sample in enumerate(annotation.sample):
        if ann_sample - window < 0 or ann_sample + window >= len(signal):
            continue
        beat = signal[ann_sample - window: ann_sample + window]
        beats.append(beat)
        labels.append(annotation.symbol[i])

    return np.array(beats), np.array(labels)


# ----------------------------
# 2. Load multiple records
# ----------------------------
def load_dataset(records, window=360, db_path="data/physionet.org/files/mitdb/1.0.0/"):
    """
    Load multiple records into X (beats) and y (labels)
    """
    all_beats, all_labels = [], []
    for rec in records:
        print(f"Extracting {rec}...")
        beats, labels = extract_beats(rec, window, db_path)
        all_beats.append(beats)
        all_labels.append(labels)
    X = np.concatenate(all_beats, axis=0)
    y = np.concatenate(all_labels, axis=0)
    return X, y


# ----------------------------
# 3. Normalize beats
# ----------------------------
def normalize(X):
    """Z-score normalization per beat"""
    return (X - X.mean(axis=1, keepdims=True)) / (X.std(axis=1, keepdims=True) + 1e-6)


# ----------------------------
# 4. Train/Val/Test split
# ----------------------------
def split_data(X, y, test_size=0.2, val_size=0.1):
    """
    Stratified split into train/val/test sets
    """
    X_train, X_temp, y_train, y_temp = train_test_split(
        X, y, test_size=test_size, random_state=42, stratify=y
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=val_size, random_state=42, stratify=y_temp
    )
    return X_train, X_val, X_test, y_train, y_val, y_test


# ----------------------------
# 5. Format for CNN & LSTM
# ----------------------------
def prepare_for_cnn(X):
    """(samples, length, channels) -> (samples, 1, length)"""
    return X.reshape((X.shape[0], 1, X.shape[1]))

def prepare_for_lstm(X):
    """(samples, length, channels) -> (samples, length, 1)"""
    return X.reshape((X.shape[0], X.shape[1], 1))


In [4]:
# ----------------------------
# Filter & encode labels
# ----------------------------
def filter_and_encode_labels(y: np.ndarray, min_count: int = 5):
    """
    Filters rare labels and encodes remaining labels to integers
    """
    counter = Counter(y)
    keep = {lbl for lbl, c in counter.items() if c >= min_count}
    mask = np.array([lbl in keep for lbl in y])
    y_filtered = y[mask]
    unique = sorted(list({lbl for lbl in y_filtered}))
    label_to_int = {lbl: i for i, lbl in enumerate(unique)}
    y_encoded = np.array([label_to_int[lbl] for lbl in y_filtered])
    return y_encoded, label_to_int, mask


# ----------------------------
# ECG Dataset
# ----------------------------
class ECGDataset(Dataset):
    def __init__(self, X, y):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.tensor(self.y[idx], dtype=torch.long)


# ----------------------------
# Prepare data for CNN
# ----------------------------
def prepare_for_cnn(X: np.ndarray) -> np.ndarray:
    """
    Reshape X to (samples, length, 1) for collate_cnn to permute to (batch, 1, length)
    """
    return X.reshape((X.shape[0], X.shape[1], 1))  # (N, L, 1)


# ----------------------------
# Collate function
# ----------------------------
def collate_cnn(batch):
    """
    Convert list of (x, y) to (batch, 1, L) for Conv1d
    """
    xs, ys = zip(*batch)
    xs = torch.stack(xs)          # (batch, L, 1)
    xs = xs.permute(0, 2, 1)      # (batch, 1, L) âœ… channels first
    ys = torch.stack([y for y in ys])
    return xs, ys


def collate_lstm(batch):
    xs, ys = zip(*batch)
    xs = torch.stack(xs)
    ys = torch.stack(ys)
    return xs, ys


# ----------------------------
# Checkpoint functions
# ----------------------------
def save_checkpoint(state, path):
    torch.save(state, path)

def load_checkpoint(path: str, device: str = "cpu"):
    import torch
    # safe load with full checkpoint (not weights-only)
    return torch.load(path, map_location=device, weights_only=False)



In [5]:
# Simple CNN
class Simple1DCNN(nn.Module):
    def __init__(self, seq_len, n_classes):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.AdaptiveMaxPool1d(1)
        )
        self.fc = nn.Linear(64, n_classes)

    def forward(self, x):
        x = self.conv(x).squeeze(-1)
        return self.fc(x)


# Simple LSTM
class SimpleLSTM(nn.Module):
    def __init__(self, seq_len, n_classes, hidden_size=64, num_layers=1):
        super().__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, n_classes)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = out[:, -1, :]  # last timestep
        return self.fc(out)


In [6]:
# MIT-BIH record splits
train_records = ["100","101","102","103","104","105","106","107","108","109",
                 "111","112","113","114","115","116","117","118","119","121","122","123","124"]
test_records = ["200","201","202","203","205","207","208","209","210","212",
                "213","214","215","219","220","221","222","223","228","230","231","232","233","234"]

# Load and normalize
X_train_raw, y_train_raw = load_dataset(train_records)
X_test_raw, y_test_raw = load_dataset(test_records)
X = np.concatenate([X_train_raw, X_test_raw])
y = np.concatenate([y_train_raw, y_test_raw])
X = normalize(X)

# Filter & encode
y, label_map, mask = filter_and_encode_labels(y, min_count=20)
X = X[mask]

# Split
X_train, X_val, X_test, y_train, y_val, y_test = split_data(X, y)

# Format for models
X_train_cnn = prepare_for_cnn(X_train)
X_val_cnn = prepare_for_cnn(X_val)
X_test_cnn = prepare_for_cnn(X_test)

X_train_lstm = prepare_for_lstm(X_train)
X_val_lstm = prepare_for_lstm(X_val)
X_test_lstm = prepare_for_lstm(X_test)

print("Shapes:")
print("CNN Train:", X_train_cnn.shape, "LSTM Train:", X_train_lstm.shape)
print("Num classes:", len(label_map))


Extracting 100...
Extracting 101...
Extracting 102...
Extracting 103...
Extracting 104...
Extracting 105...
Extracting 106...
Extracting 107...
Extracting 108...
Extracting 109...
Extracting 111...
Extracting 112...
Extracting 113...
Extracting 114...
Extracting 115...
Extracting 116...
Extracting 117...
Extracting 118...
Extracting 119...
Extracting 121...
Extracting 122...
Extracting 123...
Extracting 124...
Extracting 200...
Extracting 201...
Extracting 202...
Extracting 203...
Extracting 205...
Extracting 207...
Extracting 208...
Extracting 209...
Extracting 210...
Extracting 212...
Extracting 213...
Extracting 214...
Extracting 215...
Extracting 219...
Extracting 220...
Extracting 221...
Extracting 222...
Extracting 223...
Extracting 228...
Extracting 230...
Extracting 231...
Extracting 232...
Extracting 233...
Extracting 234...
Shapes:
CNN Train: (88136, 720, 1) LSTM Train: (88136, 720, 1)
Num classes: 19


In [None]:
def train_model(model_name="cnn", epochs=10, batch_size=128, lr=1e-3):
    """
    Train and evaluate a model (CNN or LSTM)
    Returns best checkpoint path and metrics
    """
    if model_name=="cnn":
        collate_fn = collate_cnn
        X_tr, X_val_use = X_train_cnn, X_val_cnn
    else:
        collate_fn = collate_lstm
        X_tr, X_val_use = X_train_lstm, X_val_lstm

    train_ds = ECGDataset(X_tr, y_train)
    val_ds = ECGDataset(X_val_use, y_val)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    n_classes = len(label_map)
    seq_len = X.shape[1]
    model = Simple1DCNN(seq_len, n_classes).to(DEVICE) if model_name=="cnn" else SimpleLSTM(seq_len, n_classes).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    best_val_acc = 0
    best_path = f"best_{model_name}.pth"
    metrics_history = []

    for epoch in range(1, epochs+1):
        model.train()
        losses = []
        t0 = time.time()
        for xb, yb in train_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(model(xb), yb)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

        model.eval()
        preds, trues = [], []
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                logits = model(xb)
                preds.append(torch.argmax(logits, 1).cpu().numpy())
                trues.append(yb.cpu().numpy())
        preds = np.concatenate(preds)
        trues = np.concatenate(trues)
        val_acc = accuracy_score(trues, preds)
        metrics_history.append((epoch, np.mean(losses), val_acc))

        print(f"{model_name.upper()} Epoch {epoch}: Loss={np.mean(losses):.4f} ValAcc={val_acc:.4f} Time={time.time()-t0:.1f}s")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_checkpoint({
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "label_map": label_map
            }, best_path)

    return best_path, metrics_history


In [None]:
best_lstm_path, lstm_metrics = train_model("lstm", epochs=5)

# Evaluate
val_ds = ECGDataset(X_test_lstm, y_test)
val_loader = DataLoader(val_ds, batch_size=128, shuffle=False, collate_fn=collate_lstm)

model = SimpleLSTM(X.shape[1], len(label_map)).to(DEVICE)
checkpoint = load_checkpoint(best_lstm_path, DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

preds, trues = [], []
with torch.no_grad():
    for xb, yb in val_loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        logits = model(xb)
        preds.append(torch.argmax(logits, 1).cpu().numpy())
        trues.append(yb.cpu().numpy())

preds = np.concatenate(preds)
trues = np.concatenate(trues)
lstm_acc = accuracy_score(trues, preds)
print("LSTM Test Accuracy:", lstm_acc)


LSTM Epoch 1: Loss=1.2852 ValAcc=0.6911 Time=197.3s
LSTM Test Accuracy: 0.6873865698729582


In [None]:
best_cnn_path, cnn_metrics = train_model("cnn", epochs=5)

# Evaluate
val_ds = ECGDataset(X_test_cnn, y_test)
val_loader = DataLoader(val_ds, batch_size=128, shuffle=False, collate_fn=collate_cnn)

model = Simple1DCNN(X.shape[1], len(label_map)).to(DEVICE)
checkpoint = load_checkpoint(best_cnn_path, DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

preds, trues = [], []
with torch.no_grad():
    for xb, yb in val_loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        logits = model(xb)
        preds.append(torch.argmax(logits, 1).cpu().numpy())
        trues.append(yb.cpu().numpy())

preds = np.concatenate(preds)
trues = np.concatenate(trues)
cnn_acc = accuracy_score(trues, preds)
print("CNN Test Accuracy:", cnn_acc)


CNN Epoch 1: Loss=0.8500 ValAcc=0.8377 Time=155.5s
CNN Test Accuracy: 0.838929219600726


In [11]:
metrics_df = pd.DataFrame({
    "Model": ["LSTM", "CNN"],
    "Test Accuracy": [lstm_acc, cnn_acc],
    "Best Epoch Loss": [min([m[1] for m in lstm_metrics]), min([m[1] for m in cnn_metrics])],
    "Best Epoch Val Acc": [max([m[2] for m in lstm_metrics]), max([m[2] for m in cnn_metrics])]
})

metrics_df


Unnamed: 0,Model,Test Accuracy,Best Epoch Loss,Best Epoch Val Acc
0,LSTM,0.687387,1.285218,0.691125
1,CNN,0.838929,0.849978,0.837721
