Skip to content

JayDuan123/stsae

Repository files navigation

ST-SAE: Spatio-Temporal Sparse Autoencoder for PPG Foundation Models

tests python license

PyTorch implementation of the ST-SAE pipeline described in "A Spatio-Temporal Sparse Autoencoder for Transferable Clinical Phenotyping and Vascular Aging" (MLHC submission).

The pipeline learns a sparse, temporally-coherent, perturbation-stable concept basis on top of frozen PaPaGei PPG embeddings, then uses those concepts to predict chronological age and compute Δage (vascular age acceleration) for downstream cardiovascular risk analysis.


What's in here

configs/        default.yaml — all hyperparameters, paths, ablations, K-sweep
src/
  data/         MIMIC-IV / PulseDB loaders, windowing, QC, augmentation
  models/       PaPaGei wrapper, ST-SAE (+ Top-K variant), PCA/ICA/NMF baselines
  train/        Loss components + training loop with decoder renormalization
  eval/         EV, L0, dead-feature, consistency, lag-1 autocorr,
                seed-stability, probe recovery
  aging/        Age prediction + bias-corrected Δage
  grounding/    Pulse alignment + classical PPG morphology + concept viz
  utils/        Patient-level splits + seeding
scripts/
  00_setup_papagei.sh          — clone + download PaPaGei
  01_extract_embeddings.py     — raw PPG → cached PaPaGei embeddings
  02_train_stsae.py            — train ST-SAE (full / ablations / K-sweep)
  03_run_baselines.py          — fit PCA / ICA / NMF
  04_eval_all.py               — full eval suite, leaderboard JSON
  05_aging_analysis.py         — Phase II: age prediction + Δage
  06_waveform_grounding.py     — top concept → average pulse + morphology corr
  07_make_figures.py           — Pareto, ablation bars, Δage scatter
  run_pipeline.sh              — end-to-end orchestrator
tests/
  test_synthetic.py            — pipeline smoke test, no real data needed
DATA_GUIDE.md                  — how to obtain PaPaGei weights + PulseDB + MIMIC-IV
requirements.txt

Quickstart

0. Prerequisites

  • conda (Miniconda or Anaconda)
  • git, curl
  • For GPU: NVIDIA driver ≥ 530 (CUDA 12.1 wheels). Older driver? Edit setup.sh to use cu118 wheels.
  • ~5 GB disk for the env + PaPaGei weights (datasets are separate, see DATA_GUIDE.md)

1. One-shot install

git clone https://github.com/<you>/<repo>.git stsae && cd stsae
bash setup.sh                  # GPU setup with CUDA 12.1
# or
bash setup.sh --cpu            # laptop / no GPU

This will:

  • Create a conda env named stsae (Python 3.10)
  • Install PyTorch + all dependencies
  • Clone PaPaGei + download papagei_s.pt weights from Zenodo (~70 MB)
  • Run the synthetic smoke test to verify everything works

If anything fails the script will say so loudly. You should see [ALL TESTS PASSED] at the end.

2. Read DATA_GUIDE.md

Datasets are the bottleneck — PhysioNet credentialing takes 1–2 weeks. Start with PulseDB (no credentialing) while you wait for MIMIC.

3. Edit configs/default.yaml

Set the data paths:

paths:
  papagei_repo:    /abs/path/to/papagei-foundation-model
  papagei_weights: /abs/path/to/papagei-foundation-model/weights/papagei_s.pt
  pulsedb_root:    /abs/path/to/data/PulseDB                    # or
  mimic_iii_root:  /abs/path/to/mimic3wdb-matched/1.0           # or
  mimic_iv_root:   /abs/path/to/mimic4wdb/0.1.0

4. Run the pipeline

conda activate stsae
# Edit scripts/run_pipeline.sh: set SOURCE=mimic_iii / mimic_iv / pulsedb
bash scripts/run_pipeline.sh

Or run stages individually — each script has --help.


Reproducing the design choices

Concept-note element Code location
L_rec, L_sp, L_div, L_stab, L_smooth src/train/losses.py
Decoder unit-norm renormalization STSAE.renormalize_decoder() (after every step)
Triplet (z_t, z_{t+1}, z_aug_t) batching src/data/dataset.py:WindowTripletDataset
Waveform-level augmentations (re-encoded) src/data/augmentations.py + script 01
K-scaling sweep {1024,2048,4096,8192} configs/default.yaml: d_dict_sweep
Five ablations configs/default.yaml: ablations
Patient-level age-stratified split src/utils/splits.py
Phase I clinical phenotypes src/eval/probe.py:classification_probe (call with HTN/T2D labels)
Phase II Δage with bias correction src/aging/age_pred.py:aging_pipeline
Concept grounding (high vs low avg pulse) src/grounding/concept_viz.py:ground_concept
Morphology correlation table correlate_concepts_with_morphology
Pareto frontier reporting src/eval/reconstruction.py:pareto_frontier
Seed/subsample stability src/eval/stability.py:seed_stability (Hungarian matching)

Scaling notes

  • Embedding extraction (script 01) is the bottleneck. PaPaGei is ~5.7M parameters; on an A100, expect ~30k windows/min. For ~5000 PulseDB subjects × ~250 windows = ~1.25M windows ≈ 40 minutes. Cache once, reuse.
  • ST-SAE training (script 02) is fast — z is only 512-d, K up to 8192, so a single epoch over 1M (z, z_next, z_aug) triplets is ~2 minutes on A100.
  • Eval (script 04) is mostly probe fitting; runs in minutes.
  • Total walltime end-to-end: ~3 hours on a single A100 once the data is on local disk, dominated by extraction and the K-sweep.

Memory

Embeddings are stored as float32 memory-mapped arrays:

  • z.npy: N × 512 × 4 = 2.0 GB per million windows
  • z_aug.npy: same
  • windows.npy: N × 1250 × 4 = 5.0 GB per million windows Plan disk accordingly.

Known caveats / things to verify

  • papagei_wrapper.preprocess_window uses scipy butterworth instead of pyPPG's Chebyshev II. For exact upstream parity, install pyPPG and call preprocessing.ppg.preprocess_one_ppg_signal from the cloned repo. The scipy version is close enough for downstream training but may shift EV by 1–2% versus published PaPaGei numbers.
  • PulseDB segments are pre-filtered, so QC will pass essentially everything. This is fine — but for MIMIC-IV raw, the QC heuristics in src/data/windowing.py:signal_quality_mask are not exhaustive. Consider pyPPG's signal-quality index for a more rigorous filter on noisy ICU data.
  • aging_pipeline uses ElasticNetCV for concept selection. Reviewers might prefer a stability-selection wrapper (e.g. randomized lasso); easy to plug in by replacing fit_age_predictor.

License

Code in src/, scripts/, configs/ is your code (set whichever license you want). PaPaGei is BSD-3-Clause (see its repo). Datasets carry their own DUAs — observe them.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors