In [1]:
import math
import pickle as pkl
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Load data
transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dl = DataLoader(dataset=train_ds, shuffle=True, batch_size=64)

# Architectures
class Discriminator(nn.Module):
    def __init__(self, in_features=784, out_features=1):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(32, out_features)
        )

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)
        return self.model(x)

class Generator(nn.Module):
    def __init__(self, in_features=100, out_features=784):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features, 32),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(32, 64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(64, 128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(128, out_features),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# Loss functions
def compute_loss(logits, target_value, loss_fn, device):
    """Computes loss for real outputs (target=1) or fake outputs (target=0)"""
    batch_size = logits.shape[0]
    targets = torch.full((batch_size,), target_value, device=device)
    return loss_fn(logits.squeeze(), targets)

# Training
def train_gan(d, g, d_optim, g_optim, loss_fn, dl, n_epochs, device, z_size=100):
    print(f'Training on [{device}]...')

    # Fixed latent vector for monitoring progress
    fixed_z = torch.randn(16, z_size, device=device)
    fixed_samples = []
    d_losses, g_losses = [], []

    d.to(device)
    g.to(device)

    for epoch in range(n_epochs):
        d.train()
        g.train()
        d_running_loss = 0
        g_running_loss = 0

        for real_images, _ in dl:
            real_images = real_images.to(device)
            batch_size = real_images.size(0)

            # Train Discriminator
            d_optim.zero_grad()

            # With real images (scaled to [-1, 1])
            real_images = (real_images * 2) - 1
            d_real_out = d(real_images)
            d_real_loss = compute_loss(d_real_out, 1.0, loss_fn, device)

            # With fake images
            z = torch.randn(batch_size, z_size, device=device)
            fake_images = g(z).detach()
            d_fake_out = d(fake_images)
            d_fake_loss = compute_loss(d_fake_out, 0.0, loss_fn, device)

            d_loss = d_real_loss + d_fake_loss
            d_loss.backward()
            d_optim.step()
            d_running_loss += d_loss.item()

            # Train Generator
            g_optim.zero_grad()

            z = torch.randn(batch_size, z_size, device=device)
            fake_images = g(z)
            g_out = d(fake_images)
            g_loss = compute_loss(g_out, 1.0, loss_fn, device)

            g_loss.backward()
            g_optim.step()
            g_running_loss += g_loss.item()

        # Save epoch losses
        d_losses.append(d_running_loss / len(dl))
        g_losses.append(g_running_loss / len(dl))

        print(f'Epoch [{epoch+1}/{n_epochs}] - D_loss: {d_losses[-1]:.4f}, G_loss: {g_losses[-1]:.4f}')

        # Save generator samples
        g.eval()
        with torch.no_grad():
            fixed_samples.append(g(fixed_z).cpu())

    # Save fixed samples
    with open('fixed_samples.pkl', 'wb') as f:
        pkl.dump(fixed_samples, f)

    return d_losses, g_losses

# Setup and training
d = Discriminator()
g = Generator()

d_optim = optim.Adam(d.parameters(), lr=0.002)
g_optim = optim.Adam(g.parameters(), lr=0.002)
loss_fn = nn.BCEWithLogitsLoss()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Train
n_epochs = 100
d_losses, g_losses = train_gan(d, g, d_optim, g_optim, loss_fn, dl, n_epochs, device)


100%|██████████| 9.91M/9.91M [00:00<00:00, 10.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 344kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.16MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.86MB/s]


Training on [cuda]...
Epoch [1/100] - D_loss: 1.1749, G_loss: 1.6221
Epoch [2/100] - D_loss: 1.2778, G_loss: 1.0727
Epoch [3/100] - D_loss: 0.8129, G_loss: 1.6295
Epoch [4/100] - D_loss: 1.0287, G_loss: 1.4112
Epoch [5/100] - D_loss: 1.0626, G_loss: 1.3379
Epoch [6/100] - D_loss: 1.0472, G_loss: 1.3881
Epoch [7/100] - D_loss: 1.1335, G_loss: 1.1917
Epoch [8/100] - D_loss: 1.1721, G_loss: 1.1215
Epoch [9/100] - D_loss: 1.1580, G_loss: 1.1515
Epoch [10/100] - D_loss: 1.1929, G_loss: 1.0591
Epoch [11/100] - D_loss: 1.2373, G_loss: 0.9806
Epoch [12/100] - D_loss: 1.2512, G_loss: 0.9451
Epoch [13/100] - D_loss: 1.2479, G_loss: 0.9647
Epoch [14/100] - D_loss: 1.2114, G_loss: 1.0630
Epoch [15/100] - D_loss: 1.1899, G_loss: 1.0645
Epoch [16/100] - D_loss: 1.2174, G_loss: 0.9992
Epoch [17/100] - D_loss: 1.2410, G_loss: 0.9780
Epoch [18/100] - D_loss: 1.2201, G_loss: 1.0007
Epoch [19/100] - D_loss: 1.2331, G_loss: 0.9924
Epoch [20/100] - D_loss: 1.2278, G_loss: 1.0020
Epoch [21/100] - D_loss: 1.