PyTorch implementation of the ST-SAE pipeline described in "A Spatio-Temporal Sparse Autoencoder for Transferable Clinical Phenotyping and Vascular Aging" (MLHC submission).
The pipeline learns a sparse, temporally-coherent, perturbation-stable concept basis on top of frozen PaPaGei PPG embeddings, then uses those concepts to predict chronological age and compute Δage (vascular age acceleration) for downstream cardiovascular risk analysis.
configs/ default.yaml — all hyperparameters, paths, ablations, K-sweep
src/
data/ MIMIC-IV / PulseDB loaders, windowing, QC, augmentation
models/ PaPaGei wrapper, ST-SAE (+ Top-K variant), PCA/ICA/NMF baselines
train/ Loss components + training loop with decoder renormalization
eval/ EV, L0, dead-feature, consistency, lag-1 autocorr,
seed-stability, probe recovery
aging/ Age prediction + bias-corrected Δage
grounding/ Pulse alignment + classical PPG morphology + concept viz
utils/ Patient-level splits + seeding
scripts/
00_setup_papagei.sh — clone + download PaPaGei
01_extract_embeddings.py — raw PPG → cached PaPaGei embeddings
02_train_stsae.py — train ST-SAE (full / ablations / K-sweep)
03_run_baselines.py — fit PCA / ICA / NMF
04_eval_all.py — full eval suite, leaderboard JSON
05_aging_analysis.py — Phase II: age prediction + Δage
06_waveform_grounding.py — top concept → average pulse + morphology corr
07_make_figures.py — Pareto, ablation bars, Δage scatter
run_pipeline.sh — end-to-end orchestrator
tests/
test_synthetic.py — pipeline smoke test, no real data needed
DATA_GUIDE.md — how to obtain PaPaGei weights + PulseDB + MIMIC-IV
requirements.txt
conda(Miniconda or Anaconda)git,curl- For GPU: NVIDIA driver ≥ 530 (CUDA 12.1 wheels). Older driver? Edit
setup.shto usecu118wheels. - ~5 GB disk for the env + PaPaGei weights (datasets are separate, see
DATA_GUIDE.md)
git clone https://github.com/<you>/<repo>.git stsae && cd stsae
bash setup.sh # GPU setup with CUDA 12.1
# or
bash setup.sh --cpu # laptop / no GPUThis will:
- Create a conda env named
stsae(Python 3.10) - Install PyTorch + all dependencies
- Clone PaPaGei + download
papagei_s.ptweights from Zenodo (~70 MB) - Run the synthetic smoke test to verify everything works
If anything fails the script will say so loudly. You should see
[ALL TESTS PASSED] at the end.
Datasets are the bottleneck — PhysioNet credentialing takes 1–2 weeks. Start with PulseDB (no credentialing) while you wait for MIMIC.
Set the data paths:
paths:
papagei_repo: /abs/path/to/papagei-foundation-model
papagei_weights: /abs/path/to/papagei-foundation-model/weights/papagei_s.pt
pulsedb_root: /abs/path/to/data/PulseDB # or
mimic_iii_root: /abs/path/to/mimic3wdb-matched/1.0 # or
mimic_iv_root: /abs/path/to/mimic4wdb/0.1.0conda activate stsae
# Edit scripts/run_pipeline.sh: set SOURCE=mimic_iii / mimic_iv / pulsedb
bash scripts/run_pipeline.shOr run stages individually — each script has --help.
| Concept-note element | Code location |
|---|---|
| L_rec, L_sp, L_div, L_stab, L_smooth | src/train/losses.py |
| Decoder unit-norm renormalization | STSAE.renormalize_decoder() (after every step) |
| Triplet (z_t, z_{t+1}, z_aug_t) batching | src/data/dataset.py:WindowTripletDataset |
| Waveform-level augmentations (re-encoded) | src/data/augmentations.py + script 01 |
| K-scaling sweep {1024,2048,4096,8192} | configs/default.yaml: d_dict_sweep |
| Five ablations | configs/default.yaml: ablations |
| Patient-level age-stratified split | src/utils/splits.py |
| Phase I clinical phenotypes | src/eval/probe.py:classification_probe (call with HTN/T2D labels) |
| Phase II Δage with bias correction | src/aging/age_pred.py:aging_pipeline |
| Concept grounding (high vs low avg pulse) | src/grounding/concept_viz.py:ground_concept |
| Morphology correlation table | correlate_concepts_with_morphology |
| Pareto frontier reporting | src/eval/reconstruction.py:pareto_frontier |
| Seed/subsample stability | src/eval/stability.py:seed_stability (Hungarian matching) |
- Embedding extraction (script 01) is the bottleneck. PaPaGei is ~5.7M parameters; on an A100, expect ~30k windows/min. For ~5000 PulseDB subjects × ~250 windows = ~1.25M windows ≈ 40 minutes. Cache once, reuse.
- ST-SAE training (script 02) is fast — z is only 512-d, K up to 8192, so a single epoch over 1M (z, z_next, z_aug) triplets is ~2 minutes on A100.
- Eval (script 04) is mostly probe fitting; runs in minutes.
- Total walltime end-to-end: ~3 hours on a single A100 once the data is on local disk, dominated by extraction and the K-sweep.
Embeddings are stored as float32 memory-mapped arrays:
- z.npy: N × 512 × 4 = 2.0 GB per million windows
- z_aug.npy: same
- windows.npy: N × 1250 × 4 = 5.0 GB per million windows Plan disk accordingly.
papagei_wrapper.preprocess_windowuses scipy butterworth instead of pyPPG's Chebyshev II. For exact upstream parity, install pyPPG and callpreprocessing.ppg.preprocess_one_ppg_signalfrom the cloned repo. The scipy version is close enough for downstream training but may shift EV by 1–2% versus published PaPaGei numbers.- PulseDB segments are pre-filtered, so QC will pass essentially everything.
This is fine — but for MIMIC-IV raw, the QC heuristics in
src/data/windowing.py:signal_quality_maskare not exhaustive. Consider pyPPG's signal-quality index for a more rigorous filter on noisy ICU data. aging_pipelineuses ElasticNetCV for concept selection. Reviewers might prefer a stability-selection wrapper (e.g. randomized lasso); easy to plug in by replacingfit_age_predictor.
Code in src/, scripts/, configs/ is your code (set whichever license
you want). PaPaGei is BSD-3-Clause (see its repo). Datasets carry their
own DUAs — observe them.