In [1]:
!pip -q install wfdb

import os, json, ast, math, numpy as np, pandas as pd, wfdb, tensorflow as tf, re
from tensorflow import keras
from tensorflow.keras import layers, models, optimizers
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from collections import Counter

print(tf.__version__)

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.8/163.8 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h

2025-09-13 17:43:50.864656: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757785431.073524      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757785431.134393      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


2.18.0


In [2]:
CANDIDATES = [
    "/kaggle/input/ptbxl-ekg/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3"
]
PTBXL_ROOT = next((p for p in CANDIDATES if os.path.exists(p)), None)

CSV_DB = os.path.join(PTBXL_ROOT, "ptbxl_database.csv")
CSV_SCP = os.path.join(PTBXL_ROOT, "scp_statements.csv")

assert os.path.exists(CSV_DB) and os.path.exists(CSV_SCP), "Missing required CSVs."

In [3]:
# build target from scp

TARGETS = [
    'AFib','AFlutter','RBBB','LBBB','ST elevation','ST depression','AV block','MI',
    'WPW','PVC','Idioventricular rhythm','Junctional rhythm','Fusion','Normal'
]
CLS2IDX = {c:i for i,c in enumerate(TARGETS)}

CSV_DB = os.path.join(PTBXL_ROOT, "ptbxl_database.csv")
df = pd.read_csv(CSV_DB)

def parse_scp_codes(s):
    if isinstance(s, dict): return s
    try: return json.loads(s)
    except Exception: return ast.literal_eval(s)

df["scp_codes_dict"] = df["scp_codes"].apply(parse_scp_codes)

# ---- Canonical mapping & regex helpers ----
DIRECT_MAP = {
    # conduction blocks
    "RBBB": "RBBB",
    "LBBB": "LBBB",
    "MI": "MI",
    "WPW": "WPW",
    # rhythms
    "AFIB": "AFib",
    "AFLT": "AFlutter", "AFL": "AFlutter",
    "IVR": "Idioventricular rhythm",
    "JR": "Junctional rhythm", "JER": "Junctional rhythm", "JRN": "Junctional rhythm",
    # ectopy
    "PVC": "PVC", "VPC": "PVC", "VPB": "PVC",
    # fusion beats (naming can vary)
    "FUSION": "Fusion", "FUS": "Fusion",
    # ST changes (common shorthand)
    "STE": "ST elevation", "STD": "ST depression",
}

# Regex buckets for families of codes
AV_BLOCK_PATTERNS = [
    r"^AVB$", r"^IAVB$", r"^AVB1$", r"^AVB2$", r"^AVB3$",
    r"^MOBITZ(_?I|_?II)?$", r"^MOBITZ1$", r"^MOBITZ2$", r"^HB(1|2|3)?$",
]
ST_ELEV_PATTERNS = [r"^STE(MI)?$", r"^ST[_\- ]?ELEV(ATION)?$"]
ST_DEPR_PATTERNS = [r"^STD$", r"^ST[_\- ]?DEPR(ESSION)?$"]

def match_any(pats, code):
    return any(re.match(p, code) for p in pats)

def code_to_target(code):
    c = code.upper()
    if c in DIRECT_MAP:
        return DIRECT_MAP[c]
    if match_any(AV_BLOCK_PATTERNS, c):
        return "AV block"
    if match_any(ST_ELEV_PATTERNS, c):
        return "ST elevation"
    if match_any(ST_DEPR_PATTERNS, c):
        return "ST depression"
    return None

# ---- Weighted vote per record ----
unmapped_seen = set()

def pick_target(code_dict):
    # accumulate weights per target
    scores = {t: 0.0 for t in TARGETS}
    has_norm = False
    for code, w in code_dict.items():
        tgt = code_to_target(str(code))
        if tgt is None:
            if str(code).upper() == "NORM":
                has_norm = True
            else:
                unmapped_seen.add(str(code).upper())
            continue
        scores[tgt] += float(w)

    # winner by max score; fallback to Normal if nothing matched but NORM present
    best_tgt, best_w = max(scores.items(), key=lambda kv: kv[1])
    
    # Prefer MI if tied with STE (common co-annotation; MI should win)
    if best_w > 0 and scores["MI"] == best_w and best_tgt == "ST elevation":
        best_tgt = "MI"
    
    if best_w > 0:
        return best_tgt
    if has_norm:
        return "Normal"
    return None


