**Importing modules and setting hyperparameters**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image
import os
import numpy as np 


device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters


BATCH_SIZE = 128
IMG_HEIGHT = 64
IMG_WIDTH = 64
CHANNELS = 3
LATENT_DIM = 100
EPOCHS = 150
sample_dir = './output_samples'
model_dir = './output_models'
os.makedirs(sample_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

**Prepping Dataset**

In [None]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(int(IMG_HEIGHT * 1.1)),  
    transforms.CenterCrop((IMG_HEIGHT, IMG_WIDTH)),  
    transforms.RandomHorizontalFlip(p=0.5),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

dataset = datasets.ImageFolder(root="/kaggle/input/celeba-dataset", transform=transform)

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=16,
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True
)


**Creating Generator**

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(LATENT_DIM, 8*8*256, bias=False),
            nn.BatchNorm1d(8*8*256),
            nn.ReLU(True),
            nn.Unflatten(1, (256, 8, 8)),

            nn.ConvTranspose2d(256, 128, 5, 2, 2, output_padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 5, 2, 2, output_padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, CHANNELS, 5, 2, 2, output_padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)


**Creating Discriminator**

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(CHANNELS, 64, 5, 2, 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            nn.Conv2d(64, 128, 5, 2, 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            nn.Flatten(),
            nn.Linear(128 * (IMG_HEIGHT // 4) * (IMG_WIDTH // 4), 1)
        )

    def forward(self, x):
        return self.model(x)


**Module init , optimizers and losses**

In [None]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

criterion = nn.BCEWithLogitsLoss()

LR_G = 0.0002
LR_D = 0.0002

generator_optimizer = optim.Adam(generator.parameters(), lr=LR_G, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=LR_D, betas=(0.5, 0.999))



**helper functions**

In [None]:
def generator_loss(fake_output):
    labels = torch.ones_like(fake_output, device=device)
    return criterion(fake_output, labels)

def discriminator_loss(real_output, fake_output):
    real_labels = torch.ones_like(real_output, device=device)
    fake_labels = torch.zeros_like(fake_output, device=device)
    real_loss = criterion(real_output, real_labels)
    fake_loss = criterion(fake_output, fake_labels)
    return real_loss + fake_loss

def save_generated_images(images, epoch):
    images = (images + 1) / 2  # Rescale [-1,1] to [0,1]
    save_image(images, os.path.join(sample_dir, f'epoch_{epoch:03d}.png'), nrow=4)


**Training**

In [None]:
best_loss = np.inf
patience = 4
cooldown = 0
factor = 0.5
min_lr = 1e-6
wait = 0

for epoch in range(1, EPOCHS + 1):
    gen_loss_epoch, disc_loss_epoch = 0.0, 0.0
    batches = 0

    for real_images, _ in dataloader:
        real_images = real_images.to(device)
        batch_size = real_images.size(0)
        noise = torch.randn(batch_size, LATENT_DIM, device=device)

        # ---- Train Discriminator ----
        discriminator_optimizer.zero_grad()
        fake_images = generator(noise)
        real_output = discriminator(real_images)
        fake_output = discriminator(fake_images.detach())
        disc_loss = discriminator_loss(real_output, fake_output)
        disc_loss.backward()
        discriminator_optimizer.step()

        # ---- Train Generator ----
        generator_optimizer.zero_grad()
        fake_output = discriminator(fake_images)
        gen_loss = generator_loss(fake_output)
        gen_loss.backward()
        generator_optimizer.step()

        gen_loss_epoch += gen_loss.item()
        disc_loss_epoch += disc_loss.item()
        batches += 1

    gen_loss_epoch /= batches
    disc_loss_epoch /= batches

    print(f"Epoch {epoch:03d} - Gen Loss: {gen_loss_epoch:.4f}, Disc Loss: {disc_loss_epoch:.4f}")

    # ---- Save generated images every 10 epochs ----
    if epoch % 10 == 0:
        with torch.no_grad():
            fixed_noise = torch.randn(16, LATENT_DIM, device=device)
            fake_images = generator(fixed_noise)
            save_generated_images(fake_images, epoch)

    # ---- Model checkpointing ----
    if gen_loss_epoch < best_loss:
        best_loss = gen_loss_epoch
        torch.save(generator.state_dict(), os.path.join(model_dir, 'generator_best.pth'))
        torch.save(discriminator.state_dict(), os.path.join(model_dir, 'discriminator_best.pth'))
        wait, cooldown = 0, 0
    else:
        wait += 1
        current_lr = generator_optimizer.param_groups[0]['lr']
        if wait >= patience and current_lr > min_lr and cooldown == 0:
            new_lr = max(current_lr * factor, min_lr)
            print(f"ReduceLROnPlateau: Reducing LR to {new_lr:.8f}")
            for param_group in generator_optimizer.param_groups:
                param_group['lr'] = new_lr
            for param_group in discriminator_optimizer.param_groups:
                param_group['lr'] = new_lr
            cooldown = patience // 2
            wait = 0
        if cooldown > 0:
            cooldown -= 1

    torch.save(generator.state_dict(), os.path.join(model_dir, 'generator_latest.pth'))
    torch.save(discriminator.state_dict(), os.path.join(model_dir, 'discriminator_latest.pth'))
