<a href="https://colab.research.google.com/github/Ayan-Vishwakarma/Keras-Implementation-of-Dense-and-DC-NSGAN-WGAN-WGANGP-etc/blob/main/ProGAN_code_mbstd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras

### Progressive Growing GAN's with MiniBatch Standard Deviation for preventing Mode Collapse

To reduce the effects of mode collapse in GANs, the information about batch is provided to the discriminator so that it can not only discriminate on the basis of images but also by the variation in batch of images. 

In minibatch standard deviation, the batch is reduced by standard deviation along the batch_size axis of real/generated sample, so that it generates a image whose per-pixel value represents the standard deviation of that pixel in the batch. The mean std is calculated from the new image and then tiled and  concatenated to the last layer of image such that now the image is with 4 channels,last of which represents the mean std along the batch.





PGGAN( initial_generator_block , initial_discriminator_block , z_dim , image_shape , extract_dims , trainer = "NSGAN" )

> initial_discriminator_block : Initial discriminator block [Model] which takes in z_dim vector and returns image_shape's height * image_shape's width * out_channels.
>> The images output passes though tanh activation and so are between -1 to 1 value.So the **training images should be initially rescaled between -1 to 1**.

>>The initial_discrimiator_block should take images with **1 channel more than given by extract_dims** to take account for the standard deviation.All other process are same.The MiniBatch Std is accounted only for the initial_discriminator_block and other added functional/sequential block do not take account of the minibatch standard deviation.

> initial_generator_block : Initial generator block[Model] which takes in z_dim vector and returns image_shape dimensional image. The input will be image_shape's height * image_shape's width * extract_dims and output will be a scalar.
>> For NSGAN trainer, output should have **sigmoid** activation.

>> For WGAN_GP trainer, output should **not** have any activation.

> z_dim : Latent space dimension.

> image_shape : Initial resolution of images.

> extract_dims : Number of feature maps the FROM RGB block extracts.

>trainer : Training model.Either NSGAN or WGAN_GP.

Attributes:

>generator

>discriminator

>extract_dims

>z_dim

>image_shape


Based on which trainer used:

>NSGAN

>WGAN_GP

###### Auxillary layers to help PGGAN

In [None]:
class MiniBatchStd(keras.layers.Layer):
  def __init__(self,image_shape,**kwargs):
    super().__init__(**kwargs)
    self.image_shape = tf.constant(image_shape,dtype=tf.int32)
  def call(self,inputs):
    shape = tf.concat([tf.reshape(tf.shape(inputs)[0],(1,)),self.image_shape],axis=0)
    return tf.ones(shape,tf.float32)*tf.math.reduce_mean(tf.math.reduce_std(inputs,axis=0,keepdims=False),keepdims=False)
  def compute_output_shape(self,input_shape):
    return self.image_shape

In [None]:
class MergeLayer(keras.layers.Layer):
  def __init__(self,n_intervals,**kwargs):
    super().__init__(**kwargs)
    self.alpha = tf.Variable([0.],trainable=False)
    self.n_intervals = tf.Variable([n_intervals],dtype=tf.float32,trainable=False)
  def call(self,inputs):
    self.alpha.assign(tf.clip_by_value(self.alpha + tf.math.reciprocal(self.n_intervals),clip_value_min=0.,clip_value_max = 1.))
    return self.alpha * inputs[0] + (1 - self.alpha) * inputs[1]
  def compute_output_shape(self,input_shape):
    return input_shape
  def get_config(self):
    return {"name":self.name,"alpha":self.alpha,"n_intervals":self.n_intervals}

###### PGGAN Model

