# StyleGAN from scratch

## Reference

https://github.com/SiskonEmilia/StyleGAN-PyTorch

In [24]:
import sys
import math
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from tqdm import tqdm
from PIL import Image
from torchvision import datasets, utils
from torch.utils.data import Dataset, DataLoader, dataloader
from torch.multiprocessing import reductions
from multiprocessing.reduction import ForkingPickler

%matplotlib inline

In [3]:
default_collate_fun = dataloader.default_collate

def default_collate_fun_override(batch):
    dataloader._use_share_memory = False
    return default_collate_fun(batch)

In [4]:
setattr(dataloader, 'default_collate', default_collate_fun_override)

for t in torch._storage_classes:
    if t in ForkingPickler._extra_reducers:
        del ForkingPickler._extra_reducers[t]


In [5]:
# Scale weights with HE normal initialization.

class ScaleW:
    def __init__(self, name):
        self.name = name

    def scale(self, module):
        weight = getattr(module, self.name + '_orig')
        
        # Calculate the number of parameters in previous layer.
        # (number of input channels) * (number of parameters in one channel) 
        n_in = weight.data.size(1) * weight.data[0][0].numel()

        return weight * math.sqrt(2. / n_in)
    
    @staticmethod
    def apply(module, name):
        hook = ScaleW(name)
        weight = getattr(module, name)
        module.register_parameter(name + '_orig', nn.Parameter(weight.data))
        del module._parameters[name]
        module.register_forward_pre_hook(hook)

    def __call__(self, module):
        weight = self.scale(module)
        setattr(module, self.name, weight)

In [6]:
# Quick apply for scaled weight

def quick_scale(module, name='weight'):
    ScaleW.apply(module, name)
    return module

