# Beat Model Training on Colab (CPSC2021)

This notebook trains a CNN model for ECG beat classification using the **CPSC2021** dataset.

**Key Features:**
- **Dataset**: CPSC2021 (The 4th China Physiological Signal Challenge 2021)
- **Resolution**: Native 200Hz
- **Classification**: Hierarchical (Sub-primitives -> Top-level)

**Annotations:**
We use the standard PhysioBank annotation codes (e.g., 'N', 'V', 'S').
See the [PhysioNet Annotations Reference](https://archive.physionet.org/physiobank/annotations.shtml) for a complete list.

**Note:** CPSC2021 is primarily for AFib detection, but we attempt to extract beat-level annotations if available.

In [3]:
!pip install tensorflow numpy pandas scipy wfdb scikit-learn pyedflib




In [3]:
# Data Setup (Google Drive ONLY)
import os
import subprocess
import wfdb
import sys
from google.colab import drive

# 1. Mount Google Drive
# This will prompt for authorization in the Colab output
drive.mount('/content/drive')

# Configuration
# Upload your CPSC2021 zip to this path in Drive BEFORE running this cell
DRIVE_ZIP_PATH = "/content/drive/MyDrive/ML/CPSC2021/Training_set_I.zip"
EXTRACT_DIR = "data/cpsc2021"

def setup_data():
    if os.path.exists(EXTRACT_DIR) and len(os.listdir(EXTRACT_DIR)) > 0:
        print("Data directory exists and is not empty. Skipping setup.")
        return

    os.makedirs("data", exist_ok=True)

    # Load from Google Drive
    if os.path.exists(DRIVE_ZIP_PATH):
        print(f"Found data in Drive: {DRIVE_ZIP_PATH}")
        print("Unzipping to local runtime...")
        # Unzip to EXTRACT_DIR
        os.makedirs(EXTRACT_DIR, exist_ok=True)
        subprocess.run(["unzip", "-q", DRIVE_ZIP_PATH, "-d", EXTRACT_DIR], check=True)
        print("Data loaded from Drive.")
    else:
        print(f"ERROR: Data not found in Drive at {DRIVE_ZIP_PATH}")
        print("Please upload 'Training_set_I.zip' to that location in your Google Drive and re-run this cell.")

setup_data()


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Data directory exists and is not empty. Skipping setup.


In [5]:
import numpy as np
import tensorflow as tf
import wfdb
import os
import scipy.io
from tensorflow.keras import layers, models

# Set seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)


In [11]:
import wfdb
# Data Loader & Processing Logic (CPSC2021 - Fixed for .hea/.dat)

# Constants
FS_NATIVE = 200        # CPSC2021 is 200Hz
WINDOW_SEC = 2.0       # 2-second window centered on beat
WINDOW_SAMPLES = int(WINDOW_SEC * FS_NATIVE) # 400 samples

# Top-Level Mapping
TOP_LEVEL_CLASSES = ['N', 'V', 'S']
TOP_LEVEL_MAP = {c: i for i, c in enumerate(TOP_LEVEL_CLASSES)}

SUB_PRIMITIVE_MAP = {} # Label -> Index
SUB_TO_TOP_MAP = {}    # SubIndex -> TopIndex

def get_top_level_category(label):
    # Heuristic mapping for common descriptors
    label = str(label).upper()
    if 'V' in label or 'PVC' in label: return 'V'
    if 'S' in label or 'SPB' in label or 'PAC' in label or 'A' in label: return 'S'
    return 'N' # Default to Normal

def create_cpsc_metadata(data_dir="data/cpsc2021/Training_set_I", max_files=None):
    # CPSC2021 uses standard WFDB format: .hea (header) and .dat (data)
    # We iterate over .hea files to find records.

    hea_files = []
    for root, dirs, files in os.walk(data_dir):
        for file in files:
            if file.endswith(".hea"):
                hea_files.append(os.path.join(root, file))

    hea_files = sorted(hea_files)

    if max_files:
        hea_files = hea_files[:max_files]

    all_beat_metadata = []
    all_symbols = [] # To build SUB_PRIMITIVE_MAP

    print(f"Scanning {len(hea_files)} records in {data_dir} for beat metadata...")

    processed_records_with_beats = 0
    records_with_errors = 0
    total_symbols_encountered = 0
    total_beats_kept_after_filter = 0
    sample_symbols_raw = []
    sample_symbols_kept = []

    for idx, hea_path in enumerate(hea_files):
        base_path = os.path.splitext(hea_path)[0]
        atr_path = base_path

        try:
            ann = wfdb.rdann(atr_path, 'atr')
            indices = ann.sample
            symbols = ann.symbol

            if not symbols:
                continue

            current_record_beats_found = 0
            current_record_beats_kept = 0

            for i, samp in enumerate(indices):
                sym = symbols[i]
                total_symbols_encountered += 1
                current_record_beats_found += 1

                # Collect a sample of raw symbols
                if len(sample_symbols_raw) < 50: # Limit sample size
                    sample_symbols_raw.append(sym)

                print(f"label: {sym}")

                if sym in ['+', '~', '"', '!', '[', ']', '|', 'x', '(', ')', 'p', 't', 'u', '`', "'"]:
                    continue

                # Collect a sample of kept symbols
                if len(sample_symbols_kept) < 50: # Limit sample size
                    sample_symbols_kept.append(sym)

                all_beat_metadata.append({'record_path': base_path, 'sample_index': samp, 'symbol': sym})
                all_symbols.append(sym)
                total_beats_kept_after_filter += 1
                current_record_beats_kept += 1

            if current_record_beats_kept > 0:
                processed_records_with_beats += 1

        except Exception as e:
            print(f"Error processing {base_path}: {e}") # Uncomment for verbose error per file
            records_with_errors += 1
            pass

        if (idx+1) % 50 == 0:
            print(f"Processed {idx+1} records for metadata. Records with beats found: {processed_records_with_beats}. Records with errors: {records_with_errors}")

    print(f"\n--- Metadata Scan Summary ---")
    print(f"Total records considered: {len(hea_files)}")
    print(f"Records with errors during wfdb.rdann: {records_with_errors}")
    print(f"Records from which at least one beat was successfully extracted: {processed_records_with_beats}")
    print(f"Total raw symbols encountered: {total_symbols_encountered}")
    print(f"Total beats kept after filtering (N, V, S-like): {total_beats_kept_after_filter}")
    if sample_symbols_raw:
        print(f"Sample of ALL raw symbols encountered (unique): {list(set(sample_symbols_raw))}")
    if sample_symbols_kept:
        print(f"Sample of KEPT symbols after filtering (unique): {list(set(sample_symbols_kept))}")


    if not all_beat_metadata:
        print("No beat metadata found after processing and filtering.")
        return [] # Return empty list

    # Dynamic Label Mapping
    unique_labels = sorted(list(set(all_symbols)))
    print(f"Found unique sub-primitives: {unique_labels}")

    global SUB_PRIMITIVE_MAP, SUB_TO_TOP_MAP
    SUB_PRIMITIVE_MAP = {l: i for i, l in enumerate(unique_labels)}

    # Build mapping to top-level
    for l, idx in SUB_PRIMITIVE_MAP.items():
        top_cat = get_top_level_category(l)
        SUB_TO_TOP_MAP[idx] = TOP_LEVEL_MAP[top_cat]

    return all_beat_metadata

def _load_and_preprocess_segment(record_path_tensor, sample_index_tensor, sub_primitive_idx_tensor):
    import wfdb # Redundant import as a safeguard against scope issues with tf.py_function
    record_path = record_path_tensor.numpy().decode('utf-8')
    sample_index = sample_index_tensor.numpy()
    sub_primitive_idx = sub_primitive_idx_tensor.numpy()

    # Load signal for the given record
    try:
        sig, fields = wfdb.rdsamp(record_path)
    except Exception:
        # Return dummy data in case of loading error
        return np.zeros((WINDOW_SAMPLES, 1), dtype=np.float32), tf.one_hot(0, len(SUB_PRIMITIVE_MAP), dtype=tf.int32)

    if sig.ndim == 1:
        signal_1d = sig.astype(np.float32)
    else:
        signal_1d = sig[:, 0].astype(np.float32) # Use Lead I

    start = sample_index - WINDOW_SAMPLES // 2
    end = start + WINDOW_SAMPLES

    if start < 0 or end > len(signal_1d):
        # Return dummy data if segment out of bounds
        return np.zeros((WINDOW_SAMPLES, 1), dtype=np.float32), tf.one_hot(0, len(SUB_PRIMITIVE_MAP), dtype=tf.int32)

    segment = signal_1d[start:end]

    # Normalize
    segment = (segment - np.mean(segment)) / (np.std(segment) + 1e-6)
    segment = segment[..., np.newaxis] # Add channel dimension (N, 400, 1)

    # One-hot encode the label
    one_hot_label = tf.one_hot(sub_primitive_idx, len(SUB_PRIMITIVE_MAP), dtype=tf.int32)

    return segment, one_hot_label

In [7]:
# Model Architecture (Updated for 200Hz Input)
def build_model(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)

    x = layers.Conv1D(64, 15, strides=2, padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(3, strides=2, padding='same')(x)

    # Residual Blocks
    for filters in [64, 128, 256, 512]:
        shortcut = x
        x = layers.Conv1D(filters, 7, padding='same', activation='relu')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Conv1D(filters, 7, padding='same', activation='relu')(x)
        x = layers.BatchNormalization()(x)

        if shortcut.shape[-1] != filters:
            shortcut = layers.Conv1D(filters, 1, padding='same')(shortcut)

        x = layers.Add()([x, shortcut])
        x = layers.MaxPooling1D(2)(x)

    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.5)(x)

    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs=inputs, outputs=outputs, name="ECG_CPSC2021_ResNet")
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model


