AmbientGAN: Generative models from lossy measurements
This repository provides code to reproduce results from the paper AmbientGAN: Generative models from lossy measurements.
The training setup is as in the following diagram:
Here are a few example results:
Few more samples from AmbientGAN models trained with 1-D projections:
The rest of the README describes how to reproduce the results.
- Python 2.7
- Tensorflow >= 1.4.0
pip installation, use
$ pip install -r requirements.txt
Get the data
- MNIST data is automatically downloaded
- Get the celebA dataset here and put the jpeg files in
- Get the CIFAR-10 python data from here and put it in
Get inference models
We need inference models for computing the inception score.
For MNIST, you can train your own by
cd ./src/mnist/inf python train.py
[TODO]: Provide a pretrained model.
Inception model for use with CIFAR-10 is automatically downloaded.
Create experiment scripts
This will create scripts for all the experiments in the paper.
[Optional] If you want to run only a subset of experiments you can define the grid in
./create_scripts/DATASET_NAME/grid_*.sh or if you wish to tweak a lot of parameters, you can change
./create_scripts/DATASET_NAME/base_script.sh. Then run
./create_scripts/create_scripts.sh as above to create the corresponding scripts (remember to remove any previous files from
We provide scripts to train on multiple GPUs in parallel. For example, if you wish to use 4 GPUs, you can run:
./run_scripts/run_sequentially_parallel.sh "0 1 2 3"
This will start 4 GNU screens. Each program within the screen will attempt to acquire and run experiments from
./scripts/, one at a time. Each experiment run will save samples, checkpoints, etc. to
See results as you train
You can see samples for each experiment in
EXPT_DIR is defined based on the hyperparameters of the experiment. See
./src/commons/dir_def.py to see how this is done.
python src/aggregator_mnist.py python src/aggregator_cifar.py
This will create pickle files in
./results/ with the relevant data in a Pandas dataframe.
Now use the ipython notebooks
./plotting_cifar.ipynb to get the relevant plots. The generated plots are also saved to
./results/plots/ (make sure this directory exists)