# Wasserstein Generative Adversarial Network with Gradient Penalty

WGAN-GP (which stands for Wasserstein Generative Adversarial Network with Gradient Penalty) improves the architecture of a WGAN (designed to counter training related issues of the basic GAN) by ensuring the critics is 1-Lipschitz without using weight clipping (which impacts the critic's ability to learn). There are 3 main changes, the WGAN-GP critic:
- Includes a gradient penalty term in the critic loss function
- Does not clip its weights
- Does not use bacth normalization layers

Let's build an WGAN-GP "from scratch" to have a better understanding!

## Hand Made WGAN-GP

In [1]:
import os
import numpy as np
import pandas as pd
import pickle

from functools import partial

import tensorflow as tf
import keras.backend as K
from keras.models import Model
from keras.layers import Input, Conv2D, Flatten, Dense, Reshape, UpSampling2D, Activation, BatchNormalization, Dropout, Conv2DTranspose, Layer, Lambda
from keras.initializers import RandomNormal
from keras.optimizers import Adam, RMSprop
from keras.callbacks import Callback, LearningRateScheduler
from keras.callbacks import ModelCheckpoint
from keras.datasets import mnist
from keras.utils import plot_model

# Clear TensorFlow session
K.clear_session()

# Disable eager execution
# from tensorflow.python.framework.ops import disable_eager_execution
# disable_eager_execution()

# Tensorflow debugging
# tf.debugging.enable_check_numerics()

import matplotlib.pyplot as plt

2023-11-27 22:10:03.616401: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-11-27 22:10:03.621548: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2023-11-27 22:10:03.685502: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-27 22:10:03.685558: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-27 22:10:03.687425: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to

### Critic

In [2]:
class Critic():

    def __init__(self,
                 input_dim, critic_n_layers, critic_conv_filters, critic_conv_kernel_sizes, critic_conv_strides,
                 critic_activation,
                 critic_batch_norm_momentum, critic_dropout_rate):
        self.input_dim = input_dim
        self.critic_n_layers = critic_n_layers
        self.critic_conv_filters = critic_conv_filters
        self.critic_conv_kernel_sizes = critic_conv_kernel_sizes
        self.critic_conv_strides = critic_conv_strides
        self.critic_activation = critic_activation
        self.critic_batch_norm_momentum = critic_batch_norm_momentum
        self.critic_dropout_rate = critic_dropout_rate
        self.weight_init = RandomNormal(mean=0., stddev=0.02)

        self.input = Input(shape=self.input_dim, name="critic_input")
        
        x = self.input
        for i in range(self.critic_n_layers):
            x = Conv2D(filters=self.critic_conv_filters[i],
                       kernel_size=self.critic_conv_kernel_sizes[i],
                       strides=self.critic_conv_kernel_sizes[i],
                       padding="same",
                       name="critic_conv_" + str(i))(x)
            
            if self.critic_batch_norm_momentum and i > 0:
                x = BatchNormalization(momentum=self.critic_batch_norm_momentum)(x)

            x = Activation(activation=self.critic_activation, name="critic_activation_" + str(i))(x)

            if self.critic_dropout_rate:
                x = Dropout(rate=self.critic_dropout_rate)(x)

        x = Flatten()(x)

        self.output = Dense(1, activation=None, kernel_initializer=self.weight_init, name="critic_output")(x)

        self.model = Model(self.input, self.output, name="critic")

    def summary(self):
        self.model.summary()

    def predict(self, x):
        return self.model.predict(x)

### Generator

In [3]:
class Generator():

    def __init__(self,
                 latent_dim, generator_initial_dense_layer_size, generator_n_layers, generator_upsamplings, generator_conv_filters, generator_conv_kernel_sizes, generator_conv_strides,
                 generator_activation,
                 generator_batch_norm_momentum, generator_dropout_rate):
        self.generator_initial_dense_layer_size = generator_initial_dense_layer_size
        self.generator_n_layers = generator_n_layers
        self.generator_upsamplings = generator_upsamplings
        self.generator_conv_filters = generator_conv_filters
        self.generator_conv_kernel_sizes = generator_conv_kernel_sizes
        self.generator_conv_strides = generator_conv_strides
        self.latent_dim = latent_dim
        self.generator_activation = generator_activation
        self.generator_batch_norm_momentum = generator_batch_norm_momentum
        self.generator_dropout_rate = generator_dropout_rate

        self.input = Input(shape=(self.latent_dim,), name="generator_input")

        x = Dense(np.prod(self.generator_initial_dense_layer_size))(self.input) # Connect the input to a dense layer

        if self.generator_batch_norm_momentum:
            x = BatchNormalization(momentum=self.generator_batch_norm_momentum)(x)

        x = Activation(self.generator_activation)(x)

        x = Reshape(self.generator_initial_dense_layer_size)(x) # Reshape latent space vector for convolutional transpose layers

        if self.generator_dropout_rate:
            x = Dropout(rate=self.generator_dropout_rate)(x)

        for i in range(self.generator_n_layers):
            if self.generator_upsamplings[i]:
                x = UpSampling2D(name="generator_up_sampling_" + str(i))(x)
                conv_layer = Conv2D(filters=self.generator_conv_filters[i],
                                    kernel_size=self.generator_conv_kernel_sizes[i],
                                    strides=self.generator_conv_strides[i],
                                    padding="same",
                                    name="generator_conv_" + str(i))
                x = conv_layer(x)
            else:
                conv_t_layer = Conv2DTranspose(filters=self.generator_conv_filters[i],
                                               kernel_size=self.generator_conv_kernel_sizes[i],
                                               strides=self.generator_conv_strides[i],
                                               padding="same",
                                               name="generator_conv_t_" + str(i))
                x = conv_t_layer(x)

            if i < self.generator_n_layers - 1:
                if self.generator_batch_norm_momentum:
                    x = BatchNormalization(momentum=self.generator_batch_norm_momentum)(x)
                x = Activation(activation="relu")(x)
            else:
                x = Activation(activation="tanh")(x)

        self.output = x

        self.model = Model(self.input, self.output, name="generator")

    def summary(self):
        self.model.summary()

    def predict(self, x):
        return self.model.predict(x)
        

### Generative Adversarial Network

In [4]:
class RandomWeightedAverage(Layer):
    def __init__(self, batch_size):
        super(RandomWeightedAverage, self).__init__()
        self.batch_size = batch_size

    def call(self, inputs):
        real_images = inputs[0]
        generated_images = inputs[1]
        alpha = tf.random.uniform(shape=[self.batch_size, 1, 1, 1])
        return (alpha * real_images) + ((1 - alpha) * generated_images)

In [6]:
class GenerativeAdversarialNetwork():

    def __init__(self,
                 input_dim, critic_n_layers, critic_conv_filters, critic_conv_kernel_sizes, critic_conv_strides,
                 critic_activation, critic_learning_rate,
                 critic_batch_norm_momentum, critic_dropout_rate,
                 latent_dim, generator_initial_dense_layer_size, generator_n_layers, generator_upsamplings, generator_conv_filters, generator_conv_kernel_sizes, generator_conv_strides,
                 generator_activation, generator_learning_rate,
                 generator_batch_norm_momentum, generator_dropout_rate,
                 batch_size, gradient_loss_weight):
        self.input_dim = input_dim
        self.critic_n_layers = critic_n_layers
        self.critic_conv_filters = critic_conv_filters
        self.critic_conv_kernel_sizes = critic_conv_kernel_sizes
        self.critic_conv_strides = critic_conv_strides
        self.critic_activation = critic_activation
        self.critic_learning_rate = critic_learning_rate
        self.critic_batch_norm_momentum = critic_batch_norm_momentum
        self.critic_dropout_rate = critic_dropout_rate
        
        self.generator_initial_dense_layer_size = generator_initial_dense_layer_size
        self.generator_n_layers = generator_n_layers
        self.generator_upsamplings = generator_upsamplings
        self.generator_conv_filters = generator_conv_filters
        self.generator_conv_kernel_sizes = generator_conv_kernel_sizes
        self.generator_conv_strides = generator_conv_strides
        self.latent_dim = latent_dim
        self.generator_activation = generator_activation
        self.generator_learning_rate = generator_learning_rate
        self.generator_batch_norm_momentum = generator_batch_norm_momentum
        self.generator_dropout_rate = generator_dropout_rate

        self.batch_size = batch_size
        self.gradient_loss_weight = gradient_loss_weight
        self.epoch = 0
        self.critic_losses = []
        self.generator_losses = []

        # Create critic
        self.critic = Critic(self.input_dim,
                                    self.critic_n_layers, self.critic_conv_filters, self.critic_conv_kernel_sizes, self.critic_conv_strides,
                                    self.critic_activation,
                                    self.critic_batch_norm_momentum, self.critic_dropout_rate)
        
        # Create Generator
        self.generator = Generator(self.latent_dim, self.generator_initial_dense_layer_size,
                                   self.generator_n_layers, self.generator_upsamplings, self.generator_conv_filters, self.generator_conv_kernel_sizes, self.generator_conv_strides,
                                   self.generator_activation,
                                   self.generator_batch_norm_momentum, self.generator_dropout_rate)

        # Define loss functions
        def wasserstein_loss(y_true, y_pred):
            return -K.mean(y_true * y_pred)
        
        def gradient_penalty_loss(y_true, y_pred, interpolated_samples):
            gradients = K.gradients(y_pred, interpolated_samples)[0] # Gradients of the predictions for the interpolated images (y_pred) with respect to the input (interpolated_samples)
            gradient_l2_norm = K.sqrt(
                K.sum(
                    K.square(gradients),
                    axis=np.arange(1, len(gradients.shape))
                )
            )
            gradient_penalty = K.square(1 - gradient_l2_norm)
            return K.mean(gradient_penalty)
        
        # Define utility function
        # @tf.function
        def random_weighted_average(inputs):
            alpha = tf.random.uniform(shape=(self.batch_size, 1, 1, 1))
            return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])
        
        # Compile model that trains the critic
        self.set_trainable(self.generator.model, False) # Freeze generator while training critic

        # Real images
        real_img = Input(shape=self.input_dim, name="adversarial_critic_real_img_input")
        validity_real = self.critic.model(real_img)

        # Generated image
        generated_img = Input(shape=self.input_dim, name="adversarial_critic_generated_img_input")
        validity_generated = self.critic.model(generated_img)

        # Weighted average between real and generated images
        interpolated_img = RandomWeightedAverage(self.batch_size)([real_img, generated_img])
        # interpolated_img = Lambda(lambda x: RandomWeightedAverage(self.batch_size)(x), name="random_weighted_images")([real_img, generated_img])
        # interpolated_img = Lambda(random_weighted_average, output_shape=lambda x: x[0], name="random_weighted_images")([real_img, generated_img])
        validity_interpolated = self.critic.model(interpolated_img)
        # validity_interpolated = CustomLambdaLayer()(self.critic.model, interpolated_img)

        self.critic_model = Model(inputs=[real_img, generated_img],
                                  outputs=[validity_real, validity_generated, validity_interpolated],
                                  name="adversarial_critic")

        # Use Python partial to provide loss function with additional 'interpolated_samples' argument
        partial_gp_loss = partial(gradient_penalty_loss,
                                  interpolated_samples=interpolated_img)
        partial_gp_loss.__name__ = "gradient_penalty" # Keras requires function names

        self.critic_model.compile(optimizer=Adam(learning_rate=self.critic_learning_rate, beta_1=0.5),
                                  loss=[wasserstein_loss, wasserstein_loss, partial_gp_loss],
                                  loss_weights=[1, 1, self.gradient_loss_weight])
        self.set_trainable(self.generator.model, True) # Unfreeze generator while training critic
        
        # Compile model that trains the generator
        self.set_trainable(self.critic.model, False)
        self.input = Input(shape=(self.latent_dim,), name="generative_adversarial_network_input")
        self.output = self.critic.model(self.generator.model(self.input))
        self.model = Model(self.input, self.output, name="generative_adversarial_network")
        self.model.compile(optimizer=RMSprop(learning_rate=self.generator_learning_rate),
                           loss=wasserstein_loss,
                           metrics=["accuracy"])
        self.set_trainable(self.critic.model, True)

    def summary(self):
        self.model.summary()

    def plot_model(self, run_folder):
        plot_model(self.model, to_file=os.path.join(run_folder ,'viz/model.png'), show_shapes=True, show_layer_names=True)
        plot_model(self.critic.model, to_file=os.path.join(run_folder ,'viz/critic.png'), show_shapes=True, show_layer_names=True)
        plot_model(self.generator.model, to_file=os.path.join(run_folder ,'viz/generator.png'), show_shapes=True, show_layer_names=True)
        plot_model(self.critic_model, to_file=os.path.join(run_folder ,'viz/adversarial_critic.png'), show_shapes=True, show_layer_names=True)

    def load_weights(self, filepath="model/weights/params.pkl"):
        self.model.load_weights(filepath)

    def set_trainable(self, model, value):
        model.trainable = value
        for l in model.layers:
            l.trainable = value

    def fit_critic(self, X_train):
        # Real images
        idx = np.random.randint(0, X_train.shape[0], self.batch_size)
        real_imgs = X_train[idx]
        valid = np.ones((self.batch_size, 1), dtype=np.float32)

        # Generated images
        noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
        gen_imgs = self.generator.predict(noise)
        generated = -np.ones((self.batch_size, 1), dtype=np.float32)

        # Random weighted images
        random_weighted = np.zeros((self.batch_size, 1), dtype=np.float32)

        c_loss = self.critic_model.train_on_batch(x=[real_imgs, gen_imgs], y=[valid, generated, random_weighted])

        return c_loss
    
    def fit_generator(self):
        valid = np.ones((self.batch_size, 1)) # Forces the generator to produce images considered as valid by the critic
        noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
        return self.model.train_on_batch(noise, valid)
    
    def fit(self, X_train, epochs):

        for epoch in range(self.epoch, self.epoch + epochs):
            
            # Train the critic multiple times
            for _ in range(5):
                c_loss = self.fit_critic(X_train)

            # Train the generator once
            g_loss = self.fit_generator()

            print ("%d [D loss: (%.3f)(R %.3f, F %.3f)] [D acc: (%.3f)(%.3f, %.3f)] [G loss: %.3f] [G acc: %.3f]" % (epoch,
                                                                                                                     c_loss[0], c_loss[1], c_loss[2], c_loss[3], c_loss[4], c_loss[5],
                                                                                                                     g_loss[0], g_loss[1]))

            self.critic_losses.append(c_loss)
            self.generator_losses.append(g_loss)

            self.epoch += 1

    def predict(self, x):
        return self.model.predict(x)

    def save(self, folder="model"):
        if not os.path.exists(folder):
            os.makedirs(folder)
            os.makedirs(os.path.join(folder, 'viz'))
            os.makedirs(os.path.join(folder, 'weights'))
            os.makedirs(os.path.join(folder, 'images'))

        with open(os.path.join(folder, 'weights/params.pkl'), 'wb') as f:
            pickle.dump([self.input_dim,
                         self.critic_conv_filters,
                         self.critic_conv_kernel_sizes,
                         self.critic_conv_strides,
                         self.critic_batch_norm_momentum,
                         self.critic_dropout_rate,
                         self.latent_dim,
                         self.generator_conv_filters,
                         self.generator_conv_kernel_sizes,
                         self.generator_conv_strides,
                         self.generator_batch_norm_momentum,
                         self.generator_dropout_rate], f)
        self.plot_model(folder)

