
# HiMAE PVC Detection 

This notebook wires up the **HiMAE** backbone and a simple **linear probe** for PVC detection, using the code structure from the provided repo. It is designed to run end-to-end once you add your checkpoint paths and dataset files. 

Attached is a folder called `pvc/`which contains most of the source code to run the pipeline


In [13]:

# Optional: install dependencies if your environment doesn't have them
# You can safely skip or adapt this cell.
# %pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# %pip install numpy pandas scikit-learn scipy matplotlib h5py pytorch-lightning tabulate


In [14]:
import os, sys, math, json, time, pathlib, warnings
from dataclasses import dataclass

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

# Optional metrics and utilities
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, accuracy_score, roc_curve
from scipy.signal import resample as sp_resample

import h5py

warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [15]:

# Repo import path. Point this to the 'pvc' package root where utils/ and downstream_eval/ live.
# If you're running next to this notebook, set it to a relative path.
REPO_DIR = r"./pvc"  # change this to your local clone if needed

if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)

# Imports from the repo
from utils.model_arch.himae import HiMAE  # backbone
# You can optionally use helpers if you like:
# from downstream_eval.helpers import summarize_dataset
print("Repo imported from:", REPO_DIR)


Repo imported from: ./pvc


In [16]:

# Paths & configuration — EDIT THESE
#
# HDF5 is expected to contain datasets like: 'ppg' (or 'ecg'), 'labels', and optionally 'patient_ids'.
# The PVC data in the repo's task_definition uses fs=25 Hz and 10-second windows (L = 250).
#
# If you have CSV metadata, set META_PATH as well; otherwise you can set it to None.

H5_PATH   = "./pvc_10s_synth.h5"          # TODO: point to your local H5 file
META_PATH = "./pvc_10s_synth_metadata.csv" # optional; set to None if not used
SIGNAL_KEY = "ppg"                               # or "ecg" depending on your data
LABEL_KEY  = "labels"                            # binary PVC label per segment
PID_KEY    = "patient_ids"                       # optional; used for summaries/splits if present

# HiMAE configuration. Adjust seg_len/sampling_freq to match your dataset.
CFG = dict(
    source=SIGNAL_KEY,          # 'ppg', 'ecg', or 'ppg+ecg'
    sampling_freq=25,           # Hz
    seg_len=10,                 # seconds
    model_params=dict(
        patch_len=50,           # must divide sampling_freq * seg_len; 25/50/125 are common here
    ),
)

# Checkpoint paths — provide one if you have a trained backbone and/or trained probe
BACKBONE_CKPT = "./himae_synth.ckpt"   # can be a .pt/.pth/.ckpt; leave as None if not available
LINEAR_PROBE_CKPT = None  # leave as None if you plan to train it here

BATCH_SIZE = 128
NUM_WORKERS = 1
VAL_SPLIT = 0.2   # random split fraction for validation if you don't have predefined splits
SEED = 7


In [17]:

# Build the HiMAE backbone
torch.manual_seed(SEED)
np.random.seed(SEED)

backbone = HiMAE(CFG).to(device)
backbone.eval()

# Flexible checkpoint loader that works with plain state_dict or PL checkpoints
def load_backbone_weights(model: nn.Module, ckpt_path: str):
    if ckpt_path is None or not os.path.exists(ckpt_path):
        print("No backbone checkpoint provided or path does not exist; skipping weight load.")
        return

    state = torch.load(ckpt_path, map_location="cpu")
    # Handle common checkpoint formats
    if isinstance(state, dict) and "state_dict" in state:
        sd = state["state_dict"]
        # Strip possible prefixes like "model.", "backbone.", etc.
        new_sd = {}
        for k, v in sd.items():
            nk = k
            for pref in ["model.", "backbone.", "net.", "module."]:
                if nk.startswith(pref):
                    nk = nk[len(pref):]
            new_sd[nk] = v
        missing, unexpected = model.load_state_dict(new_sd, strict=False)
    elif isinstance(state, dict):
        missing, unexpected = model.load_state_dict(state, strict=False)
    else:
        raise ValueError("Unrecognized checkpoint format")

    print("Backbone weights loaded. Missing keys:", len(missing), "| Unexpected keys:", len(unexpected))

load_backbone_weights(backbone, BACKBONE_CKPT)


Backbone weights loaded. Missing keys: 0 | Unexpected keys: 0


In [18]:

