Skip to content

Transconnectome/NeuroMamba

Repository files navigation

NeuroMamba: A State-Space Foundation Model for Functional MRI

Paper Python 3.10+ License

Official implementation of "NeuroMamba: A State-Space Foundation Model for Functional MRI" NeurIPS 2025 Workshop on Foundation Models for the Brain and Body (BrainBodyFM)

Authors: Jubin Choi¹, David Keetae Park², Junbeom Kwon³, Shinjae Yoo², Jiook Cha¹* ¹Seoul National University, ²Brookhaven National Laboratory, ³The University of Texas at Austin *Corresponding author: connectome@snu.ac.kr

📄 Paper PDF · 🔗 OpenReview · ⬇ Pretrained weights (v1.0.0)


Table of Contents

  1. Overview
  2. Architecture
  3. Installation
  4. Pretrained Checkpoints
  5. Pre-training
  6. Downstream: HCP Sex Classification
  7. Results
  8. Adding Your Own Dataset
  9. Repository Layout
  10. Citation
  11. Acknowledgments
  12. License
  13. Contact

Overview

NeuroMamba enables direct sequence modeling of 4D whole-brain fMRI. Prior fMRI foundation models either (a) reduce fMRI to region-of-interest (ROI) time series and lose fine-grained spatial detail, or (b) use rigid grid-based hierarchical models that waste compute on ~60% non-brain background. NeuroMamba breaks this trade-off with two key ideas:

  • Adaptive background removal via patch-wise tokenization — discards non-brain tokens before processing, yielding 46.5% FLOPs reduction.
  • NeRF-style frequency-based positional encoding of continuous (x, y, z, t) patch coordinates — robust to inter-subject anatomical variation.

The result: a Mamba2-based foundation model pre-trained autoregressively on >50,000 subjects (UK Biobank + ABCD + HCP), achieving state-of-the-art HCP sex classification (94.9% ACC, 98.9% AUC at 3.1M params).


Architecture

NeuroMamba Pre-training Pipeline (Figure 2)

Figure 2 of the paper. A 4D fMRI volume is divided into 6×6×6×2 patches, augmented with NeRF positional encoding, stripped of non-brain background tokens, and processed by a stack of 12 Mamba2 blocks trained with autoregressive next-token prediction (temporal-first raster order: t → z → y → x).

Model variants (paper Appendix B):

Variant Parameters embed_dim d_state headdim Stable?
NeuroMamba-Small 1.4M 128 32 128
NeuroMamba-Base (paper main) 3.1M 192 48 192
NeuroMamba-Medium 5.4M 256 64 256
NeuroMamba-Large 11.9M 384 ⚠️ training instability (see paper §5)

Common across all sizes: 12 Mamba2 blocks, ngroups=1, expand=2, patch shape 6×6×6×2, input shape 96×96×96×T.


Installation

⚠️ PyTorch Lightning version: This codebase targets the pytorch_lightning 1.9.x series (verified with 1.9.5). PL 2.x removed Trainer.add_argparse_args and other 1.x APIs that project/main.py relies on — using PL ≥ 2.0 will raise AttributeError: type object 'Trainer' has no attribute 'add_argparse_args' immediately at startup. requirements.txt pins this with pytorch-lightning>=1.9.0,<2.0. If your environment already has a newer PL, force the downgrade explicitly:

pip install --force-reinstall --no-deps "pytorch-lightning==1.9.5"
# Clone
git clone https://github.com/Transconnectome/NeuroMamba.git
cd NeuroMamba

# Create venv (Python 3.10+)
python -m venv neuromamba_env
source neuromamba_env/bin/activate

# 1. Install PyTorch FIRST (platform-specific).
#    - x86 + NVIDIA (CUDA 12.x):
pip install torch torchvision torchaudio
#    - ARM64 / Blackwell (e.g. GB10, sm_121): use the matching sbsa CUDA wheel index, e.g.
#      pip install --index-url https://download.pytorch.org/whl/cu128 torch torchvision torchaudio
#    - Or skip this step entirely if you use an NGC PyTorch container.

# 2. Install the remaining dependencies
pip install -r requirements.txt