In [7]:
LATENT_DIM = 100
BATCH_SIZE = 64

generative_adversarial_network = GenerativeAdversarialNetwork(input_dim=(28,28,1),
                                                              critic_n_layers=4,
                                                              critic_conv_filters=[64,64,128,128],
                                                              critic_conv_kernel_sizes=[5,5,5,5],
                                                              critic_conv_strides=[2,2,2,1],
                                                              critic_activation='relu',
                                                              critic_learning_rate=0.0008,
                                                              critic_batch_norm_momentum=None,
                                                              critic_dropout_rate=0.4,
                                                              latent_dim=LATENT_DIM,
                                                              generator_initial_dense_layer_size=(7,7,64),
                                                              generator_n_layers=4,
                                                              generator_upsamplings=[True,False,True,False],
                                                              generator_conv_filters=[128,64,64,1],
                                                              generator_conv_kernel_sizes=[5,5,5,5],
                                                              generator_conv_strides=[1,1,1,1],
                                                              generator_activation='relu',
                                                              generator_learning_rate=0.0004,
                                                              generator_batch_norm_momentum=0.9,
                                                              generator_dropout_rate=None,
                                                              batch_size=BATCH_SIZE,
                                                              gradient_loss_weight=10)

### 1
### 2
### 3


In [8]:
generative_adversarial_network.plot_model("model")

