In [1]:
from collections import OrderedDict as odict
import torch
import torch.nn as nn
import torch.nn.functional as nnf

# To Add:

Sequences of stuff (list comprehensions)

Tensor intro

Drawing with tensors

Math = Drawing

# Sequences

## Expansion and Reduction

# Tensors

# Tensors are Drawings

# Math is Drawings

In [2]:
class Settings:
    LatentDimension = 512

## Mapping Network

The mapping network is stated to be a nonlinear function:

$$f : Z \rightarrow W$$

The authors state that this function is implemented practically as a multilayer perceptron (MLP) with 8 layers and that both spaces $Z$ and $W$ are set to be 512-dimensional.

We could state this more explicitly as:

$$ Z, W \in \mathbb{R}^{512} $$

All that this means is that both $Z$ and $W$ are vectors of real numbers that have 512 entries ( `[1.1, 2.65, 3.141, ..., 6.022]` ).

### Multilayer Perceptron

But what, exactly, is a "multilayer perceptron"?

An MLP is a very simple kind of neural network that simply takes a vector input, multiplies it with a weight matrix to get another vector, and then repeats for some number of layers. Formally:

$$ x \in \mathbb{R}^{1 \times m} $$
$$ w \in \mathbb{R}^{m \times n} $$
$$ y \in \mathbb{R}^{1 \times n} $$

This is essentially just a vector, matrix product. If $m \gt n$ then the layer will be performing data reduction, if $m \lt n$ then it will be performing data expansion. Notably, if the weight matrix $w$ is square, $x$ and $y$ will be the same dimension, and this is what is happening in the Mapping Network. There is also no mention of a nonlinearity applied to the Mapping Network in the paper, so our construction in code is very straightforward.

In [21]:
class MappingNetwork(nn.Sequential):
    def __init__(self, layer_count=8, latent_dim=512):
        super(MappingNetwork, self).__init__()

        for layer_number in range(layer_count):
            layer_name = "linear_{}".format(layer_number)
            layer = nn.Linear(latent_dim, latent_dim)
            self.add_module(layer_name, layer)

# Math Note

$$ f \sim mn $$

# Synthesis Network

The authors' diagram of the Synthesis Network shows a repeating block of upsample, convolution, noise scaling/addition, and a function that they define called `AdaIN`.
$$ W \in \mathbb{R}^n $$
$$ Y \in \mathbb{R}^{2n} $$
$$ A : W \rightarrow Y $$

$Y$ can be thought of as a style space where the scalar components are parameters that control both how strongly feature maps in $x$ are carried forward, and how much it is shifted around the style space.

$$ AdaIN(x_i, y) = y_{s, i}\frac{x_i - \mu(x_i)}{\sigma(x_i)} + y_{b, i} $$

In [2]:
class A(nn.Module):
    def __init__(self, in_features, w_dim=512):
        super(A, self).__init__()
        self.affine = nn.Linear(w_dim, 2 * in_features)
    
    def forward(self, w):
        return self.affine(w).reshape(2, -1)
        

In [3]:
class B(nn.Module):
    def __init__(self, height, width, num_features):
        super(B, self).__init__()
        self.width = width
        self.height = height
        self.num_features = num_features
        
        self.noise_image = torch.randn(1, 1, height, width)
        
        self.scaling_factors = torch.nn.Parameter(data=torch.randn(1, num_features, 1, 1), requires_grad=True)
        
    def forward(self):
        return self.scaling_factors.expand(1, -1, self.height, self.width) * self.noise_image

In [4]:
class AdaIN(nn.Module):
    def __init__(self):
        super(AdaIN, self).__init__()
        
    def forward(self, x, y):
        mu_x    = x.mean(dim=(0, 2, 3)).reshape(1, -1, 1, 1)
        sigma_x = x.std(dim=(0, 2, 3)).reshape(1, -1, 1, 1)
        
        normed_x = (x - mu_x) / sigma_x
        
        y = y.reshape(2, -1, 1, 1)
        
        return (y[0, :] * x) + y[1, :]

