This repository contains the full code for the CalM pipeline and main paper result reproduction.
Main stages:
- Preprocess raw session data
- Train / test the neural quantizer (NQ)
- Export tokens
- Train the dynamics transformer (DT)
- Fine-tune / evaluate DT on held-out sessions
- Train / evaluate the behavior decoding head (FT)
Install dependencies:
pip install -r requirements.txt
pip install torch>=2.8.0 --index-url https://download.pytorch.org/whl/cu126Tested environment:
- Python 3.11.13
- Hydra / OmegaConf
- 8 × NVIDIA A100 40GB GPUs
Most training / evaluation commands below use torchrun.
Main script:
preprocess/preprocess.py
Add raw sessions to generate curated .npz files for NQ.
python preprocess/preprocess.pyUsually edit only:
conf/nq/data/Tseng_trial.yamlconf/nq/trainer/trainer.yaml
Outputs:
- NQ checkpoint
- reconstruction / tokenizer results
python -m task.nq_train_Tseng
python -m task.nq_test_TsengUsually edit only:
conf/nq/tokenize/tokenize.yaml
Outputs:
- tokenized held-in / held-out session data
python -m task.nq_test_Tseng tokenize=tokenizeUsually edit only:
conf/dt/data/Tseng_train.yamlconf/dt/nq/nq.yaml
Outputs:
- held-in DT checkpoint
- held-in forecasting results
PYTHONPATH=. torchrun --nproc_per_node=8 task/dt_train_Tseng.py train.mode=train ...Usually edit only:
conf/dt/data/Tseng_train.yamlconf/dt/trainer/trainer.yaml
Outputs:
- held-out forecasting results
Fine-tuning:
PYTHONPATH=. torchrun --nproc_per_node=8 task/dt_train_Tseng.py train.mode=finetune_heldout ...Evaluation:
PYTHONPATH=. torchrun --nproc_per_node=8 task/dt_train_Tseng.py train.mode=eval_test train.eval_target=heldout ...Held-in training:
Usually edit only:
conf/ft/data/tseng_behavior.yamlconf/ft/backbone/ar_backbone.yaml
Outputs:
- held-in / held-out decoding results
PYTHONPATH=. torchrun --nproc_per_node=8 task/ft_behavior_decode.py mode=train_base ...Held-out fine-tuning:
PYTHONPATH=. torchrun --nproc_per_node=8 task/ft_behavior_decode.py mode=finetune_heldout ...Evaluation:
PYTHONPATH=. torchrun --nproc_per_node=8 task/ft_behavior_decode.py mode=eval_only ...For most users, only these fields need to be changed:
- dataset roots
- registry json paths
- checkpoint paths
- output / cache directories
You usually do not need to modify most model hyperparameters to reproduce the main pipeline.
-
NQ reconstruction / tokenizer:
task/nq_train_Tseng.pytask/nq_test_Tseng.py
-
DT forecasting:
task/dt_train_Tseng.py- use
train.mode=train,finetune_heldout,eval_test
-
Behavior decoding:
task/ft_behavior_decode.py- use
mode=train_base,mode=finetune_heldout,mode=eval_only
-
Data preprocessing:
preprocess/preprocess.py
-
Plotting / result scripts:
- matlab script in
plot
- matlab script in
The project structure:
nq/: VQ tokenizerdt/: Dynamics Transformer backboneft/: behavior decoding downstream task
conf/nq/
data/
Tseng_trial.yaml
Tseng_trial_small.yaml
loss/
loss.yaml
model/
model.yaml
optim/
adamw.yaml
tokenize/
tokenize.yaml
trainer/
trainer.yaml
vq_train_Tseng.yaml
vq_test_Tseng.yaml
dataset/
nq_dataset.py
model/neural_quantizer/
VQ_quantizer.py
nq_layers.py
nq_utility.py
task/
nq_train_Tseng.py
nq_test_Tseng.py
nq_vq_configs_train_Tseng.py
nq_vq_configs_test_Tseng.py
task/nq_train_Tseng.py
task/nq_test_Tseng.py# example dataset config
# edit dataset root / split-specific settings here
data_root: /path/to/tseng_data# example VQ model config
n_emb: 128
dim_emb: 512
heads: 4
enc_layers: 4
dec_layers: 4
use_gumbel: true
use_gumbel_hard: true# example loss config
w_emb: 1.0
w_commit: 0.5lr: 5e-4
weight_decay: 1e-3epochs: 100
compile: false
compile_dynamic: false# example tokenize / inference config
ckpt_path: /path/to/vq_checkpoint.pth
out_root: /path/to/token_output
resume: true
chunk: 0python -m task.nq_train_TsengExample with overrides:
python -m task.nq_train_Tseng \
data=Tseng_trial_small \
trainer.epochs=100 \
optim.lr=5e-4python -m task.nq_train_Tseng --cfg job --resolveUse this to verify:
- the correct YAMLs are loaded
- CLI overrides are applied
- output paths are correct
Set the checkpoint path in conf/nq/vq_test_Tseng.yaml, then run:
python -m task.nq_test_TsengOr override from CLI:
python -m task.nq_test_Tseng paths.ckpt_path=/path/to/checkpoint.pthUse the tokenize-related settings in:
conf/nq/tokenize/tokenize.yaml
Then run:
python -m task.nq_test_Tseng tokenize=tokenizeDisable Gumbel during inference / tokenization:
model:
use_gumbel: false
use_gumbel_hard: falseDo this in the test / tokenize config, not in the training config.
python -m task.nq_train_Tsengpython -m task.nq_train_Tseng data=Tseng_trial_smallpython -m task.nq_train_Tseng --cfg job --resolvepython -m task.nq_test_Tsengpython -m task.nq_test_Tseng tokenize=tokenizeconf/dt/
data/Tseng_train.yaml
model/model.yaml
nq/nq.yaml
optim/adamw.yaml
trainer/trainer.yaml
train_dt_Tseng.yaml
dataset/
dt_dataset.py
model/dynamics_transformer/
dt_layers.py
dt_utility.py
dual_axis_transformer.py
task/
DT_configs.py
dt_train_Tseng.py
task/dt_train_Tseng.pydefaults:
- dt_train
- data: Tseng_train
- model: model
- optim: adamw
- nq@vq: nq
- trainer@train: trainer
- _self_data_root: /path/to/held_in_tokens
heldout_root: /path/to/held_out_tokens
registry_json: /path/to/base_registry.json
registry_json_heldout: /path/to/heldout_registry.json
token_key: token
vocab: 128
batch_size: 8
eval_batch_size: 8
num_workers: 8
train_split: train
val_split: val
test_split: testvq_state: /path/to/vq_checkpoint.pthmode: train # train / eval_test / finetune_heldout
eval_target: base # base / heldout / both
eval_ckpt: null
init_ckpt: null
heldout_init_ckpt: null
heldout_epochs: 160
heldout_lr: 2e-4
heldout_train_mode: embed_only # embed_only / full
heldout_init_use_ema_weights: false
epochs: 160
compile: false
compile_dynamic: false
eval_detail_every: 5PYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/dt_train_Tseng.py \
train.mode=train \
train.epochs=160 \
train.compile=false \
train.compile_dynamic=false \
data.data_root=/path/to/held_in_tokens \
data.registry_json=/path/to/base_registry.json \
vq.vq_state=/path/to/vq_checkpoint.pth \
runtime.device=cudaPYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/dt_train_Tseng.py \
train.mode=eval_test \
train.eval_target=base \
train.eval_ckpt=/path/to/base_ckpt.pth \
train.compile=false \
train.compile_dynamic=false \
data.data_root=/path/to/held_in_tokens \
data.registry_json=/path/to/base_registry.json \
vq.vq_state=/path/to/vq_checkpoint.pth \
runtime.device=cudaPYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/dt_train_Tseng.py \
train.mode=finetune_heldout \
train.heldout_init_ckpt=/path/to/base_ckpt.pth \
train.heldout_epochs=160 \
train.heldout_lr=2e-4 \
train.heldout_train_mode=embed_only \
train.heldout_init_use_ema_weights=false \
train.compile=false \
train.compile_dynamic=false \
data.heldout_root=/path/to/held_out_tokens \
data.registry_json_heldout=/path/to/heldout_registry.json \
vq.vq_state=/path/to/vq_checkpoint.pth \
runtime.device=cudaPYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/dt_train_Tseng.py \
train.mode=eval_test \
train.eval_target=heldout \
train.eval_ckpt=/path/to/heldout_ft_ckpt.pth \
train.compile=false \
train.compile_dynamic=false \
data.heldout_root=/path/to/held_out_tokens \
data.registry_json_heldout=/path/to/heldout_registry.json \
vq.vq_state=/path/to/vq_checkpoint.pth \
runtime.device=cudaPYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/dt_train_Tseng.py \
train.mode=eval_test \
train.eval_target=both \
train.eval_ckpt=/path/to/heldout_ft_ckpt.pth \
train.compile=false \
train.compile_dynamic=false \
data.data_root=/path/to/held_in_tokens \
data.heldout_root=/path/to/held_out_tokens \
data.registry_json=/path/to/base_registry.json \
data.registry_json_heldout=/path/to/heldout_registry.json \
vq.vq_state=/path/to/vq_checkpoint.pth \
runtime.device=cudaconf/ft/
backbone/ar_backbone.yaml
cache/feature_cache.yaml
data/tseng_behavior.yaml
eval/eval.yaml
head/lowrank_uv.yaml
train/base.yaml
train/heldout.yaml
ft_behavior_tseng.yaml
dataset/
ft_behavior_dataset.py
session_registry.py
model/behavior_decoder/
distributed.py
feature_cache.py
ft_eval.py
ft_trainer.py
heads.py
task/
FT_configs.py
ft_behavior_decode.py
task/ft_behavior_decode.pydefaults:
- ft_behavior
- data: tseng_behavior
- backbone: ar_backbone
- head: lowrank_uv
- train@train_base: base
- train@train_heldout: heldout
- cache: feature_cache
- eval: eval
- _self_registry_json: /path/to/global_registry.json
heldin_root: /path/to/held_in_root
heldout_root: /path/to/held_out_root
token_key: token
beh_key: behavior
vocab: 128
beh_channels: 3
beh_up: 2
batch_size: 8
num_workers: 8ar_ckpt: /path/to/backbone_ckpt.pth
prefer_ema_weights: true
compile: false
compile_dynamic: false
init_head_ckpt: null
init_head_registry_json: nullenabled: true
epochs: 200
lr: 0.0036
weight_decay: 1e-2
warmup_frac: 0.1
eval_every_epochs: 5
do_test_during_train: trueenabled: false
epochs: 100
lr: 0.0072
weight_decay: 1e-3
warmup_frac: 0.1
eval_every_epochs: 5
do_test_during_train: true
ft_mode: all # all / newrowseval_held_in_init: false
eval_held_out_init: false
max_eval_batches: -1
print_each_session: false
print_limit_sessions: -1enabled: true
build: true
use: true
force: false
dir: /path/to/cache_dir
dtype: fp16
max_batches: -1
allow_mismatch: false
allow_empty_eval: trueremark: ft_behavior_tseng
mode: full_pipeline # cache_only / train_base / finetune_heldout / eval_only / full_pipeline
run_name: heldout_ft
runtime:
device: auto
seed: 0
tf32: true
matmul_precision: high
use_bf16: true
ddp: true
dist_backend: nccl
dist_url: env://
local_rank: 0
paths:
output_root: /path/to/output_root
save_dir: ${hydra:runtime.output_dir}PYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/ft_behavior_decode.py \
mode=train_base \
run_name=base_train \
data.registry_json=/path/to/global_registry.json \
data.heldin_root=/path/to/held_in_root \
data.heldout_root=null \
backbone.ar_ckpt=/path/to/backbone_ckpt.pth \
backbone.compile=false \
train_base.enabled=true \
train_heldout.enabled=false \
eval.eval_held_in_init=true \
eval.eval_held_out_init=false \
cache.enabled=true \
cache.build=true \
cache.use=true \
cache.dir=/path/to/base_cache \
paths.save_dir=/path/to/base_save_dirPYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/ft_behavior_decode.py \
mode=eval_only \
run_name=base_eval \
data.registry_json=/path/to/global_registry.json \
data.heldin_root=/path/to/held_in_root \
data.heldout_root=null \
backbone.ar_ckpt=/path/to/backbone_ckpt.pth \
backbone.init_head_ckpt=/path/to/best_base.pt \
train_base.enabled=false \
train_heldout.enabled=false \
eval.eval_held_in_init=true \
eval.eval_held_out_init=false \
cache.enabled=falseEdit these before running:
data.registry_jsondata.heldin_rootdata.heldout_rootbackbone.ar_ckptbackbone.init_head_ckpt=/path/to/best_base.ptbackbone.init_head_registry_json=/path/to/base_registry.jsonif remapping is neededtrain_base.enabled=falsetrain_heldout.enabled=true
Run:
PYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/ft_behavior_decode.py \
mode=finetune_heldout \
run_name=heldout_ft \
data.registry_json=/path/to/global_registry.json \
data.heldin_root=/path/to/held_in_root \
data.heldout_root=/path/to/held_out_root \
backbone.ar_ckpt=/path/to/backbone_ckpt.pth \
backbone.compile=false \
backbone.init_head_ckpt=/path/to/best_base.pt \
backbone.init_head_registry_json=/path/to/base_registry.json \
train_base.enabled=false \
train_heldout.enabled=true \
train_heldout.epochs=100 \
train_heldout.lr=0.0072 \
train_heldout.weight_decay=1e-3 \
train_heldout.ft_mode=all \
eval.eval_held_in_init=false \
eval.eval_held_out_init=true \
cache.enabled=true \
cache.build=true \
cache.use=true \
cache.dir=/path/to/heldout_cache \
paths.save_dir=/path/to/heldout_save_dirtrain_heldout.ft_mode=newrowsPYTHONPATH=. CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 task/ft_behavior_decode.py \
mode=eval_only \
run_name=heldout_eval \
data.registry_json=/path/to/global_registry.json \
data.heldin_root=/path/to/held_in_root \
data.heldout_root=/path/to/held_out_root \
backbone.ar_ckpt=/path/to/backbone_ckpt.pth \
backbone.init_head_ckpt=/path/to/best_heldout.pt \
train_base.enabled=false \
train_heldout.enabled=false \
eval.eval_held_in_init=false \
eval.eval_held_out_init=true \
cache.enabled=falseThe preprocessing pipeline is below:
-
Load one session
- read the neural activity matrix from the session file
- optionally read aligned behavior channels
-
Construct trials
- split the continuous session into fixed-length non-overlapping windows
- each window is treated as one trial
-
Split into train / val / test
- randomly split trials with a 70 / 15 / 15 ratio
-
Preprocess neural activity
- apply causal EMA smoothing
- compute mean and standard deviation using training trials only
- apply the same normalization statistics to validation and test trials
-
Preprocess behavior
- keep the aligned behavior channels with the same trial boundaries
- if enabled, compute behavior z-score statistics from the training split only
- apply the same behavior normalization to validation and test trials
-
Save curated AR-format data
- each session is saved as one
.npzfile containing:train_padded_X,val_padded_X,test_padded_Xtrain_padded_Y,val_padded_Y,test_padded_Ytrain_lengths,val_lengths,test_lengthstrain_trial_ids,val_trial_ids,test_trial_idsz_mean,z_std- optionally
beh_z_mean,beh_z_std
- each session is saved as one
The final tensor layout is:
- neural data:
(B, N, T) - behavior data:
(B, C, T)
where B is the number of trials, N is the number of neurons, C is the number of behavior channels, and T is the trial length.