# Hyper-parameters

In [1]:
dir_name = "MNIST_Model4" #location to save the model in Google Drive

In [2]:
n_Gen = 3 #number of generators
latent_dim = 256 #dimention of input noise
batch_size = 256 #number of batches
size_dataset = 60000 #size MNIST dataset

steps_per_epoch = (size_dataset//batch_size)//n_Gen

# Adding Libraries

In [3]:
import matplotlib.pyplot as plt
from sklearn import mixture

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (Input, Dense, Dropout, LeakyReLU, 
                                     ReLU, Conv2D,Conv2DTranspose, Flatten,
                                     Reshape, BatchNormalization)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.utils import plot_model
from tensorflow.keras import Model
from google.colab import output

import os
from IPython import display

In [4]:
# for saving GIF
!pip install pygifsicle
!sudo apt-get install gifsicle
import imageio
import glob
from pygifsicle import optimize
output.clear()
print("Import Done")

Import Done


# To see if we have a GPU

In [5]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

if tf.test.gpu_device_name() == '/device:GPU:0':
    print("Using a GPU")
else:
    print("Using a CPU")

Num GPUs Available:  1
Using a GPU


# Mount Google drive to save model and data

In [5]:
from google.colab import drive
drive.mount('/content/drive')

if os.path.exists(f'/content/drive/MyDrive/{dir_name}') == False:
    os.mkdir(f'/content/drive/MyDrive/{dir_name}')
if os.path.exists(f'/content/drive/MyDrive/{dir_name}/Pictures') == False:
    os.mkdir(f'/content/drive/MyDrive/{dir_name}/Pictures')

Mounted at /content/drive


# Producing Dataset

##### Loading MNIST and convert it to stacked-MNIST

In [7]:
def dataset1_func(random_state = None):
    (train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()

    train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
    # train_images = tf.image.resize(train_images, [32,32])
    train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]

    # Convert to stacked-mnist(rgb images)
    t1 = tf.random.shuffle(train_images, seed = 10)
    t2 = tf.random.shuffle(train_images, seed = 20)
    train_images = tf.concat([train_images, t1, t2], axis=-1)
    
    return train_images

# Some Functions

##### ٔNoise Generator



In [8]:
from tensorflow_probability import distributions as tfd

# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, batch_size, n_Gen):
    # Multivariate normal diagonal distribution
    mvn = tfd.MultivariateNormalDiag(
        loc=[0]*latent_dim,
        scale_diag=[1.0]*latent_dim)

    noise = []
    for i in range(n_Gen):
        # Some samples from MVN
        x_input = mvn.sample(batch_size)
        noise.append(x_input)
    return noise

##### Loss function for the generators based on the MAD_GAN paper

In [9]:
def Generators_loss_function(y_true, y_pred): 
    logarithm = -tf.math.log(y_pred[:,-1] + 1e-15)
    return tf.reduce_mean(logarithm, axis=-1)

##### A callback which runs at end of each epoch to save and plot the results

In [10]:
class GANMonitor1(tf.keras.callbacks.Callback):
    def __init__(self, random_latent_vectors, data, plot_freq = 5, num_img = 3, latent_dim = 128, n_Gen = 6, dir_name = 'Model'):
        self.data = data[0:2]
        self.random_latent_vectors = random_latent_vectors
        self.plot_freq = plot_freq
        self.num_img = num_img
        self.latent_dim = latent_dim
        self.n_Gen = n_Gen
        self.dir_name = dir_name

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.plot_freq == 0:
            fig = plt.figure(figsize=(12, 6))
            fig.suptitle(f'Epoch {(epoch + 1):04}')
            for i in range(self.num_img):
                plt.subplot(4,8,7 + i)
                plt.title(f'Real {i+1}')
                plt.imshow((self.data[i, :, :, :] * 127.5 + 127.5)/255, aspect = 'equal')
                plt.axis('off')
                plt.subplot(4,8,15 + i)
                plt.imshow(self.data[i, :, :, 0] * 127.5 + 127.5, cmap = 'gray', aspect = 'equal', vmin=0, vmax=255)
                plt.axis('off')
                plt.subplot(4,8,23 + i)
                plt.imshow(self.data[i, :, :, 1] * 127.5 + 127.5, cmap = 'gray', aspect = 'equal', vmin=0, vmax=255)
                plt.axis('off')
                plt.subplot(4,8,31 + i)
                plt.imshow(self.data[i, :, :, 2] * 127.5 + 127.5, cmap = 'gray', aspect = 'equal', vmin=0, vmax=255)
                plt.axis('off')

            for g in range(self.n_Gen):
                generated_samples = self.model.generators[g](self.random_latent_vectors[g])
                for i in range(self.num_img):
                    plt.subplot(4,8,g*2 + i + 1)
                    plt.title(f'Gen {g + 1}')
                    plt.imshow((generated_samples[i, :, :, :] * 127.5 + 127.5)/255, aspect = 'equal')
                    plt.axis('off')
                    plt.subplot(4,8,g*2 + i + 9)
                    plt.imshow(generated_samples[i, :, :, 0] * 127.5 + 127.5, cmap = 'gray', aspect = 'equal', vmin=0, vmax=255)
                    plt.axis('off')
                    plt.subplot(4,8,g*2 + i + 17)
                    plt.imshow(generated_samples[i, :, :, 1] * 127.5 + 127.5, cmap = 'gray', aspect = 'equal', vmin=0, vmax=255)
                    plt.axis('off')
                    plt.subplot(4,8,g*2 + i + 25)
                    plt.imshow(generated_samples[i, :, :, 2] * 127.5 + 127.5, cmap = 'gray', aspect = 'equal', vmin=0, vmax=255)
                    plt.axis('off')

            plt.subplots_adjust(hspace = 0.05, wspace = 0.05)
            plt.savefig(f'/content/drive/MyDrive/{self.dir_name}/Pictures/image_at_epoch_{(epoch + 1):04}.png', dpi=200, format="png")

            # To show the plots in colab comment line below and uncomment the next line
            plt.close()
            # plt.show()

# Defining Discriminator Model

In [11]:
# define the standalone discriminator model
def define_discriminator(n_Gen):
    inp = Input(shape=(28, 28, 3))

    x = Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 3])(inp)
    x = LeakyReLU()(x)
    x = Dropout(0.3)(x)

    x = Conv2D(128, (5, 5), strides=(2, 2), padding='same')(x)
    x = LeakyReLU()(x)
    x = Dropout(0.3)(x)
    
    x = Flatten()(x)
    out = Dense(n_Gen + 1, activation = 'softmax')(x)

    model = Model(inp, out, name="Discriminator")
    return model