In [None]:
class PGGAN():
  def __init__(self,initial_generator_block,initial_discriminator_block,z_dim,image_shape,trainer="NSGAN",**kwargs):
    super().__init__(**kwargs)
    self.z_dim = z_dim
    self.extract_dims = initial_discriminator_block.input_shape[-1] - 1
    self.image_shape = image_shape
    self.initial_shape = list(image_shape[0:2]) + [1]

    ######################################   Initial Generator   #####################################################

    xin = keras.layers.Input((z_dim))
    x = initial_generator_block(xin)
    x = keras.layers.Conv2D(3,(1,1),(1,1),padding="same",activation="tanh")(x)
    self.generator = keras.models.Model(inputs=[xin],outputs=[x])

    ######################################   Initial Discriminator   #################################################

    x2in = keras.layers.Input(self.initial_shape)
    x1in = keras.layers.Input(image_shape)
    x1 = keras.layers.Conv2D(self.extract_dims,(3,3),(1,1),padding="same")(x1in)
    x1 = keras.layers.LeakyReLU(0.05)(x1)
    x = tf.keras.layers.Concatenate(axis=-1)([x1,x2in])
    x = initial_discriminator_block(x)
    self.discriminator = keras.models.Model(inputs=[x1in,x2in],outputs=[x])

    ###################################  Trainer  #####################################################
    self.trainer = trainer.upper()
    assert self.trainer == "NSGAN" or self.trainer == "WGAN_GP","trainer should be NSGAN or WGAN_GP"
    if self.trainer == "NSGAN":
      self.NSGAN = NSGAN(self.generator,self.discriminator,self.z_dim,self.initial_shape)
    elif self.trainer == "WGAN_GP":
      self.WGAN_GP = WGAN_GP(self.generator,self.discriminator,self.z_dim,self.initial_shape)

  def ProgressGAN(self,add_functional_generator,add_functional_discriminator,multiplier,n_intervals,extract_dims=None):
    image_shape = self.image_shape
    discriminator = self.discriminator
    generator = self.generator

    assert self.extract_dims == add_functional_discriminator.output_shape[-1] , "The number of feature maps extracted by newly added discriminator should match with the input feature maps required by the previously trained discriminator model"

    try:
      ####################################  Progressing Discriminator   ###############################################

      self.rdl = []
      self.extract_dims = add_functional_discriminator.input_shape[-1]

      self.image_shape = self.image_shape * np.concatenate([np.array(multiplier,dtype=np.int32),np.array([1],dtype=np.int32)],axis=0)
      xin = keras.layers.Input(self.image_shape)
      y = xin

      layer = keras.layers.AveragePooling2D(multiplier)
      self.rdl.append(layer.name)
      x = layer(xin)
      for i in self.discriminator.layers[1:3]:
        self.rdl.append(i.name)
        x = i(x)

      x2in = keras.layers.Input(self.initial_shape)
      self.rdl.append(x2in.name)
      
      y = keras.layers.Conv2D(self.extract_dims,(3,3),(1,1),padding='same')(y)
      y = keras.layers.LeakyReLU(0.05)(y)
      y = add_functional_discriminator(y)
      
      if self.trainer == "WGAN_GP":
        layer = MergeLayer(int(n_intervals * (1+2*self.WGAN_GP.n_critic)/(self.WGAN_GP.n_critic + 1)))
      else:
        layer = MergeLayer(int((4*n_intervals)/3))
      self.rdl.append(layer.name)
      x = layer([y,x])

      for i in self.discriminator.layers[3:-3]:
        x = i(x)

      self.sdl = []
      layer = keras.layers.Concatenate(axis=-1) 
      self.sdl.append(layer.name)
      x = layer([x,x2in])
      x = self.discriminator.layers[-1](x)

      self.discriminator = keras.models.Model(inputs=[xin,x2in],outputs = [x])

      ##########################################  Progressing Generator  ################################################

      xin = keras.layers.Input((self.z_dim))
      x = xin
      self.rgl = []
      for i in self.generator.layers[1:-1]:
        x = i(x)
      y = x
      
      x = add_functional_generator(x)
      x = keras.layers.Conv2D(3,(1,1),(1,1),padding="same",activation="tanh")(x)

      layer = keras.layers.UpSampling2D(multiplier)
      y = layer(y)
      self.rgl.append(layer.name)
      layer = self.generator.layers[-1]
      self.rgl.append(layer.name)
      y = layer(y)
      layer = MergeLayer(n_intervals)
      self.rgl.append(layer.name)
      z = layer([x,y])

      self.generator = keras.models.Model(inputs=[xin],outputs=[z])

    except Exception as e:
      self.discriminator = discriminator
      self.generator = generator
      self.image_shape = image_shape
      print(e)
      return 
    ###################################  Trainer  #####################################################
    if self.trainer == "NSGAN":
      self.NSGAN = NSGAN(self.generator,self.discriminator,self.z_dim,self.initial_shape)
    elif self.trainer == "WGAN_GP":
      self.WGAN_GP = WGAN_GP(self.generator,self.discriminator,self.z_dim,self.initial_shape)

  def StreamlineGAN(self):

    ###################################   Streamline Generator   #######################################

    xin = keras.layers.Input((self.z_dim))
    x = xin
    for i in self.generator.layers[1:]:
      if i.name not in self.rgl:
        x = i(x)
    self.generator = keras.models.Model(inputs=[xin],outputs=[x])

    ###################################   Streamline Discriminator   ###################################

    xin = keras.layers.Input(self.image_shape)
    x = xin
    x2in = keras.layers.Input(self.initial_shape)
    for i in self.discriminator.layers[1:]:
      if i.name in self.sdl:
        x = i([x,x2in])
        continue
      if i.name not in self.rdl:
        x = i(x)
    self.discriminator = keras.models.Model(inputs=[xin,x2in],outputs=[x])


    ###################################  Trainer  #####################################################
    if self.trainer == "NSGAN":
      self.NSGAN = NSGAN(self.generator,self.discriminator,self.z_dim,self.initial_shape)
    elif self.trainer == "WGAN_GP":
      self.WGAN_GP = WGAN_GP(self.generator,self.discriminator,self.z_dim,self.initial_shape)

