In [1]:
# WGAN-GP for CelebA 64x64 Face Generation
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import os

# -----------------------
# 1. Models: Generator & Critic
# -----------------------
class Generator(nn.Module):
    def __init__(self, z_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 512, 4, 1, 0),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.InstanceNorm2d(512, affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, 4, 1, 0)
        )

    def forward(self, x):
        return self.net(x).view(-1)

# -----------------------
# 2. Gradient Penalty
# -----------------------
def gradient_penalty(critic, real, fake, device="cuda"):
    batch_size, c, h, w = real.shape
    epsilon = torch.rand((batch_size, 1, 1, 1), device=device).repeat(1, c, h, w)
    interpolated = real * epsilon + fake * (1 - epsilon)
    interpolated.requires_grad_(True)

    mixed_scores = critic(interpolated)
    gradient = torch.autograd.grad(
        inputs=interpolated,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True
    )[0]

    gradient = gradient.view(gradient.shape[0], -1)
    gp = ((gradient.norm(2, dim=1) - 1)**2).mean()
    return gp

# -----------------------
# 3. Training Setup
# -----------------------
image_dir = "celebA/celeba/img_align_celeba"
os.makedirs("wgan_outputs", exist_ok=True)

transform = transforms.Compose([
    transforms.CenterCrop(160),
    transforms.Resize(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.ImageFolder(root=image_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

z_dim = 128
gen = Generator(z_dim).to(device)
critic = Critic().to(device)
opt_gen = torch.optim.Adam(gen.parameters(), lr=1e-4, betas=(0.0, 0.9))
opt_critic = torch.optim.Adam(critic.parameters(), lr=1e-4, betas=(0.0, 0.9))

# -----------------------
# 4. Training Loop
# -----------------------
critic_iters = 5
epochs = 3
lambda_gp = 10

for epoch in range(epochs):
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        cur_batch_size = real.size(0)

        # Train critic
        for _ in range(critic_iters):
            z = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)
            fake = gen(z).detach()
            critic_real = critic(real)
            critic_fake = critic(fake)
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) + lambda_gp * gp

            opt_critic.zero_grad()
            loss_critic.backward()
            opt_critic.step()

        # Train generator
        z = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)
        fake = gen(z)
        loss_gen = -torch.mean(critic(fake))

        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch+1}/{epochs}] Batch {batch_idx}/{len(dataloader)} "
                  f"Loss D: {loss_critic.item():.4f}, loss G: {loss_gen.item():.4f}")

    # Save samples
    with torch.no_grad():
        z = torch.randn(64, z_dim, 1, 1).to(device)
        samples = gen(z)
        samples = (samples + 1) / 2
        save_image(samples, f"wgan_outputs/epoch_{epoch+1}.png", nrow=8)

# -----------------------
# 5. Generate 100,000 Images
# -----------------------
print("Generating 100,000 images...")
output_dir = "wgan_outputs/generated"
os.makedirs(output_dir, exist_ok=True)

gen.eval()
batch_size = 100
for i in range(0, 100000, batch_size):
    z = torch.randn(batch_size, z_dim, 1, 1).to(device)
    with torch.no_grad():
        fake = gen(z).cpu()
        fake = (fake + 1) / 2
        for j in range(fake.size(0)):
            idx = i + j
            save_image(fake[j], f"wgan_generated/generated/image_{idx:05d}.png")

print("Done. 100,000 images saved to 'wgan_generated/'")

  backends.update(_get_backends("networkx.backends"))




Epoch [1/3] Batch 0/3166 Loss D: -6.3916, loss G: 4.5593


Epoch [1/3] Batch 100/3166 Loss D: -26.6402, loss G: 21.0721


Epoch [1/3] Batch 200/3166 Loss D: -19.2689, loss G: 33.9940


Epoch [1/3] Batch 300/3166 Loss D: -18.1477, loss G: 35.9317


Epoch [1/3] Batch 400/3166 Loss D: -20.4166, loss G: 28.7506


Epoch [1/3] Batch 500/3166 Loss D: -14.0392, loss G: 37.1196


Epoch [1/3] Batch 600/3166 Loss D: -15.8908, loss G: 39.4243


Epoch [1/3] Batch 700/3166 Loss D: -15.1784, loss G: 47.3409


Epoch [1/3] Batch 800/3166 Loss D: -18.3646, loss G: 34.2869


Epoch [1/3] Batch 900/3166 Loss D: -14.4248, loss G: 49.5545


Epoch [1/3] Batch 1000/3166 Loss D: -17.2217, loss G: 42.7312


Epoch [1/3] Batch 1100/3166 Loss D: -15.5086, loss G: 43.0588


Epoch [1/3] Batch 1200/3166 Loss D: -17.2078, loss G: 44.6685


Epoch [1/3] Batch 1300/3166 Loss D: -17.3383, loss G: 52.6133


Epoch [1/3] Batch 1400/3166 Loss D: -15.2380, loss G: 51.0174


Epoch [1/3] Batch 1500/3166 Loss D: -15.6955, loss G: 52.2828


Epoch [1/3] Batch 1600/3166 Loss D: -16.0453, loss G: 61.3832


Epoch [1/3] Batch 1700/3166 Loss D: -12.6293, loss G: 56.7133


Epoch [1/3] Batch 1800/3166 Loss D: -16.7912, loss G: 50.9496


Epoch [1/3] Batch 1900/3166 Loss D: -13.0867, loss G: 55.7634


Epoch [1/3] Batch 2000/3166 Loss D: -13.1293, loss G: 56.8894


Epoch [1/3] Batch 2100/3166 Loss D: -14.0013, loss G: 48.4248


Epoch [1/3] Batch 2200/3166 Loss D: -14.9455, loss G: 63.1257


Epoch [1/3] Batch 2300/3166 Loss D: -14.7145, loss G: 52.6760


Epoch [1/3] Batch 2400/3166 Loss D: -14.3453, loss G: 58.8153


Epoch [1/3] Batch 2500/3166 Loss D: -14.6316, loss G: 48.7392


Epoch [1/3] Batch 2600/3166 Loss D: -14.5610, loss G: 69.7032


Epoch [1/3] Batch 2700/3166 Loss D: -13.8260, loss G: 64.6036


Epoch [1/3] Batch 2800/3166 Loss D: -13.5897, loss G: 52.6613


Epoch [1/3] Batch 2900/3166 Loss D: -14.3526, loss G: 56.4141


Epoch [1/3] Batch 3000/3166 Loss D: -13.5358, loss G: 57.6091


Epoch [1/3] Batch 3100/3166 Loss D: -12.3521, loss G: 59.8272


Epoch [2/3] Batch 0/3166 Loss D: -11.0298, loss G: 55.2957


Epoch [2/3] Batch 100/3166 Loss D: -14.7308, loss G: 63.5315


Epoch [2/3] Batch 200/3166 Loss D: -14.1501, loss G: 57.8809


Epoch [2/3] Batch 300/3166 Loss D: -14.6483, loss G: 48.4475