# 3. Install the Mamba CUDA kernels LAST, WITHOUT build isolation.
#    (mamba-ssm / causal-conv1d setup.py does `import torch`, which fails
#     under pip's build isolation — so they are deliberately NOT in
#     requirements.txt. Requires a CUDA-compatible GPU + nvcc.)
pip install mamba-ssm causal-conv1d --no-build-isolation

Or use the bundled environment installer (Linux/HPC) — it handles the build-isolation ordering automatically:

export NEUROMAMBA_VENV_PATH="$HOME/neuromamba_env"  # destination
bash setup_env.sh

Required env vars when running:

export PYTHONPATH=$PWD:$PYTHONPATH
export OMP_NUM_THREADS=1
export TORCH_EXTENSIONS_DIR=$PWD/deepspeed
export TRITON_CACHE_DIR="$PWD/.triton"

Quick smoke test (single-GPU)

Before launching a multi-node run, verify the install on one GPU (A5000 / A6000 / RTX 3090 / 4090, ≥24 GB VRAM, bf16-capable):

sbatch sample_scripts/lab_server_smoke_test/slurm_a5000_dummy_smoke.sh

See sample_scripts/lab_server_smoke_test/README.md for both a dummy-data smoke test (~10 min) and a real HCP fine-tune smoke test (~1–2 h).


Pretrained Checkpoints

Three pretrained backbones are released as v1.0.0 GitHub assets:

Asset embed_dim d_state headdim Size
1M_E128.pt 128 32 128 25 MB
3M_E192.pt (paper main) 192 48 192 42 MB
5M_E256.pt 256 64 256 61 MB

Loading

import torch
ckpt = torch.load("3M_E192.pt", map_location="cpu", weights_only=False)
# Flat top-level state_dict (113 keys), prefix "_forward_module.model.<...>" (DeepSpeed/DDP).
# Architecture: project/module/models/fmamba.py:FMamba
# Note: depthwise conv is named `conv_xBC` in mamba_ssm < 2.x, `conv1d` in ≥ 2.x.

Pre-training

NeuroMamba was pre-trained on resting-state fMRI from:

  • UK Biobank (N=40,647)
  • ABCD (N=9,139)
  • HCP Young Adult (N=1,084)

All data preprocessed with standard pipelines (bias field correction → skull stripping → MNI alignment → crop/pad to 96×96×96) and stored as per-TR HDF5 files. Per-run global stats are computed and applied with the z-norm + minback scaling option by default (see Adding Your Own Dataset).

80/10/10 split across all subjects. HCP test set is held out for downstream evaluation.

Example: pre-train NeuroMamba-Base on S1200

python project/main.py \
    --model fmamba \
    --pretraining \
    --use_autoregressive \
    --calc_loss_without_background \
    --embed_dim 192 --depth 12 \
    --patch_size 6 6 6 2 \
    --img_size 96 96 96 20 \
    --batch_size 3 \
    --gradient_accumulation_steps 1 \
    --precision bf16 \
    --dataset_name S1200 \
    --image_path /path/to/hdf5 \
    --max_epochs 40 \
    --learning_rate 1e-4 \
    --weight_decay 1e-2 \
    --accelerator gpu --devices 12 --num_nodes 32 \
    --strategy deepspeed_stage_1 \
    --loggername wandb \
    --project_name neuromamba_pretrain

See sample_scripts/pretraining/ for full PBS-ready scripts.

Paper hyperparameters (Appendix B):

-Small -Base -Medium -Large
GPU nodes (12 ranks/node) 32 32 32 48
DeepSpeed ZeRO stage 1 1 1 3
Micro batch / GPU 3 3 3 2
Effective batch size 1,152 1,152 1,152 1,152
Epochs 40 40 40 40
Learning rate 1e-4 1e-4 1e-4 1e-4
Weight decay 1e-2 1e-2 1e-2 1e-2
LR schedule CosineAnnealingWarmUpRestarts, 5% warmup, min=1% of peak

Downstream: HCP Sex Classification

Two head architectures (paper Appendix C):

  • Linear head: AdaptiveAvgPool1d(1)Linear(embed_dim, num_classes)
  • Mamba head: 3 Mamba2 blocks (d_state = embed_dim/16) → final-token Linear(embed_dim, num_classes)

