Official implementation of "NeuroMamba: A State-Space Foundation Model for Functional MRI" NeurIPS 2025 Workshop on Foundation Models for the Brain and Body (BrainBodyFM)
Authors: Jubin Choi¹, David Keetae Park², Junbeom Kwon³, Shinjae Yoo², Jiook Cha¹* ¹Seoul National University, ²Brookhaven National Laboratory, ³The University of Texas at Austin *Corresponding author: connectome@snu.ac.kr
📄 Paper PDF · 🔗 OpenReview · ⬇ Pretrained weights (v1.0.0)
- Overview
- Architecture
- Installation
- Pretrained Checkpoints
- Pre-training
- Downstream: HCP Sex Classification
- Results
- Adding Your Own Dataset
- Repository Layout
- Citation
- Acknowledgments
- License
- Contact
NeuroMamba enables direct sequence modeling of 4D whole-brain fMRI. Prior fMRI foundation models either (a) reduce fMRI to region-of-interest (ROI) time series and lose fine-grained spatial detail, or (b) use rigid grid-based hierarchical models that waste compute on ~60% non-brain background. NeuroMamba breaks this trade-off with two key ideas:
- Adaptive background removal via patch-wise tokenization — discards non-brain tokens before processing, yielding 46.5% FLOPs reduction.
- NeRF-style frequency-based positional encoding of continuous (x, y, z, t) patch coordinates — robust to inter-subject anatomical variation.
The result: a Mamba2-based foundation model pre-trained autoregressively on >50,000 subjects (UK Biobank + ABCD + HCP), achieving state-of-the-art HCP sex classification (94.9% ACC, 98.9% AUC at 3.1M params).
Figure 2 of the paper. A 4D fMRI volume is divided into 6×6×6×2 patches, augmented with NeRF positional encoding, stripped of non-brain background tokens, and processed by a stack of 12 Mamba2 blocks trained with autoregressive next-token prediction (temporal-first raster order: t → z → y → x).
Model variants (paper Appendix B):
| Variant | Parameters | embed_dim |
d_state |
headdim |
Stable? |
|---|---|---|---|---|---|
| NeuroMamba-Small | 1.4M | 128 | 32 | 128 | ✅ |
| NeuroMamba-Base (paper main) | 3.1M | 192 | 48 | 192 | ✅ |
| NeuroMamba-Medium | 5.4M | 256 | 64 | 256 | ✅ |
| NeuroMamba-Large | 11.9M | 384 | — | — |
Common across all sizes: 12 Mamba2 blocks, ngroups=1, expand=2, patch shape 6×6×6×2, input shape 96×96×96×T.
⚠️ PyTorch Lightning version: This codebase targets thepytorch_lightning 1.9.xseries (verified with1.9.5). PL 2.x removedTrainer.add_argparse_argsand other 1.x APIs thatproject/main.pyrelies on — using PL ≥ 2.0 will raiseAttributeError: type object 'Trainer' has no attribute 'add_argparse_args'immediately at startup.requirements.txtpins this withpytorch-lightning>=1.9.0,<2.0. If your environment already has a newer PL, force the downgrade explicitly:pip install --force-reinstall --no-deps "pytorch-lightning==1.9.5"
# Clone
git clone https://github.com/Transconnectome/NeuroMamba.git
cd NeuroMamba
# Create venv (Python 3.10+)
python -m venv neuromamba_env
source neuromamba_env/bin/activate
# 1. Install PyTorch FIRST (platform-specific).
# - x86 + NVIDIA (CUDA 12.x):
pip install torch torchvision torchaudio
# - ARM64 / Blackwell (e.g. GB10, sm_121): use the matching sbsa CUDA wheel index, e.g.
# pip install --index-url https://download.pytorch.org/whl/cu128 torch torchvision torchaudio
# - Or skip this step entirely if you use an NGC PyTorch container.
# 2. Install the remaining dependencies
pip install -r requirements.txt
# 3. Install the Mamba CUDA kernels LAST, WITHOUT build isolation.
# (mamba-ssm / causal-conv1d setup.py does `import torch`, which fails
# under pip's build isolation — so they are deliberately NOT in
# requirements.txt. Requires a CUDA-compatible GPU + nvcc.)
pip install mamba-ssm causal-conv1d --no-build-isolationOr use the bundled environment installer (Linux/HPC) — it handles the build-isolation ordering automatically:
export NEUROMAMBA_VENV_PATH="$HOME/neuromamba_env" # destination
bash setup_env.shRequired env vars when running:
export PYTHONPATH=$PWD:$PYTHONPATH
export OMP_NUM_THREADS=1
export TORCH_EXTENSIONS_DIR=$PWD/deepspeed
export TRITON_CACHE_DIR="$PWD/.triton"Before launching a multi-node run, verify the install on one GPU (A5000 / A6000 / RTX 3090 / 4090, ≥24 GB VRAM, bf16-capable):
sbatch sample_scripts/lab_server_smoke_test/slurm_a5000_dummy_smoke.shSee sample_scripts/lab_server_smoke_test/README.md for both a dummy-data smoke test (~10 min) and a real HCP fine-tune smoke test (~1–2 h).
Three pretrained backbones are released as v1.0.0 GitHub assets:
| Asset | embed_dim |
d_state |
headdim |
Size |
|---|---|---|---|---|
1M_E128.pt |
128 | 32 | 128 | 25 MB |
3M_E192.pt (paper main) |
192 | 48 | 192 | 42 MB |
5M_E256.pt |
256 | 64 | 256 | 61 MB |
import torch
ckpt = torch.load("3M_E192.pt", map_location="cpu", weights_only=False)
# Flat top-level state_dict (113 keys), prefix "_forward_module.model.<...>" (DeepSpeed/DDP).
# Architecture: project/module/models/fmamba.py:FMamba
# Note: depthwise conv is named `conv_xBC` in mamba_ssm < 2.x, `conv1d` in ≥ 2.x.NeuroMamba was pre-trained on resting-state fMRI from:
- UK Biobank (N=40,647)
- ABCD (N=9,139)
- HCP Young Adult (N=1,084)
All data preprocessed with standard pipelines (bias field correction → skull stripping → MNI alignment → crop/pad to 96×96×96) and stored as per-TR HDF5 files. Per-run global stats are computed and applied with the z-norm + minback scaling option by default (see Adding Your Own Dataset).
80/10/10 split across all subjects. HCP test set is held out for downstream evaluation.
python project/main.py \
--model fmamba \
--pretraining \
--use_autoregressive \
--calc_loss_without_background \
--embed_dim 192 --depth 12 \
--patch_size 6 6 6 2 \
--img_size 96 96 96 20 \
--batch_size 3 \
--gradient_accumulation_steps 1 \
--precision bf16 \
--dataset_name S1200 \
--image_path /path/to/hdf5 \
--max_epochs 40 \
--learning_rate 1e-4 \
--weight_decay 1e-2 \
--accelerator gpu --devices 12 --num_nodes 32 \
--strategy deepspeed_stage_1 \
--loggername wandb \
--project_name neuromamba_pretrainSee sample_scripts/pretraining/ for full PBS-ready scripts.
Paper hyperparameters (Appendix B):
| -Small | -Base | -Medium | -Large | |
|---|---|---|---|---|
| GPU nodes (12 ranks/node) | 32 | 32 | 32 | 48 |
| DeepSpeed ZeRO stage | 1 | 1 | 1 | 3 |
| Micro batch / GPU | 3 | 3 | 3 | 2 |
| Effective batch size | 1,152 | 1,152 | 1,152 | 1,152 |
| Epochs | 40 | 40 | 40 | 40 |
| Learning rate | 1e-4 | 1e-4 | 1e-4 | 1e-4 |
| Weight decay | 1e-2 | 1e-2 | 1e-2 | 1e-2 |
| LR schedule | CosineAnnealingWarmUpRestarts, 5% warmup, min=1% of peak |
Two head architectures (paper Appendix C):
- Linear head:
AdaptiveAvgPool1d(1)→Linear(embed_dim, num_classes) - Mamba head: 3 Mamba2 blocks (
d_state = embed_dim/16) → final-tokenLinear(embed_dim, num_classes)
Three training modes:
| Mode | When to use |
|---|---|
| From scratch | No pretrained backbone (control) |
| Full fine-tuning | Pretrained backbone + head, all weights trainable (paper main result) |
| Frozen backbone | Pretrained backbone frozen, head only — used in paper §4.3 background-removal ablation (Table 1) |
python project/main.py \
--model fmamba \
--downstream_task sex \
--downstream_task_type classification \
--dataset_name HCP \
--image_path /path/to/hcp_hdf5 \
--load_model_path /path/to/3M_E192.pt \
--embed_dim 192 --depth 12 \
--clf_head_version mamba \
--patch_size 6 6 6 2 \
--img_size 96 96 96 20 \
--batch_size 2 \
--precision bf16 \
--max_epochs 30 \
--accelerator gpu --devices 12 --num_nodes 2 \
--strategy deepspeed_stage_1 \
--loggername wandb \
--project_name neuromamba_hcp_sexTo freeze the backbone and train only the head (paper §4.3 setup), add:
--freeze_feature_extractorTo train from scratch (no pretrained weights), drop --load_model_path.
See sample_scripts/downstream/HCP/ for full PBS-ready scripts. Each evaluation in the paper is repeated three times with different random seeds; report mean ± std.
Frozen backbone (NeuroMamba-Medium, 5.4M params), HCP sex:
| Head | Tokens | AUC | ACC |
|---|---|---|---|
| Linear | All tokens | 0.7432±0.00 | 0.6882±0.01 |
| Linear | Brain only | 0.7769±0.00 | 0.7156±0.01 |
| Mamba | All tokens | 0.8456±0.05 | 0.7767±0.07 |
| Mamba | Brain only | 0.8628±0.08 | 0.8104±0.09 |
Computational cost: 6.96×10¹⁷ → 3.72×10¹⁷ FLOPs/epoch (46.5% ↓).
| Size | Head | From scratch (AUC / ACC) | Full FT (AUC / ACC) |
|---|---|---|---|
| 1.4M | Linear | 0.9767 / 0.9427 | 0.9840 / 0.9458 |
| 1.4M | Mamba | 0.9446 / 0.8872 | 0.9535 / 0.8781 |
| 3.1M | Linear | 0.9813 / 0.9232 | 0.9885 / 0.9455 |
| 3.1M | Mamba | 0.9825 / 0.9396 | 0.9874 / 0.9486 |
| 5.4M | Linear | 0.9865 / 0.9347 | 0.9717 / 0.9198 |
| 5.4M | Mamba | 0.9729 / 0.9175 | 0.9766 / 0.9307 |
| Model | Params | Split | AUC | ACC |
|---|---|---|---|---|
| SwiFT | 4.6M | 70/15/15 | 98.0 | 92.9 |
| NeuroSTORM | 5.0M | 70/15/15 | 97.6 | 93.3 |
| NeuroMamba | 3.1M | 80/10/10 | 98.9 | 94.9 |
Note: Paper §5 acknowledges the split difference and plans to re-implement baselines on the 80/10/10 split for a like-for-like comparison.
The DeepSpeed flops profiler was used (paper §4.1). To reproduce:
python project/check_flops.py \
--model fmamba \
--embed_dim 256 --depth 12 \
--patch_size 6 6 6 2 \
--img_size 96 96 96 20 \
--batch_size 1 \
--image_path /path/to/hcp_hdf5 \
--dataset_name HCP \
--downstream_task sexTo pre-train or fine-tune NeuroMamba on a new fMRI dataset, follow the protocol in ADDING_YOUR_DATASET.md:
- Standard preprocessing (HCP-style): bias field correction → skull stripping → MNI normalization → crop/pad to 96×96×96.
- Per-run global statistics computation.
- Scaling option choice:
znorm_minback(paper default),znorm_zeroback,minmax_zeroback, orminmax_minback. - Per-TR HDF5 conversion.
- Add a dataset class entry to
project/module/utils/data_module.py.
See the doc for the full step-by-step protocol with example commands.
NeuroMamba/
├── README.md
├── ADDING_YOUR_DATASET.md # Step-by-step new-dataset protocol
├── CITATION.bib
├── LICENSE # Apache 2.0
├── requirements.txt
├── setup_env.sh # Optional HPC venv installer
├── paper/
│ └── NeuroMamba_NeurIPS2025_BrainBodyFM.pdf
├── assets/
│ └── architecture_overview.png # Paper Figure 2
├── figures/ # (additional figures, optional)
├── configs/
│ └── wandb.yaml.example
├── project/
│ ├── main.py # Training entry point
│ ├── check_flops.py # FLOPs profiling (paper §4.1)
│ ├── deepspeed_profile.py # DeepSpeed profiling utility
│ └── module/
│ ├── pl_classifier.py # PyTorch Lightning module
│ ├── models/
│ │ ├── fmamba.py # NeuroMamba backbone (paper §3)
│ │ ├── load_model.py
│ │ └── utils.py
│ └── utils/
│ ├── data_module.py
│ ├── data_utils.py
│ ├── augment.py
│ ├── losses.py
│ ├── lr_scheduler.py # CosineAnnealingWarmUpRestarts
│ ├── masking_generator.py
│ ├── metrics.py
│ ├── parser.py
│ ├── seed_creation.py
│ └── data_preprocess_and_load/
│ ├── preprocessing.py # Generic HCP-style NIFTI→TR
│ ├── preprocessing.slurm
│ ├── preprocessing_HCP_v2.py
│ ├── datasets.py
│ ├── datasets_hdf5.py
│ ├── check_and_clean_hdf5.py
│ ├── convert_data_into_hdf5_revised.py
│ └── convert_data_into_hdf5_revised_nogzip.py
└── sample_scripts/
├── pretraining/ # Pre-training reference scripts (Aurora)
├── downstream/
│ ├── HCP/ # Paper HCP sex/age/int experiments
│ │ # (5M_*, E128_*, E192_* full / TL variants)
│ └── 5M_fmamba_*.sh # Quick-start single-GPU references
└── lab_server_smoke_test/ # Generic SLURM smoke tests (any A100/Ampere)
├── slurm_a5000_dummy_smoke.sh # Install sanity (no data, ~10 min)
├── slurm_a5000_dummy_finetune.sh # .pt-ckpt load + freeze/head toggles
└── slurm_a5000_hcp_finetune_smoke.sh # Real HCP + 5-epoch fine-tune
@inproceedings{choi2025neuromamba,
title={NeuroMamba: A State-Space Foundation Model for Functional MRI},
author={Choi, Jubin and Park, David Keetae and Kwon, Junbeom and Yoo, Shinjae and Cha, Jiook},
booktitle={39th Conference on Neural Information Processing Systems (NeurIPS 2025) Workshop: Foundation Models for the Brain and Body},
year={2025},
url={https://openreview.net/forum?id=kftg4lmQi8}
}This work was supported by:
- National Research Foundation of Korea (NRF) and IITP grants funded by the Korea government (MSIT)
- Creative-Pioneering Researchers Program through Seoul National University
- Korea Brain Research Institute (KBRI) basic research program
- Korea Health Industry Development Institute (KHIDI), Ministry of Health and Welfare
- Korea Basic Science Institute (NRFEC) grant
- U.S. Department of Energy (DOE) ASCR Leadership Computing Challenge (ALCC), under award m4750-2024
- National Energy Research Scientific Computing Center (NERSC), DOE Office of Science User Facility
- Argonne (ALCF) and Oak Ridge (OLCF) Leadership Computing Facilities
- BNL: U.S. DOE Office of Science, ASCR (DE-SC-0012704)
We acknowledge the National Supercomputing Center (KSC-2023-CRE-0568) for computational resources and technical support.
Apache License 2.0 — see LICENSE.
- Corresponding Author: Jiook Cha · connectome@snu.ac.kr
- First Author: Jubin Choi · wnqlszoq123@snu.ac.kr
- Issues: GitHub Issues
