<a href="https://colab.research.google.com/github/Bustion11/NN-projects/blob/main/PGAN/ProGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import log2

factors = [1, 1, 1, 1, 1/2, 1/4, 1/8, 1/16, 1/32]

In [None]:
class WSConv2d(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
      super().__init__()
      self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
      self.scale = (gain/(in_channels * kernel_size ** 2)) ** 0.5
      self.bias = self.conv.bias
      self.conv.bias = None

      # initiliaze the layer

      nn.init.normal_(self.conv.weight)
      nn.init.zeros_(self.bias)

  def forward(self, x):
    return self.conv(x*self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

In [None]:
class PixelNorm(nn.Module):
  def __init__(self):
      super().__init__()
      self.epsilon = 1e-8
  def forward(self, x):
    return x/torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)

In [None]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, use_pixnorm=True):
    super().__init__()
    self.conv1 = WSConv2d(in_channels, out_channels)
    self.conv2 = WSConv2d(out_channels, out_channels)
    self.leaky = nn.LeakyReLU(0.2)
    self.pn = PixelNorm()
    self.use_pn = use_pixnorm

  def forward(self, x):
    x = self.leaky(self.conv1(x))
    x = self.pn(x) if self.use_pn else x

    x = self.leaky(self.conv2(x))
    x = self.pn(x) if self.use_pn else x
    return x

In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim, in_channels, img_channels=3):
    super().__init__()
    self.initial = nn.Sequential(
        PixelNorm(),
        nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0), # 1x1 -> 4x4
        nn.LeakyReLU(0.2),
        WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
        nn.LeakyReLU(0.2),
        PixelNorm()
    )

    self.initial_rgb = WSConv2d(in_channels, img_channels, kernel_size=1, stride=1, padding=0)
    self.prog_blocks, self.rgb_layers = (nn.ModuleList(), nn.ModuleList([self.initial_rgb]))

    for i in range(len(factors) - 1):
      conv_in_c = int(in_channels * factors[i])
      conv_out_c = int(in_channels * factors[i+1])

      self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
      self.rgb_layers.append(WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0))

  def _fade_in(self, alpha, upscaled, generated):
    return torch.tanh(alpha*generated + (1-alpha)*upscaled)
  
  def forward(self, x, alpha, steps):
    out = self.initial(x)

    if steps == 0:
      return self.initial_rgb(out)
    
    for step in range(steps):
      upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
      out = self.prog_blocks[step](upscaled)
    
    final_upscaled = self.rgb_layers[steps-1](upscaled)
    final_out = self.rgb_layers[steps](out)
    return self._fade_in(alpha, final_upscaled, final_out)

In [None]:
model = Generator(512, 512)
example = torch.randn(3, 512, 1, 1)
print(example.requires_grad)
example = model(example, 1, 4)
print(example.shape)
print(example.requires_grad)

False
torch.Size([3, 3, 64, 64])
True


In [None]:
class Discriminator(nn.Module):
  def __init__(self, z_dim, in_channels, img_channels=3):
    super().__init__()
    self.prog_blocks, self.rgb_layers = nn.ModuleList(), nn.ModuleList()
    self.leaky = nn.LeakyReLU(0.2)

    for i in range(len(factors)-1, 0, -1):
      conv_in_c = int(in_channels*factors[i])
      conv_out_c = int(in_channels*factors[i-1])
      self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c, use_pixnorm=False))
      self.rgb_layers.append(WSConv2d(img_channels, conv_in_c, kernel_size=1, stride=1, padding=0))

    self.initial_rgb = WSConv2d(img_channels, in_channels, kernel_size=1, stride=1, padding=0)
    self.rgb_layers.append(self.initial_rgb)
    self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

    self.final_block = nn.Sequential(
        WSConv2d(in_channels+1, in_channels, kernel_size=3, stride=1, padding=1),
        nn.LeakyReLU(0.2),
        WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
        nn.LeakyReLU(0.2),
        WSConv2d(in_channels, 1, kernel_size=1, padding=0, stride=1)
    )

  def _fade_in(self, alpha, downscaled, out):
    return alpha*out + (1-alpha)*downscaled
  
  def minibatch_std(self, x):
    batch_statistics = torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
    return torch.cat([x, batch_statistics], dim=1)

  def forward(self, x, alpha, steps):
    cur_step = len(self.prog_blocks) - steps
    out = self.leaky(self.rgb_layers[cur_step](x))

    if steps == 0:
      out = self.minibatch_std(out)
      return self.final_block(out).view(out.shape[0], -1)
    
    downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
    out = self.avg_pool(self.prog_blocks[cur_step](out))
    out = self._fade_in(alpha, downscaled, out)

    for step in range(cur_step+1, len(self.prog_blocks)):
      out = self.prog_blocks[step](out)
      out = self.avg_pool(out)

    out = self.minibatch_std(out)
    return self.final_block(out).view(out.shape[0], -1)

