<a href="https://colab.research.google.com/github/Bustion11/NN-projects/blob/main/SGAN/StyleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision.datasets import FashionMNIST
import torchvision.transforms as T
import os
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
import matplotlib.pyplot as plt

In [None]:
# Helper classes
class WSConv2d(nn.Module):
  """
  Convolution with equalised learning rate
  """
  def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
    super().__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
    self.scale = torch.tensor((gain/(in_channels * kernel_size ** 2)) ** 0.5)
    self.bias = self.conv.bias
    self.conv.bias = None

      # initiliaze the layer

    nn.init.normal_(self.conv.weight)
    nn.init.zeros_(self.bias)

  def forward(self, x):
    x = self.conv(x*self.scale)
    x += self.bias.view(1, self.bias.shape[0], 1, 1)
    return x


class PixelNorm(nn.Module):
  def __init__(self):
    super().__init__()
    self.epsilon = 1e-8
  def forward(self, x):
    return x/torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)


class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, use_pn=True):
    super().__init__()
    self.conv = WSConv2d(in_channels, out_channels)
    self.act = nn.LeakyReLU(0.2)
    self.norm = PixelNorm() if use_pn else nn.Identity()
  
  def forward(self, x):
    return self.norm(self.act(self.conv(x)))


class AdaIN(nn.Module):
  """
  Description: Adaptive instance normalization per each features channel.

  Formula: 
          AdaIN(x, y) = y_style*((x_i-mean(x_i))/std(x_i))+y_bias

  Input: x, style
         x shape: (N, C, H, W);
         style_s shape: (N, C);
         style_b shape: (N, C);
  Output: normalized x (N, C, H, W)
  """
  def __init__(self, eps=1e-6):
    self.eps = eps
    super().__init__()
  
  def forward(self, x, y_s, y_b):
    x = (x - torch.mean(input=x, dim=[-1, -2], keepdim=True)
         /(torch.std(x, dim=[-1, -2], keepdim=True) + self.eps))
    x = y_s[:, :, None, None] * x                                               #torch.einsum("n c h w, n c -> n c h w", style_s, x)
    return torch.add(x, y_b[:, :, None, None])


class Adaptive_style(nn.Module):
  """
  In the original paper: A

  Desc: 'Learned affine transformations'
  Input: W of shape (N, F)
  Output: STYLE of shape (N, C, 1), (N, C, 1)
  """
  def __init__(self, n_channels, features=512):
    super().__init__()
    self.n_channels = n_channels
    self.y_s = nn.Linear(features, n_channels)
    self.y_b = nn.Linear(features, n_channels)

    nn.init.normal_(self.y_s.weight)
    nn.init.normal_(self.y_b.weight)

    nn.init.ones_(self.y_s.bias)
    nn.init.zeros_(self.y_b.bias)

  def forward(self, w):
    y_style = self.y_s(w)
    y_bias = self.y_b(w)

    return y_style, y_bias


class Adaptive_noise(nn.Module):
  """
  In the original paper: B

  Desc: 'Applies Learned per-channel scaling factors to the noise input'
  Input: noise (N, 1, H, W)
  Output: per_channel_scaled_noise (N, C, H, W)
  """
  def __init__(self, n_channels):
    super().__init__()
    self.scale = nn.Parameter(torch.zeros(n_channels), requires_grad=True)

  def forward(self, noise):
    return noise*self.scale[None, :, None, None]

In [None]:
# Main block
class Network_block(nn.Module):
  def __init__(self, in_channels, out_channels, features=512):
    super().__init__()
    self.convolution = ConvBlock(in_channels, out_channels)
    self.B = Adaptive_noise(out_channels)
    self.A = Adaptive_style(out_channels, features)
    self.adain = AdaIN()
  
  def forward(self, x, w, noise):
    x = self.convolution(x)
    x = x + self.B(noise)
    y_s, y_b = self.A(w)
    return self.adain(x, y_s, y_b)

In [None]:
#Network classes
class Mapping_network(nn.Module):
  """
  MLP network that maps latent noise Z to W

  RELU IMPLEMENTED
  """
  def __init__(self, n_features=512, normalization=None, n_layers=8):
    super().__init__()
    self.mlp = nn.ModuleList([])
    for _ in range(n_layers):
      self.mlp.append(nn.Sequential(
          nn.Linear(n_features, n_features),
          nn.ReLU(),
          normalization if normalization is not None else nn.Identity(),
      ))

  def forward(self, x):
    for layer in self.mlp:
      x = layer(x)
    
    return x

