# Beat Model Training on Colab (CPSC2021)

This notebook trains a CNN model for ECG beat classification using the **CPSC2021** dataset.

**Key Features:**
- **Dataset**: CPSC2021 (The 4th China Physiological Signal Challenge 2021)
- **Resolution**: Native 200Hz
- **Classification**: Hierarchical (Sub-primitives -> Top-level)

**Annotations:**
We use the standard PhysioBank annotation codes (e.g., 'N', 'V', 'S'). 
See the [PhysioNet Annotations Reference](https://archive.physionet.org/physiobank/annotations.shtml) for a complete list.

**Note:** CPSC2021 is primarily for AFib detection, but we attempt to extract beat-level annotations if available.

In [None]:
!pip install tensorflow numpy pandas scipy wfdb scikit-learn pyedflib


In [None]:
# Data Setup (Google Drive ONLY)
import os
import subprocess
import sys
from google.colab import drive

# 1. Mount Google Drive
# This will prompt for authorization in the Colab output
drive.mount('/content/drive')

# Configuration
# Upload your CPSC2021 zip to this path in Drive BEFORE running this cell
DRIVE_ZIP_PATH = "/content/drive/ML/CPSC2021/Training_set_I.zip"
EXTRACT_DIR = "data/cpsc2021"

def setup_data():
    if os.path.exists(EXTRACT_DIR) and len(os.listdir(EXTRACT_DIR)) > 0:
        print("Data directory exists and is not empty. Skipping setup.")
        return

    os.makedirs("data", exist_ok=True)

    # Load from Google Drive
    if os.path.exists(DRIVE_ZIP_PATH):
        print(f"Found data in Drive: {DRIVE_ZIP_PATH}")
        print("Unzipping to local runtime...")
        # Unzip to EXTRACT_DIR
        os.makedirs(EXTRACT_DIR, exist_ok=True)
        subprocess.run(["unzip", "-q", DRIVE_ZIP_PATH, "-d", EXTRACT_DIR], check=True)
        print("Data loaded from Drive.")
    else:
        print(f"ERROR: Data not found in Drive at {DRIVE_ZIP_PATH}")
        print("Please upload 'Training_set_I.zip' to that location in your Google Drive and re-run this cell.")

setup_data()


In [None]:
import numpy as np
import tensorflow as tf
import wfdb
import scipy.io
from tensorflow.keras import layers, models

# Set seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)


In [None]:
# Data Loader & Processing Logic (CPSC2021)

# Constants
FS_NATIVE = 200        # CPSC2021 is 200Hz
WINDOW_SEC = 2.0       # 2-second window centered on beat
WINDOW_SAMPLES = int(WINDOW_SEC * FS_NATIVE) # 400 samples

# Top-Level Mapping
TOP_LEVEL_CLASSES = ['N', 'V', 'S']
TOP_LEVEL_MAP = {c: i for i, c in enumerate(TOP_LEVEL_CLASSES)}

SUB_PRIMITIVE_MAP = {} # Label -> Index
SUB_TO_TOP_MAP = {}    # SubIndex -> TopIndex

def get_top_level_category(label):
    # Heuristic mapping for common descriptors based on PhysioNet standard
    # Ref: https://archive.physionet.org/physiobank/annotations.shtml
    label = str(label).upper()
    if 'V' in label or 'PVC' in label: return 'V'
    if 'S' in label or 'SPB' in label or 'PAC' in label or 'A' in label: return 'S'
    return 'N' # Default to Normal

def load_cpsc_data(data_dir="data/cpsc2021", max_files=None):
    # CPSC2021 typically has .mat and .hea files (WFDB standard)
    # The recursive download might put files in nested folders like Training_set_I
    # We should search recursively or assume a structure.
    
    mat_files = []
    for root, dirs, files in os.walk(data_dir):
        for file in files:
            if file.endswith(".mat"):
                mat_files.append(os.path.join(root, file))
    
    # mat_files = sorted([f for f in os.listdir(data_dir) if f.endswith('.mat')])
    mat_files = sorted(mat_files)
    
    if max_files:
        mat_files = mat_files[:max_files]

    X_all = []
    y_sub_all = [] 

    print(f"Scanning {len(mat_files)} files in {data_dir}...")

    for idx, mat_path in enumerate(mat_files):
        # base_name = os.path.splitext(mat_file)[0]
        # mat_path = os.path.join(data_dir, mat_file)
        base_path = os.path.splitext(mat_path)[0]
        hea_path = base_path
        atr_path = base_path

        try:
            # 1. Load Annotations (WFDB)
            try:
                ann = wfdb.rdann(atr_path, 'atr')
                indices = ann.sample
                symbols = ann.symbol
                if not symbols:
                    continue
            except Exception:
                continue

            # 2. Load Signal
            try:
                sig, fields = wfdb.rdsamp(hea_path)
                fs = fields['fs']
                if fs != FS_NATIVE:
                    # Naive skip for now if FS mismatch
                    pass 
            except:
                 continue

            # Use Lead I (Index 0)
            signal_1d = sig[:, 0]

            for i, samp in enumerate(indices):
                sym = symbols[i]
                
                # Skip non-beat annotations
                if sym in ['+', '~', '"', '!', '[', ']', '|', 'x', '(', ')', 'p', 't', 'u', '`', "'"]:
                    continue

                # Windowing
                start = samp - WINDOW_SAMPLES // 2
                end = start + WINDOW_SAMPLES

                if start < 0 or end > len(signal_1d):
                    continue

                segment = signal_1d[start:end]
                
                # Normalize
                segment = (segment - np.mean(segment)) / (np.std(segment) + 1e-6)

                X_all.append(segment)
                y_sub_all.append(sym)

        except Exception as e:
            pass

        if (idx+1) % 50 == 0:
            print(f"Processed {idx+1} files...")

    if not X_all:
        return None, None, None

    X = np.array(X_all)
    X = X[..., np.newaxis] # (N, 400, 1)
    
    # Dynamic Label Mapping
    unique_labels = sorted(list(set(y_sub_all)))
    print(f"Found unique sub-primitives: {unique_labels}")

    global SUB_PRIMITIVE_MAP, SUB_TO_TOP_MAP
    SUB_PRIMITIVE_MAP = {l: i for i, l in enumerate(unique_labels)}
    
    # Build mapping to top-level
    for l, idx in SUB_PRIMITIVE_MAP.items():
        top_cat = get_top_level_category(l)
        SUB_TO_TOP_MAP[idx] = TOP_LEVEL_MAP[top_cat]

    y_indices = np.array([SUB_PRIMITIVE_MAP[l] for l in y_sub_all])
    
    # Convert to One-Hot
    num_subs = len(unique_labels)
    y_onehot = np.eye(num_subs)[y_indices]

    return X, y_onehot, y_indices


In [None]:
# Model Architecture (Updated for 200Hz Input)
def build_model(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)

    x = layers.Conv1D(64, 15, strides=2, padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(3, strides=2, padding='same')(x)

    # Residual Blocks
    for filters in [64, 128, 256, 512]:
        shortcut = x
        x = layers.Conv1D(filters, 7, padding='same', activation='relu')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Conv1D(filters, 7, padding='same', activation='relu')(x)
        x = layers.BatchNormalization()(x)
        
        if shortcut.shape[-1] != filters:
            shortcut = layers.Conv1D(filters, 1, padding='same')(shortcut)
        
        x = layers.Add()([x, shortcut])
        x = layers.MaxPooling1D(2)(x)

    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs=inputs, outputs=outputs, name="ECG_CPSC2021_ResNet")
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model


In [None]:
# Training & Evaluation Loop
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

def train_and_eval():
    X, y, y_indices = load_cpsc_data()
    
    if X is None:
        print("No valid data loaded. Please check if CPSC2021 contains beat annotations.")
        return
    
    # Stratify by 1D indices
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y_indices)
    
    print(f"Train Shape: {X_train.shape}, Test Shape: {X_test.shape}")
    print(f"Classes Map: {SUB_PRIMITIVE_MAP}")

    model = build_model(X_train.shape[1:], y_train.shape[1])
    
    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=10,
        batch_size=64,
        callbacks=[tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)]
    )
    
    # Evaluation
    print("\n--- Sub-Primitive Evaluation ---")
    y_pred_prob = model.predict(X_test)
    y_pred = np.argmax(y_pred_prob, axis=1)
    y_true = np.argmax(y_test, axis=1)
    
    sub_labels = sorted(SUB_PRIMITIVE_MAP.values())
    sub_names = [k for k, v in sorted(SUB_PRIMITIVE_MAP.items(), key=lambda item: item[1])]
    print(classification_report(y_true, y_pred, labels=sub_labels, target_names=sub_names, zero_division=0))

    print("\n--- Top-Level Primitive Evaluation (N, V, S) ---")
    y_true_top = np.array([SUB_TO_TOP_MAP[i] for i in y_true])
    y_pred_top = np.array([SUB_TO_TOP_MAP[i] for i in y_pred])
    
    top_labels = sorted(TOP_LEVEL_MAP.values())
    top_names = [k for k, v in sorted(TOP_LEVEL_MAP.items(), key=lambda item: item[1])]
    print(classification_report(y_true_top, y_pred_top, labels=top_labels, target_names=top_names, zero_division=0))
    
    model.save("models/cpsc2021_beat_model.keras")

if __name__ == "__main__":
    os.makedirs("models", exist_ok=True)
    if os.path.exists("data/cpsc2021"):
        train_and_eval()
    else:
        print("Data directory not found. Please ensure download completes.")


In [None]:
# Convert to TFLite
def convert_to_tflite(model_path="models/cpsc2021_beat_model.keras", tflite_path="models/cpsc2021_beat_model.tflite"):
    print(f"Converting {model_path} to TFLite...")
    try:
        model = tf.keras.models.load_model(model_path)
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        tflite_model = converter.convert()
        
        with open(tflite_path, "wb") as f:
            f.write(tflite_model)
        print(f"TFLite model saved to {tflite_path}")
    except Exception as e:
        print(f"Error converting to TFLite: {e}")

if __name__ == "__main__":
    if os.path.exists("models/cpsc2021_beat_model.keras"):
        convert_to_tflite()


In [None]:
# Download Models (Colab specific)
try:
    from google.colab import files
    print("Zipping models directory...")
    !zip -r models.zip models
    print("Downloading models.zip...")
    files.download("models.zip")
except ImportError:
    print("Not running in Google Colab, skipping download.")
