**Importing modules**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image
from tqdm import tqdm
import os
from PIL import Image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


**Building the Discriminator and Generator models**

In [None]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img, features_d, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            self.block(features_d, features_d * 2, 4, 2, 1),
            self.block(features_d * 2, features_d * 4, 4, 2, 1),
            self.block(features_d * 4, features_d * 8, 4, 2, 1),

            nn.Conv2d(features_d * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.disc(x)

class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            self.block(z_dim, features_g * 8, 4, 1, 0),
            self.block(features_g * 8, features_g * 4, 4, 2, 1),
            self.block(features_g * 4, features_g * 2, 4, 2, 1),
            self.block(features_g * 2, features_g, 4, 2, 1),
            nn.ConvTranspose2d(features_g, channels_img, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True),
        )

    def forward(self, x):
        return self.gen(x)


**Setting up hyperparameters**

In [None]:
lr_disc = 2e-5
lr_gen = 1e-4
batch_size = 128
image_size = 64
channels_img = 3
z_dim = 100
features_disc = 64
features_gen = 64
num_epochs = 200
smooth_label = 0.9


**Prepping the dataset**

In [None]:
image_folder_path = "/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba"
assert os.path.exists(image_folder_path), f"Path does not exist: {image_folder_path}"

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * 3, [0.5] * 3)
])

class CelebADataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = [f for f in os.listdir(root_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.root_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, 0  # dummy label

dataset = CelebADataset(image_folder_path, transform=transform)

loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4
)


**Setting up optimizers and losses**

In [None]:
gen = Generator(z_dim, channels_img, features_gen).to(device)
disc = Discriminator(channels_img, features_disc).to(device)

initialize_weights(gen)
initialize_weights(disc)

if torch.cuda.device_count() > 1:
    gen = nn.DataParallel(gen)
    disc = nn.DataParallel(disc)

opt_gen = optim.Adam(gen.parameters(), lr=lr_gen, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=lr_disc, betas=(0.5, 0.999))
criterion = nn.BCELoss()


**Training loop**

In [None]:
os.makedirs("outputs", exist_ok=True)

for epoch in range(num_epochs):
    gen.train()
    disc.train()

    total_loss_gen = 0
    total_loss_disc = 0
    batches = 0

    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for real, _ in pbar:
        real = real.to(device)
        batch_size_curr = real.size(0)

        noise = torch.randn(batch_size_curr, z_dim, 1, 1, device=device)

        # Train Discriminator
        disc.zero_grad()
        label_real = torch.full((batch_size_curr,), smooth_label, device=device)
        label_fake = torch.zeros(batch_size_curr, device=device)

        output_real = disc(real).view(-1)
        loss_real = criterion(output_real, label_real)

        fake = gen(noise)
        output_fake = disc(fake.detach()).view(-1)
        loss_fake = criterion(output_fake, label_fake)

        loss_disc = (loss_real + loss_fake) / 2
        loss_disc.backward()
        opt_disc.step()

        # Train Generator
        gen.zero_grad()
        output_fake_for_gen = disc(fake).view(-1)
        loss_gen = criterion(output_fake_for_gen, label_real)
        loss_gen.backward()
        opt_gen.step()

        total_loss_gen += loss_gen.item()
        total_loss_disc += loss_disc.item()
        batches += 1

        pbar.set_postfix(gen_loss=loss_gen.item(), disc_loss=loss_disc.item())

    avg_loss_gen = total_loss_gen / batches
    avg_loss_disc = total_loss_disc / batches

    print(f"Epoch [{epoch+1}/{num_epochs}] Generator Loss: {avg_loss_gen:.4f}, Discriminator Loss: {avg_loss_disc:.4f}")

    # Save sample images and model checkpoints every few epochs
    if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
        gen.eval()
        with torch.no_grad():
            fake_images = gen(torch.randn(32, z_dim, 1, 1, device=device))
            grid_real = make_grid(real[:32], normalize=True, value_range=(-1,1))
            grid_fake = make_grid(fake_images, normalize=True, value_range=(-1,1))

            # If size mismatch, interpolate fake to match real
            if grid_real.shape != grid_fake.shape:
                grid_fake = F.interpolate(grid_fake.unsqueeze(0), size=grid_real.shape[1:], mode='bilinear', align_corners=False).squeeze(0)

            comparison = torch.cat((grid_real, grid_fake), dim=1)
            save_image(comparison, f"outputs/real_vs_fake_epoch_{epoch+1}.png")

        gen.train()
    
    torch.cuda.synchronize()
