In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

from model import Discriminator, Generator

Hyperparameters:

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1
batch_size = 32
num_epochs = 200

Models:

In [3]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)

Fixed noise for comparison:

In [4]:
fixed_noise = torch.rand((batch_size, z_dim)).to(device)

Data:

In [5]:
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ]
)

In [6]:
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Optimizer, Loss and Metric:

In [7]:
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

Tensorboard stuff:

In [8]:
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0

Training:

In [9]:
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        # keep the batch size the same, but flatten the images
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        # Train Discriminator: max log(D(real)) + log(1 - D(G(z)))
        noise = torch.rand(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        # BCE: -wn * (y * log(x) + (1 - y) * log(1 - x))
        # if we set y to 1, we get -wn * log(x)
        # and maximizing log(x) is the same as minimizing -log(x)
        # and we normally want to minimize the loss function
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        # BCE: -wn * (y * log(x) + (1 - y) * log(1 - x))
        # if we set y to 0, we get -wn * log(1 - x)
        # and maximizing log(1 - x) is the same as minimizing -log(1 - x)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        # retain_graph=True because we want to reuse the fake variable again which is part of the computational graph
        lossD.backward(retain_graph=True)
        opt_disc.step()

        # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image("Mnist Fake Images", img_grid_fake, global_step=step)
                writer_real.add_image("Mnist Real Images", img_grid_real, global_step=step)
                step += 1

Epoch [0/200] Batch 0/1875                   Loss D: 0.6808, loss G: 0.6816
Epoch [1/200] Batch 0/1875                   Loss D: 0.0116, loss G: 4.7049
Epoch [2/200] Batch 0/1875                   Loss D: 0.0005, loss G: 7.4484
Epoch [3/200] Batch 0/1875                   Loss D: 0.0003, loss G: 8.2157
Epoch [4/200] Batch 0/1875                   Loss D: 0.0000, loss G: 11.0840
Epoch [5/200] Batch 0/1875                   Loss D: 0.0000, loss G: 12.8336
Epoch [6/200] Batch 0/1875                   Loss D: 0.0000, loss G: 13.6879
Epoch [7/200] Batch 0/1875                   Loss D: 0.0000, loss G: 15.0445
Epoch [8/200] Batch 0/1875                   Loss D: 0.0000, loss G: 17.0174
Epoch [9/200] Batch 0/1875                   Loss D: 0.0000, loss G: 17.8752
Epoch [10/200] Batch 0/1875                   Loss D: 0.0000, loss G: 16.8189
Epoch [11/200] Batch 0/1875                   Loss D: 0.0000, loss G: 18.2589
Epoch [12/200] Batch 0/1875                   Loss D: 0.0000, loss G: 19.4139
