## Conditional WGAN for generating handwritten digits

### Importing the libraries

In [16]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import backend
from keras.optimizers import RMSprop
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers import Flatten
from keras.layers import Reshape
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import BatchNormalization
from keras.layers import LeakyReLU
from keras.layers import Dropout
from keras.layers import Embedding
from keras.layers import Concatenate
import matplotlib.pyplot as plt

### Define networks

In [17]:
# Clip weights to a given hypercube
class ClipConstrainer(keras.constraints.Constraint):
  def __init__(self, clip_value):
    self.clip_value = clip_value

  def __call__(self, weights):
    return backend.clip(weights, self.clip_value * -1, self.clip_value)

  def get_config(self):
    return {'clip_value' : self.clip_value}

#Custom loss function to calculate wasserstein loss
def wasserstein_loss(y_true, y_pred):
  return backend.mean(y_true * y_pred)

def define_critic(image_shape = (28,28,1), n_classes = 10):
  in_label = Input(shape=(1,))
  li = Embedding(n_classes, 50)(in_label)
  n_nodes = image_shape[0]*image_shape[1]
  li = Dense(n_nodes)(li)
  li = Reshape((image_shape[0], image_shape[1], 1))(li)
  in_image = Input(shape = image_shape)
  merge = Concatenate()([in_image, li])

  weight_init = keras.initializers.RandomNormal(stddev = 0.02)
  weight_constrain = ClipConstrainer(0.01)

  model = Conv2D(64, (4,4), (2,2), padding="same", kernel_initializer=weight_init, kernel_constraint=weight_constrain, input_shape=image_shape)(merge)
  model = LeakyReLU(0.2)(model)
  model = BatchNormalization()(model)

  model = Conv2D(64, (4,4), (2,2), padding="same", kernel_initializer=weight_init, kernel_constraint=weight_constrain)(model)
  model = LeakyReLU(0.2)(model)
  model = BatchNormalization()(model)

  model = Flatten()(model)
  model = Dropout(0.4)(model)
  model = Dense(1)(model)
    
  critic = Model([in_image, in_label], model)
  opt = RMSprop(learning_rate=0.00005)
  critic.compile(optimizer=opt, loss=wasserstein_loss, metrics=['accuracy'])
  return critic

def define_generator(latent_dim, n_classes=10):
  in_label = Input(shape = (1,))
  li = Embedding(n_classes, 50)(in_label)
  weight_init = keras.initializers.RandomNormal(stddev= 0.02)
  num_nodes = 7 * 7
  li = Dense(num_nodes)(li)
  li = Reshape((7,7,1))(li)
  in_lat = Input(shape = (latent_dim,))
  num_nodes = 7 * 7 * 128
  model = Dense(num_nodes, kernel_initializer=weight_init)(in_lat)
  model = LeakyReLU(0.2)(model)
  model = Reshape((7,7,128))(model)
  merge = Concatenate()([model, li])

  model = Conv2DTranspose(128, (4,4), strides=(2,2), padding="same", kernel_initializer=weight_init)(merge)
  model = BatchNormalization()(model)
  model = LeakyReLU(0.2)(model)

  model = Conv2DTranspose(128, (4,4), strides=(2,2), padding="same", kernel_initializer=weight_init)(model)
  model = BatchNormalization()(model)
  model = LeakyReLU(0.2)(model)

  out_layer = Conv2D(1, (7,7), activation="tanh", padding="same", kernel_initializer=weight_init)(model)
  gen = Model([in_lat, in_label], out_layer)
  return gen

def define_gan(critic, generator):
  critic.trainable = False
  gen_noise, gen_label = generator.input
  gen_output = generator.output
  gan_output = critic([gen_output, gen_label])
  model = Model([gen_noise, gen_label], gan_output)
  opt= RMSprop(learning_rate=0.00005)
  model.compile(optimizer=opt, loss=wasserstein_loss)
  return model

### Helper functions for loading data

In [18]:
def load_real_samples():
  (train_x, train_y), (_,_) = keras.datasets.mnist.load_data()
  X = np.expand_dims(train_x, axis=-1)
  X = X.astype('float32')
  X = (X - 127.5) / 127.5
  return [X, train_y]

