<a href="https://colab.research.google.com/github/Ayan-Vishwakarma/Keras-Implementation-of-Dense-and-DC-NSGAN-WGAN-WGANGP-etc/blob/main/WGAN_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt

### WGAN Keras model for faster development

Only discriminator and Generator model are needed to be constructed.The model takes care of the whole WGAN training process.

WGAN -----> Subclass of Keras Model class.

Requires a trainable Generator model, a trainable Discriminator model and the dimension of latent space on which generator acts.

>fit()

>>>Keras Model fit function that requires the images whose distribution has to be learnt.All other working are the same.

** Here first n_critic batches of given batch_size are used to train the discriminator and then the last one is used to train the generator.**

>training_loop( xt , n_iter , batch_size , sampling_interval )

>>>xt :: The numpy ndarray in which images are stored,i.e, the training dataset. shape = (N,H,W,C) 

>>>n_iter :: The no of iterations to train 

>>>batch_size :: The batch size used for training 

>>>sampling_interval :: The interval after which the images are to be sampled 

>>> return -> EM loss and GP loss after each iteration Note::Here, the training_loop returns the original loss, which fit returns the estimated EM distace,i.e.,the negative of the loss returned by training_loop.

>compile( opt_gen , opt_disc , opt_GP , n_critic , clip_value=0.01 )

>>> It requires 3 different optmizers:one for generator model, one for discriminator model for reducing EM distance and the third for discriminator for reducing gradient penalty.

>>>n_critic is the number of times critic has to be trained

>>>clip_value :: The critic's weights will be between -clip_value to clip_value

** Here 2 different optimizers are required for discriminator model.**

>call(inputs , with_discriminator=None)

>>>inputs :: The input to be given to generator.This input is derived from latent space.

>>>with_discriminator :: Boolean, If true the outputs are further passed to discriminator.Only valid if discriminator exists.

>RemoveDiscriminator(self):: Removes the discriminator from the model leaving behind only the generator model for generating images.Can be used when training is complete.

>SampleImages(self , x , y , scale = 1.) :: Sample images from generator. A total of x multiplied by y images are samples displayed in x rows and y columns.

>EM(labels,logits) :: Estimates the EM distance given the binary label representing in which class an object belong vs the output of discriminator that best estimates the EM distance.

Attributes:

generator :: The generator model for generating images

discriminator :: The critic model which estimates the EM distance between the current generator's distribution and real distribution.

z_dim :: Latent space dimension

