In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard
import albumentations as A
import time

# Generator

In [None]:
factors = [1,1,1,1,0.5,0.5,0.5,0.5,0.5]

In [None]:
class WSConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

class WSLinear(nn.Module):
    def __init__(self, in_channels, out_channels,gain=2):
        super(WSLinear, self).__init__()
        self.linear = nn.Linear(in_channels, out_channels)
        self.scale = (gain / in_channels) ** 0.5
        self.bias = self.linear.bias
        self.linear.bias = None

        # initialize conv layer
        nn.init.normal_(self.linear.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.linear(x * self.scale) + self.bias

class PixelNorm(nn.Module):
  def __init__(self):
    super(PixelNorm, self).__init__()
    self.epsilon = 1e-8

  def forward(self,x):
    x = x / torch.sqrt(torch.mean(x**2,dim=1,keepdim=True) + self.epsilon)
    return x

class Mapping(nn.Module):
  def __init__(self, z_dim,w_dim):
    super(Mapping, self).__init__()
    self.w_dim=w_dim
    self.initial = [PixelNorm(),
                    WSLinear(z_dim, w_dim),
                    nn.LeakyReLU(0.2)]

    self.layers = nn.ModuleList(self.initial)
    for i in range(7):
      self.layers.append(WSLinear(w_dim,w_dim))
      self.layers.append(nn.LeakyReLU(0.2))

  def forward(self,x):
    for layer in self.layers:
      x = layer(x)
    return x

class InjectNoise(nn.Module):
  def __init__(self,channels):
    super(InjectNoise, self).__init__()
    self.weights = nn.Parameter(torch.zeros(1,channels,1,1),requires_grad=True)

  def forward(self,x):
    noise = torch.randn((x.shape[0],1,x.shape[2],x.shape[3]),device=x.device)
    return x + (self.weights*noise)

class AdaIn(nn.Module):
  def __init__(self,w_dim,channels):
    super(AdaIn, self).__init__()
    self.style_scale = WSLinear(w_dim,channels)
    self.style_bias = WSLinear(w_dim,channels)
    self.norm = nn.InstanceNorm2d(channels)

  def forward(self,x,w):
    x = self.norm(x)
    style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
    style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
    return (style_scale*x) + style_bias

class GenBlock(nn.Module):
  def __init__(self, in_channels, out_channels, w_dim, kernel_size=3, stride=1, padding=1, gain=2):
      super(GenBlock, self).__init__()
      self.conv1=WSConv2d(in_channels, out_channels,kernel_size=3, stride=1, padding=1,gain=2)
      self.conv2=WSConv2d(out_channels, out_channels,kernel_size=3, stride=1, padding=1,gain=2)
      self.noise1=InjectNoise(out_channels)
      self.noise2=InjectNoise(out_channels)
      self.adain1=AdaIn(w_dim,out_channels)
      self.adain2=AdaIn(w_dim,out_channels)
      self.leaky1=nn.LeakyReLU(0.2,inplace=True)
      self.leaky2=nn.LeakyReLU(0.2,inplace=True)

  def forward(self, x, w):
    x = self.conv1(x)
    x = self.noise1(x)
    x = self.leaky1(x)
    x = self.adain1(x,w)

    x = self.conv2(x)
    x = self.noise2(x)
    x = self.leaky2(x)
    x = self.adain2(x,w)

    return x

class Generator(nn.Module): # change in PixelNorm() in 1st layer
  def __init__(self,z_dim,w_dim,in_channels,img_channels):
    super(Generator, self).__init__()
    self.const = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
    self.mapping = Mapping(z_dim,w_dim)

    self.initial_noise1 = InjectNoise(in_channels)
    self.initial_adain1 = AdaIn(w_dim,in_channels)
    self.initial_conv = WSConv2d(in_channels,in_channels,1,1,0)
    self.initial_noise2 = InjectNoise(in_channels)
    self.initial_act = nn.LeakyReLU(0.2)
    self.initial_adain2 = AdaIn(w_dim,in_channels)
    self.initial_rgb = WSConv2d(in_channels,img_channels,1,1,0)

    self.prog_layers=nn.ModuleList()
    self.rgb_layers=nn.ModuleList([self.initial_rgb])

    for f in factors[1:]:
      self.prog_layers.append(
          GenBlock(in_channels,int(in_channels*f),w_dim,3,1,1)
      )
      in_channels = int(in_channels*f)
      self.rgb_layers.append(
          WSConv2d(in_channels,img_channels,1,1,0)
      )

  def fade_in(self,alpha,rgb_x,rgb_up_sampled):
      s = alpha*rgb_x + (1-alpha)*rgb_up_sampled
      return F.tanh(s)

  def forward(self,z,alpha,steps):
    w = self.mapping(z)
    x = self.initial_adain1(self.initial_noise1(self.const),w)
    x = self.initial_adain2(self.initial_act(self.initial_noise2(self.initial_conv(x))),w)
    if steps==0:
      return self.initial_rgb(x)

    for i in range(steps):
      up_sampled = F.interpolate(x,scale_factor=2,mode='bilinear')
      x = self.prog_layers[i](up_sampled,w)

    rgb_up_sampled = self.rgb_layers[steps-1](up_sampled) # self.rgb_layer[i]
    rgb_x = self.rgb_layers[steps](x) # self.rgb_layer[i-1]
    x = self.fade_in(alpha,rgb_x,rgb_up_sampled)
    return x

In [None]:
gen = Generator(z_dim=128,w_dim=512,in_channels=512,img_channels=3)
z = torch.rand((1,128))
img = gen(z,alpha=0.3,steps=7)
print(img.shape)

torch.Size([1, 3, 512, 512])


In [None]:
class CNN_Block(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
      super(CNN_Block, self).__init__()
      self.layers = nn.Sequential(
          WSConv2d(in_channels,in_channels,kernel_size=3, stride=1, padding=1),
          nn.LeakyReLU(0.2),
          WSConv2d(in_channels, out_channels,kernel_size=3, stride=1, padding=1),
          nn.LeakyReLU(0.2),
      )

  def forward(self, x):
    return self.layers(x)

class MiniBatchStd(nn.Module):
  def __init__(self):
    super(MiniBatchStd, self).__init__()

  def forward(self,x):
    B,C,H,W = x.shape
    x_std = torch.std(x,dim=0).mean().repeat(B,1,H,W)  # mean:1 --> std: 1,1,H,W
    return torch.cat([x,x_std],dim=1) # cat[(N,C,H,W),(N,1,H,W)] ] = (N,C+1,W,H)

class Discriminator(nn.Module): # change in PixelNorm() in 1st layer
  def __init__(self,in_channels,img_channels):
    super(Discriminator, self).__init__()

    self.initial_rgb = nn.Sequential(
         WSConv2d(img_channels,in_channels,1,1,0),
         nn.LeakyReLU(0.2),
         WSConv2d(in_channels,in_channels,1,1,0),
         nn.LeakyReLU(0.2))

    self.prog_layers=nn.ModuleList()
    self.rgb_layers=nn.ModuleList([self.initial_rgb])

    self.last = nn.Sequential(
        MiniBatchStd(),
        WSConv2d(in_channels+1,in_channels,3,2,1),
        nn.LeakyReLU(0.2),
        WSConv2d(in_channels,in_channels,4,2,1),
        nn.LeakyReLU(0.2),
        nn.Flatten(),
        nn.Linear(in_channels,1)
    )

    for f in factors[1:]:  #1,1,1,0.5
      self.rgb_layers.append(
          WSConv2d(img_channels,int(in_channels*f),1,1,0)
      )
      self.prog_layers.append(
          CNN_Block(int(in_channels*f),in_channels,3,1,1)
      )
      in_channels = int(in_channels*f)


  def fade_in(self,alpha,generated,rgb_x):
      return alpha*generated + (1-alpha)*rgb_x

  def forward(self,x,alpha,steps):

    if steps==0:
      x = self.rgb_layers[steps](x)
    else:
      rgb_x = self.rgb_layers[steps-1](F.avg_pool2d(x,2))
      x = self.rgb_layers[steps](x)
      for i in range(steps-1,-1,-1):
        x = self.prog_layers[i](x)
        x = F.avg_pool2d(x,2)
        if i==steps-1:
          x=self.fade_in(alpha,x,rgb_x)
    return self.last(x)

# Config and Utils

In [None]:
import torch
import random
import numpy as np
import os
import torchvision
import torch.nn as nn
from torchvision.utils import save_image
from scipy.stats import truncnorm
import cv2
import torch
from math import log2

class Config:
  START_TRAIN_AT_IMG_SIZE = 128
  DATASET = '/content/data_faces/'
  CHECKPOINT_GEN = "/content/generator.pth"
  CHECKPOINT_CRITIC = "/content/critic.pth"
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
  SAVE_MODEL = True
  LOAD_MODEL = False
  LEARNING_RATE = 1e-3
  BATCH_SIZES = [32, 32, 32, 16, 16, 16, 16, 8, 4]
  CHANNELS_IMG = 3
  Z_DIM = 512  # should be 512 in original paper
  W_DIM = 512
  IN_CHANNELS = 512  # should be 512 in original paper
  CRITIC_ITERATIONS = 1
  LAMBDA_GP = 10
  PROGRESSIVE_EPOCHS = [30] * len(BATCH_SIZES)
  FIXED_NOISE = torch.randn(8, Z_DIM, 1, 1).to(DEVICE)
  NUM_WORKERS = 4

config = Config()

def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location="cuda")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def generate_examples(gen, steps, truncation=0.7, n=100):
    """
    Tried using truncation trick here but not sure it actually helped anything, you can
    remove it if you like and just sample from torch.randn
    """
    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.tensor(truncnorm.rvs(-truncation, truncation, size=(1, config.Z_DIM, 1, 1)), device=config.DEVICE, dtype=torch.float32)
            img = gen(noise, alpha, steps)
            save_image(img*0.5+0.5, os.path.join(config.SAVED_PATH,f"img_{i}.png"))
    gen.train()


# Data Loader

In [None]:
import zipfile
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

with zipfile.ZipFile("/content/drive/MyDrive/celeba.zip","r") as zip_ref:
  zip_ref.extractall("data_faces/")

def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(config.CHANNELS_IMG)],
                [0.5 for _ in range(config.CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = config.BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=config.DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=True,
    )
    return loader, dataset

In [None]:
""" Training of ProGAN using WGAN-GP loss"""

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from math import log2
from tqdm import tqdm

torch.backends.cudnn.benchmarks = True

# Print losses occasionally and print to tensorboard
def plot_to_tensorboard(
    writer, loss_critic, loss_gen, real, fake, tensorboard_step
):
    writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)

    with torch.no_grad():
        # take out (up to) 8 examples to plot
        img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
        writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
        writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)


def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
    tensorboard_step,
    writer,
    scaler_gen,
    scaler_critic,
):
    loop = tqdm(loader, leave=True)
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(config.DEVICE)
        cur_batch_size = real.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
        # which is equivalent to minimizing the negative of the expression
        noise = torch.randn(cur_batch_size, config.Z_DIM).to(config.DEVICE)

        with torch.cuda.amp.autocast():
            fake = gen(noise, alpha, step)
            critic_real = critic(real, alpha, step)
            critic_fake = critic(fake.detach(), alpha, step)
            gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake))
                + config.LAMBDA_GP * gp
                + (0.001 * torch.mean(critic_real ** 2))
            )

        opt_critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        # Update alpha and ensure less than 1
        alpha += cur_batch_size / (
            (config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        if batch_idx % 500 == 0:
            with torch.no_grad():
                fixed_fakes = gen(config.FIXED_NOISE, alpha, step) * 0.5 + 0.5
            plot_to_tensorboard(
                writer,
                loss_critic.item(),
                loss_gen.item(),
                real.detach(),
                fixed_fakes.detach(),
                tensorboard_step,
            )
            tensorboard_step += 1

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )

    return tensorboard_step, alpha


def main():
    # initialize gen and disc, note: discriminator should be called critic,
    # according to WGAN paper (since it no longer outputs between [0, 1])
    # but really who cares..
    gen = Generator(
        config.Z_DIM, config.W_DIM,config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
    ).to(config.DEVICE)
    critic = Discriminator(
        config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
    ).to(config.DEVICE)

    # initialize optimizers and scalers for FP16 training
    opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
    opt_critic = optim.Adam(
        critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
    )
    scaler_critic = torch.cuda.amp.GradScaler()
    scaler_gen = torch.cuda.amp.GradScaler()

    # for tensorboard plotting
    writer = SummaryWriter(f"logs/gan1")

    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_CRITIC, critic, opt_critic, config.LEARNING_RATE,
        )

    gen.train()
    critic.train()

    tensorboard_step = 0
    # start at step that corresponds to img size that we set in config
    step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
    for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
        alpha = 1e-5  # start with very low alpha
        loader, dataset = get_loader(4 * 2 ** step)  # 4->0, 8->1, 16->2, 32->3, 64 -> 4
        print(f"Current image size: {4 * 2 ** step}")

        for epoch in range(num_epochs):
            print(f"Epoch [{epoch+1}/{num_epochs}]")
            tensorboard_step, alpha = train_fn(
                critic,
                gen,
                loader,
                dataset,
                step,
                alpha,
                opt_critic,
                opt_gen,
                tensorboard_step,
                writer,
                scaler_gen,
                scaler_critic,
            )

            if config.SAVE_MODEL:
                save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
                save_checkpoint(critic, opt_critic, filename=config.CHECKPOINT_CRITIC)

        step += 1  # progress to the next img size

In [None]:
main()



Current image size: 128
Epoch [1/30]


  0%|          | 0/12663 [04:54<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x1 and 512x512)

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir logs