In [1]:
import cv2
import os
import random
import numpy as np

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image

from math import log2
from tqdm import tqdm

from scipy.stats import truncnorm
from IPython.display import clear_output

In [2]:
def plot_to_tensorboard(
    writer, loss_critic, loss_gen, real, fake, tensorboard_step
):
    writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)

    with torch.no_grad():
        # take out (up to) 8 examples to plot
        img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
        writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
        writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)


def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location="cuda")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [3]:
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

In [4]:
class WSConv2d(nn.Module):
    def __init__(
        self, in_channels,out_channels,
        kernel_size=3,stride=1,padding=1,gain=2
    ):
        super().__init__()
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding)
        self.scale = (gain/(in_channels*kernel_size**2))**.5
        self.bias = self.conv.bias
        self.conv.bias = None
        
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)
    
    def forward(self,x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
    
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.eps = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.eps)
    
class ConvBlock(nn.Module):
    def __init__(self,in_channels,out_channels,use_pixelnorm=True):
        super().__init__()
        self.conv1 = WSConv2d(in_channels,out_channels)
        self.conv2 = WSConv2d(out_channels,out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()
        self.use_pixelnorm = use_pixelnorm
    
    def forward(self,x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pixelnorm else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pixelnorm else x
        return x

In [5]:
class Generator(nn.Module):
    def __init__(self,z_dim,in_channels,img_channels=3):
        super().__init__()
        self.first = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm()
        )
        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, 
            stride=1, padding=0
        )
        
        self.prog_blocks = nn.ModuleList([])
        self.rgb_layers = nn.ModuleList([self.initial_rgb])
        
        for i in range(len(factors)-1):
            conv_in_channels = int(in_channels*factors[i])
            conv_out_channels = int(in_channels*factors[i+1])
            
            conv_in_channels = int(in_channels * factors[i])
            conv_out_channels = int(in_channels * factors[i + 1])
            self.prog_blocks.append(ConvBlock(conv_in_channels, conv_out_channels))
            self.rgb_layers.append(WSConv2d(conv_out_channels, img_channels, kernel_size=1, stride=1, padding=0))
    
    def fade_in(self,alpha,upscaled,generated):
        return torch.tanh(alpha*generated +(1-alpha)*upscaled)
        
    def forward(self,x,alpha,steps):
        out = self.first(x) # 4x4
        if steps==0:
            return self.initial_rgb(out)
        
        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[step](upscaled)
            
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

In [6]:
class Critic(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super().__init__()
        self.prog_blocks = nn.ModuleList([])
        self.rgb_layers = nn.ModuleList([])
        self.leaky = nn.LeakyReLU(.2)
        
        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out, use_pixelnorm=False))
            self.rgb_layers.append(WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0))
            
        
        # only for 4x4 resolution
        self.end_rgb = WSConv2d(img_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.rgb_layers.append(self.end_rgb)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
        
        # block for 4x4 resolution
        self.final_block = nn.Sequential(
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, 1, kernel_size=1, padding=0, stride=1)
        )
        
        
    def fade_in(self,alpha,downscaled,out):
        return alpha*out+(1-alpha)*downscaled
        
    def minibatch_std(self,x):
        batch_statistics = torch.std(x,dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        return torch.cat([x,batch_statistics],dim=1)
    
    def forward(self,x,alpha,steps): 
        current_step = len(self.prog_blocks) - steps
        out = self.leaky(self.rgb_layers[current_step](x))
        
        if steps == 0:  # 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)
        
        downscaled = self.leaky(self.rgb_layers[current_step+1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[current_step](out))
        out = self.fade_in(alpha,downscaled,out)
        for step in range(current_step+1,len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)
            
        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0],-1)

In [7]:
# test
Z_DIM = 100
IN_CHANNELS = 256
gen = Generator(Z_DIM, IN_CHANNELS, img_channels=3)
critic = Critic(Z_DIM, IN_CHANNELS, img_channels=3)

for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
    num_steps = int(log2(img_size/4))
    x =  torch.randn((1,Z_DIM,1,1))
    z = gen(x,.5,steps=num_steps)
    assert z.shape==(1,3,img_size,img_size)
    out = critic(z,alpha=.5,steps=num_steps)
    assert out.shape == (1,1)
    print(f"OK: {img_size}x{img_size}",end=" ")

OK: 4x4 OK: 8x8 OK: 16x16 OK: 32x32 OK: 64x64 OK: 128x128 OK: 256x256 OK: 512x512 OK: 1024x1024 

In [8]:
START_TRAIN_AT_IMG_SIZE = 256
DATASET = '../content/drive/MyDrive/'
CHECKPOINT_GEN = "generator.pth"
CHECKPOINT_CRITIC = "critic.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAVE_MODEL = True
LOAD_MODEL = True
LR = 1e-3
BATCH_SIZES = [2048, 1024, 512, 128, 32, 16, 10, 4, 2]
CHANNELS_IMG = 3
Z_DIM = 512  # 256/512
IN_CHANNELS = 512  # 256/512
CRITIC_ITERATIONS = 1
LAMBDA_GP = 10
PROGRESSIVE_EPOCHS = [64] * len(BATCH_SIZES) # more (?)
FIXED_NOISE = torch.randn(8, Z_DIM, 1, 1).to(DEVICE)
NUM_WORKERS = 2

