# GAN to Generate Handwritten Digits (MNIST)

In this notebook, we will build and train a Generative Adversarial Network (GAN) to create convincing images of handwritten digits.

## How does a GAN work?

A GAN consists of two neural networks that compete against each other:

1.  **The Generator**: Tries to create fake images that are as realistic as possible. It starts from random noise and learns to transform it into something that resembles the real data.
2.  **The Discriminator**: Acts as a judge. Its job is to look at an image and decide if it is real (from the training dataset) or fake (created by the generator).

The training is a zero-sum game:
- The **discriminator** improves by getting better at detecting fakes.
- The **generator** improves by getting better at "fooling" the discriminator.

Over time, the generator becomes so good that its images are almost indistinguishable from real ones.

## 1. Environment Setup

In the following cell, we will import all the necessary libraries for our project. This is the first step in any Python script, where we gather the tools we need to work.

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
import numpy as np
import matplotlib.pyplot as plt
import os

print(f"TensorFlow Version: {tf.__version__}")

## 2. Data Loading and Preparation

The next cell handles loading and preparing our data. We will:
1.  Load the standard MNIST dataset of handwritten digits directly from TensorFlow.
2.  Display the first 100 images to visually inspect our data.
3.  Reshape and normalize the images to a format suitable for the neural network.
4.  Create an efficient TensorFlow `Dataset` object to feed the data to our model in batches during training.

In [None]:
# Load the MNIST dataset. It returns training and testing sets. We only need the training images.
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

print("\nShowing the first 100 example images:")

# Create a figure to display the images, with a size of 5x5 inches.
plt.figure(figsize=(5, 5))

# Loop 100 times to show the first 100 images.
for i in range(100):
    # Create a subplot in a 10x10 grid at position i+1.
    plt.subplot(10, 10, i + 1)
    
    # Display the i-th image in grayscale.
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    
    # Remove the numbered ticks from the x and y axes for a cleaner look.
    plt.xticks([])
    plt.yticks([])
    
# Show the final figure with all the subplots.
plt.show()

# Reshape the images from (60000, 28, 28) to (60000, 28, 28, 1) because convolutional
# layers expect a channel dimension. We also convert the pixel values to float32 type.
train_images_processed = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')

# Normalize the pixel values from the range [0, 255] to [-1, 1]. This helps stabilize training.
train_images_processed = (train_images_processed - 127.5) / 127.5

# Define constants for creating the dataset.
BUFFER_SIZE = 60000 # Number of items to shuffle.
BATCH_SIZE = 256    # Number of images per training batch.

# 1. Define the generator function.
# This function simply yields the images one by one. Shuffling will be handled by tf.data.
def data_generator():
    for img in train_images_processed:
        yield img

# 2. Create the Dataset using from_generator.
# This method is much more memory-efficient.
# We tell TensorFlow what data type and shape to expect (the "output_signature").
train_dataset = tf.data.Dataset.from_generator(
    data_generator,
    output_signature=tf.TensorSpec(shape=(28, 28, 1), dtype=tf.float32)
)

# 3. Apply shuffle and batch operations to the new dataset.
# This is the recommended way to handle shuffling and batching with large datasets.
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

print("\nDataset ready for training (created efficiently with a generator).")

## 3. Creating the Generator Model

The generator's job is to take a random noise vector and transform it into a 28x28 image. It's essentially a reverse convolutional network.

- `def make_generator_model():`: We define a function to build our model.
- `model.add(layers.Input(shape=(100,)))`: It starts with an input layer that accepts a 100-dimension noise vector.
- `model.add(layers.Dense(...))`: A fully connected layer expands this vector into a larger block of data.
- `model.add(layers.Reshape(...))`: This reshapes the data into a small 7x7 'image' with many channels (256).
- `model.add(layers.Conv2DTranspose(...))`: These are deconvolutional layers. They perform upsampling, taking the small 7x7 feature maps and intelligently scaling them up, first to 14x14, and then to the final 28x28 size.
- `activation="tanh"`: The final layer uses a `tanh` activation to ensure the output pixel values are between -1 and 1, matching our normalized real images.

In [None]:
def make_generator_model():
    model = models.Sequential()
    model.add(layers.Input(shape=(100,)))
    
    model.add(layers.Dense(7 * 7 * 256, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Reshape((7, 7, 256)))
    
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1),
                                     padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2),
                                     padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2),padding="same", use_bias=False,activation="tanh"))

    return model
    
generator = make_generator_model()

print("Generator model created")

## 4. Creating the Discriminator Model

The discriminator is a standard Convolutional Neural Network (CNN) for image classification. Its goal is to take an image and output a single value indicating whether it thinks the image is real or fake.