In [10]:
# Training & Evaluation Loop
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow as tf # Ensure tf is imported here
import numpy as np # Ensure numpy is imported

def train_and_eval():
    all_beat_metadata = create_cpsc_metadata()

    if not all_beat_metadata:
        print("No valid beat metadata found. Please check data loading setup.")
        return

    # Extract sub_primitive_indices for stratification
    # Ensure SUB_PRIMITIVE_MAP is populated by create_cpsc_metadata
    sub_primitive_indices = [SUB_PRIMITIVE_MAP[beat['symbol']] for beat in all_beat_metadata]

    # Stratify split on metadata indices
    train_metadata, test_metadata, _, _ = train_test_split(
        all_beat_metadata, sub_primitive_indices, test_size=0.2, random_state=42, stratify=sub_primitive_indices
    )

    print(f"Number of training beats: {len(train_metadata)}, Number of testing beats: {len(test_metadata)}")
    print(f"Classes Map: {SUB_PRIMITIVE_MAP}")

    num_sub_classes = len(SUB_PRIMITIVE_MAP)

    # Create tf.data.Dataset for training and validation
    def generator_fn(metadata_list):
        for beat in metadata_list:
            record_path = beat['record_path']
            sample_index = beat['sample_index']
            sub_primitive_idx = SUB_PRIMITIVE_MAP[beat['symbol']]
            yield record_path, sample_index, sub_primitive_idx

    output_signature = (
        tf.TensorSpec(shape=(), dtype=tf.string), # record_path
        tf.TensorSpec(shape=(), dtype=tf.int32),  # sample_index
        tf.TensorSpec(shape=(), dtype=tf.int32)   # sub_primitive_idx
    )

    train_ds_raw = tf.data.Dataset.from_generator(
        lambda: generator_fn(train_metadata),
        output_signature=output_signature
    )
    test_ds_raw = tf.data.Dataset.from_generator(
        lambda: generator_fn(test_metadata),
        output_signature=output_signature
    )

    # Map the preprocessing function to the dataset
    # We need to wrap _load_and_preprocess_segment to make it a tf.function compatible map
    def map_fn(record_path, sample_index, sub_primitive_idx):
        segment, one_hot_label = tf.py_function(
            _load_and_preprocess_segment,
            inp=[record_path, sample_index, sub_primitive_idx],
            Tout=[tf.float32, tf.int32]
        )
        segment.set_shape([WINDOW_SAMPLES, 1])
        one_hot_label.set_shape([num_sub_classes])
        return segment, one_hot_label

    BATCH_SIZE = 64
    AUTOTUNE = tf.data.AUTOTUNE

    train_dataset = train_ds_raw.map(map_fn, num_parallel_calls=AUTOTUNE) \
                                .shuffle(buffer_size=len(train_metadata)) \
                                .batch(BATCH_SIZE) \
                                .prefetch(AUTOTUNE)
    test_dataset = test_ds_raw.map(map_fn, num_parallel_calls=AUTOTUNE) \
                               .batch(BATCH_SIZE) \
                               .prefetch(AUTOTUNE)


    # Assuming an input shape for the model
    # The actual input shape will be (WINDOW_SAMPLES, 1)
    input_shape = (WINDOW_SAMPLES, 1)
    model = build_model(input_shape, num_sub_classes)

    history = model.fit(
        train_dataset,
        validation_data=test_dataset,
        epochs=10,
        callbacks=[tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)]
    )

    # Evaluation - This part needs to be adapted for tf.data.Dataset
    print("\n--- Sub-Primitive Evaluation ---")
    y_pred_prob = model.predict(test_dataset)
    y_pred = np.argmax(y_pred_prob, axis=1)

    # To get true labels, iterate through the test_dataset once
    y_true_list = []
    for _, labels in test_dataset.unbatch(): # unbatch to get individual labels
        y_true_list.append(labels.numpy())
    y_true = np.argmax(np.array(y_true_list), axis=1)


    sub_labels = sorted(SUB_PRIMITIVE_MAP.values())
    sub_names = [k for k, v in sorted(SUB_PRIMITIVE_MAP.items(), key=lambda item: item[1])]
    print(classification_report(y_true, y_pred, labels=sub_labels, target_names=sub_names, zero_division=0))

    print("\n--- Top-Level Primitive Evaluation (N, V, S) ---")
    y_true_top = np.array([SUB_TO_TOP_MAP[i] for i in y_true])
    y_pred_top = np.array([SUB_TO_TOP_MAP[i] for i in y_pred])

    top_labels = sorted(TOP_LEVEL_MAP.values())
    top_names = [k for k, v in sorted(TOP_LEVEL_MAP.items(), key=lambda item: item[1])]
    print(classification_report(y_true_top, y_pred_top, labels=top_labels, target_names=top_names, zero_division=0))

    model.save("models/cpsc2021_beat_model.keras")

