In [1]:
#==============================================#
#         Author: Kartikeya Sharma             #
#==============================================#

In [2]:
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision.transforms import transforms
from torch.utils.tensorboard import SummaryWriter

In [3]:
print("Cuda Available: ", torch.cuda.is_available())
print("Number of Cuda Devices:", torch.cuda.device_count())
print("Device Name:", torch.cuda.get_device_name(0))

Cuda Available:  True
Number of Cuda Devices: 1
Device Name: NVIDIA GeForce GTX 1650 Ti


In [4]:
os.getcwd()

'D:\\PyTorch Tutorials\\Implementations From Scratch\\GAN_From_Scratch'

In [21]:
class Discriminator(nn.Module):
    def __init__(self, img_size):
        super().__init__()
        self.discriminator = nn.Sequential(
            nn.Linear(in_features = img_size, out_features=128),
            nn.LeakyReLU(negative_slope = 0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
    def predict(self, x):
        return self.discriminator(x)
    
# Normalize input from 0 to 1
class Generator(nn.Module):
    def __init__(self, noise_dim, img_dim):
        super().__init__()
        self.generator = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Sigmoid()
        )
    
    def predict(self, x):
        return self.generator(x)

In [22]:
# Parameters and Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 3e-4
noise_dim = 16
img_dim = 28 * 28 * 1 # Which will be 784
batch_size = 32
epochs = 3

disc = Discriminator(img_dim).to(device)
gen = Generator(noise_dim, img_dim).to(device)

# Important to include batch_size and nosie_dim as we will be feeding the network a batch of batch_size
# noise vectors and generating batch_size number of images which will be compared evaluated by the discriminator.
noise = torch.randn((batch_size, noise_dim)).to(device)

In [23]:
trans = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean = (0.5,), std = (0.5,))
    ]
)

dataset = datasets.MNIST(root = f"dataset/", transform = trans, download = True)
loader = DataLoader(dataset, batch_size = batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr = lr)
opt_gen = optim.Adam(gen.parameters(), lr = lr)

loss = nn.BCELoss()
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")

In [24]:
step = 0

for epoch in range(epochs):
    print("Epoch:", epoch)
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]
        ## Train Discriminator Loss Function: log(D(real)) + log(1 - D(G(noise)))
        # This has to be maximized
        noise = torch.randn(batch_size, noise_dim).to(device)
        fake = gen.predict(noise)
        disc_real = disc.predict(real).view(-1)
        lossD_real = loss(disc_real, torch.ones_like(disc_real))
        
        disc_fake = disc.predict(fake).view(-1)
        lossD_fake = loss(disc_fake, torch.zeros_like(disc_fake))
        
        lossD = (lossD_fake/2 + lossD_real/2)
        disc.zero_grad()
        lossD.backward(retain_graph = True)
        opt_disc.step()
        
        ## Train Generator Loss Function: log(D(G(noise))) which was derived after seeing the problem of saturating gradients in the original loss function
        # This loss function has to be maximized        
        ## Original Loss Function (Not to be considered): log(1 - D(G(noise))) 
        # This loss had to be minimized
        
        output = disc.predict(fake).view(-1)
        lossG = loss(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
                
        if batch_idx == 0:
            print(f"Epoch [{epoch}/{epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}")

            with torch.no_grad():
                fake = gen.predict(noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1

Epoch: 0
Epoch [0/3] Batch 0/1875                       Loss D: 0.6248, loss G: 0.9511
Epoch: 1
Epoch [1/3] Batch 0/1875                       Loss D: 0.0024, loss G: 5.3318
Epoch: 2
Epoch [2/3] Batch 0/1875                       Loss D: 0.0004, loss G: 7.1757
