In [26]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from models import Discriminator, Generator,intialize_weights

In [28]:
# HYPERPARAMETERS
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DETERMINES WHETHER TO USE GPU (CUDA) OR CPU FOR COMPUTATIONS
# ADVANTAGES: AUTOMATICALLY USES GPU IF AVAILABLE FOR FASTER PROCESSING

LEARNING_RATE = 2e-4
# SETS THE LEARNING RATE FOR THE OPTIMIZER
# ADVANTAGES: SMALL LEARNING RATE HELPS IN STABLE TRAINING OF GANs

BATCH_SIZE = 128
# NUMBER OF IMAGES PROCESSED IN EACH TRAINING ITERATION
# ADVANTAGES: LARGER BATCH SIZE CAN LEAD TO MORE STABLE GRADIENTS

IMAGE_SIZE = 64
# DEFINES THE DIMENSIONS OF THE IMAGES (64x64 PIXELS)
# ADVANTAGES: STANDARD SIZE FOR DCGAN, BALANCES DETAIL AND COMPUTATION

CHANNELS_IMG = 1
# NUMBER OF COLOR CHANNELS (1 FOR GRAYSCALE, 3 FOR RGB)
# ADVANTAGES: ALLOWS FLEXIBILITY IN INPUT IMAGE TYPE

Z_DIM = 100
# DIMENSION OF THE LATENT SPACE VECTOR
# ADVANTAGES: PROVIDES ENOUGH VARIABILITY FOR GENERATOR INPUT

NUM_EPOCHS = 5
# NUMBER OF TIMES TO ITERATE THROUGH THE ENTIRE DATASET
# ADVANTAGES: ALLOWS FOR MULTIPLE PASSES OVER DATA FOR BETTER LEARNING

FEATURES_DISC = 64
FEATURES_GEN = 64
# STARTING NUMBER OF FEATURES FOR DISCRIMINATOR AND GENERATOR
# ADVANTAGES: ALLOWS FOR EASY SCALING OF NETWORK COMPLEXITY

NOISE_DIM = 100
# DIMENSION OF NOISE VECTOR (SAME AS Z_DIM)
# ADVANTAGES: CONSISTENCY IN GENERATOR INPUT

# TRANSFORMS
from torchvision import transforms
# IMPORT THE TRANSFORMS MODULE FROM TORCHVISION

transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)
# COMPOSE COMBINES MULTIPLE TRANSFORMS INTO A SINGLE TRANSFORM
# ADVANTAGES: ALLOWS FOR EASY APPLICATION OF MULTIPLE PREPROCESSING STEPS

In [30]:
# DATASET
# transform = transforms.Compose([
#     transforms.Resize(image_dim),
#     transforms.ToTensor(),
#     transforms.Normalize((0.5,), (0.5,))
# ])

dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms, download=False)

# CREATE DATA LOADER
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# INITIALIZE GENERATOR AND DISCRIMINATOR
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
intialize_weights(gen)
intialize_weights(disc)



In [31]:
# SETUP OPTIMIZERS
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

# DEFINE LOSS FUNCTION
criterion = nn.BCELoss()

# SETUP TENSORBOARD
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

In [32]:
# TRAINING LOOP
gen.train()
disc.train()

# SET GENERATOR AND DISCRIMINATOR TO TRAINING MODE
# ENABLES GRADIENT COMPUTATION AND BATCH NORMALIZATION UPDATES
# CRUCIAL FOR PROPER TRAINING OF NEURAL NETWORKS

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
        fake = gen(noise)

        # ITERATE THROUGH EPOCHS AND BATCHES
        # MOVE REAL IMAGES TO DEVICE (GPU/CPU)
        # GENERATE RANDOM NOISE AND CREATE FAKE IMAGES
        # ENSURES EFFICIENT PROCESSING AND CONSISTENT TRAINING ACROSS DEVICES

        # TRAIN DISCRIMINATOR
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # COMPUTE DISCRIMINATOR LOSS FOR REAL AND FAKE IMAGES
        # USE BINARY CROSS-ENTROPY LOSS
        # DETACH FAKE IMAGES TO PREVENT GENERATOR UPDATES
        # AVERAGE REAL AND FAKE LOSSES
        # PERFORM BACKPROPAGATION AND OPTIMIZATION
        # IMPROVES DISCRIMINATOR'S ABILITY TO DISTINGUISH REAL FROM FAKE

        # TRAIN GENERATOR
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # COMPUTE GENERATOR LOSS
        # USE BINARY CROSS-ENTROPY LOSS
        # PERFORM BACKPROPAGATION AND OPTIMIZATION
        # IMPROVES GENERATOR'S ABILITY TO PRODUCE REALISTIC IMAGES

        # PRINT LOSSES AND VISUALIZE PROGRESS
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
            step += 1

        # PERIODICALLY PRINT TRAINING PROGRESS
        # GENERATE FAKE IMAGES FROM FIXED NOISE FOR CONSISTENCY
        # CREATE IMAGE GRIDS FOR REAL AND FAKE IMAGES
        # LOG IMAGES TO TENSORBOARD FOR VISUALIZATION
        # HELPS MONITOR TRAINING PROGRESS AND QUALITY OF GENERATED IMAGES

Epoch [0/5] Batch 0/469                   Loss D: 0.6961, loss G: 0.7778