config = [(512, 512), #8
          (512, 512), #16
          (512, 512), #32
          (512, 256), #64
          (256, 128), #128
          (128, 64)]  #256

class Synthesis_network(nn.Module):
  """
  This is the main network. It implements such feature as progressive growing.
  Forward pass arguments:
    W, steps, alpha, batch_size
  
  """
  def __init__(self, img_channels, channels=512, configuration=config):
    super().__init__()

    self.const_noise = torch.ones(1, channels, 4, 4)

    self.initial_block = nn.ModuleList([
                                        Adaptive_noise(channels), # B
                                        Adaptive_style(channels), # A
                                        AdaIN(),
                                        Network_block(channels, channels)
    ])

    self.initial_rgb = WSConv2d(channels, img_channels, 1, padding=0)

    self.prog_blocks = nn.ModuleList([])
    self.to_rgb = nn.ModuleList([])

    for in_channels, out_channels in configuration: 
      self.prog_blocks.append(
          nn.ModuleList([
                         Network_block(in_channels, out_channels),
                         Network_block(out_channels, out_channels)
                         ]
                        )
          )
      self.to_rgb.append(WSConv2d(out_channels, img_channels, 1, padding=0))

  def _fade_in(self, alpha, upscaled, generated):
    return torch.sigmoid(alpha*generated + (1.0-alpha)*upscaled)#torch.tanh(alpha*generated + (1.0-alpha)*upscaled)
    
  def _initial_pass(self, W, output, batch_size):
    B, A, adain, net_block = self.initial_block
    output += B(torch.randn(batch_size, 1, output.shape[2], output.shape[3])
                .to(DEVICE))
    y_s, y_b = A(W)
    output = adain(output, y_s, y_b)
    output = net_block(output, W,
                       (torch.randn(batch_size, 1, output.shape[2], output.shape[3])
                       .to(DEVICE)))
    return output

  def forward(self, W, steps, alpha, batch_size):
    # assert steps > len(self.prog_blocks)+1
    # Batchify the constant noise
    output = torch.repeat_interleave(self.const_noise, batch_size, 0).to(DEVICE)

    # Initial block
    # -------------
    output = self._initial_pass(W, output, batch_size)
    # -------------
    
    # Main block
    # -------------
    if steps == 0:
      return torch.sigmoid(self.initial_rgb(output)) #torch.tanh(self.initial_rgb(output))
    
    for step in range(steps):
      upscaled = torch.nn.functional.interpolate(output, scale_factor=2, mode="bilinear") 
      layer1, layer2 = self.prog_blocks[step]
      upscaled = layer1(upscaled, W, 
                        (torch.randn(batch_size, 1, upscaled.shape[2], upscaled.shape[3])
                        .to(DEVICE)))
      output = layer2(upscaled, W, 
                      (torch.randn(batch_size, 1, upscaled.shape[2], upscaled.shape[3])
                      .to(DEVICE)))

    rgb_upscaled = self.to_rgb[steps-1](upscaled)
    rgb_output = self.to_rgb[steps-1](output)
    # -------------
    
    return self._fade_in(alpha, rgb_upscaled, rgb_output) 

  def manual_forward(self, x, W, step, alpha, batch_size, to_rgb=False):
    output = x
    
    if step == 0:
      output = self._initial_pass(W, output, batch_size)
      return self.initial_rgb(output) if to_rgb else output
    
    upscaled = torch.nn.functional.interpolate(output, scale_factor=2, mode="bilinear")
    layer1, layer2 = self.prog_blocks[step]
    upscaled = layer1(upscaled, W, 
                        (torch.randn(batch_size, 1, upscaled.shape[2], upscaled.shape[3])
                        .to(DEVICE)))
    output = layer2(upscaled, W, 
                      (torch.randn(batch_size, 1, upscaled.shape[2], upscaled.shape[3])
                      .to(DEVICE)))
    
    rgb_upscaled = self.to_rgb[step](upscaled) if to_rgb else upscaled
    rgb_output = self.to_rgb[step](output) if to_rgb else output

    return self._fade_in(alpha, rgb_upscaled, rgb_output)

