In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np

from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras.optimizers import Adam
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Add
from keras.layers import AveragePooling2D
from keras.layers import MaxPooling2D
from keras.layers import Dropout
from keras.layers import Embedding
from keras.layers import Concatenate
from keras.layers import GaussianNoise
from keras.layers import BatchNormalization
from keras.layers import LayerNormalization

from keras.initializers import RandomNormal
import keras.backend as K
MNIST_DATA = keras.datasets.mnist

# define the standalone discriminator model
def define_discriminator(in_shape=(28,28,1), n_classes=10):
	# label input
  in_label = Input(shape=(1,))
	# embedding the label input
  li = Embedding(n_classes, 50)(in_label)
	# scale up to image dimensions with linear activation
  n_nodes = in_shape[0]*in_shape[1]
  li = Dense(n_nodes)(li)
	# reshape to additional channel
  li = Reshape((in_shape[0], in_shape[1], 1))(li)
  # image input
  in_img = Input(shape=in_shape)
  #KERNEL INItialization
  init = RandomNormal(mean=0.0, stddev=0.02)
  in_image = Concatenate()([in_img,li])
  in_image=GaussianNoise(0.01)(in_image)
  #add a convolutional layers
  fe = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(in_image)
  fe = LayerNormalization()(fe)
  fe = LeakyReLU(alpha=0.2)(fe)
  # add 1st residual layer to the discriminator
  pre_fe = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(fe)
  pre_fe = LayerNormalization()(pre_fe)
  pre_fe = LeakyReLU(alpha=0.2)(pre_fe)
  # add 2nd residual layer to the discriminator
  pre_fe=Conv2D(128, (3,3), strides=(1,1), padding='same', kernel_initializer=init)(pre_fe)
  pre_fe = LayerNormalization()(pre_fe)
  res_lay1 = Add()([fe, pre_fe])
  res_lay1 = LeakyReLU(alpha=0.2)(res_lay1)
  # downsample
  fe = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(res_lay1)
  fe = LayerNormalization()(fe)
  fe = LeakyReLU(alpha=0.2)(fe)
  # add 1st residual layer to the discriminator
  #post_fe=Conv2D(128, (3,3), strides=(1,1), padding='same', kernel_initializer=init)(fe)
  #post_fe = BatchNormalization(synchronized=False)(post_fe)
  #post_fe = LeakyReLU(alpha=0.2)(post_fe)
  # add 2nd residual layer to the discriminator
  #post_fe=Conv2D(128, (3,3), strides=(1,1), padding='same', kernel_initializer=init)(post_fe)
  #post_fe = BatchNormalization(synchronized=False)(post_fe)
  #res_lay2 = Add()([fe, post_fe])
  #res_lay2 = LeakyReLU(alpha=0.2)(res_lay2)
	#downsample
  fe = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(fe)
  fe = LayerNormalization()(fe)
  fe = LeakyReLU(alpha=0.2)(fe)
	# flatten feature maps
  fe = Flatten()(fe)
	# dropout
  fe = Dropout(0.3)(fe)
  # output
  out_layer = Dense(1)(fe)
	# define model
  model = Model(inputs=[in_img,in_label], outputs= out_layer, name="discriminator")
	# compile model
  #opt = Adam(learning_rate=0.0001, beta_1=0.5, beta_2=0.9) #try beta_1=0 as well
  #model.compile(loss= discloss, optimizer=opt, metrics=['accuracy'])
  return model

# define the standalone generator model
def define_generator(low_res=(7,7,1), n_classes=10):
	  # label input
    in_label = Input(shape=(1,))
	# embedding the label input
    li = Embedding(n_classes, 50)(in_label)
	# scale up to image dimensions with linear activation
    n_nodes = low_res[0]*low_res[1]
    li = Dense(n_nodes)(li)
	# reshape to additional channel
    li = Reshape((low_res[0], low_res[1], 1))(li)
    # image generator input
    in_img = Input(shape=low_res)
    in_image= Concatenate()([in_img,li])
    input_img=GaussianNoise(0.01)(in_image)
    init = RandomNormal(mean=0.0, stddev=0.02)
    #add initial conv2d layer
    fir_gen = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(input_img)
    fir_gen = BatchNormalization(synchronized=False)(fir_gen)
    fir_gen = LeakyReLU(alpha=0.2)(fir_gen)
    # add 1st residual layer to the generator
    pre_gen = Conv2D(128, (3,3), strides=(1,1), padding='same',kernel_initializer=init)(fir_gen)
    pre_gen = BatchNormalization(synchronized=False)(pre_gen)
    pre_gen = LeakyReLU(alpha=0.2)(pre_gen)
    # add 2nd residual layer to the generator
    pre_gen = Conv2D(128, (3,3), strides=(1,1), padding='same', kernel_initializer=init)(pre_gen)
    pre_gen = BatchNormalization(synchronized=False)(pre_gen)
    pre_gen = Add()([fir_gen, pre_gen])
    pre_gen = LeakyReLU(alpha=0.2)(pre_gen)
    # upsample to 28x28
    gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(pre_gen)
    gen = BatchNormalization(synchronized=False)(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    #2nd upsampling
    gen2 = Conv2DTranspose(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(gen)
    gen2 = BatchNormalization(synchronized=False)(gen2)
    gen2 = LeakyReLU(alpha=0.2)(gen2)
  # output
    out_layer = Conv2D(1, (5,5), activation='tanh', padding='same')(gen2)
	# define model
    model = Model(inputs=[in_img,in_label], outputs=out_layer, name="generator")
    return model



# define the combined generator and discriminator model, for updating the generator
#def define_gan(g_model, d_model):
	# make weights in the discriminator not trainable
	#d_model.trainable = False
	# get noise and label inputs from generator model
	#gen_noise, gen_label = g_model.input
	# get image output from the generator model
	#gen_output = g_model.output
	# connect image output and label input from generator as inputs to discriminator
	#gan_output = d_model([gen_output, gen_label])
	# define gan model as taking noise and label and outputting a classification
	#model = Model([gen_noise, gen_label], gan_output)
	# compile model
	##opt = Adam(learning_rate=0.0001, beta_1=0.2)
	#model.compile(loss='binary_crossentropy', optimizer=opt)
	#return model



class WGAN(keras.Model):
    def __init__(self, discriminator, generator, Dsteps=5, gp_weight=10.0):
        super(WGAN, self).__init__()

        self.discriminator = discriminator
        self.generator = generator
        self.d_steps = Dsteps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn
    
    def generate_low_res_samples(self,image, labels):
        #print(batch_size)
        Avgpool= AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')
        #ix = randint(0, 60000, batch_size)
        # select images
        #random_images= images[ix]
        low1=Avgpool(image)
        low2=Avgpool(low1)
        #print(low2.shape)
        return low2,labels

    def gradient_penalty (self,batch_size, real_images, fake_images, real_labels):
        # Get the interpolated image
        #print(real_images.shape)
        #print(fake_images.shape)
        alpha = tf.random.uniform(shape=[1,1,1], minval=0.,maxval=1.)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            pred = self.discriminator([interpolated,real_labels], training=True)

        grads = gp_tape.gradient(pred, [interpolated])[0]
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp                             

    def train_step(self, real_data):
        if isinstance(real_data, tuple):
            real_images = real_data[0]
            real_labels = real_data[1]
            #print(real_labels.shape)


        for i in range(self.d_steps):
            # Get the lowres image
            [lowres_images, lowres_labels]= self.generate_low_res_samples(real_images, real_labels)
         
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator([lowres_images, lowres_labels], training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator([fake_images,lowres_labels], training=True)
                # Get the logits for the real images
                real_logits = self.discriminator([real_images,real_labels], training=True)

                # Calculate the discriminator loss using the fake and real image logits
                d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images,real_labels)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(zip(d_gradient, self.discriminator.trainable_variables))

        # Train the generator
        [lowres_images, lowres_labels]= self.generate_low_res_samples(real_images,real_labels)
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator([lowres_images, lowres_labels], training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator([generated_images, lowres_labels], training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(zip(gen_gradient, self.generator.trainable_variables))
        return {"d_loss": d_loss, "g_loss": g_loss}


# load  mnist images
def load_samples(MNIST_DATA):
	# load dataset
    (trainX, trainy), (_, _) = MNIST_DATA.load_data()
	# expand to 3d, e.g. add channels
    X = expand_dims(trainX, axis=-1)
    trainy = expand_dims(trainy, axis=-1)
	# convert from ints to floats
    X = X.astype('float32')
	# scale from [0,255] to [-1,1]
    X = (X - 127.5) / 127.5
    #print(X.shape, trainy.shape)
    return [X, trainy]

#def generate_low_res_samples():
 #     [images, labels]= 
 #     Avgpool= AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')
 #     ix = randint(0, images.shape[0], batch_size)
 #     # select images
 #     random_images= images[ix]
 #     labels=labels[ix]
 #     low1=Avgpool(random_images)
 #     low2=Avgpool(low1)
      #print(low2.shape)
 #     return low2,labels

class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_img=6):
        self.num_img = num_img

    def on_epoch_end(self, epoch, logs=None):
        g_model.save('WGAN_GP%.1f.keras'%epoch)

cbk = GANMonitor(num_img=3)
#loading the data 
org_dataset = load_samples(MNIST_DATA)
#print(org_dataset[1].shape)
#print(train_images.type, train_labels.type)
# define the generator and the discriminator model
g_model=define_generator()
d_model= define_discriminator()

#define the optimizer for the generator(G) and discriminator(D)
generator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9)
discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9)

#Loss functions for G and D without the Gradient penalty
def discriminator_loss(real_img, fake_img):
      real_loss = tf.reduce_mean(real_img)
      fake_loss = tf.reduce_mean(fake_img)
      return fake_loss - real_loss

def generator_loss(fake_img):
      return -tf.reduce_mean(fake_img)

# Instantiate the WGAN model.
wgan = WGAN(discriminator=d_model, generator=g_model, Dsteps=3,)

# Compile the WGAN model.
wgan.compile(d_optimizer=discriminator_optimizer, g_optimizer=generator_optimizer, g_loss_fn=generator_loss, d_loss_fn=discriminator_loss,)
epoch =50
total_samples=org_dataset[0].shape[0]
batch_size=256
steps_per_epoch=total_samples//batch_size
# Start training the model.
wgan.fit(org_dataset[0],org_dataset[1], batch_size=batch_size, epochs=epoch,callbacks=[cbk],steps_per_epoch=steps_per_epoch)