# Dataset for PVC segments stored in HDF5
class PVCH5Dataset(Dataset):
    def __init__(self, h5_path, signal_key="ppg", label_key="labels", pid_key=None,
                 target_fs=25, seg_len=10, source="ppg"):
        self.h5_path = h5_path
        self.signal_key = signal_key
        self.label_key = label_key
        self.pid_key = pid_key
        self.target_fs = target_fs
        self.seg_len = seg_len
        self.source = source

        if not os.path.exists(h5_path):
            raise FileNotFoundError(f"H5 not found at {h5_path}")

        self.h5 = h5py.File(h5_path, "r")
        self.signals = self.h5[self.signal_key][...]   # shape (N, L) or (N, 1, L)
        self.labels  = self.h5[self.label_key][...].astype(np.int64).reshape(-1)

        if self.pid_key is not None and self.pid_key in self.h5:
            self.pids = self.h5[self.pid_key][...]
        else:
            self.pids = np.arange(len(self.labels))

        # Normalize signals to (N, L)
        if self.signals.ndim == 3 and self.signals.shape[1] == 1:
            self.signals = self.signals[:, 0, :]
        elif self.signals.ndim != 2:
            raise ValueError(f"Expected signals of shape (N, L) or (N,1,L); got {self.signals.shape}")

        self.L = self.signals.shape[1]
        self.expected_L = int(self.target_fs * self.seg_len)
        print(f"Loaded H5 with {len(self)} segments. Input length={self.L}, expected={self.expected_L}")

    def __len__(self):
        return self.labels.shape[0]

    def _ensure_length(self, x):
        if x.shape[-1] == self.expected_L:
            return x
        # Simple resample if lengths mismatch
        x_res = sp_resample(x, self.expected_L)
        return x_res

    def __getitem__(self, idx):
        x = self.signals[idx].astype(np.float32)
        y = int(self.labels[idx])
        pid = self.pids[idx]
        x = self._ensure_length(x)
        # HiMAE expects (B, 1, L). We'll add channel dimension in collate
        return x, y, pid

    def close(self):
        try:
            self.h5.close()
        except Exception:
            pass

def collate_batch(batch):
    xs, ys, pids = zip(*batch)
    xs = torch.from_numpy(np.stack(xs, axis=0)).float().unsqueeze(1)  # (B,1,L)
    ys = torch.tensor(ys).long()
    return xs, ys, pids


In [19]:

# Feature extraction from the HiMAE encoder (bottleneck representation)
@torch.no_grad()
def encode_bottleneck(backbone: nn.Module, x: torch.Tensor) -> torch.Tensor:
    # x: (B,1,L) float tensor
    # returns: (B, C, Tprime) where C is final encoder channels (256 in default config)
    backbone.eval()
    current_x = x
    for enc in backbone.encoder_layers:
        current_x = enc(current_x)
    return current_x  # bottleneck feature map

class LinearProbe(nn.Module):
    def __init__(self, in_dim=256, num_classes=1):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, feats: torch.Tensor):
        # feats: (B, C, T') -> global average pooling over time, then FC
        pooled = feats.mean(dim=-1)
        logits = self.fc(pooled).squeeze(-1)  # (B,)
        return logits


In [20]:

# Data module-ish helpers
def make_loaders(h5_path, cfg, batch_size=128, num_workers=4, val_split=0.2, signal_key="ppg"):
    ds = PVCH5Dataset(
        h5_path=h5_path,
        signal_key=signal_key,
        label_key=LABEL_KEY,
        pid_key=PID_KEY,
        target_fs=cfg["sampling_freq"],
        seg_len=cfg["seg_len"],
        source=cfg["source"],
    )
    N = len(ds)
    n_val = int(round(val_split * N))
    n_train = N - n_val
    g = torch.Generator().manual_seed(SEED)
    train_ds, val_ds = random_split(ds, [n_train, n_val], generator=g)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=True, collate_fn=collate_batch)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=True, collate_fn=collate_batch)
    return ds, train_loader, val_loader


In [21]:

