In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from torch import optim

import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torchvision.utils as vutils

import numpy as np
import matplotlib.pyplot as plt

import random
from pathlib import Path

torchvision.set_image_backend('accimage')

### Custom seed

In [None]:
seed = 42

random.seed(seed)
torch.manual_seed(seed)

### Constants

In [None]:
# Root directory for dataset
data_root = Path("/home/hugo/ai/datasets/celeba-aligned-cropped")

# Number of workers for dataloader
workers = 8

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
# size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
num_channels = 3

# Size of z latent vector (i.e. size of generator input)
z_latent_size = 100

# Size of feature maps in generator
gen_feat_size = 64

# Size of feature maps in discriminator
disc_feat_size = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

use_gpu = True

### Load data

#### Create the dataset

In [None]:
class Dataset(dsets.ImageFolder):
    # Override DatasetFolder's find_classes method since
    # we don't want any classes here
    def find_classes(self, directory: str):
        return None, {"":""}

dataset = Dataset(
    root=data_root.as_posix(),
    transform=transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
)

#### Create the dataloader

In [None]:
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers,
)

#### Select the device we want to run on

In [None]:
cuda_available = torch.cuda.is_available()

if cuda_available:
    print("Cuda is available! 🥳")
else:
    print("❌ Cuda is unavailable")

device = torch.device(
    "cuda:0" if (cuda_available and use_gpu) else "cpu"
)

#### Plot some images!

In [None]:
real_batch = next(iter(dataloader))
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(
    np.transpose(
        vutils.make_grid(
            real_batch[0].to(device)[:64],
            padding=2,
            normalize=True,
        ).cpu(),
        (1, 2, 0),
    )
)

### Models implementations

#### Weights initialization

In the DCGAN paper, the authors specify that the weights should all be initialized with mean = 0 and std = 0.02

In [None]:
def init_weights(m: nn.Module):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        m.weight.data.normal_(.0, .02)
        # Same as nn.init.normal_(m.weight.data, .0, .02)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.normal_(1., .02)
        m.weight.data.zero_() # Same as .fill_(0)
    else:
        return
    
    print(f"🔧 {m.__class__.__name__} weights initialized!")

#### Generator

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.main = nn.Sequential(
            # Input is z (latent vector), going into a convolution
            nn.ConvTranspose2d(z_latent_size, gen_feat_size * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(gen_feat_size * 8),
            nn.ReLU(inplace=True),
            
            # State size: (gen_feat_size * 8) x 4 x 4
            nn.ConvTranspose2d(gen_feat_size * 8, gen_feat_size * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(gen_feat_size * 8),
            nn.ReLU(inplace=True),
            
            # State size: (gen_feat_size * 4) x 8 x 8
            nn.ConvTranspose2d(gen_feat_size * 4, gen_feat_size * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(gen_feat_size * 8),
            nn.ReLU(inplace=True),
            
            # State size: (gen_feat_size * 2) x 16 x 16
            nn.ConvTranspose2d(gen_feat_size * 2, gen_feat_size, 4, 2, 1, bias=False),
            nn.BatchNorm2d(gen_feat_size * 8),
            nn.ReLU(inplace=True),
            
            # State size: gen_feat_size x 32 x 32
            nn.ConvTranspose2d(gen_feat_size, num_channels, 4, 2, 1, bias=False),
            nn.Tanh(),
            
            # State size: num_channels x 64 x 64
        )
    
    def forward(self, input):
        return self.main(input)

In [None]:
# Create the generator
gen_net = Generator().to(device)

# Initialize all weights to random values (mean=0, std=0.02)
gen_net.apply(init_weights)

gen_net

#### Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.main = nn.Sequential(
            # Input is num_channels x 64 x 64
            nn.Conv2d(num_channels, disc_feat_size, 4, 2, 1, bias=False),
            nn.LeakyReLU(negative_slope=.2, inplace=True),
            
            # State size: disc_feat_size x 32 x 32
            nn.Conv2d(disc_feat_size, disc_feat_size * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(disc_feat_size * 2),
            nn.LeakyReLU(negative_slope=.2, inplace=True),
            
            # State size: (disc_feat_size * 2) x 16 x 16
            nn.Conv2d(disc_feat_size * 2, disc_feat_size * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(disc_feat_size * 4),
            nn.LeakyReLU(negative_slope=.2, inplace=True),
            
            # State size: (disc_feat_size * 4) x 8 x 8
            nn.Conv2d(disc_feat_size * 4, disc_feat_size * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(disc_feat_size * 8),
            nn.LeakyReLU(negative_slope=.2, inplace=True),
            
            # State size: (disc_feat_size * 8) x 4 x 4
            nn.Conv2d(disc_feat_size * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
        )
    
    def forward(self, input):
        return self.main(input)

In [None]:
# Create the discriminator
disc_net = Discriminator().to(device)

# Initialize all weights to random values (mean=0, std=0.02)
disc_net.apply(init_weights)

disc_net

#### Loss, fixed noise & optimizers

In [None]:
# BCE (Binary CrossEntropy) loss
criterion = nn.BCELoss()

# Some fixed settings for training
fixed_noise = torch.randn(64, z_latent_size, 1, 1, device=device) # Static eval noise
real = 1.
fake = 0.

# Adam optimizers
gen_optimizer = optim.Adam(gen_net.parameters(), lr=lr, betas=(beta1, 0.999))
disc_optimizer = optim.Adam(disc_net.parameters(), lr=lr, betas=(beta1, 0.999))

### Training

In [None]:
img_list, gen_losses, disc_losses = [], [], []
num_iterations = 0

print("🔄 Starting training loop")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, (data, _) in enumerate(dataloader): # _ is the label vector (batched) we are ignoring
        
        # 1. Update discriminator network: maximize log(D(x)) + log(1 - D(G(z)))
        gen_net.zero_grad()
        
        # Format batch
        img_batch = data.to(device)
        