Skip to content

TSuXinH/CalM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Quick Start: CalM with hydra

This repository contains the full code for the CalM pipeline and main paper result reproduction.

Main stages:

  1. Preprocess raw session data
  2. Train / test the neural quantizer (NQ)
  3. Export tokens
  4. Train the dynamics transformer (DT)
  5. Fine-tune / evaluate DT on held-out sessions
  6. Train / evaluate the behavior decoding head (FT)

Environment

Install dependencies:

pip install -r requirements.txt
pip install torch>=2.8.0 --index-url https://download.pytorch.org/whl/cu126

Tested environment:

  • Python 3.11.13
  • Hydra / OmegaConf
  • 8 × NVIDIA A100 40GB GPUs

Most training / evaluation commands below use torchrun.

Minimal reproduction path

Step 0. Preprocess raw data

Main script: preprocess/preprocess.py

Add raw sessions to generate curated .npz files for NQ.

python preprocess/preprocess.py

Step 1. Train / test NQ

Usually edit only:

  • conf/nq/data/Tseng_trial.yaml
  • conf/nq/trainer/trainer.yaml

Outputs:

  • NQ checkpoint
  • reconstruction / tokenizer results
python -m task.nq_train_Tseng
python -m task.nq_test_Tseng

Step 2. Export tokens

Usually edit only:

  • conf/nq/tokenize/tokenize.yaml

Outputs:

  • tokenized held-in / held-out session data
python -m task.nq_test_Tseng tokenize=tokenize

Step 3. Train DT on held-in sessions

Usually edit only:

  • conf/dt/data/Tseng_train.yaml
  • conf/dt/nq/nq.yaml

Outputs:

  • held-in DT checkpoint
  • held-in forecasting results
PYTHONPATH=. torchrun --nproc_per_node=8 task/dt_train_Tseng.py train.mode=train ...

Step 4. Fine-tune / evaluate DT on held-out sessions

Usually edit only:

  • conf/dt/data/Tseng_train.yaml
  • conf/dt/trainer/trainer.yaml

Outputs:

  • held-out forecasting results

Fine-tuning:

PYTHONPATH=. torchrun --nproc_per_node=8 task/dt_train_Tseng.py train.mode=finetune_heldout ...

Evaluation:

PYTHONPATH=. torchrun --nproc_per_node=8 task/dt_train_Tseng.py train.mode=eval_test train.eval_target=heldout ...

Step 5. Train / evaluate the decoding head

Held-in training:

Usually edit only:

  • conf/ft/data/tseng_behavior.yaml
  • conf/ft/backbone/ar_backbone.yaml

Outputs:

  • held-in / held-out decoding results
PYTHONPATH=. torchrun --nproc_per_node=8 task/ft_behavior_decode.py mode=train_base ...

Held-out fine-tuning:

PYTHONPATH=. torchrun --nproc_per_node=8 task/ft_behavior_decode.py mode=finetune_heldout ...

Evaluation:

PYTHONPATH=. torchrun --nproc_per_node=8 task/ft_behavior_decode.py mode=eval_only ...

What usually needs editing

For most users, only these fields need to be changed:

  • dataset roots
  • registry json paths
  • checkpoint paths
  • output / cache directories

You usually do not need to modify most model hyperparameters to reproduce the main pipeline.

Paper result mapping

  • NQ reconstruction / tokenizer:

    • task/nq_train_Tseng.py
    • task/nq_test_Tseng.py
  • DT forecasting:

    • task/dt_train_Tseng.py
    • use train.mode=train, finetune_heldout, eval_test
  • Behavior decoding:

    • task/ft_behavior_decode.py
    • use mode=train_base, mode=finetune_heldout, mode=eval_only
  • Data preprocessing:

    • preprocess/preprocess.py
  • Plotting / result scripts:

    • matlab script in plot

Detailed guide: CalM with hydra

