# Vanilla Generative Adversarial Networks

In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as T

In [2]:
# device
assert torch.cuda.is_available()

In [3]:
# Training Parameters
LR = 2e-4
NUM_EPOCHS = 50
BATCH_SIZE = 64
LATENT_SIZE = 128
IMG_SIZE = 28*28
DEVICE = torch.device("cuda")

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.generator = nn.Sequential(
            nn.Linear(LATENT_SIZE, 256),
            nn.LeakyReLU(0.1),
            nn.Dropout(p=0.5),
            nn.Linear(256, IMG_SIZE),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.generator(x)

In [5]:
# check generator
gen = Generator()
dummy_input = torch.randn(BATCH_SIZE, LATENT_SIZE)
gen(dummy_input).shape

torch.Size([64, 784])

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.discriminator = nn.Sequential(
            nn.Linear(IMG_SIZE, 256),
            nn.LeakyReLU(0.1),
            nn.Dropout(p=0.5),
            nn.Linear(256, 1),
        )
    
    def forward(self, x):
        return self.discriminator(x)

In [7]:
# check discriminator
dis = Discriminator()
dummy_input = torch.randn(BATCH_SIZE, IMG_SIZE)
dis(dummy_input).shape

torch.Size([64, 1])

In [8]:
# dataset and dataloader
transform = T.Compose([T.ToTensor(), T.Normalize(mean=(0.5 , ), std=(0.5,)) ])
dataset = datasets.MNIST(root='../datasets/', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2)

In [9]:
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)
gen_optim = optim.Adam(generator.parameters(), lr=LR)
dis_optim = optim.Adam(discriminator.parameters(), lr=LR)
criterion = nn.BCEWithLogitsLoss()
writer = SummaryWriter('runs/gan_mnist')

Let us look at the value function again.

$ \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(\mathbf{x})] + \mathbb{E}_{ z \sim p_{z}(z)}[\log(1 - D(G(\mathbf{z})))]$
  
We will transform the above expression into a format, that is easily digestible for deep learning frameworks, that means gradient descent and binary cross-entropy.

Binary cross-entropy is defined as follows.

$ L_n = - [ y_n \cdot \log D(x_n) + (1 - y_n) \cdot \log (1 - D(G(z_n))]$

If the discriminator faces a true image, the loss will collapse to $[-y_n \cdot \log D(x_n)]$. Using this expression in gradient descent is the same as using gradient ascent for $[y_n \cdot \log D(x_n)]$.

If the discriminator faces a fake image, the loss will collapse to $[-(1 - y_n) \cdot \log (1 - D(G(z_n))]$. Using this expression in gradient descent is the same as using gradient ascent for $[(1 - y_n) \cdot \log (1 - D(G(z_n))]$.

The gradient descent and cross-entropy calculation for the generator is slightly more tricky. When the discriminator generates a high probability for a fake image the following expression $\log (1 - D(G(z_n))$ will approach minus infinity so generally we want to use gradient descent on the following expression $[(1 - y_n) \cdot \log (1 - D(G(z_n))]$ in order to trick the discriminator more often. In practice we flip the labels for the discriminator (turn 0 label into 1 label) and minimize $-y_n\log D(G(z_n)$. This trick makes sure, that the gradient is large at the beginning of training, when the generator does not produce convincing results. 

In [10]:
# this tensor is useful to track images that are created from a fixed latent variable
# that way we can observe if the generator starts to create better images
fixed_noise = torch.randn((BATCH_SIZE, LATENT_SIZE), device=DEVICE)

In [11]:
for epoch in range(NUM_EPOCHS):
    dis_loss_col = []
    gen_loss_col = []
    for batch_idx, (features, _) in enumerate(dataloader):
        real_images = features.view(-1, IMG_SIZE).to(DEVICE)
        
        # generate fake images from standar normal distributed latent vector
        latent_vector = torch.randn(BATCH_SIZE, LATENT_SIZE, device=DEVICE)
        fake_imgs = generator(latent_vector)
        
        # calculate logits for true and fake images
        fake_logits = discriminator(fake_imgs.detach())
        real_logits = discriminator(real_images)
        
        # calculate discriminator loss
        dis_real_loss = criterion(real_logits, torch.ones(BATCH_SIZE, 1, device=DEVICE))
        dis_fake_loss = criterion(fake_logits, torch.zeros(BATCH_SIZE, 1, device=DEVICE))
        dis_loss = dis_real_loss + dis_fake_loss
        
        # optimize the discriminator 
        dis_optim.zero_grad()
        dis_loss.backward()
        dis_optim.step()

        # calculate generator loss
        gen_loss = criterion(discriminator(fake_imgs), torch.ones(BATCH_SIZE, 1, device=DEVICE))
        
        # optimize the generator
        gen_optim.zero_grad()
        gen_loss.backward()
        gen_optim.step()
        
        dis_loss_col.append(dis_loss.cpu().item())
        gen_loss_col.append(gen_loss.cpu().item())

    dis_loss = sum(dis_loss_col) / len(dis_loss_col)
    gen_loss = sum(gen_loss_col) / len(gen_loss_col)

    print(f'Epoch {epoch+1}/{NUM_EPOCHS} || Discriminator Loss: {dis_loss:.4f} || Generator Loss: {gen_loss:.4f}')
    
    with torch.inference_mode():
        fake_imgs = generator(fixed_noise).view(-1, 1, 28, 28)
        grid = torchvision.utils.make_grid(fake_imgs, normalize=True)

        writer.add_image(
            "MNIST GAN Generated Images", grid, global_step=epoch+1
        )


Epoch 1/50 || Discriminator Loss: 0.9665 || Generator Loss: 0.9535
Epoch 2/50 || Discriminator Loss: 0.9857 || Generator Loss: 1.0272
Epoch 3/50 || Discriminator Loss: 0.9247 || Generator Loss: 1.1063
Epoch 4/50 || Discriminator Loss: 1.0139 || Generator Loss: 1.0841
Epoch 5/50 || Discriminator Loss: 0.8948 || Generator Loss: 1.2844
Epoch 6/50 || Discriminator Loss: 0.9805 || Generator Loss: 1.2402
Epoch 7/50 || Discriminator Loss: 0.9157 || Generator Loss: 1.3692
Epoch 8/50 || Discriminator Loss: 0.8919 || Generator Loss: 1.4542
Epoch 9/50 || Discriminator Loss: 0.8678 || Generator Loss: 1.5171
Epoch 10/50 || Discriminator Loss: 0.9165 || Generator Loss: 1.4911
Epoch 11/50 || Discriminator Loss: 0.9299 || Generator Loss: 1.4868
Epoch 12/50 || Discriminator Loss: 0.9716 || Generator Loss: 1.4354
Epoch 13/50 || Discriminator Loss: 1.0068 || Generator Loss: 1.3822
Epoch 14/50 || Discriminator Loss: 1.0173 || Generator Loss: 1.3718
Epoch 15/50 || Discriminator Loss: 1.0230 || Generator Lo