In [7]:
class ScaledLinear(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()

        linear = nn.Linear(dim_in, dim_out)
        # Initialize weight with normal distribution.
        linear.weight.data.normal_()
        # Initialize bias to zero.
        linear.bias.data.zero_()
        
        self.linear = quick_scale(linear)

    def forward(self, x):
        return self.linear(x)

In [8]:
class ScaledConv2d(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        
        conv = nn.Conv2d(*args, **kwargs)
        conv.weight.data.normal_()
        conv.bias.data.zero_()

        self.conv = quick_scale(conv)

    def forward(self, x):
        return self.conv(x)

In [9]:
# Normalize channels of an input image. 
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # Consider mean of x as 0.
        # x_norm = x - mu / sigma
        sigma = torch.sqrt(torch.mean(x**2, dim=1, keep_dim=True) + 1e-8)
        return x / sigma

## Learned Affine Transform `A`

<img src="https://img1.daumcdn.net/thumb/R1280x0/?scode=mtistory2&fname=https%3A%2F%2Fblog.kakaocdn.net%2Fdn%2Flcwl8%2FbtrZtqZJohK%2F2MdeTwsRpXSazVdKdzHB1K%2Fimg.png"></img>

<img src="https://img1.daumcdn.net/thumb/R1280x0/?scode=mtistory2&fname=https%3A%2F%2Fblog.kakaocdn.net%2Fdn%2FDDNSS%2FbtrZD2cAUXE%2FN7OUemuCGUwCgVMssVw1yk%2Fimg.jpg"></img>

In [10]:
# Apply affine transformation A to w.
# Input: w
# Shape of w: (n_channel)
# Output: style = [y_si, y_bi]
# Shape of style: (2 * n_channel)

class LearnedA(nn.Module):
    def __init__(self, dim_latent, n_channel):
        super().__init__()
        self.transform = ScaledLinear(dim_latent, n_channel * 2)
        self.transform.linear.bias.data[:n_channel] = 1
        self.transform.linear.bias.data[n_channel:] = 0

    def forward(self, w):
        style = self.transform(w).unsqueeze(2).unsqueeze(3)
        return style

In [11]:
class AdaIN(nn.Module):
    def __init__(self, n_channel):
        super().__init__()
        self.norm = nn.InstanceNorm2d(n_channel)
    
    def forward(self, image, style):
        # Split style tensor to two parts along channel dimension.
        scaling, bias = style.chunk(2, dim=1)
        image_norm = self.norm(image)
        result = scaling * image_norm + bias
        return result

## Adding Noise

<img src="https://img1.daumcdn.net/thumb/R1280x0/?scode=mtistory2&fname=https%3A%2F%2Fblog.kakaocdn.net%2Fdn%2FdQh8aw%2FbtrZD2RepLK%2FRvLKDB7gvR9y2QCYCmuLNK%2Fimg.jpg"></img>

In [12]:
class LearnedB(nn.Module):
    def __init__(self, n_channel):
        super().__init__()
        # Broadcasting is automatically applied, so it will have the same dimension as an input.
        self.weight = nn.Parameter(torch.zeros(1, n_channel, 1, 1))
    
    def forward(self, noise):
        return noise * self.weight

## Synthesis network

<img src="https://img1.daumcdn.net/thumb/R1280x0/?scode=mtistory2&fname=https%3A%2F%2Fblog.kakaocdn.net%2Fdn%2FdjieWG%2FbtrZAtCfadE%2F1LvbQdbBYdRMe4qQ6qGzM0%2Fimg.jpg"></img>

### First Convolution Block

In [13]:
class FirstSynthesisBlock(nn.Module):
    def __init__(self, n_channel, dim_latent, dim_input):
        super().__init__()
        
        self.constant = nn.Parameter(torch.randn(1, n_channel, dim_input, dim_input))
        
        self.style1 = LearnedA(dim_latent, n_channel)
        self.style2 = LearnedA(dim_latent, n_channel)

        self.noise1 = quick_scale(LearnedB(n_channel))
        self.noise2 = quick_scale(LearnedB(n_channel))

        self.ada_in = AdaIN(n_channel)
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.conv = ScaledConv2d(n_channel, n_channel, kernel=3, padding=1)
    
    def forward(self, w, noise):
        # Constant is used as a seed for generating result.
        
        # Step 1
        seed = self.constant.repeat(noise.shape[0], 1, 1, 1)
        result = seed + self.noise1(noise)
        style1 = self.style1(w)
        result = self.ada_in(result, style1)
        result = self.leaky_relu(result)
        
        
        # Step 2
        result = self.conv(result)
        result = result + self.noise2(noise)
        style2 = self.style2(w)
        result = self.ada_in(result, style2)
        result = self.leaky_relu(result)

        return result

In [14]:
class SynthesisBlock(nn.Module):
    def __init__(self, in_channel, out_channel, dim_latent):
        super().__init__()
        
        self.style1 = LearnedA(dim_latent, out_channel)
        self.style2 = LearnedA(dim_latent, out_channel)

        self.noise1 = quick_scale(LearnedB(out_channel))
        self.noise2 = quick_scale(LearnedB(out_channel))

        self.ada_in = AdaIN(out_channel)
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.conv1 = ScaledConv2d(in_channel, out_channel, kernel=3, padding=1)
        self.conv2 = ScaledConv2d(out_channel, out_channel, kernel=3, padding=1)
    
    def forward(self, input, w, noise):
        # An input is upsampled twice from the previous one by interpolation.
        
        # Step 1
        result = self.conv1(input)
        result = result + self.noise1(noise)
        style1 = self.style1(w)
        result = self.ada_in(result, style1)
        result = self.leaky_relu(result)
   
        
        # Step 2
        result = self.conv2(result)
        result = result + self.noise2(noise)
        style2 = self.style2(w)
        result = self.ada_in(result, style2)
        result = self.leaky_relu(result)

        return result

## Mapping Network

<img src="https://img1.daumcdn.net/thumb/R1280x0/?scode=mtistory2&fname=https%3A%2F%2Fblog.kakaocdn.net%2Fdn%2FbxAGqg%2FbtrZCpTqDQo%2FETgEG2hnVdrbk8LyBmjLXK%2Fimg.jpg"></img>

In [15]:
class MappingNetwork(nn.Module):
    def __init__(self, n_fc, dim_latent):
        super().__init__()
        
        layers = [PixelNorm()]

        for _ in range(n_fc):
            layers.append(ScaledLinear(dim_latent, dim_latent))
            layers.append(nn.LeakyReLU(0.2))
        
        self.mapping = nn.Sequential(*layers)

    def forward(self, z):
        w = self.mapping(z)
        return w

## Generator

In [16]:
class Generator(nn.Module):
    def __init__(self, n_fc, dim_latent, dim_input):
        super().__init__()

        self.mapping_net = MappingNetwork(n_fc)

        in_channels = [512, 512, 512, 512, 256, 128, 64, 32]
        out_channels = [512, 512, 512, 256, 128, 64, 32, 16]

        # 9 synthesis block(1 + 8)
        self.convs = nn.ModuleList([
            FirstSynthesisBlock(512, dim_latent, dim_input),
            *(SynthesisBlock(in_channels[idx], out_channels[idx], dim_latent) for idx in range(len(in_channels)))
        ])

        self.to_rgbs = nn.ModuleList([
            ScaledConv2d(out_channels[idx], 3, kernel_size=1) for idx in range(len(in_channels))
        ])

    def forward(self, x, z, step, alpha, noise=None, mix_steps=[]):
        if type(z) != type(list()):
            z = [z]
        
        w = [self.mapping_net(latent_z) for latent_z in z]
        batch_size = w[0].size(0)

        result = 0
        current_latent = 0
        
        for i, conv in enumerate(self.convs):
            if i in mix_steps:
                current_latent = w[1]
            else:
                current_latent = w[0]

            # Except the first layer
            if i > 0 and step > 0:
                # Upsample the previous output
                result_upsampled = nn.functional.interpolate(result, scale_factor=2, mode='bilinear', align_corners=False)
                result = conv(result_upsampled, current_latent, noise[i])
            else:
                result = conv(current_latent, noise[i])

            # Last layer
            if i == step:
                result = self.to_rgbs[i](result)

                if i > 0 and 0 <= alpha < 1:
                    result_prev = self.to_rgbs[i - 1](result_upsampled)            
                    result = alpha * result + (1 - alpha(result_prev))

                break

        return result

## Discriminator

In [18]:
class DiscriminatorConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel, size_kernel1, padding1, size_kernel2=None, padding2=None):
        super().__init__()

        if size_kernel2 == None:
            size_kernel2 = size_kernel1
        
        if padding2 == None:
            padding2 = padding1

        self.convs = nn.Sequential(
            ScaledConv2d(in_channel, out_channel, size_kernel1, padding=padding1),
            nn.LeakyReLU(0.2),
            ScaledConv2d(out_channel, out_channel, size_kernel2, padding=padding2),
            nn.LeakyReLU(0.2)
        )

    def forward(self, image):
        result = self.conv(image)
        return result

In [19]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        in_channels = [512, 512, 512, 256, 128, 64, 32, 16]
        out_channels = [512, 512, 512, 512, 256, 128, 64, 32]

        in_channels.reverse()
        out_channels.reverse()

        self.from_rgbs = nn.ModuleList([
            ScaledConv2d(3, in_channels[idx], kernel_size=1) for idx in range(len(in_channels))
        ] + [ScaledConv2d(3, 512, 1)])

        self.convs = nn.ModuleList([
            DiscriminatorConvBlock(in_channels[idx], out_channels[idx], size_kernel1=3, padding1=1) for idx in range(len(in_channels))
        ] + [DiscriminatorConvBlock(512 + 1, 512, 3, 1, 4, 0)])

        self.fc = ScaledLinear(512, 1)
        
        self.n_layer = 9

    def forward(self, image, step=0, alpha=-1):
        for i in range(step):
            if i == 0:
                result = self.from_rgbs[i](image)
            
            if i == step:
                # shape of result: (batch, channel=512, 4, 4)
                res_var = result.var(dim=0, unbiased=True) + 1e-8
                # shape of res_var: (512, 4, 4)
                res_std = torch.sqrt(res_var)
                # shape of re_std: (512, 4, 4)
                mean_std = res_std.mean().expand(result.size(0), 1, 4, 4)
                # shape of mean_std: (1) -> (batch, 1, 4, 4)
                result = torch.cat([result, mean_std], dim=1)
                # shape of result: (batch, 512 + 1, 4, 4)

            result = self.convs[i](result)
            # shape of result: (batch, 512, 4, 4)

            if i < step:
                result = nn.functional.interpolate(result, scale_factor=0.5, mode='bilinear', align_corners=False)

                if i == 0 and 0 <= alpha < 1:
                    result_next = self.from_rgbs[i + 1](image)
                    result_next = nn.functional.interpolate(result_next, scale_factor=0.5, mode='bilinear', align_corners=False)

                    result = alpha * result + (1 - alpha) * result_next

        # shape of result: (batch, 512, 4, 4) => (batch, 512)
        result = result.squeeze(2).squeeze(2)
        result = self.fc(result)
        return result

## Settings

In [20]:
# use idel gpu
# it's better to use enviroment variable
# if you want to use multiple gpus, please
# modify hyperparameters at the same time
# And Make Sure Your Pytorch Version >= 1.0.1
import os
os.environ['CUDA_VISIBLE_DEVICES']='1, 2'
n_gpu             = 2
device            = torch.device('cuda:0')

learning_rate     = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
batch_size_1gpu   = {4: 128, 8: 128, 16: 64, 32: 32, 64: 16, 128: 16}
mini_batch_size_1 = 8
batch_size        = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
mini_batch_size   = 8
batch_size_4gpus  = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}
mini_batch_size_4 = 16
batch_size_8gpus  = {4: 512, 8: 256, 16: 128, 32: 64}
mini_batch_size_8 = 32
n_fc              = 8
dim_latent        = 512
dim_input         = 4
n_sample          = 120000
DGR               = 1
n_show_loss       = 40
step              = 1 # Train from (8 * 8)
max_step          = 8 # Maximum step (8 for 1024^2)
style_mixing      = [] # Waiting to implement
image_folder_path = './dataset/'
save_folder_path  = './results/'