# Defining Generators Model

In [12]:
def define_generators(n_Gen, latent_dim):
    dens = Dense(7*7*256, use_bias=False, input_shape=(latent_dim,))
    batchnorm0 = BatchNormalization()
    rel0 = LeakyReLU()
    reshape0 = Reshape([7,7,latent_dim])

    con2dt1 = Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)
    batchnorm1 = BatchNormalization()
    rel1 = LeakyReLU()

    con2dt2 = Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)
    batchnorm2 = BatchNormalization()
    rel2 = LeakyReLU()

    models = []
    for g in range(n_Gen):
        input = Input(shape=(latent_dim,), dtype = tf.float64, name=f"input_{g}")
        x = dens(input)
        x = batchnorm0(x)
        x = rel0(x)
        x = reshape0(x)

        x = con2dt1(x)
        x = batchnorm1(x)
        x = rel1(x)
        
        x = con2dt2(x)
        x = batchnorm2(x)
        x = rel2(x)
        
        x = Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')(x)
        
        models.append(Model(input, x, name = f"generator{g}"))
    return models

# Defining MADGAN Class for training via keras

In [13]:
class MADGAN(tf.keras.Model):
    def __init__(self, discriminator, generators, latent_dim, n_Gen):
        super(MADGAN, self).__init__()
        self.discriminator = discriminator
        self.generators = generators
        self.latent_dim = latent_dim
        self.n_Gen = n_Gen

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(MADGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def train_step(self, data):
        X = data
        
        # Get the batch size
        batch_size = tf.shape(X)[0]
        # Sample random points in the latent space
        random_latent_vectors = generate_latent_points(self.latent_dim, batch_size//self.n_Gen, self.n_Gen)
        # Decode them to fake generator output
        x_generator = []
        for g in range(self.n_Gen):
            x_generator.append(self.generators[g](random_latent_vectors[g]))
        
        # Combine them with real samples
        combined_samples = tf.concat([x_generator[g] for g in range(self.n_Gen)] + 
                                     [X], 
                                     axis=0
                                     )
        # Assemble labels discriminating real from fake samples
        labels = tf.concat([tf.one_hot(g * tf.ones(batch_size//self.n_Gen, dtype=tf.int32), self.n_Gen + 1) for g in range(self.n_Gen)] + 
                    [tf.one_hot(self.n_Gen * tf.ones(batch_size, dtype=tf.int32), self.n_Gen + 1)], 
                    axis=0
                    )

        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(shape = tf.shape(labels), minval = -1, maxval = 1)

        #######################
        # Train Discriminator #
        #######################
        
        # make weights in the discriminator trainable
        with tf.GradientTape() as tape:
            # Discriminator forward pass
            predictions = self.discriminator(combined_samples)

            # Compute the loss value
            # (the loss function is configured in `compile()`)
            d_loss = self.d_loss_fn(labels, predictions)

        # Compute gradients
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)


        # Update weights
        self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))

        #######################
        #   Train Generator   #
        #######################

        # Assemble labels that say "all real samples"
        misleading_labels =  tf.one_hot(self.n_Gen * tf.ones(batch_size//self.n_Gen, dtype=tf.int32), self.n_Gen + 1)

        # (note that we should *not* update the weights of the discriminator)!
        g_loss_list = []
        fake_image = []
        for g in range(self.n_Gen):
            with tf.GradientTape() as tape:
                # Generator[g] and discriminator forward pass
                predictions = self.discriminator(self.generators[g](random_latent_vectors[g]))
                
                # Compute the loss value
                # (the loss function is configured in `compile()`)
                g_loss = self.g_loss_fn(misleading_labels, predictions)

            # Compute gradients
            grads = tape.gradient(g_loss, self.generators[g].trainable_weights)
            # Update weights
            self.g_optimizer[g].apply_gradients(zip(grads, self.generators[g].trainable_weights))
            g_loss_list.append(g_loss)

        mydict = {f"g_loss{g}": g_loss_list[g] for g in range(self.n_Gen)}
        mydict.update({"d_loss": d_loss})
        return mydict

# Creating Model and training it

In [None]:
# Loading data
data = dataset1_func()
# Changing numpy dataset to tf.DATASET type and Shuffling dataset for training
dataset = tf.data.Dataset.from_tensor_slices(data) 
dataset = dataset.repeat().shuffle(10 * size_dataset, reshuffle_each_iteration=True).batch(n_Gen * batch_size, drop_remainder=True)

# Creating Discriminator and Generator
discriminator = define_discriminator(n_Gen)
discriminator.summary()
generators = define_generators(n_Gen, latent_dim)
generators[0].summary()

# creating MADGAN
madgan = MADGAN(discriminator = discriminator, generators = generators, 
                latent_dim = latent_dim, n_Gen = n_Gen)

madgan.compile(
    d_optimizer = Adam(learning_rate=2e-4, beta_1=0.5),
    g_optimizer = [Adam(learning_rate=1e-4, beta_1=0.5) for g in range(n_Gen)],
    d_loss_fn = CategoricalCrossentropy(),
    g_loss_fn = Generators_loss_function
)

checkpoint_filepath = f'/content/drive/MyDrive/{dir_name}/checkpoint'
random_latent_vectors = generate_latent_points(latent_dim = latent_dim, batch_size = 2, n_Gen = n_Gen)

my_callbacks = [
    # This callback is for ploting generators' output every epoch
    GANMonitor1(random_latent_vectors, data = data, plot_freq = 1, num_img = 2, latent_dim = latent_dim, n_Gen = n_Gen, dir_name = dir_name),
    # This callback is for Saving the model every 15 epochs
    tf.keras.callbacks.ModelCheckpoint(filepath = checkpoint_filepath , save_freq = 20, save_weights_only = True),
]

# # Loading previous saved model for resume training
# if os.path.exists(checkpoint_filepath):
#     madgan.load_weights(checkpoint_filepath)

# train the model
madgan.fit(dataset, epochs = 200, steps_per_epoch = steps_per_epoch, verbose = 1, callbacks = my_callbacks)

Model: "Discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 3)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 14, 14, 64)        4864      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 14, 14, 64)        0         
_________________________________________________________________
dropout (Dropout)            (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 7, 7, 128)         204928    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 7, 7, 128)         0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 7, 7, 128)       

# Saving GIF file

In [7]:
anim_file = 'madgan2.gif'

with imageio.get_writer(f'/content/drive/MyDrive/{dir_name}/{anim_file}', mode='I') as writer:
    filenames = glob.glob(f'/content/drive/MyDrive/{dir_name}/Pictures/image_at_epoch_*.png')
    filenames = sorted(filenames)
    for filename in filenames:
        image = imageio.imread(filename)
        # image = image[::2,::2,:]
        writer.append_data(image)
    for i in range(20):
        writer.append_data(image)   

# Reduce GIF size
optimize(f'/content/drive/MyDrive/{dir_name}/{anim_file}')