In [None]:
class Discriminator(nn.Module):
  def __init__(self, img_channels, channels=512, configuration=config):
    super().__init__()
    self.dim = channels

    configuration = configuration[::-1]
    self.blocks, self.from_rgb = nn.ModuleList([]), nn.ModuleList([])
    self.leaky = nn.LeakyReLU(0.2)
    self.pool = nn.AvgPool2d(2, 2)

    for channels in configuration:
      in_channels, out_channels = [*reversed(channels)]
      self.blocks.append(nn.Sequential(
          WSConv2d(in_channels, out_channels),
          self.leaky,
          WSConv2d(out_channels, out_channels),
          self.leaky
      ))
      
      self.from_rgb.append(WSConv2d(img_channels, in_channels, 1, 1, 0))
    
    final_rgb = WSConv2d(img_channels, self.dim, 1, 1, 0)
    self.final_block = nn.Sequential(
        WSConv2d(self.dim+1, self.dim, 3, 1, 1),
        self.leaky,
        WSConv2d(self.dim, self.dim, 4, 1, 0),
        self.leaky,
        WSConv2d(self.dim, 1, 1, 1, 0)
    )

    self.from_rgb.append(final_rgb)

  def _fade_in(self, alpha, downscaled, out):
    return alpha*out + (1-alpha)*downscaled
  
  def minibatch_std(self, x):
    batch_statistics = (torch.std(x, dim=0).mean()
    .repeat(x.shape[0], 1, x.shape[2], x.shape[3]))
    
    return torch.cat([x, batch_statistics], dim=1)
  
  def forward(self, x, steps, alpha):
    cur_step = len(self.blocks) - steps
    out = self.leaky(self.from_rgb[cur_step](x))

    if steps == 0:
      print("ERROR")
      out = self.minibatch_std(out)
      return self.final_block(out).reshape(out.shape[0], -1) 

    downscaled = self.leaky(self.from_rgb[cur_step +  1](self.pool(x)))
    out = self.pool(self.blocks[cur_step](out))
    out = self._fade_in(alpha, downscaled, out)

    for step in range(cur_step+1, len(self.blocks)):
      out = self.blocks[step](out)
      out = self.pool(out)

    out = self.minibatch_std(out)
    return self.final_block(out).reshape(out.shape[0], -1)


In [None]:
IMG_CHANNELS=1

In [None]:
W = torch.randn(5, 512).to(DEVICE)

In [None]:
generator = Synthesis_network(IMG_CHANNELS)
critic = Discriminator(IMG_CHANNELS)

In [None]:
step=4
alpha=1
batch_size=5

In [None]:
generated_image = generator(W, step, alpha, batch_size)
scores = critic(generated_image, step, alpha)

In [None]:
generated_image[4][0]

In [None]:
scores.reshape(-1)

tensor([-0.0100, -0.0100,  0.0005,  0.0029, -0.0100],
       grad_fn=<ReshapeAliasBackward0>)

In [None]:
def get_loader(steps, batch_size=8):
  img_size = 4 * 2**steps
  data = FashionMNIST(root="/data", transform=T.Compose([
                                                         T.Resize(img_size),
                                                         T.ToTensor(),
                                                         T.Normalize([0.5 for _ in range(IMG_CHANNELS)],
                                                                     [0.5 for _ in range(IMG_CHANNELS)]),
  ]), download=True)

  dataloader = torch.utils.data.DataLoader(data, batch_size, True, 
                                           num_workers=os.cpu_count(),
                                           pin_memory=True)
  return dataloader, data

In [None]:
def calculate_penalty(real, critic, step, alpha):
  real = torch.autograd.Variable(real, requires_grad=True)
  real_scores = critic(real, step, alpha)
  penalty = torch.autograd.grad(
        inputs=real,
        outputs=real_scores,
        grad_outputs=torch.ones_like(real_scores).to(DEVICE),
        create_graph=True,
        retain_graph=True
    )[0].reshape(real.shape[0], -1)

  penalty = torch.norm(penalty, dim=1).pow(2)
  return penalty.mean()

In [None]:
# For CPU
def train_loop(alpha, step, loader, dataset,
               generator, style_generator, critic,
               opt_generator, opt_s_generator, opt_critic, loss_fn):
  generator.train()
  style_generator.train()
  critic.train()
  for batch_idx, (x, _) in enumerate(loader):
    # Sample Z
    x = x.to(DEVICE)
    Z = torch.randn(x.shape[0], 512).to(DEVICE)

    # Train critic
    W = style_generator(Z)
    fake = generator(W, step, alpha, x.shape[0])

    real_scores = critic(x, step, alpha).reshape(-1)
    real_loss = loss_fn(real_scores, torch.ones_like(real_scores))

    fake_scores = critic(fake, step, alpha).reshape(-1)
    fake_loss= loss_fn(fake_scores, torch.zeros_like(fake_scores))

    penalty = calculate_penalty(x.detach(), critic, step, alpha)

    loss_critic = (real_loss+fake_loss)/2 + 10/2 * penalty

    if batch_idx % 200 == 0:
      print("Debug: ")
      print("Alpha: ", alpha)
      print("Critic loss: ", loss_critic.item())
      print("Penalty: ", penalty.item())


    opt_critic.zero_grad()
    loss_critic.backward(retain_graph=True)
    opt_critic.step()

    # Train generator
    
    scores = critic(fake, step, alpha).reshape(-1)
    loss_generator = loss_fn(scores, torch.ones_like(scores)) 
    
    if batch_idx % 200 == 0:
      print("Generator loss:", loss_generator.item())
    
    opt_generator.zero_grad()
    loss_generator.backward(retain_graph=True)
    opt_generator.step()

    opt_s_generator.zero_grad()
    opt_s_generator.step()

    alpha += x.shape[0] / ((x.shape[0]*0.5)*len(dataset))
    alpha = min(alpha, 1)

  return alpha