The project structure:

  • nq/: VQ tokenizer
  • dt/: Dynamics Transformer backbone
  • ft/: behavior decoding downstream task

1. VQ / Neural Quantizer (NQ)

Relevant files

conf/nq/
  data/
    Tseng_trial.yaml
    Tseng_trial_small.yaml
  loss/
    loss.yaml
  model/
    model.yaml
  optim/
    adamw.yaml
  tokenize/
    tokenize.yaml
  trainer/
    trainer.yaml
  vq_train_Tseng.yaml
  vq_test_Tseng.yaml

dataset/
  nq_dataset.py

model/neural_quantizer/
  VQ_quantizer.py
  nq_layers.py
  nq_utility.py

task/
  nq_train_Tseng.py
  nq_test_Tseng.py
  nq_vq_configs_train_Tseng.py
  nq_vq_configs_test_Tseng.py

1.1 Main entries

task/nq_train_Tseng.py
task/nq_test_Tseng.py

1.2 Main configs to edit

conf/nq/data/Tseng_trial.yaml

# example dataset config
# edit dataset root / split-specific settings here
data_root: /path/to/tseng_data

conf/nq/model/model.yaml

# example VQ model config
n_emb: 128
dim_emb: 512
heads: 4
enc_layers: 4
dec_layers: 4
use_gumbel: true
use_gumbel_hard: true

conf/nq/loss/loss.yaml

# example loss config
w_emb: 1.0
w_commit: 0.5

conf/nq/optim/adamw.yaml

lr: 5e-4
weight_decay: 1e-3

conf/nq/trainer/trainer.yaml

epochs: 100
compile: false
compile_dynamic: false

conf/nq/tokenize/tokenize.yaml

# example tokenize / inference config
ckpt_path: /path/to/vq_checkpoint.pth
out_root: /path/to/token_output
resume: true
chunk: 0

1.3 Train neural quantizer

python -m task.nq_train_Tseng

Example with overrides:

python -m task.nq_train_Tseng \
  data=Tseng_trial_small \
  trainer.epochs=100 \
  optim.lr=5e-4

1.4 Inspect the merged Hydra config

python -m task.nq_train_Tseng --cfg job --resolve

Use this to verify:

  • the correct YAMLs are loaded
  • CLI overrides are applied
  • output paths are correct

1.5 Evaluate / reconstruct with a trained VQ

Set the checkpoint path in conf/nq/vq_test_Tseng.yaml, then run:

python -m task.nq_test_Tseng

Or override from CLI:

python -m task.nq_test_Tseng paths.ckpt_path=/path/to/checkpoint.pth

1.6 Export tokens for downstream AR training

Use the tokenize-related settings in:

conf/nq/tokenize/tokenize.yaml

Then run:

python -m task.nq_test_Tseng tokenize=tokenize

1.7 Important note for token export

Disable Gumbel during inference / tokenization:

model:
  use_gumbel: false
  use_gumbel_hard: false

Do this in the test / tokenize config, not in the training config.

1.8 Common commands

Train VQ

python -m task.nq_train_Tseng

Train on a smaller dataset

python -m task.nq_train_Tseng data=Tseng_trial_small

Inspect config

python -m task.nq_train_Tseng --cfg job --resolve

Test / reconstruct

python -m task.nq_test_Tseng

Export tokens

python -m task.nq_test_Tseng tokenize=tokenize

2. Dynamics Transformer (DT)

Relevant files

conf/dt/
  data/Tseng_train.yaml
  model/model.yaml
  nq/nq.yaml
  optim/adamw.yaml
  trainer/trainer.yaml
  train_dt_Tseng.yaml

dataset/
  dt_dataset.py

model/dynamics_transformer/
  dt_layers.py
  dt_utility.py
  dual_axis_transformer.py

task/
  DT_configs.py
  dt_train_Tseng.py

2.1 Main entry

task/dt_train_Tseng.py

