<span style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">An Exception was encountered at '<a href="#papermill-error-cell">In [8]</a>'.</span>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from tqdm import tqdm
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import random
import matplotlib.pyplot as plt

In [2]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
# Load Dataset
image_dir = "celebA/celeba/img_align_celeba"
os.makedirs("test", exist_ok=True)
os.makedirs("gan_outputs", exist_ok=True)

transform = transforms.Compose([
    transforms.CenterCrop(160),
    transforms.Resize(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.ImageFolder(root=image_dir, transform=transform)
dataset_size = len(dataset)
indices = list(range(dataset_size))
random.seed(42)
random.shuffle(indices)
split = int(0.1 * dataset_size)
val_indices, train_indices = indices[:split], indices[split:]

train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)

train_loader = DataLoader(train_dataset, batch_size=3000, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=3000, shuffle=False, num_workers=4, pin_memory=True)



In [4]:
# Define Models
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_channels=3, features_g=64):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, features_g * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(features_g * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self, img_channels=3, features_d=64):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(img_channels, features_d, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_d * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_d * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_d * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(features_d * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.net(img).view(-1)

In [5]:
# Hyperparameters
z_dim = 100
lr = 2e-4
n_epochs = 50
batch_size = 3000
early_stop_patience = 6

In [6]:
G = Generator(z_dim).to(device)
D = Discriminator().to(device)

In [7]:
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

fixed_noise = torch.randn(64, z_dim, 1, 1, device=device)

<span id="papermill-error-cell" style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">Execution using papermill encountered an exception here and stopped:</span>

In [8]:
train_g_losses = []
val_g_losses = []
best_val_loss = float('inf')
epochs_no_improve = 0

# Training Loop
for epoch in range(n_epochs):
    g_loss_epoch = 0.0
    d_loss_epoch = 0.0
    num_batches = 0

    G.train()
    for real_imgs, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}"):
        real_imgs = real_imgs.to(device)

        # Train Discriminator
        z = torch.randn(batch_size, z_dim, 1, 1, device=device)
        fake_imgs = G(z)

        real_labels = torch.ones(batch_size, device=device)
        fake_labels = torch.zeros(batch_size, device=device)

        output_real = D(real_imgs)
        output_fake = D(fake_imgs.detach())
        loss_D = criterion(output_real, real_labels) + criterion(output_fake, fake_labels)

        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        # Train Generator
        output_fake = D(fake_imgs)
        loss_G = criterion(output_fake, real_labels)

        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        g_loss_epoch += loss_G.item()
        d_loss_epoch += loss_D.item()
        num_batches += 1

    # Validation step
    G.eval()
    val_loss = 0.0
    with torch.no_grad():
        for real_imgs, _ in val_loader:
            real_imgs = real_imgs.to(device)
            batch_size = real_imgs.size(0)
            z = torch.randn(batch_size, z_dim, 1, 1, device=device)
            fake_imgs = G(z)
            output_fake = D(fake_imgs)
            loss = criterion(output_fake, torch.ones(batch_size, device=device))
            val_loss += loss.item()

    val_loss /= len(val_loader)
    avg_g_loss = g_loss_epoch / num_batches

    train_g_losses.append(avg_g_loss)
    val_g_losses.append(val_loss)

    with torch.no_grad():
        samples = G(fixed_noise)
        save_image(samples, f"gan_outputs/epoch_{epoch+1:03d}.png", normalize=True)

    print(f"Epoch [{epoch+1}/{n_epochs}]  Loss_D: {d_loss_epoch/num_batches:.4f}  Loss_G: {avg_g_loss:.4f}  Val_Loss_G: {val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(G.state_dict(), "gan_outputs/best_generator.pth")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= early_stop_patience:
            print("Early stopping triggered!")
            break

Epoch 1/50:   0%|          | 0/61 [00:00<?, ?it/s]

Epoch 1/50:   2%|▏         | 1/61 [00:14<14:29, 14.49s/it]

Epoch 1/50:   3%|▎         | 2/61 [00:14<06:06,  6.22s/it]

Epoch 1/50:   5%|▍         | 3/61 [00:15<03:27,  3.58s/it]

Epoch 1/50:   7%|▋         | 4/61 [00:15<02:13,  2.35s/it]

Epoch 1/50:   8%|▊         | 5/61 [00:21<03:11,  3.43s/it]

Epoch 1/50:  10%|▉         | 6/61 [00:21<02:19,  2.53s/it]

Epoch 1/50:  11%|█▏        | 7/61 [00:22<01:39,  1.85s/it]

Epoch 1/50:  13%|█▎        | 8/61 [00:22<01:14,  1.40s/it]

Epoch 1/50:  15%|█▍        | 9/61 [00:30<03:02,  3.50s/it]

Epoch 1/50:  16%|█▋        | 10/61 [00:31<02:15,  2.65s/it]

Epoch 1/50:  18%|█▊        | 11/61 [00:32<01:38,  1.98s/it]

Epoch 1/50:  20%|█▉        | 12/61 [00:32<01:14,  1.51s/it]

Epoch 1/50:  21%|██▏       | 13/61 [00:40<02:47,  3.49s/it]

Epoch 1/50:  23%|██▎       | 14/61 [00:41<02:02,  2.60s/it]

Epoch 1/50:  25%|██▍       | 15/61 [00:41<01:29,  1.95s/it]

Epoch 1/50:  26%|██▌       | 16/61 [00:42<01:07,  1.50s/it]

Epoch 1/50:  28%|██▊       | 17/61 [00:50<02:33,  3.48s/it]

Epoch 1/50:  30%|██▉       | 18/61 [00:50<01:54,  2.66s/it]

Epoch 1/50:  31%|███       | 19/61 [00:51<01:23,  2.00s/it]

Epoch 1/50:  33%|███▎      | 20/61 [00:51<01:02,  1.53s/it]

Epoch 1/50:  34%|███▍      | 21/61 [00:59<02:15,  3.40s/it]

Epoch 1/50:  36%|███▌      | 22/61 [01:00<01:40,  2.58s/it]

Epoch 1/50:  38%|███▊      | 23/61 [01:00<01:13,  1.94s/it]

Epoch 1/50:  39%|███▉      | 24/61 [01:01<00:55,  1.49s/it]

Epoch 1/50:  41%|████      | 25/61 [01:09<02:04,  3.46s/it]

Epoch 1/50:  43%|████▎     | 26/61 [01:09<01:31,  2.61s/it]

Epoch 1/50:  44%|████▍     | 27/61 [01:10<01:06,  1.96s/it]

Epoch 1/50:  46%|████▌     | 28/61 [01:10<00:49,  1.51s/it]

Epoch 1/50:  48%|████▊     | 29/61 [01:18<01:51,  3.48s/it]

Epoch 1/50:  49%|████▉     | 30/61 [01:19<01:20,  2.61s/it]

Epoch 1/50:  51%|█████     | 31/61 [01:19<00:58,  1.96s/it]

Epoch 1/50:  52%|█████▏    | 32/61 [01:20<00:43,  1.50s/it]

Epoch 1/50:  54%|█████▍    | 33/61 [01:28<01:35,  3.42s/it]

Epoch 1/50:  56%|█████▌    | 34/61 [01:28<01:11,  2.65s/it]

Epoch 1/50:  57%|█████▋    | 35/61 [01:29<00:51,  1.99s/it]

Epoch 1/50:  59%|█████▉    | 36/61 [01:29<00:38,  1.53s/it]

Epoch 1/50:  61%|██████    | 37/61 [01:37<01:21,  3.39s/it]

Epoch 1/50:  62%|██████▏   | 38/61 [01:38<00:59,  2.60s/it]

Epoch 1/50:  64%|██████▍   | 39/61 [01:38<00:43,  1.96s/it]

Epoch 1/50:  66%|██████▌   | 40/61 [01:39<00:31,  1.50s/it]

Epoch 1/50:  67%|██████▋   | 41/61 [01:47<01:09,  3.48s/it]

Epoch 1/50:  69%|██████▉   | 42/61 [01:48<00:50,  2.67s/it]

Epoch 1/50:  70%|███████   | 43/61 [01:48<00:36,  2.00s/it]

Epoch 1/50:  72%|███████▏  | 44/61 [01:49<00:26,  1.54s/it]

Epoch 1/50:  74%|███████▍  | 45/61 [01:56<00:54,  3.40s/it]

Epoch 1/50:  75%|███████▌  | 46/61 [01:57<00:40,  2.67s/it]

Epoch 1/50:  77%|███████▋  | 47/61 [01:58<00:27,  2.00s/it]

Epoch 1/50:  79%|███████▊  | 48/61 [01:58<00:19,  1.53s/it]

Epoch 1/50:  80%|████████  | 49/61 [02:06<00:40,  3.41s/it]

Epoch 1/50:  82%|████████▏ | 50/61 [02:07<00:29,  2.69s/it]

Epoch 1/50:  84%|████████▎ | 51/61 [02:07<00:20,  2.02s/it]

Epoch 1/50:  85%|████████▌ | 52/61 [02:08<00:13,  1.55s/it]

Epoch 1/50:  87%|████████▋ | 53/61 [02:15<00:26,  3.31s/it]

Epoch 1/50:  89%|████████▊ | 54/61 [02:17<00:19,  2.73s/it]

Epoch 1/50:  90%|█████████ | 55/61 [02:17<00:12,  2.05s/it]

Epoch 1/50:  92%|█████████▏| 56/61 [02:18<00:07,  1.57s/it]

Epoch 1/50:  93%|█████████▎| 57/61 [02:25<00:13,  3.29s/it]

Epoch 1/50:  95%|█████████▌| 58/61 [02:26<00:07,  2.61s/it]

Epoch 1/50:  97%|█████████▋| 59/61 [02:26<00:03,  1.95s/it]

Epoch 1/50:  98%|█████████▊| 60/61 [02:27<00:01,  1.49s/it]

Epoch 1/50:  98%|█████████▊| 60/61 [02:28<00:02,  2.47s/it]




ValueError: Using a target size (torch.Size([3000])) that is different to the input size (torch.Size([2340])) is deprecated. Please ensure they have the same size.

In [None]:
# Plot losses
plt.figure(figsize=(8, 5))
plt.plot(range(1, len(train_g_losses) + 1), train_g_losses, label='Training Generator Loss')
plt.plot(range(1, len(val_g_losses) + 1), val_g_losses, label='Validation Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Generator Training vs Validation Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("gan_outputs/loss_plot.png")
plt.show()

In [None]:
# Generate final images using best generator
os.makedirs("gan_outputs/generated", exist_ok=True)
G.load_state_dict(torch.load("gan_outputs/best_generator.pth"))
G.eval()
with torch.no_grad():
    for i in tqdm(range(0, 10000, 64), desc="Generating final images"):
        z = torch.randn(64, z_dim, 1, 1, device=device)
        gen_imgs = G(z)
        for j in range(gen_imgs.size(0)):
            save_image(gen_imgs[j], f"gan_outputs/generated/{i + j:05d}.png", normalize=True)