# Supervised GAN example for MNIST dataset
Most of the code is similar to gan_mnist.ipynb file except for few modifications on using labels to use on GAN.
Credits to Github Repo: https://github.com/soumith/ganhacks
on tips and techniques for training GAN models

In [1]:
#import libraries
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Activation
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt
import sys
import numpy as np

#set optimizer to adam
optimizer = Adam(0.0002, 0.5)
batch_size = 128

#generator model
def generator():
    #additionally added dropout for regularization
    model = Sequential()
    model.add(Dense(256, input_dim=(100+1)))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(28*28*1, activation='tanh'))
    model.add(Reshape((28,28,1)))
    
    model.summary()

    return model

#discriminator model
def discriminator():

    model = Sequential()

    model.add(Flatten(input_shape=(28,28,1)))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.2))
    
    inputs = Input(shape=(28,28,1))
    features = model(inputs)
    
    #output both valid_img and valid_label
    valid_img = Dense(1,activation='sigmoid')(features) # 0 if fake image, 1 if real image
    valid_label = Dense(10+1,activation='softmax')(features) # 0~9 for classfication, 10 for unknown/fake image

    return Model(inputs=inputs,outputs=[valid_img,valid_label])

#compile models
D = discriminator()
D.compile(loss=['binary_crossentropy','sparse_categorical_crossentropy'],
            optimizer=optimizer,
            metrics=['accuracy'])
G = generator()
G.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

#100 for random noise and +1 for the label as input
inputs = Input(shape=(101,))
gen = G(inputs)
D.trainable = False
target_img,target_label = D(gen)
stacked = Model(inputs,[target_img,target_label])

#compile stacked model
stacked.compile(loss=['binary_crossentropy','sparse_categorical_crossentropy'],
            optimizer=optimizer,
            metrics=['accuracy'])

#load and normalize data
(X_train, y_train), (_, _) = mnist.load_data()
X_train = X_train / 127.5 - 1
#expand dim to fit shape
X_train = np.expand_dims(X_train, axis=3)
#1 for valid images, 0 for fake images
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(50):
  for batch in range(int(X_train.shape[0]/batch_size)):
    #get images and labels by batch size
    imgs = X_train[batch*batch_size:(batch+1)*batch_size]
    labels = y_train[batch*batch_size:(batch+1)*batch_size].reshape(batch_size,1)
    
    #generate noise and add label to generate fake image
    noise = np.random.normal(0,1,(batch_size,100))
    noise = np.concatenate((noise,labels),axis=1)
    fake_imgs = G.predict(noise)
    
    #train discriminator on both labels and real/fake classification
    D.trainable = True
    d_loss_real = D.train_on_batch(imgs,[valid,labels]) #use labels for real data
    d_loss_fake = D.train_on_batch(fake_imgs,[fake,10*np.ones(labels.shape)]) #use 10 for fake data as labels
    
    #generate another random noise for training
    noise = np.random.normal(0,1,(batch_size,100))
    noise = np.concatenate((noise,labels),axis=1)
    D.trainable = False #make sure discriminator's weights are fixed during generator's training
    g_loss = stacked.train_on_batch(noise, [valid,labels]) #train generator on whole stacked model
  print("Epoch: ",epoch)
  #print 0~9 generated images per epoch to observe performance
  for i in range(10):
    print("Image: ",i)
    noise = np.random.normal(0,1,(1,100))
    noise = np.concatenate((noise,[[i]]),axis=1)
    plt.imshow(G.predict(noise).reshape(28,28))
    plt.show()

## Results
Pretty promising results by the end for simpler digit patterns from 0~5. Especially, some digits are almost indistinguishable from real human written digits. Given further training, noisy output from 6~9 may also benefit from training.