Skip to content

VainF/Diff-Pruning

Repository files navigation

Diff-Pruning: Structural Pruning for Diffusion Models

Update

Check our latest work DeepCache, a training-free and almost loessless method for diffusion model acceleration. It can be viewed as a special pruning technique that dynamically drops deep layers and only runs shallow ones during inference.

Introduction

Structural Pruning for Diffusion Models [arxiv]
Gongfan Fang, Xinyin Ma, Xinchao Wang
National University of Singapore

This work presents Diff-Pruning, an efficient structrual pruning method for diffusion models. Our empirical assessment highlights two primary features:

  1. Efficiency: It enables approximately a 50% reduction in FLOPs at a mere 10% to 20% of the original training expenditure;
  2. Consistency: The pruned diffusion models inherently preserve generative behavior congruent with the pre-trained ones.

Supported Methods

  • Magnitude Pruning
  • Random Pruning
  • Taylor Pruning
  • Diff-Pruning (A taylor-based method proposed in our paper)

TODO List

  • Support more diffusion models from Diffusers
  • Upload checkpoints of pruned models
  • Training scripts for CelebA-HQ, LSUN Church & LSUN Bedroom
  • Align the performance with the DDIM Repo.

Our Exp Code (Unorganized)

Pruning with DDIM codebase

This example shows how to prune a DDPM model pre-trained on CIFAR-10 using the DDIM codebase. Since that Huggingface Diffusers do not support skip_type='quad' in DDIM, you may get slightly worse FID scores with Diffusers for both pre-trained models (FID=4.5) and pruned models (FID=5.6). We are working on this to implement the quad strategy for Diffusers. For reproducibility, we provide our original but unorganized exp code for the paper in ddpm_exp.

cd ddpm_exp
# Prune & Finetune
bash scripts/simple_cifar_our.sh 0.05 # the pre-trained model and data will be automatically prepared
# Sampling
bash scripts/sample_cifar_ddpm_pruning.sh run/finetune_simple_v2/cifar10_ours_T=0.05.pth/logs/post_training/ckpt_100000.pth run/sample

For FID, please refer to this section.

Output:

Found 49984 files.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:49<00:00,  7.97it/s]
FID:  5.242662673752534

Pruning with LDM codebase

Please check ldm_exp/run.sh for an example of pruning a pre-trained LDM model on ImageNet. This codebase is still unorganized. We will clean it up in the future.

Pruning with Huggingface Diffusers

The following pipeline prunes a pre-trained DDPM on CIFAR-10 with Huggingface Diffusers.

0. Requirements, Data and Pretrained Model

  • Requirements
pip install -r requirements.txt
  • Data

Download and extract CIFAR-10 images to data/cifar10_images for training and evaluation.

python tools/extract_cifar10.py --output data
  • Pretrained Models

The following script will download an official DDPM model and convert it to the format of Huggingface Diffusers. You can find the converted model at pretrained/ddpm_ema_cifar10. It is an EMA version of google/ddpm-cifar10-32

bash tools/convert_cifar10_ddpm_ema.sh

(Optional) You can also download a pre-converted model using wget

wget https://github.com/VainF/Diff-Pruning/releases/download/v0.0.1/ddpm_ema_cifar10.zip

1. Pruning

Create a pruned model at run/pruned/ddpm_cifar10_pruned

bash scripts/prune_ddpm_cifar10.sh 0.3  # pruning ratio = 30\%

2. Finetuning (Post-Training)

Finetune the model and save it at run/finetuned/ddpm_cifar10_pruned_post_training

bash scripts/finetune_ddpm_cifar10.sh

3. Sampling

Pruned: Sample and save images to run/sample/ddpm_cifar10_pruned

bash scripts/sample_ddpm_cifar10_pruned.sh

Pretrained: Sample and save images to run/sample/ddpm_cifar10_pretrained

bash scripts/sample_ddpm_cifar10_pretrained.sh

4. FID Score

This script was modified from https://github.com/mseitzer/pytorch-fid.

# pre-compute the stats of CIFAR-10 dataset
python fid_score.py --save-stats data/cifar10_images run/fid_stats_cifar10.npz --device cuda:0 --batch-size 256
# Compute the FID score of sampled images
python fid_score.py run/sample/ddpm_cifar10_pruned run/fid_stats_cifar10.npz --device cuda:0 --batch-size 256

5. (Optional) Distributed Training and Sampling with Accelerate

This project supports distributed training and sampling.

python -m torch.distributed.launch --nproc_per_node=8 --master_port 22222 --use_env <ddpm_sample.py|ddpm_train.py> ...

A multi-processing example can be found at scripts/sample_ddpm_cifar10_pretrained_distributed.sh.

Prune Pre-trained DPMs from HuggingFace Diffusers

Example: google/ddpm-ema-bedroom-256

python ddpm_prune.py \
--dataset "<path/to/imagefoler>" \  
--model_path google/ddpm-ema-bedroom-256 \
--save_path run/pruned/ddpm_ema_bedroom_256_pruned \
--pruning_ratio 0.05 \
--pruner "<random|magnitude|reinit|taylor|diff-pruning>" \
--batch_size 4 \
--thr 0.05 \
--device cuda:0 \

The dataset and thr arguments only work for taylor & diff-pruning.

Example: CompVis/ldm-celebahq-256

python ldm_prune.py \
--model_path CompVis/ldm-celebahq-256 \
--save_path run/pruned/ldm_celeba_pruned \
--pruning_ratio 0.05 \
--pruner "<random|magnitude|reinit>" \
--device cuda:0 \
--batch_size 4 \

Results

  • DDPM on Cifar-10, CelebA and LSUN
  • Conditional LDM on ImageNet-1K 256

We also have some results on Conditional LDM for ImageNet-1K 256x256, where we finetune a pruned LDM for only 4 epochs. Will release the training script soon.

Acknowledgement

This project is heavily based on Diffusers, Torch-Pruning, pytorch-fid. Our experiments were conducted on ddim and LDM.

Citation

If you find this work helpful, please cite:

@inproceedings{fang2023structural,
  title={Structural pruning for diffusion models},
  author={Gongfan Fang and Xinyin Ma and Xinchao Wang},
  booktitle={Advances in Neural Information Processing Systems},
  year={2023},
}
@inproceedings{fang2023depgraph,
  title={Depgraph: Towards any structural pruning},
  author={Fang, Gongfan and Ma, Xinyin and Song, Mingli and Mi, Michael Bi and Wang, Xinchao},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={16091--16101},
  year={2023}
}