# Training loop for the linear probe only (backbone frozen)
def train_linear_probe(backbone, probe, train_loader, val_loader, epochs=10, lr=1e-3, weight_decay=1e-4):
    backbone.eval()
    for p in backbone.parameters():
        p.requires_grad_(False)

    probe = probe.to(device)
    optim = torch.optim.AdamW(probe.parameters(), lr=lr, weight_decay=weight_decay)
    best_val_auc = -np.inf
    best_state = None

    for ep in range(1, epochs+1):
        probe.train()
        running_loss = 0.0
        for xb, yb, _ in train_loader:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()
            feats = encode_bottleneck(backbone, xb)
            logits = probe(feats)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, yb)
            optim.zero_grad(set_to_none=True)
            loss.backward()
            optim.step()
            running_loss += loss.item() * xb.size(0)

        # Validation
        probe.eval()
        y_true, y_score = [], []
        with torch.no_grad():
            for xb, yb, _ in val_loader:
                xb = xb.to(device, non_blocking=True)
                yb = yb.to(device, non_blocking=True).float()
                feats = encode_bottleneck(backbone, xb)
                logits = probe(feats)
                y_true.append(yb.cpu().numpy())
                y_score.append(torch.sigmoid(logits).cpu().numpy())
        y_true = np.concatenate(y_true)
        y_score = np.concatenate(y_score)
        try:
            val_auc = roc_auc_score(y_true, y_score)
            val_ap  = average_precision_score(y_true, y_score)
        except ValueError:
            val_auc, val_ap = float("nan"), float("nan")

        print(f"Epoch {ep:02d} | train_loss={running_loss/len(train_loader.dataset):.4f} | "
              f"val_auc={val_auc:.4f} | val_ap={val_ap:.4f}")

        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_state = {k: v.cpu() for k, v in probe.state_dict().items()}

    if best_state is not None:
        probe.load_state_dict(best_state)
    return probe, best_val_auc


In [22]:

# End-to-end wiring (disabled by default). Uncomment to run after editing paths.
ds, train_loader, val_loader = make_loaders(H5_PATH, CFG, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
                                            val_split=VAL_SPLIT, signal_key=SIGNAL_KEY)
probe = LinearProbe(in_dim=256, num_classes=1)

probe, best_val_auc = train_linear_probe(backbone, probe, train_loader, val_loader, epochs=10, lr=1e-3)
print("Best val AUC:", best_val_auc)
torch.save(probe.state_dict(), "pvc_linear_probe.pt")
print("Saved probe to pvc_linear_probe.pt")


Loaded H5 with 55861 segments. Input length=250, expected=250
Epoch 01 | train_loss=183.0066 | val_auc=0.5735 | val_ap=0.0695
Epoch 02 | train_loss=41.4369 | val_auc=0.5714 | val_ap=0.0731
Epoch 03 | train_loss=41.6830 | val_auc=0.6352 | val_ap=0.1178
Epoch 04 | train_loss=33.4461 | val_auc=0.6954 | val_ap=0.1136
Epoch 05 | train_loss=37.2827 | val_auc=0.6582 | val_ap=0.1210
Epoch 06 | train_loss=39.2719 | val_auc=0.6746 | val_ap=0.1049
Epoch 07 | train_loss=31.3521 | val_auc=0.5214 | val_ap=0.0603
Epoch 08 | train_loss=62.3732 | val_auc=0.7663 | val_ap=0.1162
Epoch 09 | train_loss=31.3644 | val_auc=0.6081 | val_ap=0.1217
Epoch 10 | train_loss=25.5817 | val_auc=0.5083 | val_ap=0.0506
Best val AUC: 0.7662564648241595
Saved probe to pvc_linear_probe.pt


In [23]:

# Inference helper to get PVC probabilities for a dataloader
@torch.no_grad()
def infer_pvc_probs(backbone, probe, loader):
    backbone.eval()
    probe.eval()
    all_pids, all_probs, all_labels = [], [], []
    for xb, yb, pids in loader:
        xb = xb.to(device, non_blocking=True)
        feats = encode_bottleneck(backbone, xb)
        logits = probe(feats)
        probs = torch.sigmoid(logits).cpu().numpy()
        all_probs.append(probs)
        all_labels.append(np.asarray(yb))
        all_pids.extend(list(pids))
    return np.concatenate(all_probs), np.concatenate(all_labels), np.asarray(all_pids)


In [24]:

# Example: compute metrics and export predictions (disabled by default)
test_loader = val_loader  
probs, labels, pids = infer_pvc_probs(backbone, probe, test_loader)
auc = roc_auc_score(labels, probs)
ap  = average_precision_score(labels, probs)
print("Test AUC:", auc, "AP:", ap)
df = pd.DataFrame({"patient_id": pids, "label": labels, "p_pvc": probs})
df.to_csv("pvc_predictions.csv", index=False)
print("Wrote pvc_predictions.csv")


Test AUC: 0.7662564648241595 AP: 0.11621493155214786
Wrote pvc_predictions.csv