In [52]:
class WGAN(keras.models.Model):
 
  def __init__(self,generator,discriminator,z_dim,**kwargs):
    super().__init__(**kwargs)
    self.with_disc = True
    self.generator = generator
    self.discriminator = discriminator
    self.z_dim = z_dim
    self.gan = keras.Sequential([generator,discriminator])
    assert ((self.gan.layers[0].trainable == True) and (self.gan.layers[1].trainable == True)),"Generator and Discriminator should be trainable"
    self.generator.build(self.z_dim)
    self.discriminator.build(self.generator.output_shape[1:])
    class EM(keras.losses.Loss):
      def call(self,labels,logits):
        return ( -tf.reduce_sum(labels * logits) / (tf.reduce_sum(labels) + 1e-10)  + tf.reduce_sum((1. - labels) * (logits)) / (tf.reduce_sum((1. - labels)) + 1e-10))
    self.EM = EM()
        
  def call(self,inputs,with_discriminator=None):
    if with_discriminator and self.with_disc:
      return self.gan(inputs)      
    else:
      return self.generator(inputs)
 
  def compile(self,opt_gen,opt_disc,n_critic=4,clip_value=0.01):
    super().compile(optimizer = opt_gen,loss = self.EM)
    self.opt_gen = opt_gen
    self.opt_disc= opt_disc
    self.n_critic = n_critic
    self.nc = tf.Variable(n_critic+1,dtype = tf.int32,trainable = False)
    self.intervals = tf.Variable(1,dtype = tf.int32,trainable = False)
    assert clip_value>0 , "clip_value should be strictly greater than zero"
    self.clip_value = clip_value
    self.__d_loss__ = tf.Variable([0],dtype = tf.float32,trainable = False)
    self.__g_loss__ = tf.Variable([0],dtype = tf.float32,trainable = False)
  
  def train_step(self,imgs):
    if isinstance(imgs,tuple):
      imgs = data[0]
    batch_size = tf.shape(imgs)[0]
    
    tf.cond(self.intervals != 0, lambda: self.__train_discriminator__(imgs,batch_size),lambda:self.__train_generator__(imgs,batch_size))
    self.intervals.assign((self.intervals+1)%self.nc) 
  
    return {"EM_Distance":-self.__d_loss__,"generative_loss":self.__g_loss__}

  def __train_discriminator__(self,imgs,batch_size):
    for i in range(self.n_critic):
      x = imgs
      x_ = self.generator(tf.random.normal((batch_size,self.z_dim)))

      with tf.GradientTape(watch_accessed_variables=True) as tape:
        xin = tf.concat([x,x_],axis=0)
        yin = tf.concat([tf.ones((batch_size,1)),tf.zeros((batch_size,1))],axis = 0)
        loss = self.EM(yin,self.discriminator(xin))

      grads = tape.gradient(loss,self.discriminator.trainable_weights)
      self.opt_disc.apply_gradients(zip(grads,self.discriminator.trainable_weights))
      for j in self.discriminator.weights:
        j.assign(tf.clip_by_value(j,clip_value_min= -self.clip_value,clip_value_max= self.clip_value))
      del xin,yin,grads
      self.__d_loss__.assign(tf.reshape(loss,self.__d_loss__.shape))

  def __train_generator__(self,imgs,batch_size):

    with tf.GradientTape(watch_accessed_variables=True) as tape:
      xin = tf.random.normal(tf.stack([batch_size,self.z_dim],axis=0))
      yin = tf.ones((batch_size,1))
      gloss = self.EM(yin,self.gan(xin))
    
    grads = tape.gradient(gloss,self.generator.trainable_weights)
    self.opt_gen.apply_gradients(zip(grads,self.generator.trainable_weights))
    self.__g_loss__.assign(tf.reshape(gloss,self.__g_loss__.shape))


  def training_loop(self,xt,n_iter,batch_size,sampling_interval):
    self.gan.compile( optimizer = self.opt_gen,loss = self.EM)
    disc_weight = self.discriminator.trainable_weights
    self.discriminator.compile( optimizer = self.opt_disc,loss = self.EM)
    self.discriminator.trainable = False
    assert self.discriminator.trainable == False

    losses = []
    for i in range(n_iter):
      for j in range(self.n_critic):
        ind = np.random.randint(len(xt),size= batch_size)
        loss = self.discriminator.train_on_batch(np.concatenate([xt[ind],self.generator(np.random.randn(batch_size,self.z_dim))],axis=0),np.concatenate([np.ones((batch_size,1)),np.zeros((batch_size,1))],axis=0))
        for k in disc_weight:
          k.assign(tf.clip_by_value(k,clip_value_min= -self.clip_value,clip_value_max= self.clip_value))
      self.gan.train_on_batch(np.random.randn(batch_size,self.z_dim),np.ones((batch_size,1)))
      losses.append(-loss)
      if i % sampling_interval == 0:
        print(i,losses[-1])
        self.SampleImages(4,4)  
    self.discriminator.trainable = True  
    return losses
 
  def SampleImages(self,x,y,scale=1.):
    imgs = self.generator(np.random.randn(x*y,self.z_dim))
    fig,ax = plt.subplots(x,y,figsize=(y*scale,x*scale))
    if imgs.shape[-1] == 1:
      imgs = tf.reshape(imgs,(imgs.shape[0],imgs.shape[1],imgs.shape[2]))
      for i in range(x*y):
        ax[i%y,i//y].imshow(imgs[i],cmap="gray")
    else:
      for i in range(x*y):
        ax[i%y,i//y].imshow(imgs[i],cmap="gray") 
  
  def RemoveDiscriminator(self):
    if self.with_disc == True:
      del self.discriminator
      self.with_disc = False

### MonitorEMloss

MonitorEMloss class -----> Keras callback for monitoring loss and sampling images if needed.

>MonitorEMloss( model,sampling_interval,batch_size,imgs,sample_images)

1.model :: The model on which this callback is applied.

2.sampling_interval :: After how many training steps the loss have to be monitored.

3.batch_size :: The numeber of real and fake images to be sampled for calculating loss.

4.imgs :: Training data images.

5.sample_images :: Boolean, Do images are also to be displayed while monitoring loss.

Since, train_step of model.fit uses @tf.function decorator, so the graph that is created makes it harder to save to losses out of the graph, that is why callback is used here instead of directly saving loss inside the model.

In [4]:
class MonitorEMloss(keras.callbacks.Callback):
  def __init__(self,model,sampling_interval,batch_size,imgs,sample_images=False,**kwargs):
    self.model = model
    self.sampling_interval = sampling_interval
    self.counter = 0
    self.losses = []
    self.batch_size = batch_size
    self.imgs = imgs
    self.sample_images = sample_images
 
  def on_train_batch_begin(self,batch,logs=None):
    self.counter += 1
    if self.counter > self.sampling_interval :
      self.counter = 0
      self.losses.append(-model.EM(model.discriminator(np.concatenate([self.imgs[np.random.randint(len(self.imgs),size=self.batch_size)],model(np.random.randn(self.batch_size,model.z_dim))],axis=0)),np.concatenate([np.ones((self.batch_size,1)),np.zeros((self.batch_size,1))],axis=0)).numpy())
      if self.sample_images:
        model.SampleImages(4,4)

Note: The right discriminator if found by gradient descent methodology.So initially loss increases and then after finding the optimal critic that estimates the Wasserstein-1 distance, loss starts to decrease.So,train for at least 300 iterations before concluding that the given hyperparamets stablizes the model or not.