# Vanilla GAN

Originally proposed by [Goodfellow et al.](https://arxiv.org/abs/1406.2661) is their work titled Generative Adversarial Networks. This network uses a basic implementation where generator and discriminator models are MLPs.

This notebook trains both networks using ADAM optimizer to play the minimax game. We showcase the effectiveness using MNIST digit generation


## Load Libraries

In [10]:
import os
import numpy as np
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets
from torch.autograd import Variable
from torch.utils.data import DataLoader

from torchvision.utils import save_image
import torchvision.transforms as transforms

## Check GPU

In [11]:
CUDA = True if torch.cuda.is_available() else False

## Set Parameters

In [12]:
NUM_CHANNELS = 1
IMG_DIM = 28
BATCH_SIZE = 64
Z_DIM = 256 # Noise Vector Dimension
N_EPOCHS = 200
SAMPLE_INTERVAL = 400
IMG_SHAPE = (NUM_CHANNELS,IMG_DIM, IMG_DIM)

## Get MNIST Dataset

In [13]:
# create directory
os.makedirs("images", exist_ok=True)

# download dataset
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(IMG_DIM), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

## Discriminator Model

In [15]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(IMG_SHAPE)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

## Generator Model

In [14]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # Repeating Parameterised Generator Block of Layers
        def gen_block(in_feat_shape, out_feat_shape):
            layers = [nn.Linear(in_feat_shape, out_feat_shape)]
            layers.append(nn.BatchNorm1d(out_feat_shape, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        # Model Setup
        self.model = nn.Sequential(
            *gen_block(Z_DIM, 256),
            *gen_block(256, 256),
            *gen_block(256, 512),
            *gen_block(512, 1024),
            nn.Linear(1024, int(np.prod(IMG_SHAPE))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *IMG_SHAPE)
        return img

## Attach Loss & Optimizers

In [16]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Loss function
adversarial_loss = torch.nn.BCELoss()

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [17]:
if CUDA:
  generator.cuda()
  discriminator.cuda()
  adversarial_loss.cuda()

  Tensor = torch.cuda.FloatTensor
else:
  Tensor = torch.FloatTensor

## Train GAN

In [18]:
for epoch in range(N_EPOCHS):
    for i, (imgs, _) in enumerate(dataloader):

        # Set Real and Fake Labels
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Set Variable for real images
        real_imgs = Variable(imgs.type(Tensor))

        #  Train Generator
        optimizer_G.zero_grad()

        # Sample noise vector z for generator
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], Z_DIM))))

        # get generator output
        gen_imgs = generator(z)

        # Calculate and update generator loss
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        #  Train Discriminator
        optimizer_D.zero_grad()

        # Calculate Discriminator loss over Fake and Real Samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        # Update Discriminator loss
        d_loss.backward()
        optimizer_D.step()
        print(f'Epoch: {epoch}/{N_EPOCHS}-Batch: {i}/{len(dataloader)}--D.loss:{d_loss.item():.4f},G.loss:{g_loss.item():.4f}')

        batches_done = epoch * len(dataloader) + i
        if batches_done % SAMPLE_INTERVAL == 0:
            save_image(gen_imgs.data[:25], f"images/{batches_done}.png"  , nrow=5, normalize=True)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch: 194/200-Batch\h: 628/938--D.loss:0.3377,G.loss:1.7905
Epoch: 194/200-Batch\h: 629/938--D.loss:0.3871,G.loss:1.4324
Epoch: 194/200-Batch\h: 630/938--D.loss:0.3401,G.loss:1.6544
Epoch: 194/200-Batch\h: 631/938--D.loss:0.3443,G.loss:1.7614
Epoch: 194/200-Batch\h: 632/938--D.loss:0.3648,G.loss:1.8349
Epoch: 194/200-Batch\h: 633/938--D.loss:0.3647,G.loss:1.7287
Epoch: 194/200-Batch\h: 634/938--D.loss:0.3074,G.loss:1.5991
Epoch: 194/200-Batch\h: 635/938--D.loss:0.2957,G.loss:1.8384
Epoch: 194/200-Batch\h: 636/938--D.loss:0.3797,G.loss:2.0893
Epoch: 194/200-Batch\h: 637/938--D.loss:0.3577,G.loss:1.5704
Epoch: 194/200-Batch\h: 638/938--D.loss:0.4256,G.loss:1.4087
Epoch: 194/200-Batch\h: 639/938--D.loss:0.3773,G.loss:2.2183
Epoch: 194/200-Batch\h: 640/938--D.loss:0.3881,G.loss:1.6336
Epoch: 194/200-Batch\h: 641/938--D.loss:0.3466,G.loss:1.3488
Epoch: 194/200-Batch\h: 642/938--D.loss:0.3101,G.loss:1.4910
Epoch: 194/200-Batch