## Importing Libraries and Dependencies

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

## Creating the Discriminator and Generator Classes

In [2]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
                        nn.Linear(img_dim, 128),
                        nn.LeakyReLU(0.1),
                        nn.Linear(128, 1),
                        nn.Sigmoid())
    def forward(self, x):
        return self.disc(x)
    
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
                        nn.Linear(z_dim, 256),
                        nn.LeakyReLU(0.1),
                        nn.Linear(256, img_dim),
                        nn.Tanh())
    def forward(self, x):
        return self.gen(x)

### Hyperparameters and Initializations

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 2e-3
z_dim = 64
img_dim = 28*28*1
batch_size = 64
num_epochs = 50

disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)
fixed_noise = torch.randn(batch_size, z_dim).to(device)
transforms = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = datasets.MNIST(root = "dataset/", transform=transforms, download = True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0

for epoch in range(1,num_epochs+1):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph = True)
        opt_disc.step()
        
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] \ "
                f"Loss D: {lossD:.4f}, {lossG:.4f}"
            )
            
            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)
                
                writer_fake.add_image(
                            "Mnist Fake Images", img_grid_fake, global_step=step
                )
                
                writer_real.add_image(
                            "Mnist Real Images", img_grid_real, global_step=step
                )
                
                step += 1

Epoch [1/50] \ Loss D: 0.7269, 0.7427
Epoch [2/50] \ Loss D: 0.1916, 5.5330
Epoch [3/50] \ Loss D: 0.2375, 4.6466
Epoch [4/50] \ Loss D: 0.4133, 3.9732
Epoch [5/50] \ Loss D: 0.6609, 5.8680
Epoch [6/50] \ Loss D: 0.0745, 6.9675
Epoch [7/50] \ Loss D: 0.2593, 3.3839
Epoch [8/50] \ Loss D: 0.3779, 4.3812
Epoch [9/50] \ Loss D: 0.5799, 4.1994
Epoch [10/50] \ Loss D: 0.4051, 2.5164
Epoch [11/50] \ Loss D: 0.5031, 2.1605
Epoch [12/50] \ Loss D: 0.4063, 3.7229
Epoch [13/50] \ Loss D: 0.6045, 3.2286
Epoch [14/50] \ Loss D: 0.4953, 2.2602
Epoch [15/50] \ Loss D: 0.5121, 2.4665
Epoch [16/50] \ Loss D: 0.6417, 1.7855
Epoch [17/50] \ Loss D: 0.6818, 1.4048
Epoch [18/50] \ Loss D: 0.4629, 3.2133
Epoch [19/50] \ Loss D: 0.4479, 1.5561
Epoch [20/50] \ Loss D: 0.4977, 2.7479
Epoch [21/50] \ Loss D: 0.7136, 2.4512
Epoch [22/50] \ Loss D: 0.5992, 1.4490
Epoch [23/50] \ Loss D: 0.6875, 2.0628
Epoch [24/50] \ Loss D: 0.8031, 1.2425
Epoch [25/50] \ Loss D: 0.6229, 1.6627
Epoch [26/50] \ Loss D: 0.5265, 2.

In [4]:
# RUN "tensorboard --logdir=runs" on Anaconda or CMD to access Tensorboard.