In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from tqdm import tqdm
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import random
import matplotlib.pyplot as plt

In [2]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
# Load Dataset
image_dir = "celebA/celeba/img_align_celeba"
os.makedirs("test", exist_ok=True)
os.makedirs("gan_outputs", exist_ok=True)

transform = transforms.Compose([
    transforms.CenterCrop(160),
    transforms.Resize(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.ImageFolder(root=image_dir, transform=transform)
dataset_size = len(dataset)
indices = list(range(dataset_size))
random.seed(42)
random.shuffle(indices)
split = int(0.1 * dataset_size)
val_indices, train_indices = indices[:split], indices[split:]

train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)

train_loader = DataLoader(train_dataset, batch_size=3000, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=3000, shuffle=False, num_workers=4, pin_memory=True)



In [4]:
# Define Models
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_channels=3, features_g=64):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, features_g * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(features_g * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self, img_channels=3, features_d=64):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(img_channels, features_d, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_d * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_d * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_d * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(features_d * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.net(img).view(-1)

In [5]:
# Hyperparameters
z_dim = 100
lr = 2e-4
n_epochs = 50
batch_size = 3000
early_stop_patience = 6

In [6]:
G = Generator(z_dim).to(device)
D = Discriminator().to(device)

In [7]:
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

fixed_noise = torch.randn(64, z_dim, 1, 1, device=device)

In [8]:
train_g_losses = []
val_g_losses = []
best_val_loss = float('inf')
epochs_no_improve = 0

# Training Loop
for epoch in range(n_epochs):
    g_loss_epoch = 0.0
    d_loss_epoch = 0.0
    num_batches = 0

    G.train()
    for real_imgs, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}"):
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)

        # Train Discriminator
        z = torch.randn(batch_size, z_dim, 1, 1, device=device)
        fake_imgs = G(z)

        real_labels = torch.ones(batch_size, device=device)
        fake_labels = torch.zeros(batch_size, device=device)

        output_real = D(real_imgs)
        output_fake = D(fake_imgs.detach())
        loss_D = criterion(output_real, real_labels) + criterion(output_fake, fake_labels)

        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        # Train Generator
        output_fake = D(fake_imgs)
        loss_G = criterion(output_fake, real_labels)

        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        g_loss_epoch += loss_G.item()
        d_loss_epoch += loss_D.item()
        num_batches += 1

    # Validation step
    G.eval()
    val_loss = 0.0
    with torch.no_grad():
        for real_imgs, _ in val_loader:
            real_imgs = real_imgs.to(device)
            batch_size = real_imgs.size(0)
            z = torch.randn(batch_size, z_dim, 1, 1, device=device)
            fake_imgs = G(z)
            output_fake = D(fake_imgs)
            loss = criterion(output_fake, torch.ones(batch_size, device=device))
            val_loss += loss.item()

    val_loss /= len(val_loader)
    avg_g_loss = g_loss_epoch / num_batches

    train_g_losses.append(avg_g_loss)
    val_g_losses.append(val_loss)

    with torch.no_grad():
        samples = G(fixed_noise)
        save_image(samples, f"gan_outputs/epoch_{epoch+1:03d}.png", normalize=True)

    print(f"Epoch [{epoch+1}/{n_epochs}]  Loss_D: {d_loss_epoch/num_batches:.4f}  Loss_G: {avg_g_loss:.4f}  Val_Loss_G: {val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(G.state_dict(), "gan_outputs/best_generator.pth")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= early_stop_patience:
            print("Early stopping triggered!")
            break

Epoch 1/50:   0%|          | 0/61 [00:00<?, ?it/s]

Epoch 1/50:   2%|▏         | 1/61 [00:14<14:40, 14.67s/it]

Epoch 1/50:   3%|▎         | 2/61 [00:15<06:11,  6.30s/it]

Epoch 1/50:   5%|▍         | 3/61 [00:15<03:30,  3.63s/it]

Epoch 1/50:   7%|▋         | 4/61 [00:16<02:15,  2.37s/it]

Epoch 1/50:   8%|▊         | 5/61 [00:21<03:14,  3.48s/it]

Epoch 1/50:  10%|▉         | 6/61 [00:21<02:16,  2.48s/it]

Epoch 1/50:  11%|█▏        | 7/61 [00:22<01:38,  1.82s/it]

Epoch 1/50:  13%|█▎        | 8/61 [00:22<01:13,  1.38s/it]

Epoch 1/50:  15%|█▍        | 9/61 [00:31<03:02,  3.51s/it]

Epoch 1/50:  16%|█▋        | 10/61 [00:31<02:17,  2.70s/it]

Epoch 1/50:  18%|█▊        | 11/61 [00:32<01:40,  2.01s/it]

Epoch 1/50:  20%|█▉        | 12/61 [00:32<01:15,  1.53s/it]

Epoch 1/50:  21%|██▏       | 13/61 [00:40<02:44,  3.43s/it]

Epoch 1/50:  23%|██▎       | 14/61 [00:41<02:07,  2.72s/it]

Epoch 1/50:  25%|██▍       | 15/61 [00:42<01:33,  2.04s/it]

Epoch 1/50:  26%|██▌       | 16/61 [00:42<01:10,  1.56s/it]

Epoch 1/50:  28%|██▊       | 17/61 [00:50<02:29,  3.40s/it]

Epoch 1/50:  30%|██▉       | 18/61 [00:51<01:54,  2.66s/it]

Epoch 1/50:  31%|███       | 19/61 [00:51<01:23,  2.00s/it]

Epoch 1/50:  33%|███▎      | 20/61 [00:52<01:02,  1.53s/it]

Epoch 1/50:  34%|███▍      | 21/61 [01:00<02:17,  3.45s/it]

Epoch 1/50:  36%|███▌      | 22/61 [01:00<01:45,  2.69s/it]

Epoch 1/50:  38%|███▊      | 23/61 [01:01<01:16,  2.02s/it]

Epoch 1/50:  39%|███▉      | 24/61 [01:01<00:57,  1.56s/it]

Epoch 1/50:  41%|████      | 25/61 [01:09<02:01,  3.38s/it]

Epoch 1/50:  43%|████▎     | 26/61 [01:10<01:35,  2.72s/it]

Epoch 1/50:  44%|████▍     | 27/61 [01:11<01:09,  2.04s/it]

Epoch 1/50:  46%|████▌     | 28/61 [01:11<00:51,  1.55s/it]

Epoch 1/50:  48%|████▊     | 29/61 [01:19<01:49,  3.42s/it]

Epoch 1/50:  49%|████▉     | 30/61 [01:20<01:25,  2.75s/it]

Epoch 1/50:  51%|█████     | 31/61 [01:21<01:01,  2.06s/it]

Epoch 1/50:  52%|█████▏    | 32/61 [01:21<00:45,  1.58s/it]

Epoch 1/50:  54%|█████▍    | 33/61 [01:29<01:35,  3.41s/it]

Epoch 1/50:  56%|█████▌    | 34/61 [01:30<01:13,  2.73s/it]

Epoch 1/50:  57%|█████▋    | 35/61 [01:30<00:53,  2.04s/it]

Epoch 1/50:  59%|█████▉    | 36/61 [01:31<00:39,  1.57s/it]

Epoch 1/50:  61%|██████    | 37/61 [01:38<01:21,  3.38s/it]

Epoch 1/50:  62%|██████▏   | 38/61 [01:39<01:00,  2.64s/it]

Epoch 1/50:  64%|██████▍   | 39/61 [01:40<00:43,  1.99s/it]

Epoch 1/50:  66%|██████▌   | 40/61 [01:40<00:31,  1.52s/it]

Epoch 1/50:  67%|██████▋   | 41/61 [01:48<01:08,  3.43s/it]

Epoch 1/50:  69%|██████▉   | 42/61 [01:49<00:52,  2.77s/it]

Epoch 1/50:  70%|███████   | 43/61 [01:50<00:37,  2.07s/it]

Epoch 1/50:  72%|███████▏  | 44/61 [01:50<00:27,  1.59s/it]

Epoch 1/50:  74%|███████▍  | 45/61 [01:57<00:52,  3.27s/it]

Epoch 1/50:  75%|███████▌  | 46/61 [01:59<00:41,  2.79s/it]

Epoch 1/50:  77%|███████▋  | 47/61 [01:59<00:29,  2.08s/it]

Epoch 1/50:  79%|███████▊  | 48/61 [02:00<00:20,  1.59s/it]

Epoch 1/50:  80%|████████  | 49/61 [02:07<00:39,  3.26s/it]

Epoch 1/50:  82%|████████▏ | 50/61 [02:09<00:31,  2.86s/it]

Epoch 1/50:  84%|████████▎ | 51/61 [02:09<00:21,  2.13s/it]

Epoch 1/50:  85%|████████▌ | 52/61 [02:10<00:14,  1.63s/it]

Epoch 1/50:  87%|████████▋ | 53/61 [02:17<00:25,  3.22s/it]

Epoch 1/50:  89%|████████▊ | 54/61 [02:18<00:19,  2.76s/it]

Epoch 1/50:  90%|█████████ | 55/61 [02:19<00:12,  2.06s/it]

Epoch 1/50:  92%|█████████▏| 56/61 [02:19<00:07,  1.58s/it]

Epoch 1/50:  93%|█████████▎| 57/61 [02:26<00:12,  3.20s/it]

Epoch 1/50:  95%|█████████▌| 58/61 [02:28<00:08,  2.75s/it]

Epoch 1/50:  97%|█████████▋| 59/61 [02:28<00:04,  2.05s/it]

Epoch 1/50:  98%|█████████▊| 60/61 [02:29<00:01,  1.56s/it]

Epoch 1/50: 100%|██████████| 61/61 [02:30<00:00,  1.45s/it]

Epoch 1/50: 100%|██████████| 61/61 [02:30<00:00,  2.47s/it]




Epoch [1/50]  Loss_D: 0.4697  Loss_G: 8.8295  Val_Loss_G: 4.1850


Epoch 2/50:   0%|          | 0/61 [00:00<?, ?it/s]

Epoch 2/50:   2%|▏         | 1/61 [00:09<09:53,  9.88s/it]

Epoch 2/50:   3%|▎         | 2/61 [00:10<04:16,  4.34s/it]

Epoch 2/50:   5%|▍         | 3/61 [00:10<02:28,  2.56s/it]

Epoch 2/50:   7%|▋         | 4/61 [00:11<01:38,  1.73s/it]

Epoch 2/50:   8%|▊         | 5/61 [00:20<04:07,  4.41s/it]

Epoch 2/50:  10%|▉         | 6/61 [00:20<02:48,  3.06s/it]

Epoch 2/50:  11%|█▏        | 7/61 [00:21<01:59,  2.21s/it]

Epoch 2/50:  13%|█▎        | 8/61 [00:21<01:26,  1.64s/it]

Epoch 2/50:  15%|█▍        | 9/61 [00:29<03:11,  3.68s/it]

Epoch 2/50:  16%|█▋        | 10/61 [00:30<02:16,  2.68s/it]

Epoch 2/50:  18%|█▊        | 11/61 [00:30<01:39,  2.00s/it]

Epoch 2/50:  20%|█▉        | 12/61 [00:31<01:15,  1.54s/it]

Epoch 2/50:  21%|██▏       | 13/61 [00:39<02:53,  3.61s/it]

Epoch 2/50:  23%|██▎       | 14/61 [00:40<02:04,  2.66s/it]

Epoch 2/50:  25%|██▍       | 15/61 [00:40<01:31,  1.99s/it]

Epoch 2/50:  26%|██▌       | 16/61 [00:40<01:08,  1.52s/it]

Epoch 2/50:  28%|██▊       | 17/61 [00:49<02:38,  3.61s/it]

Epoch 2/50:  30%|██▉       | 18/61 [00:49<01:54,  2.66s/it]

Epoch 2/50:  31%|███       | 19/61 [00:50<01:23,  1.99s/it]

Epoch 2/50:  33%|███▎      | 20/61 [00:50<01:02,  1.53s/it]

Epoch 2/50:  34%|███▍      | 21/61 [00:59<02:24,  3.61s/it]

Epoch 2/50:  36%|███▌      | 22/61 [00:59<01:43,  2.66s/it]

Epoch 2/50:  38%|███▊      | 23/61 [01:00<01:15,  1.99s/it]

Epoch 2/50:  39%|███▉      | 24/61 [01:00<00:56,  1.53s/it]

Epoch 2/50:  41%|████      | 25/61 [01:08<02:04,  3.46s/it]

Epoch 2/50:  43%|████▎     | 26/61 [01:09<01:29,  2.56s/it]

Epoch 2/50:  44%|████▍     | 27/61 [01:09<01:05,  1.92s/it]

Epoch 2/50:  46%|████▌     | 28/61 [01:09<00:48,  1.48s/it]

Epoch 2/50:  48%|████▊     | 29/61 [01:18<01:52,  3.52s/it]

Epoch 2/50:  49%|████▉     | 30/61 [01:18<01:20,  2.60s/it]

Epoch 2/50:  51%|█████     | 31/61 [01:19<00:58,  1.95s/it]

Epoch 2/50:  52%|█████▏    | 32/61 [01:19<00:43,  1.50s/it]

Epoch 2/50:  54%|█████▍    | 33/61 [01:27<01:39,  3.57s/it]

Epoch 2/50:  56%|█████▌    | 34/61 [01:28<01:11,  2.63s/it]

Epoch 2/50:  57%|█████▋    | 35/61 [01:28<00:51,  1.98s/it]

Epoch 2/50:  59%|█████▉    | 36/61 [01:29<00:37,  1.52s/it]

Epoch 2/50:  61%|██████    | 37/61 [01:37<01:24,  3.52s/it]

Epoch 2/50:  62%|██████▏   | 38/61 [01:37<00:59,  2.60s/it]

Epoch 2/50:  64%|██████▍   | 39/61 [01:38<00:43,  1.96s/it]

Epoch 2/50:  66%|██████▌   | 40/61 [01:38<00:31,  1.50s/it]

Epoch 2/50:  67%|██████▋   | 41/61 [01:46<01:09,  3.48s/it]

Epoch 2/50:  69%|██████▉   | 42/61 [01:47<00:48,  2.57s/it]

Epoch 2/50:  70%|███████   | 43/61 [01:47<00:34,  1.94s/it]

Epoch 2/50:  72%|███████▏  | 44/61 [01:48<00:25,  1.49s/it]

Epoch 2/50:  74%|███████▍  | 45/61 [01:56<00:56,  3.54s/it]

Epoch 2/50:  75%|███████▌  | 46/61 [01:57<00:39,  2.62s/it]

Epoch 2/50:  77%|███████▋  | 47/61 [01:57<00:27,  1.96s/it]

Epoch 2/50:  79%|███████▊  | 48/61 [01:57<00:19,  1.51s/it]

Epoch 2/50:  80%|████████  | 49/61 [02:06<00:43,  3.66s/it]

Epoch 2/50:  82%|████████▏ | 50/61 [02:07<00:29,  2.70s/it]

Epoch 2/50:  84%|████████▎ | 51/61 [02:07<00:20,  2.02s/it]

Epoch 2/50:  85%|████████▌ | 52/61 [02:07<00:14,  1.56s/it]

Epoch 2/50:  87%|████████▋ | 53/61 [02:16<00:28,  3.58s/it]

Epoch 2/50:  89%|████████▊ | 54/61 [02:16<00:18,  2.64s/it]

Epoch 2/50:  90%|█████████ | 55/61 [02:17<00:11,  1.98s/it]

Epoch 2/50:  92%|█████████▏| 56/61 [02:17<00:07,  1.52s/it]

Epoch 2/50:  93%|█████████▎| 57/61 [02:25<00:13,  3.35s/it]

Epoch 2/50:  95%|█████████▌| 58/61 [02:25<00:07,  2.47s/it]

Epoch 2/50:  97%|█████████▋| 59/61 [02:26<00:03,  1.86s/it]

Epoch 2/50:  98%|█████████▊| 60/61 [02:26<00:01,  1.43s/it]

Epoch 2/50: 100%|██████████| 61/61 [02:27<00:00,  1.44s/it]

Epoch 2/50: 100%|██████████| 61/61 [02:28<00:00,  2.43s/it]




Epoch [2/50]  Loss_D: 0.9594  Loss_G: 3.5042  Val_Loss_G: 1.4430


Epoch 3/50:   0%|          | 0/61 [00:00<?, ?it/s]

Epoch 3/50:   2%|▏         | 1/61 [00:10<10:13, 10.23s/it]

Epoch 3/50:   3%|▎         | 2/61 [00:10<04:23,  4.47s/it]

Epoch 3/50:   5%|▍         | 3/61 [00:11<02:33,  2.64s/it]

Epoch 3/50:   7%|▋         | 4/61 [00:11<01:40,  1.77s/it]

Epoch 3/50:   8%|▊         | 5/61 [00:19<03:53,  4.16s/it]

Epoch 3/50:  10%|▉         | 6/61 [00:20<02:39,  2.90s/it]

Epoch 3/50:  11%|█▏        | 7/61 [00:20<01:53,  2.10s/it]

Epoch 3/50:  13%|█▎        | 8/61 [00:21<01:23,  1.57s/it]

Epoch 3/50:  15%|█▍        | 9/61 [00:29<03:10,  3.66s/it]

Epoch 3/50:  16%|█▋        | 10/61 [00:30<02:16,  2.67s/it]

Epoch 3/50:  18%|█▊        | 11/61 [00:30<01:39,  2.00s/it]

Epoch 3/50:  20%|█▉        | 12/61 [00:30<01:14,  1.53s/it]

Epoch 3/50:  21%|██▏       | 13/61 [00:39<02:48,  3.52s/it]

Epoch 3/50:  23%|██▎       | 14/61 [00:39<02:01,  2.58s/it]

Epoch 3/50:  25%|██▍       | 15/61 [00:39<01:29,  1.94s/it]

Epoch 3/50:  26%|██▌       | 16/61 [00:40<01:07,  1.49s/it]

Epoch 3/50:  28%|██▊       | 17/61 [00:48<02:34,  3.52s/it]

Epoch 3/50:  30%|██▉       | 18/61 [00:49<01:51,  2.60s/it]

Epoch 3/50:  31%|███       | 19/61 [00:49<01:22,  1.95s/it]

Epoch 3/50:  33%|███▎      | 20/61 [00:49<01:01,  1.50s/it]

Epoch 3/50:  34%|███▍      | 21/61 [00:58<02:19,  3.48s/it]

Epoch 3/50:  36%|███▌      | 22/61 [00:58<01:40,  2.57s/it]

Epoch 3/50:  38%|███▊      | 23/61 [00:58<01:13,  1.94s/it]

Epoch 3/50:  39%|███▉      | 24/61 [00:59<00:55,  1.49s/it]

Epoch 3/50:  41%|████      | 25/61 [01:07<02:09,  3.59s/it]

Epoch 3/50:  43%|████▎     | 26/61 [01:08<01:32,  2.65s/it]

Epoch 3/50:  44%|████▍     | 27/61 [01:08<01:07,  1.99s/it]

Epoch 3/50:  46%|████▌     | 28/61 [01:09<00:50,  1.52s/it]

Epoch 3/50:  48%|████▊     | 29/61 [01:17<01:52,  3.52s/it]

Epoch 3/50:  49%|████▉     | 30/61 [01:17<01:20,  2.60s/it]

Epoch 3/50:  51%|█████     | 31/61 [01:18<00:58,  1.96s/it]

Epoch 3/50:  52%|█████▏    | 32/61 [01:18<00:43,  1.51s/it]

Epoch 3/50:  54%|█████▍    | 33/61 [01:26<01:37,  3.48s/it]

Epoch 3/50:  56%|█████▌    | 34/61 [01:27<01:09,  2.56s/it]

Epoch 3/50:  57%|█████▋    | 35/61 [01:27<00:50,  1.93s/it]

Epoch 3/50:  59%|█████▉    | 36/61 [01:28<00:37,  1.48s/it]

Epoch 3/50:  61%|██████    | 37/61 [01:36<01:24,  3.51s/it]

Epoch 3/50:  62%|██████▏   | 38/61 [01:36<00:59,  2.58s/it]

Epoch 3/50:  64%|██████▍   | 39/61 [01:37<00:42,  1.95s/it]

Epoch 3/50:  66%|██████▌   | 40/61 [01:37<00:31,  1.50s/it]

Epoch 3/50:  67%|██████▋   | 41/61 [01:45<01:10,  3.51s/it]

Epoch 3/50:  69%|██████▉   | 42/61 [01:46<00:49,  2.59s/it]

Epoch 3/50:  70%|███████   | 43/61 [01:46<00:34,  1.94s/it]

Epoch 3/50:  72%|███████▏  | 44/61 [01:47<00:25,  1.49s/it]

Epoch 3/50:  74%|███████▍  | 45/61 [01:55<00:55,  3.46s/it]

Epoch 3/50:  75%|███████▌  | 46/61 [01:55<00:38,  2.56s/it]

Epoch 3/50:  77%|███████▋  | 47/61 [01:56<00:26,  1.92s/it]

Epoch 3/50:  79%|███████▊  | 48/61 [01:56<00:19,  1.48s/it]

Epoch 3/50:  80%|████████  | 49/61 [02:04<00:41,  3.47s/it]

Epoch 3/50:  82%|████████▏ | 50/61 [02:05<00:28,  2.57s/it]

Epoch 3/50:  84%|████████▎ | 51/61 [02:05<00:19,  1.93s/it]

Epoch 3/50:  85%|████████▌ | 52/61 [02:06<00:13,  1.49s/it]

Epoch 3/50:  87%|████████▋ | 53/61 [02:14<00:28,  3.55s/it]

Epoch 3/50:  89%|████████▊ | 54/61 [02:14<00:18,  2.62s/it]

Epoch 3/50:  90%|█████████ | 55/61 [02:15<00:11,  1.97s/it]

Epoch 3/50:  92%|█████████▏| 56/61 [02:15<00:07,  1.52s/it]

Epoch 3/50:  93%|█████████▎| 57/61 [02:23<00:13,  3.42s/it]

Epoch 3/50:  95%|█████████▌| 58/61 [02:24<00:07,  2.52s/it]

Epoch 3/50:  97%|█████████▋| 59/61 [02:24<00:03,  1.89s/it]

Epoch 3/50:  98%|█████████▊| 60/61 [02:25<00:01,  1.45s/it]

Epoch 3/50: 100%|██████████| 61/61 [02:26<00:00,  1.43s/it]

Epoch 3/50: 100%|██████████| 61/61 [02:26<00:00,  2.40s/it]




Epoch [3/50]  Loss_D: 0.8305  Loss_G: 3.3152  Val_Loss_G: 2.7380


Epoch 4/50:   0%|          | 0/61 [00:00<?, ?it/s]

Epoch 4/50:   2%|▏         | 1/61 [00:09<09:55,  9.93s/it]

Epoch 4/50:   3%|▎         | 2/61 [00:10<04:16,  4.35s/it]

Epoch 4/50:   5%|▍         | 3/61 [00:10<02:29,  2.57s/it]

Epoch 4/50:   7%|▋         | 4/61 [00:11<01:38,  1.73s/it]

Epoch 4/50:   8%|▊         | 5/61 [00:19<03:45,  4.02s/it]

Epoch 4/50:  10%|▉         | 6/61 [00:19<02:34,  2.81s/it]

Epoch 4/50:  11%|█▏        | 7/61 [00:20<01:51,  2.07s/it]

Epoch 4/50:  13%|█▎        | 8/61 [00:20<01:22,  1.56s/it]

Epoch 4/50:  15%|█▍        | 9/61 [00:28<03:06,  3.59s/it]

Epoch 4/50:  16%|█▋        | 10/61 [00:29<02:13,  2.62s/it]

Epoch 4/50:  18%|█▊        | 11/61 [00:30<01:44,  2.09s/it]

Epoch 4/50:  20%|█▉        | 12/61 [00:30<01:17,  1.59s/it]

Epoch 4/50:  21%|██▏       | 13/61 [00:39<02:55,  3.66s/it]

Epoch 4/50:  23%|██▎       | 14/61 [00:39<02:06,  2.69s/it]

Epoch 4/50:  25%|██▍       | 15/61 [00:39<01:32,  2.01s/it]

Epoch 4/50:  26%|██▌       | 16/61 [00:40<01:09,  1.55s/it]

Epoch 4/50:  28%|██▊       | 17/61 [00:48<02:38,  3.60s/it]

Epoch 4/50:  30%|██▉       | 18/61 [00:49<01:54,  2.65s/it]

Epoch 4/50:  31%|███       | 19/61 [00:49<01:23,  1.99s/it]

Epoch 4/50:  33%|███▎      | 20/61 [00:50<01:02,  1.52s/it]

Epoch 4/50:  34%|███▍      | 21/61 [00:58<02:21,  3.55s/it]

Epoch 4/50:  36%|███▌      | 22/61 [00:58<01:42,  2.62s/it]

Epoch 4/50:  38%|███▊      | 23/61 [00:59<01:14,  1.97s/it]

Epoch 4/50:  39%|███▉      | 24/61 [00:59<00:55,  1.51s/it]

Epoch 4/50:  41%|████      | 25/61 [01:08<02:08,  3.58s/it]

Epoch 4/50:  43%|████▎     | 26/61 [01:08<01:32,  2.64s/it]

Epoch 4/50:  44%|████▍     | 27/61 [01:09<01:07,  1.98s/it]

Epoch 4/50:  46%|████▌     | 28/61 [01:09<00:50,  1.52s/it]

Epoch 4/50:  48%|████▊     | 29/61 [01:17<01:55,  3.60s/it]

Epoch 4/50:  49%|████▉     | 30/61 [01:18<01:22,  2.65s/it]

Epoch 4/50:  51%|█████     | 31/61 [01:18<00:59,  1.99s/it]

Epoch 4/50:  52%|█████▏    | 32/61 [01:19<00:44,  1.52s/it]

Epoch 4/50:  54%|█████▍    | 33/61 [01:27<01:39,  3.56s/it]

Epoch 4/50:  56%|█████▌    | 34/61 [01:28<01:10,  2.63s/it]

Epoch 4/50:  57%|█████▋    | 35/61 [01:28<00:51,  1.97s/it]

Epoch 4/50:  59%|█████▉    | 36/61 [01:28<00:37,  1.51s/it]

Epoch 4/50:  61%|██████    | 37/61 [01:37<01:25,  3.55s/it]

Epoch 4/50:  62%|██████▏   | 38/61 [01:37<01:00,  2.62s/it]

Epoch 4/50:  64%|██████▍   | 39/61 [01:38<00:43,  1.96s/it]

Epoch 4/50:  66%|██████▌   | 40/61 [01:38<00:31,  1.51s/it]

Epoch 4/50:  67%|██████▋   | 41/61 [01:47<01:12,  3.62s/it]

Epoch 4/50:  69%|██████▉   | 42/61 [01:47<00:50,  2.67s/it]

Epoch 4/50:  70%|███████   | 43/61 [01:47<00:36,  2.00s/it]

Epoch 4/50:  72%|███████▏  | 44/61 [01:48<00:26,  1.53s/it]

Epoch 4/50:  74%|███████▍  | 45/61 [01:56<00:57,  3.60s/it]

In [None]:
# Plot losses
plt.figure(figsize=(8, 5))
plt.plot(range(1, len(train_g_losses) + 1), train_g_losses, label='Training Generator Loss')
plt.plot(range(1, len(val_g_losses) + 1), val_g_losses, label='Validation Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Generator Training vs Validation Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("gan_outputs/loss_plot.png")
plt.show()

In [None]:
# Generate final images using best generator
os.makedirs("gan_outputs/generated", exist_ok=True)
G.load_state_dict(torch.load("gan_outputs/best_generator.pth"))
G.eval()
with torch.no_grad():
    for i in tqdm(range(0, 10000, 64), desc="Generating final images"):
        z = torch.randn(64, z_dim, 1, 1, device=device)
        gen_imgs = G(z)
        for j in range(gen_imgs.size(0)):
            save_image(gen_imgs[j], f"gan_outputs/generated/{i + j:05d}.png", normalize=True)