In [None]:
from glob import glob 
from typing import Tuple, Callable, Dict
import os

import matplotlib.pyplot as plt 
import numpy as np 
import torch 
import torch.nn as nn 
import torch.optim as optim 
from torch.nn import Module
from PIL import Image 
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, Normalize

from datetime import datetime

In [None]:
data_dir = 'processed-celeba-small/processed_celeba_small/celeba'
img_channels = 3

USING THE PREPROCESSED DATASET, THE IMAGES ARE ALREADY CROPPED TO 64x64x3

In [None]:
def get_transforms(size: Tuple[int, int]) -> Callable:
    transforms = [ToTensor(), Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]
    return Compose(transforms)

In [None]:
class DatasetDirectory(Dataset):
    def __init__(self, directory: str, transforms: Callable = None, extension: str = '.jpg'):
        self.directory = directory
        self.extension = extension
        self.transforms = transforms if transforms is not None else get_transforms()
        self.dataset = ImageFolder(root=directory, transform=self.transforms)

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, index: int) -> torch.Tensor:
        return self.dataset[index][0]  


In [None]:
# # Function to check dataset outputs
# def check_dataset_outputs(dataset: Dataset):
#     assert len(dataset) == 32600, 'The dataset should contain 32,600 images.'
#     index = np.random.randint(len(dataset))
#     image = dataset[index]
#     assert image.shape == torch.Size([3, 64, 64]), 'You must reshape the images to be 64x64'
#     assert image.min() >= -1 and image.max() <= 1, 'The images should range between -1 and 1.'
#     print('Congrats, your dataset implementation passed all the tests')

# Create the dataset
dataset = DatasetDirectory(data_dir, get_transforms((64, 64)))


print(f"Actual number of images in the dataset: {len(dataset)}")
print(f"Contents of the dataset directory '{data_dir}':")
print(os.listdir(data_dir))

In [None]:
def denormalize(images):
    return ((images +1.)/2.*255).astype(np.uint8)

fig = plt.figure(figsize=(20,4))
plot_size = 20
for idx in np.arange(plot_size):
    ax = fig.add_subplot(2, int(plot_size/2), idx+1, xticks=[], yticks=[])
    img = dataset[idx].numpy()
    img = np.transpose(img,(1,2,0))
    img = denormalize(img)
    ax.imshow(img)

In [None]:
class Discriminator(Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.img_channels = img_channels
        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels=img_channels,out_channels=64, kernel_size=4, stride=2, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.2, inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.2, inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(512),
                nn.LeakyReLU(0.2, inplace=True)
            )
        ])
        self.final_layer = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=0)
        

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        for block in self.blocks:
            x = block(x)
        x = self.final_layer(x)
        x = x.view(-1,1,1,1)
        return x

In [None]:
# def check_discriminator(discriminator: torch.nn.Module):
#     images = torch.randn(1, 3, 64, 64)
#     score = discriminator(images)
#     assert score.shape == torch.Size([1, 1, 1, 1]), 'The discriminator output should be a single score.'
#     print('Congrats, your discriminator implementation passed all the tests')

In [None]:
discriminator = Discriminator()
# check_discriminator(discriminator)

In [None]:
class Generator(Module):
    def __init__(self, latent_dim: int):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.img_channels = img_channels
        self.init_layers()

    def init_layers(self):
        self.inital = nn.Sequential(
            nn.ConvTranspose2d(self.latent_dim, out_channels=512, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(True)
            ),
            nn.Sequential(
                nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(True)
            ),
            nn.Sequential(
                nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(True)
            ),
            nn.Sequential(
                nn.ConvTranspose2d(in_channels=64, out_channels=img_channels, kernel_size=4, stride=2, padding=1),
                nn.Tanh()
            )
        ])
    
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        x = self.inital(x)
        for blocks in self.blocks:
            x = blocks(x)
        return x
        

In [None]:
# def check_generator(generator: torch.nn.Module, latent_dim: int):
#     latent_vector = torch.randn(1, latent_dim, 1, 1)
#     image = generator(latent_vector)
#     assert image.shape == torch.Size([1, 3, 64, 64]), 'The generator should output a 64x64x3 images.'
#     print('Congrats, your generator implementation passed all the tests')

In [None]:
latent_dim = 128
generator = Generator(latent_dim)
# check_generator(generator, latent_dim)

In [None]:
def create_optimizers(generator: Module, discriminator:Module, lr=0.0001, beta1:float = 0.5, beta2:float = 0.999):
    g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas = (beta1,beta2))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas = (beta1,beta2))
    return g_optimizer, d_optimizer