In [9]:
generative_adversarial_network.critic.summary()

Model: "critic"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 critic_input (InputLayer)   [(None, 28, 28, 1)]       0         
                                                                 
 critic_conv_0 (Conv2D)      (None, 6, 6, 64)          1664      
                                                                 
 critic_activation_0 (Activ  (None, 6, 6, 64)          0         
 ation)                                                          
                                                                 
 dropout (Dropout)           (None, 6, 6, 64)          0         
                                                                 
 critic_conv_1 (Conv2D)      (None, 2, 2, 64)          102464    
                                                                 
 critic_activation_1 (Activ  (None, 2, 2, 64)          0         
 ation)                                                     

In [10]:
generative_adversarial_network.generator.summary()

Model: "generator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 generator_input (InputLaye  [(None, 100)]             0         
 r)                                                              
                                                                 
 dense (Dense)               (None, 3136)              316736    
                                                                 
 batch_normalization (Batch  (None, 3136)              12544     
 Normalization)                                                  
                                                                 
 activation (Activation)     (None, 3136)              0         
                                                                 
 reshape (Reshape)           (None, 7, 7, 64)          0         
                                                                 
 generator_up_sampling_0 (U  (None, 14, 14, 64)        0 

In [11]:
generative_adversarial_network.summary()

Model: "generative_adversarial_network"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 generative_adversarial_net  [(None, 100)]             0         
 work_input (InputLayer)                                         
                                                                 
 generator (Functional)      (None, 28, 28, 1)         844161    
                                                                 
 critic (Functional)         (None, 1)                 718913    
                                                                 
Total params: 1563074 (5.96 MB)
Trainable params: 1556290 (5.94 MB)
Non-trainable params: 6784 (26.50 KB)
_________________________________________________________________


In [12]:
generative_adversarial_network.critic_model.summary()

Model: "adversarial_critic"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 adversarial_critic_real_im  [(None, 28, 28, 1)]          0         []                            
 g_input (InputLayer)                                                                             
                                                                                                  
 adversarial_critic_generat  [(None, 28, 28, 1)]          0         []                            
 ed_img_input (InputLayer)                                                                        
                                                                                                  
 random_weighted_average (R  (64, 28, 28, 1)              0         ['adversarial_critic_real_img_
 andomWeightedAverage)                                              input[0][0]',

### Load Data

In [13]:
# Load MNIST dataset
mnist_dataset = mnist.load_data()
(trainset, testset) = (mnist_dataset[0], mnist_dataset[1])
(X_train, y_train) = trainset
(X_test, y_test) = testset

# Preprocess data (convert to float and scale to between 0 and 1)
X_train = X_train.astype('float32')
X_train /= 255
X_test = X_test.astype('float32')
X_test /= 255

# Preprocess data (convert to uint8)
# y_train = y_train.astype('uint8')
# y_test = y_test.astype('uint8')

### Train

In [14]:
EPOCHS = 2000

In [15]:
generative_adversarial_network.fit(X_train, EPOCHS)



TypeError: in user code:

    File "/home/fabien/anaconda3/lib/python3.9/site-packages/keras/src/engine/training.py", line 1401, in train_function  *
        return step_function(self, iterator)
    File "/tmp/ipykernel_17304/3478761015.py", line 56, in gradient_penalty_loss  *
        gradients = K.gradients(y_pred, interpolated_samples)[0] # Gradients of the predictions for the interpolated images (y_pred) with respect to the input (interpolated_samples)
    File "/home/fabien/anaconda3/lib/python3.9/site-packages/keras/src/backend.py", line 4693, in gradients  **
        return tf.compat.v1.gradients(
    File "/home/fabien/anaconda3/lib/python3.9/site-packages/keras/src/engine/keras_tensor.py", line 285, in __array__
        raise TypeError(

    TypeError: You are passing KerasTensor(type_spec=TensorSpec(shape=(64, 28, 28, 1), dtype=tf.float32, name=None), name='random_weighted_average/add:0', description="created by layer 'random_weighted_average'"), an intermediate Keras symbolic input/output, to a TF API that does not allow registering custom dispatchers, such as `tf.cond`, `tf.function`, gradient tapes, or `tf.map_fn`. Keras Functional model construction only supports TF API calls that *do* support dispatching, such as `tf.math.add` or `tf.reshape`. Other APIs cannot be called directly on symbolic Kerasinputs/outputs. You can work around this limitation by putting the operation in a custom Keras layer `call` and calling that layer on this symbolic input/output.


### Evaluate Generative Adversarial Network

In [None]:
# Plot critic losses
plt.plot([row[0] for row in generative_adversarial_network.critic_losses], label="loss")
plt.plot([row[1] for row in generative_adversarial_network.critic_losses], label="loss (real images)")
plt.plot([row[2] for row in generative_adversarial_network.critic_losses], label="loss (generated images)")
plt.legend()
plt.show()

In [None]:
# Plot critic accuracy
plt.plot([row[3] for row in generative_adversarial_network.critic_losses], label="accuracy")
plt.plot([row[4] for row in generative_adversarial_network.critic_losses], label="accuracy (real images)")
plt.plot([row[5] for row in generative_adversarial_network.critic_losses], label="accuracy (generated images)")
plt.legend()
plt.show()

In [None]:
# Plot generator loss
plt.plot([row[0] for row in generative_adversarial_network.generator_losses], label="loss")
plt.legend()
plt.show()

In [None]:
# Plot generator accuracy
plt.plot([row[1] for row in generative_adversarial_network.generator_losses], label="accuracy")
plt.legend()
plt.show()

### Save Model

In [None]:
generative_adversarial_network.save()

### Load Pre-Trained Model

In [None]:
# variational_autoencoder.load_weights()

### Predictions

In [None]:
def modifier_vector(size, dim, value):
    v = np.zeros(shape=(1,size))
    v[dim] = value
    return v

In [None]:
NB_PRED = 5

latent_vector_init = np.random.random(size=(1,LATENT_DIM))

latent_vectors = tf.data.Dataset.from_tensor_slices([latent_vector_init + modifier_vector(LATENT_DIM, 0, i) for i in range(NB_PRED)])

# list(latent_vectors.as_numpy_iterator())

In [None]:
latent_vectors

In [None]:
def l1_compare_images(img1, img2):
    return np.mean(np.abs(img1 - img2))

def find_closest(img):
    closest = None
    closest_l1 = None
    for i in X_train:
        l1 = l1_compare_images(img, i)
        if closest_l1 is None or l1 < closest_l1:
            closest = i
            closest_l1 = l1
    for i in X_test:
        l1 = l1_compare_images(img, i)
        if closest_l1 is None or l1 < closest_l1:
            closest = i
            closest_l1 = l1
    return closest, closest_l1

In [None]:
predictions = generative_adversarial_network.generator.predict(latent_vectors)

closest_in_dataset = []
latent_vectors_iterator = iter(latent_vectors)
for i in range(NB_PRED):
    closest_in_dataset.append(find_closest(predictions[i]))
    plt.figure(figsize=((5,5)))
    plt.subplot(1,2,1)
    plt.imshow(predictions[i], cmap="gray")
    plt.title(f"{latent_vectors_iterator.get_next()}")
    plt.axis("off")
    plt.subplot(1,2,2)
    plt.imshow(closest_in_dataset[i][0], cmap="gray")
    plt.title(f"l1 = {closest_in_dataset[i][1]}")
    plt.axis("off")

GANs are notoriously hard to train because:
- The losses can oscillate leading to bad results. This is called **Oscillating Loss**.
- The generator can find a small number of samples that fool the critic (called *modes*), it starts to map every point in the latent space to this observation and the gradient of the loss function collapses to 0 preventing it to learn properly (even if we train the critic in a way it is not fooled anymore by these observations, the generator has become numb and will simply find other modes). This phenomenon is know as **Mode Collapse**.
- The lack of correlation between the generator loss and the image quality makes the GAN hard to train. The generator is graded against a critic that improves overtime so the loss can increase while the image quality improves. GANs use an **Uninformative Loss**.
- GANs uses a lot of **Hyperparameters** that require fine-tuning.

In order to counter some of those effects, some improvements can be made to the basic GAN architecture. For instance the Wasserstein GAN is an attempt a improving GANs.