- `def make_discriminator_model():`: Defines the function to build the discriminator.
- `model.add(layers.Conv2D(...))`: These are convolutional layers that process the input image, extracting features and downsampling it (making it smaller) with `strides=(2, 2)`.
- `model.add(layers.LeakyReLU())`: The Leaky ReLU activation function helps the network learn.
- `model.add(layers.Dropout(0.3))`: This layer randomly deactivates 30% of its input units during training to prevent the model from becoming too specialized to the training data (overfitting).
- `model.add(layers.Flatten())`: This layer converts the 2D feature maps from the convolutional layers into a single 1D vector.
- `model.add(layers.Dense(1))`: The final output layer has one neuron, which will produce a single logit value indicating the model's prediction (real or fake).

In [None]:
def make_discriminator_model():
    model = models.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

discriminator = make_discriminator_model()

print("Discriminator model created")

## 5. Defining Losses and Optimizers

Here we define the functions that will measure how 'wrong' our models are (the loss) and the algorithms that will update them (the optimizers).

- `cross_entropy`: We create an instance of `BinaryCrossentropy`, which is the correct loss function for a binary (real/fake) classification problem.
- `discriminator_loss`: This function calculates the discriminator's loss. It's the sum of two parts: how well it identifies real images as real (`real_loss`) and how well it identifies fake images as fake (`fake_loss`).
- `generator_loss`: This function calculates the generator's loss. It measures how well the discriminator was fooled. The generator wins (has low loss) if the discriminator classifies its fake images as real (i.e., close to 1).
- `..._optimizer`: We create an `Adam` optimizer for each network. This algorithm will be used to apply the calculated gradients and update the networks' weights.

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

generator_optimizer = optimizers.Adam(1e-4)
discriminator_optimizer = optimizers.Adam(1e-4)

## 6. Setting up the Training Loop

This cell defines the logic for a single step of training and a helper function to visualize progress.

- `EPOCHS`, `noise_dim`, etc.: We define some constants for the training process.
- `seed`: We create a fixed batch of random noise. This will be fed to the generator at the end of each epoch to create sample images, allowing us to visually track its improvement over time.
- `@tf.function`: This is a decorator that converts the Python function into a high-performance TensorFlow graph, which makes training much faster.
- `def train_step(images)`: This function encapsulates one full training step for a single batch of data.
    - `with tf.GradientTape() as ...`: The `GradientTape` is a crucial tool that 'records' all operations. This allows TensorFlow to automatically calculate the gradients needed for backpropagation.
    - The code inside the `with` block generates fake images, gets the discriminator's predictions for both real and fake images, and calculates the losses for both networks.
    - `gradients_of_... = ...gradient(...)`: The tape computes the gradients of the loss with respect to each network's trainable weights.
    - `...optimizer.apply_gradients(...)`: The optimizer uses these calculated gradients to update the weights of the generator and discriminator.
- `if not os.path.exists(...)`: This checks if a directory named `gan_images` exists to save our output images.
- `os.makedirs(...)`: If the directory does not exist, it is created.
- `def generate_and_save_images(...)`: A helper function that uses the generator to create images from the fixed `seed` noise, then uses `matplotlib` to create a grid of these images and save it to a file.

In [None]:
EPOCHS = 50 # You can start with fewer (e.g., 50) for faster training
noise_dim = 100
num_examples_to_generate = 16

seed = tf.random.normal([num_examples_to_generate, noise_dim])

@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

if not os.path.exists('gan_images'):
    os.makedirs('gan_images')

def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)
    
    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    plt.savefig('gan_images/image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

## 7. Training the GAN

This is where we put everything together and start the training. The code defines a `train` function that will manage the entire process.

In [None]:
# Import libraries for display clearing and time tracking.
from IPython import display
import time

# Define the main training function.
def train(dataset, epochs):
    # Loop for the specified number of epochs.
    for epoch in range(epochs):
        # Record the start time.
        start = time.time()

        # Loop through each batch in the dataset.
        for image_batch in dataset:
            # Perform a single training step on the batch.
            train_step(image_batch)

        # Clear the output of the cell to show the new generated images.
        display.clear_output(wait=True)
        # Generate and save sample images to visualize progress.
        generate_and_save_images(generator,
                                 epoch + 1,
                                 seed)

        # Print the time taken for the epoch to complete.
        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

    # After the final epoch, clear the output and show the final set of generated images.
    display.clear_output(wait=True)
    generate_and_save_images(generator, epochs, seed)

### Executing the Training

The final cell calls the `train` function, passing our prepared dataset and the number of epochs. This command will start the actual training process, which may take a while.

In [None]:
print("Trainning model ...")

train(train_dataset, EPOCHS)

## 8. Conclusion

Congratulations! You have trained a GAN from scratch.

You can look in the `gan_images` folder to see the generator's evolution. At the beginning, the images will be pure noise, but as the epochs go by, they should start to take the shape of recognizable digits.

### Possible next steps:
- **Train longer**: GANs benefit from more training epochs.
- **Adjust hyperparameters**: Try different learning rates, batch sizes, or network architectures.
- **Try with other datasets**: You can adapt this GAN to train on other datasets, such as [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist) or [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html).