In [None]:
import os
import wfdb
import numpy as np
import pandas as pd
from scipy.signal import butter, filtfilt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random

In [10]:
# =========================
# 1. CONFIG
# =========================
MITDB_PATH = "mit-bih-arrhythmia-database-1.0.0"
WINDOW_SIZE = 180     # samples before & after R-peak
SAMPLE_RATE = 360
BATCH_SIZE = 128
EPOCHS = 20
LEARNING_RATE = 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

<torch._C.Generator at 0x2570a4fb050>

In [11]:
# =========================
# BANDPASS FILTER (0.5–40 Hz)
# =========================
def bandpass_filter(signal, lowcut=0.5, highcut=40, fs=SAMPLE_RATE, order=3):
    nyq = 0.5 * fs
    b, a = butter(order, [lowcut/nyq, highcut/nyq], btype='band')
    return filtfilt(b, a, signal)

In [12]:
# =========================
# DOWNLOAD DATASET
# =========================
def download_mitbih():
    os.makedirs(MITDB_PATH, exist_ok=True)
    records = wfdb.get_record_list("mitdb")
    print(f"Found {len(records)} records.")
    return records

In [13]:
# =========================
# EXTRACT BEATS
# =========================
def extract_beats(records):
    beats, labels = [], []
    for rec in tqdm(records, desc="Extracting beats"):
        record = wfdb.rdrecord(rec, pn_dir="mitdb")
        ann = wfdb.rdann(rec, 'atr', pn_dir="mitdb")
        sig = bandpass_filter(record.p_signal[:, 0])  # MLII lead filtered

        for idx, sym in zip(ann.sample, ann.symbol):
            if sym not in ['N', 'L', 'R', 'A', 'V', 'F']:
                continue
            start, end = idx - WINDOW_SIZE, idx + WINDOW_SIZE
            if start < 0 or end > len(sig):
                continue
            segment = sig[start:end]
            segment = (segment - np.mean(segment)) / (np.std(segment) + 1e-8)
            beats.append(segment)
            labels.append(sym)
    beats, labels = np.array(beats), np.array(labels)
    print(f"Extracted {len(beats)} beats.")
    return beats, labels

In [14]:
# =========================
# DATA AUGMENTATION
# =========================
def augment_signal(x):
    if random.random() < 0.3:
        noise = np.random.normal(0, 0.02, x.shape)
        x = x + noise
    if random.random() < 0.3:
        drift = 0.05 * np.sin(2 * np.pi * np.linspace(0, 1, len(x)) * random.uniform(0.1, 0.3))
        x = x + drift
    if random.random() < 0.3:
        scale = random.uniform(0.9, 1.1)
        x = x * scale
    return x.astype(np.float32)

In [15]:
# =========================
# DATASET CLASS
# =========================
class ECGDataset(Dataset):
    def __init__(self, X, y, augment=False):
        self.X, self.y = X, y
        self.augment = augment

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]
        if self.augment:
            x = augment_signal(x)
        x = torch.tensor(x, dtype=torch.float32).unsqueeze(0)
        y = torch.tensor(self.y[idx], dtype=torch.long)
        return x, y

In [16]:
# =========================
# 1D CNN MODEL
# =========================
class CNN1D(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 32, 7, padding=3), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(32, 64, 5, padding=2), nn.BatchNorm1d(64), nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(64, 128, 3, padding=1), nn.BatchNorm1d(128), nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.net(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


In [17]:
# =========================
# TRAIN FUNCTION
# =========================
def train_model(model, train_loader, val_loader, class_weights):
    criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights, dtype=torch.float).to(DEVICE))
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    for epoch in range(EPOCHS):
        model.train()
        running_loss, correct, total = 0, 0, 0
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            pred = out.argmax(1)
            correct += (pred == y).sum().item()
            total += y.size(0)
        acc = correct / total
        print(f"Train Loss: {running_loss/len(train_loader):.4f} | Acc: {acc:.4f}")

        # Validation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(DEVICE), y.to(DEVICE)
                out = model(x)
                pred = out.argmax(1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        print(f"Validation Accuracy: {correct/total:.4f}")

In [18]:
# =========================
# MAIN
# =========================
if __name__ == "__main__":
    records = download_mitbih()
    X, y = extract_beats(records)
    le = LabelEncoder()
    y_enc = le.fit_transform(y)

    X_train, X_val, y_train, y_val = train_test_split(X, y_enc, test_size=0.2, random_state=SEED, stratify=y_enc)
    class_counts = np.bincount(y_enc)
    class_weights = np.max(class_counts) / (class_counts + 1e-8)
    print("Class weights:", class_weights)

    train_ds = ECGDataset(X_train, y_train, augment=True)
    val_ds = ECGDataset(X_val, y_val, augment=False)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

    model = CNN1D(num_classes=len(le.classes_)).to(DEVICE)
    print(f"Training on {DEVICE} for classes: {le.classes_}")
    train_model(model, train_loader, val_loader, class_weights)

    # Evaluation
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(DEVICE)
            out = model(x)
            preds = out.argmax(1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(y.numpy())

    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=le.classes_))
    print("\nConfusion Matrix:")
    print(confusion_matrix(all_labels, all_preds))
    torch.save(model.state_dict(), "cnn1d_mitbih_augmented.pth")
    print("✅ Saved model: cnn1d_mitbih_augmented.pth")

Found 48 records.


Extracting beats: 100%|██████████| 48/48 [04:35<00:00,  5.73s/it]


Extracted 100814 beats.
Class weights: [29.46229379 93.52992519  9.29389171  1.         10.33921433 10.52195259]
Training on cuda for classes: ['A' 'F' 'L' 'N' 'R' 'V']


Epoch 1/20: 100%|██████████| 631/631 [00:04<00:00, 142.24it/s]


Train Loss: 0.7845 | Acc: 0.7202
Validation Accuracy: 0.6639


Epoch 2/20: 100%|██████████| 631/631 [00:04<00:00, 134.08it/s]


Train Loss: 0.3853 | Acc: 0.8519
Validation Accuracy: 0.8050


Epoch 3/20: 100%|██████████| 631/631 [00:04<00:00, 140.51it/s]


Train Loss: 0.3239 | Acc: 0.8751
Validation Accuracy: 0.8663


Epoch 4/20: 100%|██████████| 631/631 [00:04<00:00, 138.12it/s]


Train Loss: 0.2906 | Acc: 0.8816
Validation Accuracy: 0.7567


Epoch 5/20: 100%|██████████| 631/631 [00:04<00:00, 137.61it/s]


Train Loss: 0.2587 | Acc: 0.8930
Validation Accuracy: 0.9278


Epoch 6/20: 100%|██████████| 631/631 [00:04<00:00, 136.64it/s]


Train Loss: 0.2476 | Acc: 0.8975
Validation Accuracy: 0.9209


Epoch 7/20: 100%|██████████| 631/631 [00:04<00:00, 130.73it/s]


Train Loss: 0.2403 | Acc: 0.8989
Validation Accuracy: 0.9109


Epoch 8/20: 100%|██████████| 631/631 [00:04<00:00, 140.90it/s]


Train Loss: 0.2357 | Acc: 0.9031
Validation Accuracy: 0.9134


Epoch 9/20: 100%|██████████| 631/631 [00:04<00:00, 139.12it/s]


Train Loss: 0.2127 | Acc: 0.9083
Validation Accuracy: 0.9295


Epoch 10/20: 100%|██████████| 631/631 [00:04<00:00, 140.10it/s]


Train Loss: 0.2115 | Acc: 0.9103
Validation Accuracy: 0.6949


Epoch 11/20: 100%|██████████| 631/631 [00:04<00:00, 142.93it/s]


Train Loss: 0.2066 | Acc: 0.9109
Validation Accuracy: 0.9468


Epoch 12/20: 100%|██████████| 631/631 [00:04<00:00, 132.07it/s]


Train Loss: 0.2008 | Acc: 0.9163
Validation Accuracy: 0.9254


Epoch 13/20: 100%|██████████| 631/631 [00:04<00:00, 142.33it/s]


Train Loss: 0.1920 | Acc: 0.9154
Validation Accuracy: 0.9593


Epoch 14/20: 100%|██████████| 631/631 [00:04<00:00, 140.24it/s]


Train Loss: 0.1852 | Acc: 0.9196
Validation Accuracy: 0.9317


Epoch 15/20: 100%|██████████| 631/631 [00:04<00:00, 138.91it/s]


Train Loss: 0.1843 | Acc: 0.9227
Validation Accuracy: 0.9253


Epoch 16/20: 100%|██████████| 631/631 [00:04<00:00, 140.95it/s]


Train Loss: 0.1803 | Acc: 0.9212
Validation Accuracy: 0.9631


Epoch 17/20: 100%|██████████| 631/631 [00:04<00:00, 131.34it/s]


Train Loss: 0.1711 | Acc: 0.9240
Validation Accuracy: 0.9113


Epoch 18/20: 100%|██████████| 631/631 [00:04<00:00, 142.38it/s]


Train Loss: 0.1752 | Acc: 0.9246
Validation Accuracy: 0.9316


Epoch 19/20: 100%|██████████| 631/631 [00:04<00:00, 137.19it/s]


Train Loss: 0.1640 | Acc: 0.9264
Validation Accuracy: 0.9415


Epoch 20/20: 100%|██████████| 631/631 [00:04<00:00, 138.12it/s]


Train Loss: 0.1620 | Acc: 0.9273
Validation Accuracy: 0.9064

Classification Report:
              precision    recall  f1-score   support

           A       0.39      0.89      0.55       509
           F       0.24      0.88      0.37       161
           L       0.97      0.99      0.98      1614
           N       1.00      0.88      0.94     15002
           R       0.98      0.99      0.98      1451
           V       0.70      0.97      0.81      1426

    accuracy                           0.91     20163
   macro avg       0.71      0.93      0.77     20163
weighted avg       0.95      0.91      0.92     20163


Confusion Matrix:
[[  454     2     0    40     6     7]
 [    0   142     0     5     0    14]
 [    0     0  1597     3     0    14]
 [  676   440    46 13264    22   554]
 [   15     0     0     2  1433     1]
 [    8    17     4     9     3  1385]]
✅ Saved model: cnn1d_mitbih_augmented.pth