# PixelNorm
From the [Progressive Growing of GANs paper](https://arxiv.org/pdf/1710.10196.pdf), section 4.2, the authors detail the per-pixel normalization function as:

$$ b_{x, y} = \frac{a_{x, y}}{\sqrt{\frac{1}{n}\Sigma_{j=0}^{n-1}{(a^{j}_{x, y})^2 + \epsilon}}} $$

where $\epsilon = 10^{-8}$ and $n$ is the number of feature maps.

In [5]:
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 10 ** -8
        
    def forward(self, x):
        n, c, h, w = x.shape
        d = x.pow(2)
        d = d.sum(dim=(1)) + self.epsilon
        d = d.mul(1 / c)
        d = d.sqrt()
        d = d.unsqueeze(1)

        return x / d

In [6]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        
        self.conv = nn.Conv2d(in_channels, out_channels, (3, 3), 1, 1)
        self.conv.weight.data.normal_(0, 1)
        self.conv.bias.data.fill_(0)
        
        self.norm = PixelNorm()
        self.act = nn.LeakyReLU(negative_slope=0.2)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        
        return x     
        

In [7]:
class SynthesisBlock(nn.Module):
    def __init__(self, in_channels, out_channels, height, width, w_dim=512):
        super(SynthesisBlock, self).__init__()
        
        self.upsample = nn.UpsamplingBilinear2d((height, width))
        self.conv0 =    ConvBlock(in_channels, out_channels)
        self.b0 =       B(height, width, out_channels)
        self.a0 =       A(out_channels, w_dim=w_dim)
        self.adain0 =   AdaIN()
        
        self.conv1 =    ConvBlock(out_channels, out_channels)
        self.b1 =       B(height, width, out_channels)
        self.a1 =       A(out_channels, w_dim=w_dim)
        self.adain1 =   AdaIN()
    
    def forward(self, tensor_dict):

        x = tensor_dict["x"]
        w = tensor_dict["w"]
        
        x = self.upsample(x)
        x = self.conv0(x)
        x = x + self.b0()
        y = self.a0(w)
        x = self.adain0(x, y)
        
        x = self.conv1(x)
        x = x + self.b1()
        y = self.a1(w)
        x = self.adain1(x, y)
        
        return {"x": x, "w": w}

In [14]:
class InputBlock(nn.Module):
    def __init__(self, in_channels, out_channels, height, width, w_dim=512):
        super(InputBlock, self).__init__()
        
        self.conv0 =    ConvBlock(in_channels, out_channels)
        self.b0 =       B(height, width, out_channels)
        self.a0 =       A(out_channels, w_dim=w_dim)
        self.adain0 =   AdaIN()
        
        self.b1 =       B(height, width, out_channels)
        self.a1 =       A(out_channels, w_dim=w_dim)
        self.adain1 =   AdaIN()
    
    def forward(self, tensor_dict):

        x = tensor_dict["x"]
        w = tensor_dict["w"]
        
        x = self.conv0(x)
        x = x + self.b0()
        y = self.a0(w)
        x = self.adain0(x, y)

        x = x + self.b1()
        y = self.a1(w)
        x = self.adain1(x, y)
        
        return {"x": x, "w": w}

In [15]:
class OutputBlock(nn.Module):
    def __init__(self, input_channels):
        super(OutputBlock, self).__init__()
        
        self.to_rgb = nn.Conv2d(input_channels, 3, (1, 1), 1, 1)
    
    def forward(self, tensor_dict):
        x = tensor_dict["x"]
        w = tensor_dict["w"]
        
        x = self.to_rgb(x)
        
        return {"x": x, "w": w}

In [16]:
class StyleGAN(nn.Module):
    def __init__(self, input_layer=None, layer_params=None, w_dim=512):
        super(StyleGAN, self).__init__()
        
        if input_layer == None:
            input_layer = InputBlock(512, 512, 4, 4)
        
        self.input = input_layer

        self.main = nn.Sequential()

        if layer_params == None:
            layer_params = [
                (512, 512,    8,    8, w_dim),
                (512, 512,   16,   16, w_dim),
                (512, 512,   32,   32, w_dim),
                (512, 256,   64,   64, w_dim),
                (256, 128,  128,  128, w_dim),
                (128,  64,  256,  256, w_dim),
                ( 64,  32,  512,  512, w_dim),
                ( 32,  16, 1024, 1024, w_dim),
            ]
        
        [self.main.add_module("sb_{}".format(n), SynthesisBlock(*p)) for n, p in enumerate(layer_params)]
    
        final_out_channels = layer_params[-1][1]
    
        self.output = OutputBlock(final_out_channels)
    
    def forward(self, tensor_dict):
        tensor_dict = self.input(tensor_dict)
        tensor_dict = self.main(tensor_dict)

        return self.output(tensor_dict)

In [22]:
mn = MappingNetwork()
sg = StyleGAN()

print(mn)
print(sg)

MappingNetwork(
  (linear_0): Linear(in_features=512, out_features=512, bias=True)
  (linear_1): Linear(in_features=512, out_features=512, bias=True)
  (linear_2): Linear(in_features=512, out_features=512, bias=True)
  (linear_3): Linear(in_features=512, out_features=512, bias=True)
  (linear_4): Linear(in_features=512, out_features=512, bias=True)
  (linear_5): Linear(in_features=512, out_features=512, bias=True)
  (linear_6): Linear(in_features=512, out_features=512, bias=True)
  (linear_7): Linear(in_features=512, out_features=512, bias=True)
)
StyleGAN(
  (input): InputBlock(
    (conv0): ConvBlock(
      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): PixelNorm()
      (act): LeakyReLU(negative_slope=0.2)
    )
    (b0): B()
    (a0): A(
      (affine): Linear(in_features=512, out_features=1024, bias=True)
    )
    (adain0): AdaIN()
    (b1): B()
    (a1): A(
      (affine): Linear(in_features=512, out_features=1024, bias=True)
    )
    

In [23]:
x = torch.randn(4, 512, 4, 4, requires_grad=True)
z = torch.randn(1, 512, requires_grad=True)

w = mn(z)


tensor_dict = {"x": x, "w": w}

sg(tensor_dict)

{'x': tensor([[[[-0.0869, -0.0869, -0.0869,  ..., -0.0869, -0.0869, -0.0869],
           [-0.0869, -0.0646, -0.0790,  ..., -0.0559, -0.0668, -0.0869],
           [-0.0869, -0.0994, -0.0468,  ..., -0.0623, -0.0438, -0.0869],
           ...,
           [-0.0869, -0.0850, -0.0785,  ..., -0.0446, -0.0691, -0.0869],
           [-0.0869, -0.0574, -0.0783,  ..., -0.0681, -0.0553, -0.0869],
           [-0.0869, -0.0869, -0.0869,  ..., -0.0869, -0.0869, -0.0869]],
 
          [[-0.1199, -0.1199, -0.1199,  ..., -0.1199, -0.1199, -0.1199],
           [-0.1199, -0.1049, -0.1184,  ..., -0.1014, -0.1031, -0.1199],
           [-0.1199, -0.1421, -0.1172,  ..., -0.1218, -0.0909, -0.1199],
           ...,
           [-0.1199, -0.1453, -0.1282,  ..., -0.1063, -0.1203, -0.1199],
           [-0.1199, -0.1064, -0.1143,  ..., -0.1098, -0.1183, -0.1199],
           [-0.1199, -0.1199, -0.1199,  ..., -0.1199, -0.1199, -0.1199]],
 
          [[ 0.2024,  0.2024,  0.2024,  ...,  0.2024,  0.2024,  0.2024],
        

In [19]:
params = list(sg.parameters())

In [20]:
pc = 0
for p in params:
    s = list(p.shape)
    c = 1
    for n in s:
        c *= n
    pc += c

pc

24102067