## ConvNeXt v2 for Chagas ECG detection

This notebook implements a ConvNeXt v2 model for ECG-based Chagas disease detection.  
It treats each ECG as a 12-channel time-series, using the modern ConvNeXt v2 architecture
which combines depthwise separable convolutions, inverted bottlenecks, and advanced
normalization techniques.

Goals:  
- Leverage state-of-the-art ConvNeXt v2 architecture for improved feature extraction
- Capture complex temporal patterns in ECG signals that simpler models might miss  
- Provide a strong deep learning baseline using modern architectural innovations
- Keep preprocessing, splits, and metrics identical to prior notebooks for an
  apples-to-apples comparison
- ConvNeXt V2 paper: https://arxiv.org/abs/2301.00808

## Environment setup

### Import libraries

In [1]:
import numpy as np
import tensorflow as tf
from keras import layers, models
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    average_precision_score,
    precision_recall_fscore_support,
    confusion_matrix,
    ConfusionMatrixDisplay,
)
import matplotlib.pyplot as plt
from tqdm.keras import TqdmCallback

RANDOM_STATE = 2025
tf.keras.utils.set_random_seed(RANDOM_STATE)

In [2]:
# Check GPU/Metal availability
print("TensorFlow version:", tf.__version__)
print("Keras version:", tf.keras.__version__)

# Check for GPU availability
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    print(f"GPU(s) available: {len(gpus)}")
    for i, gpu in enumerate(gpus):
        print(f"  GPU {i}: {gpu}")
    print(
        f"GPU memory growth enabled: {tf.config.experimental.get_memory_growth(gpus[0])}"
    )
else:
    print("No GPU available")

# Check if running on Apple Silicon with Metal
try:
    metal_devices = tf.config.list_physical_devices("GPU")
    if metal_devices and any(
        "metal" in str(device).lower() for device in metal_devices
    ):
        print("Metal Performance Shaders (MPS) detected for Apple Silicon")
except:
    pass

# Check available devices
print("\nAvailable devices:")
for device in tf.config.list_logical_devices():
    print(f"  {device}")


TensorFlow version: 2.19.0
Keras version: 3.10.0
GPU(s) available: 1
  GPU 0: PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')
GPU memory growth enabled: None

Available devices:
  LogicalDevice(name='/device:CPU:0', device_type='CPU')
  LogicalDevice(name='/device:GPU:0', device_type='GPU')


2025-07-27 21:57:34.153258: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M4 Pro
2025-07-27 21:57:34.153290: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 48.00 GB
2025-07-27 21:57:34.153297: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 18.00 GB
I0000 00:00:1753678654.153321 1876177 pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
I0000 00:00:1753678654.153354 1876177 pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


### Load preprocessed datasets

In [3]:
# Path to the folder containing preprocessed data
DATA_DIR = "../data/prepared"

# train = np.load(f'{DATA_DIR}/train_full_parts0-6.npz')
train = np.load(f"{DATA_DIR}/train_bal_parts0-6_aug.npz")
val = np.load(f"{DATA_DIR}/val_parts0-6.npz")
test = np.load(f"{DATA_DIR}/test_external.npz")

# Extract arrays and labels from the loaded data
X_train, y_train = train["X"], train["y"]
X_val, y_val = val["X"], val["y"]
X_test, y_test = test["X"], test["y"]

# Check array shapes and positive counts
print("Train :", X_train.shape, "Positives:", y_train.sum())
print("Val   :", X_val.shape, "Positives:", y_val.sum())
print("Test  :", X_test.shape, "Positives:", y_test.sum())

Train : (17880, 2920, 12) Positives: 4470
Val   : (27873, 2920, 12) Positives: 559
Test  : (23430, 2920, 12) Positives: 1631


## Modeling

### Build ConvNeXt v2 model

In [4]:
class LayerScale(layers.Layer):
    """Layer Scale implementation for ConvNeXt v2"""

    def __init__(self, init_value=1e-6, **kwargs):
        super().__init__(**kwargs)
        self.init_value = init_value

    def build(self, input_shape):
        self.scale = self.add_weight(
            name="scale",
            shape=(input_shape[-1],),
            initializer=tf.keras.initializers.Constant(self.init_value),
            trainable=True,
        )
        super().build(input_shape)

    def call(self, x):
        return x * self.scale

    def get_config(self):
        config = super().get_config()
        config.update({"init_value": self.init_value})
        return config


class DropPath(layers.Layer):
    """Drop Path (Stochastic Depth) implementation"""

    def __init__(self, drop_rate=0.0, **kwargs):
        super().__init__(**kwargs)
        self.drop_rate = drop_rate

    def call(self, x, training=None):
        if not training or self.drop_rate == 0.0:
            return x

        keep_prob = 1 - self.drop_rate
        shape = (tf.shape(x)[0],) + (1,) * (len(x.shape) - 1)
        random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
        binary_tensor = tf.floor(random_tensor)
        return (x / keep_prob) * binary_tensor

    def get_config(self):
        config = super().get_config()
        config.update({"drop_rate": self.drop_rate})
        return config


def convnext_block_1d(x, dim, drop_path_rate=0.0):
    """ConvNeXt v2 block adapted for 1D signals."""
    input_x = x

    # Depthwise convolution
    x = layers.DepthwiseConv1D(
        kernel_size=7, padding="same", depth_multiplier=1, use_bias=False
    )(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)

    # Inverted bottleneck
    x = layers.Conv1D(4 * dim, kernel_size=1)(x)
    x = layers.Activation("gelu")(x)
    x = layers.Conv1D(dim, kernel_size=1)(x)

    # Layer Scale and Drop Path
    x = LayerScale(init_value=1e-6)(x)
    if drop_path_rate > 0:
        x = DropPath(drop_rate=drop_path_rate)(x)

    return layers.Add()([input_x, x])


def build_convnext_v2_model(seq_len=2920, n_ch=12):
    """ConvNeXt v2 model for 1D ECG signals."""
    # Input and stem
    inputs = layers.Input(shape=(seq_len, n_ch))
    x = layers.Conv1D(96, kernel_size=4, strides=4, padding="same")(inputs)
    x = layers.LayerNormalization(epsilon=1e-6)(x)

    # Four stages with different channel dimensions
    channels = [96, 192, 384, 768]
    blocks_per_stage = [3, 3, 9, 3]
    drop_path_rates = np.linspace(0, 0.4, sum(blocks_per_stage))

    block_idx = 0
    for stage, (dim, num_blocks) in enumerate(zip(channels, blocks_per_stage)):
        # Apply blocks for current stage
        for _ in range(num_blocks):
            x = convnext_block_1d(x, dim, drop_path_rates[block_idx])
            block_idx += 1

        # Downsample between stages (except last)
        if stage < len(channels) - 1:
            x = layers.LayerNormalization(epsilon=1e-6)(x)
            x = layers.Conv1D(
                channels[stage + 1], kernel_size=2, strides=2, padding="same"
            )(x)

    # Classification head
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.Dropout(0.1)(x)
    x = layers.Dense(1, activation="sigmoid")(x)

    # Create and compile model
    model = models.Model(inputs, x, name="convnext_v2_ecg")

    # Cosine decay learning rate
    total_steps = 300 * len(X_train) // 32  # EPOCHS * samples // BATCH
    cosine_decay = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=4e-3, decay_steps=total_steps, alpha=1e-6
    )

    model.compile(
        optimizer=tf.keras.optimizers.AdamW(
            learning_rate=cosine_decay, weight_decay=0.05
        ),
        loss="binary_crossentropy",
        metrics=[
            tf.keras.metrics.AUC(name="auroc", curve="ROC"),
            tf.keras.metrics.AUC(name="auprc", curve="PR"),
        ],
    )

    model.summary(line_length=80)
    return model


model = build_convnext_v2_model()

In [None]:
import os

os.makedirs("../models", exist_ok=True)

# Convert data
X_train_tf = X_train.astype("float32")
X_val_tf = X_val.astype("float32")
X_test_tf = X_test.astype("float32")

# Training parameters
EPOCHS = 300
BATCH = 128

# Callbacks
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_auroc", mode="max", patience=50, restore_best_weights=True
    ),
    tf.keras.callbacks.ModelCheckpoint(
        "../models/convnext_v2_best.weights.h5",
        monitor="val_auroc",
        mode="max",
        save_best_only=True,
        save_weights_only=True,
        verbose=1,
    ),
    TqdmCallback(verbose=1),
]

