In [None]:
import torch
import torch.nn as nn
import torch.optim as optim # optimizer
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Now we do transformation to convert an input image (or any other data) into a PyTorch tensor.

Specifically, it converts pixel values (usually in the range [0, 255]) to floating-point values between 0 and 1.

For example, if you have an RGB image, this transformation will convert it into a 3-channel tensor (red, green, and blue).

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))]) # We are dealing with EMNIST, therefore only one channel is needed (mean),(std)


In [None]:
train_dataset = datasets.EMNIST(root='./data',
                                 split='balanced',  # Choose the split (e.g., 'balanced', 'letters', 'digits', etc.)
                                 train=True,
                                 download=True,
                                 transform=transform) # Use defined transform

dataloader = torch.utils.data.DataLoader(train_dataset,
                                         batch_size=32,
                                         shuffle = True) # Shuffle the dataset

Downloading https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip to ./data/EMNIST/raw/gzip.zip


100%|██████████| 561753746/561753746 [00:12<00:00, 43710681.64it/s]


Extracting ./data/EMNIST/raw/gzip.zip to ./data/EMNIST/raw


Separating your hyperparameters are actually a good habit!

In [None]:
# Hyperparameters
latent_dim = 100 # Latent space
lr = 0.0002 # Learning rate, the smaller the lr, the more the time takes.

# This betas are specifically for Adam optimizer
beta1 = 0.5 # Controls the exponential decay rate for the first moment estimate (commonly set 0.9 or higher)
beta2 = 0.999 # ... the second moment estimate
# Higher beta (close to 1.0) means more persistant
num_epochs = 1

Range: **Tanh** has a range of *-1 to 1*, while **ReLU** has a range of *0 to infinity*.

Non-linearity: Both functions are **non-linear**, but Tanh is symmetric around the origin, while ReLU is NOT.

Derivative: The derivative of Tanh is always less than 1, while the derivative of ReLU is either 0 or 1. (can refer to the picture below)

Vanishing Gradient: Tanh is prone to the vanishing gradient problem, which can slow down training in deep networks. ReLU is less prone to the vanishing gradient problem.

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128 * 28 * 28),
            nn.ReLU(),
            nn.Unflatten(1, (128, 28, 28)),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64, momentum=0.78),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 1, kernel_size=3, padding=1)  # Adjust to 1 output channel (grayscale)
            #nn.Flatten(),  # Flatten the output to 1D tensor
        )

    def forward(self, z):
        img = self.model(z)
        return img


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.25),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ZeroPad2d((0, 1, 0, 1)),
            nn.BatchNorm2d(64, momentum=0.82),
            nn.LeakyReLU(0.25),
            nn.Dropout(0.25),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128, momentum=0.82),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.25),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256, momentum=0.8),
            nn.LeakyReLU(0.25),
            nn.Dropout(0.25),
            nn.Sigmoid() #sigmoid function ranged (0,1), therefore it only output yes or no by 1 and 0
        )

    def forward(self, img):
        validity = self.model(img)
        return validity


In [None]:
# Initialize generator and discriminator
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

# Loss function
adversarial_loss = nn.BCELoss()
# ALTERNATIVELY, YOU CAN USE HINGE LOSS WHICH IS PERFECT FOR GANs (my biased opinion)
"""
def hinge_loss(real_logits, fake_logits):
    D_loss = -torch.mean(torch.min(0., -1.0 + real_logits)) - torch.mean(torch.min(0., -1.0 - fake_logits))
    G_loss = -torch.mean(fake_logits)
    return D_loss, G_loss
"""
# Optimizers
optimizer_G = optim.Adam(generator.parameters(),
                         lr=lr,
                         betas=(beta1, beta2))
optimizer_D = optim.Adam(discriminator.parameters(),
                         lr=lr,
                         betas=(beta1, beta2))


In [None]:
import torch.nn.functional as F
# Training loop
for epoch in range(num_epochs):
    for i, batch in enumerate(dataloader):
        # Convert list to tensor
        real_images = batch[0].to(device)

        # Adversarial ground truths
        valid = torch.ones((real_images.size(0), 256, 4, 4), device=device)
        fake = torch.zeros((real_images.size(0), 256, 4, 4), device=device)
        # Configure input
        real_images = real_images.to(device)

        # ---------------------
        # Train Discriminator
        # ---------------------

        optimizer_D.zero_grad() #Set zero gradient

        # Sample noise as generator input
        z = torch.randn(real_images.size(0), latent_dim, device=device)

        # Generate a batch of images
        fake_images = generator(z)
        fake_images_resized = F.interpolate(fake_images, size=(28, 28), mode='bilinear', align_corners=False)

        # Measure discriminator's ability
        # to classify real and fake images
        real_loss = adversarial_loss(discriminator(real_images), valid)
        fake_loss = adversarial_loss(discriminator(fake_images_resized.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        # Backward pass and optimize
        d_loss.backward()
        optimizer_D.step()

        # -----------------
        # Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Generate a batch of images
        gen_images = generator(z)
        gen_images_resized = F.interpolate(gen_images, size=(28, 28), mode='bilinear', align_corners=False)

        # Adversarial loss
        g_loss = adversarial_loss(discriminator(gen_images_resized), valid)

        # Backward pass and optimize
        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        # Progress Monitoring
        # ---------------------

        if (i + 1) % 100 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}] "
                f"Batch {i+1}/{len(dataloader)} "
                f"Discriminator Loss: {d_loss.item():.4f} "
                f"Generator Loss: {g_loss.item():.4f}"
            )

    # Save generated images for every epoch
    if (epoch + 1) % 10 == 0:
        with torch.no_grad():
            z = torch.randn(16, latent_dim, device=device)
            generated = generator(z).detach().cpu()
            grid = torchvision.utils.make_grid(generated, nrow=4, normalize=True)
            plt.imshow(np.transpose(grid, (1, 2, 0)))
            plt.axis("off")
            plt.show()


Epoch [1/1] Batch 100/3525 Discriminator Loss: 0.7393 Generator Loss: 0.6100
Epoch [1/1] Batch 200/3525 Discriminator Loss: 0.7259 Generator Loss: 0.6205
Epoch [1/1] Batch 300/3525 Discriminator Loss: 0.7173 Generator Loss: 0.6344
Epoch [1/1] Batch 400/3525 Discriminator Loss: 0.7106 Generator Loss: 0.6445
Epoch [1/1] Batch 500/3525 Discriminator Loss: 0.7054 Generator Loss: 0.6541
Epoch [1/1] Batch 600/3525 Discriminator Loss: 0.7012 Generator Loss: 0.6632


## References
https://towardsdatascience.com/generative-adversarial-networks-gans-a-beginners-guide-f37c9f3b7817

