This repository contains the official implementation of the ICCV 2023 paper
Diffusion Model as Representation Learner Xingyi Yang, Xinchao Wang
In this paper, we conduct an in-depth investigation of the representation power of DPMs, and propose a novel knowledge transfer method that leverages the knowledge acquired by generative DPMs for recognition tasks. We introduce a novel knowledge transfer paradigm named RepFusion. Our paradigm extracts representations at different time steps from off-the-shelf DPMs and dynamically employs them as supervision for student networks, in which the optimal time is determined through reinforcement learning.
Basicly, we contain the code for distillation, the 3 downstream tasks including classification
, segmentation
, landmark
├── classification_distill/
# code for image classification
# and knowledge distillation
├── configs/
├── <DATASET>-<DISTILL_LOSS>/
ddpm-<BACKBONE>_<DISTILL_LOSS>.py
# config file for Repfussion on <DATASET>
# with <DISTILL_LOSS> as loss function
# and <BACKBONE> as architecture
├── baseline/
<BACKBONE>_<BATCHSIZE>_<DATASET>_finetune.py
├── mmcls/
├── models/
├── guided_diffusion/
# code taken from the guided diffusion repo
├── classifiers/
├── kd.py
# distillation baselines
├── repfusion.py
# core code for distillation from diffusion model
├── landmark/
# code for facial landmark detection
├── configs/face/2d_kpt_sview_rgb_img/topdown_heatmap/wflw
<BACKBONE>_wflw_256x256_baseline_<BATCHSIZE>.py
<BACKBONE>_wflw_256x256_<BATCHSIZE>_repfussion.py
├── segmentation/
# code for face parsing
├── configs/
├── celebahq_mask/
bisenetv1_<BACKBONE>_lr5e-3_2x8_448x448_160k_coco-celebahq_mask_baseline.py
bisenetv1_<BACKBONE>_lr5e-3_2x8_448x448_160k_coco-celebahq_mask_repfusion.py
We mainly depend on 4 packages, namely
mmclassification
. Please install the enviroment using INSTALLmmsegmentation
. Please install the enviroment using INSTALLmmpose
. Please install the enviroment using INSTALLdiffusers
. Install viapip install --upgrade diffusers[torch]
, or go to the official repo for help.
We use 4 datasets in our paper. Please put them all under the data/<DATASET>
- CelabAMask-HQ, and please follow the guideline on official repo.
- WFLW. For WFLW data, please download images from WFLW Dataset. Please download the annotation files from wflw_annotations.
- TinyImageNet, please download dataset using this script.
- CIAFR10,
mmcls
will automatically download it for you.
-
For DPMs based on huggingface, the model will be automatically downloaded. Just make sure you gets the correct model id, e.g.
-
For DPM on Tiny-ImageNet, we download it from the guided-diffusion repo from the link weight.
- We first do distillation from a trained DPM
# <CONFIG_NAME>: config path for distillation
# <GPU_NUMS>: num of gpus for training
cd classification_distill
bash tools/dist_train.sh <CONFIG_NAME> <GPU_NUMS>
- Put the saved checkpoint in config as installization for downstream training. For example
model = dict(
...
backbone=dict(
...
backbone_cfg=dict(
...
init_cfg=dict(
type='Pretrained',
checkpoint=<CHECKPOINT_PATH> ,
# Put the disilled checkpoint hear
prefix='student.backbone.')
)
),
)
- Do downstream training
# <CONFIG_NAME>: config path for distillation
# <GPU_NUMS>: num of gpus for training
# <TASK_NAME>: either 'classification_distill', 'segmentation' or 'landmark'
cd <TASK_NAME>
bash tools/dist_train.sh <CONFIG_NAME> <GPU_NUMS>
@article{yang2023diffusion,
author = {Xingyi Yang, Xinchao Wang},
title = {Diffusion Model as Representation Learner},
journal = {International Conference on Computer Vision (ICCV)},
year = {2023},
}