## **ECG Arrhythmia Classification with CNN and Interactive Dashboard**

Electrocardiography (ECG) is a non-invasive technique that records the electrical activity of the heart over time. The ECG waveform reflects the coordinated depolarization and repolarization of cardiac muscle cells, mediated by the heart’s conduction system.

### **Anatomical and Physiological Basis**

The heart’s conduction system ensures rhythmic and synchronized contractions:

* **Sinoatrial (SA) Node** – The natural pacemaker, located in the right atrium, initiates the electrical impulse.
* **Atrial Muscle** – Conducts the impulse across both atria, producing the **P wave** (atrial depolarization).
* **Atrioventricular (AV) Node** – Delays the impulse to allow ventricular filling, seen in the **PR segment**.
* **Bundle of His & Bundle Branches** – Transmit the signal through the interventricular septum.
* **Purkinje Fibers** – Rapidly deliver the impulse to ventricular myocardium, generating the **QRS complex** (ventricular depolarization) followed by the **T wave** (ventricular repolarization).

### **Arrhythmias and ECG Changes**

Arrhythmias occur when the impulse generation or conduction pathway is altered:

* **Normal beat (N)** – Regular SA node rhythm with intact conduction.
* **Ventricular ectopic beat (VEB)** – Premature ventricular depolarization from an abnormal focus in the ventricles, often producing a wide QRS complex.
* **Supraventricular ectopic beat (SVEB)** – Originates above the ventricles (atria or AV node) and alters P-wave morphology with a narrow QRS.
* **Fusion beat (F)** – A hybrid waveform from simultaneous normal and ectopic activation.

These morphological differences are directly tied to the anatomical site of origin, making ECG classification both clinically relevant and physiologically interpretable.

### **Project Objective**

In this project, we develop a **Convolutional Neural Network (CNN)** model to classify ECG beats into different arrhythmia types using the MIT-BIH Arrhythmia Database. The model automatically learns morphological features such as P-wave shape, QRS width, and ST-T segment variations that correspond to underlying conduction abnormalities.

To make the results accessible and interpretable, we integrate the trained model into a **Streamlit-based interactive dashboard** that allows users to:

* Upload ECG files or explore sample beats
* View the raw waveform and detected beats
* See classification results with confidence scores
* Explore explainability visualizations (e.g., Grad-CAM) mapping model attention to specific waveform regions
* Connect waveform changes to anatomical and physiological causes

This combination of deep learning, interactive visualization, and anatomical context bridges the gap between machine intelligence and clinical reasoning.

---

In [16]:
# ================================================================
# Create folder structure
# ================================================================
import os

folders = [
    "data/raw",
    "data/processed",
    "models",
    "src/data",
    "src/models",
    "app"
]
for f in folders:
    os.makedirs(f, exist_ok=True)

print("✅ Folder structure ready.")

# ================================================================
# Download MIT-BIH Arrhythmia Database (all records for training)
# ================================================================
import wfdb

# Full MIT-BIH record list (except for corrupted/missing ones)
record_ids = [
    "100", "101", "102", "103", "104", "105", "106", "107", "108",
    "109", "111", "112", "113", "114", "115", "116", "117", "118",
    "119", "121", "122", "123", "124", "200"
]

for rec in record_ids:
    rec_path = os.path.join("data/raw", rec)
    if not os.path.exists(rec_path):
        print(f"⬇️ Downloading record {rec}...")
        wfdb.dl_database(
            "mitdb",
            rec_path,
            records=[rec]
        )
print("✅ All MIT-BIH records downloaded.")

# ================================================================
# Preprocessing functions
# ================================================================
import numpy as np
import neurokit2 as nk
from scipy.signal import butter, filtfilt

def bandpass_filter(signal, fs, lowcut=0.5, highcut=40.0, order=4):
    """Bandpass filter for ECG signal."""
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, signal)

def preprocess_ecg(record_path, window_pre=0.2, window_post=0.4):
    """Load, filter, detect R-peaks, and segment beats."""
    try:
        # Load ECG
        rec = wfdb.rdrecord(record_path)
        sig = rec.p_signal[:, 0]  # first channel
        fs = rec.fs

        # Filter
        sig_filtered = bandpass_filter(sig, fs)

        # R-peak detection
        _, rpeaks = nk.ecg_peaks(sig_filtered, sampling_rate=fs)

        # Segment beats
        beats = []
        wp = int(window_pre * fs)
        ws = int(window_post * fs)
        for r in rpeaks['ECG_R_Peaks']:
            start = r - wp
            end = r + ws
            if start >= 0 and end < len(sig_filtered):
                beat = sig_filtered[start:end]
                beats.append(beat)
        beats = np.array(beats)

        return beats, sig_filtered, fs, rpeaks['ECG_R_Peaks']
    except Exception as e:
        print(f"⚠️ Error in {record_path}: {e}")
        return np.array([]), np.array([]), None, None

# ================================================================
# Batch process all records
# ================================================================
all_beats = []
for rec in record_ids:
    rec_path = os.path.join("data/raw", rec, rec)
    beats, sig_filt, fs, rlocs = preprocess_ecg(rec_path)
    if beats.size > 0:
        all_beats.append(beats)
        print(f"✅ {rec}: {beats.shape[0]} beats extracted")

# Concatenate and save
all_beats = np.vstack(all_beats)
np.save("data/processed/all_beats.npy", all_beats)

print(f"🎯 Total beats extracted: {all_beats.shape[0]}")
print("✅ Preprocessing complete. Data saved in data/processed/")


✅ Folder structure ready.
⬇️ Downloading record 123...
Generating record list for: 123
Generating list of all files for: 123
Created local base download directory: data/raw\123
Downloading files...
Finished downloading files
⬇️ Downloading record 124...
Generating record list for: 124
Generating list of all files for: 124
Created local base download directory: data/raw\124
Downloading files...
Finished downloading files
⬇️ Downloading record 200...
Generating record list for: 200
Generating list of all files for: 200
Created local base download directory: data/raw\200
Downloading files...
Finished downloading files
✅ All MIT-BIH records downloaded.
✅ 100: 2270 beats extracted
✅ 101: 1870 beats extracted
✅ 102: 2187 beats extracted
✅ 103: 2083 beats extracted
✅ 104: 2199 beats extracted
✅ 105: 2571 beats extracted
✅ 106: 2040 beats extracted
✅ 107: 2136 beats extracted
✅ 108: 1775 beats extracted
✅ 109: 2530 beats extracted
✅ 111: 2125 beats extracted
✅ 112: 2538 beats extracted
✅ 113: 

1. Data Preparation and Labelling

In [17]:
# ================================================================
# 📥 5. Beat extraction with labels (Full MIT-BIH Dataset)
# ================================================================
import wfdb
import os
import numpy as np
from scipy.signal import resample_poly

# Mapping from MIT-BIH annotation symbols to simplified class labels
# N = Normal, V = Ventricular ectopic, S = Supraventricular ectopic, F = Fusion, Q = Unknown
symbol_to_class = {
    'N': 0,  # Normal
    'L': 0, 'R': 0, 'e': 0, 'j': 0,
    'V': 1, 'E': 1,
    'S': 2, 'A': 2, 'a': 2, 'J': 2,
    'F': 3,
    '/': 4, 'Q': 4, '?': 4
}

class_names = ["Normal", "VEB", "SVEB", "Fusion", "Unknown"]
TARGET_LENGTH = 200  

def extract_beats_with_labels(record_path):
    """Load ECG signal, filter, align to annotations, extract beats & labels."""
    rec = wfdb.rdrecord(record_path)
    ann = wfdb.rdann(record_path, 'atr')
    sig = rec.p_signal[:, 0]  # first channel
    fs = rec.fs

    # Filter signal
    sig_filtered = bandpass_filter(sig, fs)

    beats, labels = [], []
    window_pre = int(0.2 * fs)
    window_post = int(0.4 * fs)
    
    for idx, sym in zip(ann.sample, ann.symbol):
        if sym in symbol_to_class:
            start = idx - window_pre
            end = idx + window_post
            if start >= 0 and end < len(sig_filtered):
                beat = sig_filtered[start:end]
                # Resample to fixed length
                beat = resample_poly(beat, TARGET_LENGTH, len(beat))
                # Normalize
                beat = (beat - np.mean(beat)) / (np.std(beat) + 1e-8)
                beats.append(beat)
                labels.append(symbol_to_class[sym])

    return np.array(beats), np.array(labels)


# ================================================================
# Process all records from MIT-BIH (48 total)
# ================================================================
all_beats, all_labels = [], []

# List of records in MIT-BIH (100–234, excluding a few missing ones)
record_ids = [
    "100","101","102","103","104","105","106","107","108","109",
    "111","112","113","114","115","116","117","118","119","121",
    "122","123","124","200"
]

for rec in record_ids:
    record_path = os.path.join("data/raw", rec, rec)
    try:
        beats, labels = extract_beats_with_labels(record_path)
        all_beats.append(beats)
        all_labels.append(labels)
        print(f"✅ Processed record {rec} | {beats.shape[0]} beats")
    except Exception as e:
        print(f"⚠️ Skipped record {rec} due to error: {e}")

# Combine all into arrays
all_beats = np.vstack(all_beats)
all_labels = np.concatenate(all_labels)

# Save processed dataset
np.save("data/processed/beats.npy", all_beats)
np.save("data/processed/labels.npy", all_labels)

print(f"✅ Saved {all_beats.shape[0]} beats.")
print(f"   Each beat has {all_beats.shape[1]} samples ({TARGET_LENGTH}).")
print(f"   Class distribution: {np.bincount(all_labels)}")


✅ Processed record 100 | 2272 beats
✅ Processed record 101 | 1865 beats
✅ Processed record 102 | 2131 beats
✅ Processed record 103 | 2083 beats
✅ Processed record 104 | 1562 beats
✅ Processed record 105 | 2572 beats
✅ Processed record 106 | 2027 beats
✅ Processed record 107 | 2137 beats
✅ Processed record 108 | 1763 beats
✅ Processed record 109 | 2531 beats
✅ Processed record 111 | 2124 beats
✅ Processed record 112 | 2538 beats
✅ Processed record 113 | 1794 beats
✅ Processed record 114 | 1879 beats
✅ Processed record 115 | 1952 beats
✅ Processed record 116 | 2411 beats
✅ Processed record 117 | 1534 beats
✅ Processed record 118 | 2277 beats
✅ Processed record 119 | 1987 beats
✅ Processed record 121 | 1862 beats
✅ Processed record 122 | 2475 beats
✅ Processed record 123 | 1517 beats
✅ Processed record 124 | 1618 beats
✅ Processed record 200 | 2600 beats
✅ Saved 49511 beats.
   Each beat has 200 samples (200).
   Class distribution: [41592  2172   222    15  5510]


CNN Model Definition & Training

In [18]:
# ================================================================
# 🤖 6. Build and Train 1D-CNN for ECG Beat Classification
# ================================================================
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

# ------------------------------
# 🔹 Load processed dataset
# ------------------------------
X = np.load("data/processed/beats.npy")
y = np.load("data/processed/labels.npy")

# Add channel dimension for Conv1D
X = X[..., np.newaxis]

# Train-validation split (stratified for balance)
X_train, X_val, y_train, y_val = train_test_split(
    X, y, stratify=y, test_size=0.2, random_state=42
)

# One-hot encode labels
num_classes = len(np.unique(y))
y_train_cat = tf.keras.utils.to_categorical(y_train, num_classes=num_classes)
y_val_cat = tf.keras.utils.to_categorical(y_val, num_classes=num_classes)

# ------------------------------
# ⚖️ Handle class imbalance
# ------------------------------
class_weights = compute_class_weight(
    class_weight="balanced", classes=np.unique(y), y=y
)
cw_dict = {i: w for i, w in enumerate(class_weights)}
print("✅ Computed class weights:", cw_dict)

# ------------------------------
# 🏗️ Define 1D CNN model
# ------------------------------
def build_cnn(input_shape, num_classes):
    model = tf.keras.Sequential([
        tf.keras.layers.Conv1D(32, kernel_size=7, padding="same", activation="relu", input_shape=input_shape),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling1D(2),

        tf.keras.layers.Conv1D(64, kernel_size=5, padding="same", activation="relu"),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling1D(2),

        tf.keras.layers.Conv1D(128, kernel_size=3, padding="same", activation="relu"),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.GlobalAveragePooling1D(),

        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dropout(0.4),
        tf.keras.layers.Dense(num_classes, activation="softmax")
    ])
    return model

model = build_cnn(X_train.shape[1:], num_classes)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
              loss="categorical_crossentropy",
              metrics=["accuracy"])

# ------------------------------
# ⏳ Train model with callbacks
# ------------------------------
callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
    tf.keras.callbacks.ModelCheckpoint("models/cnn_ecg_best.h5", save_best_only=True)
]

history = model.fit(
    X_train, y_train_cat,
    validation_data=(X_val, y_val_cat),
    epochs=30,
    batch_size=64,
    class_weight=cw_dict,
    callbacks=callbacks,
    verbose=1
)