if __name__ == "__main__":
    os.makedirs("models", exist_ok=True)
    if os.path.exists("data/cpsc2021"):
        train_and_eval()
    else:
        print("Data directory not found. Please ensure download completes.")

Scanning 719 records in data/cpsc2021/Training_set_I for beat metadata...
Error processing data/cpsc2021/Training_set_I/data_0_1: name 'wfdb' is not defined
Error processing data/cpsc2021/Training_set_I/data_0_10: name 'wfdb' is not defined
Error processing data/cpsc2021/Training_set_I/data_0_11: name 'wfdb' is not defined
Error processing data/cpsc2021/Training_set_I/data_0_12: name 'wfdb' is not defined
Error processing data/cpsc2021/Training_set_I/data_0_13: name 'wfdb' is not defined
Error processing data/cpsc2021/Training_set_I/data_0_14: name 'wfdb' is not defined
Error processing data/cpsc2021/Training_set_I/data_0_15: name 'wfdb' is not defined
Error processing data/cpsc2021/Training_set_I/data_0_2: name 'wfdb' is not defined
Error processing data/cpsc2021/Training_set_I/data_0_3: name 'wfdb' is not defined
Error processing data/cpsc2021/Training_set_I/data_0_4: name 'wfdb' is not defined
Error processing data/cpsc2021/Training_set_I/data_0_5: name 'wfdb' is not defined
Error p

In [None]:
# Convert to TFLite
def convert_to_tflite(model_path="models/cpsc2021_beat_model.keras", tflite_path="models/cpsc2021_beat_model.tflite"):
    print(f"Converting {model_path} to TFLite...")
    try:
        model = tf.keras.models.load_model(model_path)
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        tflite_model = converter.convert()

        with open(tflite_path, "wb") as f:
            f.write(tflite_model)
        print(f"TFLite model saved to {tflite_path}")
    except Exception as e:
        print(f"Error converting to TFLite: {e}")

if __name__ == "__main__":
    if os.path.exists("models/cpsc2021_beat_model.keras"):
        convert_to_tflite()


In [None]:
# Download Models (Colab specific)
try:
    from google.colab import files
    print("Zipping models directory...")
    !zip -r models.zip models
    print("Downloading models.zip...")
    files.download("models.zip")
except ImportError:
    print("Not running in Google Colab, skipping download.")
