In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy
import torchvision

# Model Implementation

In [None]:
# c : content
# s : style
# content 에서 content의 특징을 빼주고 style의 특징을 더해준다.
def AdaIN(c, s):
  c_mean, c_std = torch.mean(x, dim=1), torch.std(x, dim=1)
  s_mean, s_std = torch.mean(y, dim=1), torch.std(y, dim=1)

  return s_std * ((c-c_mean)/c_std) + s_mean
  
# pixelwise feature vector normalization in generator
# generator와 discriminator의 magnitude가 competition에 의해 통제 불능의 상태가 되는 것을 막기 위해
# 각 convolution의 feature에 대해 pixel 단위로 normalizing을 해준다.
# 결과 품질에는 큰 영향이 없으나 필요시 signal magnitude가 점차 증가되는 것을 막아준다.
class PixelNorm(nn.Module):
  def __init__(self, epsilon=1e-8):
    super(PixelNorm, self).__init__()
    self.epsilon=epsilon
  
  def forward(self, x):
    return x * torch.rsqrt(torch.mean(torch.square(x), dim=1, keepdim=True) + self.epsilon)

# ProGAN(PGGAN, PGAN)에서의 equalized linaer layer
class EqualizedLinear(nn.Module):
  def __init__(self,in_dim, out_dim, bias=True, bias_init=0, c=1):
    super(EqualizedLinear, self).__init__()
    self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div(c))

    if bias:
      self.bias = nn.Parameter(torch.randn(out_dim).fill_(bias_init))
    else:
      self.bias = None
    
    self.scale = (1/math.sqrt(in_dim)) * c
    self.c = c

  def forward(self, x):
    out = F.linear(x, self.weight * self.scale)
    out = F.leaky_relu(out, negative_slope=0.2)

    return out

# Synthesis input의 크기가 고정이다.
# 4x4x512 크기인 tensor를 생성
class ConstInput(nn.Module):
  def __init__(self, channel=512, size=4):
    super(ConstInput, self).__init__()
    self.input = nn.Parameter(torch.randn(1, channel, size, size))
  
  def forward(self,batch_size):
    x = self.input.repeat(batch_size, 1, 1, 1)
    return x


class SGConv2d(nn.Module):
  def __init__(self, in_channel, out_channel, kernel_size=3, style_dim=512, demodulate=True, upsample=False, blur_filter=[1,3,3,1]):
    super(SGConv2d, self).__init__()
    self.eps = 1e-8
    self.kernel_size= kernel_size
    self.in_ch = in_channel
    self.out_ch = out_channel
    self.upsample = upsample
    
    if upsample:
      self.up = nn.Upsample(scale_factor=2, mode=linear)

    fan_in = in_channel * kernel_size**2
    self.scale = 1 / math.sqrt(fan_in)
    self.padding = kernel_size // 2

    self.weight = nn.Parameter(torch.randn(1, out_cahnnel, in_channel, kernel_size, kernel_size))

    self.modulation = EqualizedLinear(style_dim, in_channel, bias_init=1)
    self.demodulate = demodulate

  def forward(self, x, style):
    batch, _, height, width = x.size()

    style = self.modulation(style).view(batch, 1, self.in_ch, 1, 1)
    weight = self.scale * self.weight * style

    if self.demodulate:
      demod = torch.rsqrt(weight.pow(2).sum([2,3,4])+self.epsilon)
      weight = weight * demod.view(batch, self.out_ch, 1, 1, 1)
    
    weight = weight.view(batch*self.out_ch, self.in_ch, self.kernel_size, self.kernel_size)

    if self.upsample:
      x = x.view(1, batch*self.in_ch, height, width)
      weight = weight.view(batch, self.out_ch, self.in_ch, self.kernel_size, self.kernel_size)
      weight = weifht.transpose(1, 2).view(batch*self.in_ch, self.out_ch, self.kernel_size, self.kernel_size)

      self.up(x)
      out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=batch)
      _,_, height, width = out.size()

      out = out.view(batch, self.out_ch, height, width)
    




class StyleGAN(nn.Module):
  def __init__(self,):
    super(StyleGAN, self).__init__()

    # Mapping Network ---------------------------------------------
    self.layers = []
    self.layers.append(PixelNorm())
    for idx in range(8):
      self.layers.append(EqualizedLinear(latent_dim, latent_dim))
      self.layers.append(nn.LeakyReLU(0.2))
    
    # style
    self.mapping = nn.Sequential(*self.layers)
    # -------------------------------------------------------------

    # Synthesis Network -------------------------------------------
    blur_filter = [1,3,3,1]
    channel_multiplier = 2
    self.channels = {
      4: 512,
      8: 512,
      16: 512,
      32: 512,
      64: 256 * channel_multiplier,
      128: 128 * channel_multiplier,
      256: 64 * channel_multiplier,
      512: 32 * channel_multiplier,
      1024: 16 * channel_multiplier,
    }

    self.constInput = ConstInput()

    
    # -------------------------------------------------------------

  def forward(self, z, label):
    # Mapping Network ---------------------------------------------
    label = torch.dot(label, z.float())

    x = torch.concat([z, label])
    w = self.map(x)
    w = w.view(-1, 1)
    # -------------------------------------------------------------

    # Synthesis Network -------------------------------------------

    return 