In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

### Discriminator Architecture:
###### (3, 64, 64) --> (64, 32, 32)
###### (64, 32, 32) --> (128, 16, 16)
###### (128, 16, 16) --> (256, 8, 8)
###### (256, 8, 8) --> (512, 4, 4)
###### (512, 4, 4) --> (1, 1, 1)

In [None]:
class Discriminator(nn.Module):
  def __init__(self, img_channels):
    super(Discriminator, self).__init__()
    self.discriminator = nn.Sequential(
        # Layer 1 - (Input: (batch_size, 3, 64, 64)) --> (Output: (batch_size, 64, 32, 32))
        # Paper described - no batch norm in 1st layer
        nn.Conv2d(img_channels, 64, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(0.2),

        # Layer 2 - (Input: (batch_size, 64, 32, 32)) --> (Output: (batch_size, 128, 16, 16))
        nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),

        # Layer 3 - (Input: (batch_size, 128, 16, 16)) --> (Output: (batch_size, 256, 8, 8))
        nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        # Layer 4 - (Input: (batch_size, 256, 8, 8)) --> (Output: (batch_size, 512, 4, 4))
        nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2),

        # Layer 5 - (Input: (batch_size, 512, 4, 4)) --> (Output: (batch_size, 1, 1, 1))
        # Using Sigmoid so that output is between 0(fake) and 1(real)
        nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0),
        nn.Sigmoid(),
    )

  def forward(self, x):
    return self.discriminator(x)

## Generator Architecture
###### (100, 1, 1) --> (1024, 4, 4)
###### (1024, 4, 4) --> (512, 8, 8)
###### (512, 8, 8) --> (256, 16, 16)
###### (256, 16, 16) --> (128, 32, 32)
###### (128, 32, 32) --> (3, 64, 64)

In [None]:
class Generator(nn.Module):
  def __init__(self, img_channels, latent_dim):
    super(Generator, self).__init__()
    self.generator = nn.Sequential(
        # Layer 1: (Input: (batch_size, latent_dim=100, 1, 1)) --> (Output: (batch_size, 1024, 4, 4))
        nn.ConvTranspose2d(latent_dim, 1024, kernel_size = 4, stride = 1, padding = 0),
        nn.BatchNorm2d(1024),
        nn.ReLU(),

        # Layer 2: (Input: (batch_size, 1024, 4, 4)) --> (Output: (batch_size, 512, 8, 8))
        nn.ConvTranspose2d(1024, 512, kernel_size = 4, stride = 2, padding = 1),
        nn.BatchNorm2d(512),
        nn.ReLU(),

        # Layer 3: (Input: (batch_size, 512, 8, 8)) --> (Output: (batch_size, 256, 16, 16))
        nn.ConvTranspose2d(512, 256, kernel_size = 4, stride = 2, padding = 1),
        nn.BatchNorm2d(256),
        nn.ReLU(),

        # Layer 4: (Input: (batch_size, 256, 16, 16)) --> (Output: (batch_size, 128, 32, 32))
        nn.ConvTranspose2d(256, 128, kernel_size = 4, stride = 2, padding = 1),
        nn.BatchNorm2d(128),
        nn.ReLU(),

        # Layer 5: (Input: (batch_size, 128, 32, 32)) --> (Output: (batch_size, img_channels=3, 64, 64))
        # Using Tanh so that output is betwwen -1 and 1, since we will be normalizing our images
        nn.ConvTranspose2d(128, img_channels, kernel_size = 4, stride = 2, padding = 1),
        nn.Tanh(),
    )

  def forward(self, x):
    return self.generator(x)

In [None]:
def init_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0, 0.02)

In [None]:
def test():
  batch_size, img_channels, H, W = 32, 1, 64, 64
  x = torch.randn((batch_size, img_channels, H, W))
  discriminator = Discriminator(img_channels)
  init_weights(discriminator)
  assert discriminator(x).shape == (batch_size, 1, 1, 1)
  generator = Generator(latent_dim = 100, img_channels=img_channels)
  init_weights(generator)
  y = torch.randn((batch_size, 100, 1, 1))
  assert generator(y).shape == (batch_size, img_channels, 64, 64)

test()

In [None]:
img_channels, latent_dim = 1, 100
lr = 3e-4
batch_size = 128
img_size = 64
num_epochs = 3

