## Chapter 11, Example 1 (GAN on MNIST)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
import os
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

## Generator and Discriminator Classes for GAN

### Generator
The `Generator` class is a neural network for generating synthetic data that mimics some distribution of real data. It takes a random noise vector as input and outputs data that should be indistinguishable from real data to the `Discriminator`.

- `noise_dim`: The dimension of the random noise vector.
- The network consists of four linear layers with ReLU activations, except for the last layer which uses a Tanh activation. This is because the output is expected to be an image with pixel values normalized between -1 and 1, which is the standard for MNIST dataset.
- The output is reshaped into the size of an MNIST image (1 x 28 x 28).

### Discriminator
The `Discriminator` class is a neural network that classifies data as real or synthetic. It takes an image as input and outputs a probability that the image is real.

- The input images are first flattened.
- The network consists of four linear layers with ReLU activations, leading to a single output through a Sigmoid activation function. The Sigmoid is used to output a probability between 0 and 1, indicating how likely it is that the input data is from the real dataset as opposed to being generated by the `Generator`.

### Forward Methods
Both classes have a `forward` method that defines how the data flows through the network:
- For the `Generator`, the `forward` method takes a noise vector `z`, passes it through the network, and reshapes the output to the size of an MNIST image.
- For the `Discriminator`, the `forward` method takes an image `x`, flattens it, and passes it through the network to obtain the probability of the image being real.

These classes are essential components of a Generative Adversarial Network (GAN), where the `Generator` learns to produce more realistic images over time, and the `Discriminator` learns to better distinguish between real and generated images.


In [None]:
# Generator
class Generator(nn.Module):
    def __init__(self, noise_dim=100):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 28*28),
            nn.Tanh()  # Use Tanh for the last layer because MNIST pixel values are between -1 and 1
        )

    def forward(self, z):
        return self.net(z).view(-1, 1, 28, 28)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(28*28, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 512),
            nn.ReLU(True),
            nn.Linear(512, 256),
            nn.ReLU(True),
            nn.Linear(256, 1),
            nn.Sigmoid()  # Sigmoid activation because we need probability outputs
        )

    def forward(self, x):
        x = x.view(-1, 28*28)
        return self.net(x)

In [None]:
# Initialize Generator and Discriminator
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss and optimizer
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)

# Noise sampler
def sample_noise(batch_size, noise_dim):
    return torch.rand(batch_size, noise_dim, device=device)*2 - 1  # Uniform noise [-1, 1]

## Training Loop for GAN

The training process involves alternating between updating the `Discriminator` and the `Generator`. Here's a breakdown of the training loop:

### Parameters:
- `num_epochs`: The number of times to iterate over the entire dataset.
- `noise_dim`: The dimensionality of the noise vector fed into the `Generator`.

### Training Loop:
For each epoch in `num_epochs`:
1. Iterate over the `train_loader`, which provides batches of real images and their associated labels (which are not used here).
2. For each batch:
   - The `batch_size` is determined based on the number of images in the current batch.
   - `real_labels` are tensor of ones, and `fake_labels` are tensor of zeros, representing the true labels for real and fake data, respectively.

### Training the Discriminator:
- Zero the gradients of the `Discriminator` before the forward pass.
- Compute the `Discriminator`'s loss on real images (`loss_d_real`) and perform a backward pass to calculate gradients.
- Generate fake images using the `Generator` with random noise `z`.
- Compute the `Discriminator`'s loss on the fake images (`loss_d_fake`) and perform a backward pass to calculate gradients.
- The total loss for the `Discriminator` is the sum of `loss_d_real` and `loss_d_fake`.
- Update the `Discriminator`'s weights with a step of the optimizer (`optimizer_d`).

### Training the Generator:
- Zero the gradients of the `Generator` before the forward pass.
- Pass the fake images through the `Discriminator` and calculate the loss (`loss_g`) using real labels, to encourage the `Generator` to produce better fakes.
- Perform a backward pass to calculate gradients for the `Generator`.
- Update the `Generator`'s weights with a step of the optimizer (`optimizer_g`).

This training loop is designed to improve the `Generator`'s ability to create images that are indistinguishable from real ones, while the `Discriminator` becomes better at telling them apart. The `Discriminator` is trained first within each batch, followed by the `Generator`, ensuring that the `Generator` uses the most up-to-date version of the `Discriminator` for its updates.


In [1]:
# Create output directory
os.makedirs('mnist_gan_output', exist_ok=True)

# Training
num_epochs = 50
noise_dim = 100
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        batch_size = images.size(0)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Train Discriminator
        discriminator.zero_grad()
        outputs = discriminator(images.to(device)).view(-1, 1)
        loss_d_real = criterion(outputs, real_labels)
        loss_d_real.backward()

        z = sample_noise(batch_size, noise_dim)
        fake_images = generator(z)
        outputs = discriminator(fake_images.detach()).view(-1, 1)
        loss_d_fake = criterion(outputs, fake_labels)
        loss_d_fake.backward()

        loss_d = loss_d_real + loss_d_fake
        optimizer_d.step()

        # Train Generator
        generator.zero_grad()
        outputs = discriminator(fake_images).view(-1, 1)
        loss_g = criterion(outputs, real_labels)
        loss_g.backward()
        optimizer_g.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss D: {loss_d.item():.4f}, Loss G: {loss_g.item():.4f}')

    # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        vutils.save_image(images, 'mnist_gan_output/real_images.png', normalize=True)

    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    vutils.save_image(fake_images, f'mnist_gan_output/fake_images_epoch_{epoch+1}.png', normalize=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 101804937.59it/s]


Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 103358100.53it/s]


Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 61235092.94it/s]

Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5576852.68it/s]


Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Epoch [1/50], Step [100/938], Loss D: 0.4443, Loss G: 1.4049
Epoch [1/50], Step [200/938], Loss D: 0.0327, Loss G: 4.2727
Epoch [1/50], Step [300/938], Loss D: 0.0051, Loss G: 26.2855
Epoch [1/50], Step [400/938], Loss D: 0.3268, Loss G: 5.2529
Epoch [1/50], Step [500/938], Loss D: 1.2156, Loss G: 4.4732
Epoch [1/50], Step [600/938], Loss D: 2.8881, Loss G: 3.4738
Epoch [1/50], Step [700/938], Loss D: 0.3673, Loss G: 5.5105
Epoch [1/50], Step [800/938], Loss D: 0.7306, Loss G: 2.2787
Epoch [1/50], Step [900/938], Loss D: 0.5731, Loss G: 1.5423
Epoch [2/50], Step [100/938], Loss D: 0.2681, Loss G: 2.7369
Epoch [2/50], Step [200/938], Loss D: 0.3621, Loss G: 2.7035
Epoch [2/50], Step [300/938], Loss D: 0.8924, Loss G: 1.6316
Epoch [2/50], Step [400/938], Loss D: 0.3813, Loss G: 2.9162
Epoch [2/50], Step [500/938], Loss D: 0.2133, Loss G: 3.6234
Epoch [2/50], Step [600/938], Loss D: 2.0276, Loss G: 4.67