In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [2]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.d = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1), # GAN pake LRELU
            nn.Linear(128, 1), # fake = 0, real = 1
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.d(x)

class Generator(nn.Module):
    def __init__(self,z_dim, img_dim): # z_dim = latent noise dimension
        super().__init__()
        self.g = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim), # Outputin artif image (28x28x1)
            nn.Tanh() # nanti img_dim di normalisasi agar -1<=x<=1, makanya pake tanh
        )
    def forward(self, x):
        return self.g(x)

In [6]:
# Hyperparameter
"""
GAN sangat sensitif dengan hyperparameter yang dipilih.
"""
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LR = 3e-4 # is the best lr for Adam optim -Andrej Karpathy
Z_DIM = 64 # latent noise dimension
IMG_DIM = 784 # MNIST image dimension
BATCH_SIZE = 32
EPOCH = 10
FIXED_NOISE = torch.randn(BATCH_SIZE, Z_DIM, device=DEVICE)

In [4]:
d = Discriminator(IMG_DIM).to(DEVICE)
g = Generator(Z_DIM, IMG_DIM).to(DEVICE)
transforms = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.5), (0.5))
    ])

dataset = datasets.MNIST(root='./data/', transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
d_optim = optim.Adam(d.parameters(), lr=LR)
g_optim = optim.Adam(g.parameters(), lr=LR)
criterion = nn.BCELoss() # based on the paper
sum_writer_fake = SummaryWriter(f"runs/ezGAN/fake_data")
sum_writer_real = SummaryWriter(f"runs/ezGAN/real_data")
step = 0

In [7]:
for epoch in range(EPOCH):
    for idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(DEVICE)
        
        # Discriminator
        """
        -> maximize log(D(x)) + log(1 - D(G(z)))
        """
        z_noise = torch.randn(BATCH_SIZE, Z_DIM).to(DEVICE) # gaussian noise
        fake = g(z_noise)
        d_real = d(real).view(-1)
        loss_d_real = criterion(d_real, torch.ones_like(d_real)) # log(D(x)) + 0
        d_fake = d(fake).view(-1)
        loss_d_fake = criterion(d_fake, torch.zeros_like(d_fake)) # 0 + log(1 - D(G(z)))
        loss_d = (loss_d_real + loss_d_fake) / 2 # log(D(x)) + log(1 - D(G(z))), kenapa dibagi 2?
        d_optim.zero_grad()
        loss_d.backward(retain_graph=True)
        d_optim.step()

        # Generator
        """
        -> minimize log(1 - D(G(z))) but leads to vanishing gradient problem
        --> the workaround is to maximize log(D(G(z)))
        """
        out = d(fake).view(-1)
        loss_g = criterion(out, torch.ones_like(out))  # maximize log(D(x)) + 0
        g_optim.zero_grad()
        loss_g.backward()
        g_optim.step()

        # Logging
        if idx == 0:
            print(
                f"Epoch [{epoch}/{EPOCH}] Batch {idx}/{len(loader)} \
                      Loss D: {loss_d:.4f}, loss G: {loss_g:.4f}"
            )

            with torch.no_grad():
                fake = g(FIXED_NOISE).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(
                    fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(
                    data, normalize=True)

                sum_writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                sum_writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1


Epoch [0/10] Batch 0/1875                       Loss D: 0.2918, loss G: 1.6669
Epoch [1/10] Batch 0/1875                       Loss D: 0.2996, loss G: 1.4723
Epoch [2/10] Batch 0/1875                       Loss D: 0.8469, loss G: 0.6616
Epoch [3/10] Batch 0/1875                       Loss D: 0.4883, loss G: 1.4582
Epoch [4/10] Batch 0/1875                       Loss D: 0.5857, loss G: 1.1705
Epoch [5/10] Batch 0/1875                       Loss D: 0.3918, loss G: 1.5993
Epoch [6/10] Batch 0/1875                       Loss D: 0.5467, loss G: 1.4018
Epoch [7/10] Batch 0/1875                       Loss D: 0.4233, loss G: 1.3390
Epoch [8/10] Batch 0/1875                       Loss D: 0.5797, loss G: 0.8857
Epoch [9/10] Batch 0/1875                       Loss D: 0.6779, loss G: 1.0047


![hehe](./asset/FP5rFJEXsAIQ8FA.jpg)