def generate_real_samples(dataset, num_samples):
  images, labels = dataset
  indices = np.random.randint(0, images.shape[0], num_samples)
  X, labels = images[indices], labels[indices]
  y = np.ones(num_samples) * -1
  return [X, labels], y

def generate_latent_points(latent_dim, num_samples, n_classes = 10):
  x_input = np.random.randn(latent_dim * num_samples)
  z_input = x_input.reshape(num_samples, latent_dim)
  labels = np.random.randint(0, n_classes, num_samples)
  return [z_input, labels]

def generate_fake_samples(gen, latent_dim, num_samples):
  z_input, label_input = generate_latent_points(latent_dim, num_samples)
  images = gen.predict([z_input, label_input])
  y = np.ones(num_samples)
  return [images, label_input], y

### Define the train step

In [19]:
from numpy.ma.core import mean
def summarize_performance(step, gen,critic, latent_dim):
  num_samples = 100
  [X, labels], _ = generate_fake_samples(gen, latent_dim, num_samples)
  X = (X + 1)/2

  #Generating images 
  for i in range(10*10):
    plt.subplot(10, 10, i+1)
    plt.axis('off')
    plt.imshow(X[i, :, :, 0], cmap="gray")
  plot_file_name = "generated_images_%04d.png" % (step + 1)
  plt.savefig('c_eval\\' + plot_file_name)
  plt.close()

  #Saving generator weights
  gen_file_name = "gen_%04d.h5" % (step + 1)
  gen.save('c_model\\' + gen_file_name)
  print(f">Saved {plot_file_name} and {gen_file_name}.")

  #Saving critic weights
  critic_file_name = "critic_%04d.h5" % (step + 1)
  gen.save('c_model\\' + critic_file_name)
  print(f">Saved {plot_file_name} and, critic and gen models.")

def plot_history(c_real_hist, c_fake_hist, gan_hist):
  plt.plot(c_real_hist, label="critic_real")
  plt.plot(c_fake_hist, label="critic_fake")
  plt.plot(gan_hist, label="GAN")
  plt.legend()
  plt.savefig("loss_history.png")
  plt.close()

def train_gan(gen, critic, gan, dataset, latent_dim, batch_size = 64, num_epochs = 20, num_critic = 5):
  batches_per_epoch = int(dataset[0].shape[0] / batch_size)
  c_real_hist, c_fake_hist, g_hist = list(), list(), list()
  for i in range(num_epochs):
    for j in range(batches_per_epoch):
        cr_tmp, cf_tmp = list(), list()
        for _ in range(num_critic):
          [X_real, labels_real], y_real = generate_real_samples(dataset, int(batch_size/2))
          c_loss_real = critic.train_on_batch([X_real, labels_real], y_real)
          cr_tmp.append(c_loss_real)
          [X_fake, labels_fake], y_fake = generate_fake_samples(gen, latent_dim, int(batch_size/2))
          c_loss_fake = critic.train_on_batch([X_fake, labels_fake], y_fake)
          cf_tmp.append(c_loss_fake)
        c_real_hist.append(mean(cr_tmp))
        c_fake_hist.append(mean(cf_tmp))
        [X_gan, labels] = generate_latent_points(latent_dim, batch_size)
        y_gan = -1 * np.ones(batch_size)
        g_loss = gan.train_on_batch([X_gan, labels], y_gan)
        g_hist.append(g_loss)
    print(f"Epoch: {i+1}/{num_epochs} c_real:{c_real_hist[-1]} c_fake:{c_fake_hist[-1]} GAN:{g_hist[-1]}")
    summarize_performance(i, gen,critic, latent_dim)
  plot_history(c_real_hist, c_fake_hist, g_hist)


### Initialize models and training

In [None]:
latent_dim = 10
critic = define_critic()
critic.summary()
generator = define_generator(latent_dim)
generator.summary()
gan = define_gan(critic, generator)
dataset = load_real_samples()
train_gan(generator, critic, gan, dataset, latent_dim)