In [None]:
import os
import cv2
from tensorflow.keras.preprocessing.image import array_to_img
from tensorflow.keras.callbacks import Callback
import warnings
warnings.filterwarnings('ignore')
# Adam is going to be the optimizer for both
from tensorflow.keras.optimizers import RMSprop
# Binary cross entropy is going to be the loss for both 
from tensorflow.keras.losses import BinaryCrossentropy
import tensorflow as tf
import numpy as np
from tensorflow import keras
from keras import Sequential, layers
from keras.layers import Conv2D, Dense, Flatten, Reshape, LeakyReLU, Dropout, UpSampling2D
import matplotlib.pyplot as plt 
# Brining in tensorflow datasets for fashion mnist 
import tensorflow_datasets as tfds

from sklearn.model_selection import train_test_split
from keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.models import Model
import pickle

In [None]:
strategy = tf.distribute.MirroredStrategy()
print('DEVICES AVAILABLE: {}'.format(strategy.num_replicas_in_sync))

In [None]:
ds = tfds.load('fashion_mnist', split='train')

In [None]:
dataiterator = ds.as_numpy_iterator()

In [None]:
fig, ax = plt.subplots(ncols=4, figsize=(20,20))
for idx in range(4): 
    sample = dataiterator.next()
    ax[idx].imshow(np.squeeze(sample['image']), cmap = "gray")
    ax[idx].title.set_text(sample['label'])

In [None]:
def get_training_data():
    # Reload the dataset 
    train_dataset = tfds.load('fashion_mnist', split='train')
#     print(len(train_dataset))
    # Running the dataset through the scale_images preprocessing step
    train_dataset = train_dataset.map(lambda x: x['image']/255) 
    # Cache the dataset for that batch 
    train_dataset = train_dataset.cache()
    # Shuffle it up 
    train_dataset = train_dataset.shuffle(6000)
    # Batch into 128 images per sample
    train_dataset = train_dataset.batch(256)
    
    return train_dataset

In [None]:
train_data=get_training_data()

In [None]:
len(train_data)

In [None]:
train_data.as_numpy_iterator().next().shape

In [None]:
def build_generator():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(28,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)
    

    return model

In [None]:
# with strategy.scope():
generator = build_generator()
generator.summary()

In [None]:
tf.keras.utils.plot_model(
    generator,
    to_file='/kaggle/working/generator.png',
    show_shapes=False,
    show_dtype=False,
    show_layer_names=True,
    rankdir='TB',
    expand_nested=False,
    dpi=96,
    layer_range=None,
    show_layer_activations=False,
    show_trainable=False
)

In [None]:
# preds = model.predict(np.random.randn(32,3))

imgs = generator.predict(tf.random.normal((4,28)))
# print(imgs.shape)
# Setup the subplot formatting 
fig, ax = plt.subplots(ncols=4, figsize=(20,20))
# Loop four times and get images 
for idx, img in enumerate(imgs): 
    # Plot the image using a specific subplot 
#     print(img.shape)
    ax[idx].imshow(np.squeeze(img))
    # Appending the image label as the plot title 
    ax[idx].title.set_text(idx)

In [None]:
def build_discriminator(): 
#     with strategy.scope():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

In [None]:
discriminator = build_discriminator()
discriminator.summary()

In [None]:
tf.keras.utils.plot_model(
    discriminator,
    to_file='/kaggle/working/discriminator.png',
    show_shapes=False,
    show_dtype=False,
    show_layer_names=True,
    rankdir='TB',
    expand_nested=False,
    dpi=96,
    layer_range=None,
    show_layer_activations=False,
    show_trainable=False
)

