<a href="https://colab.research.google.com/github/Ayan-Vishwakarma/Keras-Implementation-of-Dense-and-DC-NSGAN-WGAN-WGANGP-etc/blob/main/WGAN_GP_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt

# WGAN_GP keras Model for faster development

Only discriminator and Generator model are needed to be constructed.The model takes care of the whole WGAN_GP training process.

WGAN_GP -----> Subclass of Keras Model class.

Requires a trainable Generator model, a trainable Discriminator model and the dimension of latent space on which generator acts.

> fit() 

>>> Keras Model fit function that requires the images whose distribution has to be learnt.All other working are the same.

**Important: Here batch_size must be equal to  (mini_batch_size x n_critic) because a given batch is divided to n_critic mini_batches on which critic is trained n_critic times.And because of this reason, the given batch_size in fit method should be divisible by n_critic for best memory utilization.
Preffered n_critic are 2,4,8,... and mini_batch_size are 16,32,64,...  
This poses another problem that larger images are restricted for very small mini_batch_sizes using this approach since the autograph does not seems to let outer variables use graph tensors,restricting fine grained control.However the train_step can be easily modified to take some permutations of batch_size to increase mini_batch_size.**

> training_loop( xt , n_iter , batch_size , sampling_interval )

>>> xt :: The numpy ndarray in which images are stored,i.e, the training dataset. shape = (N,H,W,C)
>>>n_iter :: The no of iterations to train
>>>batch_size :: The batch size used for training
>>>sampling_interval :: The interval after which the images are to be sampled
>>> return -> EM loss and GP loss after each iteration
>> Note::Here, the training_loop returns the original loss, which fit returns the estimated EM distace,i.e.,the negative of the loss returned by training_loop.

> compile( opt_gen , opt_disc , opt_GP , n_critic , lmbda=10 )

>>> It requires 3 different optmizers:one for generator model, one for discriminator model for reducing EM distance and the third for discriminator for reducing gradient penalty.

>>>n_critic is the number of times critic has to be trained

>>>lambda is multiplied to GP_loss.

 ** Here 2 different optimizers are required for discriminator model.**
> call(inputs , with_discriminator=None)

>>> inputs :: The input to be given to generator.This input is derived from latent space.

>>> with_discriminator :: Boolean, If true the outputs are further passed to discriminator.Only valid if discriminator exists.

> RemoveDiscriminator(self):: Removes the discriminator from the model leaving behind only the generator model for generating images.Can be used when training is complete.

> SampleImages(self,x,y) :: Sample images from generator. A total of x multiplied by y images are samples displayed in x rows and y columns.

> EM(labels,logits) :: Estimates the EM distance given the binary label representing in which class an object belong vs the output of discriminator that best estimates the EM distance.

Attributes:

generator :: The generator model for generating images

discriminator :: The critic model which estimates the EM distance between the current generator's distribution and real distribution.

z_dim :: Latent space dimension

In [171]:
class WGAN_GP(keras.models.Model):
 
  def __init__(self,generator,discriminator,z_dim,**kwargs):
    super().__init__(**kwargs)
    self.with_disc = True
    self.generator = generator
    self.discriminator = discriminator
    self.z_dim = z_dim
    self.gan = keras.Sequential([generator,discriminator])
    assert ((self.gan.layers[0].trainable == True) and (self.gan.layers[1].trainable == True)),"Generator and Discriminator should be trainable"
    
  def call(self,inputs,with_discriminator=None):
    if with_discriminator and self.with_disc:
      return self.gan(inputs)      
    else:
      return self.generator(inputs)
 
  def compile(self,opt_gen,opt_disc,opt_GP,n_critic,lmbda=10):
    super().compile(optimizer = opt_gen,loss = self.EM)
    self.opt_gen = opt_gen
    self.opt_disc= opt_disc
    self.opt_GP = opt_GP
    self.n_critic = n_critic
    assert ( lmbda > 0 ),"lambda value should be strictly greater than 0"
    self.lmbda = lmbda
 
  def train_step(self,imgs):
    if isinstance(imgs,tuple):
      imgs = data[0]
    batch_size =  (tf.shape(imgs)[0] // self.n_critic)
    
    for i in range(self.n_critic):
      x = imgs[i*batch_size:(i+1)*batch_size]

      x_ = generator(tf.random.normal((batch_size,self.z_dim)))
 
      with tf.GradientTape(watch_accessed_variables=True) as tape:
        xin = tf.concat([x,x_],axis=0)
        yin = tf.concat([tf.ones((batch_size,1)),tf.zeros((batch_size,1))],axis = 0)
        loss = self.EM(yin,self.discriminator(xin))
 
      grads = tape.gradient(loss,self.discriminator.trainable_weights)#
      self.opt_disc.apply_gradients(zip(grads,self.discriminator.trainable_weights))#
      del xin,yin,grads
 
      t = tf.random.uniform((batch_size,1,1,1))
      xche = t*x + (1-t)*x_
 
      with tf.GradientTape(watch_accessed_variables=True,persistent = True) as Tape:
        Tape.watch(xche)
        with tf.GradientTape(watch_accessed_variables=True,persistent=True) as tape:
          tape.watch(xche)
          y =  self.discriminator(xche)
        Dx = tape.gradient(y,xche)
        GP_loss = self.lmbda * tf.math.reduce_mean((tf.math.reduce_euclidean_norm(Dx,axis=[1,2,3]) - 1.)**2)
 
      grads = Tape.gradient(GP_loss,self.discriminator.trainable_weights,unconnected_gradients=tf.UnconnectedGradients.ZERO)#
 
      del tape,Tape,xche,y,Dx,x,x_
      self.opt_GP.apply_gradients(zip(grads,self.discriminator.trainable_weights))#
 
    with tf.GradientTape(watch_accessed_variables=True) as tape:
      xin = tf.random.normal(tf.stack([batch_size,self.z_dim],axis=0))
      yin = tf.ones((batch_size,1))
      gloss = self.EM(yin,self.gan(xin))
    
    grads = tape.gradient(gloss,self.generator.trainable_weights)
    self.opt_gen.apply_gradients(zip(grads,self.generator.trainable_weights))
 
    return {"EM_Distance":-loss,"GP_loss":GP_loss,"generative_loss":gloss}

  def training_loop(self,xt,n_iter,batch_size,sampling_interval):
    self.gan.compile( optimizer = self.opt_gen,loss = self.EM)
    self.discriminator.compile( optimizer = self.opt_disc,loss = self.EM)
    disc_weight = discriminator.trainable_weights
    self.discriminator.trainable = False
    assert self.discriminator.trainable == False

    label_discriminator = np.concatenate([np.ones(batch_size),np.zeros(batch_size)],axis=0)
    label_generator = np.ones(batch_size)
    losses = []
    
    for i in range(n_iter):
      for j in range(self.n_critic):
          
        x = xt[np.random.randint(len(xt),size= batch_size)]
        x_ = self.generator(np.random.randn(batch_size,self.z_dim))
        loss = self.discriminator.train_on_batch(np.concatenate([x,x_],axis=0),label_discriminator)
        
        t = tf.random.uniform((batch_size,1,1,1))
        xche = t*x + (1-t)*x_
        
        with tf.GradientTape(watch_accessed_variables=True,persistent = True) as Tape:
          xche = tf.Variable(xche)
          Tape.watch(xche)
          with tf.GradientTape(watch_accessed_variables=True,persistent=True) as tape:
            tape.watch(xche)
            y =  self.discriminator(xche)
          Dx = tape.gradient(y,xche)
          GP_loss = self.lmbda * tf.math.reduce_mean((tf.math.reduce_euclidean_norm(Dx,axis=[1,2,3]) - 1.)**2)
        
        losses.append([loss,GP_loss.numpy()])
        grads = Tape.gradient(GP_loss,disc_weight,unconnected_gradients=tf.UnconnectedGradients.ZERO)

        del tape,Tape,xche,y,Dx,GP_loss
        self.opt_GP.apply_gradients(zip(grads,disc_weight))
        
      gloss = self.gan.train_on_batch(np.random.randn(batch_size,self.z_dim),label_generator)
      
      if i % sampling_interval == 0:
        print(i,losses[-1],gloss)
        self.SampleImages(4,4)

    self.discriminator.trainable = True
    return losses
 
  def RemoveDiscriminator(self):
    if self.with_disc == True:
      del self.discriminator
      self.with_disc = False
 
  def SampleImages(self,x,y):
    imgs = self.generator(np.random.randn(x*y,self.z_dim))
    fig,ax = plt.subplots(x,y,figsize=(y,x))
    if imgs.shape[-1] == 1:
      imgs = tf.reshape(imgs,(imgs.shape[0],imgs.shape[1],imgs.shape[2]))
      for i in range(x*y):
        ax[i%y,i//y].imshow(imgs[i],cmap="gray")
    else:
      for i in range(x*y):
        ax[i%y,i//y].imshow(imgs[i],cmap="gray") 
 
  def EM(self,labels,logits):
    return ( -tf.reduce_sum(labels * logits) / (tf.reduce_sum(labels) + 1e-10)  + tf.reduce_sum((1. - labels) * (logits)) / (tf.reduce_sum((1. - labels)) + 1e-10))

#### Monitoring Losses

MonitorEMloss class ----->
 Keras callback for monitoring loss and sampling images if needed.

> MonitorEMloss( model,sampling_interval,batch_size,imgs,sample_images) 

1.   model :: The model on which this callback is applied.
2.   sampling_interval :: After how many training steps the loss have to be monitored.
3.   batch_size :: The numeber of real and fake images to be sampled for calculating loss.
4.   imgs :: Training data images.
5.   sample_images :: Boolean, Do images are also to be displayed while monitoring loss.

Since, train_step of model.fit uses @tf.function decorator, so the graph that is created makes it harder to save to losses out of the graph, that is why callback is used here instead of directly saving loss inside the model.

In [159]:
class MonitorEMloss(keras.callbacks.Callback):
  def __init__(self,model,sampling_interval,batch_size,imgs,sample_images=False,**kwargs):
    self.model = model
    self.sampling_interval = sampling_interval
    self.counter = 0
    self.losses = []
    self.batch_size = batch_size
    self.imgs = imgs
    self.sample_images = sample_images
 
  def on_train_batch_begin(self,batch,logs=None):
    self.counter += 1
    if self.counter > self.sampling_interval :
      self.counter = 0
      self.losses.append(model.EM(model.discriminator(np.concatenate([self.imgs[np.random.randint(len(self.imgs),size=self.batch_size)],model(np.random.randn(self.batch_size,model.z_dim))],axis=0)),np.concatenate([np.ones((self.batch_size,1)),np.zeros((self.batch_size,1))],axis=0)).numpy())
      if self.sample_images:
        model.SampleImages(4,4)

Important Note: Do train for more than 300 iterations and then conclude that the model is stablizing or not from the given hyperparameters because the loss,especially the GP loss is initially fluctuating very much but then stablizes after 300 iterations or more.

It may also happen sometimes that the model after converging for 5000 iterations suddenly diverges.In those cases, methods like equalized learning rates and Pixelwise Normalization should be taken into account which reduces the magnitude of gradients.