In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from tqdm import tqdm
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import autograd


  backends.update(_get_backends("networkx.backends"))


In [2]:
print(os.getcwd())


/sciclone/data10/jrhee01/genVision-celebA


In [3]:
# Load Dataset

image_dir = "celebA/celeba/img_align_celeba"
os.makedirs("test", exist_ok=True)

transform = transforms.Compose([
    transforms.CenterCrop(160),
    transforms.Resize(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

full_dataset = datasets.ImageFolder(root=image_dir, transform=transform)

dataloader = DataLoader(full_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)



In [4]:
class Generator(torch.nn.Module):
    def __init__(self, z_dim=100, img_channels=3):
        super().__init__()

        self.main_module = nn.Sequential(
            # Latent vector z_dim → 1024 filters → 4x4
            nn.ConvTranspose2d(in_channels=z_dim, out_channels=1024, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),

            # 1024x4x4 → 512x8x8
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            # 512x8x8 → 256x16x16
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            # 256x16x16 → 128x32x32
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            # 128x32x32 → img_channels x64x64
            nn.ConvTranspose2d(in_channels=128, out_channels=img_channels, kernel_size=4, stride=2, padding=1)
        )

        self.output = nn.Tanh()

    def forward(self, z):
        return self.output(self.main_module(z))


In [5]:
class Discriminator(torch.nn.Module):
    def __init__(self, img_channels=3):
        super().__init__()
        
        # Filters: [64, 128, 256, 512]
        # Input_dim = img_channels (3 for RGB)
        # Output_dim = 1 (real/fake score)
        
        self.main_module = nn.Sequential(
            # Input: img_channels x64x64
            nn.Conv2d(in_channels=img_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 64x32x32
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 128x16x16
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 256x8x8
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 512x4x4 → output 1x1
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False)
        )

    def forward(self, x):
        x = self.main_module(x)
        return x.view(-1)


In [6]:
# Gradient Penalty function
def gradient_penalty(D, real, fake, device):
    # Get the batch size (number of samples in the batch)
    batch_size = real.size(0)
    # Generate random interpolation coefficients epsilon for each sample
    # Shape: [batch_size, 1, 1, 1] so it can broadcast across all image dimensions (C, H, W)
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
    # Interpolate between real and fake images:
    # interpolated = epsilon * real + (1 - epsilon) * fake
    # This creates points along the lines between real and fake samples
    interpolated = epsilon * real + (1 - epsilon) * fake
    # Tell PyTorch to track gradients for interpolated samples (needed to compute gradient penalty)
    interpolated.requires_grad_(True)
    # Pass interpolated samples through the discriminator (critic) to get scores
    interp_logits = D(interpolated)
    # Compute gradients of the critic output w.r.t. the interpolated samples
    # grad() returns a tuple → take the first element
    gradients = autograd.grad(
        outputs=interp_logits,
        inputs=interpolated,
        grad_outputs=torch.ones_like(interp_logits, device=device),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    # Flatten gradients for each sample into vectors (combine channel, height, width into one dimension)
    gradients = gradients.view(batch_size, -1)
    # Compute L2 norm (Euclidean norm) of gradients for each sample
    # This measures how large the gradients are for each interpolated image
    gradient_norm = gradients.norm(2, dim=1)
    #Compute gradient penalty:
    # Penalize the squared difference from 1 for each sample's gradient norm
    # Then take the average across the batch
    gp = ((gradient_norm - 1) ** 2).mean()
    return gp

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Define latent vector size (input noise dimension for generator)
z_dim = 100
# Set learning rate for both generator and discriminator optimizers
lr = 2e-4
# Number of epochs to train
n_epochs = 5
# Batch size for training
batch_size = 64
# Initialize the Generator model and move it to the selected device 
G = Generator(z_dim).to(device)
# Initialize the Discriminator (critic) model and move it to the selected device
D = Discriminator().to(device)

# Adam optimizer for the generator and discriminator
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

# Prepare a fixed random noise vector to generate consistent sample images across epochs 
fixed_noise = torch.randn(64, z_dim, 1, 1, device=device)


cpu


In [8]:
# Start looping over all epochs
for epoch in range(n_epochs):
    # Initialize epoch-level loss trackers
    g_loss_epoch = 0.0
    d_loss_epoch = 0.0
    num_batches = 0

    # Loop over the dataloader that provides batches of real images
    for real_imgs, _ in tqdm(dataloader, desc=f"Epoch {epoch+1}/{n_epochs}"):
        real_imgs = real_imgs.to(device)  # Move real images to GPU/CPU
        batch_size = real_imgs.size(0)    # Get actual batch size

        # Set hyperparameters once 
        d_iterations = 5
        lambda_gp = 10

        # Train the Discriminator multiple times
        for _ in range(d_iterations):
            # Sample random noise
            z = torch.randn(batch_size, z_dim, 1, 1, device=device)

            # Generate fake images
            fake_imgs = G(z)

            # Get discrimnator outputs for real and fake images
            output_real = D(real_imgs)
            output_fake = D(fake_imgs.detach())  # detach so gradients don’t flow into generator

            # Compute Wasserstein loss for discriminator
            loss_D = -(output_real.mean() - output_fake.mean())

            # Compute gradient penalty
            gp = gradient_penalty(D, real_imgs, fake_imgs.detach(), device)

            # Total discriminator loss
            loss_D += lambda_gp * gp

            # Update Discriminator
            optimizer_D.zero_grad()
            loss_D.backward()
            optimizer_D.step()

        # Train the Generator
        z = torch.randn(batch_size, z_dim, 1, 1, device=device)
        fake_imgs = G(z)
        output_fake = D(fake_imgs)
        loss_G = -output_fake.mean()  # Generator tries to make D(fake) large

        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        # Accumulate losses
        g_loss_epoch += loss_G.item()
        d_loss_epoch += loss_D.item()
        num_batches += 1

    # Generate and save sample images using fixed noise
    with torch.no_grad():
        samples = G(fixed_noise)
        save_image(samples, f"wgan_outputs/epoch_{epoch+1:03d}.png", normalize=True)

    # Print average losses for the epoch
    print(f"Epoch [{epoch+1}/{n_epochs}]  Loss_D: {d_loss_epoch/num_batches:.4f}  Loss_G: {g_loss_epoch/num_batches:.4f}")


Epoch 1/5:   0%|          | 0/3166 [00:00<?, ?it/s]



Epoch 1/5:   0%|          | 1/3166 [00:15<13:25:03, 15.26s/it]

Epoch 1/5:   0%|          | 2/3166 [00:28<12:29:31, 14.21s/it]

Epoch 1/5:   0%|          | 3/3166 [00:42<12:11:03, 13.87s/it]

Epoch 1/5:   0%|          | 4/3166 [00:55<12:04:02, 13.74s/it]

Epoch 1/5:   0%|          | 5/3166 [01:09<12:01:17, 13.69s/it]

Epoch 1/5:   0%|          | 6/3166 [01:22<11:57:14, 13.62s/it]

Epoch 1/5:   0%|          | 7/3166 [01:36<11:53:50, 13.56s/it]

Epoch 1/5:   0%|          | 8/3166 [01:49<11:54:46, 13.58s/it]

Epoch 1/5:   0%|          | 9/3166 [02:03<11:52:03, 13.53s/it]

Epoch 1/5:   0%|          | 10/3166 [02:16<11:50:39, 13.51s/it]

Epoch 1/5:   0%|          | 11/3166 [02:30<11:51:23, 13.53s/it]

Epoch 1/5:   0%|          | 12/3166 [02:43<11:48:51, 13.48s/it]

Epoch 1/5:   0%|          | 13/3166 [02:57<11:48:25, 13.48s/it]

Epoch 1/5:   0%|          | 14/3166 [03:10<11:47:47, 13.47s/it]

Epoch 1/5:   0%|          | 15/3166 [03:24<11:47:36, 13.47s/it]

Epoch 1/5:   1%|          | 16/3166 [03:37<11:46:36, 13.46s/it]

Epoch 1/5:   1%|          | 17/3166 [03:51<11:47:22, 13.48s/it]

Epoch 1/5:   1%|          | 18/3166 [04:04<11:45:34, 13.45s/it]

Epoch 1/5:   1%|          | 19/3166 [04:17<11:46:21, 13.47s/it]

Epoch 1/5:   1%|          | 20/3166 [04:31<11:49:39, 13.53s/it]

Epoch 1/5:   1%|          | 21/3166 [04:45<11:48:40, 13.52s/it]

Epoch 1/5:   1%|          | 22/3166 [04:58<11:48:03, 13.51s/it]

Epoch 1/5:   1%|          | 23/3166 [05:12<11:49:04, 13.54s/it]

Epoch 1/5:   1%|          | 24/3166 [05:25<11:48:21, 13.53s/it]

Epoch 1/5:   1%|          | 25/3166 [05:39<11:46:29, 13.50s/it]

Epoch 1/5:   1%|          | 26/3166 [05:52<11:46:29, 13.50s/it]

Epoch 1/5:   1%|          | 27/3166 [06:06<11:46:34, 13.51s/it]

In [None]:
# Generate 10,000 Images
os.makedirs("wgan_outputs/generated", exist_ok=True)
G.eval()
with torch.no_grad():
    for i in tqdm(range(0, 10000, 64), desc="Generating final images"):
        z = torch.randn(64, z_dim, 1, 1, device=device)
        gen_imgs = G(z)
        for j in range(gen_imgs.size(0)):
            save_image(gen_imgs[j], f"wgan_outputs/generated/{i + j:05d}.png", normalize=True)