In [None]:
class FashionGAN(Model): 
    def __init__(self, generator, discriminator, *args, **kwargs):
        # Pass through args and kwargs to base class 
        super().__init__(*args, **kwargs)
        
        # Create attributes for gen and disc
        self.generator = generator 
        self.discriminator = discriminator 
        self.g_losses, self.d_losses = [], []
        self.g_loss_tracker = keras.metrics.Mean(name = "g_loss")
        self.d_loss_tracker = keras.metrics.Mean(name = "d_loss")
        
    def compile(self, g_opt, d_opt, g_loss, d_loss, *args, **kwargs): 
        # Compile with base class
        super().compile(*args, **kwargs)
        
        # Create attributes for losses and optimizers
        self.g_opt = g_opt
        self.d_opt = d_opt
        self.g_loss = g_loss
        self.d_loss = d_loss 

    def train_step(self, batch):
        # Get the data 
        real_images = batch
        fake_images = self.generator(tf.random.normal((256, 28,1)), training=False)
        
        # Train the discriminator
        with tf.GradientTape() as d_tape: 
            # Pass the real and fake images to the discriminator model
            yhat_real = self.discriminator(real_images, training=True) 
            yhat_fake = self.discriminator(fake_images, training=True)
            yhat_realfake = tf.concat([yhat_real, yhat_fake], axis=0)

            # Create labels for real and fakes images
            y_realfake = tf.concat([tf.zeros_like(yhat_real), tf.ones_like(yhat_fake)], axis=0)

            # Add some noise to the TRUE outputs
            noise_real = 0.15*tf.random.uniform(tf.shape(yhat_real))
            noise_fake = -0.15*tf.random.uniform(tf.shape(yhat_fake))
            y_realfake += tf.concat([noise_real, noise_fake], axis=0)

            # Calculate loss - BINARYCROSS 
            total_d_loss = self.d_loss(y_realfake, yhat_realfake)
        
        
        
        # Apply backpropagation - nn learn 
        dgrad = d_tape.gradient(total_d_loss, self.discriminator.trainable_variables) 
        self.d_opt.apply_gradients(zip(dgrad, self.discriminator.trainable_variables))

        # Train the generator 
        with tf.GradientTape() as g_tape: 
            # Generate some new images
            gen_images = self.generator(tf.random.normal((256,28,1)), training=True)

            # Create the predicted labels
            predicted_labels = self.discriminator(gen_images, training=False)

            # Calculate loss - trick to training to fake out the discriminator
            total_g_loss = self.g_loss(tf.zeros_like(predicted_labels), predicted_labels) 
        

        # Apply backprop
        ggrad = g_tape.gradient(total_g_loss, self.generator.trainable_variables)
        self.g_opt.apply_gradients(zip(ggrad, self.generator.trainable_variables))
        
        self.g_loss_tracker.update_state(total_g_loss)
        self.d_loss_tracker.update_state(total_d_loss)
        return {
            "d_loss":self.d_loss_tracker.result(), 
            "g_loss":self.g_loss_tracker.result()
        }
    @property
    def metrics(self):
        return [ self.d_loss_tracker, self.g_loss_tracker]

In [None]:
class ModelMonitor(Callback):
    def __init__(self, num_img=3, latent_dim=28):
        self.num_img = num_img
        self.latent_dim = latent_dim
        

    def on_epoch_end(self, epoch, logs=None):
        if epoch % 5 == 0:
            # saving the model weights
            self.model.generator.save_weights("/kaggle/working/F_GANgen.weights.h5")
            self.model.discriminator.save_weights("/kaggle/working/F_GANdis.weights.h5")
            
            # saving the generator optimizer config
            gen_config = self.model.g_opt.get_config()
            with open('/kaggle/working/g_optimizer.pkl', 'wb') as f:
                pickle.dump(gen_config, f)
                
            # saving the discriminator optimizer config
            des_config = self.model.d_opt.get_config()
            with open('/kaggle/working/d_optimizer.pkl', 'wb') as f:
                pickle.dump(des_config,f)
            
            # saving the generator and descriminator
#             with open('/kaggle/working/desc.pkl', 'wb') as f:
#                 pickle.dump(self.model.discriminator, f)
                
#             with open('/kaggle/working/gen.pkl', 'wb') as f:
#                 pickle.dump(self.model.generator, f)

            self.model.generator.save('/kaggle/working/gen.keras')
            self.model.discriminator.save('/kaggle/working/des.keras')

                
                
        if epoch % 50 == 0:
            random_latent_vectors = tf.random.uniform((self.num_img, self.latent_dim,1))
            generated_images = self.model.generator(random_latent_vectors)
            generated_images *= 255
            generated_images.numpy()
            
            for i in range(self.num_img):
                img = array_to_img(generated_images[i])
                img.save(os.path.join('/kaggle/working/', f'generated_img_{epoch}_{i}.png'))