# ------------------------------
# 💾 Save final model
# ------------------------------
model.save("models/cnn_ecg_final.h5")
print("✅ Model trained and saved at models/cnn_ecg_final.h5")


✅ Computed class weights: {0: np.float64(0.23807943835352952), 1: np.float64(4.55902394106814), 2: np.float64(44.604504504504504), 3: np.float64(660.1466666666666), 4: np.float64(1.7971324863883849)}
Epoch 1/30


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.5647 - loss: 1.4158



[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 15ms/step - accuracy: 0.5685 - loss: 1.4053 - val_accuracy: 0.1871 - val_loss: 1.7984
Epoch 2/30
[1m618/619[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 16ms/step - accuracy: 0.5507 - loss: 0.9276



[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 17ms/step - accuracy: 0.5662 - loss: 0.9457 - val_accuracy: 0.5905 - val_loss: 0.8659
Epoch 3/30
[1m617/619[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 16ms/step - accuracy: 0.6025 - loss: 0.8222



[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 17ms/step - accuracy: 0.6060 - loss: 0.9721 - val_accuracy: 0.7364 - val_loss: 0.7961
Epoch 4/30
[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 22ms/step - accuracy: 0.5886 - loss: 0.9299 - val_accuracy: 0.6570 - val_loss: 0.8736
Epoch 5/30
[1m618/619[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 25ms/step - accuracy: 0.6472 - loss: 0.7614



[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 27ms/step - accuracy: 0.6476 - loss: 0.7245 - val_accuracy: 0.7702 - val_loss: 0.6816
Epoch 6/30
[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 19ms/step - accuracy: 0.6752 - loss: 0.6074 - val_accuracy: 0.7206 - val_loss: 0.7001
Epoch 7/30
[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 17ms/step - accuracy: 0.6669 - loss: 0.8030 - val_accuracy: 0.2246 - val_loss: 1.5602
Epoch 8/30
[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 17ms/step - accuracy: 0.6613 - loss: 0.6257 - val_accuracy: 0.3650 - val_loss: 1.7139
Epoch 9/30
[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 14ms/step - accuracy: 0.7319 - loss: 0.6406 - val_accuracy: 0.7035 - val_loss: 0.8343
Epoch 10/30
[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.7240 - loss: 0.6419



[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 17ms/step - accuracy: 0.7396 - loss: 0.5660 - val_accuracy: 0.7405 - val_loss: 0.6423
Epoch 11/30
[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 17ms/step - accuracy: 0.7514 - loss: 0.5157 - val_accuracy: 0.5528 - val_loss: 1.0054
Epoch 12/30
[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 20ms/step - accuracy: 0.7772 - loss: 0.4952 - val_accuracy: 0.6729 - val_loss: 0.8767
Epoch 13/30
[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 18ms/step - accuracy: 0.7695 - loss: 0.4931 - val_accuracy: 0.7075 - val_loss: 0.8280
Epoch 14/30
[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 18ms/step - accuracy: 0.7880 - loss: 0.4538 - val_accuracy: 0.4036 - val_loss: 1.5703
Epoch 15/30
[1m619/619[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 19ms/step - accuracy: 0.7781 - loss: 0.7303 - val_accuracy: 0.7422 - val_loss: 0.7333




✅ Model trained and saved at models/cnn_ecg_final.h5


3. Grad-CAM for Explainability

In [19]:
# ================================================================
# 🔍 Grad-CAM for 1D CNN
# ================================================================
import tensorflow as tf
import numpy as np

def grad_cam_1d(model, signal, class_index, layer_name="last_conv"):
    """
    Generate Grad-CAM heatmap for 1D CNN input.
    
    Args:
        model: Trained tf.keras model.
        signal: 1D numpy array of shape (seq_len,) or (seq_len,1).
        class_index: Target class index for explanation.
        layer_name: Name of the last conv layer.
    Returns:
        heatmap (numpy array): Importance weights aligned to signal length.
    """
    # Ensure correct shape: (1, seq_len, 1)
    if signal.ndim == 1:
        signal = np.expand_dims(signal, axis=-1)
    signal = np.expand_dims(signal, axis=0)

    # Create model mapping input -> (conv outputs, predictions)
    conv_layer = model.get_layer(layer_name)
    grad_model = tf.keras.models.Model(
        inputs=model.inputs,
        outputs=[conv_layer.output, model.output]
    )

    with tf.GradientTape() as tape:
        conv_outputs, predictions =_