2.2 Required defaults in conf/dt/train_dt_Tseng.yaml

defaults:
  - dt_train
  - data: Tseng_train
  - model: model
  - optim: adamw
  - nq@vq: nq
  - trainer@train: trainer
  - _self_

2.3 Main configs to edit

conf/dt/data/Tseng_train.yaml

data_root: /path/to/held_in_tokens
heldout_root: /path/to/held_out_tokens
registry_json: /path/to/base_registry.json
registry_json_heldout: /path/to/heldout_registry.json

token_key: token
vocab: 128
batch_size: 8
eval_batch_size: 8
num_workers: 8

train_split: train
val_split: val
test_split: test

conf/dt/nq/nq.yaml

vq_state: /path/to/vq_checkpoint.pth

conf/dt/trainer/trainer.yaml

mode: train                    # train / eval_test / finetune_heldout
eval_target: base              # base / heldout / both
eval_ckpt: null

init_ckpt: null
heldout_init_ckpt: null
heldout_epochs: 160
heldout_lr: 2e-4
heldout_train_mode: embed_only # embed_only / full
heldout_init_use_ema_weights: false

epochs: 160
compile: false
compile_dynamic: false
eval_detail_every: 5

2.4 Held-in training

PYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/dt_train_Tseng.py \
  train.mode=train \
  train.epochs=160 \
  train.compile=false \
  train.compile_dynamic=false \
  data.data_root=/path/to/held_in_tokens \
  data.registry_json=/path/to/base_registry.json \
  vq.vq_state=/path/to/vq_checkpoint.pth \
  runtime.device=cuda

2.5 Held-in evaluation

PYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/dt_train_Tseng.py \
  train.mode=eval_test \
  train.eval_target=base \
  train.eval_ckpt=/path/to/base_ckpt.pth \
  train.compile=false \
  train.compile_dynamic=false \
  data.data_root=/path/to/held_in_tokens \
  data.registry_json=/path/to/base_registry.json \
  vq.vq_state=/path/to/vq_checkpoint.pth \
  runtime.device=cuda

2.6 Held-out fine-tuning

PYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/dt_train_Tseng.py \
  train.mode=finetune_heldout \
  train.heldout_init_ckpt=/path/to/base_ckpt.pth \
  train.heldout_epochs=160 \
  train.heldout_lr=2e-4 \
  train.heldout_train_mode=embed_only \
  train.heldout_init_use_ema_weights=false \
  train.compile=false \
  train.compile_dynamic=false \
  data.heldout_root=/path/to/held_out_tokens \
  data.registry_json_heldout=/path/to/heldout_registry.json \
  vq.vq_state=/path/to/vq_checkpoint.pth \
  runtime.device=cuda

2.7 Held-out evaluation

Held-out only

PYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/dt_train_Tseng.py \
  train.mode=eval_test \
  train.eval_target=heldout \
  train.eval_ckpt=/path/to/heldout_ft_ckpt.pth \
  train.compile=false \
  train.compile_dynamic=false \
  data.heldout_root=/path/to/held_out_tokens \
  data.registry_json_heldout=/path/to/heldout_registry.json \
  vq.vq_state=/path/to/vq_checkpoint.pth \
  runtime.device=cuda

Base + held-out together

PYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/dt_train_Tseng.py \
  train.mode=eval_test \
  train.eval_target=both \
  train.eval_ckpt=/path/to/heldout_ft_ckpt.pth \
  train.compile=false \
  train.compile_dynamic=false \
  data.data_root=/path/to/held_in_tokens \
  data.heldout_root=/path/to/held_out_tokens \
  data.registry_json=/path/to/base_registry.json \
  data.registry_json_heldout=/path/to/heldout_registry.json \
  vq.vq_state=/path/to/vq_checkpoint.pth \
  runtime.device=cuda

3. Behavior Decoding (FT)

Relevant files

