<a href="https://colab.research.google.com/github/Tarwish2005/-ECG-Arrhythmia-Detection/blob/main/Lstm%2BCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip -q install wfdb numpy scipy scikit-learn torch torchaudio torchvision matplotlib


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.8/163.8 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m47.8 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.1 which is incompatible.
cudf-cu12 25.6.0 requires pandas<2.2.4dev0,>=2.0, but you have pandas 2.3.1 which is incompatible.
dask-cudf-cu12 25.6.0 requires pandas<2.2.4dev0,>=2.0, but you have pandas 2.3.1 which is incompatible.[0m[31m
[0m

In [None]:
# If needed (e.g., in Colab):
# !pip install wfdb numpy scipy scikit-learn torch torchvision matplotlib

import os
import numpy as np
import wfdb
import matplotlib.pyplot as plt

from scipy.signal import butter, filtfilt, resample
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader


#Parameter

In [None]:
VFDB_DIR   = "/content/drive/MyDrive/ECG/VFDB"
CUVTDB_DIR = "/content/drive/MyDrive/ECG/CUVTDB"
ADB_DIR    = "/content/drive/MyDrive/ECG/MITDB"   # (use Arrhythmia DB path if different)
NSTDB_DIR  = "/content/drive/MyDrive/ECG/NSTDB"       # contains 'ma' noise
TARGET_FS  = 250
WIN_SEC    = 5
WIN_SAMPLES = TARGET_FS * WIN_SEC


#Preprocessing


In [None]:
def highpass_filter(signal, fs=250, cutoff=1.0, order=4):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype="high", analog=False)
    return filtfilt(b, a, signal)

def preprocess_ecg(signal, fs=250, target_fs=250, window_size=5):
    # 1) resample
    if fs != target_fs:
        n_samples = int(len(signal) * target_fs / fs)
        signal = resample(signal, n_samples)

    # 2) remove baseline wander
    signal = highpass_filter(signal, fs=target_fs, cutoff=1.0, order=4)

    # 3) z-score normalize
    mu, sigma = np.mean(signal), np.std(signal) + 1e-8
    signal = (signal - mu) / sigma

    # 4) segment into non-overlapping windows
    segment_length = window_size * target_fs
    n_segments = len(signal) // segment_length
    if n_segments == 0:
        return np.array([])

    segments = []
    for i in range(n_segments):
        seg = signal[i*segment_length:(i+1)*segment_length]
        if not (np.isnan(seg).any() or np.isinf(seg).any()):
            segments.append(seg)

    return np.asarray(segments)


#DATA

In [None]:
def _list_records(db_path):
    return [f.split('.')[0] for f in os.listdir(db_path) if f.endswith('.dat')]

def load_mitdb(db_path):
    X, y = [], []
    if not os.path.exists(db_path):
        print(f"Warning: {db_path} does not exist"); return X, y
    records = _list_records(db_path)
    print(f"Found {len(records)} MITDB records")

    for rec_name in records:
        try:
            rec = wfdb.rdrecord(os.path.join(db_path, rec_name))
            if rec.p_signal is not None and rec.p_signal.shape[1] > 0:
                sig = rec.p_signal[:, 0]
                segments = preprocess_ecg(sig, fs=rec.fs)
                if len(segments) > 0:
                    X.extend(segments)     # Non-VT/VF segments
                    y.extend([0]*len(segments))
        except Exception as e:
            print(f"MITDB error {rec_name}: {e}")
    print(f"MITDB: {len(X)} segments")
    return X, y

def load_vfdb(db_path):
    X, y = [], []
    if not os.path.exists(db_path):
        print(f"Warning: {db_path} does not exist"); return X, y
    records = _list_records(db_path)
    print(f"Found {len(records)} VFDB records")

    for rec_name in records:
        try:
            rec = wfdb.rdrecord(os.path.join(db_path, rec_name))
            if rec.p_signal is not None and rec.p_signal.shape[1] > 0:
                sig = rec.p_signal[:, 0]
                segments = preprocess_ecg(sig, fs=rec.fs)
                if len(segments) > 0:
                    X.extend(segments)     # VT/VF segments
                    y.extend([1]*len(segments))
        except Exception as e:
            print(f"VFDB error {rec_name}: {e}")
    print(f"VFDB: {len(X)} segments")
    return X, y

def load_cuvtdb(db_path):
    X, y = [], []
    if not os.path.exists(db_path):
        print(f"Warning: {db_path} does not exist"); return X, y
    records = _list_records(db_path)
    print(f"Found {len(records)} CUVTDB records")

    for rec_name in records:
        try:
            rec = wfdb.rdrecord(os.path.join(db_path, rec_name))
            if rec.p_signal is not None and rec.p_signal.shape[1] > 0:
                sig = rec.p_signal[:, 0]
                segments = preprocess_ecg(sig, fs=rec.fs)
                if len(segments) > 0:
                    X.extend(segments)     # VT/VF segments
                    y.extend([1]*len(segments))
        except Exception as e:
            print(f"CUVTDB error {rec_name}: {e}")
    print(f"CUVTDB: {len(X)} segments")
    return X, y

def load_nstdb_artifacts(db_path, target_fs=250, window_size=5):
    """
    Return artifact windows from NSTDB to use as 'ma' (muscle artifact) sources.
    We do NOT label them here; they’re used to synthesize the Noisy class.
    """
    MA = []
    if not os.path.exists(db_path):
        print(f"Warning: {db_path} does not exist"); return MA

    records = _list_records(db_path)
    print(f"Found {len(records)} NSTDB records (as artifact sources)")

    for rec_name in records:
        try:
            rec = wfdb.rdrecord(os.path.join(db_path, rec_name))
            if rec.p_signal is not None and rec.p_signal.shape[1] > 0:
                sig = rec.p_signal[:, 0]
                # NOTE: We do minimal preprocessing (no high-pass), since it’s artifact
                # but we standardize to target_fs and windowing to align lengths.
                if rec.fs != target_fs:
                    n_samples = int(len(sig) * target_fs / rec.fs)
                    sig = resample(sig, n_samples)
                # z-norm artifact
                sig = (sig - np.mean(sig)) / (np.std(sig) + 1e-8)

                seg_len = window_size * target_fs
                n_segments = len(sig) // seg_len
                for i in range(n_segments):
                    MA.append(sig[i*seg_len:(i+1)*seg_len])
        except Exception as e:
            print(f"NSTDB error {rec_name}: {e}")
    print(f"NSTDB artifact windows: {len(MA)}")
    return np.asarray(MA)


In [None]:

VFDB_DIR   = "/content/drive/MyDrive/ECG/VFDB"
CUVTDB_DIR = "/content/drive/MyDrive/ECG/CUVTDB"
ADB_DIR    = "/content/drive/MyDrive/ECG/MITDB"
NSTDB_DIR  = "/content/drive/MyDrive/ECG/NSTDB"

# Load base datasets
non_vf_X, non_vf_y = load_mitdb(ADB_DIR)        # class 0
vf1_X, vf1_y       = load_vfdb(VFDB_DIR)        # class 1
vf2_X, vf2_y       = load_cuvtdb(CUVTDB_DIR)    # class 1
ma_X               = load_nstdb_artifacts(NSTDB_DIR)  # unlabeled artifact windows

# Combine VT/VF
vf_X = vf1_X + vf2_X
vf_y = vf1_y + vf2_y

print(f"Non-VF segments: {len(non_vf_X)} | VF segments: {len(vf_X)} | MA windows: {len(ma_X)}")


Found 48 MITDB records
MITDB: 17328 segments
Found 22 VFDB records
VFDB: 9240 segments
Found 39 CUVTDB records
CUVTDB: 707 segments
Found 30 NSTDB records (as artifact sources)
NSTDB artifact windows: 10830
Non-VF segments: 17328 | VF segments: 9947 | MA windows: 10830


In [None]:
rng = np.random.default_rng(42)

def add_gaussian_noise(seg, scale=0.05):
    noise = rng.normal(0.0, 1.0, size=seg.shape)
    return seg + scale * noise

def synth_noisy_from_nonvf(seg, ma_bank, a=0.3, gauss_scale=0.08):
    # pick a random MA window, same length
    ma = ma_bank[rng.integers(0, len(ma_bank))]
    ma = ma[:len(seg)]
    # Non-VT/VF_noisy = Non-VT/VF_clean + a·w_n + ma
    out = seg.copy()
    out = out + a * rng.normal(0.0, 1.0, size=seg.shape)  # a·w_n
    out = out + ma                                        # + ma
    return add_gaussian_noise(out, scale=gauss_scale)     # extra small Gaussian

def augment_vf(seg, b=0.02):
    # VT/VF_aug = VT/VF_clean + b·w_n   (small)
    return add_gaussian_noise(seg, scale=b)


In [None]:
# Convert lists to arrays
non_vf_X = np.asarray(non_vf_X)
vf_X     = np.asarray(vf_X)

print("Shapes → Non-VF:", non_vf_X.shape, "| VF:", vf_X.shape)

# Create synthetic Noisy class from Non-VT/VF + Gaussian + MA
noisy_syn = []
for seg in non_vf_X:
    noisy_syn.append(synth_noisy_from_nonvf(seg, ma_X, a=0.3, gauss_scale=0.08))
noisy_syn = np.asarray(noisy_syn)

# Augment VF with light Gaussian
vf_aug = []
for seg in vf_X:
    vf_aug.append(augment_vf(seg, b=0.02))
vf_aug = np.asarray(vf_aug)

# Assemble X, y
X = np.concatenate([non_vf_X, vf_X, noisy_syn, vf_aug], axis=0)
y = np.concatenate([
    np.zeros(len(non_vf_X), dtype=int),   # class 0: Non-VF
    np.ones(len(vf_X), dtype=int),        # class 1: VF
    2*np.ones(len(noisy_syn), dtype=int), # class 2: Noisy (synthetic)
    np.ones(len(vf_aug), dtype=int),      # class 1: VF (augmented)
])

print("Final X:", X.shape, "| y counts:", dict(zip(*np.unique(y, return_counts=True))))


Shapes → Non-VF: (17328, 1250) | VF: (9947, 1250)
Final X: (54550, 1250) | y counts: {np.int64(0): np.int64(17328), np.int64(1): np.int64(19894), np.int64(2): np.int64(17328)}


In [None]:
# Split
X_train, X_temp, y_train, y_temp = train_test_split(
    X, y, test_size=0.30, random_state=42, stratify=y
)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.50, random_state=42, stratify=y_temp
)

print(f"Train: {len(X_train)} | Val: {len(X_val)} | Test: {len(X_test)}")

# To tensors (shape: [N, 1, 1250])
X_train_t = torch.tensor(X_train, dtype=torch.float32).unsqueeze(1)
X_val_t   = torch.tensor(X_val,   dtype=torch.float32).unsqueeze(1)
X_test_t  = torch.tensor(X_test,  dtype=torch.float32).unsqueeze(1)

y_train_t = torch.tensor(y_train, dtype=torch.long)
y_val_t   = torch.tensor(y_val,   dtype=torch.long)
y_test_t  = torch.tensor(y_test,  dtype=torch.long)

train_ds = TensorDataset(X_train_t, y_train_t)
val_ds   = TensorDataset(X_val_t,   y_val_t)
test_ds  = TensorDataset(X_test_t,  y_test_t)

batch_size = 32
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, drop_last=False)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, drop_last=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


