#### Program 10:
##### Objective:
  Implement a Generative Adversarial Network (GAN) on the MNIST dataset using the PyTorch framework.

Tasks:
  - Define a GAN architecture.
  - Preprocess the MNIST dataset.
  - Define the model train function.
  - Train the model using suitable criterion and optimizer.


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

transform = transforms.Compose([
    transforms.ToTensor()
])

dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
subset = Subset(dataset, range(1000))
dataloader = DataLoader(subset, batch_size=10, shuffle=True)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 28*28),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)

generator = Generator()
discriminator = Discriminator()
criterion = nn.BCELoss()
optim_gen = optim.Adam(generator.parameters(), lr=2e-4)
optim_disc = optim.Adam(discriminator.parameters(), lr=2e-4)

def train(num_epochs):
    for epoch in range(num_epochs):
        generator.train()
        discriminator.train()
        for real, _ in dataloader:
            real = real.view(-1, 28*28)
            batch_size = real.size(0)

            # Train Discriminator
            noise = torch.randn(batch_size, 100)
            fake = generator(noise)
            disc_real = discriminator(real)
            loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
            disc_fake = discriminator(fake)
            loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
            loss_disc = (loss_disc_real + loss_disc_fake) / 2

            # Backprop
            optim_disc.zero_grad()
            loss_disc.backward()
            optim_disc.step()

            # Train Generator
            noise = torch.randn(batch_size, 100)
            fake = generator(noise)
            disc_fake = discriminator(fake)
            loss_gen = criterion(disc_fake, torch.ones_like(disc_fake))

            # Backprop
            optim_gen.zero_grad()
            loss_gen.backward()
            optim_gen.step()

        print(f'Epoch {epoch+1}, Loss D: {loss_disc.item():.4f}, Loss G: {loss_gen.item():.4f}')

train(15)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 51911433.36it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1619721.00it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 13196956.40it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1591389.92it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






Epoch 1, Loss D: 0.5309, Loss G: 1.1956
Epoch 2, Loss D: 0.2890, Loss G: 2.1580
Epoch 3, Loss D: 0.4902, Loss G: 1.6557
Epoch 4, Loss D: 0.4337, Loss G: 1.2396
Epoch 5, Loss D: 0.4178, Loss G: 1.6804
Epoch 6, Loss D: 0.4102, Loss G: 1.5597
Epoch 7, Loss D: 0.3869, Loss G: 1.3843
Epoch 8, Loss D: 0.4317, Loss G: 1.5284
Epoch 9, Loss D: 0.3976, Loss G: 1.3018
Epoch 10, Loss D: 0.2943, Loss G: 1.4342
Epoch 11, Loss D: 0.3589, Loss G: 1.1268
Epoch 12, Loss D: 0.4174, Loss G: 0.9415
Epoch 13, Loss D: 0.5020, Loss G: 0.7116
Epoch 14, Loss D: 0.4995, Loss G: 0.6887
Epoch 15, Loss D: 0.4828, Loss G: 0.7063
