In [None]:
# Import necessary modules from utils and gan_models
import torch
from torch import nn
import numpy as np
from torchvision.utils import save_image
import os

import utils
from gan_models import Generator, Discriminator

# Get device from utils
device = utils.get_device()

# 1. Dataset

In [None]:
# Define dataset parameters
img_size = 32
BATCH_SIZE = 64

# Get dataloader using utils function
dataloader = utils.get_mnist_dataloader(img_size=img_size, batch_size=BATCH_SIZE)

# 2. Model

In [None]:
# Define model parameters
channels = 1
img_shape = (channels, img_size, img_size)
latent_dim = 100

In [None]:
generator = Generator(latent_dim=latent_dim, img_shape=img_shape)
discriminator = Discriminator(img_shape=img_shape)

In [None]:
generator.to(device)
discriminator.to(device)

# 3. Training

In [None]:
# Create output directory using utils function
output_dir = "./images_gan"
utils.create_dir(output_dir)
save_interval = 10

In [None]:
EPOCHS = 200

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

criterion = nn.BCELoss()
hist = {
    "train_G_loss": [],
    "train_D_loss": [],
}

for epoch in range(EPOCHS):
    running_G_loss = 0.0
    running_D_loss = 0.0

    for i, (imgs, _) in enumerate(dataloader):

        real_imgs = imgs.to(device)
        real_labels = torch.ones(imgs.shape[0], 1).to(device)
        fake_labels = torch.zeros(imgs.shape[0], 1).to(device)

        # -------------------------- Train Generator --- 
        optimizer_G.zero_grad()
        
        # Noise input for Generator
        z = torch.randn((imgs.shape[0], latent_dim)).to(device)

        gen_imgs = generator(z)
        G_loss = criterion(discriminator(gen_imgs), real_labels)
        running_G_loss += G_loss.item()

        G_loss.backward()
        optimizer_G.step()


        # -------------- Train Discriminator --- 
        optimizer_D.zero_grad()
        real_loss = criterion(discriminator(real_imgs), real_labels)
        fake_loss = criterion(discriminator(gen_imgs.detach()), fake_labels) # .detach() to prevent back prop to Generator
        D_loss = (real_loss + fake_loss) / 2
        running_D_loss += D_loss.item()

        D_loss.backward()
        optimizer_D.step()
    
    epoch_G_loss = running_G_loss / len(dataloader)
    epoch_D_loss = running_D_loss / len(dataloader)
    
    print(f"Epoch [{epoch + 1}/{EPOCHS}], Train G Loss: {epoch_G_loss:.4f}, Train D Loss: {epoch_D_loss:.4f}")

    hist["train_G_loss"].append(epoch_G_loss)
    hist["train_D_loss"].append(epoch_D_loss)

    if epoch % save_interval == 0:
        save_image(gen_imgs.data[:25], f"images/epoch_{epoch}.png", nrow=5, normalize=True)