In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, utils
import numpy as np
import os

# Hyperparameters
z_dim = 100          # Dimension of generator input (noise vector)
g_feat = 128         # Base number of generator feature maps
d_feat = 64          # Base number of discriminator feature maps
batch_size = 128
learning_rate = 2e-4
epochs = 30

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load Fashion-MNIST dataset (normalized to [0,1])
transform = transforms.ToTensor()
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Define the Generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # We will project the random z into a 7x7 feature map with g_feat*2 channels, then upsample to 28x28
        self.fc = nn.Linear(z_dim, g_feat*2 * 7 * 7)
        # Transposed convolutions to scale up the image
        self.convT1 = nn.ConvTranspose2d(g_feat*2, g_feat, kernel_size=4, stride=2, padding=1)  # 7x7 -> 14x14
        self.convT2 = nn.ConvTranspose2d(g_feat, 1, kernel_size=4, stride=2, padding=1)         # 14x14 -> 28x28
        # Batch norms (except for output layer)
        self.bn1 = nn.BatchNorm2d(g_feat*2)
        self.bn2 = nn.BatchNorm2d(g_feat)

    def forward(self, z):
        # Fully connected to reshape into [batch, g_feat*2, 7, 7]
        x = self.fc(z)
        x = self.bn1(x.view(-1, g_feat*2, 7, 7))
        x = F.relu(x)
        # First upsampling convolution
        x = self.convT1(x)
        x = self.bn2(x)
        x = F.relu(x)
        # Second upsampling convolution to get 1x28x28 image
        x = torch.sigmoid(self.convT2(x))  # Sigmoid to output pixels in [0,1]
        return x

# Define the Discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Convolution layers to downsample the image
        self.conv1 = nn.Conv2d(1, d_feat, kernel_size=4, stride=2, padding=1)    # 28x28 -> 14x14
        self.conv2 = nn.Conv2d(d_feat, d_feat*2, kernel_size=4, stride=2, padding=1)  # 14x14 -> 7x7
        self.fc = nn.Linear(d_feat*2 * 7 * 7, 1)  # Fully connected to single logit output
        # Batch norm (not applied to input layer conv1)
        self.bn2 = nn.BatchNorm2d(d_feat*2)

    def forward(self, x):
        # Downsample with LeakyReLU activations
        x = F.leaky_relu(self.conv1(x), negative_slope=0.2)
        x = F.leaky_relu(self.bn2(self.conv2(x)), negative_slope=0.2)
        x = x.view(x.size(0), -1)  # flatten
        # Output layer (we will apply Sigmoid after obtaining this logit)
        logit = self.fc(x)
        out = torch.sigmoid(logit)  # probability of being real
        return out

# Initialize generator, discriminator, and optimizers
G = Generator().to(device)
D = Discriminator().to(device)
opt_G = optim.Adam(G.parameters(), lr=learning_rate, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=learning_rate, betas=(0.5, 0.999))
criterion = nn.BCELoss()

os.makedirs('gan_outputs', exist_ok=True)
fixed_noise = torch.randn(16, z_dim, device=device)  # fixed noise for monitoring progress

for epoch in range(1, epochs+1):
    G.train(); D.train()
    g_loss_sum = 0.0
    d_loss_sum = 0.0
    for real_imgs, _ in train_loader:
        real_imgs = real_imgs.to(device)
        batch_size_curr = real_imgs.size(0)

        # Labels for real and fake
        real_labels = torch.ones(batch_size_curr, device=device)
        fake_labels = torch.zeros(batch_size_curr, device=device)

        # ---- Train Discriminator ----
        # Real images
        opt_D.zero_grad()
        out_real = D(real_imgs)               # D(real)
        loss_real = criterion(out_real.squeeze(), real_labels)
        # Fake images
        z = torch.randn(batch_size_curr, z_dim, device=device)
        fake_imgs = G(z).detach()             # G(z), detach to avoid gradient to G when training D
        out_fake = D(fake_imgs)              # D(fake)
        loss_fake = criterion(out_fake.squeeze(), fake_labels)
        # Total discriminator loss
        d_loss = loss_real + loss_fake
        d_loss.backward()
        opt_D.step()

        # ---- Train Generator ----
        opt_G.zero_grad()
        z = torch.randn(batch_size_curr, z_dim, device=device)
        fake_imgs = G(z)                     # new fake images for generator update
        out_fake = D(fake_imgs)              # D's opinion on these fakes
        # Generator tries to fool discriminator: we want D(fake) to output 1 (real)
        g_loss = criterion(out_fake.squeeze(), real_labels)
        g_loss.backward()
        opt_G.step()

        # Accumulate losses for monitoring
        g_loss_sum += g_loss.item()
        d_loss_sum += d_loss.item()
    # End of epoch
    avg_g_loss = g_loss_sum / len(train_loader)
    avg_d_loss = d_loss_sum / len(train_loader)
    print(f"Epoch {epoch}/{epochs} - Generator loss: {avg_g_loss:.4f}, Discriminator loss: {avg_d_loss:.4f}")

    # Save generator outputs on fixed_noise to monitor training progress
    G.eval()
    with torch.no_grad():
        fake_samples = G(fixed_noise).cpu()
    utils.save_image(fake_samples, f"gan_outputs/fixed_samples_epoch{epoch}.png", nrow=4, normalize=True)

# After training, save a larger batch of generated images and an interpolation
G.eval()
with torch.no_grad():
    # Generate 64 random images
    z = torch.randn(64, z_dim, device=device)
    fake_images = G(z).cpu()
    utils.save_image(fake_images, "gan_outputs/generated_images.png", nrow=8, normalize=True)
    # Interpolate between two random latent vectors
    z_start = torch.randn(1, z_dim, device=device)
    z_end   = torch.randn(1, z_dim, device=device)
    # Generate interpolation series of 10 images
    alphas = torch.linspace(0, 1, steps=10).to(device)
    interp_images = []
    for alpha in alphas:
        z_interp = (1 - alpha) * z_start + alpha * z_end
        interp_img = G(z_interp).cpu()
        interp_images.append(interp_img)
    interp_images = torch.cat(interp_images, dim=0)
    utils.save_image(interp_images, "gan_outputs/interpolation.png", nrow=10, normalize=True)
print("GAN training complete. Check 'gan_outputs/' for generated images.")


100%|██████████| 26.4M/26.4M [00:02<00:00, 11.7MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 194kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.70MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 10.2MB/s]


Epoch 1/30 - Generator loss: 1.9982, Discriminator loss: 0.5149
Epoch 2/30 - Generator loss: 2.1473, Discriminator loss: 0.5013
Epoch 3/30 - Generator loss: 1.9226, Discriminator loss: 0.5712
Epoch 4/30 - Generator loss: 1.6871, Discriminator loss: 0.6689


KeyboardInterrupt: 