In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

In [2]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data as DataLoader
from torch.utils.tensorboard import SummaryWriter

In [3]:
class Discrimiator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(img_dim, 128),
        nn.LeakyReLU(0.1),
        nn.Linear(128, 1),
        nn.Sigmoid(),
    )

  def forward(self, x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self, z_dim, img_dim):
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(z_dim, 256),
        nn.LeakyReLU(0.1),
        nn.Linear(256, img_dim),
        nn.Tanh(), #to make sure the value of pixel values is between -1 and 1 after normalization
    )

  def forward(self,x):
    return self.gen(x)

lr = 3e-4
z_dim = 64
image_dim = 784
batch_size = 32
num_epochs = 50
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [4]:
disc = Discrimiator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn(batch_size, z_dim).to(device)
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),])

In [5]:
dataset = datasets.MNIST(root='dataset/', transform=transforms, download=True)

In [6]:
loader = DataLoader.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [7]:
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"GAN_MNIST/fake")
writer_real = SummaryWriter(f"GAN_MNIST/real")
step = 0

In [8]:
for epoch in range(num_epochs):
  for batch_idx, (real, _) in enumerate(loader):
    real = real.view(-1, 784).to(device)
    batch_size = real.shape[0]
    noise = torch.randn(batch_size, z_dim).to(device)

    #Train Discriminator
    fake = gen(noise)
    disc_real = disc(real).view(-1)
    #Criterion is -w_n[y_n(ln(x_n)) + (1-y_n)ln(1-x_n)], so first we take yn as 1, next 0 becasue, together makes the negative of target function that is to be maximized, so we minimize the loss to maximize the target fn
    lossD_real = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake).view(-1)
    lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    lossD = (lossD_real + lossD_fake) / 2
    disc.zero_grad()
    lossD.backward(retain_graph= True)
    opt_disc.step()

    #Train Generator: here, we want to minimize log(1-D(G(z))) but the gradient loss fucntion does train properly, so instead we maximize log(D(G(z)))
    output = disc(fake).view(-1)
    lossG = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_gen.step()

    if batch_idx == 0:
      print(
          f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
          Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
      )

      with torch.no_grad():
        fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
        data = real.reshape(-1, 1, 28, 28)
        img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
        img_grid_real = torchvision.utils.make_grid(data, normalize=True)

        writer_fake.add_image(
          "Mnist Fake Images", img_grid_fake, global_step=step
        )
        writer_real.add_image(
          "Mnist Real Images", img_grid_real, global_step=step
        )
        step += 1

Epoch [0/50] Batch 0/1875           Loss D: 0.6924, loss G: 0.7461
Epoch [1/50] Batch 0/1875           Loss D: 0.3047, loss G: 1.6017
Epoch [2/50] Batch 0/1875           Loss D: 0.5663, loss G: 0.8626
Epoch [3/50] Batch 0/1875           Loss D: 0.5267, loss G: 1.1182
Epoch [4/50] Batch 0/1875           Loss D: 0.8383, loss G: 0.7230
Epoch [5/50] Batch 0/1875           Loss D: 0.6173, loss G: 1.0907
Epoch [6/50] Batch 0/1875           Loss D: 0.9232, loss G: 0.6570
Epoch [7/50] Batch 0/1875           Loss D: 0.7700, loss G: 0.7874
Epoch [8/50] Batch 0/1875           Loss D: 0.5480, loss G: 1.3472
Epoch [9/50] Batch 0/1875           Loss D: 0.9052, loss G: 0.7682
Epoch [10/50] Batch 0/1875           Loss D: 0.7633, loss G: 0.9711
Epoch [11/50] Batch 0/1875           Loss D: 0.4703, loss G: 1.1517
Epoch [12/50] Batch 0/1875           Loss D: 0.5269, loss G: 1.1476
Epoch [13/50] Batch 0/1875           Loss D: 0.5305, loss G: 1.1745
Epoch [14/50] Batch 0/1875           Loss D: 0.7415, loss 

In [9]:
# Save the model
torch.save(disc.state_dict(), "discriminator.pth")
torch.save(gen.state_dict(), "generator.pth")