In [None]:
with strategy.scope():
    # Define learning rates
    g_learning_rate = 1e-5
    d_learning_rate = 1e-6

    # Create RMSprop optimizers for generator and discriminator
    g_opt = RMSprop(learning_rate=g_learning_rate)
    d_opt = RMSprop(learning_rate=d_learning_rate)
    g_loss = BinaryCrossentropy()
    d_loss = BinaryCrossentropy()
    generator = build_generator()
    discriminator = build_discriminator()
    f_GAN = FashionGAN(generator, discriminator)
    
    f_GAN.compile(g_opt, d_opt, g_loss, d_loss)


In [None]:
hist = f_GAN.fit(train_data, epochs=2000, callbacks=[ModelMonitor()])

In [None]:
def PlotAccuracy(net):
    plt.figure(figsize=(15, 4))
    plt.subplot(1,3,1)
    plt.plot(net['d_loss'],label='Discriminator Loss')
    plt.plot(net['g_loss'],label='Generator Loss')
    plt.title("RMSProp Discriminator and Generator Loss")
    plt.ylabel('Loss')
    plt.xlabel('epochs')
    plt.legend()
    plt.plot()


    plt.subplot(1,3,2)
    plt.plot(net['d_loss'],label='Discriminator Loss')
    plt.plot(net['g_loss'],label='Generator Loss')
    plt.title("Adam Discriminator and Generator Loss")
    plt.ylabel('Loss')
    plt.xlabel('epochs')
    plt.legend()
    plt.plot()

PlotAccuracy(hist.history)

In [None]:
with open("/kaggle/working/Adamaxhistory.pkl","wb") as f:
    pickle.dump(hist.history,f)

In [None]:
# to be run after the model is trained
imgs = generator(tf.random.normal((16, 28)))

fig, ax = plt.subplots(ncols=4, nrows=4, figsize=(10,10))
for r in range(4): 
    for c in range(4): 
        ax[r][c].imshow(imgs[(r+1)*(c+1)-1])


In [None]:
# finally saving the generator and descriminator models

generator.save("/kaggle/working/final_generator_model.keras")
discriminator.save("/kaggle/working/final_descriminator_model.keras")

### If the model has to be trained again using the saved weights, we run the following code


In [None]:
# with strategy.scope():
#     gen, des = build_generator(), build_discriminator()
#     gen.load_weights("/kaggle/input/model-attributes/F_GANgen.weights.h5")
#     des.load_weights("/kaggle/input/model-attributes/F_GANdis.weights.h5")
#     newModel = FashionGAN(gen,des)
# #     print(newModel.d_losses,newModel.g_losses)
#     # loading the saved weights
#     with open("/kaggle/input/model-attributes/g_optimizer (1).pkl", "rb") as fp:
#         g = pickle.load(fp)
#     with open("/kaggle/input/model-attributes/d_optimizer (1).pkl", "rb") as fp:
#         d = pickle.load(fp)
        
# #     newModel.load_weights('/kaggle/input/old-modeld/F_GAN.weights.h5')
    
#     # re-assigning the previous weights
#     new_gopt = Adamax().from_config(g)
#     new_dopt = Adamax().from_config(d)
#     g_loss = BinaryCrossentropy()
#     d_loss = BinaryCrossentropy()
#     newModel.compile(new_gopt, new_dopt,g_loss,d_loss) 
    

In [None]:
# hist = newModel.fit(train_data, epochs = 6, callbacks=[ModelMonitor()])


In [None]:
# plt.plot(hist.history["loss"])
# hist.history

In [None]:
# imgs = gen(tf.random.normal((16, 28)))

# fig, ax = plt.subplots(ncols=4, nrows=4, figsize=(10,10))
# for r in range(4): 
#     for c in range(4): 
#         ax[r][c].imshow(imgs[(r+1)*(c+1)-1])