In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.datasets as datasets
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm

### Discriminator Architecture:
###### (3, 64, 64) --> (64, 32, 32)
###### (64, 32, 32) --> (128, 16, 16)
###### (128, 16, 16) --> (256, 8, 8)
###### (256, 8, 8) --> (512, 4, 4)
###### (512, 4, 4) --> (1, 1, 1)

In [2]:
class Discriminator_Critic(nn.Module):
  def __init__(self, img_channels):
    super(Discriminator_Critic, self).__init__()
    self.discriminator = nn.Sequential(
        nn.Conv2d(img_channels, 64, kernel_size = 4, stride = 2, padding = 1),
        nn.LeakyReLU(0.2),
        nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),
        nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),
        nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2),
        nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0),
        nn.Sigmoid(),
    )

  def forward(self, x):
    return self.discriminator(x)

## Generator Architecture
###### (100, 1, 1) --> (1024, 4, 4)
###### (1024, 4, 4) --> (512, 8, 8)
###### (512, 8, 8) --> (256, 16, 16)
###### (256, 16, 16) --> (128, 32, 32)
###### (128, 32, 32) --> (3, 64, 64)

In [3]:
class Generator(nn.Module):
  def __init__(self, img_channels, latent_dim):
    super(Generator, self).__init__()
    self.generator = nn.Sequential(
        nn.ConvTranspose2d(latent_dim, 1024, kernel_size=4, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(1024),
        nn.ReLU(True),
        nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(512),
        nn.ReLU(True),
        nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(256),
        nn.ReLU(True),
        nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(128),
        nn.ReLU(True),
        nn.ConvTranspose2d(128, img_channels, kernel_size=4, stride=2, padding=1, bias=False),
        nn.Tanh()
    )

  def forward(self, x):
    return self.generator(x)

In [4]:
def init_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0, 0.02)

In [5]:
def test():
  batch_size, img_channels, H, W = 32, 1, 64, 64
  x = torch.randn((batch_size, img_channels, H, W))
  discriminator = Discriminator_Critic(img_channels)
  init_weights(discriminator)
  assert discriminator(x).shape == (batch_size, 1, 1, 1)
  generator = Generator(latent_dim = 100, img_channels=img_channels)
  init_weights(generator)
  y = torch.randn((batch_size, 100, 1, 1))
  assert generator(y).shape == (batch_size, img_channels, 64, 64)

test()

In [6]:
lr = 5e-5
batch_size = 64
img_channels = 1
latent_dim = 128
num_epochs = 2
features_critic = 64
features_gen = 64
critic_iter = 5
weight_clips = 0.01

In [7]:
transformations = transforms.Compose(
    [
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(img_channels)], [0.5 for _ in range(img_channels)]
        ),
    ]
)

In [8]:
dataset = datasets.MNIST(root="dataset/", transform=transformations, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 10453804.59it/s]


Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 335131.69it/s]


Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 2693529.97it/s]


Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 8675104.17it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw






In [9]:
generator = Generator(img_channels, latent_dim)
discriminator = Discriminator_Critic(img_channels)
init_weights(generator)
init_weights(discriminator)

In [10]:
generator_opt = optim.RMSprop(generator.parameters(), lr=lr)
discriminator_opt = optim.RMSprop(discriminator.parameters(), lr=lr)

In [None]:
ls_critic_loss = []
ls_gen_loss = []

for epoch in range(num_epochs):
    for batch_idx, (data, labels) in tqdm(enumerate(loader), total = len(loader)):
        cur_batch_size = data.shape[0]

        for _ in range(critic_iter):
            noise = torch.randn(cur_batch_size, latent_dim, 1, 1)
            fake = generator(noise)
            critic_real = discriminator(data).reshape(-1)
            critic_fake = discriminator(fake).reshape(-1)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake))
            discriminator.zero_grad()
            loss_critic.backward(retain_graph=True)
            discriminator_opt.step()

            # clip critic weights between -0.01, 0.01
            for p in discriminator.parameters():
                p.data.clamp_(-weight_clips,weight_clips)

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = discriminator(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        generator.zero_grad()
        loss_gen.backward()
        generator_opt.step()

        if batch_idx % 100 == 0:
            print(f"Gen Loss: {loss_gen}")
            print(f"Disc Loss: {loss_critic}")
            ls_critic_loss.append(loss_critic.item())
            ls_gen_loss.append(loss_gen.item())

  0%|          | 1/938 [00:22<5:56:21, 22.82s/it]

Gen Loss: -0.48125159740448
Disc Loss: -0.06271237134933472


  7%|▋         | 61/938 [22:49<5:25:32, 22.27s/it]