Skip to content

LiuhanChen-github/VDiS

Repository files navigation

Scalable Diffusion Models with State Space Backbone (DiS)
Official PyTorch Implementation

arXiv

This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring diffusion models with state space backbones (DiSs). Our model treats all inputs including the time, condition and noisy image patches as tokens and employs skip connections between shallow and deep layers. Different from original Mamba for text sequence modeling, our SSM block process the hidden states sequence with both forward and backward directions

DiS framework

1. Environments

  • Python 3.10

    • conda create -n your_env_name python=3.10
  • Requirements file

    • pip install -r requirements.txt
  • Install causal_conv1d and mamba

    • pip install -e causal_conv1d
    • pip install -e mamba

2. Training

We provide a training script for DiS in train.py. This script can be used to train unconditional, class-conditional DiS models, it can be easily modified to support other types of conditioning.

To launch DiS-H/2 (512x512) in the latent space training with N GPUs on one node:

torchrun --nnodes=1 --nproc_per_node=8 train.py \
--model DiS-H/2 \
--dataset-type imagenet \
--data-path imageNet1k \
--image-size 64 \
--task-type class-cond \
--num-classes 999

To launch DiS-S/2 (32x32) in the pixel space training with N GPUs on one node:

torchrun --nnodes=1 --nproc_per_node=8 train.py \
--model DiS-S/2 \
--data-path cifar10_data \
--dataset-type cifar-10 \
--image-size 32 \
--task-type uncond 

There are several additional options; see train.py for details. All experiments in our work of training script can be found in file direction script.

For convenience, the pre-trained DiS models can be downloaded directly here as well:

DiT Model Image Resolution FID-50K
DiS-H/2 256x256 2.10
DiS-H/2 512x512 2.88

3. Evaluation

We include a sample.py script which samples images from a DiS model. Besides, we support other metrics evaluation in test.py script.

python sample.py \
--model DiS-L/2 \
--dataset-type imagenet \
--ckpt /path/to/model \
--image-size 256 \
--num-classes 1000 \
--cfg-scale 1.5

4. BibTeX

@article{FeiDiS2024,
  title={Scalable Diffusion Models with State Space Backbone},
  author={Zhengcong Fei, Mingyuan Fan, Changqian Yu, Jusnshi Huang},
  year={2024},
  journal={arXiv preprint},
}

5. Acknowledgments

The codebase is based on the awesome DiT, U-ViT, and Vim repos.

About

use mamba to video generation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published