In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import transforms
# from torch.utils.tensorboard import SummaryWriter
import torchvision

In [9]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1), # Leaky ReLU = f(x) = max(0.01*x, x)
            nn.Linear(128, 1), # 128 -> 1 because we just need to predict if it is real or fake, so fake is zero, real is 1
            nn.Sigmoid(), # sigmoid for scaling between 0 and 1
        )

    def forward(self, x):
        return self.disc(x)

In [10]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim): # z_dim is the dimension of the latent noice that the generator will use to create new samples
        super().__init__() 
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh(), # tanh for ensuring that the value is between -1 and 1, beacuse that how we3 are going to scale our inputs
                       # tanh = (e^x – e^-x) / (e^x + e^-x)
        )
    
    def forward(self, x):
        return self.gen(x)

In [11]:
# Hyperparameters - GAN is very sensitive to hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

lr = 3e-4
z_dim = 64
img_dim = 28 * 28 * 1
batch_size = 32
epochs = 50

cuda


In [12]:
disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device) # will be used to see how good the generater is getting over time
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5),(0.5))] # these are the actual mean and sd for MNIST dataset
)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

In [13]:
# setup tensorboard
# writer_fake = SummaryWriter(f"runs/GAN_MNSIT/fake")
# writer_real = SummaryWriter(f"runs/GAN_MNSIT/real")
# step = 0
import os
os.mkdir('runs')
os.mkdir('runs/real')
os.mkdir('runs/fake')

$BCE Loss = -w_n[y_n log(x_n) + (1 - y_n)log(1 - x_n)]$

In [14]:
for epoch in range(epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        bacth_size = real.shape[0]

        ### training for discriminator : max log(D(real)) + log(1-D(G(noise)))
        noise = torch.randn(bacth_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real)) # note how we have passed 1 in the second papram (y) to remove that (1-y) part from the loss term
                                                                      # now, we need to maximize it, and due the the presence of -ve sign, we need to minimize this
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) # note how we have passed 0 in the second tem here (y), and we need to minimize it due to -ve sign
        lossD = (lossD_real + lossD_fake)/2
        disc.zero_grad()
        lossD.backward(retain_graph = True) # retarining the graph so that torch do not fluses everting so that we can use the 'fake' for discriminatoir as well
        opt_disc.step()

        ### training the generator : min log(1-D(G(z))) -> this eqn saturates the gradients (results slower tarining); a better sol is max log(D(G(z)))
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        ### additional code for tensorboard
        if batch_idx == 0:
            print(
                f"Epoch[{epoch}/{epochs}] LossD:{lossD:.4}, LossG:{lossG:.4}"
            )

        if (epoch+1)%10 == 0 and batch_idx == 0:
            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)
                # save the images as png
                torchvision.utils.save_image(img_grid_fake, f"runs/fake/{epoch+1}.png")
                torchvision.utils.save_image(img_grid_real, f"runs/real/{epoch+1}.png")
            

            # writer_fake.add_image(
            #     "MNIST Fake", img_grid_fake, global_step=step
            # )
            # writer_real.add_image(
            #     "MNIST Real", img_grid_real, global_step=step
            # )

Epoch[0/50] LossD:0.7176, LossG:0.748
Epoch[1/50] LossD:0.3233, LossG:1.533
Epoch[2/50] LossD:0.7922, LossG:0.8395
Epoch[3/50] LossD:1.127, LossG:0.4905
Epoch[4/50] LossD:0.665, LossG:0.9109
Epoch[5/50] LossD:0.752, LossG:0.7621
Epoch[6/50] LossD:0.8487, LossG:0.7157
Epoch[7/50] LossD:0.6298, LossG:0.8206
Epoch[8/50] LossD:0.5119, LossG:1.16
Epoch[9/50] LossD:0.4942, LossG:1.479
Epoch[10/50] LossD:0.8003, LossG:0.8304
Epoch[11/50] LossD:0.9386, LossG:0.5095
Epoch[12/50] LossD:0.5659, LossG:1.171
Epoch[13/50] LossD:0.928, LossG:0.8423
Epoch[14/50] LossD:0.758, LossG:0.9671
Epoch[15/50] LossD:0.8722, LossG:1.1
Epoch[16/50] LossD:0.4591, LossG:1.513
Epoch[17/50] LossD:0.8951, LossG:0.739
Epoch[18/50] LossD:0.7035, LossG:1.302
Epoch[19/50] LossD:0.7011, LossG:0.916
Epoch[20/50] LossD:0.6532, LossG:0.8266
Epoch[21/50] LossD:0.7117, LossG:0.6831
Epoch[22/50] LossD:0.6569, LossG:1.014
Epoch[23/50] LossD:0.4741, LossG:1.294
Epoch[24/50] LossD:0.5496, LossG:1.186
Epoch[25/50] LossD:0.5502, Loss

In [15]:
# !tensorboard --logdir=runs/ --host=localhost --port=8888