In [None]:
%pip install gdown


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path


In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:
# Load CelebA dataset (no labels, using ImageFolder)
def load_celeba(data_dir='img-align-celeba'):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    try:
        dataset = torchvision.datasets.ImageFolder(root=data_dir, transform=transform)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True, num_workers=0, pin_memory=True)
        print(f"Loaded {len(dataset)} images from {data_dir}")
        return dataloader
    except Exception as e:
        print(f"Error loading dataset: {e}")
        raise

# Load dataset
dataloader = load_celeba()


In [None]:
# Parameters
latent_dim = 100
epochs = 50
image_shape = (3, 64, 64)


In [None]:
# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 512 * 8 * 8),
            nn.BatchNorm1d(512 * 8 * 8),
            nn.LeakyReLU(0.2),
            nn.Unflatten(1, (512, 8, 8)),
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh()
        )
    def forward(self, x):
        return self.model(x)


In [None]:
# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Flatten(),
            nn.Linear(512 * 4 * 4, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x)


In [None]:
# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss and optimizers
criterion = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))


In [None]:
# Training loop
def train_gan(epochs):
    d_losses, g_losses = [], []
    for epoch in range(epochs):
        d_loss_avg, g_loss_avg = 0, 0
        steps = 0
        for real_imgs, _ in dataloader:
            real_imgs = real_imgs.to(device)
            batch_size = real_imgs.size(0)

            # Train discriminator
            d_optimizer.zero_grad()
            real_labels = torch.ones(batch_size, 1).to(device) * 0.9
            fake_labels = torch.zeros(batch_size, 1).to(device)
            real_output = discriminator(real_imgs)
            d_loss_real = criterion(real_output, real_labels)

            noise = torch.randn(batch_size, latent_dim).to(device)
            fake_imgs = generator(noise)
            fake_output = discriminator(fake_imgs.detach())
            d_loss_fake = criterion(fake_output, fake_labels)

            d_loss = 0.5 * (d_loss_real + d_loss_fake)
            d_loss.backward()
            d_optimizer.step()

            # Train generator
            g_optimizer.zero_grad()
            fake_output = discriminator(fake_imgs)
            g_loss = criterion(fake_output, real_labels)
            g_loss.backward()
            g_optimizer.step()

            d_loss_avg += d_loss.item()
            g_loss_avg += g_loss.item()
            steps += 1

        d_loss_avg /= steps
        g_loss_avg /= steps
        d_losses.append(d_loss_avg)
        g_losses.append(g_loss_avg)
        print(f"Epoch {epoch}, D Loss: {d_loss_avg:.4f}, G Loss: {g_loss_avg:.4f}")

        if epoch % 5 == 0 or epoch == epochs - 1:
            with torch.no_grad():
                noise = torch.randn(10, latent_dim).to(device)
                generated_images = generator(noise).cpu().numpy()
                generated_images = (generated_images * 0.5 + 0.5).transpose(0, 2, 3, 1)
                plt.figure(figsize=(20, 2))
                for i in range(10):
                    plt.subplot(1, 10, i + 1)
                    plt.imshow(generated_images[i])
                    plt.axis('off')
                plt.savefig(f'celeba_epoch_{epoch}.png')
                plt.close()

    plt.plot(d_losses, label='Discriminator Loss')
    plt.plot(g_losses, label='Generator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('celeba_loss.png')
    plt.close()


In [None]:
# Debug function (optional)
def debug_shapes(dataloader, discriminator, device):
    for real_imgs, _ in dataloader:
        real_imgs = real_imgs.to(device)
        print("Input shape:", real_imgs.shape)
        class DebugDiscriminator(Discriminator):
            def forward(self, x):
                for i, layer in enumerate(self.model):
                    x = layer(x)
                    if i in [0, 3, 6, 9, 12, 13]:
                        print(f"Layer {i} ({layer.__class__.__name__}): {x.shape}")
                return x
        debug_discriminator = DebugDiscriminator().to(device)
        output = debug_discriminator(real_imgs)
        print("Output shape:", output.shape)
        break


In [None]:
if __name__ == "__main__":
    train_gan(epochs)