Train: 38185 | Val: 8182 | Test: 8183


device(type='cuda')

#Model

In [None]:
class ECG_CNN_LSTM(nn.Module):
    def __init__(self, input_length=1250, num_classes=3, hidden_size=128, lstm_layers=2):
        super().__init__()
        # CNN branch
        self.conv1 = nn.Conv1d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(32); self.pool1 = nn.MaxPool1d(2)

        self.conv2 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(64); self.pool2 = nn.MaxPool1d(2)

        self.conv3 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(128); self.pool3 = nn.MaxPool1d(2)

        self.conv4 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm1d(256); self.pool4 = nn.MaxPool1d(2)

        self.conv5 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm1d(512); self.pool5 = nn.MaxPool1d(2)

        # 1250 → 625 → 312 → 156 → 78 → 39
        self.flattened_size = 512 * 39

        # LSTM branch (bidirectional)
        self.lstm = nn.LSTM(
            input_size=1, hidden_size=hidden_size,
            num_layers=lstm_layers, batch_first=True, bidirectional=True
        )

        # Fusion + FC
        self.fc1 = nn.Linear(self.flattened_size + 2*hidden_size, 256)
        self.drop1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, 128)
        self.drop2 = nn.Dropout(0.3)
        self.fc3 = nn.Linear(128, num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        # x: [B, 1, L]
        # CNN branch
        c = self.pool1(self.relu(self.bn1(self.conv1(x))))
        c = self.pool2(self.relu(self.bn2(self.conv2(c))))
        c = self.pool3(self.relu(self.bn3(self.conv3(c))))
        c = self.pool4(self.relu(self.bn4(self.conv4(c))))
        c = self.pool5(self.relu(self.bn5(self.conv5(c))))
        c = c.view(c.size(0), -1)  # [B, 512*39]

        # LSTM branch (expects [B, L, 1])
        l = x.permute(0, 2, 1)
        l, _ = self.lstm(l)
        l = l[:, -1, :]  # last time step, [B, 2*hidden]

        # Fuse
        f = torch.cat([c, l], dim=1)
        features = self.relu(self.fc1(f))
        features = self.drop1(features)
        z = self.relu(self.fc2(features))
        z = self.drop2(z)
        logits = self.fc3(z)
        return logits, features


In [None]:
model = ECG_CNN_LSTM(input_length=X.shape[1], num_classes=3, hidden_size=128, lstm_layers=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

epochs = 100
train_losses, val_losses = [], []
train_accs, val_accs = [], []

def accuracy_from_logits(logits, y_true):
    preds = logits.argmax(dim=1)
    return (preds == y_true).float().mean().item()

print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
for epoch in range(epochs):
    # train
    model.train()
    tloss, tcorrect, tcount = 0.0, 0, 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits, _ = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        tloss += loss.item()
        tcorrect += (logits.argmax(1) == yb).sum().item()
        tcount += yb.size(0)

    # val
    model.eval()
    vloss, vcorrect, vcount = 0.0, 0, 0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            logits, _ = model(xb)
            loss = criterion(logits, yb)
            vloss += loss.item()
            vcorrect += (logits.argmax(1) == yb).sum().item()
            vcount += yb.size(0)

    train_losses.append(tloss/len(train_loader))
    val_losses.append(vloss/len(val_loader))
    train_accs.append(tcorrect/tcount)
    val_accs.append(vcorrect/vcount)

    if (epoch+1) % 10 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:03d}/{epochs} | "
              f"Train Loss {train_losses[-1]:.4f} Acc {train_accs[-1]:.4f} | "
              f"Val Loss {val_losses[-1]:.4f} Acc {val_accs[-1]:.4f}")


Parameters: 6,265,603
Epoch 001/100 | Train Loss 0.1799 Acc 0.9369 | Val Loss 0.0495 Acc 0.9853
Epoch 010/100 | Train Loss 0.0267 Acc 0.9935 | Val Loss 0.0145 Acc 0.9958
Epoch 020/100 | Train Loss 0.0134 Acc 0.9964 | Val Loss 0.0469 Acc 0.9918
Epoch 030/100 | Train Loss 0.0127 Acc 0.9969 | Val Loss 0.0072 Acc 0.9979
Epoch 040/100 | Train Loss 0.0108 Acc 0.9973 | Val Loss 0.0203 Acc 0.9938
Epoch 050/100 | Train Loss 0.0119 Acc 0.9969 | Val Loss 0.0074 Acc 0.9974
Epoch 060/100 | Train Loss 0.0114 Acc 0.9972 | Val Loss 0.0083 Acc 0.9983


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
plt.figure(figsize=(12,4))

plt.subplot(1,2,1)
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.title("Loss"); plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend()

plt.subplot(1,2,2)
plt.plot(train_accs, label="Train Acc")
plt.plot(val_accs, label="Val Acc")
plt.title("Accuracy"); plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.legend()

plt.tight_layout()
plt.show()
