# ü§™ CelebA ‰∫∫ËÑ∏Êï∞ÊçÆÈõÜ

Âú®Êú¨notebook‰∏≠ÔºåÊàë‰ª¨Â∞ÜÊºîÁ§∫ËÆ≠ÁªÉËá™Â∑±ÁöÑ Wasserstein GANÔºàWGANÔºâ‰ª•ÁîüÊàê CelebA ‰∫∫ËÑ∏Êï∞ÊçÆÁöÑÂÆåÊï¥Ê≠•È™§„ÄÇ

ËøôÊÆµ‰ª£Á†ÅÊîπÁºñËá™ Aakash Kumar Nain Âú® Keras ÁΩëÁ´ô‰∏äÊèê‰æõÁöÑ‰ºòÁßÄÊïôÁ®ã [WGAN-GP ÊïôÁ®ã](https://keras.io/examples/generative/wgan_gp/)„ÄÇ

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

## 0. Parameters <a name="parameters"></a>

In [None]:
IMAGE_SIZE = 64
CHANNELS = 3
BATCH_SIZE = 512
Z_DIM = 128
NUM_FEATURES = 64
LEARNING_RATE = 0.0002
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.9
EPOCHS = 200
CRITIC_STEPS = 3
GP_WEIGHT = 10.0
LOAD_MODEL = False

DATA_PATH = "./data/celeba-dataset/img_align_celeba"
OUTPUT_DIR = "./output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
MODEL_DIR = "./models"
os.makedirs(MODEL_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

## 1. Prepare the data <a name="prepare"></a>

In [None]:
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*CHANNELS, [0.5]*CHANNELS)
])

train_dataset = datasets.ImageFolder(root=DATA_PATH, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

In [None]:
def show_batch(batch, nrow=8):
    grid = utils.make_grid(batch, nrow=nrow, normalize=True)
    plt.figure(figsize=(12,6))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis("off")
    plt.show()

sample_batch = next(iter(train_loader))[0]
show_batch(sample_batch)

## 2. Build the WGAN-GP <a name="build"></a>

In [None]:
class Critic(nn.Module):
    def __init__(self, channels=3):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Conv2d(512, 1, 4, 1, 0),
            nn.Flatten()
        )

    def forward(self, x):
        return self.model(x)


In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=128, channels=3):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        return self.model(z)

In [None]:
critic = Critic(channels=CHANNELS).to(device)
generator = Generator(z_dim=Z_DIM, channels=CHANNELS).to(device)

In [None]:
def gradient_penalty(critic, real, fake):
    batch_size = real.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    interpolated = alpha * real + (1 - alpha) * fake
    interpolated.requires_grad_(True)

    pred = critic(interpolated)
    grads = torch.autograd.grad(
        outputs=pred,
        inputs=interpolated,
        grad_outputs=torch.ones_like(pred),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    grads = grads.view(grads.size(0), -1)
    gp = ((grads.norm(2, dim=1) - 1) ** 2).mean()
    return gp

## 3. Train the GAN <a name="train"></a>

In [None]:
c_optimizer = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(ADAM_BETA_1, ADAM_BETA_2))
g_optimizer = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(ADAM_BETA_1, ADAM_BETA_2))

def save_generated_images(generator, epoch, n_samples=10):
    z = torch.randn(n_samples, Z_DIM, device=device)
    gen_imgs = generator(z).detach().cpu()
    grid = utils.make_grid(gen_imgs, nrow=n_samples, normalize=True)
    plt.figure(figsize=(12,6))
    plt.imshow(grid.permute(1,2,0).numpy())
    plt.axis("off")
    plt.savefig(f"{OUTPUT_DIR}/generated_img_{epoch:03d}.png")
    plt.close()

In [None]:
for epoch in range(1, EPOCHS+1):
    for real_imgs, _ in tqdm(train_loader):
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)

        # Train critic
        for _ in range(CRITIC_STEPS):
            z = torch.randn(batch_size, Z_DIM, device=device)
            fake_imgs = generator(z)

            c_optimizer.zero_grad()
            real_preds = critic(real_imgs)
            fake_preds = critic(fake_imgs.detach())
            c_loss = fake_preds.mean() - real_preds.mean()
            gp = gradient_penalty(critic, real_imgs, fake_imgs)
            total_c_loss = c_loss + GP_WEIGHT * gp
            total_c_loss.backward()
            c_optimizer.step()

        # Train generator
        z = torch.randn(batch_size, Z_DIM, device=device)
        fake_imgs = generator(z)
        g_optimizer.zero_grad()
        g_loss = -critic(fake_imgs).mean()
        g_loss.backward()
        g_optimizer.step()

    print(f"Epoch [{epoch}/{EPOCHS}] c_loss: {total_c_loss.item():.4f}, g_loss: {g_loss.item():.4f}")
    save_generated_images(generator, epoch)

In [None]:
# Save the final models
torch.save(generator.state_dict(), f"{MODEL_DIR}/generator.pth")
torch.save(critic.state_dict(), f"{MODEL_DIR}/critic.pth")

## Generate images

In [None]:
# %%
def generate_images(generator, n_samples=10):
    z = torch.randn(n_samples, Z_DIM, device=device)
    imgs = generator(z).detach().cpu()
    grid = utils.make_grid(imgs, nrow=n_samples, normalize=True)
    plt.figure(figsize=(12,6))
    plt.imshow(grid.permute(1,2,0).numpy())
    plt.axis("off")
    plt.show()

generate_images(generator, n_samples=10)