df["target"] = df["scp_codes_dict"].apply(pick_target)
df = df[df["target"].notna()].copy()
df = df[df["target"].isin(TARGETS)].copy()

# Use 100 Hz files (records100) -> filename_lr
df["path"] = df["filename_hr"].apply(lambda p: os.path.join(PTBXL_ROOT, p))
df["wfdb_stem"] = df["path"].apply(lambda p: os.path.splitext(p)[0])

def has_files(stem):
    return os.path.exists(stem + ".hea") and os.path.exists(stem + ".dat")

df = df[df["wfdb_stem"].apply(has_files)].reset_index(drop=True)

# Standard PTB-XL folds: train (1..8), val (9), test (10)
train_df = df[~df["strat_fold"].isin([9, 10])].copy()
val_df   = df[df["strat_fold"] == 9].copy()
test_df  = df[df["strat_fold"] == 10].copy()

train_df["y"] = train_df["target"].map(CLS2IDX).astype(int)
val_df["y"]   = val_df["target"].map(CLS2IDX).astype(int)
test_df["y"]  = test_df["target"].map(CLS2IDX).astype(int)

num_classes = len(TARGETS)
print("Classes (fixed order):", TARGETS)
print("Counts (train/val/test):")
print(train_df["target"].value_counts().reindex(TARGETS, fill_value=0))
print(val_df["target"].value_counts().reindex(TARGETS, fill_value=0))
print(test_df["target"].value_counts().reindex(TARGETS, fill_value=0))

if unmapped_seen:
    print("\nUnmapped scp codes encountered (add to DIRECT_MAP/regex if needed):")
    print(sorted(list(unmapped_seen)))

Classes (fixed order): ['AFib', 'AFlutter', 'RBBB', 'LBBB', 'ST elevation', 'ST depression', 'AV block', 'MI', 'WPW', 'PVC', 'Idioventricular rhythm', 'Junctional rhythm', 'Fusion', 'Normal']
Counts (train/val/test):
target
AFib                        37
AFlutter                    47
RBBB                         0
LBBB                         0
ST elevation                 0
ST depression                0
AV block                     0
MI                           0
WPW                         64
PVC                        829
Idioventricular rhythm       0
Junctional rhythm            0
Fusion                       0
Normal                    7506
Name: count, dtype: int64
target
AFib                        3
AFlutter                    5
RBBB                        0
LBBB                        0
ST elevation                0
ST depression               0
AV block                    0
MI                          0
WPW                         7
PVC                       101
Idioventr

In [4]:
TARGETS = [
    'AFib','AFlutter','RBBB','LBBB','ST elevation','ST depression','AV block','MI',
    'WPW','PVC','Idioventricular rhythm','Junctional rhythm','Fusion','Normal'
]
CLS2IDX = {c:i for i,c in enumerate(TARGETS)}

CSV_DB = os.path.join(PTBXL_ROOT, "ptbxl_database.csv")
df = pd.read_csv(CSV_DB)

def parse_scp_codes(s):
    if isinstance(s, dict): return s
    try: return json.loads(s)
    except Exception: return ast.literal_eval(s)

df["scp_codes_dict"] = df["scp_codes"].apply(parse_scp_codes)

DIRECT_MAP = {
    "RBBB":"RBBB", "LBBB":"LBBB", "MI":"MI", "WPW":"WPW",
    "AFIB":"AFib",
    "AFLT":"AFlutter", "AFL":"AFlutter",
    "IVR":"Idioventricular rhythm",
    "JR":"Junctional rhythm", "JER":"Junctional rhythm", "JRN":"Junctional rhythm",
    "PVC":"PVC", "VPC":"PVC", "VPB":"PVC",
    "FUSION":"Fusion", "FUS":"Fusion",
    "STE":"ST elevation", "STD":"ST depression",

    # conduction variants
    "CRBBB":"RBBB", "IRBBB":"RBBB",
    "CLBBB":"LBBB", "ILBBB":"LBBB",

    # MI location/markers
    "AMI":"MI", "ALMI":"MI", "ASMI":"MI", "IMI":"MI", "ILMI":"MI",
    "LMI":"MI", "PMI":"MI", "IPMI":"MI", "IPLMI":"MI", "QWAVE":"MI",

    # ectopy patterns to PVC bucket
    "BIGU":"PVC", "TRIGU":"PVC",

    # benign sinus rhythms to Normal
    "SR":"Normal", "SARRH":"Normal", "STACH":"Normal", "SBRAD":"Normal",
}

