<a href="https://colab.research.google.com/github/Riddick4-droid/ComputerVision-TensorFlow/blob/main/Wasserstein_GAN_GP_tensorflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# WASSERSTEIN IMPLEMENTATION - GENERATIVE ADVERSARIAL NETWORKS WITH THE GRADIENT PENALTY

Generative Adversarial Networks perform well in generative tasks such as image and video generation. In my previous notebook titled `Generative Adversarial Network tensorflow` I explored the easiest ways to implement GANS using the tensorflow library. I leveraged the use of Convolutional NNs, batch normalization, etc. In this notebook i introduce a more robust model for the GAN model that deals with the vanishing gradient problem.This model (WGAN-GP) stands for Wasserstein Generative Adversarial Network-with Gradient Penalty is a variation of the normal Deep Convolutional GAN where its key difference is the way the discriminator(critic's) loss is calculated. The initial problem is the presence of vanishing and exploding gradients in the DCGAN where slight changes in the input data(images) had a huge effect on the discriminator's output such that it could not provide reliable information to the generator; as such the generator struggles to generate the generic images from the training data.
To solve this problem the WGAN-GP paper mentioned a technique called the Gradient Penalty. This technique seeks to peanlize how the gradients in the discriminator are updated by enforcing the 1-Lipscihz constraint which forces the gradients within a certain value threshold usually 1. As such gradient updates are controlled and the model learns smoothly.
To implement this, we get rid of the batch normalization in the discriminator to help the model see how each input affects the discriminator. The loss function which was initially computed with the sigmoid values now changes to logits thereby heping us calculate the new loss called Earth Mover's distance which is distance between the real and generated data distributions

everything is the same as in the notebook for GANs with minor tweaks in the architecture of the generator from using the layers.Upsample to using the Conv2DTranspose which is very robust in the task of GANS and differs from the upsampler layer because the upsampler just enlarges the spatial dimension of the feature map by repeating or interpolating pixels, whereas on the other hand, the conv2dtransponse also known as the deconvolution or transponse convolution increases spatial dimensions while simultaneously learning  how to fill in new pixels.

In [None]:
#creating a function to extract data into a folder
import pathlib
import os
import zipfile
from timeit import default_timer as timer
from google.colab import files
from tqdm.auto import tqdm
import shutil
def extractor(path_to_save:str,
              name_of_zipfile:str=None,
              auto_upload:bool=False,
              get_time:bool=False):
    # calculate time

    start=timer()
    try:
        #create new folder
        root = '/content/'

        folder_path = os.path.join(root,path_to_save)

        #make it a directory
        if not os.path.exists(folder_path):
            print(f'creating new dir: {folder_path}')
            os.makedirs(folder_path,exist_ok=True)
        else:
            print(f'path {folder_path} already exists, proceeding to extraction...')

        #unzip
        if auto_upload:
            uploaded = files.upload()
            for f in tqdm(uploaded.keys(),total=len(uploaded)):
                #print filename
                print('='*10)
                print(f'successfully uploaded: {files} into content')
                #handle zip files
                if f.endswith('.zip'):
                    #set zip path
                    zip_path = os.path.join('/content',f)
                    with zipfile.ZipFile(zip_path, 'r') as zipref:
                        #extract to folder_path
                        print(f'extracting content of {f} to {folder_path}')
                        zipref.extractall(folder_path)
                        print('successful!!')
                        print('extracted files: \n',zipref.namelist())
                else:
                    print(f'File {f} does not end with .zip, check again!!')
                    #remove the uploaded file
                    shutil.rmtree(path=f'/content/{f}',ignore_errors=False)
        elif name_of_zipfile:
            with zipfile.ZipFile(f'/content/{name_of_zipfile}.zip', 'r') as zipref:
                #extract to folder_path
                print(f'extracting content of {name_of_zipfile} to {folder_path}')
                zipref.extractall(folder_path)
                print('successful!!')
                print('extracted files: \n',zipref.namelist())

    except Exception as e:
        print(f'Error encountered: {e}')
    end = timer()

    #print time
    if get_time:
        return f'Time taken: {(end-start):.5f}'

In [None]:
#test the function
extractor(path_to_save='image_data',name_of_zipfile=None,auto_upload=True,get_time=True)

In [None]:
#make necessary imports
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import LearningRateScheduler
from keras.callbacks import ReduceLROnPlateau
from keras.callbacks import TensorBoard
from keras import layers

In [None]:
##first step is to configure the function that deals with calculating the losses

#generator loss based of the discriminator's output
def generator_loss(fake_predictions_from_disc):
    return -tf.reduce_mean(fake_predictions_from_disc)

#discriminator loss based on real and fake image classification
def discriminator_loss(fake_images,real_images):
    real_loss = tf.reduce_mean(real_images)
    fake_loss = tf.reduce_mean(fake_images)
    #since we are using the Earth Mover's distance
    loss_diff = fake_loss - real_loss
    #return the diff
    return loss_diff

In [None]:
#build the generator with layers.Conv2DTranspose()
#notice that in the usual DCGAN that I did used layers.Upsample2D()
#however, it is found that for the purpose of GANS the best modules to transition a noisy vector
#is to use the Conv2D transpose which fills in the values needed for the image through interpolation
#interpolation: the technique used to estimate unknown values within a given range
import typing
def generator(random_noise_dim:int,
              activation:typing.Literal['sigmoid','tanh'],
              lora_rank=None):
    """The generator takes in a noisy vector of dim=N(per choice and available computational resource)
    it this uses the interpolation ability of the layers.Conv2DTranspose() to estimate the values needed for
    the full nxn generic image to be formed

    Args:
    (random_noise): The N-dim of the noise vector that is passed through the
    function for it to generate a generic image
    that resembles the training data.

    returns: An nxn image(tensor spec)
    """
    #initialize the input layer
    inputs = layers.Input(shape=(random_noise_dim,))

    #first dense layer
    x = layers.Dense(units=4*4*512,use_bias=False,lora_rank=lora_rank)(inputs)

    #after the layer above it is best practice for GAN generators to have the resulting x reshaped

    x = layers.Reshape((4,4,512))(x)

    #implement dilation_rate for more context capture by the kernel
    x = layers.Conv2DTranspose(filters=512,kernel_size=(5,5),strides=(2,2),padding='same',dilation_rate=(1,1),use_bias=False)(x)
    #apply batchnormilzation
    x = layers.BatchNormalization(center=True)(x)
    #apply an activation
    x = layers.LeakyReLU(negative_slope=0.3)(x)#any negative values in computation are affected by the negative slope factor=0.3

    #repeat these layers
    x = layers.Conv2DTranspose(filters=256,kernel_size=(5,5),strides=(2,2),padding='same',dilation_rate=(1,1),use_bias=False)(x)
    #apply batchnormilzation
    x = layers.BatchNormalization(center=True)(x)
    #apply an activation
    x = layers.LeakyReLU(negative_slope=0.3)(x)#any negative values in computation are affected by the negative slope factor=0.3

    x = layers.Conv2DTranspose(filters=128,kernel_size=(5,5),strides=(2,2),padding='same',dilation_rate=(1,1),use_bias=False)(x)
    #apply batchnormilzation
    x = layers.BatchNormalization(center=True)(x)
    #apply an activation
    x = layers.LeakyReLU(negative_slope=0.3)(x)#any negative values in computation are affected by the negative slope factor=0.3

    x = layers.Conv2DTranspose(filters=64,kernel_size=(5,5),strides=(2,2),padding='same',dilation_rate=(1,1),use_bias=False)(x)
    #apply batchnormilzation
    x = layers.BatchNormalization(center=True)(x)
    #apply an activation
    x = layers.LeakyReLU(negative_slope=0.3)(x)#any negative values in computation are affected by the negative slope factor=0.3

    #add more for a denser architecture-this does not guarantee better performance
    #inlcuding layers.Dropout() to regularize weight updates in the generator
    x = layers.Dropout(rate=0.02)(x)

    x = layers.Conv2D(filters=3,strides=(1,1),kernel_size=(5,5),padding='same',dilation_rate=(1,1),activation=activation)(x)

    assert x.shape == (None,64,64,3) #since we are dealing with RGB images we set the color channels to 3 so shape (none=batch_size,height,width,color_channel=3)

    #create model
    model = tf.keras.Model(inputs=inputs,outputs=x)

    return model

In [None]:
##view model summary
generator_model = generator(random_noise_dim=100,activation='tanh')

##summary
generator_model.summary(show_trainable=True,expand_nested=True)

In [None]:
##build the discriminator
##note that this function will be done wihtout the batchnormalizationlayer
##because the batchnormalization although provides some stability is training by normalization parameter updates
##it does so in a collective manner
##in the WGAN-GP, we want to investigate minor gradient updates caused by input image convolutions/changes
#the batch normalizer prevents us from seeing this -so we don't use it
##the discriminator from the previous tutorial the discrimonator was programmed to output probabilities
##in this new case because we want to compute the wasserstein loss/distance which is the diostance between the
##fake image prediction values and real image prediction distribution data. Instead of dealing with probabilities
##we want to deal with the logits and be able to calculate how far off the discriminator fake and real predictions are apart

def discriminator(batch_images=(64,64,3)):
    """Takes an image (real/generic) and outputs logits
    likelihood of the image being a real or fake image
    returns logits
    """
    inputs = layers.Input(shape=batch_images)

    x = layers.Conv2D(64,(5,5),(2,2),'same',use_bias=False)(inputs)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(128,(5,5),(2,2),'same',use_bias=False)(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(256,(5,5),(2,2),'same',use_bias=False)(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(512,(5,5),(2,2),'same',use_bias=False)(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Flatten()(x)
    x = layers.Dropout(0.02)(x)

    x = layers.Dense(1,activation='linear')(x)

    model = tf.keras.Model(inputs=inputs,outputs=x)

    return model

In [None]:
build_discriminator = discriminator()

#summary
build_discriminator.summary(show_trainable=True,expand_nested=True)

In [None]:
##now i will setup the Gradient Penalty function.
##this function does its part to implement the 1-lipschitz constraint
##it is a mathematical function that takes in real sample data, and fake sample data
##it also takes in the discriminator.
#the goal is to use the function as a means to monintor how the gradients in
##the discriminator's behavior toward minor changes in input and to train it to become more robust towards towards those changes.
##the purpose of the Gradient Penalty is to also enforce the 1-lipschitz constraint which regularizes the discriminator and ensures smoothness and stability
##it prevents overfitting on the training data
##it also improves generalization.
##note that the discriminator is like the marking scheme, as such, if it is unable to properly assess and mark the generator's work well through over excitement,
##the tendencies of the generator thinking it is doing well is high.as such we need to ensure that it does not get too excited when marking the generator's work
#and to do that we have to constrain its weight updates and force it to not go beyond a certain threshold whenever it sees something small happen in the input data from
#generator images and real images. it should be able to properly and srictly ground itself in fairness
#the gradient penalty is implemented with the GAN model so that is one major addition to the model code

#lets subclass the tf.keras.models.Model so that we get accessto keras's attributes and methods to enable us to use some of the
#important features of the Model class like the compile, train_step and the fit() method


In [None]:
#build WGAN-GP
class WGAN_GP(tf.keras.models.Model):
    def __init__(self,
                 discriminator:tf.keras.models.Model,
                 generator:tf.keras.models.Model,
                 random_noise_dim:int,
                 gp_weight: typing.Union[float, int]=10.0,
                 extra_train_steps:int=5,
                 **kwargs
                 )->tf.keras.models.Model:
        super(WGAN_GP,self).__init__(**kwargs)
        self.discriminator=discriminator
        self.generator=generator
        self.noise_dim=random_noise_dim
        self.extra_train_steps = extra_train_steps
        self.gp_weight = gp_weight

    #configure the compile method with custom arguments and the **kwargs so we can pass original arguments from the
    #orignal model's compile method
    def compile(self,
                disc_opt:tf.keras.optimizers.Optimizer,
                gen_opt:tf.keras.optimizers.Optimizer,
                disc_loss:typing.Callable,
                gen_loss:typing.Callable,
                **kwargs):
        """This allows us to override the compile method in order to input acceptable arguments but modified to our usecase.
        Eg. wgan_gp = GWAN_GP().compile(disc_loss,gen_loss,disc_opt,gen_opt,**kwargs like run_eagerly,metrics,etc)
        """
        super(WGAN_GP,self).compile(**kwargs)
        self.disc_opt = disc_opt
        self.gen_opt = gen_opt
        self.disc_loss=disc_loss
        self.gen_loss = gen_loss

    #configure the gradient penalty method
    #rememeber we want to use this as pretrain step to make the 'discriminator' robust to minute changes
    #as such we want to mimic a real case where the discriminator receive real and fake samples
    #then use these two categories to interpplate samples that will be used to train the discriminator
    #the gradient penalty is added to the discriminator's loss
    def gradient_penalty(self,
                        real_samples:tf.Tensor,
                        fake_samples:tf.Tensor,
                        training:bool=True):
        self.real_samples=real_samples
        #cast to float32
        tf.cast(self.real_samples,dtype=tf.float16)
        self.fake_samples=fake_samples
        tf.cast(self.fake_samples,dtype=tf.float16)
        #get the batch_dim of the real images coming in
        batch_size = tf.shape(self.real_samples)[0]

        #setup the epsilon-generate random numbers from a uniform distribution
        #the shape=(batch_size,1,1,1) where the extra 1,1,1 ensures broadcastability when mixing with high dimensional tensors
        #the minval and maxval ensures the sample is drawn from the range [0,1)
        #purpose-used to interpolate between real and fake(generated) samples
        epsilon = tf.random.uniform(shape=[batch_size,1,1,1],minval=0,maxval=1,dtype=tf.float16)


        #interplated samples
        interpolated_samples = epsilon * self.real_samples + (1-epsilon)*self.fake_samples


        #train
        with tf.GradientTape(watch_accessed_variables=True) as tape:
            #watch
            tape.watch(interpolated_samples)
            #get the discriminator's output for the interpolated images
            #remember that in WGAN-GP the discriminato now outputs logits
            logits = self.discriminator(interpolated_samples,training=training)
        #calculate the gradient
        grad = tape.gradient(logits,interpolated_samples)
        #calculate the l2 norm |f(x1)-f(x2)|<= |x1-x2|
        grad_norm = tf.sqrt(tf.reduce_sum(tf.square(grad),axis=[1,2,3]))
        #calculate the gradient penalty
        grad_penalty = tf.reduce_mean((grad_norm-1.0)**2)
        return grad_penalty

    ##define a method which adds noise to each training sample before feeding it into the critic
    def add_instance_noise(self, x: tf.Tensor, stddev: float=0.1) -> tf.Tensor:
        """Adds instance noise to the input tensor.#
        A trick in GANS to ensure stability in training"""
        noise = tf.random.normal(
        tf.shape(x),       # ensure the same shape as input for consistency and avoidance of shape mismatch errors
        mean=0.0,          # centered around zero
        stddev=stddev, # controls the noise intensity-larger values mean more noise,makes the model more robust but affects generated image quality and vice versa
        dtype=tf.float16,
        )
        return x + noise


    #setup the train_step method, overrides the inbuilt one
    def train_step(self,real_images,training:bool=True)->typing.Dict[str,tf.Tensor]:
        batch_size = tf.shape(real_images)[0]
        #noise generator
        noise = tf.random.normal([batch_size,self.noise_dim], dtype=tf.float16)
        #collect the grad penalties
        gps = []

        #train
        #in WGAN-GP we want the discriminator to be more robust so we train it more
        for _ in range(self.extra_train_steps):
            with tf.GradientTape(watch_accessed_variables=True) as disc_tape:
                disc_tape.watch(noise)
                #generate images with generator
                generated_images = self.generator(noise,training=training)
                #add noise to the generated images
                noise_to_gen_images = self.add_instance_noise(generated_images)
                #get the discriminator to predict
                pred_for_generic = self.discriminator(noise_to_gen_images,training=training)
                #include noise in the discriminator prediction of the real images
                pred_for_real = self.discriminator(self.add_instance_noise(real_images),training=training)

                #calculate the WGAN-GP gp
                gp = self.gradient_penalty( # Removed 'discriminator=self.discriminator' argument
                                          real_samples = self.add_instance_noise(real_images),
                                          fake_samples = noise_to_gen_images,training=True)
                gps.append(gp)

                #calculate the gradient with regards to the discriminator loss
                #remember the gp is added to the discriminator with a minor scale factor
                disc_loss = self.disc_loss(pred_for_generic,pred_for_real) + gp * self.gp_weight

            #cmpute the discriminator gradients which comes after calculating the loss
            disc_grad = disc_tape.gradient(disc_loss,self.discriminator.trainable_variables)

            #update the weights accordingly with optimizer
            self.disc_opt.apply_gradients(zip(disc_grad,self.discriminator.trainable_variables))

        #time to train the generator
        with tf.GradientTape() as gen_tape:
            gen_tape.watch(noise)

            #generate images with generator
            generated_images = self.generator(noise,training=training)

            #add noise before passing to the discriminator
            noise_to_gen_images = self.add_instance_noise(generated_images)

            #pass to the discriminator
            fake_pred = self.discriminator(noise_to_gen_images,training=training)

            #calcualate the generator loss
            # The discriminator output influences the generator. the generator tries to minimize this loss
            gen_loss = self.gen_loss(fake_pred)

        #calculate gradients
        gen_grad = gen_tape.gradient(gen_loss, self.generator.trainable_variables)

        #update the weights accordingly with the optimizer
        self.gen_opt.apply_gradients(zip(gen_grad,self.generator.trainable_variables))

        ##update the metrics-this is from the compile() method from the class 'Model'
        for metric in self.metrics:
            metric.update_state(real_images,generated_images)

        #update
        results = {m.name:m.result() for m in self.metrics}
        results.update({'d_loss':disc_loss,'g_loss':gen_loss,'gp':tf.reduce_mean(gps)})
        return results

In [None]:
wgan = WGAN_GP(discriminator=build_discriminator,generator=generator_model,random_noise_dim=100)

In [None]:
#print model summary
wgan.summary(show_trainable=True, expand_nested=True)

In [None]:
##this code was retrieved from pylessons.org
#it aids in visualizing and saving the gif and png of th model output
import cv2
import numpy as np
import os
import imageio
class ResultsCallback(tf.keras.callbacks.Callback):
    """A callback that saves generated images after each epoch."""
    def __init__(
            self,
            noise_dim: int,
            save_name:str,
            results_path: str,
            examples_to_generate: int=16,
            grid_size: tuple=(4, 4),
            spacing: int=5,
            gif_size: tuple=(416, 416),
            duration: float=0.1
        ):
        """ Initializes the ResultsCallback class.

        Args:
            noise_dim (int): The dimensionality of the noise vector that is inputted to the generator.
            results_path (str): The path to the directory where the results will be saved.
            examples_to_generate (int, optional): The number of images to generate and save. Defaults to 16.
            grid_size (tuple, optional): The size of the grid to arrange the generated images. Defaults to (4, 4).
            spacing (int, optional): The spacing between the generated images. Defaults to 5.
            gif_size (tuple, optional): The size of the gif to be generated. Defaults to (416, 416).
            duration (float, optional): The duration of each frame in the gif. Defaults to 0.1.
        """
        super(ResultsCallback, self).__init__()
        self.seed = tf.random.normal([examples_to_generate, noise_dim]) #for each image in the grid
        self.results = []
        self.results_path = results_path + '/results'
        self.grid_size = grid_size
        self.spacing = spacing
        self.gif_size = gif_size
        self.duration = duration
        self.save_name = save_name

        # create the results directory if it doesn't exist
        os.makedirs(self.results_path, exist_ok=True)

    def save_pred(self, epoch: int, results: list) -> None:
        """ Saves the generated images as a grid and as a gif.

        Args:
            epoch (int): The current epoch.
            results (list): A list of generated images.
        """
        # construct an image from generated images with spacing between them using numpy
        w, h , c = results[0].shape
        # construct grid with self.grid_size
        grid = np.zeros((self.grid_size[0] * w + (self.grid_size[0] - 1) * self.spacing, self.grid_size[1] * h + (self.grid_size[1] - 1) * self.spacing, c), dtype=np.uint8)
        for i in range(self.grid_size[0]):
            for j in range(self.grid_size[1]):
                grid[i * (w + self.spacing):i * (w + self.spacing) + w, j * (h + self.spacing):j * (h + self.spacing) + h] = results[i * self.grid_size[1] + j]

        # save the image
        cv2.imwrite(f'{self.results_path}/img_{epoch}.png', grid)

        # save image to memory resized to gif size
        self.results.append(cv2.resize(grid, self.gif_size, interpolation=cv2.INTER_AREA))

    def on_epoch_end(self, epoch: int, logs: dict=None) -> None:
        """Executes at the end of each epoch."""
        predictions = self.model.generator(self.seed, training=False)
        predictions_uint8 = (predictions * 127.5 + 127.5).numpy().astype(np.uint8)
        self.save_pred(epoch, predictions_uint8)

    def on_train_end(self, logs=None) -> None:
        """Executes at the end of training."""
        # save the results as a gif with imageio

        # Create a list of imageio image objects from the OpenCV images
        imageio_images = [imageio.core.util.Image(image) for image in self.results]

        # Write the imageio images to a GIF file
        imageio.mimsave(self.results_path + f"/{self.save_name}.gif", imageio_images, duration=self.duration)

In [None]:
#set callbacks for experiment control and tracking
#for the first call back i will setup the Learning Rate scheduler
#create a function to control learning rate scheduling
def scheduler(epoch,lr):
    """Controls when and direction+magnitude of learning rate"""
    if epoch % 2 == 0:
        #update on even epochs
        return float(lr * tf.exp(-0.1))
    else:
        return float(lr)
#configure the learning rate scheduler
lr_scheduler = LearningRateScheduler(schedule=scheduler,verbose=1)

In [None]:
##next callback will be the ReduceLROnPlateau
lr_plateau = ReduceLROnPlateau(monitor='loss',patience=10,mode='auto',cooldown=0.1,min_delta=1e-4)

In [None]:
#create a function to train the model
def train(model:tf.keras.models.Model,
          epochs:int,
          data,
          noise_dim:int,
          verbose:bool,
          channel:int,
          callbacks:typing.Union[tf.keras.callbacks],
          save_model_name:str,
          run_eagerly:bool=False,
          optimizer:typing.Literal['Adam','SGD']='Adam',
          download_path=False):
    x_train = data

    #incoming image shape
    image_shape=(64,64,channel)

    #set model path
    model_path = os.path.join('/content/trained_model',f'{save_model_name}.h5')
    parent_dir = os.path.dirname(model_path)

    if not os.path.exists(parent_dir): # Check if parent directory exists
        print(f'path {parent_dir} does not exist, creating.....')
        os.makedirs(parent_dir,exist_ok=True)
        print(f'path {parent_dir} successfully created')
    else:
        print(f'path {parent_dir} already exists')

    # check to see if a directory with the model_path name already exists and remove it
    if os.path.isdir(model_path):
        print(f"Warning: Directory '{model_path}' exists, removing it to save the model file.")
        #use shutil.rmtree to get rid of the directory
        shutil.rmtree(model_path)

    #configure model
    model_used = model(discriminator=build_discriminator,generator=generator_model,random_noise_dim=noise_dim)

    #setyp optimizer
    genOpt = getattr(tf.keras.optimizers,optimizer)(learning_rate=0.0001,beta_1=0.99,beta_2=0.99,weight_decay=0.01)
    discOpt = getattr(tf.keras.optimizers,optimizer)(learning_rate=0.0001,beta_1=0.99,beta_2=0.99,weight_decay=0.01)

    #setup results callback
    res_callbacks = ResultsCallback(noise_dim=noise_dim,save_name='flowers',results_path=os.path.dirname(model_path))

    #compile
    model_used.compile(disc_loss=discriminator_loss,
            gen_loss=generator_loss,
            disc_opt=discOpt,
            gen_opt=genOpt,
            run_eagerly=run_eagerly)
    #esnure that callbacks is a list
    if not isinstance(callbacks, (list,tuple)):
        callbacks=[callbacks]

    #add the res_callbacks to the list of callbacks
    callbacks = list(callbacks)+[res_callbacks]

    # using the tf.data.Dataset.from_generator to enable proper iteration through the content of the directory
    dataset = tf.data.Dataset.from_generator(
        lambda: (x_train[i] for i in range(len(x_train))), # Iterate through the DirectoryIterator
        #force dataset to be of atensorpsec with the reuqiered shape
        output_signature=tf.TensorSpec(shape=(None, image_shape[0], image_shape[1], image_shape[2]), dtype=tf.float16)
    ).repeat()#repeat for every batch

    #fit
    results = model_used.fit(dataset,
                             epochs=epochs,
                             steps_per_epoch=20,
                             verbose=verbose,
                             callbacks=callbacks)

    #save
    model_used.save(model_path)
    print(f'Model saved to {model_path}')

    if download_path:
        #download
        files.download(model_path)

    #return
    return results

In [None]:
#tensorboard
import datetime
import tensorboard
#setup log directory
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
#configure tensorboard as callback
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
#configure callbacks as a list
callbacks = [lr_scheduler,lr_plateau,tensorboard_callback]

In [None]:
##set mixed precision
from keras import mixed_precision

#set policy
policy = mixed_precision.Policy('mixed_float16')

mixed_precision.set_global_policy(policy)

In [None]:
#cuda
tf.config.list_physical_devices('GPU')

In [None]:
##setup datapath
from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_gen = ImageDataGenerator(
    #preprocessing
    preprocessing_function = lambda x:(x/127.5)-1.0,
    #apply augmentation
    horizontal_flip=True
)

#flow from the directory
train_data = train_gen.flow_from_directory(directory='/content/image_data',
                                           target_size=(64,64),
                                           class_mode=None,
                                           shuffle=True
                                           )

In [None]:
results = train(model=WGAN_GP,
                epochs=600,
                data=train_data,
                noise_dim=100,
                verbose=True,
                channel=3,
                callbacks=callbacks,
                save_model_name='wgan_model',
                run_eagerly=False,
                optimizer='Adam')

In [None]:
results.history.keys()

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs

the above chart gives a full overview of the model's performance during training. it includes the loss monitoring for both components of the GAN network namely the epoch_g_loss and the epoch_d_loss. These two metrics beloong to the generator and discriminator respectively. These can be seen training as expected with the loss for the generator gaining a downward trajectory implying its prowess to reducing its loss and improving its generative ability.the discriminator tries to maximize its loss implying that it is not being fooled and indicates to the generator to up its game.

Great, now the model is trained on a simple dataset of flowers, lets check inference time and precision of the model to accurately generate an image based of a random noise vector. Remember that, the goal of the GAN model is to take a random noise vector usually sampled from a normal distribution and then produce an image that tries to mimic the training data which it was trained.

In [None]:
##to pertorm inference, lets first load the saved model with the module and specific framework used
##in this case Tensorflow
import tensorflow as tf

# Assuming the WGAN_GP class is defined in the current scope or imported
#load the model from its saved path
loaded_model = tf.keras.models.load_model(
    filepath='/content/trained_model/wgan_model.h5',
    compile=True,
    custom_objects={'WGAN_GP': WGAN_GP(discriminator=build_discriminator,generator=generator_model,random_noise_dim=100)} # Add your custom model class here
)

In [None]:
#the error above is expected, it occurred because I did not perform an override of the basemodel's get_config and from_config
#for serializing and deserializing the model respectively. Usually, to save a model, it has to be serialized for in order for it to be saved
#and deserialized for it to be used. to do that i will do an override of the basemodel's config and use the tf.keras.utils.serialize_keras_objects() and
#tf.keras.deserialize_keras_object()