<a href="https://colab.research.google.com/github/Parshantladhar/Handwritten-Digit-Generator/blob/main/conditional_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Conditional GAN

**Description:** Training a GAN conditioned on class labels to generate handwritten digits.

Generative Adversarial Networks (GANs) let us generate novel image data, video data,
or audio data from a random input. Typically, the random input is sampled
from a normal distribution, before going through a series of transformations that turn
it into something plausible (image, video, audio, etc.).

However, a simple [DCGAN](https://arxiv.org/abs/1511.06434) doesn't let us control
the appearance (e.g. class) of the samples we're generating. For instance,
with a GAN that generates MNIST handwritten digits, a simple DCGAN wouldn't let us
choose the class of digits we're generating.
To be able to control what we generate, we need to _condition_ the GAN output
on a semantic input, such as the class of an image.

In this example, we'll build a **Conditional GAN** that can generate MNIST handwritten
digits conditioned on a given class. Such a model can have various useful applications:

* let's say you are dealing with an
[imbalanced image dataset](https://developers.google.com/machine-learning/data-prep/construct/sampling-splitting/imbalanced-data),
and you'd like to gather more examples for the skewed class to balance the dataset.
Data collection can be a costly process on its own. You could instead train a Conditional GAN and use
it to generate novel images for the class that needs balancing.
* Since the generator learns to associate the generated samples with the class labels,
its representations can also be used for [other downstream tasks](https://arxiv.org/abs/1809.11096).

Following are the references used for developing this example:

* [Conditional Generative Adversarial Nets](https://arxiv.org/abs/1411.1784)
* [Lecture on Conditional Generation from Coursera](https://www.coursera.org/lecture/build-basic-generative-adversarial-networks-gans/conditional-generation-inputs-2OPrG)

If you need a refresher on GANs, you can refer to the "Generative adversarial networks"
section of
[this resource](https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-12/r-3/232).

This example requires TensorFlow 2.5 or higher, as well as TensorFlow Docs, which can be
installed using the following command:

In [None]:
!pip install -q git+https://github.com/tensorflow/docs

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for tensorflow-docs (setup.py) ... [?25l[?25hdone


## Imports

In [None]:
import keras

from keras import layers
from keras import ops
from tensorflow_docs.vis import embed
import tensorflow as tf
import numpy as np
import imageio

## Constants and hyperparameters

In [None]:
batch_size = 64 # determines how many images are processed together in each training iteration.
num_channels = 1 #It indicates that the images used in this model have only one color channel (grayscale).
num_classes = 10 # Since we're working with the MNIST dataset, which has 10 digit classes (0 through 9), this variable represents the total number of classes the model needs to learn.
image_size = 28 # It specifies the height and width of the images in the MNIST dataset. Each image is 28 pixels by 28 pixels.
latent_dim = 128 # This value represents the dimensionality of the latent space, which is a lower-dimensional representation of the input data.

## Loading the MNIST dataset and preprocessing it

In [None]:
# We'll use all the available examples from both the training and test
# sets.
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_labels = np.concatenate([y_train, y_test])

# Scale the pixel values to [0, 1] range, add a channel dimension to
# the images, and one-hot encode the labels.
all_digits = all_digits.astype("float32") / 255.0 #(SCALING) converts the image data to floating-point numbers and scales the pixel values to a range of 0 to 1. This is a common preprocessing step for neural networks.
all_digits = np.reshape(all_digits, (-1, 28, 28, 1)) #(RESHAPING) reshapes the image data to have the format (number of images, height, width, number of channels)
all_labels = keras.utils.to_categorical(all_labels, 10) #(ONE-HOT ENCODING) converts the labels (which are originally numbers from 0 to 9) into a one-hot encoding format. This means each label is transformed into a vector of length 10, where only the element corresponding to the digit's value is 1, and the rest are 0s. This representation is more suitable for training classification models.

# Create tf.data.Dataset.
dataset = tf.data.Dataset.from_tensor_slices((all_digits, all_labels)) #This creates a TensorFlow Dataset object from the NumPy arrays all_digits and all_labels. TensorFlow Datasets are an efficient way to handle large datasets, especially during training
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size) #shuffle(buffer_size=1024): This shuffles the data within a buffer of 1024 elements. Shuffling helps prevent the model from learning patterns based on the order of the data. ##batch(batch_size): This divides the data into batches of size specified by batch_size (which is 64 in this case). Training in batches is more efficient and allows for faster convergence.

print(f"Shape of training images: {all_digits.shape}")
print(f"Shape of training labels: {all_labels.shape}")

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Shape of training images: (70000, 28, 28, 1)
Shape of training labels: (70000, 10)


## Calculating the number of input channel for the generator and discriminator

In a regular (unconditional) GAN, we start by sampling noise (of some fixed
dimension) from a normal distribution. In our case, we also need to account
for the class labels. We will have to add the number of classes to
the input channels of the generator (noise input) as well as the discriminator
(generated image input).

In [None]:
generator_in_channels = latent_dim + num_classes #This line calculates the number of input channels for the generator.
discriminator_in_channels = num_channels + num_classes #This line calculates the number of input channels for the discriminator.
print(generator_in_channels, discriminator_in_channels)

# In Summary:

# These lines determine how many pieces of information are fed into the generator and discriminator. By including the class label along with the noise (for the generator) and the image (for the discriminator), we allow the model to learn the relationship between images and their classes, enabling it to generate images based on specific conditions.

138 11


## Creating the discriminator and generator

The model definitions (`discriminator`, `generator`, and `ConditionalGAN`) have been
adapted from [this example](https://keras.io/guides/customizing_what_happens_in_fit/).

In [None]:
# Create the discriminator.
discriminator = keras.Sequential(
    [
        keras.layers.InputLayer((28, 28, discriminator_in_channels)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),# 64 and 128 specify the number of filters in each layer. ##(3, 3) defines the size of the filters (3x3 pixels). ###strides=(2, 2) means the filter moves 2 pixels at a time, effectively downsampling the image. ####padding="same" ensures the output has the same dimensions as the input.
        layers.LeakyReLU(negative_slope=0.2),#These are activation functions. They introduce non-linearity to the model, allowing it to learn complex patterns. LeakyReLU is a variation of ReLU (Rectified Linear Unit)  that allows a small, non-zero gradient for negative inputs.
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.GlobalMaxPooling2D(),#This layer reduces the spatial dimensions of the feature maps by taking the maximum value within each feature map.
        layers.Dense(1),# A fully connected layer with a single output neuron. This output represents the discriminator's prediction (whether the input is real or fake).
    ],
    name="discriminator",
)
#Discriminator Summary
# In essence, the discriminator is a convolutional neural network that takes an image (and its label information) as input and outputs a single value indicating whether the image is real or fake.

# Create the generator.
generator = keras.Sequential(
    [
        keras.layers.InputLayer((generator_in_channels,)),
        # We want to generate 128 + num_classes coefficients to reshape into a
        # 7x7x(128 + num_classes) map.
        layers.Dense(7 * 7 * generator_in_channels), # A fully connected layer that expands the input vector into a higher-dimensional representation.
        layers.LeakyReLU(negative_slope=0.2), #Same activation function as in the discriminator.
        layers.Reshape((7, 7, generator_in_channels)), # Reshapes the output of the previous layer into a 3D tensor to be used by the following convolutional layers.
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"), #These are transposed convolutional layers. They effectively upsample the input, gradually increasing its spatial dimensions to generate an image. The parameters are similar to Conv2D but work in reverse.
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"), #The final convolutional layer produces the output image. ##1 indicates a single output channel (grayscale). ###activation="sigmoid" ensures the pixel values are in the range [0, 1].
    ],
    name="generator",
)
# The generator takes a random noise vector (and a class label) as input and transforms it into a synthetic image that ideally resembles a real image from the target dataset.

# In summary, these two code blocks define the architectures of the discriminator and generator networks within the Conditional GAN. They work together in an adversarial process, where the generator tries to create realistic images, and the discriminator tries to distinguish between real and generated images.

## Creating a `ConditionalGAN` model

In [None]:

class ConditionalGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.seed_generator = keras.random.SeedGenerator(1337)
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")

    @property
    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):#(d_optimizer) The optimizer used to update the discriminator's weights. ##(g_optimizer) The optimizer used to update the generator's weights. ###loss_fn: The loss function used to measure how well the networks are performing.
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, data):
        # Unpack the data.
        real_images, one_hot_labels = data

        # Add dummy dimensions to the labels so that they can be concatenated with
        # the images. This is for the discriminator.
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = ops.repeat(
            image_one_hot_labels, repeats=[image_size * image_size]
        )
        image_one_hot_labels = ops.reshape(
            image_one_hot_labels, (-1, image_size, image_size, num_classes)
        )

        # Sample random points in the latent space and concatenate the labels.
        # This is for the generator.
        batch_size = ops.shape(real_images)[0]
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )
        random_vector_labels = ops.concatenate(
            [random_latent_vectors, one_hot_labels], axis=1
        )

        # Decode the noise (guided by labels) to fake images.
        generated_images = self.generator(random_vector_labels)

        # Combine them with real images. Note that we are concatenating the labels
        # with these images here.
        fake_image_and_labels = ops.concatenate(
            [generated_images, image_one_hot_labels], -1
        )
        real_image_and_labels = ops.concatenate([real_images, image_one_hot_labels], -1)
        combined_images = ops.concatenate(
            [fake_image_and_labels, real_image_and_labels], axis=0
        )

        # Assemble labels discriminating real from fake images.
        labels = ops.concatenate(
            [ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
        )

        # Train the discriminator.
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space.
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )
        random_vector_labels = ops.concatenate(
            [random_latent_vectors, one_hot_labels], axis=1
        )

        # Assemble labels that say "all real images".
        misleading_labels = ops.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_vector_labels)
            fake_image_and_labels = ops.concatenate(
                [fake_images, image_one_hot_labels], -1
            )
            predictions = self.discriminator(fake_image_and_labels)
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Monitor loss.
        self.gen_loss_tracker.update_state(g_loss)
        self.disc_loss_tracker.update_state(d_loss)
        return {
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
        }


## Training the Conditional GAN

In [None]:
cond_gan = ConditionalGAN(
    discriminator=discriminator, generator=generator, latent_dim=latent_dim
)
cond_gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

cond_gan.fit(dataset, epochs=20)

Epoch 1/20
[1m1094/1094[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 26ms/step - d_loss: 0.4254 - g_loss: 1.6097
Epoch 2/20


KeyboardInterrupt: 

## Interpolating between classes with the trained generator

In [None]:
# We first extract the trained generator from our Conditional GAN.
trained_gen = cond_gan.generator

# Choose the number of intermediate images that would be generated in
# between the interpolation + 2 (start and last images).
num_interpolation = 9  # @param {type:"integer"}

# Sample noise for the interpolation.
interpolation_noise = keras.random.normal(shape=(1, latent_dim))
interpolation_noise = ops.repeat(interpolation_noise, repeats=num_interpolation)
interpolation_noise = ops.reshape(interpolation_noise, (num_interpolation, latent_dim))


def interpolate_class(first_number, second_number):
    # Convert the start and end labels to one-hot encoded vectors.
    first_label = keras.utils.to_categorical([first_number], num_classes)
    second_label = keras.utils.to_categorical([second_number], num_classes)
    first_label = ops.cast(first_label, "float32")
    second_label = ops.cast(second_label, "float32")

    # Calculate the interpolation vector between the two labels.
    percent_second_label = ops.linspace(0, 1, num_interpolation)[:, None]
    percent_second_label = ops.cast(percent_second_label, "float32")
    interpolation_labels = (
        first_label * (1 - percent_second_label) + second_label * percent_second_label
    )

    # Combine the noise and the labels and run inference with the generator.
    noise_and_labels = ops.concatenate([interpolation_noise, interpolation_labels], 1)
    fake = trained_gen.predict(noise_and_labels)
    return fake


start_class = 2  # @param {type:"slider", min:0, max:9, step:1}
end_class = 6  # @param {type:"slider", min:0, max:9, step:1}

fake_images = interpolate_class(start_class, end_class)

Here, we first sample noise from a normal distribution and then we repeat that for
`num_interpolation` times and reshape the result accordingly.
We then distribute it uniformly for `num_interpolation`
with the label identities being present in some proportion.

In [None]:
fake_images *= 255.0
converted_images = fake_images.astype(np.uint8)
converted_images = ops.image.resize(converted_images, (96, 96)).numpy().astype(np.uint8)
imageio.mimsave("animation.gif", converted_images[:, :, :, 0], fps=1)
embed.embed_file("animation.gif")

We can further improve the performance of this model with recipes like
[WGAN-GP](https://keras.io/examples/generative/wgan_gp).
Conditional generation is also widely used in many modern image generation architectures like
[VQ-GANs](https://arxiv.org/abs/2012.09841), [DALL-E](https://openai.com/blog/dall-e/),
etc.

You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/conditional-gan) and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/conditional-GAN).