# ü§™ Êù°‰ª∂ WGAN-GP Âú® CelebA ‰∫∫ËÑ∏Êï∞ÊçÆÈõÜ‰∏äËÆ≠ÁªÉ

Êú¨notebookÊºîÁ§∫Â¶Ç‰ΩïÂú® CelebA ‰∫∫ËÑ∏Êï∞ÊçÆÈõÜ‰∏äËÆ≠ÁªÉËá™Â∑±ÁöÑÊù°‰ª∂ÁîüÊàêÂØπÊäóÁΩëÁªú (Conditional GAN, CGAN)„ÄÇ

‰ª£Á†ÅÊîπÁºñËá™ Sayak Paul ÁöÑ‰ºòÁßÄ [CGAN ÊïôÁ®ã](https://keras.io/examples/generative/conditional_gan/)

In [None]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 0. Parameters <a name="parameters"></a>

In [None]:
IMAGE_SIZE = 64
CHANNELS = 3
CLASSES = 2
BATCH_SIZE = 128
Z_DIM = 32
LEARNING_RATE = 5e-5
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.9
EPOCHS = 20
CRITIC_STEPS = 3
GP_WEIGHT = 10.0
LABEL = "Blond_Hair"
DATA_PATH = "/app/data/celeba-dataset"
LOAD_MODEL = False

## 1. Prepare the data <a name="prepare"></a>

In [None]:
attributes = pd.read_csv(os.path.join(DATA_PATH, "list_attr_celeba.csv"))
labels = attributes[LABEL].tolist()
int_labels = [1 if x == 1 else 0 for x in labels]


In [None]:
class CelebADataset(Dataset):
    def __init__(self, img_dir, labels, transform=None):
        self.img_dir = img_dir
        self.img_files = sorted(os.listdir(img_dir))
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_files[idx])
        img = Image.open(img_path).convert("RGB")
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)  # [-1, 1]
])

dataset = CelebADataset(
    img_dir=os.path.join(DATA_PATH, "img_align_celeba"),
    labels=int_labels,
    transform=transform
)

train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

## 2. Build the GAN <a name="build"></a>

In [None]:
class Critic(nn.Module):
    def __init__(self, img_channels, label_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(img_channels + label_dim, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Conv2d(128, 128, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Conv2d(128, 128, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Conv2d(128, 1, 4, 1, 0)
        )

    def forward(self, x, labels):
        # Â∞Ü one-hot label Êâ©Â±ïÂà∞ÂõæÂÉèÂ∞∫ÂØ∏Âπ∂ÊãºÊé•
        label_map = labels[:, :, None, None].repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat([x, label_map], dim=1)
        return self.model(x).view(x.size(0), -1)


In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, label_dim, img_channels):
        super().__init__()
        self.fc = nn.Linear(z_dim + label_dim, 128*4*4)
        self.net = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(64, img_channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z, labels):
        x = torch.cat([z, labels], dim=1)
        x = self.fc(x).view(-1, 128, 4, 4)
        return self.net(x)

In [None]:
critic = Critic(CHANNELS, CLASSES).to(device)
generator = Generator(Z_DIM, CLASSES, CHANNELS).to(device)
print(critic)
print(generator)

In [None]:
def gradient_penalty(critic, real, fake, labels):
    batch_size = real.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    interpolated = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
    pred = critic(interpolated, labels)
    grads = torch.autograd.grad(
        outputs=pred,
        inputs=interpolated,
        grad_outputs=torch.ones_like(pred),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    grads = grads.view(grads.size(0), -1)
    gp = ((grads.norm(2, dim=1) - 1) ** 2).mean()
    return gp

## 3. Train the GAN <a name="train"></a>

In [None]:
c_optimizer = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(ADAM_BETA_1, ADAM_BETA_2))
g_optimizer = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(ADAM_BETA_1, ADAM_BETA_2))

In [None]:
def train_wgan_gp(epochs):
    generator.train()
    critic.train()

    for epoch in range(epochs):
        for i, (real_imgs, labels) in enumerate(tqdm(train_loader)):
            real_imgs = real_imgs.to(device)
            labels = F.one_hot(labels, num_classes=CLASSES).float().to(device)

            # Critic Êõ¥Êñ∞Â§öÊ≠•
            for _ in range(CRITIC_STEPS):
                z = torch.randn(real_imgs.size(0), Z_DIM, device=device)
                fake_imgs = generator(z, labels)

                real_preds = critic(real_imgs, labels)
                fake_preds = critic(fake_imgs.detach(), labels)

                c_loss = fake_preds.mean() - real_preds.mean()
                gp = gradient_penalty(critic, real_imgs, fake_imgs, labels)
                total_c_loss = c_loss + GP_WEIGHT * gp

                c_optimizer.zero_grad()
                total_c_loss.backward()
                c_optimizer.step()

            # Generator Êõ¥Êñ∞
            z = torch.randn(real_imgs.size(0), Z_DIM, device=device)
            fake_imgs = generator(z, labels)
            g_loss = -critic(fake_imgs, labels).mean()

            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()

        print(f"Epoch [{epoch+1}/{epochs}] | Critic Loss: {total_c_loss.item():.4f} | Generator Loss: {g_loss.item():.4f}")

# %%
train_wgan_gp(EPOCHS)

## Generate images

In [None]:
def generate_images(generator, num_images=10):
    generator.eval()
    z = torch.randn(num_images, Z_DIM, device=device)
    labels_0 = F.one_hot(torch.zeros(num_images, dtype=torch.long), num_classes=CLASSES).float().to(device)
    labels_1 = F.one_hot(torch.ones(num_images, dtype=torch.long), num_classes=CLASSES).float().to(device)

    with torch.no_grad():
        imgs_0 = generator(z, labels_0).cpu()
        imgs_1 = generator(z, labels_1).cpu()
    
    imgs_0 = (imgs_0 + 1) / 2  # [0,1]
    imgs_1 = (imgs_1 + 1) / 2

    grid_0 = utils.make_grid(imgs_0, nrow=5)
    grid_1 = utils.make_grid(imgs_1, nrow=5)

    utils.save_image(grid_0, "./output/generated_label_0.png")
    utils.save_image(grid_1, "./output/generated_label_1.png")

generate_images(generator)