In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import log2

In [2]:
factors = [1, 1, 1, 1, 1/2, 1/4, 1/8, 1/16, 1/32] # Feature map for adjusting for each resolution

# Model Building

In [3]:
class WSConv2d(nn.Module): # Weight Scaled 2D
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) # Initializes the conv layer
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5 # Calculates the scaling factor
        self.bias = self.conv.bias # Sets a manual Bias
        self.conv.bias = None # Uses self.bias instead of conv.bias

        # initialize conv layer
        nn.init.normal_(self.conv.weight) # Weights initialized (Normal Distribution)
        nn.init.zeros_(self.bias) # Bias is a tensor of zeros

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1) # Input is scaled to prevent 'Vanishing Gradient'
                                                                                       # Bias reshaped to match output dimesnion then applies

In [4]:
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8 # Prevents division by zero

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon) # Normalizing the vector across all channels

In [5]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm=True):l
        super(ConvBlock, self).__init__()
        self.use_pn = use_pixelnorm
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels) # Allows stacking of Conv Blocks
        self.leaky = nn.LeakyReLU(0.2) # Prevents vanishing gradient (Negative values have contribution)
        self.pn = PixelNorm() 

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x # Computes the tensor and returns it

In [6]:
class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3): # Takes a Random Noise of z_dim
        super(Generator, self).__init__()

        self.initial = nn.Sequential( # 1x1 latent vector -> 4x4
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0), # 1x1->4x4,
            nn.LeakyReLU(0.2), # Prevents dying neurons
            WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        ) # Building board for the generator (Canvas)

        self.initial_rgb = WSConv2d( # Converts 4x4 into a RGB Image
            in_channels, img_channels, kernel_size=1, stride=1, padding=0 
        ) # Returns a low res image
        self.prog_blocks, self.rgb_layers = ( # Allows stacking of layers (Resolutions)
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]), # Converts feature maps into RGB at each step
        )

        for i in range(
            len(factors) - 1
        ):
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1]) # Each block increases resolution 
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c)) # Learning is done here
            self.rgb_layers.append( # Converts output into RGB
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled) # allows the images to Fade Into higher resolution smoothly

    def forward(self, x, alpha, steps):
        out = self.initial(x)

        if steps == 0:
            return self.initial_rgb(out) # Prevents error for 1 channel dimension (apparently works)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[step](upscaled)

        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out) # Smooth Fade in


In [7]:
class Discriminator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        for i in range(len(factors) - 1, 0, -1): # Setup Blocks in reverse order (High Res To Low Res)
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out, use_pixelnorm=False)) # Pixel Norm is not essential for generators
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d( # Downsamples image by factor of 2 (Intermediate between blocks)
            kernel_size=2, stride=2
        )  
        
        self.final_block = nn.Sequential(
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d( # Final Output Indicates real or fake
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),
        )

    def fade_in(self, alpha, downscaled, out): # Fades in for lower res
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x): # Prevents generator mode collapse
        batch_statistics = ( # Calcs std 
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3]) # Adds std as as cahnnel
        )
        return torch.cat([x, batch_statistics], dim=1) # Generator produces varies results

    def forward(self, x, alpha, steps):
        cur_step = len(self.prog_blocks) - steps # Start RES

        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0: # Intial step adds the minibatch channel
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

In [8]:
# Test cases to check if all fuctions run for all cases. (Check Change_logs/Sources.txt)

if __name__ == "__main__":
    Z_DIM = 50
    IN_CHANNELS = 256
    gen = Generator(Z_DIM, IN_CHANNELS, img_channels=3)
    critic = Discriminator(Z_DIM, IN_CHANNELS, img_channels=3)

    for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
        num_steps = int(log2(img_size / 4))
        x = torch.randn((1, Z_DIM, 1, 1))
        z = gen(x, 0.5, steps=num_steps)
        assert z.shape == (1, 3, img_size, img_size)
        out = critic(z, alpha=0.5, steps=num_steps)
        assert out.shape == (1, 1)
        print(f"Success! At img size: {img_size}")

  torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])


Success! At img size: 4
Success! At img size: 8
Success! At img size: 16
Success! At img size: 32
Success! At img size: 64
Success! At img size: 128
Success! At img size: 256
Success! At img size: 512
Success! At img size: 1024


# Model Training Utils

In [9]:
import torch
import random
import numpy as np
import os
import torchvision
import torch.nn as nn
from torchvision.utils import save_image
from scipy.stats import truncnorm

In [10]:
def plot_to_tensorboard(
    writer, loss_critic, loss_gen, real, fake, tensorboard_step
):
    writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step) # Log data onto tensorboard

    with torch.no_grad(): # Create image grid
        img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
        writer.add_image("Real", img_grid_real, global_step=tensorboard_step) # Log image grids
        writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)

In [11]:
# Check WGAN-GP in change_logs/source.txt

def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"): # Critic: Score for generator
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta) # Images b/w real and fake
    interpolated_images.requires_grad_(True)

    mixed_scores = critic(interpolated_images, alpha, train_step) #Computes gradient penalty

    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1) # Calculates Gradient Nomr (Magnitude penalty)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2) # Enforces  Lipschitz continuity
    return gradient_penalty

In [12]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"): # Save checkpoints to resume later (Useful for less ram computations)
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

In [13]:
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location="cuda")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [15]:
def generate_examples(gen, steps, truncation=0.7, n=100): # Demo purposes
    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad(): # Truncation for noise control
            noise = torch.tensor(truncnorm.rvs(-truncation, truncation, size=(1, config.Z_DIM, 1, 1)), device=config.DEVICE, dtype=torch.float32)
            img = gen(noise, alpha, steps)
            save_image(img*0.5+0.5, f"saved_examples/img_{i}.png")
    gen.train()

In [24]:
import cv2
import torch
from math import log2
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

# Intial Parameters
START_TRAIN_AT_IMG_SIZE = 128 # Intial Res
DATASET = 'celeba_hq_org'
CHECKPOINT_GEN = "generator.pth" # File saving location
CHECKPOINT_CRITIC = "critic.pth"
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" # Use cuda if Nvidia GPU Available
SAVE_MODEL = True
LOAD_MODEL = False
LEARNING_RATE = 1e-3
BATCH_SIZES = [16, 16, 16, 8, 8, 8, 8, 4, 2] # Change batch_size with respect to ram needs
CHANNELS_IMG = 3
Z_DIM = 256  
IN_CHANNELS = 256 
CRITIC_ITERATIONS = 1 # 5 for WGAN performance
LAMBDA_GP = 10 # Gradient penalty weight
PROGRESSIVE_EPOCHS = [10] * len(BATCH_SIZES) # No of epoches
FIXED_NOISE = torch.randn(8, Z_DIM, 1, 1).to(DEVICE) # Noise generation
NUM_WORKERS = 4

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


In [17]:
batch_statistics = (torch.std(x, dim=0) + 1e-8).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])

  batch_statistics = (torch.std(x, dim=0) + 1e-8).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])


In [18]:
torch.backends.cudnn.benchmarks = True # Selects most eff algo


def get_loader(image_size): # Fetch images
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)), # Resize as a square, consistent
            transforms.ToTensor(), # Converts to tensor E(0,1)
            transforms.RandomHorizontalFlip(p=0.5), 
            transforms.Normalize( # Normalise with mean = 0.5, std= 0.5 
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    return loader, dataset

In [19]:
def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
    tensorboard_step,
    writer,
    scaler_gen,
    scaler_critic,
):
    loop = tqdm(loader, leave=True)
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
        # which is equivalent to minimizing the negative of the expression
        noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(DEVICE)

        with torch.amp.autocast(device_type='mps'):
            fake = gen(noise, alpha, step)
            critic_real = critic(real, alpha, step)
            critic_fake = critic(fake.detach(), alpha, step)
            gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake))
                + LAMBDA_GP * gp
                + (0.001 * torch.mean(critic_real ** 2))
            )

        opt_critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        with torch.amp.autocast(device_type='mps'):
            gen_fake = critic(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        # Update alpha and ensure less than 1
        alpha += cur_batch_size / (
            (PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        if batch_idx % 500 == 0:
            with torch.no_grad():
                fixed_fakes = gen(FIXED_NOISE, alpha, step) * 0.5 + 0.5
            plot_to_tensorboard(
                writer,
                loss_critic.item(),
                loss_gen.item(),
                real.detach(),
                fixed_fakes.detach(),
                tensorboard_step,
            )
            tensorboard_step += 1

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )

    return tensorboard_step, alpha

In [20]:
def main():
    # initialize gen and disc, note: discriminator should be called critic,
    # according to WGAN paper (since it no longer outputs between [0, 1])
    # but really who cares..
    gen = Generator(
        Z_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
    ).to(DEVICE)
    critic = Discriminator(
        Z_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
    ).to(DEVICE)

    # initialize optimizers and scalers for FP16 training
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
    opt_critic = optim.Adam(
        critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99)
    )
    scaler_critic = torch.amp.GradScaler()
    scaler_gen = torch.amp.GradScaler()

    # for tensorboard plotting
    writer = SummaryWriter(f"logs/gan1")

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC, critic, opt_critic, LEARNING_RATE,
        )

    gen.train()
    critic.train()

    tensorboard_step = 0
    # start at step that corresponds to img size that we set in config
    step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))
    for num_epochs in PROGRESSIVE_EPOCHS[step:]:
        alpha = 1e-5  # start with very low alpha
        loader, dataset = get_loader(4 * 2 ** step)  # 4->0, 8->1, 16->2, 32->3, 64 -> 4
        print(f"Current image size: {4 * 2 ** step}")

        for epoch in range(num_epochs):
            print(f"Epoch [{epoch+1}/{num_epochs}]")
            tensorboard_step, alpha = train_fn(
                critic,
                gen,
                loader,
                dataset,
                step,
                alpha,
                opt_critic,
                opt_gen,
                tensorboard_step,
                writer,
                scaler_gen,
                scaler_critic,
            )

            if SAVE_MODEL:
                save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
                save_checkpoint(critic, opt_critic, filename=CHECKPOINT_CRITIC)

        step += 1  # progress to the next img size

In [None]:
if __name__ == "__main__":
    main()



Current image size: 128
Epoch [1/10]


 62%|███████▍    | 2331/3750 [47:46<28:54,  1.22s/it, gp=0.0392, loss_critic=-7]