# Generative Adversarial Networks

In [None]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.optimizers import SGD
import numpy as np
from PIL import Image
import argparse
import math
import tqdm
import os

def generator_model():
    model = Sequential()
    model.add(Dense(input_dim=100, output_dim=1024))
    model.add(Activation('tanh'))
    model.add(Dense(128*7*7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((128, 7, 7), input_shape=(128*7*7,)))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Convolution2D(64, 5, 5, border_mode='same'))
    model.add(Activation('tanh'))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Convolution2D(1, 5, 5, border_mode='same'))
    model.add(Activation('tanh'))
    return model

def discriminator_model():
    model = Sequential()
    model.add(Convolution2D(
                        64, 5, 5,
                        border_mode='same',
                        input_shape=(1, 28, 28)))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Convolution2D(128, 5, 5))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('tanh'))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

def generator_containing_discriminator(generator, discriminator):
    model = Sequential()
    model.add(generator)
    discriminator.trainable = False
    model.add(discriminator)
    return model

def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[2:]
    image = np.zeros((height*shape[0], width*shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
            img[0, :, :]
    return image

def normalize_and_save_images(generated_images, epoch, index, path):
    image = np.squeeze(generated_images)
    image = (image - image.min()) / (image.max() - image.min()) * 255.0
    Image.fromarray(image.astype(np.uint8)).save(path+"epoch_"+str(epoch)+"_batch_"+str(index)+".png")

def train(X_train,
          saving_path = os.getcwd()+"/",
          batch_size = 100,
          learning_rate=0.0005, 
          momentum=0.9, 
          nesterov=True,
          n_epochs=100):
    
    """
    Parameters
    ----------
    X_train: numpy array
        training examples
        num_images x x_pixels x y_pixels
    saving_path: str
        the path that generated images are saved
    batch_size: int
        number of images in mini batch
    learning_rate: float
        learning rate
    momemtum: float
        momentum parameter
    nesterov: bool
        whether to use nesterov momentum or not
    n_epochs: int
        number of cycles of through the dataset
    """

    # Initialize the models
    
    discriminator = discriminator_model()
    generator = generator_model()
    discriminator.trainable = True
    discriminator_on_generator = \
        generator_containing_discriminator(generator, discriminator)
        
    # Initialize optimization object
    optim = SGD(lr=learning_rate, momentum=momentum, nesterov=True)
    
    # Compile models
    generator.compile(loss='binary_crossentropy',
                      optimizer=optim)
    
    discriminator.compile(loss='binary_crossentropy', 
                          optimizer=optim)
    
    discriminator_on_generator.compile(loss='binary_crossentropy', 
                                       optimizer=optim)
    
    # Iterate across epochs
    for epoch in range(n_epochs):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0] / batch_size))
        
        # Randomize order in which training data is seen
        select = np.random.permutation(X_train.shape[0])
                

        # Iterate many times to train discriminator
        for index in range(int(X_train.shape[0]/batch_size)):
            # Make the generator sample random images from noise
            noise = np.random.uniform(low=-1, high=1, size=(batch_size, 100))
            generated_images = generator.predict(noise, verbose=0)
            
            # Take a subset of real images for this mini batch
            batch_select = select[index * batch_size: (index + 1) * batch_size]
            real_images = X_train[batch_select]

            
            # Produce training data for discriminator
            X = np.concatenate((real_images, generated_images))
            y = [1] * batch_size + [0] * batch_size
            
            d_loss = discriminator.train_on_batch(X, y)
        
            # Keep discriminator fixed and train generator
            noise = np.random.uniform(low=-1, high=1, size=(batch_size, 100))
            discriminator.trainable = False
            g_loss = discriminator_on_generator.train_on_batch(noise, [1] * batch_size)
            discriminator.trainable = True

            # Normalize and save generated images from every 10th mini batch
            if index % 100 == 0:
                normalize_and_save_images(np.expand_dims(combine_images(generated_images),axis = 0), epoch, index, saving_path)
                print 60*"-" 
                print("d_loss : %f" % d_loss)
                print("g_loss : %f" % g_loss)
    #Save weights
    generator.save_weights('generator', True)
    discriminator.save_weights('discriminator', True)

## Import mnist data

In [None]:
from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])

## GAN on mnist

In [None]:
train(X_train,batch_size = 100)

## GAN on image 5 of mnist

In [None]:
X_five = X_train[y_train == 5]
train(X_five,batch_size = 100)