low_steps         = [0, 1, 2]
# style_mixing    += low_steps
mid_steps         = [3, 4, 5]
# style_mixing    += mid_steps
hig_steps         = [6, 7, 8]
# style_mixing    += hig_steps

# Used to continue training from last checkpoint
startpoint        = 0
used_sample       = 0
alpha             = 0

# Mode: Evaluate? Train?
is_train          = True

# How to start training?
# True for start from saved model
# False for retrain from the very beginning
is_continue       = True
d_losses          = [float('inf')]
g_losses          = [float('inf')]

inputs, outputs = [], []

In [21]:
# Set the parameter whether it is required for gradient descent or not.
def set_grad_flag(module, flag):
    for p in module.parameters():
        p.requires_grad = flag

# Reset the Learning rate.
def reset_LR(optimizer, lr):
    for pam_group in optimizer.param_groups:
        mul = pam_group.get('mul', 1)
        pam_group['lr'] = lr * mul

In [22]:
# Gain sample with DataLoader
def gain_sample(dataset, batch_size, image_size=4):
    transform = transforms.Compose([
            transforms.Resize(image_size),          # Resize to the same size
            transforms.CenterCrop(image_size),      # Crop to get square area
            transforms.RandomHorizontalFlip(),      # Increase number of samples
            transforms.ToTensor(),            
            transforms.Normalize((0.5, 0.5, 0.5),   # Normalize RGB values
                                 (0.5, 0.5, 0.5))])

    dataset.transform = transform
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=8)

    return loader

