# Tensorflow-Generative-Model-Collections
Tensorflow implementation of various GANs and VAEs.

## Environment Setup

In [None]:
!export DISPLAY=:0

## Imports

In [None]:
import os

## GAN Variants
from GAN import GAN
from CGAN import CGAN
from infoGAN import infoGAN
from ACGAN import ACGAN
from EBGAN import EBGAN
from WGAN import WGAN
from WGAN_GP import WGAN_GP
from DRAGAN import DRAGAN
from LSGAN import LSGAN
from BEGAN import BEGAN

## VAE Variants
from VAE import VAE
from CVAE import CVAE

from utils import show_all_variables
from utils import check_folder

import tensorflow as tf
import argparse

## Configuration

### Select the type of GAN

In [None]:
#Select the type of GAN. 
#Options = ['GAN', 'CGAN', 'infoGAN', 'ACGAN', 'EBGAN', 'BEGAN', 'WGAN', 'WGAN_GP', 'DRAGAN', 'LSGAN', 'VAE', 'CVAE']
#Default = 'GAN'
gan_type = 'GAN'

### Select the Dataset

#### MNIST

In [None]:
#Select Dataset
#Options = ['mnist', 'fashion-mnist', 'celebA']
#Default = 'mnist'
dataset = 'mnist'

In [None]:
!mkdir -p data/mnist
!wget -Nq http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz -P data/mnist
!wget -Nq http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz -P data/mnist
!wget -Nq http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz -P data/mnist
!wget -Nq http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz -P data/mnist

### Select Parameters

In [None]:
epoch = 20
batch_size = 64
z_dim = 62
checkpoint_dir = os.path.relpath('checkpoint')
result_dir = os.path.relpath('results')
log_dir = os.path.relpath('logs')

## Run the Model

In [None]:
tf.reset_default_graph()
models = [GAN, CGAN, infoGAN, ACGAN, EBGAN, WGAN, WGAN_GP, DRAGAN,
              LSGAN, BEGAN, VAE, CVAE]
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: # declare instance for GAN
    gan = None
    for model in models:
        if gan_type == model.model_name:
            gan = model(sess,
                        epoch=epoch,
                        batch_size=batch_size,
                        z_dim=z_dim,
                        dataset_name=dataset,
                        checkpoint_dir=checkpoint_dir,
                        result_dir=result_dir,
                        log_dir=log_dir)
    # build graph
    gan.build_model()

    # show network architecture
    show_all_variables()

    # launch the graph in a session
    gan.train()
    print(" [*] Training finished!")

    # visualize learned generator
    gan.visualize_results(args.epoch-1)
    print(" [*] Testing finished!")
    