## NSGAN and WGAN_GP models that takes into Minibatch standard deviation into account.

In [None]:
class NSGAN(keras.models.Model):
  def __init__(self,generator,discriminator,z_dim,initial_shape,**kwargs):
    super().__init__(**kwargs)
    self.with_disc = True
    self.generator = generator
    self.discriminator = discriminator
    self.z_dim = z_dim
    xin = keras.layers.Input((z_dim,))
    x = self.generator(xin)
    x1 = x
    x1 = MiniBatchStd(initial_shape)(x1)
    y = self.discriminator([x,x1])
    self.initial_shape = initial_shape
    self.gan = keras.models.Model(inputs = [xin],outputs=[y])
    self.gm = keras.metrics.BinaryAccuracy(name="binary_accuracy_generator")
    self.dm = keras.metrics.BinaryAccuracy(name="binary_accuracy_discriminator")
    assert ((self.gan.layers[0].trainable == True) and (self.gan.layers[1].trainable == True)),"Generator and Discriminator should be trainable"
 
  def compile(self,opt_gen,opt_disc):
    super().compile(optimizer = opt_gen,loss = "binary_crossentropy")
    self.opt_gen = opt_gen
    self.opt_disc= opt_disc
 
  def train_step(self,imgs):
    if isinstance(imgs,tuple):
      imgs = data[0]
    
    batch_size = tf.shape(imgs)[0]

    x = [imgs,MiniBatchStd(self.initial_shape)(imgs)]
    x_ = self.generator(tf.random.normal((batch_size,self.z_dim)))
    x_ = [x_,MiniBatchStd(self.initial_shape)(x_)]

    with tf.GradientTape(watch_accessed_variables=True) as tape:
      xin = [tf.concat([x[0],x_[0]],axis=0),tf.concat([x[1],x_[1]],axis=0)]
      yin = tf.concat([tf.ones((batch_size,1)),tf.zeros((batch_size,1))],axis = 0)
      self.dm.update_state(yin,self.discriminator(xin))
      loss = keras.losses.binary_crossentropy(yin,self.discriminator(xin))

    grads = tape.gradient(loss,self.discriminator.trainable_weights)
    self.opt_disc.apply_gradients(zip(grads,self.discriminator.trainable_weights))
    del xin,yin,grads

    with tf.GradientTape(watch_accessed_variables=True) as tape:
      xin = tf.random.normal(tf.stack([batch_size,self.z_dim],axis=0))
      yin = tf.ones((batch_size,1))
      self.gm.update_state(yin,self.gan(xin))
      gloss = keras.losses.binary_crossentropy(yin,self.gan(xin))
    
    grads = tape.gradient(gloss,self.generator.trainable_weights)
    self.opt_gen.apply_gradients(zip(grads,self.generator.trainable_weights))
 
    return {"Discriminator_loss":loss,"Generator_loss":gloss,self.gm.name:self.gm.result(),self.dm.name:self.dm.result()}
 
  def SampleImages(self,x,y):
    imgs = self.generator(np.random.randn(x*y,self.z_dim))
    fig,ax = plt.subplots(x,y,figsize=(y*2,x*2))
    if imgs.shape[-1] == 1:
      imgs = tf.reshape(imgs,(imgs.shape[0],imgs.shape[1],imgs.shape[2]))
      for i in range(x*y):
        ax[i%y,i//y].imshow(imgs[i]/2 + 0.5,cmap="gray")
    else:
      for i in range(x*y):
        ax[i%y,i//y].imshow(imgs[i]/2 + 0.5)

In [None]:
class WGAN_GP(keras.models.Model):
 
  def __init__(self,generator,discriminator,z_dim,initial_shape,**kwargs):
    super().__init__(**kwargs)
    self.with_disc = True
    self.generator = generator
    self.discriminator = discriminator
    self.z_dim = z_dim
    xin = keras.layers.Input((z_dim,))
    x = self.generator(xin)
    x1 = x
    x1 = MiniBatchStd(initial_shape)(x1)
    y = self.discriminator([x,x1])
    self.initial_shape = initial_shape
    self.gan = keras.models.Model(inputs = [xin],outputs=[y])
    assert ((self.gan.layers[0].trainable == True) and (self.gan.layers[1].trainable == True)),"Generator and Discriminator should be trainable"
 
  def compile(self,opt_gen,opt_disc,opt_GP,n_critic,lmbda=10):
    super().compile(optimizer = opt_gen,loss = self.EM)
    self.opt_gen = opt_gen
    self.opt_disc= opt_disc
    self.opt_GP = opt_GP
    self.n_critic = n_critic
    self.nc = tf.Variable(n_critic+1,dtype = tf.int32,trainable = False)
    self.intervals = tf.Variable(1,dtype = tf.int32,trainable = False)
    assert ( lmbda > 0 ),"lambda value should be strictly greater than 0"
    self.lmbda = lmbda
    self.__d_loss__ = tf.Variable([0],dtype = tf.float32,trainable = False)
    self.__g_loss__ = tf.Variable([0],dtype = tf.float32,trainable = False)
    self.__GP_loss__ = tf.Variable([0],dtype = tf.float32,trainable = False)
 
  def train_step(self,imgs):
    if isinstance(imgs,tuple):
      imgs = data[0]
    batch_size =  tf.shape(imgs)[0]

    tf.cond(self.intervals != 0, lambda: self.__train_discriminator__(imgs,batch_size),lambda:self.__train_generator__(imgs,batch_size))
    self.intervals.assign((self.intervals+1)%self.nc) 
    return {"EM_Distance":-self.__d_loss__,"GP_loss":self.__GP_loss__,"generative_loss":self.__g_loss__}

  def __train_discriminator__(self,imgs,batch_size):
    x = [imgs,MiniBatchStd(self.initial_shape)(imgs)]
    x_ = self.generator(tf.random.normal((batch_size,self.z_dim)))
    x_ = [x_,MiniBatchStd(self.initial_shape)(x_)]

    with tf.GradientTape(watch_accessed_variables=True) as tape:
      xin = [tf.concat([x[0],x_[0]],axis=0),tf.concat([x[1],x_[1]],axis=0)]
      yin = tf.concat([tf.ones((batch_size,1)),tf.zeros((batch_size,1))],axis = 0)
      loss = self.EM(yin,self.discriminator(xin))

    grads = tape.gradient(loss,self.discriminator.trainable_weights)#
    self.opt_disc.apply_gradients(zip(grads,self.discriminator.trainable_weights))#
    del xin,yin,grads

    t = tf.random.uniform((batch_size,1,1,1))
    xche = t*x[0] + (1-t)*x_[0]

    with tf.GradientTape(watch_accessed_variables=True,persistent = True) as Tape:
      Tape.watch(xche)
      with tf.GradientTape(watch_accessed_variables=True,persistent=True) as tape:
        tape.watch(xche)
        y =  self.discriminator([xche,MiniBatchStd(self.initial_shape)(xche)])
      Dx = tape.gradient(y,xche)
      GP_loss = self.lmbda * tf.math.reduce_mean((tf.math.reduce_euclidean_norm(Dx,axis=[1,2,3]) - 1.)**2)

    grads = Tape.gradient(GP_loss,self.discriminator.trainable_weights,unconnected_gradients=tf.UnconnectedGradients.ZERO)#

    del tape,Tape,xche,y,Dx,x,x_
    self.opt_GP.apply_gradients(zip(grads,self.discriminator.trainable_weights))#
    self.opt_GP.apply_gradients(zip(grads,self.discriminator.trainable_weights))
    self.__d_loss__.assign(tf.reshape(loss,self.__d_loss__.shape))
    self.__GP_loss__.assign(tf.reshape(GP_loss,self.__GP_loss__.shape))
 
  def __train_generator__(self,imgs,batch_size):
    with tf.GradientTape(watch_accessed_variables=True) as tape:
      xin = tf.random.normal(tf.stack([batch_size,self.z_dim],axis=0))
      yin = tf.ones((batch_size,1))
      gloss = self.EM(yin,self.gan(xin))
    
    grads = tape.gradient(gloss,self.generator.trainable_weights)
    self.opt_gen.apply_gradients(zip(grads,self.generator.trainable_weights))
    self.opt_gen.apply_gradients(zip(grads,self.generator.trainable_weights))
    self.__g_loss__.assign(tf.reshape(gloss,self.__g_loss__.shape))
 
 
  def SampleImages(self,x,y,scale = 1.):
    imgs = self.generator(np.random.randn(x*y,self.z_dim))
    fig,ax = plt.subplots(x,y,figsize=(y*scale,x*scale))
    if imgs.shape[-1] == 1:
      imgs = tf.reshape(imgs,(imgs.shape[0],imgs.shape[1],imgs.shape[2]))
      for i in range(x*y):
        ax[i%y,i//y].imshow(imgs[i]/2 + 0.5,cmap="gray")
    else:
      for i in range(x*y):
        ax[i%y,i//y].imshow(imgs[i]/2 + 0.5) 
 
  def EM(self,labels,logits):
    return ( -tf.reduce_sum(labels * logits) / (tf.reduce_sum(labels) + 1e-10)  + tf.reduce_sum((1. - labels) * (logits)) / (tf.reduce_sum((1. - labels)) + 1e-10))