In [1]:
import torch
import torch.nn as nn  
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torchvision.utils as vutils
from tqdm import tqdm
import nbimporter

In [2]:
# Transforming the dataset
class DataLoaderBuilder():
    def __init__(self, root_path, image_size=(64, 64), batch_size=64, mean=0.5, std=0.5, shuffle=True):
        self.root_path = root_path
        self.image_size = image_size
        self.batch_size = batch_size
        self.mean = mean
        self.std = std
        self.shuffle = shuffle
    
    def get_dataloader(self):
        transform = transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

        # Dataset
        dataset = datasets.ImageFolder(root=self.root_path, transform=transform)
        
        # DataLoader
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=self.shuffle)
        return dataloader

In [3]:
class Critic(nn.Module):
    def __init__(self, input_dim, feature_dim):
        super().__init__()
        
        self.model = nn.Sequential(
            # Input is of shape N x (input_channels) x 64 x 64
            nn.Conv2d(input_dim, feature_dim, kernel_size=4, stride=2, padding=1),  # 32x32
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(feature_dim, feature_dim * 2, kernel_size=4, stride=2, padding=1),  # 16x16
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(feature_dim * 2, feature_dim * 4, kernel_size=4, stride=2, padding=1),  # 8x8
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(feature_dim * 4, feature_dim * 8, kernel_size=4, stride=2, padding=1),  # 4x4
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(feature_dim * 8, 1, kernel_size=4, stride=1, padding=0),  # 1x1
        )

    def forward(self, x):
        return self.model(x)

In [4]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, latent_dim, output_channels, feature_dim):
        super(Generator, self).__init__()
        
        self.model = nn.Sequential(
            # Input is of shape N x (latent_dim) x 1 x 1
            nn.ConvTranspose2d(latent_dim, feature_dim * 16, kernel_size=4, stride=1, padding=0),  # 4x4
            nn.BatchNorm2d(feature_dim * 16),
            nn.ReLU(),
            
            nn.ConvTranspose2d(feature_dim * 16, feature_dim * 8, kernel_size=4, stride=2, padding=1),  # 8x8
            nn.BatchNorm2d(feature_dim * 8),
            nn.ReLU(),
            
            nn.ConvTranspose2d(feature_dim * 8, feature_dim * 4, kernel_size=4, stride=2, padding=1),  # 16x16
            nn.BatchNorm2d(feature_dim * 4),
            nn.ReLU(),
            
            nn.ConvTranspose2d(feature_dim * 4, feature_dim * 2, kernel_size=4, stride=2, padding=1),  # 32x32
            nn.BatchNorm2d(feature_dim * 2),
            nn.ReLU(),

            nn.ConvTranspose2d(feature_dim * 2, output_channels, kernel_size=4, stride=2, padding=1),  # 64 x 64
            nn.Tanh()  # Output is of range [-1, 1]
        )
    
    def forward(self, x):
        return self.model(x)

In [5]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset_path = "../celeb_a_dataset"
dataloader = DataLoaderBuilder(root_path=dataset_path, image_size=(64, 64), batch_size=64).get_dataloader() 

latent_dim = 100
output_channels = 3
feature_dim = 64
batch_size = 64
weight_clipping = 0.01

G = Generator(latent_dim, output_channels, feature_dim).to(device)
C = Critic(output_channels, feature_dim).to(device)
initialize_weights(G)
initialize_weights(C)

C_loss = []
G_loss = []

lr_C = 5e-5
lr_G = 5e-5

# Using RMSprop as its mentioned in the paper 
C_optimizer = optim.RMSprop(C.parameters(), lr=lr_C)
G_optimizer = optim.RMSprop(G.parameters(), lr=lr_G)

In [None]:
epochs = 65
n_critic = 5

for epoch in tqdm(range(epochs)):
    for i, (real_images, _) in enumerate(dataloader):
        real_images = real_images.to(device)
        batch_size = real_images.size(0)

        # Train Critic
        for _ in range(n_critic):
            latent_vector = torch.randn((batch_size, latent_dim, 1, 1)).to(device)
            fake_images = G(latent_vector).detach()  # Detach to avoid generator gradients

            critic_real = C(real_images).view(-1)
            critic_fake = C(fake_images).view(-1)

            loss_critic = torch.mean(critic_fake) - torch.mean(critic_real)

            C_optimizer.zero_grad()
            loss_critic.backward()
            C_optimizer.step()

            # Weight clipping
            for param in C.parameters():
                param.data.clamp_(-weight_clipping, weight_clipping)

            C_loss.append(loss_critic.item())

        # Train Generator
        latent_vector = torch.randn((batch_size, latent_dim, 1, 1)).to(device)
        fake_images = G(latent_vector)
        critic_fake = C(fake_images).view(-1)
        loss_G = -torch.mean(critic_fake)

        G_optimizer.zero_grad()
        loss_G.backward()
        G_optimizer.step()

        G_loss.append(loss_G.item())

    if (epoch) % 5 == 0:
        torch.save(G.state_dict(), f"wgan_generator.pth")
        torch.save(C.state_dict(), f"wgan_critic.pth")
        print(f"Saved model weights at epoch {epoch}.")

In [None]:
# Function to generate and visualize images
def generate_images(generator, latent_dim, num_images, device, save_path=None):
    # Set the generator to evaluation mode
    generator.eval()

    # Generate random noise vectors
    noise = torch.randn((num_images, latent_dim, 1, 1), device=device)

    # Generate fake images
    with torch.no_grad():  # No need to calculate gradients
        fake_images = generator(noise)

    # Normalize the images to [0, 1] for visualization
    fake_images = (fake_images + 1) / 2  # Since the generator outputs images in range [-1, 1]

    # Create a grid of images
    grid = vutils.make_grid(fake_images, nrow=8, normalize=True)

    # Visualize the images
    plt.figure(figsize=(12, 12))
    plt.imshow(grid.permute(1, 2, 0).cpu())  # Convert CHW to HWC for visualization
    plt.axis("off")
    plt.title("Generated Images")
    plt.show()

    
# Example usage
latent_dim = 100
num_images = 64
generate_images(G, latent_dim, num_images, device, save_path="generated_images.png")