Three training modes:

Mode When to use
From scratch No pretrained backbone (control)
Full fine-tuning Pretrained backbone + head, all weights trainable (paper main result)
Frozen backbone Pretrained backbone frozen, head only — used in paper §4.3 background-removal ablation (Table 1)

Example: full fine-tuning, NeuroMamba-Base + Mamba head

python project/main.py \
    --model fmamba \
    --downstream_task sex \
    --downstream_task_type classification \
    --dataset_name HCP \
    --image_path /path/to/hcp_hdf5 \
    --load_model_path /path/to/3M_E192.pt \
    --embed_dim 192 --depth 12 \
    --clf_head_version mamba \
    --patch_size 6 6 6 2 \
    --img_size 96 96 96 20 \
    --batch_size 2 \
    --precision bf16 \
    --max_epochs 30 \
    --accelerator gpu --devices 12 --num_nodes 2 \
    --strategy deepspeed_stage_1 \
    --loggername wandb \
    --project_name neuromamba_hcp_sex

To freeze the backbone and train only the head (paper §4.3 setup), add:

    --freeze_feature_extractor

To train from scratch (no pretrained weights), drop --load_model_path.

See sample_scripts/downstream/HCP/ for full PBS-ready scripts. Each evaluation in the paper is repeated three times with different random seeds; report mean ± std.


Results

Effect of background removal (paper Table 1)

Frozen backbone (NeuroMamba-Medium, 5.4M params), HCP sex:

Head Tokens AUC ACC
Linear All tokens 0.7432±0.00 0.6882±0.01
Linear Brain only 0.7769±0.00 0.7156±0.01
Mamba All tokens 0.8456±0.05 0.7767±0.07
Mamba Brain only 0.8628±0.08 0.8104±0.09

Computational cost: 6.96×10¹⁷ → 3.72×10¹⁷ FLOPs/epoch (46.5% ↓).

From-scratch vs fine-tuning (paper Table 2)

Size Head From scratch (AUC / ACC) Full FT (AUC / ACC)
1.4M Linear 0.9767 / 0.9427 0.9840 / 0.9458
1.4M Mamba 0.9446 / 0.8872 0.9535 / 0.8781
3.1M Linear 0.9813 / 0.9232 0.9885 / 0.9455
3.1M Mamba 0.9825 / 0.9396 0.9874 / 0.9486
5.4M Linear 0.9865 / 0.9347 0.9717 / 0.9198
5.4M Mamba 0.9729 / 0.9175 0.9766 / 0.9307

Comparison with SOTA (paper Table 3)

Model Params Split AUC ACC
SwiFT 4.6M 70/15/15 98.0 92.9
NeuroSTORM 5.0M 70/15/15 97.6 93.3
NeuroMamba 3.1M 80/10/10 98.9 94.9

Note: Paper §5 acknowledges the split difference and plans to re-implement baselines on the 80/10/10 split for a like-for-like comparison.

FLOPs measurement

The DeepSpeed flops profiler was used (paper §4.1). To reproduce:

python project/check_flops.py \
    --model fmamba \
    --embed_dim 256 --depth 12 \
    --patch_size 6 6 6 2 \
    --img_size 96 96 96 20 \
    --batch_size 1 \
    --image_path /path/to/hcp_hdf5 \
    --dataset_name HCP \
    --downstream_task sex

Adding Your Own Dataset

To pre-train or fine-tune NeuroMamba on a new fMRI dataset, follow the protocol in ADDING_YOUR_DATASET.md:

  1. Standard preprocessing (HCP-style): bias field correction → skull stripping → MNI normalization → crop/pad to 96×96×96.
  2. Per-run global statistics computation.
  3. Scaling option choice: znorm_minback (paper default), znorm_zeroback, minmax_zeroback, or minmax_minback.
  4. Per-TR HDF5 conversion.
  5. Add a dataset class entry to project/module/utils/data_module.py.

See the doc for the full step-by-step protocol with example commands.


Repository Layout

NeuroMamba/
├── README.md
├── ADDING_YOUR_DATASET.md          # Step-by-step new-dataset protocol
├── CITATION.bib
├── LICENSE                         # Apache 2.0
├── requirements.txt
├── setup_env.sh                    # Optional HPC venv installer
├── paper/
│   └── NeuroMamba_NeurIPS2025_BrainBodyFM.pdf
├── assets/
│   └── architecture_overview.png   # Paper Figure 2
├── figures/                        # (additional figures, optional)
├── configs/
│   └── wandb.yaml.example
├── project/
│   ├── main.py                     # Training entry point
│   ├── check_flops.py              # FLOPs profiling (paper §4.1)
│   ├── deepspeed_profile.py        # DeepSpeed profiling utility
│   └── module/
│       ├── pl_classifier.py        # PyTorch Lightning module
│       ├── models/
│       │   ├── fmamba.py           # NeuroMamba backbone (paper §3)
│       │   ├── load_model.py
│       │   └── utils.py
│       └── utils/
│           ├── data_module.py
│           ├── data_utils.py
│           ├── augment.py
│           ├── losses.py
│           ├── lr_scheduler.py     # CosineAnnealingWarmUpRestarts
│           ├── masking_generator.py
│           ├── metrics.py
│           ├── parser.py
│           ├── seed_creation.py
│           └── data_preprocess_and_load/
│               ├── preprocessing.py        # Generic HCP-style NIFTI→TR
│               ├── preprocessing.slurm
│               ├── preprocessing_HCP_v2.py
│               ├── datasets.py
│               ├── datasets_hdf5.py
│               ├── check_and_clean_hdf5.py
│               ├── convert_data_into_hdf5_revised.py
│               └── convert_data_into_hdf5_revised_nogzip.py
└── sample_scripts/
    ├── pretraining/                # Pre-training reference scripts (Aurora)
    ├── downstream/
    │   ├── HCP/                    # Paper HCP sex/age/int experiments
    │   │                           #   (5M_*, E128_*, E192_* full / TL variants)
    │   └── 5M_fmamba_*.sh          # Quick-start single-GPU references
    └── lab_server_smoke_test/      # Generic SLURM smoke tests (any A100/Ampere)
        ├── slurm_a5000_dummy_smoke.sh      # Install sanity (no data, ~10 min)
        ├── slurm_a5000_dummy_finetune.sh   # .pt-ckpt load + freeze/head toggles
        └── slurm_a5000_hcp_finetune_smoke.sh   # Real HCP + 5-epoch fine-tune

Citation

@inproceedings{choi2025neuromamba,
  title={NeuroMamba: A State-Space Foundation Model for Functional MRI},
  author={Choi, Jubin and Park, David Keetae and Kwon, Junbeom and Yoo, Shinjae and Cha, Jiook},
  booktitle={39th Conference on Neural Information Processing Systems (NeurIPS 2025) Workshop: Foundation Models for the Brain and Body},
  year={2025},
  url={https://openreview.net/forum?id=kftg4lmQi8}
}

Acknowledgments

This work was supported by:

  • National Research Foundation of Korea (NRF) and IITP grants funded by the Korea government (MSIT)
  • Creative-Pioneering Researchers Program through Seoul National University
  • Korea Brain Research Institute (KBRI) basic research program
  • Korea Health Industry Development Institute (KHIDI), Ministry of Health and Welfare
  • Korea Basic Science Institute (NRFEC) grant
  • U.S. Department of Energy (DOE) ASCR Leadership Computing Challenge (ALCC), under award m4750-2024
  • National Energy Research Scientific Computing Center (NERSC), DOE Office of Science User Facility
  • Argonne (ALCF) and Oak Ridge (OLCF) Leadership Computing Facilities
  • BNL: U.S. DOE Office of Science, ASCR (DE-SC-0012704)

We acknowledge the National Supercomputing Center (KSC-2023-CRE-0568) for computational resources and technical support.


License

Apache License 2.0 — see LICENSE.


Contact

About

Official implementation of "NeuroMamba: A State-Space Foundation Model for Functional MRI" NeurIPS 2025 Workshop on Foundation Models for the Brain and Body (BrainBodyFM)

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors