In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

# Discriminator

In [19]:
class Discriminator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(img_dim,128),
        nn.LeakyReLU(0.1), # LeakyReLU is better default for GAN's
        nn.Linear(128,1),
        nn.Sigmoid(),
    )

  def forward(self,X):
    return self.disc(X)

# Generator

In [20]:
class Generator(nn.Module):
  def __init__(self,z_dim,img_dim): # z_dim is the "latent noise"
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(z_dim,256),
        nn.LeakyReLU(0.1),
        nn.Linear(256,img_dim),
        nn.Tanh(),
    )

  def forward(self,X):
    return self.gen(X)


In [21]:
# Hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64 # 128, 256
image_dim = 28*28*1
batch_size = 32
epochs = 50

In [22]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,),(0.5,))]
)

In [23]:
dataset = datasets.MNIST(root="dataset/", transform=transform, download=True)

In [24]:
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [25]:
opt_disc = torch.optim.Adam(disc.parameters(),lr=lr)
opt_gen = torch.optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f'runs/GAN_MNIST/fake')
writer_real = SummaryWriter(f'runs/GAN_MNIST/real')
step = 0

In [26]:
for epoch in range(epochs):
  for batch_idx, (img,label) in enumerate(loader):
    img = img.view(-1,784).to(device)
    batch_size = img.shape[0]

    ### Training for Discriminator: max log(D(real)) + log(1 - D(G(z)))
    noise = torch.randn(batch_size, z_dim).to(device)
    fake = gen(noise)
    disc_real = disc(img).view(-1)
    lossD_real = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake).view(-1)
    lossD_fake = criterion(disc_fake, torch.ones_like(disc_fake))

    lossD = (lossD_real + lossD_fake) / 2
    disc.zero_grad()
    lossD.backward(retain_graph=True)
    opt_disc.step()



    ### Train Generator min log(1 - D(G(z))) <--> max log(D(G(z)))
    output = disc(fake).view(-1)
    lossG = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_gen.step()


    ### Code for tensor board
    if batch_idx == 0:
      print(
            f"Epoch [{epoch}/{epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

      with torch.no_grad():
          fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
          data = img.reshape(-1, 1, 28, 28)
          img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
          img_grid_real = torchvision.utils.make_grid(data, normalize=True)

          writer_fake.add_image(
            "Mnist Fake Images", img_grid_fake, global_step=step
            )
          writer_real.add_image(
              "Mnist Real Images", img_grid_real, global_step=step
            )
          step += 1






Epoch [0/50] Batch 0/1875                       Loss D: 0.6819, loss G: 0.6532
Epoch [1/50] Batch 0/1875                       Loss D: 0.0000, loss G: 0.0000
Epoch [2/50] Batch 0/1875                       Loss D: 0.0000, loss G: 0.0000
Epoch [3/50] Batch 0/1875                       Loss D: 0.0000, loss G: 0.0000
Epoch [4/50] Batch 0/1875                       Loss D: 0.0000, loss G: 0.0000


KeyboardInterrupt: 