In [None]:
def generator_loss(fake_logits):
    loss = -torch.mean(fake_logits)
    return loss 

def discriminator_loss(real_logits, fake_logits):
    loss = torch.mean(real_logits)-torch.mean(fake_logits)
    return loss 


In [None]:

def gradient_penalty(discriminator, real_images, fake_images, device):
    batch_size, c, h, w = real_images.shape
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device, requires_grad=True)
    interpolated_images = epsilon * real_images + (1 - epsilon) * fake_images

    mixed_scores = discriminator(interpolated_images)

    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradient = gradient.view(batch_size, -1)
    gradient_norm = gradient.norm(2, dim=1)
    penalty = torch.mean((gradient_norm - 1) ** 2)
    return penalty

In [None]:
def discriminator_step(batch_size, latent_dim, real_images, generator, discriminator, d_optimizer, lambda_gp, device):
    noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
    fake_images = generator(noise)
    
    real_logits = discriminator(real_images)
    fake_logits = discriminator(fake_images.detach())
    
    gp = gradient_penalty(discriminator, real_images, fake_images, device)
    d_loss = torch.mean(fake_logits) - torch.mean(real_logits) + lambda_gp * gp
    
    d_optimizer.zero_grad()
    d_loss.backward()
    d_optimizer.step()
    
    return {'loss': d_loss.item(), 'gp': gp.item()}

def generator_step(batch_size, latent_dim, generator, discriminator, g_optimizer, device):
    noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
    fake_images = generator(noise)
    fake_logits = discriminator(fake_images)
    
    g_loss = -torch.mean(fake_logits)
    
    g_optimizer.zero_grad()
    g_loss.backward()
    g_optimizer.step()
    
    return {'loss': g_loss.item()}

In [None]:
latent_dim = 128
device = 'cuda'
n_epochs = 4
batch_size = 64

In [None]:
print_every = 50
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
g_optimizer, d_optimizer = create_optimizers(generator, discriminator)

dataloader = DataLoader(dataset, 
                        batch_size=64, 
                        shuffle=True, 
                        num_workers=4, 
                        drop_last=True,
                        pin_memory=False)


In [None]:
def display(fixed_latent_vector: torch.Tensor):
    """ helper function to display images during training """
    fig = plt.figure(figsize=(14, 4))
    plot_size = 16
    for idx in np.arange(plot_size):
        ax = fig.add_subplot(2, int(plot_size/2), idx+1, xticks=[], yticks=[])
        img = fixed_latent_vector[idx, ...].detach().cpu().numpy()
        img = np.transpose(img, (1, 2, 0))
        img = denormalize(img)
        ax.imshow(img)
    plt.show()

In [None]:
fixed_latent_vector = torch.randn(16, latent_dim,1,1).float().cuda()
losses = []
lr = 0.0001
beta1 = 0.5
beta2 = 0.999
lambda_gp = 10  
critic_steps = 5
n_epochs = 100
print_every = 10

g_optimizer, d_optimizer = create_optimizers(generator, discriminator, lr, beta1, beta2)


for epoch in range(n_epochs):
    for batch_i, real_images in enumerate(dataloader):
        real_images = real_images.to(device)
        batch_size = real_images.size(0)

        d_step_res = discriminator_step(batch_size, latent_dim, real_images, generator, discriminator, d_optimizer, lambda_gp, device)
        d_loss_value = d_step_res['loss']
        gp_value = d_step_res['gp']

        g_step_res = generator_step(batch_size, latent_dim, generator, discriminator, g_optimizer, device)
        g_loss_value = g_step_res['loss']

        if batch_i % print_every == 0:
            d = d_loss_value
            g = g_loss_value
            losses.append((d,g))
            time = str(datetime.now()).split('.')[0]
            print(f'{time} | Epoch [{epoch+1}/{n_epochs}] | Batch {batch_i}/{len(dataloader)} | d_loss: {d:.4f} | g_loss: {g:.4f}')
    
    # display images during training
    generator.eval()
    generated_images = generator(fixed_latent_vector)
    display(generated_images)
    generator.train()

In [None]:
"""
DO NOT MODIFY ANYTHING IN THIS CELL
"""
fig, ax = plt.subplots()
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator', alpha=0.5)
plt.plot(losses.T[1], label='Generator', alpha=0.5)
plt.title("Training Losses")
plt.legend()