# 🎨 MNIST GAN - Handwritten Digit Generator

[![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)](https://www.python.org/downloads/)
[![TensorFlow](https://img.shields.io/badge/TensorFlow-2.13+-orange.svg)](https://tensorflow.org/)
[![License](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE)

A comprehensive implementation of **Generative Adversarial Networks (GANs)** for generating handwritten digits using the MNIST dataset. This project demonstrates the power of adversarial training where two neural networks compete to create realistic digit images.

## 🚀 Project Overview

This notebook implements a GAN that learns to generate new handwritten digits that are visually similar to the MNIST training dataset. The model consists of:
- **Generator Network**: Creates fake digit images from random noise
- **Discriminator Network**: Distinguishes between real and generated images  
- **Adversarial Training**: Both networks improve through competition

## 🎯 Key Results
- Successfully generates realistic MNIST-style digits
- Progressive improvement visible across training epochs
- Automatic output organization in structured folders

---


# 🎨 Generative Adversarial Network (GAN) for MNIST Digit Generation

Welcome to this comprehensive implementation of a **Generative Adversarial Network (GAN)** that learns to generate handwritten digits similar to the MNIST dataset! 

## 📚 Import Libraries and Dataset

This section imports all the necessary libraries and loads the MNIST dataset for training our GAN model. We'll be using:
- **TensorFlow/Keras** for deep learning operations
- **NumPy** for numerical computations  
- **Matplotlib** for visualization
- **MNIST dataset** containing 60,000 handwritten digit images

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers import BatchNormalization, LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam

import ssl
import urllib.request

ssl._create_default_https_context = ssl._create_unverified_context

(X_train, _), (_, _) = mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


## 🏗️ Build Generator Network

The **Generator** is responsible for creating fake images from random noise. It takes a 100-dimensional noise vector as input and transforms it into a 28×28 grayscale image that resembles MNIST digits.

**Architecture:**
- **Input:** 100-dimensional random noise vector
- **Hidden Layers:** Dense layers with LeakyReLU activation and Batch Normalization
- **Output:** 784 neurons reshaped to 28×28×1 image with tanh activation
- **Purpose:** Learn to fool the discriminator by generating realistic-looking digits



In [2]:
def build_generator():
  model = Sequential()
  model.add(Dense(256, input_dim = 100))
  model.add(LeakyReLU(alpha = 0.2))
  model.add(BatchNormalization(momentum = 0.8))
  model.add(Dense(512))
  model.add(LeakyReLU(alpha = 0.2))
  model.add(BatchNormalization(momentum = 0.8))
  model.add(Dense(1024))
  model.add(LeakyReLU(alpha = 0.2))
  model.add(BatchNormalization(momentum = 0.8))
  model.add(Dense(784, activation = 'tanh'))
  model.add(Reshape((28, 28, 1)))
  return model

generator = build_generator()

## 🕵️ Build Discriminator Network

The **Discriminator** acts as a binary classifier that distinguishes between real MNIST images and fake images generated by the generator. It's trained to become better at detecting fake images while the generator tries to fool it.

**Architecture:**
- **Input:** 28×28×1 grayscale images (flattened to 784 features)
- **Hidden Layers:** Dense layers with LeakyReLU activation for feature extraction
- **Output:** Single neuron with sigmoid activation (0 = fake, 1 = real)
- **Purpose:** Learn to differentiate between real and generated images
- **Optimizer:** Adam with learning rate 0.0002 and beta1 = 0.5

In [3]:
def build_discriminator():
  model = Sequential()
  model.add(Flatten(input_shape = (28, 28, 1)))
  model.add(Dense(512))
  model.add(LeakyReLU(alpha = 0.2))
  model.add(Dense(256))
  model.add(LeakyReLU(alpha = 0.2))
  model.add(Dense(1, activation = 'sigmoid'))
  return model

discriminator = build_discriminator()
discriminator.compile(optimizer = Adam(0.0002, 0.5), loss = 'binary_crossentropy', metrics = ['accuracy'])



## 🎯 Model Training & GAN Implementation

This section implements the core GAN training loop where the **Generator** and **Discriminator** compete against each other in a minimax game. The training alternates between:

### 🔄 Training Process:
1. **Train Discriminator:** 
   - Feed real MNIST images (label = 1)
   - Feed generator's fake images (label = 0)
   - Update discriminator to better distinguish real vs fake

2. **Train Generator:** 
   - Generate fake images and try to fool discriminator
   - Train generator to make discriminator classify fakes as real (label = 1)

### 📊 Key Features:
- **Epochs:** 10,000 training iterations
- **Batch Size:** 64 images per batch  
- **Progress Tracking:** Loss and accuracy printed every 100 epochs
- **Image Saving:** Generated samples saved to `Output/` folder every 100 epochs
- **Data Normalization:** Images scaled to [-1, 1] range for better training

In [None]:
discriminator.trainable = False

gan_input = Input(shape = (100,))
generated_image = generator(gan_input)
gan_output = discriminator(generated_image)

gan = Model(gan_input, gan_output)
gan.compile(optimizer = Adam(0.0002, 0.5), loss = 'binary_crossentropy')

def train_gan(epochs, batch_size = 128):
  X_train, _ = mnist.load_data()
  X_train = (X_train[0].astype(np.float32) - 127.5) / 127.5
  X_train = np.expand_dims(X_train, axis = 3)

  real = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  for epoch in range(epochs):
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    real_images = X_train[idx]

    noise = np.random.normal(0, 1, (batch_size, 100))
    generated_images = generator.predict(noise)

    d_loss_real = discriminator.train_on_batch(real_images, real)
    d_loss_fake = discriminator.train_on_batch(generated_images, fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    noise = np.random.normal(0, 1, (batch_size, 100))
    g_loss = gan.train_on_batch(noise, real)

    if epoch % 100 == 0:
      print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100*d_loss[1]}] [G loss: {g_loss}]")
      save_images(epoch)

def save_images(epoch):
  import os
  
  # Create Output folder if it doesn't exist
  output_dir = "Output"
  if not os.path.exists(output_dir):
    os.makedirs(output_dir)
  
  r, c = 5, 5
  noise = np.random.normal(0, 1, (r * c, 100))
  generated_images = generator.predict(noise)

  generated_images = 0.5 * generated_images + 0.5

  fig, axs = plt.subplots(r, c)
  count = 0
  for i in range(r):
    for j in range(c):
      axs[i, j].imshow(generated_images[count, :, :, 0], cmap = 'gray')
      axs[i, j].axis('off')
      count += 1

  # Save image in the Output folder
  fig.savefig(f"{output_dir}/gan_images_{epoch}.png")
  plt.close()

train_gan(epochs = 10000, batch_size = 64)



0 [D loss: 0.6017122268676758, acc.: 65.625] [G loss: 0.6385312676429749]
100 [D loss: 0.0074606218840926886, acc.: 100.0] [G loss: 4.790931701660156]
200 [D loss: 0.1798715591430664, acc.: 93.75] [G loss: 3.718496799468994]
300 [D loss: 0.4191024899482727, acc.: 77.34375] [G loss: 3.720974922180176]
400 [D loss: 0.731965959072113, acc.: 42.96875] [G loss: 0.7559784650802612]
500 [D loss: 0.641836404800415, acc.: 51.5625] [G loss: 0.7079087495803833]
600 [D loss: 0.6433212757110596, acc.: 55.46875] [G loss: 0.7660417556762695]
700 [D loss: 0.6234583556652069, acc.: 66.40625] [G loss: 0.749906063079834]
800 [D loss: 0.6160970330238342, acc.: 68.75] [G loss: 0.7661515474319458]
900 [D loss: 0.5807276666164398, acc.: 75.78125] [G loss: 0.8519521355628967]
1000 [D loss: 0.5694043040275574, acc.: 71.09375] [G loss: 0.8829123973846436]
1100 [D loss: 0.6146611273288727, acc.: 57.8125] [G loss: 0.9057543873786926]
1200 [D loss: 0.5757217109203339, acc.: 75.78125] [G loss: 0.9227821230888367]
1