In [None]:
# Initializing discriminator and its optimizer
discriminator = Discriminator(img_channels)
init_weights(discriminator)
opt_disc = optim.Adam(discriminator.parameters(), lr = lr, betas = (0.5, 0.999))

# Initializing generator and its optimizer
generator = Generator(img_channels=img_channels, latent_dim=latent_dim)
init_weights(generator)
opt_gen = optim.Adam(generator.parameters(), lr = lr, betas = (0.5, 0.999))

criterion = nn.BCELoss()

In [None]:
# Loading and Transforming data
transforms = transforms.Compose(
    [
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, )),
    ]
)

dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms,
                       download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
fixed_noise = torch.randn(32, latent_dim, 1, 1)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:09<00:00, 1094507.85it/s]


Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 134326.56it/s]


Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 845333.91it/s] 


Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 10595399.76it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw






In [None]:
from tqdm import tqdm

for epoch in range(num_epochs):
  for batch_index, (real_img, real_label) in tqdm(enumerate(loader), total=len(loader)):
    ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
    ### BCE Loss = −wn​[yn​⋅logxn​+(1−yn​)⋅log(1−xn​)], so either maximis the above expression or minimize negative of that
    ### First Term: log(D(x))
    D_x = discriminator(real_img).view(-1)
    loss_d_real = criterion(D_x, torch.ones_like(D_x))

    ### Second Term: log(1-D(G(z)))
    random_noise_z = torch.randn((batch_size, latent_dim, 1, 1))
    G_z = generator(random_noise_z)
    D_G_z = discriminator(G_z).view(-1)
    loss_d_fake = criterion(D_G_z, torch.zeros_like(D_G_z))

    ### Total Loss == Average
    loss_d = (loss_d_real + loss_d_fake)/2
    discriminator.zero_grad()
    loss_d.backward(retain_graph = True)
    opt_disc.step()

    ### Train Generator: min log(1 - D(G(z))) which is same as max log(D(G(z)) which is same as min -log(D(G(z))) (1st term of BCE)
    gen_D_G_z = discriminator(G_z).view(-1)
    loss_g = criterion(gen_D_G_z, torch.ones_like(gen_D_G_z))
    generator.zero_grad()
    loss_g.backward()
    opt_gen.step()

    ls_gen_loss = []
    ls_disc_loss = []
    if batch_index % 100 == 0:
      print(f"Gen Loss: {loss_g}")
      print(f"Disc Loss: {loss_d}")
      ls_gen_loss.append(loss_g)
      ls_disc_loss.append(loss_d)

  0%|          | 1/469 [00:11<1:32:26, 11.85s/it]

Gen Loss: 2.2882027626037598
Disc Loss: 0.10390563309192657


 22%|██▏       | 101/469 [18:39<1:10:59, 11.57s/it]

Gen Loss: 4.632757663726807
Disc Loss: 0.00883081927895546


 43%|████▎     | 201/469 [37:09<49:31, 11.09s/it]

Gen Loss: 1.0620778799057007
Disc Loss: 0.5736242532730103


 64%|██████▍   | 301/469 [55:20<30:15, 10.81s/it]

Gen Loss: 1.8106060028076172
Disc Loss: 0.6237796545028687


 86%|████████▌ | 401/469 [1:12:58<11:46, 10.38s/it]

Gen Loss: 1.072310209274292
Disc Loss: 0.6497573852539062


100%|██████████| 469/469 [1:24:50<00:00, 10.85s/it]
  0%|          | 1/469 [00:11<1:28:16, 11.32s/it]

Gen Loss: 1.1321178674697876
Disc Loss: 0.6212272644042969


 22%|██▏       | 101/469 [19:02<1:12:02, 11.75s/it]

Gen Loss: 0.9239950776100159
Disc Loss: 0.6096317172050476


 43%|████▎     | 201/469 [38:20<50:10, 11.23s/it]

Gen Loss: 0.763047456741333
Disc Loss: 0.687522292137146


 64%|██████▍   | 301/469 [56:50<31:55, 11.40s/it]

Gen Loss: 0.6072092652320862
Disc Loss: 0.6305333375930786


 71%|███████   | 334/469 [1:03:12<25:39, 11.40s/it]