<a href="https://colab.research.google.com/github/Ayan-Vishwakarma/Keras-Implementation-of-Dense-and-DC-NSGAN-WGAN-WGANGP-etc/blob/main/NSGAN_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt

**Make sure either the training_data size is divisible by batch_size or the verbose in fit method is 0.Either of the above method solves the unknown error that is occuring in fit method when calling it's own callbacks due to last batch being different from the other batches [This happens when the training_data size is not divisible by batch_size].**

NSGAN ---> subclass of keras class Model

NSGAN(generator,discriminator,z_dim) 

It requires a generator model,discriminator model and latent space dimension.
The generator and discriminator provided need not necessarily to be compiled,uncompiled model also works.

However the NSGAN model instance need to be compiled before using keras fit() method or the training_loop method.


> compile( opt_gen,opt_disc)
>>> opt_gen :: Generator ,keras optimizer for minimizing binary_crossentropy loss.

>>> opt_disc :: Discriminator, keras optimizer for minimizing binary_crossentropy loss.

> training_loop( xt , n_iter , batch_size , sampling_interval )

>>> xt :: The numpy ndarray in which images are stored,i.e, the training dataset. shape = (N,H,W,C) 

>>>n_iter :: The no of iterations to train 

>>>batch_size :: The batch size used for training 

>>>sampling_interval :: The interval after which the images are to be sampled 

>>>return -> discriminator loss, Generator loss, and binary accuracy of discriminator classifying ability of real and fake images after generator and discriminator training respectively.

> call(inputs , with_discriminator=None)

>>> inputs :: The input to be given to generator.This input is derived from latent space.

>>> with_discriminator :: Boolean, If true the outputs are further passed to discriminator.Only valid if discriminator exists.

>RemoveDiscriminator(self):: Removes the discriminator from the model leaving behind only the generator model for generating images.Can be used when training is complete.

>SampleImages(self,x,y) :: Sample images from generator. A total of x multiplied by y images are samples displayed in x rows and y columns.

Attributes:

generator :: The generator model for generating images

discriminator :: The critic model which estimates the EM distance between the current generator's distribution and real distribution.

z_dim :: Latent space dimension

In [25]:
class NSGAN(keras.models.Model):
 
  def __init__(self,generator,discriminator,z_dim,**kwargs):
    super().__init__(**kwargs)
    self.with_disc = True
    self.generator = generator
    self.discriminator = discriminator
    self.z_dim = z_dim
    self.gan = keras.Sequential([generator,discriminator])
    self.gm = keras.metrics.BinaryAccuracy(name="binary_accuracy_generator")
    self.dm = keras.metrics.BinaryAccuracy(name="binary_accuracy_discriminator")
    assert ((self.gan.layers[0].trainable == True) and (self.gan.layers[1].trainable == True)),"Generator and Discriminator should be trainable"
    
  def call(self,inputs,with_discriminator=None):
    if with_discriminator and self.with_disc:
      return self.gan(inputs)      
    else:
      return self.generator(inputs)
 
  def compile(self,opt_gen,opt_disc):
    super().compile(optimizer = opt_gen,loss = "binary_crossentropy")
    self.opt_gen = opt_gen
    self.opt_disc= opt_disc
 
  def train_step(self,imgs):
    if isinstance(imgs,tuple):
      imgs = data[0]
    
    batch_size = tf.shape(imgs)[0]

    x = imgs
    x_ = self.generator(tf.random.normal((batch_size,self.z_dim)))

    with tf.GradientTape(watch_accessed_variables=True) as tape:
      xin = tf.concat([x,x_],axis=0)
      yin = tf.concat([tf.ones((batch_size,1)),tf.zeros((batch_size,1))],axis = 0)
      self.dm.update_state(yin,self.discriminator(xin))
      loss = keras.losses.binary_crossentropy(yin,self.discriminator(xin))

    grads = tape.gradient(loss,self.discriminator.trainable_weights)
    self.opt_disc.apply_gradients(zip(grads,self.discriminator.trainable_weights))
    del xin,yin,grads
 
    with tf.GradientTape(watch_accessed_variables=True) as tape:
      xin = tf.random.normal(tf.stack([batch_size,self.z_dim],axis=0))
      yin = tf.ones((batch_size,1))
      self.gm.update_state(yin,self.gan(xin))
      gloss = keras.losses.binary_crossentropy(yin,self.gan(xin))
    
    grads = tape.gradient(gloss,self.generator.trainable_weights)
    self.opt_gen.apply_gradients(zip(grads,self.generator.trainable_weights))
 
    return {"Discriminator_loss":loss,"Generator_loss":gloss,self.gm.name:self.gm.result(),self.dm.name:self.dm.result()}
 
  def training_loop(self,xt,n_iter,batch_size,sampling_interval):
    self.gan.compile(optimizer = self.opt_gen,loss = "binary_crossentropy",metrics=["acc"])
    self.discriminator.compile(optimizer = self.opt_disc,loss = "binary_crossentropy",metrics=["acc"])
    self.discriminator.trainable = False
    assert self.discriminator.trainable == False

    losses = []
    accuracies = [] 
    for i in range(n_iter):

      ind = np.random.randint(60000,size = batch_size)
      w1 = self.discriminator.train_on_batch(xt[ind],np.ones((batch_size)))
      w2 = self.discriminator.train_on_batch(self.generator(np.random.randn(batch_size,z)),np.zeros(batch_size))
      w = np.add(w1,w2)/2
      losses.append(w[0])
      accuracies.append(w[1])

      w = self.gan.train_on_batch(np.random.randn(batch_size,z),np.ones(batch_size))

      if i%sampling_interval == 0:
        print(i,[losses[-1],accuracies[-1]],w)
        self.SampleImages(4,4)
        
    self.discriminator.trainable = True
    return losses

  def RemoveDiscriminator(self):
    if self.with_disc == True:
      del self.discriminator
      self.with_disc = False
 
  def SampleImages(self,x,y):
    imgs = self.generator(np.random.randn(x*y,self.z_dim))
    fig,ax = plt.subplots(x,y,figsize=(y,x))
    if imgs.shape[-1] == 1:
      imgs = tf.reshape(imgs,(imgs.shape[0],imgs.shape[1],imgs.shape[2]))
      for i in range(x*y):
        ax[i%y,i//y].imshow(imgs[i],cmap="gray")
    else:
      for i in range(x*y):
        ax[i%y,i//y].imshow(imgs[i],cmap="gray") 

Since the Jensen Shannon loss in NSGAN model have no relation with the  convergence of gan model, that part for monitoring JensenShannon loss is skipped here. However neither the losses should become very small, neither the discriminator accuracy should reach below 0.5 and generator accuracy reach above 0.5 in the initial phase of training.This can be inferred through keras fit and training_loop's output.