AutoGAN-Distiller: Searching to Compress Generative Adversarial Networks
Yonggan Fu, Wuyang Chen, Haotao Wang, Haoran Li, Yingyan Lin, Zhangyang Wang
Accepted at ICML 2020 [Paper Link].
We propose AutoGAN-Distiller (AGD) Framework, among the first AutoML frameworks dedicated to GAN compression, and is also among a few earliest works that explore AutoML for GANs.
- AGD is established on a speciﬁcally designed search space of efﬁcient generator building blocks, leveraging knowledge from state-of-the-art GANs for different tasks.
- It performs differentiable neural architecture search under the target compression ratio (computational resource constraint), which preserves the original GAN generation quality via the guidance of knowledge distillation.
- We demonstrate AGD on two representative mobile-based GAN applications: unpaired image translation (using a CycleGAN), and super resolution (using an encoder-decoder GAN).
Unpaired image translation:
Unpaired Image Translation
horse2zebra, zebra2horse, summer2winter, winter2summer: Unpaired-dataset
Training (DIV2K+Flickr2K): SR-training-dataset
Evaluation (Set5, Set14, BSD100, Urban100): SR-eval-dataset
AGD_SR are the source codes for unpaired image translation task and super resolution task respectively. The codes for pretrain, search, train from scratch and eval are in the
AGD_ST/search as an example. All the configurations during pretrain, search, train from scratch, eval are in
config_eval.py respectively. Please specify the target dataset
C.dataset and change the dataset path
C.dataset_path in the three config files to the real paths on your PC.
env.yml for the complete conda environment. Create a new conda environment:
conda env create -f env.yml conda activate pytorch
In partiqular, if the thop package encounters some version conflicts, please specify the thop version:
pip install thop==0.0.31.post1912272122
Step 1: Pretrain the Supernet
- Switch to the
C.pretrain = Truein
Start to pretrain:
The checkpoints during pretraining are saved at
Step 2: Search
C.pretrain = 'ckpt/pretrain'in
Start to search:
Step 3: Train the derived network from scratch
C.load_path = 'ckpt/search'in
Start to train from scratch:
Step 4: Eval
C.load_path = 'ckpt/search'and
C.ckpt = 'ckpt/finetune/weights.pt'in
- Start to evaluate on the testing dataset:
The result images are saved at
Two differences in Super Resolution tasks
Please download the checkpoint of original ESRGAN (teacher model) from pretrained ESRGAN and move it to the directory
The step 3 is splitted into two steps, i.e., first pretrain the derived architecture with only content loss and then finetune with perceptual loss:
C.pretrain = Truein
C.pretrain = 'ckpt/finetune_pretrain/weights.pt'in
Pretrained models are provided at pretrained AGD.
To evaluate the pretrained models, please copy the network architecture definition and pretrained weights to the corresponding directories:
cp arch.pt ckpt/search/ cp weights.pt ckpt/finetune/
then do the evaluation following step 4.
Our Related Work
Please also check our concurrent work on a unified optimization framework combining model distillation, channel pruning and quantization for GAN compression: