This work introduces Video Diffusion Transformer (VDT), which pioneers the use of transformers in diffusion-based video generation. It features transformer blocks with modularized temporal and spatial attention modules, allowing separate optimization of each component and leveraging the rich spatial-temporal representation inherited from transformers.
VDT offers several appealing benefits. (1) It excels at capturing temporal dependencies to produce temporally consistent video frames and even simulate the dynamics of 3D objects over time. (2) It enables flexible conditioning information through simple concatenation in the token space, effectively unifying video generation and prediction tasks. (3) Its modularized design facilitates a spatial-temporal decoupled training strategy, leading to improved efficiency.
Extensive experiments on video generation, prediction, and dynamics modeling (i.e., physics-based QA) tasks have been conducted to demonstrate the effectiveness of VDT in various scenarios, including autonomous driving, human action, and physics-based simulation.
- Python3, PyTorch>=1.8.0, torchvision>=0.7.0 are required for the current codebase.
- To install the other dependencies, run
conda env create -f environment.yml
conda activate VDT
We now provide Physion_Collide checkpoint for conditional generation. You can download it from here.
We provide inference script on physion_collide video prediction. To sample results, you can first download the checkpoint, then run:
python physion_sample.py --ckpt $CHECKPOINT_PATH
Our codebase is built based on DiT, SlotFormer and MVCD. We thank the authors for the nicely organized code!