conf/ft/
  backbone/ar_backbone.yaml
  cache/feature_cache.yaml
  data/tseng_behavior.yaml
  eval/eval.yaml
  head/lowrank_uv.yaml
  train/base.yaml
  train/heldout.yaml
  ft_behavior_tseng.yaml

dataset/
  ft_behavior_dataset.py
  session_registry.py

model/behavior_decoder/
  distributed.py
  feature_cache.py
  ft_eval.py
  ft_trainer.py
  heads.py

task/
  FT_configs.py
  ft_behavior_decode.py

3.1 Main entry

task/ft_behavior_decode.py

3.2 Required defaults in conf/ft/ft_behavior_tseng.yaml

defaults:
  - ft_behavior
  - data: tseng_behavior
  - backbone: ar_backbone
  - head: lowrank_uv
  - train@train_base: base
  - train@train_heldout: heldout
  - cache: feature_cache
  - eval: eval
  - _self_

3.3 Main configs to edit

conf/ft/data/tseng_behavior.yaml

registry_json: /path/to/global_registry.json
heldin_root: /path/to/held_in_root
heldout_root: /path/to/held_out_root

token_key: token
beh_key: behavior
vocab: 128
beh_channels: 3
beh_up: 2
batch_size: 8
num_workers: 8

conf/ft/backbone/ar_backbone.yaml

ar_ckpt: /path/to/backbone_ckpt.pth
prefer_ema_weights: true
compile: false
compile_dynamic: false

init_head_ckpt: null
init_head_registry_json: null

conf/ft/train/base.yaml

enabled: true
epochs: 200
lr: 0.0036
weight_decay: 1e-2
warmup_frac: 0.1
eval_every_epochs: 5
do_test_during_train: true

conf/ft/train/heldout.yaml

enabled: false
epochs: 100
lr: 0.0072
weight_decay: 1e-3
warmup_frac: 0.1
eval_every_epochs: 5
do_test_during_train: true
ft_mode: all    # all / newrows

conf/ft/eval/eval.yaml

eval_held_in_init: false
eval_held_out_init: false
max_eval_batches: -1
print_each_session: false
print_limit_sessions: -1

conf/ft/cache/feature_cache.yaml

enabled: true
build: true
use: true
force: false
dir: /path/to/cache_dir
dtype: fp16
max_batches: -1
allow_mismatch: false
allow_empty_eval: true

conf/ft/ft_behavior_tseng.yaml

remark: ft_behavior_tseng
mode: full_pipeline       # cache_only / train_base / finetune_heldout / eval_only / full_pipeline
run_name: heldout_ft

runtime:
  device: auto
  seed: 0
  tf32: true
  matmul_precision: high
  use_bf16: true
  ddp: true
  dist_backend: nccl
  dist_url: env://
  local_rank: 0

paths:
  output_root: /path/to/output_root
  save_dir: ${hydra:runtime.output_dir}

3.4 Held-in head training

PYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/ft_behavior_decode.py \
  mode=train_base \
  run_name=base_train \
  data.registry_json=/path/to/global_registry.json \
  data.heldin_root=/path/to/held_in_root \
  data.heldout_root=null \
  backbone.ar_ckpt=/path/to/backbone_ckpt.pth \
  backbone.compile=false \
  train_base.enabled=true \
  train_heldout.enabled=false \
  eval.eval_held_in_init=true \
  eval.eval_held_out_init=false \
  cache.enabled=true \
  cache.build=true \
  cache.use=true \
  cache.dir=/path/to/base_cache \
  paths.save_dir=/path/to/base_save_dir

3.5 Held-in head evaluation

PYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/ft_behavior_decode.py \
  mode=eval_only \
  run_name=base_eval \
  data.registry_json=/path/to/global_registry.json \
  data.heldin_root=/path/to/held_in_root \
  data.heldout_root=null \
  backbone.ar_ckpt=/path/to/backbone_ckpt.pth \
  backbone.init_head_ckpt=/path/to/best_base.pt \
  train_base.enabled=false \
  train_heldout.enabled=false \
  eval.eval_held_in_init=true \
  eval.eval_held_out_init=false \
  cache.enabled=false