# Train the model
history = model.fit(
    X_train_tf,
    y_train,
    validation_data=(X_val_tf, y_val),
    epochs=EPOCHS,
    batch_size=BATCH,
    callbacks=callbacks,
    verbose=0,
)

0epoch [00:00, ?epoch/s]

0batch [00:00, ?batch/s]

2025-07-27 21:58:19.024554: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


### Evaluation

In [None]:
def keras_report(model, name, X_split, y_split, plot_cm=True):
    """
    Compute metrics for a Keras binary-classifier (sigmoid output).
    Args:
        name (str): Name of the dataset split (e.g., 'Train', 'Validation', 'External test').
        X_split (np.ndarray): Feature matrix for the split.
        y_split (np.ndarray): True labels for the split.
        plot_cm (bool): Whether to plot the confusion matrix.
    Returns:
        Prints the performance metrics and confusion matrix.
    """
    y_prob = model.predict(X_split, verbose=0).squeeze()
    y_pred = y_prob >= 0.5

    acc = accuracy_score(y_split, y_pred)
    auroc = roc_auc_score(y_split, y_prob)
    auprc = average_precision_score(y_split, y_prob)
    prec, rec, f1, _ = precision_recall_fscore_support(
        y_split, y_pred, average="binary", zero_division=0
    )
    tn, fp, fn, tp = confusion_matrix(y_split, y_pred).ravel()
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

    print(f"{name} metrics")
    print(f"  accuracy     {acc:.3f}")
    print(f"  AUROC        {auroc:.3f}")
    print(f"  AUPRC        {auprc:.3f}")
    print(f"  precision    {prec:.3f}")
    print(f"  recall       {rec:.3f}")
    print(f"  specificity  {specificity:.3f}")
    print(f"  F1           {f1:.3f}\n")

    if plot_cm:
        ConfusionMatrixDisplay(
            confusion_matrix(y_split, y_pred),
            display_labels=["Neg", "Pos"],
        ).plot(cmap="Blues")
        plt.title(f"{name} confusion matrix")
        plt.show()

In [None]:
# Evaluate the model on different splits
keras_report(model, "Train", X_train_tf, y_train)
keras_report(model, "Validation", X_val_tf, y_val)
keras_report(model, "External test", X_test_tf, y_test)

In [None]:
# Plot training curves
plt.figure(figsize=(6, 4))
plt.plot(history.history["auroc"], label="train AUROC")
plt.plot(history.history["val_auroc"], label="val AUROC")
plt.xlabel("epoch")
plt.ylabel("AUROC")
plt.legend()
plt.grid(True)
plt.title("Training progress")
plt.show()

In [None]:
# Save the complete model
model.save("../models/convnext_v2_complete.h5")
print("Model saved successfully!")

# Model parameter count
total_params = model.count_params()
trainable_params = sum(
    [tf.keras.backend.count_params(w) for w in model.trainable_weights]
)
print(f"\nModel Parameters:")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Plot both AUROC and AUPRC training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# AUROC plot
ax1.plot(history.history["auroc"], label="train AUROC", linewidth=2)
ax1.plot(history.history["val_auroc"], label="val AUROC", linewidth=2)
ax1.set_xlabel("Epoch")
ax1.set_ylabel("AUROC")
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_title("AUROC Training Progress")

# AUPRC plot
ax2.plot(history.history["auprc"], label="train AUPRC", linewidth=2)
ax2.plot(history.history["val_auprc"], label="val AUPRC", linewidth=2)
ax2.set_xlabel("Epoch")
ax2.set_ylabel("AUPRC")
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_title("AUPRC Training Progress")

plt.tight_layout()
plt.show()

# Print best epoch info
best_epoch = np.argmax(history.history["val_auroc"])
best_val_auroc = max(history.history["val_auroc"])
best_val_auprc = history.history["val_auprc"][best_epoch]

print(f"\nBest Performance:")
print(f"Epoch: {best_epoch + 1}")
print(f"Validation AUROC: {best_val_auroc:.4f}")
print(f"Validation AUPRC: {best_val_auprc:.4f}")