Skip to content

LucasYFL/Multistage_Diffusion

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Improving Training Efficiency of Diffusion Models via Multi-Stage Framework and Tailored Multi-Decoder Architecture

Huijie Zhang*, Yifu Lu*, Ismail Alkhouri, Saiprasad Ravishankar, Dogyoon Song, Qing Qu.

Paper | arXiv | Website

Teaser image

Abstract.

Diffusion models, emerging as powerful deep generative tools, excel in various applications. They operate through a two-steps process: introducing noise into training samples and then employing a model to convert random noise into new samples (e.g., images). However, their remarkable generative performance is hindered by slow training and sampling. This is due to the necessity of tracking extensive forward and reverse diffusion trajectories, and employing a large model with numerous parameters across multiple timesteps (i.e., noise levels). To tackle these challenges, we present a multi-stage framework inspired by our empirical findings. These observations indicate the advantages of employing distinct parameters tailored to each timestep while retaining universal parameters shared across all time steps. Our approach involves segmenting the time interval into multiple stages where we employ custom multi-decoder U-net architecture that blends time-dependent models with a universally shared encoder. Our framework enables the efficient distribution of computational resources and mitigates inter-stage interference, which substantially improves training efficiency. Extensive numerical experiments affirm the effectiveness of our framework, showcasing significant training and sampling efficiency enhancements on three state-of-the-art diffusion models, including large-scale latent diffusion models. Furthermore, our ablation studies illustrate the impact of two important components in our framework: (i) a novel timestep clustering algorithm for stage division, and (ii) an innovative multi-decoder U-net architecture, seamlessly integrating universal and customized hyperparameters.

This repository is based on DPM-solver, EDM and LDM. We adding the multistage strategy to both of this three repositories. For the following sections, we provide detailed instructions for running each architecture.

Pretrained model

All models mentioned in the paper could be downloaded here, including multistage-dpm-solver for CIFAR-10 dataset, multistage-edm for CIFAR-10 and CelebA dataset, multistage-ldm for CelebA dataset.

DPM-Solver

Requirements

cd dpm-solver/score_sde_pytorch
conda create -n dpm python=3.8
conda activate dpm
pip install -r requirements.txt

Training

torchrun --nproc_per_node=2 --master_port=29502 main.py --config ./configs/vp/cifar10_ncsnpp_multistage_deep_continuous_v2.py --workdir exp/multistage --mode train --config.training.batch_size=128

Evaluation

torchrun --nproc_per_node=1 --master_port=29600 main_interval.py --config "configs/vp/cifar10_ncsnpp_multistage_deep_continuous_v2.py" --m1 multistage --workdir exp --eval_folder multistage/eval --config.eval.t_tuples="()" --config.eval.t_converge="(0,)" --config.eval.begin_ckpt=1 --config.eval.end_ckpt=10 --config.eval.batch_size=1024 --config.sampling.steps=20  --config.sampling.eps=1e-4 
python evaluation_fromsample.py --config "configs/vp/cifar10_ncsnpp_multistage_deep_continuous_v2.py" --workdir exp --eval_folder multistage/eval  --config.eval.begin_ckpt=1 --config.eval.end_ckpt=10 --config.eval.batch_size=1024

EDM

Requirements

conda env create -f edm/environment.yml -n multistage-edm
conda activate multistage-edm

Preparing dataset

For training and evaluation dataset preparation, please go to preparing dataset of EDM.

Training

# Train DDPM++ model for unconditional CIFAR-10 using 4 GPUs
cd edm
torchrun --standalone --nproc_per_node=4 train.py --outdir=../training-runs --data=../dataset/cifar10-32x32.zip --cond=0 --arch=ddpmpp-multistage --batch=128

Evaluation

# evaluate multistage-edm on CelebA dataset
cd edm
torchrun --standalone --nproc_per_node=1 generate.py --outdir=../fid-tmp --seeds=00000-49999 --network=../model/multistage_edm_celeba.pkl --batch=512

torchrun --standalone --nproc_per_node=1 fid.py calc --images=../fid-tmp --ref=../dataset/celebA_32_edm_fid.npz

# evaluate multistage-edm on CIFAR-10 dataset
cd edm
torchrun --standalone --nproc_per_node=1 generate.py --outdir=../fid-tmp --seeds=00000-49999 --network=../model/multistage_edm_cifar10.pkl --batch=512

torchrun --standalone --nproc_per_node=1 fid.py calc --images=../fid-tmp --ref=../dataset/cifar10-fid.npz

LDM

Requirements

cd latent-diffusion
conda create -n ldm python=3.8
conda activate ldm
pip install -r requirements.txt

Training

python main.py --base configs/celebahq-ldm-vq-4_multistage224-256-192-128.yaml -t --gpus 0,1,2,

Evaluation

python scripts/sample_diffusion.py -r DIR_TO_CKPT -l DIR_TO_RESULT_FOLDER -e 0 -c 20 --batch_size 48
fidelity --gpu 0 --fid --samples-find-deep --input1 data/celebahq/ --input2 DIR_TO_RESULT_FOLDER

Citation

@inproceedings{multistage,
title={Improving Training Efficiency of Diffusion Models via Multi-Stage Framework and Tailored Multi-Decoder Architectures},
author={Zhang, Huijie and Lu, Yifu and Alkhouri, Ismail and Ravishankar, Saiprasad and Song, Dogyoon and Qu, Qing},
booktitle={Conference on Computer Vision and Pattern Recognition 2024},
year={2024},
url={https://openreview.net/forum?id=YtptmpZQOg}
}

Acknowlegements

This repository is based on LuChengTHU/dpm-solver, NVlabs/edm and CompVis/latent-diffusion.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published