## **ECG Arrhythmia Classification with CNN and Interactive Dashboard**

Electrocardiography (ECG) is a non-invasive technique that records the electrical activity of the heart over time. The ECG waveform reflects the coordinated depolarization and repolarization of cardiac muscle cells, mediated by the heart’s conduction system.

### **Anatomical and Physiological Basis**

The heart’s conduction system ensures rhythmic and synchronized contractions:

* **Sinoatrial (SA) Node** – The natural pacemaker, located in the right atrium, initiates the electrical impulse.
* **Atrial Muscle** – Conducts the impulse across both atria, producing the **P wave** (atrial depolarization).
* **Atrioventricular (AV) Node** – Delays the impulse to allow ventricular filling, seen in the **PR segment**.
* **Bundle of His & Bundle Branches** – Transmit the signal through the interventricular septum.
* **Purkinje Fibers** – Rapidly deliver the impulse to ventricular myocardium, generating the **QRS complex** (ventricular depolarization) followed by the **T wave** (ventricular repolarization).

### **Arrhythmias and ECG Changes**

Arrhythmias occur when the impulse generation or conduction pathway is altered:

* **Normal beat (N)** – Regular SA node rhythm with intact conduction.
* **Ventricular ectopic beat (VEB)** – Premature ventricular depolarization from an abnormal focus in the ventricles, often producing a wide QRS complex.
* **Supraventricular ectopic beat (SVEB)** – Originates above the ventricles (atria or AV node) and alters P-wave morphology with a narrow QRS.
* **Fusion beat (F)** – A hybrid waveform from simultaneous normal and ectopic activation.

These morphological differences are directly tied to the anatomical site of origin, making ECG classification both clinically relevant and physiologically interpretable.

### **Project Objective**

In this project, we develop a **Convolutional Neural Network (CNN)** model to classify ECG beats into different arrhythmia types using the MIT-BIH Arrhythmia Database. The model automatically learns morphological features such as P-wave shape, QRS width, and ST-T segment variations that correspond to underlying conduction abnormalities.

To make the results accessible and interpretable, we integrate the trained model into a **Streamlit-based interactive dashboard** that allows users to:

* Upload ECG files or explore sample beats
* View the raw waveform and detected beats
* See classification results with confidence scores
* Explore explainability visualizations (e.g., Grad-CAM) mapping model attention to specific waveform regions
* Connect waveform changes to anatomical and physiological causes

This combination of deep learning, interactive visualization, and anatomical context bridges the gap between machine intelligence and clinical reasoning.

---

In [14]:
# ================================================================
# 1. Create folder structure
# ================================================================
import os

folders = [
    "data/raw",
    "data/processed",
    "models",
    "src/data",
    "src/models",
    "app"
]
for f in folders:
    os.makedirs(f, exist_ok=True)

print("Folder structure ready.")


# ================================================================
# 2. Download MIT-BIH Arrhythmia Database (selected records)
# ================================================================
import wfdb

record_ids = [
    "100", "101", "102", "103", "104", "105", "106", "107", "108",
    "109", "111", "112", "113", "114", "115", "116", "117", "118",
    "119", "121", "122", "123", "124", "200"
]

for rec in record_ids:
    rec_path = os.path.join("data/raw", rec)
    if not os.path.exists(rec_path):
        print(f"⬇Downloading record {rec}...")
        wfdb.dl_database("mitdb", rec_path, records=[rec])
print("MIT-BIH records downloaded.")




Folder structure ready.
MIT-BIH records downloaded.


1. Data Preparation and Labelling

In [15]:
# ================================================================
# Preprocessing functions
# ================================================================
import numpy as np
import neurokit2 as nk
from scipy.signal import butter, filtfilt
from collections import Counter
from sklearn.utils import resample

def bandpass_filter(signal, fs, lowcut=0.5, highcut=40.0, order=4):
    """Bandpass filter for ECG signal."""
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, signal)


