# Understanding and Implementing Wasserstein GAN (WGAN) from Scratch

Generative Adversarial Networks (GANs) have become a powerful tool in the realm of deep learning for generating realistic data. One variant of GANs, known as Wasserstein GAN (WGAN), introduces a new perspective on training generators and discriminators. In this blog post, we will delve into the intricacies of WGAN and provide a step-by-step implementation using TensorFlow/Keras.

## Introduction to Wasserstein GAN

### The Need for Wasserstein GAN

Traditional GANs use the binary cross-entropy loss function, which can lead to training instability and mode collapse. Wasserstein GAN addresses these issues by introducing the Wasserstein distance (or Earth Mover's distance) as a more reliable metric for training.

### Wasserstein Distance

The Wasserstein distance measures the minimum cost of transforming one probability distribution into another. In the context of GANs, it provides a more continuous and stable gradient for the generator and discriminator, allowing for more robust training.

## Implementation

### Loading and Preprocessing Data

We start by loading the MNIST dataset, a collection of handwritten digits, and normalizing the pixel values to the range [-1, 1].

In [None]:
import numpy as np
from tensorflow.keras.layers import Input, Dense, LeakyReLU, BatchNormalization, Reshape, UpSampling2D, Conv2D, Flatten
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import RMSprop
import tensorflow as tf
from tensorflow.keras.datasets import mnist

# Load the MNIST dataset
(x_train, y_train), (_, _) = mnist.load_data()

x_train = x_train.reshape(60000, 28, 28, 1)
x_train = x_train.astype('float32') / 255

### Building the Generator & Discriminator

Next, we define the architecture of the generator and discriminator using TensorFlow/Keras. The generator creates synthetic images, while the critic evaluates the realness of both real and generated images.

In [None]:
# Define the generator model
def create_generator():
    model = Sequential()
    model.add(Dense(256, use_bias=False, input_shape=(100,)))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(512, use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1024, use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(784, activation='tanh', use_bias=False))
    model.add(Reshape((28, 28, 1)))  # Reshape to image dimensions
    return model

# Define the discriminator model
def create_discriminator():
    model = Sequential()
    model.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=(28, 28, 1)))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Flatten())
    model.add(Dense(1))
    return model

# Create the generator and discriminator models
generator = create_generator()
discriminator = create_discriminator()

### Training the Generator via WGAN

The WGAN model is created by connecting the generator and critic. The generator is trained to generate images that the discriminator classifies as real.

In [None]:
# Define the Wasserstein GAN model
def create_wgan(generator, discriminator):
    discriminator.trainable = False
    wgan_model = Sequential([generator, discriminator])
    return wgan_model


# Compile the discriminator for Wasserstein GAN
discriminator.compile(optimizer=RMSprop(lr=0.00005), loss='mse')

# Create the Wasserstein GAN model
wgan_model = create_wgan(generator, discriminator)
wgan_model.compile(optimizer=RMSprop(lr=0.00005), loss='mse')


### Training Loop

The training loop iterates for a specified number of epochs, updating the critic and generator alternatively.

In [None]:
# Train the WGAN
epochs = 500
batch_size = 128
clip_value = 0.01  # Clip weights to enforce Lipschitz continuity

for epoch in range(epochs):
    for i in range(0, len(x_train) - batch_size + 1, batch_size):
        real_images = x_train[i:i + batch_size]
        noise = np.random.normal(0, 1, (batch_size, 100))
        generated_images = generator.predict(noise)

        real_labels = np.ones((batch_size, 1))
        fake_labels = -np.ones((batch_size, 1))

        # Train the discriminator
        d_loss_real = discriminator.train_on_batch(real_images, real_labels)
        d_loss_fake = discriminator.train_on_batch(generated_images, fake_labels)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # Clip discriminator weights
        for layer in discriminator.layers:
            weights = layer.get_weights()
            weights = [np.clip(w, -clip_value, clip_value) for w in weights]
            layer.set_weights(weights)

        # Train the generator
        noise = np.random.normal(0, 1, (batch_size, 100))
        g_loss = wgan_model.train_on_batch(noise, real_labels)

    # Print losses at the end of each epoch
    print(f'Epoch {epoch + 1}/{epochs}, Discriminator Loss: {d_loss}, Generator Loss: {g_loss}')

    # Save models every 25th epoch
    if (epoch + 1) % 25 == 0:
        generator.save(f'wgan_generator_epoch_{epoch + 1}.h5')
        discriminator.save(f'wgan_discriminator_epoch_{epoch + 1}.h5')

# Save the final generator model
generator.save('wgan_generator_final.h5')

# Save the final discriminator model (optional for evaluation purposes)
discriminator.save('wgan_discriminator_final.h5')   

### Generating Images

Finally, the trained generator is used to create synthetic images, and a grid of these images is displayed. The loaded discriminator performs the task of assessing whether an image is genuine or fake. If the output value is in proximity to 0, it signifies that the generated sample is counterfeit. Conversely, if the output value approaches 1, it indicates that the generated sample is authentic.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model

# Specify the epoch count, increasing in increments of 25 (with a maximum limit of 200), indicating the training duration for both the generator and discriminator.
num_epoch = 200  # Adjust as needed

# Load the generator and discriminator models
generator = load_model(f'Notebooks/Models/Generator/wgan_generator_epoch_{num_epoch}.h5')
discriminator = load_model(f'Notebooks/Models/Discriminator/wgan_discriminator_epoch_{num_epoch}.h5')

# Generate a batch of random noise
batch_size = 144  # Adjust as needed
noise = np.random.normal(0, 1, (batch_size, 100))

# Generate images using the generator
generated_images = generator.predict(noise)

# Rescale the generated images to the range [0, 1]
generated_images = 0.5 * generated_images + 0.5

# Display the generated images
rows, cols = 12, 12  # Adjust as needed
fig, axs = plt.subplots(rows, cols)
fig.suptitle('Generated Images')
idx = 0
for i in range(rows):
    for j in range(cols):
        axs[i, j].imshow(generated_images[idx].reshape(28, 28), cmap='gray')
        axs[i, j].axis('off')
        idx += 1
plt.show()


# Evaluate generated images using the discriminator
discriminator_predictions = discriminator.predict(generated_images)

# Print discriminator predictions for each generated image
for i in range(batch_size):
    print(f"Image {i + 1} - Discriminator Prediction: {discriminator_predictions[i][0]}")


## Conclusion

Wasserstein GAN offers a more stable training process compared to traditional GANs. By understanding the Wasserstein distance and implementing it in TensorFlow/Keras, we can create more reliable generators for various applications.