In [None]:
generator = Synthesis_network(IMG_CHANNELS).to(DEVICE)
style_generator = Mapping_network().to(DEVICE)
critic = Discriminator(IMG_CHANNELS).to(DEVICE)

In [None]:
opt_critic = torch.optim.Adam(critic.parameters(), lr=2e-3)
opt_generator = torch.optim.Adam(generator.parameters(), lr=2e-3)
opt_style_gen = torch.optim.Adam(style_generator.parameters(), lr=2e-3*0.01)


loss_fn = nn.BCEWithLogitsLoss()

In [None]:
alpha=1e-5

In [None]:
loader, dataset = get_loader(2)
for _ in range(5):
  print("Step: ", _+1)
  alpha = train_loop(alpha, 2, loader, dataset, generator, style_generator,
                      critic, opt_generator, opt_style_gen, opt_critic, loss_fn)

In [None]:
step = 1
for _ in range(5):
  
  alpha = 1e-5
  loader, dataset = get_loader(step)
  print(f"Current image size: {4 * 2 ** step}")

  for epoch in range(10):
    print(f"Epoch: [{epoch+1}/{10}]")
    alpha = train_loop(alpha, step, loader, dataset, generator, style_generator,
                      critic, opt_generator, opt_style_gen, opt_critic, loss_fn)
  step += 1

In [None]:
noise = torch.randn(8, 512).to(DEVICE)

In [None]:
style_generator.eval()
generator.eval()
W = style_generator(noise)
img = generator(noise, 1, 1, 8)
plt.imshow(img.detach().cpu()[3][0])

In [None]:
# For GPU 
# IT DOES NOT WORK
def train_loop_gpu(alpha, step, loader, dataset,
                   generator, style_generator, critic,
                   opt_generator, opt_s_generator, opt_critic, loss_fn,
                   gs_generator, gs_s_generator, gs_critic
                   ):
  for batch_idx, (x, _) in enumerate(loader):
    # Sample Z
    x = x.to(DEVICE)
    Z = torch.randn(x.shape[0], 512).to(DEVICE)

    # Train critic
    with torch.cuda.amp.autocast():
      W = style_generator(Z)
      fake = generator(W, step, alpha, x.shape[0])
      print(fake)
      
      real_scores = critic(x, step, alpha)#.reshape(-1)
      real_loss = loss_fn(real_scores, torch.ones_like(real_scores))

      fake_scores = critic(fake, step, alpha)#.reshape(-1)
      fake_loss= loss_fn(fake_scores, torch.zeros_like(fake_scores)) 

      loss_critic = (real_loss+fake_loss)/2

    if batch_idx % 100 == 0:
      print("Debug: ")
      print("Alpha: ", alpha)
      print("Critic loss:", loss_critic.item())
      #print(fake)


    opt_critic.zero_grad()
    gs_critic.scale(loss_critic).backward(retain_graph=True)
    gs_critic.step(opt_critic)
    gs_critic.update()


    # Train generator
    with torch.cuda.amp.autocast():
      scores = critic(fake, step, alpha).reshape(-1)
      loss_generator = loss_fn(scores, torch.ones_like(scores)) 
    
    if batch_idx % 100 == 0:
      print("Generator loss:", loss_generator.item())
    
    opt_generator.zero_grad()
    gs_generator.scale(loss_generator).backward(retain_graph=True)
    gs_generator.step(opt_generator)
    gs_generator.update()

    opt_s_generator.zero_grad()
    gs_s_generator.scale(loss_generator).backward()
    gs_s_generator.step(opt_s_generator)
    gs_s_generator.update()

    alpha += x.shape[0] / ((x.shape[0]*0.5)*len(dataset))
    alpha = min(alpha, 1)

  return alpha