In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import datetime
from tensorflow.keras.utils import to_categorical
import io

In [2]:
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.layers import BatchNormalization,Dense,Conv2D,Reshape,Conv2DTranspose,ReLU,LeakyReLU,Flatten,Activation,Dropout,Input
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.losses import CategoricalCrossentropy

In [3]:
class sWGAN:
  def __init__(self,training_images,targets):
    self.no_of_samples=training_images.shape[0]
    self.height=training_images.shape[1]
    self.width=training_images.shape[2]
    self.channels=training_images.shape[3]
    self.train_data=(training_images-127.5)/127.5 ### Converting grey scale [0-255] -> [-1,1]
    self.shape=(self.height,self.width,self.channels)
    self.no_of_classes=len(np.unique(targets,return_counts=False))+1
    self.targets=to_categorical(targets,self.no_of_classes)
    self.target_shape=self.targets.shape
    self.noise_size=100
    self.Generator= None
    self.Critic_Classifier= None
    self.clip_value=0.01
    self.cross_entropy=CategoricalCrossentropy(from_logits=False)
    self.n_critic= 5  

    self.gen_optimizer=RMSprop(0.00005)
    self.critic_optimizer=RMSprop(0.00005)

    self.current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    self.log_dir = 'logs/' + self.current_time
    self.summary_writer = tf.summary.create_file_writer(self.log_dir)

  def get_generator(self):
    
    ### Generator definition

    w_init=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    Generator= Sequential()

    ### Input layer: takes in input a random noise of 100 points distributed in a random distribution 
    Generator.add(Dense(int(self.height/4)*int(self.width/4)*256,use_bias=False, input_shape=(self.noise_size,),kernel_initializer=w_init))
    Generator.add(BatchNormalization(momentum=0.8))
    Generator.add(ReLU())

    ### Reshaping layer to reshape in image dimension
    Generator.add(Reshape((int(self.height/4),int(self.width/4),256)))

    ### Upconv layer 1  ## The size remains constant (7 x 7 x 128)
    Generator.add(Conv2DTranspose(128, (5,5), strides=(1,1),padding="same",kernel_initializer=w_init,use_bias=False))
    Generator.add(BatchNormalization(momentum=0.8))
    Generator.add(ReLU())

    ### Upconv layer 2  ## The size upsamples by 2 (14 x 14 x 128)
    Generator.add(Conv2DTranspose(64, (5,5), strides=(2,2),padding="same",use_bias=False,kernel_initializer=w_init))
    Generator.add(BatchNormalization(momentum=0.8))
    Generator.add(ReLU())

    ### Upconv layer 3  ## The size upsamples by 2 (28 x 28 x 1)
    Generator.add(Conv2DTranspose(self.channels, (5,5), strides=(2,2),padding="same",kernel_initializer=w_init,activation="tanh"))

    return Generator

  def get_critic_classifier(self):

    w_init=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    input_tensor=Input((self.shape))

    conv1=Conv2D(64, (5,5), strides=(2,2),padding="same",use_bias=False,kernel_initializer=w_init,name="conv1")(input_tensor)
    batch_1=BatchNormalization(momentum=0.8,name="bn1")(conv1)
    activation_1=LeakyReLU(0.2,name="ac1")(batch_1)

    conv2=Conv2D(128, (5,5), strides=(2,2),padding="same",use_bias=False,kernel_initializer=w_init,name="conv2")(activation_1)
    batch_2=BatchNormalization(momentum=0.8,name="bn2")(conv2)
    activation_2=LeakyReLU(0.2,name="ac2")(batch_2)

    flattened=Flatten()(activation_2)
    
    #### Classifier head

    dense_1_cl=Dense(256,kernel_initializer=w_init,name="cld1")(flattened)
    dense_1_cl=BatchNormalization(momentum=0.8,name="clbn1")(dense_1_cl)
    dense_1_cl=ReLU(name="clac1")(dense_1_cl)
    dense_1_cl=Dropout(0.2,name="cldr1")(dense_1_cl)

    dense_2_cl=Dense(128,kernel_initializer=w_init,name="cld2")(dense_1_cl)
    dense_2_cl=BatchNormalization(momentum=0.8,name="clbn2")(dense_2_cl)
    dense_2_cl=ReLU(name="clac2")(dense_2_cl)
    dense_2_cl=Dropout(0.2,name="cldr2")(dense_2_cl)


    output_cl=Dense(self.no_of_classes,kernel_initializer=w_init,activation="softmax",name="cl_out")(dense_2_cl)

    #### Critic head

    dense_1_cr=Dense(128,kernel_initializer=w_init,use_bias=False,name="crd1")(flattened)
    dense_1_cr=BatchNormalization(momentum=0.8,name="crbn1")(dense_1_cr)
    dense_1_cr=ReLU(name="crac1")(dense_1_cr)

    output_cr=Dense(1,kernel_initializer=w_init,name="cr_out")(dense_1_cr)


    model=Model(inputs=[input_tensor],outputs=[output_cr,output_cl])

    return model


  
  def get_all_models(self):

    self.Generator=self.get_generator()
    self.Critic_Classifier=self.get_critic_classifier()

    print("############# Generator ###############")

    print(self.Generator.summary())

    print("############# Critic-Classifier ###############")

    print(self.Critic_Classifier.summary())

    return None
  
  def train_on_batch_disc(self,data,fake,label_data,label_fake):
    
    with tf.GradientTape() as disc_tape:
      critic_out_real,class_out_real=self.Critic_Classifier(data)
      critic_out_fake,class_out_fake=self.Critic_Classifier(fake)
      loss_class_real=self.cross_entropy(label_data,class_out_real)
      loss_class_fake=self.cross_entropy(label_fake,class_out_fake)
      loss_class=(loss_class_real+loss_class_fake)/2
      loss_critic=tf.reduce_mean(critic_out_fake)-tf.reduce_mean(critic_out_real)
      loss_total=loss_critic+loss_class
    
    
    d_grads=disc_tape.gradient(loss_total,self.Critic_Classifier.trainable_variables)
    self.critic_optimizer.apply_gradients(zip(d_grads,self.Critic_Classifier.trainable_variables))

    _,pred_real=self.Critic_Classifier(data,training=False)
    _,pred_fake=self.Critic_Classifier(fake,training=False)
    pred_prob_real=tf.round(pred_real)
    pred_prob_fake=tf.round(pred_fake)
    disc_acc_real=tf.reduce_mean(tf.cast(tf.equal(label_data,pred_prob_real),dtype=tf.float32))
    disc_acc_fake=tf.reduce_mean(tf.cast(tf.equal(label_fake,pred_prob_fake),dtype=tf.float32))
    acc_real=disc_acc_real.numpy()
    acc_fake=disc_acc_fake.numpy()
    acc=(acc_real+acc_fake)/2

    return loss_critic,loss_class,loss_total,acc

  def train_on_batch_generator(self,noise):
    
    with tf.GradientTape() as gen_tape:
      output_cr_fake,output_cl_fake=self.Critic_Classifier(self.Generator(noise,training=False),training=False)
      loss_gen=-tf.reduce_mean(output_cr_fake)

    g_grads=gen_tape.gradient(loss_gen,self.Generator.trainable_variables)
    self.gen_optimizer.apply_gradients(zip(g_grads,self.Generator.trainable_variables))

    return loss_gen

  def show_images(self, rows=4, columns=4):

    z = tf.random.uniform( minval=-1, maxval=1,shape=[rows*columns,self.noise_size]) 

    generated_images = self.Generator.predict(z)
    output_cr,output_cl=self.Critic_Classifier(generated_images,training=False)

    generated_images= (generated_images - (-1))/(1 - (-1))
    _,labels=self.Critic_Classifier(generated_images)

    ### Min-Max scaling to convert pixles from [-1,1] -> [0,1]
    figure = plt.figure(figsize=(10,10))
    ### Plotting
    for i in range(rows*columns):
      label=np.argmax(labels[i])
      plt.subplot(rows, columns, i+1)
      plt.xlabel(label)
      plt.xticks([])
      plt.yticks([])
      plt.axis('off')
      plt.grid(False)
      plt.imshow(generated_images[i,:,:,0], cmap=plt.cm.binary) 
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close(figure)
    buf.seek(0)
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    image = tf.expand_dims(image, 0)
    return image
  

  def train(self,epochs,batch_size):

    self.get_all_models()  ### Initializing all models     
    dataset=tf.data.Dataset.from_tensor_slices((self.train_data,self.targets)).shuffle(self.no_of_samples).batch(batch_size,drop_remainder=True)                                                                           
    gen_loss=[]
    dis_loss=[]
    t_loss=[]
    classifier_loss=[]
    classifier_acc=[]

    history={}


    no_of_batches=self.no_of_samples/batch_size
    
    for epoch in range(epochs):      ### Training epoch

      iterator=dataset.as_numpy_iterator()

    #### Critic Training

      d_loss=0
      total_loss=0
      g_loss=0 
      c_loss=0
      c_acc=0

      for _ in range(50):

        d_loss_n_crit=0
        c_loss_n_crit=0
        c_acc_n_crit=0
        total_loss_n_crit=0

        for _ in range(self.n_critic):

          z = tf.random.uniform( minval=-1, maxval=1,shape=[batch_size,self.noise_size])
          generated=self.Generator.predict(z)
          fake_labels=to_categorical(np.full((batch_size, 1), self.no_of_classes-1), num_classes=self.no_of_classes)
          #print(fake_labels)
          (samples,labels)=iterator.next()
  

          loss_critic,loss_class,loss_total,acc_real=self.train_on_batch_disc(samples,generated,labels,fake_labels)
         
          for layer in self.Critic_Classifier.layers:
            weights_crit = layer.get_weights()
            if layer.name not in ["cld1","clbn1","clac1","cldr1","cld2","clbn2","clac2","cldr2","cl_out"]:
              weights_mod_crit = [np.clip(w, -self.clip_value, self.clip_value) for w in weights_crit]
            else:
              weights_mod_crit=weights_crit
            layer.set_weights(weights_mod_crit)



          d_loss_n_crit+=loss_critic
          c_loss_n_crit+=loss_class
          c_acc_n_crit+=acc_real
          total_loss_n_crit+=loss_total
        
        d_loss_n_crit/=self.n_critic
        c_loss_n_crit/=self.n_critic
        c_acc_n_crit/=self.n_critic
        total_loss_n_crit/=self.n_critic
        d_loss/=self.n_critic
        
        z = tf.random.uniform( minval=-1, maxval=1,shape=[batch_size,self.noise_size]) 
        loss_gen=self.train_on_batch_generator(z)
        
        d_loss+=d_loss_n_crit
        c_loss+=c_loss_n_crit
        c_acc+=c_acc_n_crit
        total_loss+=total_loss_n_crit
        g_loss+=loss_gen

        
      
      d_loss/=50
      g_loss/=50
      total_loss/=50
      c_loss/=50
      c_acc/=50
      dis_loss.append(d_loss)
      gen_loss.append(g_loss)
      t_loss.append(total_loss)
      classifier_loss.append(c_loss)
      classifier_acc.append(c_acc)

      print(f"ON EPOCH {epoch} Critic Loss: {d_loss}, Generator loss: {g_loss}, classifier loss: {c_loss}, total_loss: {total_loss}, classifier accuracy: {c_acc}")
      with self.summary_writer.as_default():
        tf.summary.scalar('loss/Generator', g_loss, step=epoch)
        tf.summary.scalar('loss/Critic', d_loss, step=epoch)
        tf.summary.scalar('loss/Classifier', c_loss, step=epoch)
        tf.summary.scalar('loss/Total', total_loss, step=epoch)
        tf.summary.scalar('acc/classifier', c_acc, step=epoch)

      
      if epoch%50==0:
        figs=self.show_images()
        tf.summary.image('gen_images', figs, step=epoch)


    history["Critic loss"]=dis_loss

    history["Generator loss"]=gen_loss
    history["Total Loss"]=t_loss
    history["Classifier loss"]=classifier_loss
    history["Classifier accuracy"]=classifier_acc
    return history




In [4]:
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data(path="mnist.npz")
x_train=x_train.reshape((x_train.shape[0],x_train.shape[1],x_train.shape[2],1))

In [5]:
GAN=sWGAN(x_train,y_train)

In [6]:
%load_ext tensorboard
%tensorboard --logdir logs

Reusing TensorBoard on port 6006 (pid 272), started 0:48:07 ago. (Use '!kill 272' to kill it.)

<IPython.core.display.Javascript object>

In [None]:
history=GAN.train(2000,64)