AV_BLOCK_PATTERNS = [
    r"^(?:AVB|IAVB)$",
    r"^AVB_?[123]$",
    r"^[123]AVB$",
    r"^MOBITZ(?:_?I|_?II)?$",
    r"^HB[123]?$",
]

# ST-elevation family: STE*, all injury currents INJ*, early repolarization (EL)
ST_ELEV_PATTERNS = [
    r"^STE(?:_|MI)?$",
    r"^ST[ _\-]?ELEV(?:ATION)?$",
    r"^INJ[A-Z]*$",         # INJAL/INJAS/INJIL/INJIN/INJLA...
    r"^EL$",                # early repolarization
]

# ST-depression / ischemia family: STD*, ISC*, NST (nonspecific ST-T),
# T-wave inversion/low T, digoxin effect, strain (SEHYP)
ST_DEPR_PATTERNS = [
    r"^STD_?$",
    r"^ST[ _\-]?DEPR(?:ESSION)?$",
    r"^ISC[A-Z_]*$",        # ISCAL/ISCAN/ISCAS/ISCIL/ISCIN/ISCLA/ISC_
    r"^NST_?$",             # non-specific ST-T changes
    r"^INVT$",              # T-wave inversion
    r"^LOWT$",              # low T
    r"^DIG$",               # digoxin effect
    r"^SEHYP$",             # secondary ST-T due to hypertrophy
    r"^VCLVH$",             # voltage criteria LVH often with ST-T strain
]

def match_any(pats, code):
    return any(re.match(p, code) for p in pats)

def code_to_target(code):
    c = str(code).upper()
    # direct hits first
    if c in DIRECT_MAP:
        return DIRECT_MAP[c]
    # regex families next
    if match_any(AV_BLOCK_PATTERNS, c):
        return "AV block"
    if match_any(ST_ELEV_PATTERNS, c):
        return "ST elevation"
    if match_any(ST_DEPR_PATTERNS, c):
        return "ST depression"
    # otherwise no mapping
    return None

# ---- Weighted vote per record ----
unmapped_seen = set()

def pick_target(code_dict):
    # accumulate weights per target
    scores = {t: 0.0 for t in TARGETS}
    has_norm = False
    for code, w in code_dict.items():
        tgt = code_to_target(str(code))
        if tgt is None:
            if str(code).upper() == "NORM":
                has_norm = True
            else:
                unmapped_seen.add(str(code).upper())
            continue
        scores[tgt] += float(w)

    # winner by max score; fallback to Normal if nothing matched but NORM present
    best_tgt = max(scores.items(), key=lambda kv: kv[1])
    if best_tgt[1] > 0:
        return best_tgt[0]
    if has_norm:
        return "Normal"
    return None

df["target"] = df["scp_codes_dict"].apply(pick_target)
df = df[df["target"].notna()].copy()
df = df[df["target"].isin(TARGETS)].copy()

# Use 100 Hz files (records100) -> filename_lr
df["path"] = df["filename_lr"].apply(lambda p: os.path.join(PTBXL_ROOT, p))
df["wfdb_stem"] = df["path"].apply(lambda p: os.path.splitext(p)[0])

def has_files(stem):
    return os.path.exists(stem + ".hea") and os.path.exists(stem + ".dat")

df = df[df["wfdb_stem"].apply(has_files)].reset_index(drop=True)

# Standard PTB-XL folds: train (1..8), val (9), test (10)
train_df = df[~df["strat_fold"].isin([9, 10])].copy()
val_df   = df[df["strat_fold"] == 9].copy()
test_df  = df[df["strat_fold"] == 10].copy()

train_df["y"] = train_df["target"].map(CLS2IDX).astype(int)
val_df["y"]   = val_df["target"].map(CLS2IDX).astype(int)
test_df["y"]  = test_df["target"].map(CLS2IDX).astype(int)

num_classes = len(TARGETS)
print("Classes (fixed order):", TARGETS)
print("Counts (train/val/test):")
print(train_df["target"].value_counts().reindex(TARGETS, fill_value=0))
print(val_df["target"].value_counts().reindex(TARGETS, fill_value=0))
print(test_df["target"].value_counts().reindex(TARGETS, fill_value=0))