In [None]:
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [None]:
# Download dataset
import torchvision.datasets as datasets
DATASET = datasets.StanfordCars('/data', download = True)

In [None]:
from os import cpu_count
# Configuration file
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
START_TRAINING_IMG_SIZE = 4
LR = 1e-3
BATCH_SIZES = [32, 32, 32, 16, 16, 16, 16, 8, 4]
IMG_CHANNELS = 3
Z_DIM = 256
IN_CHANNELS = 256
CRITIC_ITERATIONS = 1
LAMBDA_GP = 10
PROGRESSIVE_EPOCHS = [30] * len(BATCH_SIZES)
WORKERS = cpu_count()
ALPHA_GP = 10

In [None]:
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
#from tqdm import tqdm

torch.backends.cudnn.benchmarks = True

def get_loader(img_size):
  transformations = transforms.Compose([
                                  transforms.Resize((img_size, img_size)),
                                  transforms.ToTensor(),
                                  transforms.RandomHorizontalFlip(0.5),
                                  transforms.Normalize([0.5 for _ in range(IMG_CHANNELS)], [0.5 for _ in range(IMG_CHANNELS)])
  ])
  batch_size = BATCH_SIZES[int(log2(img_size/4))]
  dataset = datasets.StanfordCars(root="/data", transform=transformations)
  loader = DataLoader(dataset, batch_size, shuffle=True, pin_memory=True, num_workers=WORKERS)
  return loader, dataset

def train_loop(critic, generator, loader, dataset, step, alpha, opt_critic, opt_generator, scaler_generator, scaler_critic):
  for batch_idx, (real, _) in enumerate(loader):
    real = real.to(DEVICE)
    cur_batch_size = real.shape[0]

    noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(DEVICE)

    # Train the critic
    with torch.cuda.amp.autocast():
      fake = generator(noise, alpha, step)
      scores_real = critic(real, alpha, step)
      scores_fake = critic(fake.detach(), alpha, step)

      gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
      loss_critic = (
          -(torch.mean(scores_real) - torch.mean(scores_fake))
          + LAMBDA_GP*gp
          + (0.001 * torch.mean(scores_real**2))
      )
    
    opt_critic.zero_grad()
    scaler_critic.scale(loss_critic).backward()
    scaler_critic.step(opt_critic)
    scaler_critic.update()

    # Train the generator
    with torch.cuda.amp.autocast():
      gen_fake = critic(fake, alpha, step)
      loss_gen = -torch.mean(gen_fake)
    
    opt_generator.zero_grad()
    scaler_generator.scale(loss_gen).backward()
    scaler_generator.step(opt_generator)
    scaler_generator.update()

    # Update the alpha
    alpha += cur_batch_size / ((PROGRESSIVE_EPOCHS[step]*0.5)*len(dataset))
    alpha = min(alpha, 1)

    if batch_idx%500==0:
      print(f"GP: {gp.cpu().detach().numpy()}, loss: {loss_critic}")
    
  return alpha

In [None]:
generator = Generator(Z_DIM, IN_CHANNELS, IMG_CHANNELS).to(DEVICE)
critic = Discriminator(Z_DIM, IN_CHANNELS, IMG_CHANNELS).to(DEVICE)

scaler_generator = torch.cuda.amp.GradScaler()
scaler_critic = torch.cuda.amp.GradScaler()

generator.train()
critic.train()

In [None]:
optimizer_generator = optim.Adam(generator.parameters(), lr=LR, betas=(0, 0.99))
optimizer_critic = optim.Adam(critic.parameters(), lr=LR, betas=(0, 0.99))

In [None]:
for _ in range(5):
  step=int(log2(START_TRAINING_IMG_SIZE/4))
  for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5
    loader, dataset = get_loader(4 * 2**step)
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(num_epochs):
      print(f"Epoch: [{epoch+1}/{num_epochs}]")
      alpha = train_loop(critic, generator, loader, dataset, step, alpha, optimizer_critic, optimizer_generator, scaler_generator, scaler_critic)
    step += 1

torch.save(generator.state_dict(), 'generator_weights.pth')
torch.save(generator, 'generator.pth')

In [None]:
# Check how discriminator works
with torch.no_grad():
  example = generator(torch.randn(4, Z_DIM, 1, 1).to(DEVICE), 0.5, 5)
  scores = critic(example, 1, 5)
  print(scores, "\n", scores.shape)

In [None]:
fixed_noise = torch.randn(20, Z_DIM, 1, 1).to(DEVICE)

In [None]:
# Generate some examples
import matplotlib.pyplot as plt

with torch.no_grad():
  generator.eval()
  example = generator(fixed_noise, 1, 8) * 0.5 + 0.5
  print(example.shape)
  grid = torchvision.utils.make_grid(example.cpu().detach(), nrow=5)
  plt.figure(figsize=(20, 16))
  plt.imshow(grid.permute(1, 2, 0))

  plt.show()

generator.train()
print()