Skip to content

TianrongChen/DMSB

Repository files navigation

(NeurIPS2023) DMSB: Deep Momentum Multi-Marginal Schrödinger Bridge [LINK]

Official PyTorch implementation of the paper "Deep Momentum Multi-Marginal Schrödinger Bridge (DMSB)" which introduces a new class of trajectory inference models that extend SB models to momentum dynamcis and multi-marginal case.

Connection with Vanilla Schrödinger Bridge

Example GIF

Toy Examples

Tasks (--problem-name) Results
Mixture Gaussians (gmm)

drawing

Semicircle (semicircle)

drawing

Petal (Petal)

drawing

100-Dim Single Cell RNA sequence (RNAsc)

drawing

If you find this library useful, please cite ⬇️
@article{chen2023deep,
  title={Deep Momentum Multi-Marginal Schr$\backslash$" odinger Bridge},
  author={Chen, Tianrong and Liu, Guan-Horng and Tao, Molei and Theodorou, Evangelos A},
  journal={arXiv preprint arXiv:2303.01751},
  year={2023}
}

Installation

(Environment may have conflict with cuda version... I am currently fixing it... but it should work for most of cuda...)This code is developed with Python3. PyTorch >=1.7 (we recommend 1.8.1). First, install the dependencies with Anaconda and activate the environment DMSB with

conda env create --file requirements.yaml python=3.8
conda activate DMSB

Download the RNA-seq daaset from this repo, and put it under ./data/RNAsc/ProcessedData/.

Reproducing the result in the paper


We provide the checkpoint and the code for training from scratch for all the dataset reported in the paper.

GMM

python main.py --problem-name gmm --dir reproduce/gmm --log-tb --gpu 1

Memo: The results in the paper sould be reproduced by around 6 stage of Bregman Iteration.

Petal

python main.py --problem-name petal --dir reproduce/petal --log-tb

Memo: The results in the paper sould be reproduced by around 17 stage of Bregman Iteration.

RNAsc

python main.py --problem-name RNAsc --dir reproduce/RNA --log-tb  --num-itr 2000
python main.py --problem-name RNAsc --dir reproduce/RNA-loo1 --log-tb  --use-amp --num-itr 2000 --LOO 1
python main.py --problem-name RNAsc --dir reproduce/RNA-loo2 --log-tb  --use-amp --num-itr 2000 --LOO 2
python main.py --problem-name RNAsc --dir reproduce/RNA-loo3 --log-tb  --use-amp --num-itr 2000 --LOO 3

Where Can I find the results?

The visualization results are saved in the folder /results.

The numerical value are saved in the tensorboard and event file are saved the folder /runs,

The checkpoints are saved in the folder /checkpoint, and you can reload the checkpoint by:

python main.py --problem-name [problem-name] --dir [your/dir/name/for/current/run] --log-tb  --load [dir/to/checkpoints/]

The numerical results for all metrics will be displayed in the terminal as well.

About

Official implementation of Deep Momentum Schrödinger Bridge

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages