This repository contains the code for the ICLR 2021 paper "Counterfactual Generative Networks" by Axel Sauer and Andreas Geiger. If you want to take the CGN for a spin and generate counterfactual images, you can try out the Colab below.
If you find our code or paper useful, please cite
@inproceedings{Sauer2021ICLR,
author = {Axel Sauer, Andreas Geiger},
title = {Counterfactual Generative Networks},
booktitle = {International Conference on Learning Representations (ICLR)},
year = {2021}}Install anaconda (if you don't have it yet)
wget https://repo.anaconda.com/archive/Anaconda3-2020.11-Linux-x86_64.sh
bash Anaconda3-2020.11-Linux-x86_64.sh
source ~/.profileClone the repo and build the environment
git clone https://github.com/autonomousvision/counterfactual_generative_networks
cd counterfactual_generative_networks
conda env create -f environment.yml
conda activate cgnMake all scripts executable: chmod +x scripts/*. Then, download the datasets (colored MNIST, Cue-Conflict, IN-9) and the pre-trained weights (CGN, U2-Net). Comment out the ones you don't need.
./scripts/download_data.sh
./scripts/download_weights.shThe main functions of this sub-repo are:
- Generating the MNIST variants
- Training a CGN
- Generating counterfactual datasets
- Training a shape classifier
We provide well-working configs and weights in mnists/experiments. To train a CGN on, e.g., Wildlife MNIST, run
python mnists/train_cgn.py --cfg mnists/experiments/cgn_wildlife_MNIST/cfg.yamlFor more info, add --help. Weights and samples will be saved in mnists/experiments/.
To generate the counterfactuals for, e.g., double-colored MNIST, run
python mnists/generate_data.py \
--weight_path mnists/experiments/cgn_double_colored_MNIST/weights/ckp.pth \
--dataset double_colored_MNIST --no_cfs 10 --dataset_size 100000Make sure that you provide the right dataset together with the weights. You can adapt the weight-path to use your own weights. The command above generates ten counterfactuals per shape.
The classifier training uses Tensor datasets, so you need to save the non-counterfactual datasets as tensors. For DATASET = {colored_MNIST, double_colored_MNIST, wildlife_MNIST}, run
python mnists/generate_data.py --dataset DATASETTo train, e.g., a shape classifier (invariant to foreground and background) on wildlife MNIST, run,
python mnists/train_classifier.py --dataset wildlife_MNIST_counterfactualAdd --help for info on the available options and arguments. The hyperparameters are unchanged for all experiments.
The main functions of this sub-repo are:
- Training a CGN
- Generating data (samples, interpolations, or a whole dataset)
- Training an invariant classifier ensemble
Run
python imagenet/train_cgn.py --model_name MODEL_NAMEThe default parameters should give you satisfactory results. You can change them in imagenet/config.yml. For more info, add --help. Weights and samples will be saved in imagenet/data/MODEL_NAME.
Samples. To generate a dataset of counterfactual images, run
python imagenet/generate_data.py --mode random --weights_path imagenet/weights/cgn.pth \
--n_data 100 --weights_path imagenet/weights/cgn.pth --run_name RUN_NAME \
--truncation 0.5 --batch_sz 1The results will be saved in imagenet/data. For more info, add --help. If you want to save only masks, textures, etc., you need to change this directly in the code (see line 206).
The labels will be stored in a csv file. You can read them as follows:
import pandas as pd
df = pd.read_csv(path, index_col=0)
df = df.set_index('im_name')
shape_cls = df['shape_cls']['RUN_NAME_0000000.png']Generating a dataset to train a classfier. Produce one dataset with --run_name train, the other
one with --run_name val. If you have several GPUs available, you can index the name, e.g., --run_name train_GPU_NUM. The class ImagenetCounterfactual will glob all these datasets and generate a single, big
training set. Make sure to set --batch_sz 1. With a larger batch size, a batch will be saved as a single png; this is useful for visualization, not for training.
Interpolations. To generate interpolation sheets, e.g., from a barn (425) to whale (147), run
python imagenet/generate_data.py --mode fixed_classes \
--n_data 1 --weights_path imagenet/weights/cgn.pth --run_name barn_to_whale \
--truncation 0.3 --interp all --classes 425 425 425 --interp_cls 147 --save_noiseYou can also do counterfactual interpolations, i.e., interpolating only over, e.g., shape, by setting --interp shape.
Interpolation Gif. To generate a gif like in the teaser (Sample an image of class $1, than interpolate to shape $2, then background $3, then shape $4, and finally back to $1), run
./scripts/generate_teaser_gif.sh 992 293 147 330The positional arguments are the classes, see imagenet labels for the available options.
Training. First, you need to make sure that you have all datasets in imagenet/data/. Download Imagenet, e.g., from Kaggle, produce a counterfactual dataset (see above), and download the Cue-Conflict and BG-Challenge dataset (via the download script in scripts).
To train a classifier on a single GPU with a pre-trained Resnet-50 backbone, run
python imagenet/train_classifier.py -a resnet50 -b 32 --lr 0.001 -j 6 \
--epochs 45 --pretrained --cf_data CF_DATA_PATH --name RUN_NAMEAgain, add --help for more information on the possible arguments.
Distributed Training. To switch to multi-GPU training, run echo $CUDA_VISIBLE_DEVICES to see if the GPUs are visible. In the case of a
single node with several GPUs, you can run, e.g.,
python imagenet/train_classifier.py -a resnet50 -b 256 --lr 0.001 -j 6 \
--epochs 45 --pretrained --cf_data CF_DATA_PATH --name RUN_NAME \
--rank 0 --multiprocessing-distributed --dist-url tcp://127.0.0.1:8890 --world-size 1If your setup differs, e.g., several GPU machines, you need to adapt the rank and world size.
Visualization. To visualize the Tensorboard outputs, run tensorboard --logdir=imagenet/runs and open the local address in your browser.
We like to acknowledge several repos of which we use parts of code, data, or models in our implementation:



