In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [None]:
class discriminator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(img_dim, 128),
        nn.LeakyReLU(0.01),
        nn.Linear(128, 1),
        nn.Sigmoid(),
    )
  def forward(self, x):
    return self.disc(x)

In [None]:
class generator(nn.Module):
  def __init__(self, z_dim, img_dim):
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(z_dim, 128),
        nn.LeakyReLU(0.01),
        nn.Linear(128, img_dim),
        nn.Tanh(),
    )

  def forward(self, x):
    return self.gen(x)

In [None]:
device = 'cuda' if torch.cuda.is_available else 'cpu'
lr = 3e-4
z_dim = 64 # 32, 128, 256
image_dim = 28 * 28 * 1 # 784
BATCH_SIZE = 32
EPOCHS = 50

In [None]:
disc = discriminator(image_dim).to(device)
gen = generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((BATCH_SIZE, z_dim)).to(device)
transform_pipeline = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ]
    )

dataset = datasets.MNIST(root='dataset/', transform=transform_pipeline, download=True)

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0


In [None]:
for epoch in range(num_epochs):
  for batch_idx, (real, _) in enumerate(loader): # real = images, _ = labels
    real = real.view(-1, 784).to(device)
    batch_size = real.shape[0]

    ### we need to maximize discriminator loss which is log(D(real)) + log(1 - D(G(z)))

    ### where z is the random noise to be given as input to the generator to generate fake image
    noise = torch.randn(BATCH_SIZE, z_dim).to(device)
    fake = gen(noise)
    disc_real = disc(real).view(-1) # D(real)
    lossD_real = criterion(disc_real, torch.ones_like(disc_real)) # log(D(real))

    disc_fake = disc(fake.detach()).view(-1) # D(G(z))
    lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) # log(1 - D(G(z)))

    lossD = (lossD_real + lossD_fake) / 2
    disc.zero_grad()
    lossD.backward()
    opt_disc.step()

    ### training generator : min log(1 - D(G(z))) <-> max log(D(G(z))
    output = disc(fake).view(-1)
    lossG = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_gen.step()


    if batch_idx == 0:
      print(f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(Loader)} \Loss D: {lossD:.4f}, loss G: {lossG:.4f}")

      with torch.no_grad():
        fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
        data = real.reshape(-1, 1, 28, 28)
        img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
        img_grid_real = torchvision.utils.make_grid(data, normalize=True)

        writer_fake.add_image(
            "Mnist Fake Images", img_grid_fake, global_step=step
                )
        writer_real.add_image(
            "Mnist Real Images", img_grid_real, global_step=step
                )
        step += 1

Epoch [0/50] Batch 0/1875 \Loss D: 0.5763, loss G: 0.7134
Epoch [1/50] Batch 0/1875 \Loss D: 0.2884, loss G: 1.7365
Epoch [2/50] Batch 0/1875 \Loss D: 0.3590, loss G: 1.2555
Epoch [3/50] Batch 0/1875 \Loss D: 0.2401, loss G: 2.0769
Epoch [4/50] Batch 0/1875 \Loss D: 0.6208, loss G: 0.9404
Epoch [5/50] Batch 0/1875 \Loss D: 0.3656, loss G: 1.7556
Epoch [6/50] Batch 0/1875 \Loss D: 0.7192, loss G: 0.8101
Epoch [7/50] Batch 0/1875 \Loss D: 0.4161, loss G: 1.5343
Epoch [8/50] Batch 0/1875 \Loss D: 0.7290, loss G: 1.1284
Epoch [9/50] Batch 0/1875 \Loss D: 0.6449, loss G: 1.3920
Epoch [10/50] Batch 0/1875 \Loss D: 0.8622, loss G: 1.0441
Epoch [11/50] Batch 0/1875 \Loss D: 0.3679, loss G: 2.0514
Epoch [12/50] Batch 0/1875 \Loss D: 0.5871, loss G: 1.2284
Epoch [13/50] Batch 0/1875 \Loss D: 0.5649, loss G: 1.4849
Epoch [14/50] Batch 0/1875 \Loss D: 0.8450, loss G: 0.9330
Epoch [15/50] Batch 0/1875 \Loss D: 0.7086, loss G: 0.9336
Epoch [16/50] Batch 0/1875 \Loss D: 0.6517, loss G: 1.1530
Epoch [