def preprocess_ecg(record_path, ann_path, window_pre=0.2, window_post=0.4):
    """Load, filter, detect R-peaks, and segment beats with labels."""
    try:
        # Load ECG
        rec = wfdb.rdrecord(record_path)
        sig = rec.p_signal[:, 0]  # first channel
        fs = rec.fs

        # Filter
        sig_filtered = bandpass_filter(sig, fs)

        # Annotations
        ann = wfdb.rdann(ann_path, "atr")
        r_locs = ann.sample
        labels = ann.symbol  # beat labels

        # Segment beats
        beats, beat_labels = [], []
        wp = int(window_pre * fs)
        ws = int(window_post * fs)
        for r, lbl in zip(r_locs, labels):
            start = r - wp
            end = r + ws
            if start >= 0 and end < len(sig_filtered):
                beat = sig_filtered[start:end]
                beats.append(beat)
                beat_labels.append(lbl)

        return np.array(beats), np.array(beat_labels)
    except Exception as e:
        print(f" Error in {record_path}: {e}")
        return np.array([]), np.array([])


# ================================================================
# 4. Batch process all records
# ================================================================
all_beats, all_labels = [], []

for rec in record_ids:
    rec_path = os.path.join("data/raw", rec, rec)
    beats, labels = preprocess_ecg(rec_path, rec_path)
    if beats.size > 0:
        all_beats.append(beats)
        all_labels.append(labels)
        print(f" {rec}: {beats.shape[0]} beats extracted")

# Concatenate
all_beats = np.vstack(all_beats)
all_labels = np.concatenate(all_labels)

print(f"\n Total beats extracted: {all_beats.shape[0]}")
print(f"Label distribution: {Counter(all_labels)}")


# ================================================================
# 5. Balance dataset
# ================================================================
unique_labels = np.unique(all_labels)
balanced_beats, balanced_labels = [], []

min_count = min([np.sum(all_labels == lbl) for lbl in unique_labels])
print(f"Balancing dataset to {min_count} samples per class...")

for lbl in unique_labels:
    beats_lbl = all_beats[all_labels == lbl]
    resampled_beats = resample(beats_lbl, replace=False, n_samples=min_count, random_state=42)
    balanced_beats.append(resampled_beats)
    balanced_labels.extend([lbl] * min_count)

balanced_beats = np.vstack(balanced_beats)
balanced_labels = np.array(balanced_labels)

print(f"Final balanced dataset: {balanced_beats.shape[0]} beats")
print(f"Final label distribution: {Counter(balanced_labels)}")


# ================================================================
# Save dataset
# ================================================================
np.save("data/processed/beats.npy", balanced_beats)
np.save("data/processed/labels.npy", balanced_labels)

print("Data saved in data/processed/ (beats.npy & labels.npy)")


 100: 2272 beats extracted
 101: 1873 beats extracted
 102: 2191 beats extracted
 103: 2089 beats extracted
 104: 2309 beats extracted
 105: 2690 beats extracted
 106: 2098 beats extracted
 107: 2139 beats extracted
 108: 1823 beats extracted
 109: 2533 beats extracted
 111: 2132 beats extracted
 112: 2548 beats extracted
 113: 1794 beats extracted
 114: 1889 beats extracted
 115: 1960 beats extracted
 116: 2420 beats extracted
 117: 1538 beats extracted
 118: 2299 beats extracted
 119: 2093 beats extracted
 121: 1874 beats extracted
 122: 2477 beats extracted
 123: 1517 beats extracted
 124: 1633 beats extracted
 200: 2790 beats extracted

 Total beats extracted: 50981
Label distribution: Counter({np.str_('N'): 33277, np.str_('/'): 5485, np.str_('L'): 4614, np.str_('R'): 3695, np.str_('V'): 2172, np.str_('f'): 722, np.str_('+'): 355, np.str_('~'): 321, np.str_('A'): 185, np.str_('|'): 51, np.str_('J'): 31, np.str_('Q'): 25, np.str_('x'): 21, np.str_('F'): 15, np.str_('j'): 6, np.str_(

CNN Model Definition & Training

In [None]:
# ================================================================
# 6. Build and Train 1D-CNN for ECG Beat Classification
# ================================================================
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import LabelEncoder
import os

# ------------------------------
# Load processed dataset
# ------------------------------
X = np.load("data/processed/beats.npy")
y = np.load("data/processed/labels.npy")

# Add channel dimension for Conv1D
X = X[..., np.newaxis]

# ------------------------------
# Encode string labels -> integers
# ------------------------------
encoder = LabelEncoder()
y_encoded = encoder.fit_transform(y)   # e.g., "/" -> 0, "N" -> 1, "V" -> 2

# Save mapping for later decoding
os.makedirs("data/processed", exist_ok=True)
np.save("data/processed/label_classes.npy", encoder.classes_)
print("Classes:", encoder.classes_)

# ------------------------------
# Train-validation split
# ------------------------------
X_train, X_val, y_train, y_val = train_test_split(
    X, y_encoded, stratify=y_encoded, test_size=0.2, random_state=42
)

# One-hot encode labels
num_classes = len(np.unique(y_encoded))
y_train_cat = tf.keras.utils.to_categorical(y_train, num_classes=num_classes)
y_val_cat = tf.keras.utils.to_categorical(y_val, num_classes=num_classes)

# ------------------------------
# Handle class imbalance
# ------------------------------
class_weights = compute_class_weight(
    class_weight="balanced", classes=np.unique(y_encoded), y=y_encoded
)
cw_dict = {i: w for i, w in enumerate(class_weights)}
print("Computed class weights:", cw_dict)

# ------------------------------
# Define 1D CNN model
# ------------------------------
def build_cnn(input_shape, num_classes):
    model = tf.keras.Sequential([
        tf.keras.layers.Conv1D(32, kernel_size=7, padding="same", activation="relu", input_shape=input_shape),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling1D(2),

        tf.keras.layers.Conv1D(64, kernel_size=5, padding="same", activation="relu"),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling1D(2),

        tf.keras.layers.Conv1D(128, kernel_size=3, padding="same", activation="relu"),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.GlobalAveragePooling1D(),

        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dropout(0.4),
        tf.keras.layers.Dense(num_classes, activation="softmax")
    ])
    return model

model = build_cnn(X_train.shape[1:], num_classes)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
            loss="categorical_crossentropy",
            metrics=["accuracy"])

# ------------------------------
# Train model with callbacks
# ------------------------------
os.makedirs("models", exist_ok=True)

callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
    tf.keras.callbacks.ModelCheckpoint("models/cnn_ecg_best.h5", save_best_only=True)
]

history = model.fit(
    X_train, y_train_cat,
    validation_data=(X_val, y_val_cat),
    epochs=30,
    batch_size=64,
    class_weight=cw_dict,
    callbacks=callbacks,
    verbose=1
)

# ------------------------------
# Save final model
# ------------------------------
model.save("models/cnn_ecg_final.h5")
print("Model trained and saved at models/cnn_ecg_final.h5")



✅ Classes: ['+' '/' 'A' 'F' 'J' 'L' 'N' 'Q' 'R' 'V' 'a' 'f' 'j' 'x' '|' '~']
✅ Computed class weights: {0: np.float64(1.0), 1: np.float64(1.0), 2: np.float64(1.0), 3: np.float64(1.0), 4: np.float64(1.0), 5: np.float64(1.0), 6: np.float64(1.0), 7: np.float64(1.0), 8: np.float64(1.0), 9: np.float64(1.0), 10: np.float64(1.0), 11: np.float64(1.0), 12: np.float64(1.0), 13: np.float64(1.0), 14: np.float64(1.0), 15: np.float64(1.0)}


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m3s[0m 4s/step - accuracy: 0.0000e+00 - loss: 2.9610



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 577ms/step - accuracy: 0.0132 - loss: 2.9567 - val_accuracy: 0.0500 - val_loss: 2.7718
Epoch 2/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 100ms/step - accuracy: 0.0938 - loss: 2.6820



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 241ms/step - accuracy: 0.0921 - loss: 2.6934 - val_accuracy: 0.1000 - val_loss: 2.7710
Epoch 3/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 68ms/step - accuracy: 0.2344 - loss: 2.5470



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 204ms/step - accuracy: 0.2105 - loss: 2.5578 - val_accuracy: 0.1500 - val_loss: 2.7697
Epoch 4/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 78ms/step - accuracy: 0.2969 - loss: 2.4132



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 195ms/step - accuracy: 0.2895 - loss: 2.4341 - val_accuracy: 0.1000 - val_loss: 2.7697
Epoch 5/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 75ms/step - accuracy: 0.3438 - loss: 2.3106



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 206ms/step - accuracy: 0.3026 - loss: 2.3535 - val_accuracy: 0.0500 - val_loss: 2.7697
Epoch 6/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 72ms/step - accuracy: 0.3281 - loss: 2.2355



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 207ms/step - accuracy: 0.3026 - loss: 2.2695 - val_accuracy: 0.0500 - val_loss: 2.7689
Epoch 7/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 62ms/step - accuracy: 0.2656 - loss: 2.2712



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 270ms/step - accuracy: 0.2500 - loss: 2.2871 - val_accuracy: 0.0500 - val_loss: 2.7682
Epoch 8/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 72ms/step - accuracy: 0.3438 - loss: 2.2011



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 180ms/step - accuracy: 0.3421 - loss: 2.2181 - val_accuracy: 0.0500 - val_loss: 2.7675
Epoch 9/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 63ms/step - accuracy: 0.4219 - loss: 2.1427



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 177ms/step - accuracy: 0.3947 - loss: 2.1636 - val_accuracy: 0.0500 - val_loss: 2.7668
Epoch 10/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 65ms/step - accuracy: 0.3438 - loss: 2.1428



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 204ms/step - accuracy: 0.3421 - loss: 2.1360 - val_accuracy: 0.0500 - val_loss: 2.7662
Epoch 11/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 82ms/step - accuracy: 0.3125 - loss: 2.1037



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 198ms/step - accuracy: 0.3026 - loss: 2.1080 - val_accuracy: 0.0500 - val_loss: 2.7657
Epoch 12/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 62ms/step - accuracy: 0.4375 - loss: 1.9161



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 183ms/step - accuracy: 0.4079 - loss: 1.9815 - val_accuracy: 0.0500 - val_loss: 2.7652
Epoch 13/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 62ms/step - accuracy: 0.4531 - loss: 1.9610



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 186ms/step - accuracy: 0.4211 - loss: 2.0103 - val_accuracy: 0.0500 - val_loss: 2.7647
Epoch 14/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 63ms/step - accuracy: 0.3594 - loss: 1.9986



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 186ms/step - accuracy: 0.3816 - loss: 1.9823 - val_accuracy: 0.0500 - val_loss: 2.7644
Epoch 15/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 84ms/step - accuracy: 0.5000 - loss: 1.8369



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 310ms/step - accuracy: 0.4605 - loss: 1.8958 - val_accuracy: 0.0500 - val_loss: 2.7638
Epoch 16/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 78ms/step - accuracy: 0.4844 - loss: 1.7750



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 189ms/step - accuracy: 0.4737 - loss: 1.7895 - val_accuracy: 0.0500 - val_loss: 2.7635
Epoch 17/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 67ms/step - accuracy: 0.4844 - loss: 1.9093



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 193ms/step - accuracy: 0.4342 - loss: 1.9657 - val_accuracy: 0.0500 - val_loss: 2.7631
Epoch 18/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 69ms/step - accuracy: 0.4688 - loss: 1.7557



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 195ms/step - accuracy: 0.4474 - loss: 1.7517 - val_accuracy: 0.0500 - val_loss: 2.7619
Epoch 19/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 62ms/step - accuracy: 0.4375 - loss: 1.7512



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 189ms/step - accuracy: 0.4737 - loss: 1.7223 - val_accuracy: 0.0500 - val_loss: 2.7607
Epoch 20/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 60ms/step - accuracy: 0.4062 - loss: 1.7902



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 188ms/step - accuracy: 0.4474 - loss: 1.7765 - val_accuracy: 0.0500 - val_loss: 2.7597
Epoch 21/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 64ms/step - accuracy: 0.4688 - loss: 1.6918



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 219ms/step - accuracy: 0.4605 - loss: 1.7165 - val_accuracy: 0.0500 - val_loss: 2.7589
Epoch 22/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 89ms/step - accuracy: 0.5156 - loss: 1.6700



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 260ms/step - accuracy: 0.5132 - loss: 1.7086 - val_accuracy: 0.0500 - val_loss: 2.7580
Epoch 23/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 63ms/step - accuracy: 0.4844 - loss: 1.7088



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 201ms/step - accuracy: 0.4605 - loss: 1.7021 - val_accuracy: 0.0500 - val_loss: 2.7570
Epoch 24/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 61ms/step - accuracy: 0.5312 - loss: 1.5761



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 190ms/step - accuracy: 0.5132 - loss: 1.6351 - val_accuracy: 0.0500 - val_loss: 2.7562
Epoch 25/30
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 74ms/step - accuracy: 0.4062 - loss: 1.6153



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 227ms/step - accuracy: 0.4211 - loss: 1.6287 - val_accuracy: 0.0500 - val_loss: 2.7554
Epoch 26/30
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 124ms/step - accuracy: 0.5132 - loss: 1.5722 - val_accuracy: 0.0500 - val_loss: 2.7560
Epoch 27/30
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 123ms/step - accuracy: 0.5921 - loss: 1.5715 - val_accuracy: 0.0500 - val_loss: 2.7568
Epoch 28/30
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 117ms/step - accuracy: 0.5658 - loss: 1.5284 - val_accuracy: 0.0500 - val_loss: 2.7569
Epoch 29/30
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 134ms/step - accuracy: 0.5789 - loss: 1.5202 - val_accuracy: 0.0500 - val_loss: 2.7565
Epoch 30/30
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 127ms/step - accuracy: 0.5526 - loss: 1.4659 - val_accuracy: 0.0500 - val_loss: 2.7561




✅ Model trained and saved at models/cnn_ecg_final.h5


3. Grad-CAM for Explainability

In [17]:


def grad_cam_1d(model, signal, class_index, layer_name="conv1d_2"):
    """
    Generate Grad-CAM heatmap for 1D CNN input.
    
    Args:
        model: Trained tf.keras model.
        signal: 1D numpy array of shape (seq_len,) or (seq_len,1).
        class_index: Target class index for explanation.
        layer_name: Name of the last conv layer.
    Returns:
        heatmap (numpy array): Importance weights aligned to signal length.
    """
    # Ensure correct shape: (1, seq_len, 1)
    if signal.ndim == 1:
        signal = np.expand_dims(signal, axis=-1)
    signal = np.expand_dims(signal, axis=0)

    # Get the last conv layer
    conv_layer = model.get_layer(layer_name)

    # Create model: input -> (conv outputs, predictions)
    grad_model = tf.keras.models.Model(
        inputs=model.inputs,
        outputs=[conv_layer.output, model.output]
    )

    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(signal)
        loss = predictions[:, class_index]

    # Compute gradients of loss wrt conv outputs
    grads = tape.gradient(loss, conv_outputs)

    # Global average pooling over channels
    weights = tf.reduce_mean(grads, axis=1)  # shape (batch, channels)

    # Weighted combination of conv outputs
    cam = tf.reduce_sum(tf.multiply(conv_outputs, tf.expand_dims(weights, 1)), axis=-1)

    # Normalize heatmap to 0–1
    heatmap = tf.maximum(cam, 0).numpy()[0]
    heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) + 1e-8)

    # Resize to match signal length
    heatmap = np.interp(
        np.arange(signal.shape[1]),
        np.linspace(0, signal.shape[1] - 1, len(heatmap)),
        heatmap
    )

    return heatmap
