# Training GAN on MNIST dataset.

Following activities are carried out to perform the reduction in size of the GAN model.


1. We will be using 'tflit' for using our model in mobile devices for faster inference.

2. Model is later quantized and retrained with quantization aware training. (Model Pruning)

3. Images obtained can be observed in 'quantized_images' folder

4. For better quality of results train the images for 100k epochs.

5. There can be problem of mode collapse with quantization awared training. This can be taken care with modeling the latent vector like in StyleGAN. 

In [None]:
from keras.datasets import mnist
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [None]:
#Define input image dimensions
#Large images take too much time and resources.
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)

In [None]:
def build_generator():

    noise_shape = (100,) #1D array of size 100 (latent vector / noise)

    model = keras.Sequential([
      keras.layers.InputLayer(input_shape=noise_shape),
      keras.layers.Dense(784),
      keras.layers.Reshape(target_shape=(28, 28, 1)),
      keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
      keras.layers.MaxPooling2D(pool_size=(2, 2)),
      keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
      keras.layers.MaxPooling2D(pool_size=(2, 2)),
      keras.layers.Flatten(),
      keras.layers.Dense(np.prod(img_shape), activation='tanh'),
      keras.layers.Reshape(img_shape)
    ])

    return model

In [None]:
def build_discriminator():


    model = keras.Sequential([
        
    keras.layers.Flatten(input_shape=img_shape),
    keras.layers.Dense(512),
    keras.layers.LeakyReLU(alpha=0.2),
    keras.layers.Dense(256),
    keras.layers.LeakyReLU(alpha=0.2),
    keras.layers.Dense(1, activation='sigmoid')
        
    ])

    return model

def train(epochs, batch_size=128, save_interval=50):

    # Load the dataset
    (X_train, _), (_, _) = mnist.load_data()

    # Convert to float and Rescale -1 to 1 (Can also do 0 to 1)
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5

#Add channels dimension. As the input to our gen and discr. has a shape 28x28x1.
    X_train = np.expand_dims(X_train, axis=3) 

    half_batch = int(batch_size / 2) 
    for epoch in range(epochs):
        idx = np.random.randint(0, X_train.shape[0], half_batch)
        imgs = X_train[idx]

 
        noise = np.random.normal(0, 1, (half_batch, 100))

        # Generate a half batch of fake images
        gen_imgs = generator.predict(noise)

        # Train the discriminator on real and fake images, separately
        #Research showed that separate training is more effective. 
        d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
        d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
    #take average loss from real and fake images. 
    #
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) 
        noise = np.random.normal(0, 1, (batch_size, 100)) 
        valid_y = np.array([1] * batch_size) #Creates an array of all ones of size=batch size

        g_loss = combined.train_on_batch(noise, valid_y)


        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

        # If at save interval => save generated image samples
        if epoch % save_interval == 0:
            save_imgs(epoch)


In [None]:

def save_imgs(epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, 100))
    gen_imgs = generator.predict(noise)
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("/content/sample_data/images/mnist_%d.png" % epoch)
    plt.close()


In [None]:

optimizer = 'adam'
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy',
    optimizer=optimizer,
    metrics=['accuracy'])

generator = build_generator()

generator.compile(loss='binary_crossentropy', optimizer=optimizer)

generator.summary() 
z = keras.layers.Input(shape=(100,))   #Our random input to the generator
img = generator(z)   
discriminator.trainable = False   
valid = discriminator(img)  #Validity check on the generated image

combined = keras.models.Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)


train(epochs=10000, batch_size=128, save_interval=1000)



In [None]:
generator.save('./generator/')  #Test the model on GAN4_predict...

In [None]:
discriminator.save('./discriminator/')  #Test the model on GAN4_predict...

In [None]:
converter = tf.lite.TFLiteConverter.from_saved_model("./generator")
tflite_model = converter.convert()

In [None]:
converter = tf.lite.TFLiteConverter.from_saved_model("./generator")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()

In [None]:
len(tflite_model)

In [None]:
len(tflite_quant_model)

In [None]:
! pip install -q tensorflow-model-optimization

In [None]:
import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = tfmot.quantization.keras.quantize_model(generator)
q_aware_model.summary()

In [None]:
q_aware_model.compile(loss='binary_crossentropy', optimizer=optimizer)
z = keras.layers.Input(shape=(100,))   #Our random input to the generator
img = q_aware_model(z)
valid = discriminator(img) 
combined = keras.models.Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)
train(epochs=10000, batch_size=128, save_interval=1000)

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_qaware_model = converter.convert()



In [None]:
len(tflite_qaware_model)