Learning Disentangled Representation by Exploiting Pretrained Generative Models: A Contrastive Learning View
Learning Disentangled Representation by Exploiting Pretrained Generative Models: A Contrastive Learning View
Xuanchi Ren*, Tao Yang*, Yuwang Wang and Wenjun Zeng
ICLR 2022
* indicates equal contribution
✅ Update StyleGAN2
✅ Update SNGAN
🔲 Update VAE
🔲 Update Glow
✅ Evaluation
In this repo, we propose an unsupervised and model-agnostic method: Disentanglement via Contrast (DisCo) in the Variation Space. This code discovers disentangled directions in the latent space and extract disentangled representations from images with Contrastive Learning. DisCo achieves the state-of-the-art disentanglement given pretrained non-disentangled generative models, including GAN, VAE, and Flow.
NOTE: The following results are obtained in a completely unsupervised manner. More results (including VAE and Flow) are presented in Appendix.
FFHQ StyleGAN2 | |
---|---|
Pose | Smile |
Race | Oldness |
Overexpose | Hair |
Shapes3D StyleGAN2 | |
---|---|
Wall Color | Floor Color |
Object Color | Pose |
Car3D StyleGAN2 | |
---|---|
Azimuth | Yaw |
Anime SNGAN | |
---|---|
Pose | Natureness |
Glass | Tone |
NOTE: DisCo achieves the state-of-the-art disentanglement
Shapes3D | |
---|---|
MIG | DCI |
Car3D | |
---|---|
MIG | DCI |
MPI3D | |
---|---|
MIG | DCI |
- NVIDIA GPU + CUDA CuDNN
- Python 3
- Clone the repository:
git clone https://github.com/xrenaa/DisCo.git
cd DisCo
- Dependencies (To Do):
We recommend running this repository using Anaconda.
- Docker:
Alternatively, you can useDocker
to run the code. We providethomasyt/gan-disc
for easy use.
Please download the pre-trained models from the following links and put them to the corresponding paths.
Path | Description |
---|---|
shapes3d_StyleGAN | StyleGAN2 model pretrained on shapes3d : range from 0-4.pt . Corresponding path: ./pretrained_weights/shapes3d/ . |
cars3d_StyleGAN | StyleGAN2 model pretrained on cars3d : range from 0-4.pt . Corresponding path: ./pretrained_weights/cars3d/ . |
mpi3d_StyleGAN | StyleGAN2 model pretrained on mpi3d : range from 0-4.pt . Corresponding path: ./pretrained_weights/mpi3d/ . |
shapes3d_VAE | VAE model pretrained on shapes3d : range from VAE_0-4 . Corresponding path: ./pretrained_weights/shapes3d/ . |
cars3d_VAE | VAE model pretrained on cars3d : range from VAE_0-4 . Corresponding path: ./pretrained_weights/cars3d/ . |
mpi3d_VAE | VAE model pretrained on mpi3d : range from VAE_0-4 . Corresponding path: ./pretrained_weights/mpi3d/ . |
For SNGAN, you can run the following code to download the weights for MNIST
and Anime
:
python ./pretrained_weights/download.py
To train the models, make sure you download the required models and put them to the correct path.
python train.py \
--G stylegan \
--dataset 0 \
--exp_name your_name \
--B 32 \
--N 32 \
--K 64
For --dataset
, you can choose 0
for shapes3D, 1
for mpi3d, 2
for cars3d.
python train.py \
--G sngan \
--dataset 5 \
--exp_name your_name \
--B 32 \
--N 32 \
--K 64
For --dataset
, you can choose 5
for MNIST, 6
for Anime.
-
Dependencies: For evaluation, you will need
tensorflow
,gin-config
. -
Download the dataset (except for Shapes3D):
cd data
./dlib_download_data.sh
For Shapes3D, you will first need to download the data from Google Cloud Storage. Click on this link and left-click the file 3dshapes.h5
to download it. Then you should put it under directory data
.
- Run the evaluation:
python evaluate.py --dataset 0 --exp_name your_name
For --dataset
, you can choose 0
for shapes3D, 1
for mpi3d, 2
for cars3d (you can only evaluate the performance on these datasets). The results will be put under the same directory with the checkpoint.
ProgGAN and BigGAN are based on: https://github.com/anvoynov/GANLatentDiscovery.
StyleGAN are based on: https://github.com/rosinality/stylegan2-pytorch.
Disentanglement metrics are based on: https://github.com/google-research/disentanglement_lib.
@inproceedings{ren2022DisCo,
title = {Learning Disentangled Representation by Exploiting Pretrained Generative Models: A Contrastive Learning View},
author = {Xuanchi Ren, Tao Yang, Yuwang Wang, Wenjun Zeng},
booktitle = {ICLR},
year = {2022}
}