torch.backends.cudnn.benchmarks = True

In [9]:
def generate_examples(gen, steps, truncation=0.7, n=100):
    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.tensor(truncnorm.rvs(-truncation, truncation, size=(1, Z_DIM, 1, 1)), device=DEVICE, dtype=torch.float32)
            img = gen(noise, alpha, steps)
            save_image(img*0.5+0.5, f"saved/img_{i}.png")
    gen.train()

In [10]:
def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    return loader, dataset

In [11]:
# import matplotlib.pyplot as plt
# l,d = get_loader(4 * 2 ** step)
# a = iter(d)
# next(a)
# next(a)
# plt.imshow(next(a)[0].permute(1,2,0)*0.5+0.5)
# plt.show()

In [12]:
def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
    tensorboard_step,
    writer,
    scaler_gen,
    scaler_critic,
):
    loop = tqdm(loader, leave=True)
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]

        noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(DEVICE)

        with torch.cuda.amp.autocast():
            fake = gen(noise, alpha, step)
            critic_real = critic(real, alpha, step)
            critic_fake = critic(fake.detach(), alpha, step)
            gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake))
                + LAMBDA_GP * gp
                + (0.001 * torch.mean(critic_real ** 2))
            )

        opt_critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()

        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        alpha += cur_batch_size / (
            (PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        if batch_idx % 256 == 0 and batch_idx:
            with torch.no_grad():
                fixed_fakes = gen(FIXED_NOISE, alpha, step) * 0.5 + 0.5
            plot_to_tensorboard(
                writer,
                loss_critic.item(),
                loss_gen.item(),
                real.detach(),
                fixed_fakes.detach(),
                tensorboard_step,
            )
            tensorboard_step += 1

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )

    return tensorboard_step, alpha

In [13]:
gen = Generator(Z_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG).to(DEVICE)
critic = Critic(Z_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG).to(DEVICE)

# initialize optimizers and scalers for FP16 training
opt_gen = optim.Adam(gen.parameters(), lr=LR, betas=(0.0, 0.99))
opt_critic = optim.Adam(critic.parameters(), lr=LR, betas=(0.0, 0.99))
scaler_critic = torch.cuda.amp.GradScaler()
scaler_gen = torch.cuda.amp.GradScaler()

# for tensorboard plotting
writer = SummaryWriter(f"logs")
if LOAD_MODEL:
    load_checkpoint(
        CHECKPOINT_GEN, gen, opt_gen, LR,
    )
    load_checkpoint(
        CHECKPOINT_CRITIC, critic, opt_critic, LR,
    )

gen.train()
critic.train()
tensorboard_step = 0
step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))

=> Loading checkpoint
=> Loading checkpoint


In [None]:
for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5  # start with very low alpha
    loader, dataset = get_loader(4 * 2 ** step)
    print(f"Image size: {4 * 2 ** step}, Step: {step}")

    for epoch in range(num_epochs):
        tensorboard_step, alpha = train_fn(
            critic,
            gen,
            loader,
            dataset,
            step,
            alpha,
            opt_critic,
            opt_gen,
            tensorboard_step,
            writer,
            scaler_gen,
            scaler_critic,
        )
        generate_examples(gen,step,n=4)

        if SAVE_MODEL:
            save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
            save_checkpoint(critic, opt_critic, filename=CHECKPOINT_CRITIC)

    clear_output()
    step += 1 # next img size

  0%|          | 0/500 [00:00<?, ?it/s]

Image size: 256, Step: 6


100%|██████████| 500/500 [24:03<00:00,  2.89s/it, gp=0.0673, loss_critic=1.42]  
100%|██████████| 500/500 [24:05<00:00,  2.89s/it, gp=0.0835, loss_critic=-2.6]  
100%|██████████| 500/500 [24:07<00:00,  2.89s/it, gp=0.111, loss_critic=-.15]   
100%|██████████| 500/500 [24:08<00:00,  2.90s/it, gp=0.0281, loss_critic=-6.37]  
100%|██████████| 500/500 [24:08<00:00,  2.90s/it, gp=0.155, loss_critic=-10.5]   
100%|██████████| 500/500 [24:09<00:00,  2.90s/it, gp=0.0106, loss_critic=-3.35] 
100%|██████████| 500/500 [24:10<00:00,  2.90s/it, gp=0.0122, loss_critic=-1.19] 
100%|██████████| 500/500 [24:09<00:00,  2.90s/it, gp=0.134, loss_critic=5.54]    
100%|██████████| 500/500 [24:11<00:00,  2.90s/it, gp=0.0196, loss_critic=1.69]  
100%|██████████| 500/500 [24:11<00:00,  2.90s/it, gp=0.0444, loss_critic=-1.15]  
100%|██████████| 500/500 [24:11<00:00,  2.90s/it, gp=0.00818, loss_critic=-9.45] 
100%|██████████| 500/500 [24:11<00:00,  2.90s/it, gp=0.0534, loss_critic=-7.1]  
100%|██████████| 500/50

## Links:
* https://youtu.be/nkQHASviYac
* https://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf