In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image, make_grid
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset
import numpy as np
import torch.nn.utils.spectral_norm as spectral_norm

In [2]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
# Directories
image_dir = "celebA/celeba/img_align_celeba"
os.makedirs("gan_outputs", exist_ok=True)

In [4]:
# Hyperparameters
z_dim = 100
lr = 2e-4
batch_size = 3000
n_epochs = 60
patience = 10

In [5]:
# Transformations
transform = transforms.Compose([
    transforms.CenterCrop(160),
    transforms.Resize(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [6]:
# Dataset and splits
full_dataset = datasets.ImageFolder(root=image_dir, transform=transform)
total_size = len(full_dataset)
train_size = int(0.8 * total_size)
val_size = total_size - train_size  # ensure the sum matches exactly

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [7]:
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_channels=3, features_g=64):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, features_g * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(features_g * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self, img_channels=3, features_d=64):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(img_channels, features_d, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_d * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_d * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_d * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(features_d * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.net(img).view(-1)

In [8]:
# Initialize models and optimizers
G = Generator(z_dim).to(device)
D = Discriminator().to(device)

In [9]:
criterion = nn.BCEWithLogitsLoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

fixed_noise = torch.randn(64, z_dim, 1, 1, device=device)

In [10]:
best_val_loss = float("inf")
epochs_no_improve = 0
train_losses = []
val_losses = []

# Training loop
for epoch in range(n_epochs):
    G.train()
    D.train()
    running_loss_G = 0.0
    running_loss_D = 0.0

    print(f"Epoch [{epoch+1}/{n_epochs}]")
    for real_imgs, _ in tqdm(train_loader):
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)

        # Labels
        real_labels = torch.ones(batch_size, device=device)
        fake_labels = torch.zeros(batch_size, device=device)

        # Train Discriminator
        z = torch.randn(batch_size, z_dim, 1, 1, device=device)
        fake_imgs = G(z).detach()
        D_real = D(real_imgs)
        D_fake = D(fake_imgs)
        loss_D = criterion(D_real, real_labels) + criterion(D_fake, fake_labels)

        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        # Train Generator
        z = torch.randn(batch_size, z_dim, 1, 1, device=device)
        fake_imgs = G(z)
        D_fake = D(fake_imgs)
        loss_G = criterion(D_fake, real_labels)

        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        running_loss_D += loss_D.item()
        running_loss_G += loss_G.item()

    avg_train_loss_G = running_loss_G / len(train_loader)
    avg_train_loss_D = running_loss_D / len(train_loader)
    train_losses.append((avg_train_loss_D, avg_train_loss_G))

    # Validation step
    G.eval()
    D.eval()
    val_loss_D = 0.0
    val_loss_G = 0.0
    with torch.no_grad():
        for real_imgs, _ in val_loader:
            real_imgs = real_imgs.to(device)
            batch_size = real_imgs.size(0)
            real_labels = torch.ones(batch_size, device=device)
            fake_labels = torch.zeros(batch_size, device=device)

            # Discriminator validation
            z = torch.randn(batch_size, z_dim, 1, 1, device=device)
            fake_imgs = G(z)
            D_real = D(real_imgs)
            D_fake = D(fake_imgs)
            loss_D = criterion(D_real, real_labels) + criterion(D_fake, fake_labels)

            # Generator validation
            z = torch.randn(batch_size, z_dim, 1, 1, device=device)
            fake_imgs = G(z)
            D_fake = D(fake_imgs)
            loss_G = criterion(D_fake, real_labels)

            val_loss_D += loss_D.item()
            val_loss_G += loss_G.item()

    avg_val_loss_D = val_loss_D / len(val_loader)
    avg_val_loss_G = val_loss_G / len(val_loader)
    val_losses.append((avg_val_loss_D, avg_val_loss_G))

    print(f"Train Loss D: {avg_train_loss_D:.4f}, G: {avg_train_loss_G:.4f} | Val Loss D: {avg_val_loss_D:.4f}, G: {avg_val_loss_G:.4f}")

    # Early stopping check
    if avg_val_loss_D + avg_val_loss_G < best_val_loss:
        best_val_loss = avg_val_loss_D + avg_val_loss_G
        epochs_no_improve = 0
        torch.save(G.state_dict(), "gan_outputs/best_generator.pth")
        torch.save(D.state_dict(), "gan_outputs/best_discriminator.pth")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print("Early stopping triggered!")
            break

    # Save sample images
    with torch.no_grad():
        fake_imgs = G(fixed_noise).detach().cpu()
        grid = make_grid(fake_imgs, padding=2, normalize=True)
        save_image(grid, f"gan_outputs/epoch_{epoch+1}.png")

Epoch [1/60]


  0%|          | 0/55 [00:00<?, ?it/s]

  2%|▏         | 1/55 [00:04<04:08,  4.60s/it]

  4%|▎         | 2/55 [00:08<03:34,  4.05s/it]

  5%|▌         | 3/55 [00:11<03:07,  3.60s/it]

  7%|▋         | 4/55 [00:14<02:54,  3.43s/it]

  9%|▉         | 5/55 [00:17<02:45,  3.31s/it]

 11%|█         | 6/55 [00:20<02:38,  3.24s/it]

 13%|█▎        | 7/55 [00:23<02:34,  3.21s/it]

 15%|█▍        | 8/55 [00:26<02:28,  3.16s/it]

 16%|█▋        | 9/55 [00:30<02:24,  3.15s/it]

 18%|█▊        | 10/55 [00:33<02:19,  3.11s/it]

 20%|██        | 11/55 [00:36<02:16,  3.10s/it]

 22%|██▏       | 12/55 [00:39<02:13,  3.11s/it]

 24%|██▎       | 13/55 [00:42<02:11,  3.13s/it]

 25%|██▌       | 14/55 [00:45<02:08,  3.13s/it]

 27%|██▋       | 15/55 [00:48<02:05,  3.13s/it]

 29%|██▉       | 16/55 [00:51<02:01,  3.12s/it]

 31%|███       | 17/55 [00:55<01:59,  3.15s/it]

 33%|███▎      | 18/55 [00:58<01:55,  3.12s/it]

 35%|███▍      | 19/55 [01:01<01:52,  3.12s/it]

 36%|███▋      | 20/55 [01:04<01:48,  3.09s/it]

 38%|███▊      | 21/55 [01:07<01:46,  3.12s/it]

 40%|████      | 22/55 [01:10<01:43,  3.13s/it]

 42%|████▏     | 23/55 [01:13<01:39,  3.12s/it]

 44%|████▎     | 24/55 [01:16<01:37,  3.14s/it]

 45%|████▌     | 25/55 [01:19<01:33,  3.12s/it]

 47%|████▋     | 26/55 [01:23<01:30,  3.13s/it]

 49%|████▉     | 27/55 [01:26<01:27,  3.12s/it]

 51%|█████     | 28/55 [01:29<01:24,  3.13s/it]

 53%|█████▎    | 29/55 [01:32<01:21,  3.12s/it]

 55%|█████▍    | 30/55 [01:35<01:18,  3.14s/it]

 56%|█████▋    | 31/55 [01:38<01:14,  3.11s/it]

 58%|█████▊    | 32/55 [01:41<01:12,  3.16s/it]

 60%|██████    | 33/55 [01:45<01:09,  3.17s/it]

 62%|██████▏   | 34/55 [01:48<01:06,  3.16s/it]

 64%|██████▎   | 35/55 [01:51<01:02,  3.15s/it]

 65%|██████▌   | 36/55 [01:54<01:00,  3.18s/it]

 67%|██████▋   | 37/55 [01:57<00:56,  3.13s/it]

 69%|██████▉   | 38/55 [02:00<00:53,  3.13s/it]

 71%|███████   | 39/55 [02:03<00:50,  3.14s/it]

 73%|███████▎  | 40/55 [02:06<00:46,  3.12s/it]

 75%|███████▍  | 41/55 [02:10<00:43,  3.13s/it]

 76%|███████▋  | 42/55 [02:13<00:40,  3.14s/it]

 78%|███████▊  | 43/55 [02:16<00:37,  3.14s/it]

 80%|████████  | 44/55 [02:19<00:34,  3.11s/it]

 82%|████████▏ | 45/55 [02:22<00:31,  3.12s/it]

 84%|████████▎ | 46/55 [02:25<00:27,  3.11s/it]

 85%|████████▌ | 47/55 [02:28<00:24,  3.10s/it]

 87%|████████▋ | 48/55 [02:31<00:21,  3.10s/it]

 89%|████████▉ | 49/55 [02:35<00:18,  3.11s/it]

 91%|█████████ | 50/55 [02:38<00:15,  3.08s/it]

 93%|█████████▎| 51/55 [02:41<00:12,  3.11s/it]

 95%|█████████▍| 52/55 [02:44<00:09,  3.12s/it]

 96%|█████████▋| 53/55 [02:47<00:06,  3.12s/it]

 98%|█████████▊| 54/55 [02:50<00:03,  3.12s/it]

100%|██████████| 55/55 [02:50<00:00,  2.22s/it]

100%|██████████| 55/55 [02:50<00:00,  3.10s/it]




Train Loss D: 1.0433, G: 0.6925 | Val Loss D: 1.0099, G: 0.6931
Epoch [2/60]


  0%|          | 0/55 [00:00<?, ?it/s]

  2%|▏         | 1/55 [00:03<02:46,  3.08s/it]

  4%|▎         | 2/55 [00:06<02:42,  3.07s/it]

  5%|▌         | 3/55 [00:09<02:39,  3.06s/it]

  7%|▋         | 4/55 [00:12<02:36,  3.07s/it]

  9%|▉         | 5/55 [00:15<02:36,  3.13s/it]

 11%|█         | 6/55 [00:18<02:32,  3.11s/it]

 13%|█▎        | 7/55 [00:21<02:29,  3.11s/it]

 15%|█▍        | 8/55 [00:24<02:26,  3.12s/it]

 16%|█▋        | 9/55 [00:28<02:25,  3.15s/it]

 18%|█▊        | 10/55 [00:31<02:21,  3.14s/it]

 20%|██        | 11/55 [00:34<02:18,  3.14s/it]

 22%|██▏       | 12/55 [00:37<02:13,  3.12s/it]

 24%|██▎       | 13/55 [00:40<02:10,  3.12s/it]

 25%|██▌       | 14/55 [00:43<02:07,  3.10s/it]

 27%|██▋       | 15/55 [00:46<02:04,  3.10s/it]

 29%|██▉       | 16/55 [00:49<02:00,  3.10s/it]

 31%|███       | 17/55 [00:52<01:57,  3.09s/it]

 33%|███▎      | 18/55 [00:55<01:54,  3.09s/it]

 35%|███▍      | 19/55 [00:58<01:51,  3.09s/it]

 36%|███▋      | 20/55 [01:02<01:48,  3.09s/it]

 38%|███▊      | 21/55 [01:05<01:45,  3.09s/it]

 40%|████      | 22/55 [01:08<01:41,  3.08s/it]

 42%|████▏     | 23/55 [01:11<01:38,  3.08s/it]

 44%|████▎     | 24/55 [01:14<01:34,  3.05s/it]

 45%|████▌     | 25/55 [01:17<01:32,  3.09s/it]

 47%|████▋     | 26/55 [01:20<01:29,  3.09s/it]

 49%|████▉     | 27/55 [01:23<01:25,  3.06s/it]

 51%|█████     | 28/55 [01:26<01:22,  3.07s/it]

 53%|█████▎    | 29/55 [01:29<01:20,  3.11s/it]

 55%|█████▍    | 30/55 [01:32<01:17,  3.10s/it]

 56%|█████▋    | 31/55 [01:36<01:14,  3.09s/it]

 58%|█████▊    | 32/55 [01:39<01:11,  3.10s/it]

 60%|██████    | 33/55 [01:42<01:07,  3.06s/it]

 62%|██████▏   | 34/55 [01:45<01:04,  3.07s/it]

 64%|██████▎   | 35/55 [01:48<01:01,  3.09s/it]

 65%|██████▌   | 36/55 [01:51<00:58,  3.09s/it]

 67%|██████▋   | 37/55 [01:54<00:55,  3.08s/it]

 69%|██████▉   | 38/55 [01:57<00:52,  3.08s/it]

 71%|███████   | 39/55 [02:00<00:49,  3.08s/it]

 73%|███████▎  | 40/55 [02:03<00:46,  3.09s/it]

 75%|███████▍  | 41/55 [02:06<00:43,  3.09s/it]

 76%|███████▋  | 42/55 [02:09<00:40,  3.10s/it]

 78%|███████▊  | 43/55 [02:13<00:37,  3.09s/it]

 80%|████████  | 44/55 [02:16<00:34,  3.12s/it]

 82%|████████▏ | 45/55 [02:19<00:30,  3.08s/it]

 84%|████████▎ | 46/55 [02:22<00:27,  3.08s/it]

 85%|████████▌ | 47/55 [02:25<00:24,  3.08s/it]

 87%|████████▋ | 48/55 [02:28<00:21,  3.12s/it]

 89%|████████▉ | 49/55 [02:31<00:18,  3.10s/it]

 91%|█████████ | 50/55 [02:34<00:15,  3.08s/it]

 93%|█████████▎| 51/55 [02:37<00:12,  3.07s/it]

 95%|█████████▍| 52/55 [02:40<00:09,  3.09s/it]

 96%|█████████▋| 53/55 [02:43<00:06,  3.09s/it]

 98%|█████████▊| 54/55 [02:47<00:03,  3.08s/it]

100%|██████████| 55/55 [02:47<00:00,  3.04s/it]




Train Loss D: 1.0066, G: 0.6931 | Val Loss D: 1.0078, G: 0.6931
Epoch [3/60]


  0%|          | 0/55 [00:00<?, ?it/s]

  2%|▏         | 1/55 [00:03<02:48,  3.12s/it]

  4%|▎         | 2/55 [00:06<02:44,  3.11s/it]

  5%|▌         | 3/55 [00:09<02:44,  3.16s/it]

  7%|▋         | 4/55 [00:12<02:40,  3.15s/it]

  9%|▉         | 5/55 [00:15<02:36,  3.13s/it]

 11%|█         | 6/55 [00:18<02:33,  3.14s/it]

 13%|█▎        | 7/55 [00:21<02:28,  3.10s/it]

 15%|█▍        | 8/55 [00:25<02:26,  3.12s/it]

 16%|█▋        | 9/55 [00:28<02:23,  3.13s/it]

 18%|█▊        | 10/55 [00:31<02:20,  3.11s/it]

 20%|██        | 11/55 [00:34<02:15,  3.08s/it]

 22%|██▏       | 12/55 [00:37<02:11,  3.05s/it]

 24%|██▎       | 13/55 [00:40<02:07,  3.03s/it]

 25%|██▌       | 14/55 [00:43<02:05,  3.05s/it]

 27%|██▋       | 15/55 [00:46<02:02,  3.06s/it]

 29%|██▉       | 16/55 [00:49<01:59,  3.06s/it]

 31%|███       | 17/55 [00:52<01:56,  3.07s/it]

 33%|███▎      | 18/55 [00:55<01:55,  3.11s/it]

 35%|███▍      | 19/55 [00:58<01:50,  3.08s/it]

 36%|███▋      | 20/55 [01:01<01:46,  3.05s/it]

 38%|███▊      | 21/55 [01:04<01:45,  3.10s/it]

 40%|████      | 22/55 [01:08<01:42,  3.10s/it]

 42%|████▏     | 23/55 [01:11<01:39,  3.10s/it]

 44%|████▎     | 24/55 [01:14<01:36,  3.10s/it]

 45%|████▌     | 25/55 [01:17<01:31,  3.06s/it]

 47%|████▋     | 26/55 [01:20<01:28,  3.07s/it]

 49%|████▉     | 27/55 [01:23<01:25,  3.06s/it]

 51%|█████     | 28/55 [01:26<01:22,  3.07s/it]

 53%|█████▎    | 29/55 [01:29<01:20,  3.08s/it]

 55%|█████▍    | 30/55 [01:32<01:17,  3.09s/it]

 56%|█████▋    | 31/55 [01:35<01:13,  3.07s/it]

 58%|█████▊    | 32/55 [01:38<01:10,  3.08s/it]

 60%|██████    | 33/55 [01:41<01:07,  3.05s/it]

 62%|██████▏   | 34/55 [01:44<01:03,  3.04s/it]

 64%|██████▎   | 35/55 [01:47<01:01,  3.07s/it]

 65%|██████▌   | 36/55 [01:50<00:58,  3.06s/it]

 67%|██████▋   | 37/55 [01:53<00:54,  3.04s/it]

 69%|██████▉   | 38/55 [01:57<00:52,  3.10s/it]

 71%|███████   | 39/55 [02:00<00:49,  3.09s/it]

 73%|███████▎  | 40/55 [02:03<00:46,  3.07s/it]

 75%|███████▍  | 41/55 [02:06<00:43,  3.10s/it]

 76%|███████▋  | 42/55 [02:09<00:40,  3.09s/it]

 78%|███████▊  | 43/55 [02:12<00:36,  3.08s/it]

 80%|████████  | 44/55 [02:15<00:33,  3.05s/it]

 82%|████████▏ | 45/55 [02:18<00:30,  3.06s/it]

 84%|████████▎ | 46/55 [02:21<00:27,  3.07s/it]

 85%|████████▌ | 47/55 [02:24<00:24,  3.04s/it]

 87%|████████▋ | 48/55 [02:27<00:21,  3.04s/it]

 89%|████████▉ | 49/55 [02:30<00:18,  3.04s/it]

 91%|█████████ | 50/55 [02:33<00:15,  3.05s/it]

In [None]:
# Plot losses
train_loss_D, train_loss_G = zip(*train_losses)
val_loss_D, val_loss_G = zip(*val_losses)

plt.figure(figsize=(10, 5))
plt.plot(train_loss_D, label="Train Loss D")
plt.plot(train_loss_G, label="Train Loss G")
plt.plot(val_loss_D, label="Val Loss D")
plt.plot(val_loss_G, label="Val Loss G")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Training and Validation Losses")
plt.savefig("gan_outputs/loss_plot.png")
plt.show()


In [None]:
# Generate final images using best model
G.load_state_dict(torch.load("gan_outputs/best_generator.pth"))
G.eval()
with torch.no_grad():
    final_noise = torch.randn(64, z_dim, 1, 1, device=device)
    generated_imgs = G(final_noise)
    save_image(make_grid(generated_imgs, normalize=True), "gan_outputs/final_generated.png")