3.6 Held-out fine-tuning

Edit these before running:

  • data.registry_json
  • data.heldin_root
  • data.heldout_root
  • backbone.ar_ckpt
  • backbone.init_head_ckpt=/path/to/best_base.pt
  • backbone.init_head_registry_json=/path/to/base_registry.json if remapping is needed
  • train_base.enabled=false
  • train_heldout.enabled=true

Run:

PYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/ft_behavior_decode.py \
  mode=finetune_heldout \
  run_name=heldout_ft \
  data.registry_json=/path/to/global_registry.json \
  data.heldin_root=/path/to/held_in_root \
  data.heldout_root=/path/to/held_out_root \
  backbone.ar_ckpt=/path/to/backbone_ckpt.pth \
  backbone.compile=false \
  backbone.init_head_ckpt=/path/to/best_base.pt \
  backbone.init_head_registry_json=/path/to/base_registry.json \
  train_base.enabled=false \
  train_heldout.enabled=true \
  train_heldout.epochs=100 \
  train_heldout.lr=0.0072 \
  train_heldout.weight_decay=1e-3 \
  train_heldout.ft_mode=all \
  eval.eval_held_in_init=false \
  eval.eval_held_out_init=true \
  cache.enabled=true \
  cache.build=true \
  cache.use=true \
  cache.dir=/path/to/heldout_cache \
  paths.save_dir=/path/to/heldout_save_dir

Train only new held-out rows

train_heldout.ft_mode=newrows

3.7 Held-out evaluation

PYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/ft_behavior_decode.py \
  mode=eval_only \
  run_name=heldout_eval \
  data.registry_json=/path/to/global_registry.json \
  data.heldin_root=/path/to/held_in_root \
  data.heldout_root=/path/to/held_out_root \
  backbone.ar_ckpt=/path/to/backbone_ckpt.pth \
  backbone.init_head_ckpt=/path/to/best_heldout.pt \
  train_base.enabled=false \
  train_heldout.enabled=false \
  eval.eval_held_in_init=false \
  eval.eval_held_out_init=true \
  cache.enabled=false

Data Curation

The preprocessing pipeline is below:

  1. Load one session

    • read the neural activity matrix from the session file
    • optionally read aligned behavior channels
  2. Construct trials

    • split the continuous session into fixed-length non-overlapping windows
    • each window is treated as one trial
  3. Split into train / val / test

    • randomly split trials with a 70 / 15 / 15 ratio
  4. Preprocess neural activity

    • apply causal EMA smoothing
    • compute mean and standard deviation using training trials only
    • apply the same normalization statistics to validation and test trials
  5. Preprocess behavior

    • keep the aligned behavior channels with the same trial boundaries
    • if enabled, compute behavior z-score statistics from the training split only
    • apply the same behavior normalization to validation and test trials
  6. Save curated AR-format data

    • each session is saved as one .npz file containing:
      • train_padded_X, val_padded_X, test_padded_X
      • train_padded_Y, val_padded_Y, test_padded_Y
      • train_lengths, val_lengths, test_lengths
      • train_trial_ids, val_trial_ids, test_trial_ids
      • z_mean, z_std
      • optionally beh_z_mean, beh_z_std

The final tensor layout is:

  • neural data: (B, N, T)
  • behavior data: (B, C, T)

where B is the number of trials, N is the number of neurons, C is the number of behavior channels, and T is the trial length.

About

CalM (published on ICML26) is a self-supervised framework for neuronal calcium traces including a shared discrete tokenizer and an autoregressive transformer, enabling forecasting, behavior decoding, and providing interpretability across large, multi-animal datasets.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors