## **ECG Arrhythmia Classification with CNN and Interactive Dashboard**

Electrocardiography (ECG) is a non-invasive technique that records the electrical activity of the heart over time. The ECG waveform reflects the coordinated depolarization and repolarization of cardiac muscle cells, mediated by the heart’s conduction system.

### **Anatomical and Physiological Basis**

The heart’s conduction system ensures rhythmic and synchronized contractions:

* **Sinoatrial (SA) Node** – The natural pacemaker, located in the right atrium, initiates the electrical impulse.
* **Atrial Muscle** – Conducts the impulse across both atria, producing the **P wave** (atrial depolarization).
* **Atrioventricular (AV) Node** – Delays the impulse to allow ventricular filling, seen in the **PR segment**.
* **Bundle of His & Bundle Branches** – Transmit the signal through the interventricular septum.
* **Purkinje Fibers** – Rapidly deliver the impulse to ventricular myocardium, generating the **QRS complex** (ventricular depolarization) followed by the **T wave** (ventricular repolarization).

### **Arrhythmias and ECG Changes**

Arrhythmias occur when the impulse generation or conduction pathway is altered:

* **Normal beat (N)** – Regular SA node rhythm with intact conduction.
* **Ventricular ectopic beat (VEB)** – Premature ventricular depolarization from an abnormal focus in the ventricles, often producing a wide QRS complex.
* **Supraventricular ectopic beat (SVEB)** – Originates above the ventricles (atria or AV node) and alters P-wave morphology with a narrow QRS.
* **Fusion beat (F)** – A hybrid waveform from simultaneous normal and ectopic activation.

These morphological differences are directly tied to the anatomical site of origin, making ECG classification both clinically relevant and physiologically interpretable.

### **Project Objective**

In this project, we develop a **Convolutional Neural Network (CNN)** model to classify ECG beats into different arrhythmia types using the MIT-BIH Arrhythmia Database. The model automatically learns morphological features such as P-wave shape, QRS width, and ST-T segment variations that correspond to underlying conduction abnormalities.

To make the results accessible and interpretable, we integrate the trained model into a **Streamlit-based interactive dashboard** that allows users to:

* Upload ECG files or explore sample beats
* View the raw waveform and detected beats
* See classification results with confidence scores
* Explore explainability visualizations (e.g., Grad-CAM) mapping model attention to specific waveform regions
* Connect waveform changes to anatomical and physiological causes

This combination of deep learning, interactive visualization, and anatomical context bridges the gap between machine intelligence and clinical reasoning.

---

In [None]:
# ================================================================
# 📦 1. Install dependencies
# ================================================================
!pip install wfdb streamlit plotly tensorflow scikit-learn neurokit2 tf-explain -q

# ================================================================
# 📂 2. Create folder structure
# ================================================================
import os

folders = [
    "data/raw",
    "data/processed",
    "models",
    "src/data",
    "src/models",
    "app"
]
for f in folders:
    os.makedirs(f, exist_ok=True)

print("✅ Folder structure ready.")

# ================================================================
# 📥 3. Download MIT-BIH Arrhythmia Database (records 100-104 for demo)
# ================================================================
import wfdb

record_ids = ["100", "101", "102", "103", "104"]
for rec in record_ids:
    wfdb.dl_database(
        "mitdb",
        os.path.join("data/raw", rec),
        records=[rec]
    )
print("✅ MIT-BIH sample records downloaded.")

# ================================================================
# 🛠 4. Preprocessing functions
# ================================================================
import numpy as np
import neurokit2 as nk
from scipy.signal import butter, filtfilt

def bandpass_filter(signal, fs, lowcut=0.5, highcut=40.0, order=4):
    """Bandpass filter for ECG signal."""
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, signal)

def preprocess_ecg(record_path):
    """Load, filter, detect R-peaks, and segment beats."""
    # Load ECG
    rec = wfdb.rdrecord(record_path)
    sig = rec.p_signal[:, 0]  # first channel
    fs = rec.fs

    # Filter
    sig_filtered = bandpass_filter(sig, fs)

    # R-peak detection
    _, rpeaks = nk.ecg_peaks(sig_filtered, sampling_rate=fs)

    # Segment beats
    beats = []
    window_pre = int(0.2 * fs)   # 200 ms before R-peak
    window_post = int(0.4 * fs)  # 400 ms after R-peak
    for r in rpeaks['ECG_R_Peaks']:
        start = r - window_pre
        end = r + window_post
        if start >= 0 and end < len(sig_filtered):
            beats.append(sig_filtered[start:end])
    beats = np.array(beats)

    return beats, sig_filtered, fs, rpeaks['ECG_R_Peaks']

# Test with one record
beats, sig_filt, fs, rlocs = preprocess_ecg("data/raw/100/100")
print(f"Extracted beats: {beats.shape}, Sampling rate: {fs} Hz")


✅ Folder structure ready.
Generating record list for: 100
Generating list of all files for: 100
Downloading files...
Finished downloading files
Generating record list for: 101
Generating list of all files for: 101
Downloading files...
Finished downloading files
Generating record list for: 102
Generating list of all files for: 102
Downloading files...
Finished downloading files
Generating record list for: 103
Generating list of all files for: 103
Downloading files...
Finished downloading files
Generating record list for: 104
Generating list of all files for: 104
Downloading files...
Finished downloading files
✅ MIT-BIH sample records downloaded.
Extracted beats: (2270, 216), Sampling rate: 360 Hz


1. Data Preparation and Labelling

In [None]:
# ================================================================
# 📥 5. Beat extraction with labels
# ================================================================
import wfdb
from collections import defaultdict

# Mapping from MIT-BIH annotation symbols to simplified class labels
# N = Normal, V = Ventricular ectopic, S = Supraventricular ectopic, F = Fusion, Q = Unknown
symbol_to_class = {
    'N': 0,  # Normal
    'L': 0, 'R': 0, 'e': 0, 'j': 0,
    'V': 1, 'E': 1,
    'S': 2, 'A': 2, 'a': 2, 'J': 2,
    'F': 3,
    '/': 4, 'Q': 4, '?': 4
}

class_names = ["Normal", "VEB", "SVEB", "Fusion", "Unknown"]

def extract_beats_with_labels(record_path):
    # Load ECG signal and annotations
    rec = wfdb.rdrecord(record_path)
    ann = wfdb.rdann(record_path, 'atr')
    sig = rec.p_signal[:, 0]
    fs = rec.fs

    # Filter signal
    sig_filtered = bandpass_filter(sig, fs)

    beats, labels = [], []
    window_pre = int(0.2 * fs)
    window_post = int(0.4 * fs)

    for idx, sym in zip(ann.sample, ann.symbol):
        if sym in symbol_to_class:
            start = idx - window_pre
            end = idx + window_post
            if start >= 0 and end < len(sig_filtered):
                beats.append(sig_filtered[start:end])
                labels.append(symbol_to_class[sym])

    return np.array(beats), np.array(labels)

# Process all demo records
all_beats, all_labels = [], []
for rec in record_ids:
    beats, labels = extract_beats_with_labels(f"data/raw/{rec}/{rec}")
    all_beats.append(beats)
    all_labels.append(labels)

all_beats = np.vstack(all_beats)
all_labels = np.concatenate(all_labels)

# Save processed dataset
np.save("data/processed/beats.npy", all_beats)
np.save("data/processed/labels.npy", all_labels)

print(f"✅ Saved {all_beats.shape[0]} beats. Shape per beat: {all_beats.shape[1]} samples")


✅ Saved 9913 beats. Shape per beat: 216 samples


CNN Model Definition & Training

In [None]:
# ================================================================
# 🤖 6. Build and train 1D-CNN
# ================================================================
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

# Load data
X = np.load("data/processed/beats.npy")
y = np.load("data/processed/labels.npy")

# Add channel dimension
X = X[..., np.newaxis]

# Train/val split
X_train, X_val, y_train, y_val = train_test_split(X, y, stratify=y, test_size=0.2, random_state=42)

# One-hot encode labels
y_train_cat = tf.keras.utils.to_categorical(y_train, num_classes=len(class_names))
y_val_cat = tf.keras.utils.to_categorical(y_val, num_classes=len(class_names))

# Compute class weights for imbalance
class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(y), y=y)
cw_dict = {i: w for i, w in enumerate(class_weights)}

# CNN model
def build_cnn(input_shape, num_classes):
    model = tf.keras.Sequential([
        tf.keras.layers.Conv1D(32, kernel_size=7, padding="same", activation="relu", input_shape=input_shape),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling1D(2),

        tf.keras.layers.Conv1D(64, kernel_size=5, padding="same", activation="relu"),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling1D(2),

        tf.keras.layers.Conv1D(128, kernel_size=3, padding="same", activation="relu"),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.GlobalAveragePooling1D(),

        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dropout(0.4),
        tf.keras.layers.Dense(num_classes, activation="softmax")
    ])
    return model

model = build_cnn(X_train.shape[1:], len(class_names))
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

# Train
history = model.fit(
    X_train, y_train_cat,
    validation_data=(X_val, y_val_cat),
    epochs=20,
    batch_size=64,
    class_weight=cw_dict
)

# Save model
model.save("models/cnn_ecg.h5")
print("✅ Model trained and saved.")


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/20
[1m124/124[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 134ms/step - accuracy: 0.6720 - loss: 1.8247 - val_accuracy: 0.3459 - val_loss: 2.2321
Epoch 2/20
[1m124/124[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 70ms/step - accuracy: 0.7723 - loss: 0.4849 - val_accuracy: 0.3459 - val_loss: 5.2172
Epoch 3/20
[1m124/124[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 59ms/step - accuracy: 0.8374 - loss: 0.2750 - val_accuracy: 0.3459 - val_loss: 5.3485
Epoch 4/20
[1m124/124[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 59ms/step - accuracy: 0.8152 - loss: 0.2791 - val_accuracy: 0.3485 - val_loss: 3.5498
Epoch 5/20
[1m124/124[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 69ms/step - accuracy: 0.8660 - loss: 0.3257 - val_accuracy: 0.5618 - val_loss: 0.9041
Epoch 6/20
[1m124/124[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 66ms/step - accuracy: 0.8633 - loss: 0.1922 - val_accuracy: 0.8008 - val_loss: 0.4817
Epoch 7/20
[1m12



✅ Model trained and saved.


3. Grad-CAM for Explainability

In [None]:
# ================================================================
# 🔍 7. Grad-CAM implementation for 1D CNN
# ================================================================
import tensorflow.keras.backend as K

def grad_cam_1d(model, signal, class_index):
    grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(index=-3).output, model.output])
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(tf.expand_dims(signal, axis=0))
        loss = predictions[:, class_index]
    grads = tape.gradient(loss, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=1)
    conv_outputs = conv_outputs.numpy()[0]
    pooled_grads = pooled_grads.numpy()

    for i in range(conv_outputs.shape[-1]):
        conv_outputs[:, i] *= pooled_grads[i]
    heatmap = np.mean(conv_outputs, axis=-1)
    heatmap = np.maximum(heatmap, 0)
    heatmap /= np.max(heatmap) if np.max(heatmap) != 0 else 1
    return heatmap