if unmapped_seen:
    print("\nUnmapped scp codes encountered (add to DIRECT_MAP/regex if needed):")
    print(sorted(list(unmapped_seen)))

Classes (fixed order): ['AFib', 'AFlutter', 'RBBB', 'LBBB', 'ST elevation', 'ST depression', 'AV block', 'MI', 'WPW', 'PVC', 'Idioventricular rhythm', 'Junctional rhythm', 'Fusion', 'Normal']
Counts (train/val/test):
target
AFib                        36
AFlutter                    47
RBBB                      1212
LBBB                       486
ST elevation               281
ST depression             2185
AV block                   320
MI                        2940
WPW                         58
PVC                        345
Idioventricular rhythm       0
Junctional rhythm            0
Fusion                       0
Normal                    7274
Name: count, dtype: int64
target
AFib                        3
AFlutter                    5
RBBB                      154
LBBB                       58
ST elevation               40
ST depression             266
AV block                   40
MI                        343
WPW                         7
PVC                        47
Idioventr

In [5]:
def has_files(stem):
    # WFDB path is without extension; ensure .hea exists (and implicitly .dat)
    hea = stem + ".hea"
    dat = stem + ".dat"
    return os.path.exists(hea) and os.path.exists(dat)

df["wfdb_stem"] = df["path"].apply(lambda p: os.path.splitext(p)[0])
df = df[df["wfdb_stem"].apply(has_files)].reset_index(drop=True)

In [6]:
assert "target" in df.columns, "Expected column 'target' not found. Did you run the 14-class mapping?"

train_df = df[~df["strat_fold"].isin([9, 10])].copy()
val_df   = df[df["strat_fold"] == 9].copy()
test_df  = df[df["strat_fold"] == 10].copy()

train_df.loc[:, "y"] = train_df["target"].map(CLS2IDX).astype(int)
val_df.loc[:, "y"]   = val_df["target"].map(CLS2IDX).astype(int)
test_df.loc[:, "y"]  = test_df["target"].map(CLS2IDX).astype(int)

num_classes = len(TARGETS)
print("Classes (fixed order):", TARGETS, "->", num_classes)

# Optional: show per-class counts to confirm coverage
print("\nTrain counts:")
print(train_df["target"].value_counts().reindex(TARGETS, fill_value=0))
print("\nVal counts:")
print(val_df["target"].value_counts().reindex(TARGETS, fill_value=0))
print("\nTest counts:")
print(test_df["target"].value_counts().reindex(TARGETS, fill_value=0))

Classes (fixed order): ['AFib', 'AFlutter', 'RBBB', 'LBBB', 'ST elevation', 'ST depression', 'AV block', 'MI', 'WPW', 'PVC', 'Idioventricular rhythm', 'Junctional rhythm', 'Fusion', 'Normal'] -> 14

Train counts:
target
AFib                        36
AFlutter                    47
RBBB                      1212
LBBB                       486
ST elevation               281
ST depression             2185
AV block                   320
MI                        2940
WPW                         58
PVC                        345
Idioventricular rhythm       0
Junctional rhythm            0
Fusion                       0
Normal                    7274
Name: count, dtype: int64

Val counts:
target
AFib                        3
AFlutter                    5
RBBB                      154
LBBB                       58
ST elevation               40
ST depression             266
AV block                   40
MI                        343
WPW                         7
PVC                        47


In [7]:
TARGET_FS = 500
FIXED_LEN = 5000  # 10s @ 100Hz

def read_wfdb(stem_path):
    # string without extension
    sig, meta = wfdb.rdsamp(stem_path)
    # sig: np.ndarray [T, channels], meta.fs should be 100 for filename_lr
    x = sig.astype(np.float32)
    if x.shape[0] < FIXED_LEN:
        pad = FIXED_LEN - x.shape[0]
        x = np.pad(x, ((0, pad), (0, 0)), mode="constant")
    elif x.shape[0] > FIXED_LEN:
        x = x[:FIXED_LEN]
    return x

def py_loader(stem_path):
    x = read_wfdb(stem_path.decode("utf-8"))
    return x
    
