In [9]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable


import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets

In [2]:
num_eps=50
bsize=32
lrate=0.001
lat_dimension=64
image_sz=64
chnls=3
logging_intv=200

In [3]:
class GANGenerator(nn.Module):
    def __init__(self):
        super(GANGenerator, self).__init__()
        self.inp_sz = image_sz // 4
        self.lin = nn.Sequential(nn.Linear(lat_dimension, 128 * self.inp_sz ** 2))
        self.bn1 = nn.BatchNorm2d(128)
        self.up1 = nn.Upsample(scale_factor=2)
        self.cn1 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128, 0.8)
        self.rl1 = nn.LeakyReLU(0.2, inplace=True)
        self.up2 = nn.Upsample(scale_factor=2)
        self.cn2 = nn.Conv2d(128, 64, 3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(64, 0.8)
        self.rl2 = nn.LeakyReLU(0.2, inplace=True)
        self.cn3 = nn.Conv2d(64, chnls, 3, stride=1, padding=1)
        self.act = nn.Tanh()

    def forward(self, x):
        x = self.lin(x)
        x = x.view(x.shape[0], 128, self.inp_sz, self.inp_sz)
        x = self.bn1(x)
        x = self.up1(x)
        x = self.cn1(x)
        x = self.bn2(x)
        x = self.rl1(x)
        x = self.up2(x)
        x = self.cn2(x)
        x = self.bn3(x)
        x = self.rl2(x)
        x = self.cn3(x)
        out = self.act(x)
        return out

In [4]:
class GANDiscriminator(nn.Module):
    def __init__(self):
        super(GANDiscriminator, self).__init__()

        def disc_module(ip_chnls, op_chnls, bnorm=True):
            mod = [nn.Conv2d(ip_chnls, op_chnls, 3, 2, 1), 
                   nn.LeakyReLU(0.2, inplace=True), 
                   nn.Dropout2d(0.25)]
            if bnorm:
                mod += [nn.BatchNorm2d(op_chnls, 0.8)]
            return mod

        self.disc_model = nn.Sequential(
            *disc_module(chnls, 16, bnorm=False),
            *disc_module(16, 32),
            *disc_module(32, 64),
            *disc_module(64, 128),
        )

        # width and height of the down-sized image
        ds_size = image_sz // 2 ** 4
        self.adverse_lyr = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, x):
        x = self.disc_model(x)
        x = x.view(x.shape[0], -1)
        out = self.adverse_lyr(x)
        return out

In [None]:
# instantiate the discriminator and generator models
gen = GANGenerator()
disc = GANDiscriminator()

# define the loss metric
adv_loss_func = torch.nn.BCELoss()

In [6]:
# define the dataset and corresponding dataloader
dloader = torch.utils.data.DataLoader(
    datasets.ImageFolder(
        "./data/mnist/",
        transform=transforms.Compose(
            [transforms.Resize((image_sz, image_sz)), 
             transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=bsize,
    shuffle=True,
)

# define the optimization schedule for both G and D
opt_gen = torch.optim.Adam(gen.parameters(), lr=lrate)
opt_disc = torch.optim.Adam(disc.parameters(), lr=lrate)

In [None]:
for ep in range(num_eps):
    for idx, (images, _) in enumerate(dloader):

        # generate grounnd truths for real and fake images
        good_img = Variable(torch.FloatTensor(images.shape[0], 1).fill_(1.0), requires_grad=False)
        bad_img = Variable(torch.FloatTensor(images.shape[0], 1).fill_(0.0), requires_grad=False)

        # get a real image
        actual_images = Variable(images.type(torch.FloatTensor))

        # train the generator model
        opt_gen.zero_grad()

        # generate a batch of images based on random noise as input
        noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (images.shape[0], lat_dimension))))
        gen_images = gen(noise)

        # generator model optimization - how well can it fool the discriminator
        generator_loss = adv_loss_func(disc(gen_images), good_img)
        generator_loss.backward()
        opt_gen.step()

        # train the discriminator model
        opt_disc.zero_grad()

        # calculate discriminator loss as average of mistakes(losses) in confusing real images as fake and vice versa
        actual_image_loss = adv_loss_func(disc(actual_images), good_img)
        fake_image_loss = adv_loss_func(disc(gen_images.detach()), bad_img)
        discriminator_loss = (actual_image_loss + fake_image_loss) / 2

        # discriminator model optimization
        discriminator_loss.backward()
        opt_disc.step()

        batches_completed = ep * len(dloader) + idx
        if batches_completed % logging_intv == 0:
            print(f"epoch number {ep} | batch number {idx} | generator loss = {generator_loss.item()} \
            | discriminator loss = {discriminator_loss.item()}")
            save_image(gen_images.data[:25], f"images_mnist/{batches_completed}.png", nrow=5, normalize=True)