In [23]:
def imshow(tensor, i):
    grid = tensor[0]
    grid.clamp_(-1, 1).add_(1).div_(2)
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
    ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    img = Image.fromarray(ndarr)
    img.save(f'{save_folder_path}sample-iter{i}.png')
    plt.imshow(img)
    plt.show()

In [25]:
def train(generator, discriminator, g_optim, d_optim ,dataset , step, start_point=0, used_sample=0, d_losses=[], g_losses=[], alpha=0):
    resolution = 4 * 2 ** step

    origin_loader = gain_sample(dataset, batch_size.get(resolution, mini_batch_size), resolution)
    data_loader = iter(origin_loader)

    reset_LR(g_optim, learning_rate.get(resolution, 0.001))
    reset_LR(d_optim, learning_rate.get(resolution, 0.001))

    progress_bar = tqdm(range(start_point + 1, n_sample * 5))

    for i in progress_bar:
        alpha = min(1, alpha + batch_size.get(resolution, mini_batch_size) / (n_sample * 2))
        
        if used_sample > n_sample * 2 and step < max_step:
            step += 1

            alpha = 0
            used_sample = 0
            
            resolution = 4 * 2 ** step

            # Change batch size.

            del origin_loader
            del data_loader

            origin_loader = gain_sample(dataset, batch_size.get(resolution, mini_batch_size), resolution)
            data_loader = iter(origin_loader)

            reset_LR(g_optim, learning_rate.get(resolution, 0.001))
            reset_LR(d_optim, learning_rate.get(resolution, 0.001))


        try:
            # Try to read next image
            real_image, label = next(data_loader)

        except (OSError, StopIteration):
            # Dataset exhausted, train from the first image
            data_loader = iter(origin_loader)
            real_image, label = next(data_loader)

        used_sample += real_image.shape[0]

        real_image = real_image.to(device)


        # Train Discriminator first.
        # Set all gradients of discriminator to zero.
        discriminator.zero_grad()
        set_grad_flag(discriminator, True)
        set_grad_flag(generator, False)

        real_image.required_grad = True
        if n_gpu > 1:
            real_predict = nn.parallel.data_parallel(discriminator,(real_image, step, alpha), range(n_gpu))
        else:
            real_predict = discriminator(real_image, step, alpha)

        # Calculate the loss of prediction for real image.
        # Discriminator should predict 1 to the real image.
        real_predict = nn.functional.softplus(-real_predict).mean()
        real_predict.backward()

        grad_real = torch.autograd.grad(outputs=real_predict.sum(), inputs=real_image, create_graph=True)[0]
        grad_penalty_real = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean()
        grad_penalty_real = 10 / 2 * grad_penalty_real
        grad_penalty_real.backward()

        # Generate latent code
        w1 = [torch.randn((batch_size.get(resolution, mini_batch_size), dim_latent), device=device)]
        w2 = [torch.randn((batch_size.get(resolution, mini_batch_size), dim_latent), device=device)]

        noise_1 = []
        noise_2 = []

        for m in range(step + 1):
            size = 4 * 2 ** m # Due to the upsampling, size of noise will grow
            noise_1.append(torch.randn((batch_size.get(resolution, mini_batch_size), 1, size, size), device=device))
            noise_2.append(torch.randn((batch_size.get(resolution, mini_batch_size), 1, size, size), device=device))
        
        if n_gpu > 1:
            fake_image = nn.parallel.data_parallel(generator, (w1, step, alpha, noise_1), range(n_gpu))
            fake_predict = nn.parallel.data_parallel(discriminator, (fake_image, step, alpha), range(n_gpu))
        else:
            fake_image = generator(w1, step, alpha, noise_1)
            fake_predict = discriminator(fake_image, step, alpha)

        # Calculate the loss of prediction for fake image.
        # Discriminator sould predict 0 to the fake image. 
        fake_predict = nn.functional.softplus(fake_predict).mean()
        fake_predict.backward()

        if i % n_show_loss == 0:
            d_losses.append((real_predict + fake_predict).item())
        
        d_optim.step()

        del grad_penalty_real, grad_real, fake_predict, real_predict, w1
        
        # Train Generator
        # Set all gradients of generator to zero.
        generator.zero_grad()
        set_grad_flag(discriminator, False)
        set_grad_flag(generator, True)

        if n_gpu > 1:
            fake_image = nn.parallel.data_parallel(generator, (w2, step, alpha, noise_2), range(n_gpu))
            fake_predict = nn.parallel.data_parallel(discriminator, (fake_image, step, alpha), range(n_gpu))
        else:
            fake_image = generator(w2, step, alpha, noise_2)
            fake_predict = discriminator(fake_image, step, alpha)

        fake_predict = nn.functional.softplus(fake_predict).mean()
        fake_predict.backward()
        g_optim.step()


        if i % n_show_loss == 0:
            g_losses.append(fake_predict.item())
            imshow(fake_image.data.cpu(), i)

        del fake_predict, fake_image, w2
        
        if (i + 1) % 1000 == 0:
            # Save the model every 1000 iterations
            torch.save({
                'generator'    : generator.state_dict(),
                'discriminator': discriminator.state_dict(),
                'g_optim'      : g_optim.state_dict(),
                'd_optim'      : d_optim.state_dict(),
                'parameters'   : (step, i, used_sample, alpha),
                'd_losses'     : d_losses,
                'g_losses'     : g_losses
            }, 'checkpoint/trained.pth')
            print(f'Iteration {i} successfully saved.')
        
        progress_bar.set_description((f'Resolution: {resolution}*{resolution}  D_Loss: {d_losses[-1]:.4f}  G_Loss: {g_losses[-1]:.4f}  Alpha: {alpha:.4f}'))
    
    return d_losses, g_losses

In [None]:
generator = Generator(n_fc, dim_latent, dim_input).to(device)
discriminator = Discriminator().to(device)

g_optim = optim.Adam([{
    'params': generator.convs.parameters(),
    'lr': 0.001    
}, {
    'params': generator.to_rgbs.parameters(),
    'lr': 0.001
}], lr=0.001, betas=(0.0, 0.99))

g_optim.add_param_group({
    'params': generator.fcs.parameters(),
    'lr': 0.001 * 0.01,
    'mul': 0.01
})

d_optim = optim.Adam(discriminator.parameters(), lr=0.001, betas=(0.0, 0.99))
dataset = datasets.ImageFolder(image_folder_path)