def make_ds(df, batch_size=64, shuffle=False, cache=False, augment=False):
    paths = df["wfdb_stem"].values.astype("U")
    ys = df["y"].values.astype(np.int32)

    ds = tf.data.Dataset.from_tensor_slices((paths, ys))
    if shuffle:
        ds = ds.shuffle(len(df), reshuffle_each_iteration=True)

    # Map: load with numpy_function, then set shape
    def _map(path, y):
        x = tf.numpy_function(py_loader, [path], tf.float32)
        x.set_shape([FIXED_LEN, 12])     # (time, channels)
        # Optional per-sample normalization (z-score across time)
        mean = tf.reduce_mean(x, axis=0, keepdims=True)
        std  = tf.math.reduce_std(x, axis=0, keepdims=True) + 1e-6
        x = (x - mean) / std

        if augment:
            # simple jitter, small gaussian noise
            x = x + tf.random.normal(tf.shape(x), stddev=0.01)
        return x, y

    ds = ds.map(_map, num_parallel_calls=tf.data.AUTOTUNE)
    if cache:
        ds = ds.cache()
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds

In [8]:
TARGET_N = 600

dfs = []
for cls in TARGETS:
    df_c = train_df[train_df["target"] == cls]
    if len(df_c) == 0:
        # skip classes that don't exist in train split; the model can't learn them yet
        continue
    reps = max(1, int(np.ceil(TARGET_N / len(df_c))))
    df_aug = pd.concat([df_c] * reps, ignore_index=True).sample(
        TARGET_N, replace=True, random_state=42
    )
    dfs.append(df_aug)

train_df_bal = pd.concat(dfs, ignore_index=True)
train_df_bal = train_df_bal.sample(len(train_df_bal), random_state=42).reset_index(drop=True)

print("Balanced training size:", len(train_df_bal))
print("Balanced per-class counts:")
print(Counter(train_df_bal["target"]))

# Build dataset from the balanced frame
BATCH = 64  
train_ds = make_ds(train_df_bal, batch_size=BATCH, shuffle=True,  cache=True,  augment=True)
val_ds   = make_ds(val_df,      batch_size=BATCH, shuffle=False, cache=True,  augment=False)
test_ds  = make_ds(test_df,     batch_size=BATCH, shuffle=False, cache=False, augment=False)

Balanced training size: 6600
Balanced per-class counts:
Counter({'PVC': 600, 'ST elevation': 600, 'RBBB': 600, 'MI': 600, 'AFlutter': 600, 'Normal': 600, 'AV block': 600, 'ST depression': 600, 'LBBB': 600, 'WPW': 600, 'AFib': 600})


I0000 00:00:1757785605.245828      36 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1757785605.246582      36 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5


In [9]:
BATCH = 64
train_ds = make_ds(train_df, batch_size=BATCH, shuffle=True,  cache=True,  augment=True)
val_ds   = make_ds(val_df,   batch_size=BATCH, shuffle=False, cache=True,  augment=False)
test_ds  = make_ds(test_df,  batch_size=BATCH, shuffle=False, cache=False, augment=False)

y_train = train_df["y"].values
present = np.unique(y_train)  # only classes that actually appear

cw_vals = compute_class_weight(
    class_weight="balanced",
    classes=present,
    y=y_train
)

# Map only present classes; it's fine if some classes are missing from this dict
class_weights = {int(c): float(w) for c, w in zip(present, cw_vals)}
print("Class weights (present classes only):", class_weights)

Class weights (present classes only): {0: 38.343434343434346, 1: 29.36943907156673, 2: 1.138913891389139, 3: 2.8402543958099513, 4: 4.9123261080556455, 5: 0.631745371333472, 6: 4.3136363636363635, 7: 0.46951144094001235, 8: 23.79937304075235, 9: 4.0010540184453225, 13: 0.18976679081160797}


In [10]:
class MacroF1(tf.keras.metrics.Metric):
    # macro-F1 for multi-class classification
    def __init__(self, num_classes, name="macro_f1", **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_classes = num_classes
        self.cm = self.add_weight(
            name="conf_mat",
            shape=(num_classes, num_classes),
            initializer="zeros",
            dtype=tf.float32,
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, tf.int32)
        y_pred = tf.argmax(y_pred, axis=-1, output_type=tf.int32)
        cm = tf.math.confusion_matrix(
            y_true, y_pred, num_classes=self.num_classes, dtype=tf.float32
        )
        self.cm.assign_add(cm)

    def result(self):
        cm = self.cm
        tp = tf.linalg.tensor_diag_part(cm)
        fp = tf.reduce_sum(cm, axis=0) - tp
        fn = tf.reduce_sum(cm, axis=1) - tp
        precision = tf.math.divide_no_nan(tp, tp + fp)
        recall    = tf.math.divide_no_nan(tp, tp + fn)
        f1 = tf.math.divide_no_nan(2.0 * precision * recall, precision + recall)
        return tf.reduce_mean(f1)

    def reset_states(self):
        self.cm.assign(tf.zeros_like(self.cm))

In [11]:
strategy = tf.distribute.MirroredStrategy()
print("Replicas:", strategy.num_replicas_in_sync)

# ----- utility: Stochastic Depth (DropPath) -----
class StochasticDepth(layers.Layer):
    def __init__(self, drop_prob=0.0, **kwargs):
        super().__init__(**kwargs); self.drop_prob = drop_prob
    def call(self, x, training=None):
        if (not training) or self.drop_prob == 0.0:
            return x
        keep = 1.0 - self.drop_prob
        # shape: (batch, 1, 1) broadcast over time & channels
        mask = tf.random.uniform([tf.shape(x)[0], 1, 1], 0, 1) < keep
        mask = tf.cast(mask, x.dtype)
        return (x / keep) * mask

def SEBlock(x, r=16):
    c = x.shape[-1]
    s = layers.GlobalAveragePooling1D()(x)
    s = layers.Dense(max(c//r, 8), activation="relu")(s)
    s = layers.Dense(c, activation="sigmoid")(s)
    return layers.Multiply()([x, layers.Reshape((1, c))(s)])

def Bottleneck(x, filters, k=7, stride=1, se=True, drop_prob=0.0):
    """1D Bottleneck: 1x1 → kx1 → 1x1 (+ SE + DropPath)."""
    in_c = x.shape[-1]

    y = layers.Conv1D(filters, 1, strides=1, padding="same", use_bias=False)(x)
    y = layers.BatchNormalization()(y)
    y = layers.ReLU()(y)

    y = layers.Conv1D(filters, k, strides=stride, padding="same", use_bias=False)(y)
    y = layers.BatchNormalization()(y)
    y = layers.ReLU()(y)

    y = layers.Conv1D(filters * 4, 1, strides=1, padding="same", use_bias=False)(y)
    y = layers.BatchNormalization()(y)

    if se:
        y = SEBlock(y, r=16)

    # projection if needed
    if (in_c != filters * 4) or (stride != 1):
        sc = layers.Conv1D(filters * 4, 1, strides=stride, padding="same", use_bias=False)(x)
        sc = layers.BatchNormalization()(sc)
    else:
        sc = x

    y = layers.add([sc, y])
    y = StochasticDepth(drop_prob=drop_prob)(y)
    y = layers.ReLU()(y)
    return y

def make_seresnet50_1d(input_len, n_leads, num_classes, drop_path_max=0.1):
    """ResNet-50 depth: [3,4,6,3] bottlenecks with SE and DropPath."""
    # kernel plan per stage (larger early to see QRS/ST morphology @500 Hz)
    k1, k2, k3, k4 = 15, 9, 7, 3
    # channels per stage (pre-expansion)
    c1, c2, c3, c4 = 64, 128, 256, 384
    # blocks per stage
    b = [3, 4, 6, 3]
    # strides per stage (downsample at stage start except stage1)
    s = [1, 2, 2, 2]  # 5000 -> 2500 -> 1250 -> 625

    total_blocks = sum(b)
    block_id = 0

    inp = layers.Input(shape=(input_len, n_leads))
    x = layers.Conv1D(64, 15, strides=2, padding="same", use_bias=False)(inp)  # 5000 -> 2500
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # Stage 1
    for i in range(b[0]):
        dp = drop_path_max * (block_id / max(1, total_blocks-1))
        x = Bottleneck(x, c1, k=k1, stride=(s[0] if i==0 else 1), se=True, drop_prob=dp)
        block_id += 1

    # Stage 2
    for i in range(b[1]):
        dp = drop_path_max * (block_id / max(1, total_blocks-1))
        x = Bottleneck(x, c2, k=k2, stride=(s[1] if i==0 else 1), se=True, drop_prob=dp)
        block_id += 1

    # Stage 3
    for i in range(b[2]):
        dp = drop_path_max * (block_id / max(1, total_blocks-1))
        x = Bottleneck(x, c3, k=k3, stride=(s[2] if i==0 else 1), se=True, drop_prob=dp)
        block_id += 1

    # Stage 4
    for i in range(b[3]):
        dp = drop_path_max * (block_id / max(1, total_blocks-1))
        x = Bottleneck(x, c4, k=k4, stride=(s[3] if i==0 else 1), se=True, drop_prob=dp)
        block_id += 1

    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dropout(0.5)(x)
    out = layers.Dense(num_classes, activation="softmax", dtype="float32")(x)  # keep logits in fp32

    return models.Model(inp, out)


macro_f1 = MacroF1(num_classes=len(TARGETS))
with strategy.scope():
    num_classes = len(TARGETS)  # your 14 classes
    model = make_seresnet50_1d(input_len=FIXED_LEN, n_leads=12, num_classes=num_classes, drop_path_max=0.1)

    # Cosine decay w/ restarts (works well on long 500Hz sequences)
    train_source = train_df if 'train_df_bal' not in globals() else train_df_bal
    steps_per_epoch = max(1, len(train_source) // BATCH)
    lr_sched = tf.keras.optimizers.schedules.CosineDecayRestarts(
        initial_learning_rate=2e-3,  # bump a bit with mixed precision
        first_decay_steps=steps_per_epoch * 5
    )
    opt = optimizers.Adam(learning_rate=lr_sched)

    model.compile(
    optimizer=opt,  
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy", macro_f1])

model.summary()

Replicas: 2


In [None]:
LOGDIR = "/kaggle/working/tb"
os.makedirs(LOGDIR, exist_ok=True)

cbs = [
    tf.keras.callbacks.ModelCheckpoint("/kaggle/working/best.keras", monitor="val_loss", save_best_only=True),
    tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=4, min_lr=1e-5, verbose=1),
]

EPOCHS = 45
hist = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS, callbacks=cbs, verbose=1)


Epoch 1/45


I0000 00:00:1757785665.118460     103 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1757785665.121223     102 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m238/238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m368s[0m 1s/step - accuracy: nan - loss: nan - macro_f1: 0.0777 - val_accuracy: 0.5258 - val_loss: 1.3410 - val_macro_f1: 0.0467 - learning_rate: 0.0011
Epoch 2/45
[1m238/238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m309s[0m 1s/step - accuracy: nan - loss: nan - macro_f1: 0.1052 - val_accuracy: 0.5913 - val_loss: 1.3383 - val_macro_f1: 0.1655 - learning_rate: 2.8167e-05
Epoch 3/45
[1m238/238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m305s[0m 1s/step - accuracy: nan - loss: nan - macro_f1: 0.1265 - val_accuracy: 0.7331 - val_loss: 1.1624 - val_macro_f1: 0.1614 - learning_rate: 0.0018
Epoch 4/45
[1m238/238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m305s[0m 1s/step - accuracy: nan - loss: nan - macro_f1: 0.1428 - val_accuracy: 0.6330 - val_loss: 1.0884 - val_macro_f1: 0.1873 - learning_rate: 0.0012
Epoch 5/45
[1m238/238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m305s[0m 1s/step - accuracy: nan - loss: nan 

In [None]:
test_loss, test_acc = model.evaluate(test_ds, verbose=0)
print(f"Test accuracy: {test_acc:.4f}")
CLASS_NAMES = [
    "Normal", "AFib", "AFlutter", "RBBB", "LBBB", "ST Elevation",
    "ST Depression", "AV Block", "MI", "WPW", "PVC",
    "Idioventricular", "Junctional", "Fusion"
]

np.save("/kaggle/working/label_classes.npy", np.array(CLASS_NAMES))
print("Saved classes to /kaggle/working/label_classes.npy")


# Save label encoder classes for inference
model.save("/kaggle/working/best.keras")
model.save("/kaggle/working/best.h5")
classes = np.load("/kaggle/working/label_classes.npy", allow_pickle=True)
print(classes)

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

# Collect true + predicted labels
y_true, y_pred = [], []
for xb, yb in test_ds:
    probs = model.predict(xb, verbose=0)
    y_pred.extend(np.argmax(probs, axis=1))
    y_true.extend(yb.numpy())

# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred, labels=np.arange(len(CLASS_NAMES)))

# Plot
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=CLASS_NAMES)
fig, ax = plt.subplots(figsize=(10, 10))
disp.plot(ax=ax, cmap="Blues", xticks_rotation=45, colorbar=False)
plt.title("Confusion Matrix (Test Set)")
plt.show()