epoch number 0 | batch number 0 | generator loss = 0.703317 | discriminator loss = 0.693092
epoch number 0 | batch number 200 | generator loss = 3.520811 | discriminator loss = 0.097897
epoch number 0 | batch number 400 | generator loss = 0.086716 | discriminator loss = 2.107819
epoch number 0 | batch number 600 | generator loss = 2.876380 | discriminator loss = 0.613213
epoch number 0 | batch number 800 | generator loss = 3.257680 | discriminator loss = 0.336500
epoch number 0 | batch number 1000 | generator loss = 1.235065 | discriminator loss = 0.663049
epoch number 0 | batch number 1200 | generator loss = 1.009370 | discriminator loss = 0.554212
epoch number 0 | batch number 1400 | generator loss = 1.388542 | discriminator loss = 0.472908
epoch number 0 | batch number 1600 | generator loss = 1.012620 | discriminator loss = 0.578827
epoch number 0 | batch number 1800 | generator loss = 1.140438 | discriminator loss = 1.165913
epoch number 0 | batch number 2000 | generator loss = 2.5

epoch number 2 | batch number 4736 | generator loss = 0.569593 | discriminator loss = 0.858146
epoch number 2 | batch number 4936 | generator loss = 1.062639 | discriminator loss = 0.519092
epoch number 2 | batch number 5136 | generator loss = 0.899736 | discriminator loss = 0.702106
epoch number 2 | batch number 5336 | generator loss = 0.830697 | discriminator loss = 0.591507
epoch number 2 | batch number 5536 | generator loss = 0.978978 | discriminator loss = 0.696458
epoch number 2 | batch number 5736 | generator loss = 1.071144 | discriminator loss = 0.854599
epoch number 2 | batch number 5936 | generator loss = 0.656397 | discriminator loss = 0.623073
epoch number 2 | batch number 6136 | generator loss = 0.798755 | discriminator loss = 0.644222
epoch number 3 | batch number 4 | generator loss = 1.033044 | discriminator loss = 0.541450
epoch number 3 | batch number 204 | generator loss = 0.994769 | discriminator loss = 0.729598
epoch number 3 | batch number 404 | generator loss = 0

epoch number 5 | batch number 3140 | generator loss = 0.676439 | discriminator loss = 0.660038
epoch number 5 | batch number 3340 | generator loss = 0.691239 | discriminator loss = 0.662879
epoch number 5 | batch number 3540 | generator loss = 0.767506 | discriminator loss = 0.611993
epoch number 5 | batch number 3740 | generator loss = 1.141797 | discriminator loss = 0.769062
epoch number 5 | batch number 3940 | generator loss = 1.065294 | discriminator loss = 0.629925
epoch number 5 | batch number 4140 | generator loss = 1.110013 | discriminator loss = 0.523894
epoch number 5 | batch number 4340 | generator loss = 0.741305 | discriminator loss = 0.424061
epoch number 5 | batch number 4540 | generator loss = 0.582408 | discriminator loss = 0.558700
epoch number 5 | batch number 4740 | generator loss = 1.102569 | discriminator loss = 0.613773
epoch number 5 | batch number 4940 | generator loss = 0.642278 | discriminator loss = 0.554472
epoch number 5 | batch number 5140 | generator los

epoch number 8 | batch number 1544 | generator loss = 1.579486 | discriminator loss = 0.797279
epoch number 8 | batch number 1744 | generator loss = 0.989263 | discriminator loss = 0.620146
epoch number 8 | batch number 1944 | generator loss = 0.853543 | discriminator loss = 0.710362
epoch number 8 | batch number 2144 | generator loss = 1.103519 | discriminator loss = 0.630420
epoch number 8 | batch number 2344 | generator loss = 0.757749 | discriminator loss = 0.805621
epoch number 8 | batch number 2544 | generator loss = 1.479676 | discriminator loss = 0.451883
epoch number 8 | batch number 2744 | generator loss = 0.855336 | discriminator loss = 0.540650
