##### Dataset

In [None]:
import os
from PIL import Image

# ffhq_data_path = "./cartoonized_ffhq_dataset"
ffhq_data_path = "/kaggle/input/ffhq-20000/ffhq_dataset_20000"
ffhq_list = []
# anime_data_path = "./anime_dataset"
anime_data_path = "/kaggle/input/anime-dataset-16101/anime_dataset"
anime_list = []

for root, _, files in os.walk(ffhq_data_path):
    for file in files:
        if file.endswith(".png"):
            img_path = os.path.join(root, file)
            ffhq_list.append(img_path)

for root, _, files in os.walk(anime_data_path):
    for file in files:
        if file.endswith(".png"):
            img_path = os.path.join(root, file)
            anime_list.append(img_path)

print("ffhq :", len(ffhq_list))
print("anime:", len(anime_list))

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import to_pil_image
from IPython.display import display
from torchvision.transforms.functional import to_tensor

class Dataset_(Dataset):
    def __init__(self, ffhq_list, anime_list):
        self.ffhq_data = ffhq_list
        self.anime_data = anime_list
        self.anime_len = len(anime_list)
        
    def __len__(self):
        return len(self.ffhq_data)
    
    def __getitem__(self, index):
        sample_f = self.ffhq_data[index]
        sample_a = self.anime_data[index % self.anime_len]

        return sample_f, sample_a

batch_size = 2
Cartoon_Dataset = Dataset_(ffhq_list, anime_list)
Cartoon_Dataloader = DataLoader(Cartoon_Dataset, batch_size=batch_size, shuffle=True)

for batch in Cartoon_Dataloader:
    img_batch = [Image.open(img) for tuple in zip(*batch) for img in tuple]
    display(img_batch[0])
    display(img_batch[1])
    display(img_batch[2])
    display(img_batch[3])
    break

# Option

In [None]:
class Option():
    def __init__(self):
        """Encoder"""
        self.use_antialias = True
        self.netE_num_downsampling_sp = 4
        self.spatial_code_ch = 8
        self.netE_num_downsampling_gl = 2
        self.global_code_ch = 2048
        self.netE_nc_steepness = 2.0
        self.netE_scale_capacity = 1.0

        """Generator"""
        self.num_classes = 0
        self.netG_num_base_resnet_layers = 2
        self.netG_use_noise = True
        self.netG_scale_capacity = 1.0

        """Discriminator"""
        self.crop_size = 256
        self.netD_scale_capacity = 1.0
        self.netPatchD_scale_capacity = 4.0
        self.netPatchD_max_nc = 256 + 128
        self.patch_size = 64
        self.max_num_tiles = 8
        self.patch_random_transformation = True

        """BaseModel"""
        self.num_gpus = 1
        self.checkpoints_dir = "./checkpoints/"
        self.name = "model"
        self.isTrain = False
        self.pretrained_name = None
        self.resume_iter = False

        """SwapingAE"""
        self.lambda_R1 = 10.0
        self.lambda_patch_R1 = 1.0
        self.lambda_L1 = 1.0
        self.lambda_GAN = 1.0
        self.lambda_PatchGAN = 1.0
        self.patch_min_scale = 1 / 8
        self.patch_max_scale = 1 / 4
        self.patch_num_crops = 8
        self.patch_use_aggregation = True
        
        """Optimizer"""
        self.lr = 0.002
        self.beta1 = 0.0
        self.beta2 = 0.99
        self.R1_once_every = 16
        
        """Training"""
        self.print_freq = 100
        self.display_freq = 200
        self.save_freq = 5000
        
opt = Option()

# StyleGan2 Layers

In [None]:
import torch
from torch import nn as nn
from torch.nn import functional as F

def upfirdn2d_native(
    input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
    bs, ch, in_h, in_w = input.shape
    minor = 1
    kernel_h, kernel_w = kernel.shape

    out = input.view(-1, in_h, 1, in_w, 1, minor)
    if up_x > 1 or up_y > 1:
        out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])

    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    if pad_x0 > 0 or pad_x1 > 0 or pad_y0 > 0 or pad_y1 > 0:
        out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])

    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape(
        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
    )

    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)

    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )

    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]
    out = out.view(bs, ch, out.size(1), out.size(2))

    return out

def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
    return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])

In [None]:
import torch
from torch import nn as nn
from torch.nn import functional as F

def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
    """
    before bias(channels,)
    after  bias(1, channels, 1, 1)
    """

    dims = [1, -1] + [1] * (input.dim() - 2)
    bias = bias.view(*dims)
    return F.leaky_relu(input + bias, negative_slope) * scale

class FusedLeakyReLU(nn.Module):
    def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
        super().__init__()

        self.bias = nn.Parameter(torch.zeros(channel))
        self.negative_slope = negative_slope
        self.scale = scale

    def forward(self, input):
        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)

In [None]:
import torch
from torch import nn as nn
from torch.nn import functional as F

from collections import OrderedDict
import math

# from upfirdn2d import *
# from fused_leaky_relu import *

def make_kernel(k):
    k = torch.tensor(k, dtype = torch.float32)
    # Create a 2D matrix using outer product
    if k.dim() == 1:
        k = k[None, :] * k[:, None]
    
    # Normalize
    k /= k.sum()
    return k

class Upsample(nn.Module):
    def __init__(self, kernel, factor=2):
        super().__init__()
        self.factor = factor
        kernel = make_kernel((kernel) * (factor ** 2))
        self.register_buffer('kernel', kernel)
        
        p = kernel.shape[0] - factor

        pad0 = (p + 1) // 2 + factor - 1
        pad1 = p // 2

        self.pad = (pad0, pad1)

    def forward(self, input):
        out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
        return out
    
class Downsample(nn.Module):
    def __init__(self, kernel, factor=2, pad=None, reflection_pad=False):
        super().__init__()

        self.factor = factor
        kernel = make_kernel(kernel)
        self.register_buffer('kernel', kernel)
        self.reflection = reflection_pad

        if pad is None:
            p = kernel.shape[0] - factor
        else:
            p = pad

        pad0 = (p + 1) // 2
        pad1 = p // 2

        self.pad = (pad0, pad1)

    def forward(self, input):
        if self.reflection:
            input = F.pad(input, (self.pad[0], self.pad[1], self.pad[0], self.pad[1]), mode='reflect')
            pad = (0, 0)
        else:
            pad = self.pad

        out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=pad)

        return out

class Blur(nn.Module):
    def __init__(self, kernel, pad, upsample_factor=1, reflection_pad=False):
        super().__init__()
        
        kernel = make_kernel(kernel)
        
        # To match the upsampling size
        if upsample_factor > 1:
            kernel = kernel * (upsample_factor ** 2)
        
        # Register to the Blur Module buffer but won't be considered model parameters.
        self.register_buffer('kernel', kernel)
        
        self.pad = pad
        self.reflection = reflection_pad
        if self.reflection:
            self.reflection_pad = nn.ReflectionPad2d((pad[0], pad[1], pad[0], pad[1]))
            self.pad = (0, 0)
        
    def forward(self, input):
        if self.reflection:
            input = self.reflection_pad(input)
        out = upfirdn2d(input, self.kernel, pad=self.pad)

        return out

# Equalizing the weight initialization, prevent certain layers from dominating the learning process during training
class EqualConv2d(nn.Module):
    def __init__(
        self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, lr_mul=1.0
    ):
        super().__init__()
        self.weight = nn.Parameter(
            torch.randn(out_channel, in_channel, kernel_size, kernel_size)
        )
        self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) * lr_mul

        self.stride = stride
        self.padding = padding

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channel))
        else:
            self.bias = None

    def forward(self, input):
        out = F.conv2d(
            input,
            self.weight * self.scale,
            bias=self.bias,
            stride=self.stride,
            padding=self.padding,
        )
        
        return out
    
    def __repr__(self):
        return (
            f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
            f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
        )

class EqualLinear(nn.Module):
    def __init__(
        self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
    ):
        super().__init__()

        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))

        else:
            self.bias = None

        self.activation = activation

        self.scale = (1 / math.sqrt(in_dim)) * lr_mul
        self.lr_mul = lr_mul

    def forward(self, input):
        if self.activation:
            if input.dim() > 2:
                out = F.conv2d(input, self.weight[:, :, None, None] * self.scale)
            else:
                out = F.linear(input, self.weight * self.scale)
            out = fused_leaky_relu(out, self.bias * self.lr_mul)

        else:
            if input.dim() > 2:
                out = F.conv2d(input, self.weight[:, :, None, None] * self.scale,
                               bias=self.bias * self.lr_mul
                )
            else:
                out = F.linear(
                    input, self.weight * self.scale, bias=self.bias * self.lr_mul
                )

        return out

    def __repr__(self):
        return (
            f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
        )
            
class ConvLayer(nn.Sequential):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        downsample=False,
        blur_kernel=[1,3,3,1],
        bias=True,
        activate=True,
        pad=None,
        reflection_pad=False
    ):
        layers = []
        
        if downsample:
            factor = 2
            if pad is None:
                pad = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (pad + 1) // 2
            pad1 = pad // 2
            
            layers.append(("Blur", Blur(blur_kernel, pad=(pad0, pad1), reflection_pad=reflection_pad)))
            
            stride = 2
            self.padding = 0
        else:
            stride = 1
            self.padding = kernel_size // 2 if pad is None else pad
            if reflection_pad:
                layers.append(("RefPad", nn.ReflectionPad2d(self.padding)))
                self.padding = 0
        
        layers.append(("Conv",
                       EqualConv2d(
                           in_channel,
                           out_channel,
                           kernel_size,
                           padding=self.padding,
                           stride=stride,
                           bias=bias and not activate,
                       ))
        )

        if activate:
            if bias:
                layers.append(("Act", FusedLeakyReLU(out_channel)))
            else:
                layers.append(("Act", ScaledLeakyReLU(0.2)))

        super().__init__(OrderedDict(layers))

    def forward(self, x):
        out = super().forward(x)
        return out
    
class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, blur_kernel=[1,3,3,1], reflection_pad=False, pad=None, downsample=True):
        super().__init__()
        self.conv1 = ConvLayer(in_channel, in_channel, 3, reflection_pad=reflection_pad, pad=pad)
        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample, blur_kernel=blur_kernel, reflection_pad=reflection_pad, pad=pad)
        self.skip = ConvLayer(
            in_channel, out_channel, 1, downsample=downsample, blur_kernel=blur_kernel, activate=False, bias=False
        )
        
    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)
        skip = self.skip(input)
        
        return (out + skip) / math.sqrt(2)

# Base Network

In [None]:
class BaseNetwork(torch.nn.Module):
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        
    def print_architecture(self, verbose=False):
        name = type(self).__name__
        result = '-------------------%s---------------------\n' % name
        total_num_params = 0
        for i, (name, child) in enumerate(self.named_children()):
            num_params = sum([p.numel() for p in child.parameters()])
            total_num_params += num_params
            if verbose:
                result += "%s: %3.3fM\n" % (name, (num_params / 1e6))
            for i, (name, grandchild) in enumerate(child.named_children()):
                num_params = sum([p.numel() for p in grandchild.parameters()])
                if verbose:
                    result += "\t%s: %3.3fM\n" % (name, (num_params / 1e6))
        result += '[Network %s] Total number of parameters : %.3f M\n' % (name, total_num_params / 1e6)
        result += '-----------------------------------------------\n'
        print(result)
        
    def set_requires_grad(self, requires_grad):
        for param in self.parameters():
            param.requires_grad = requires_grad

    def collect_parameters(self, name):
        params = []
        for m in self.modules():
            if type(m).__name__ == name:
                params += list(m.parameters())
        return params

    def fix_and_gather_noise_parameters(self):
        params = []
        device = next(self.parameters()).device
        for m in self.modules():
            if type(m).__name__ == "NoiseInjection":
                assert m.image_size is not None, "One forward call should be made to determine size of noise parameters"
                m.fixed_noise = torch.nn.Parameter(torch.randn(m.image_size[0], 1, m.image_size[2], m.image_size[3], device=device))
                params.append(m.fixed_noise)
        return params

    def remove_noise_parameters(self, name):
        for m in self.modules():
            if type(m).__name__ == "NoiseInjection":
                m.fixed_noise = None

    def forward(self, x):
        return x

# Encoder

In [None]:
from __future__ import print_function
import torch
import numbers
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import math
import numpy as np
from PIL import Image
import os
import importlib
import argparse
from argparse import Namespace
from sklearn.decomposition import PCA as PCA

def normalize(v):
    if type(v) == list:
        return [normalize(vv) for vv in v]

    return v * torch.rsqrt((torch.sum(v ** 2, dim=1, keepdim=True) + 1e-8))

class RandomSpatialTransformer:
    def __init__(self, opt, bs):
        self.opt = opt
        #self.resample_transformation(bs)

    def create_affine_transformation(self, ref, rot, sx, sy, tx, ty):
        return torch.stack([-ref * sx * torch.cos(rot), -sy * torch.sin(rot), tx,
                            -ref * sx * torch.sin(rot), sy * torch.cos(rot), ty], axis=1)

    def resample_transformation(self, bs, device, reflection=None, rotation=None, scale=None, translation=None):
        dev = device
        zero = torch.zeros((bs), device=dev)
        if reflection is None:
            #if "ref" in self.opt.random_transformation_mode:
            ref = torch.round(torch.rand((bs), device=dev)) * 2 - 1
            #else:
            #    ref = 1.0
        else:
            ref = reflection

        if rotation is None:
            #if "rot" in self.opt.random_transformation_mode:
            max_rotation = 30 * math.pi / 180
            rot = torch.rand((bs), device=dev) * (2 * max_rotation) - max_rotation
            #else:
            #    rot = 0.0
        else:
            rot = rotation

        if scale is None:
            #if "scale" in self.opt.random_transformation_mode:
            min_scale = 1.0
            max_scale = 1.0
            sx = torch.rand((bs), device=dev) * (max_scale - min_scale) + min_scale
            sy = torch.rand((bs), device=dev) * (max_scale - min_scale) + min_scale
            #else:
            #    sx, sy = 1.0, 1.0
        else:
            sx, sy = scale

        tx, ty = zero, zero

        A = torch.stack([ref * sx * torch.cos(rot), -sy * torch.sin(rot), tx,
                         ref * sx * torch.sin(rot), sy * torch.cos(rot), ty], axis=1)
        return A.view(bs, 2, 3)

    def forward_transform(self, x, size):
        if type(x) == list:
            return [self.forward_transform(xx) for xx in x]

        affine_param = self.resample_transformation(x.size(0), x.device)
        affine_grid = F.affine_grid(affine_param, (x.size(0), x.size(1), size[0], size[1]), align_corners=False)
        x = F.grid_sample(x, affine_grid, padding_mode='reflection', align_corners=False)

        return x
    
def resize2d_tensor(x, size_or_tensor_of_size):
    if torch.is_tensor(size_or_tensor_of_size):
        size = size_or_tensor_of_size.size()
    elif isinstance(size_or_tensor_of_size, np.ndarray):
        size = size_or_tensor_of_size.shape
    else:
        size = size_or_tensor_of_size

    if isinstance(size, tuple) or isinstance(size, list):
        return F.interpolate(x, size[-2:],
                             mode='bilinear', align_corners=False)
    else:
        raise ValueError("%s is unrecognized" % str(type(size)))
    
def visualize_spatial_code(sp):
    device = sp.device
    #sp = (sp - sp.min()) / (sp.max() - sp.min() + 1e-7)
    if sp.size(1) <= 2:
        sp = sp.repeat([1, 3, 1, 1])[:, :3, :, :]
    if sp.size(1) == 3:
        pass
    else:
        sp = sp.detach().cpu().numpy()
        X = np.transpose(sp, (0, 2, 3, 1))
        B, H, W = X.shape[0], X.shape[1], X.shape[2]
        X = np.reshape(X, (-1, X.shape[3]))
        X = X - X.mean(axis=0, keepdims=True)
        try:
            Z = PCA(3).fit_transform(X)
        except ValueError:
            print("Running PCA on the structure code has failed.")
            print("This is likely a bug of scikit-learn in version 0.18.1.")
            print("https://stackoverflow.com/a/42764378")
            print("The visualization of the structure code on visdom won't work.")
            return torch.zeros(B, 3, H, W, device=device)
        sp = np.transpose(np.reshape(Z, (B, H, W, -1)), (0, 3, 1, 2))
        sp = (sp - sp.min()) / (sp.max() - sp.min()) * 2 - 1
        sp = torch.from_numpy(sp).to(device)
    return sp

def to_numpy(metric_dict):
    new_dict = {}
    for k, v in metric_dict.items():
        if "numpy" not in str(type(v)):
            v = v.detach().cpu().mean().numpy()
        new_dict[k] = v
    return new_dict

In [None]:
import numpy as np
# from util import *

class ToSpatialCode(torch.nn.Module):
    def __init__(self, inch, outch, scale):
        super().__init__()
        hiddench = inch // 2
        self.conv1 = ConvLayer(inch, hiddench, 1, activate=True, bias=True)
        self.conv2 = ConvLayer(hiddench, outch, 1, activate=False, bias=True)
        self.scale = scale
        self.upsample = Upsample([1, 3, 3, 1], 2)
        self.blur = Blur([1, 3, 3, 1], pad=(2, 1))
        self.register_buffer('kernel', make_kernel([1, 3, 3, 1]))

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        for i in range(int(np.log2(self.scale))):
            x = self.upsample(x)
        return x
        
class Encoder(BaseNetwork):
    def __init__(self, opt):
        super().__init__(opt)
        
        blur_kernel = [1, 2, 1] if self.opt.use_antialias else [1]
        self.add_module("FromRGB", ConvLayer(3, self.nc(0), 1))
        
        self.DownToSpatialCode = nn.Sequential()
        for i in range(self.opt.netE_num_downsampling_sp):
            self.DownToSpatialCode.add_module(
                "ResBlockDownBy%d" % (2 ** i),
                ResBlock(self.nc(i), self.nc(i + 1), blur_kernel, reflection_pad=True)
            )
            
        nchannels = self.nc(self.opt.netE_num_downsampling_sp)
        self.add_module(
            "ToSpatialCode",
            nn.Sequential(
                ConvLayer(nchannels, nchannels, 1, activate=True, bias=True),
                ConvLayer(nchannels, self.opt.spatial_code_ch, kernel_size=1,
                          activate=False, bias=True)
            )
        )
        
        self.DownToGlobalCode = nn.Sequential()
        for i in range(self.opt.netE_num_downsampling_gl):
            idx_from_beginning = self.opt.netE_num_downsampling_sp + i
            self.DownToGlobalCode.add_module(
                "ConvLayerDownBy%d" % (2 ** idx_from_beginning),
                ConvLayer(self.nc(idx_from_beginning),
                          self.nc(idx_from_beginning + 1), kernel_size=3,
                          blur_kernel=[1], downsample=True, pad=0)
            )
            
        nchannels = self.nc(self.opt.netE_num_downsampling_sp +
                            self.opt.netE_num_downsampling_gl)
        self.add_module(
            "ToGlobalCode",
            nn.Sequential(
                EqualLinear(nchannels, self.opt.global_code_ch)
            )
        )
        
    def nc(self, idx):
        nc = self.opt.netE_nc_steepness ** (5 + idx)
        nc = nc * self.opt.netE_scale_capacity
        nc = min(self.opt.global_code_ch, int(round(nc)))
        return round(nc)
    
    def forward(self, x, extract_features=False):
        x = self.FromRGB(x)
        midpoint = self.DownToSpatialCode(x)
        sp = self.ToSpatialCode(midpoint)
        
        if extract_features:
            padded_midpoint = F.pad(midpoint, (1, 0, 1, 0), mode='reflect')
            feature = self.DownToGlobalCode[0](padded_midpoint)
            assert feature.size(2) == sp.size(2) // 2 and feature.size(3) == sp.size(3) // 2
            feature = F.interpolate(
                feature, size=(7, 7), mode='bilinear', align_corners=False)
            
        x = self.DownToGlobalCode(midpoint)
        x = x.mean(dim=(2, 3))
        gl = self.ToGlobalCode(x)
        sp = normalize(sp)
        gl = normalize(gl)
        if extract_features:
            return sp, gl, feature
        else:
            return sp, gl

# StyleGan2 Layers

In [None]:
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)

class ModulatedConv2d(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        demodulate=True,
        upsample=False,
        downsample=False,
        blur_kernel=[1, 3, 3, 1],
    ):
        super().__init__()

        self.eps = 1e-8
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.upsample = upsample
        self.downsample = downsample

        if upsample:
            factor = 2
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2 + 1

            self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)

        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2

            self.blur = Blur(blur_kernel, pad=(pad0, pad1))

        fan_in = in_channel * kernel_size ** 2
        self.scale = 1 / math.sqrt(fan_in)
        self.padding = kernel_size // 2

        self.weight = nn.Parameter(
            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
        )

        self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)

        self.demodulate = demodulate
        self.new_demodulation = True

    def __repr__(self):
        return (
            f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
            f'upsample={self.upsample}, downsample={self.downsample})'
        )

    def forward(self, input, style):
        batch, in_channel, height, width = input.shape

        if style.dim() > 2:
            style = F.interpolate(style, size=(input.size(2), input.size(3)), mode='bilinear', align_corners=False)
            style = self.modulation(style).unsqueeze(1)
            if self.demodulate:
                style = style * torch.rsqrt(style.pow(2).mean([2], keepdim=True) + 1e-8)
            input = input * style
            weight = self.scale * self.weight
            weight = weight.repeat(batch, 1, 1, 1, 1)
        else:
            style = style.view(batch, style.size(1))
            style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
            if self.new_demodulation:
                style = style[:, 0, :, :, :]
                if self.demodulate:
                    style = style * torch.rsqrt(style.pow(2).mean([1], keepdim=True) + 1e-8)
                input = input * style
                weight = self.scale * self.weight
                weight = weight.repeat(batch, 1, 1, 1, 1)
            else:
                weight = self.scale * self.weight * style

        if self.demodulate:
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
            weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)

        weight = weight.view(
            batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
        )

        if self.upsample:
            input = input.view(1, batch * in_channel, height, width)
            weight = weight.view(
                batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
            )
            weight = weight.transpose(1, 2).reshape(
                batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
            )
            out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)
            out = self.blur(out)

        elif self.downsample:
            input = self.blur(input)
            _, _, height, width = input.shape
            input = input.view(1, batch * in_channel, height, width)
            out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        else:
            input = input.view(1, batch * in_channel, height, width)
            out = F.conv2d(input, weight, padding=self.padding, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        return out
            
class NoiseInjection(nn.Module):
    def __init__(self):
        super().__init__()

        self.weight = nn.Parameter(torch.zeros(1))
        self.fixed_noise = None
        self.image_size = None

    def forward(self, image, noise=None):
        if self.image_size is None:
            self.image_size = image.shape

        if noise is None and self.fixed_noise is None:
            batch, _, height, width = image.shape
            noise = image.new_empty(batch, 1, height, width).normal_()
        elif self.fixed_noise is not None:
            noise = self.fixed_noise
            # to avoid error when generating thumbnails in demo
            if image.size(2) != noise.size(2) or image.size(3) != noise.size(3):
                noise = F.interpolate(noise, image.shape[2:], mode="nearest")
        else:
            pass  # use the passed noise

        return image + self.weight * noise

class StyledConv(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        upsample=False,
        blur_kernel=[1, 3, 3, 1],
        demodulate=True,
        use_noise=True,
        lr_mul=1.0,
    ):
        super().__init__()

        self.conv = ModulatedConv2d(
            in_channel,
            out_channel,
            kernel_size,
            style_dim,
            upsample=upsample,
            blur_kernel=blur_kernel,
            demodulate=demodulate,
        )

        self.use_noise = use_noise
        self.noise = NoiseInjection()
        # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
        # self.activate = ScaledLeakyReLU(0.2)
        self.activate = FusedLeakyReLU(out_channel)

    def forward(self, input, style, noise=None):
        out = self.conv(input, style)
        if self.use_noise:
            out = self.noise(out, noise=noise)
        # out = out + self.bias
        out = self.activate(out)

        return out
    
class ToRGB(nn.Module):
    def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
        super().__init__()

        if upsample:
            self.upsample = Upsample(blur_kernel)

        self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
        self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))

    def forward(self, input, style, skip=None):
        out = self.conv(input, style)
        out = out + self.bias

        if skip is not None:
            skip = self.upsample(skip)

            out = out + skip

        return out

# Generator

In [None]:
class UpsamplingBlock(torch.nn.Module):
    def __init__(self, inch, outch, styledim,
                 blur_kernel=[1, 3, 3, 1], use_noise=False):
        super().__init__()
        self.inch, self.outch, self.styledim = inch, outch, styledim
        self.conv1 = StyledConv(inch, outch, 3, styledim, upsample=True,
                                blur_kernel=blur_kernel, use_noise=use_noise)
        self.conv2 = StyledConv(outch, outch, 3, styledim, upsample=False,
                                use_noise=use_noise)

    def forward(self, x, style):
        return self.conv2(self.conv1(x, style), style)
    
class ResolutionPreservingResnetBlock(torch.nn.Module):
    def __init__(self, opt, inch, outch, styledim):
        super().__init__()
        self.conv1 = StyledConv(inch, outch, 3, styledim, upsample=False)
        self.conv2 = StyledConv(outch, outch, 3, styledim, upsample=False)
        if inch != outch:
            self.skip = ConvLayer(inch, outch, 1, activate=False, bias=False)
        else:
            self.skip = torch.nn.Identity()

    def forward(self, x, style):
        skip = self.skip(x)
        res = self.conv2(self.conv1(x, style), style)
        return (skip + res) / math.sqrt(2)
    
class UpsamplingResnetBlock(torch.nn.Module):
    def __init__(self, inch, outch, styledim, blur_kernel=[1, 3, 3, 1], use_noise=False):
        super().__init__()
        self.inch, self.outch, self.styledim = inch, outch, styledim
        self.conv1 = StyledConv(inch, outch, 3, styledim, upsample=True, blur_kernel=blur_kernel, use_noise=use_noise)
        self.conv2 = StyledConv(outch, outch, 3, styledim, upsample=False, use_noise=use_noise)
        if inch != outch:
            self.skip = ConvLayer(inch, outch, 1, activate=True, bias=True)
        else:
            self.skip = torch.nn.Identity()

    def forward(self, x, style):
        skip = F.interpolate(self.skip(x), scale_factor=2, mode='bilinear', align_corners=False)
        res = self.conv2(self.conv1(x, style), style)
        return (skip + res) / math.sqrt(2)
    
class GeneratorModulation(torch.nn.Module):
    def __init__(self, styledim, outch):
        super().__init__()
        self.scale = EqualLinear(styledim, outch)
        self.bias = EqualLinear(styledim, outch)

    def forward(self, x, style):
        if style.ndimension() <= 2:
            return x * (1 * self.scale(style)[:, :, None, None]) + self.bias(style)[:, :, None, None]
        else:
            style = F.interpolate(style, size=(x.size(2), x.size(3)), mode='bilinear', align_corners=False)
            return x * (1 * self.scale(style)) + self.bias(style)

class Generator(BaseNetwork):
    def __init__(self, opt):
        super().__init__(opt)
        num_upsamplings = opt.netE_num_downsampling_sp
        blur_kernel = [1, 3, 3, 1] if opt.use_antialias else [1]

        self.global_code_ch = opt.global_code_ch + opt.num_classes

        self.add_module(
            "SpatialCodeModulation",
            GeneratorModulation(self.global_code_ch, opt.spatial_code_ch))

        in_channel = opt.spatial_code_ch
        for i in range(opt.netG_num_base_resnet_layers):
            # gradually increase the number of channels
            out_channel = (i + 1) / opt.netG_num_base_resnet_layers * self.nf(0)
            out_channel = max(opt.spatial_code_ch, round(out_channel))
            layer_name = "HeadResnetBlock%d" % i
            new_layer = ResolutionPreservingResnetBlock(
                opt, in_channel, out_channel, self.global_code_ch)
            self.add_module(layer_name, new_layer)
            in_channel = out_channel

        for j in range(num_upsamplings):
            out_channel = self.nf(j + 1)
            layer_name = "UpsamplingResBlock%d" % (2 ** (4 + j))
            new_layer = UpsamplingResnetBlock(
                in_channel, out_channel, self.global_code_ch,
                blur_kernel, opt.netG_use_noise)
            self.add_module(layer_name, new_layer)
            in_channel = out_channel

        last_layer = ToRGB(out_channel, self.global_code_ch,
                           blur_kernel=blur_kernel)
        self.add_module("ToRGB", last_layer)

    def nf(self, num_up):
        ch = 128 * (2 ** (self.opt.netE_num_downsampling_sp - num_up))
        ch = int(min(512, ch) * self.opt.netG_scale_capacity)
        return ch

    def forward(self, spatial_code, global_code):
        spatial_code = normalize(spatial_code)
        global_code = normalize(global_code)

        x = self.SpatialCodeModulation(spatial_code, global_code)
        for i in range(self.opt.netG_num_base_resnet_layers):
            resblock = getattr(self, "HeadResnetBlock%d" % i)
            x = resblock(x, global_code)

        for j in range(self.opt.netE_num_downsampling_sp):
            key_name = 2 ** (4 + j)
            upsampling_layer = getattr(self, "UpsamplingResBlock%d" % key_name)
            x = upsampling_layer(x, global_code)
        rgb = self.ToRGB(x, global_code, None)

        return rgb

# StyleGan2 Layer

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], reflection_pad=False, pad=None, downsample=True):
        super().__init__()

        self.conv1 = ConvLayer(in_channel, in_channel, 3, reflection_pad=reflection_pad, pad=pad)
        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample, blur_kernel=blur_kernel, reflection_pad=reflection_pad, pad=pad)

        self.skip = ConvLayer(
            in_channel, out_channel, 1, downsample=downsample, blur_kernel=blur_kernel, activate=False, bias=False
        )

    def forward(self, input):
        #print("before first resnet layeer, ", input.shape)
        out = self.conv1(input)
        #print("after first resnet layer, ", out.shape)
        out = self.conv2(out)
        #print("after second resnet layer, ", out.shape)

        skip = self.skip(input)
        out = (out + skip) / math.sqrt(2)

        return out

class BaseDiscriminator(nn.Module):
    def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
        super().__init__()

        channels = {
            4: 512,
            8: 512,
            16: min(512, int(512 * channel_multiplier)),
            32: min(512, int(512 * channel_multiplier)),
            64: int(256 * channel_multiplier),
            128: int(128 * channel_multiplier),
            256: int(64 * channel_multiplier),
            512: int(32 * channel_multiplier),
            1024: int(16 * channel_multiplier),
        }

        original_size = size
        size = 2 ** int(round(math.log(size, 2)))
        convs = [('0', ConvLayer(3, channels[size], 1))]

        log_size = int(math.log(size, 2))

        in_channel = channels[size]

        for i in range(log_size, 2, -1):
            out_channel = channels[2 ** (i - 1)]
            layer_name = str(9 - i) if i <= 8 else "%dx%d" % (2 ** i, 2 ** i)
            convs.append((layer_name, ResBlock(in_channel, out_channel, blur_kernel)))

            in_channel = out_channel

        self.convs = nn.Sequential(OrderedDict(convs))

        #self.stddev_group = 4
        #self.stddev_feat = 1

        self.final_conv = ConvLayer(in_channel, channels[4], 3)

        side_length = int(4 * original_size / size)

        self.final_linear = nn.Sequential(
            EqualLinear(channels[4] * (side_length ** 2), channels[4], activation='fused_lrelu'),
            EqualLinear(channels[4], 1),
        )

    def forward(self, input):
        out = self.convs(input)

        batch, channel, height, width = out.shape
        
        #group = min(batch, self.stddev_group)
        #stddev = out.view(
        #    group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
        #)
        #stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
        #stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
        #stddev = stddev.repeat(group, 1, height, width)
        #out = torch.cat([out, stddev], 1)

        out = self.final_conv(out)
        out = out.view(batch, -1)
        out = self.final_linear(out)

        return out

    def get_features(self, input):
        return self.final_conv(self.convs(input))

In [None]:
class Discriminator(BaseNetwork):
    def __init__(self, opt):
        super().__init__(opt)
        self.stylegan2_D = BaseDiscriminator(
            opt.crop_size,
            2.0 * opt.netD_scale_capacity,
            blur_kernel=[1, 3, 3, 1] if self.opt.use_antialias else [1]
        )

    def forward(self, x):
        pred = self.stylegan2_D(x)
        return pred

    def get_features(self, x):
        return self.stylegan2_D.get_features(x)

    def get_pred_from_features(self, feat, label):
        assert label is None
        feat = feat.flatten(1)
        out = self.stylegan2_D.final_linear(feat)
        return out

# Patch discriminator

In [None]:
class BasePatchDiscriminator(BaseNetwork):
    def __init__(self, opt):
        super().__init__(opt)
        #self.visdom = util.Visualizer(opt)

    def needs_regularization(self):
        return False

    def extract_features(self, patches):
        raise NotImplementedError()

    def discriminate_features(self, feature1, feature2):
        raise NotImplementedError()

    def apply_random_transformation(self, patches):
        B, ntiles, C, H, W = patches.size()
        patches = patches.view(B * ntiles, C, H, W)
        before = patches
        transformer = RandomSpatialTransformer(self.opt, B * ntiles)
        patches = transformer.forward_transform(patches, (self.opt.patch_size, self.opt.patch_size))
        #self.visdom.display_current_results({'before': before,
        #                                     'after': patches}, 0, save_result=False)
        return patches.view(B, ntiles, C, H, W)

    def sample_patches_old(self, img, indices):
        B, C, H, W = img.size()
        s = self.opt.patch_size
        if H % s > 0 or W % s > 0:
            y_offset = torch.randint(H % s, (), device=img.device)
            x_offset = torch.randint(W % s, (), device=img.device)
            img = img[:, :,
                      y_offset:y_offset + s * (H // s),
                      x_offset:x_offset + s * (W // s)]
        img = img.view(B, C, H//s, s, W//s, s)
        ntiles = (H // s) * (W // s)
        tiles = img.permute(0, 2, 4, 1, 3, 5).reshape(B, ntiles, C, s, s)
        if indices is None:
            indices = torch.randperm(ntiles, device=img.device)[:self.opt.max_num_tiles]
            return self.apply_random_transformation(tiles[:, indices]), indices
        else:
            return self.apply_random_transformation(tiles[:, indices])

    def forward(self, real, fake, fake_only=False):
        assert real is not None
        real_patches, patch_ids = self.sample_patches(real, None)
        if fake is None:
            real_patches.requires_grad_()
        real_feat = self.extract_features(real_patches)

        bs = real.size(0)
        if fake is None or not fake_only:
            pred_real = self.discriminate_features(
                real_feat,
                torch.roll(real_feat, 1, 1))
            pred_real = pred_real.view(bs, -1)


        if fake is not None:
            fake_patches = self.sample_patches(fake, patch_ids)
            #self.visualizer.display_current_results({'real_A': real_patches[0],
            #                                         'real_B': torch.roll(fake_patches, 1, 1)[0]}, 0, False, max_num_images=16)
            fake_feat = self.extract_features(fake_patches)
            pred_fake = self.discriminate_features(
                real_feat,
                torch.roll(fake_feat, 1, 1))
            pred_fake = pred_fake.view(bs, -1)

        if fake is None:
            return pred_real, real_patches
        elif fake_only:
            return pred_fake
        else:
            return pred_real, pred_fake
        
class PatchDiscriminator(BasePatchDiscriminator):

    def __init__(self, opt):
        super().__init__(opt)
        channel_multiplier = self.opt.netPatchD_scale_capacity
        size = self.opt.patch_size
        channels = {
            4: min(self.opt.netPatchD_max_nc, int(256 * channel_multiplier)),
            8: min(self.opt.netPatchD_max_nc, int(128 * channel_multiplier)),
            16: min(self.opt.netPatchD_max_nc, int(64 * channel_multiplier)),
            32: int(32 * channel_multiplier),
            64: int(16 * channel_multiplier),
            128: int(8 * channel_multiplier),
            256: int(4 * channel_multiplier),
        }

        log_size = int(math.ceil(math.log(size, 2)))

        in_channel = channels[2 ** log_size]

        blur_kernel = [1, 3, 3, 1] if self.opt.use_antialias else [1]

        convs = [('0', ConvLayer(3, in_channel, 3))]

        for i in range(log_size, 2, -1):
            out_channel = channels[2 ** (i - 1)]

            layer_name = str(7 - i) if i <= 6 else "%dx%d" % (2 ** i, 2 ** i)
            convs.append((layer_name, ResBlock(in_channel, out_channel, blur_kernel)))

            in_channel = out_channel

        convs.append(('5', ResBlock(in_channel, self.opt.netPatchD_max_nc * 2, downsample=False)))
        convs.append(('6', ConvLayer(self.opt.netPatchD_max_nc * 2, self.opt.netPatchD_max_nc, 3, pad=0)))

        self.convs = nn.Sequential(OrderedDict(convs))

        out_dim = 1

        pairlinear1 = EqualLinear(channels[4] * 2 * 2 * 2, 2048, activation='fused_lrelu')
        pairlinear2 = EqualLinear(2048, 2048, activation='fused_lrelu')
        pairlinear3 = EqualLinear(2048, 1024, activation='fused_lrelu')
        pairlinear4 = EqualLinear(1024, out_dim)
        self.pairlinear = nn.Sequential(pairlinear1, pairlinear2, pairlinear3, pairlinear4)

    def extract_features(self, patches, aggregate=False):
        if patches.ndim == 5:
            B, T, C, H, W = patches.size()
            flattened_patches = patches.flatten(0, 1)
        else:
            B, C, H, W = patches.size()
            T = patches.size(1)
            flattened_patches = patches
        features = self.convs(flattened_patches)
        features = features.view(B, T, features.size(1), features.size(2), features.size(3))
        if aggregate:
            features = features.mean(1, keepdim=True).expand(-1, T, -1, -1, -1)
        return features.flatten(0, 1)

    def extract_layerwise_features(self, image):
        feats = [image]
        for m in self.convs:
            feats.append(m(feats[-1]))

        return feats

    def discriminate_features(self, feature1, feature2):
        feature1 = feature1.flatten(1)
        feature2 = feature2.flatten(1)
        out = self.pairlinear(torch.cat([feature1, feature2], dim=1))
        return out

In [None]:
def apply_random_crop(x, target_size, scale_range, num_crops=1, return_rect=False):
    # build grid
    B = x.size(0) * num_crops
    flip = torch.round(torch.rand(B, 1, 1, 1, device=x.device)) * 2 - 1.0
    unit_grid_x = torch.linspace(-1.0, 1.0, target_size, device=x.device)[np.newaxis, np.newaxis, :, np.newaxis].repeat(B, target_size, 1, 1)
    unit_grid_y = unit_grid_x.transpose(1, 2)
    unit_grid = torch.cat([unit_grid_x * flip, unit_grid_y], dim=3)


    #crops = []
    x = x.unsqueeze(1).expand(-1, num_crops, -1, -1, -1).flatten(0, 1)
    #for i in range(num_crops):
    scale = torch.rand(B, 1, 1, 2, device=x.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
    offset = (torch.rand(B, 1, 1, 2, device=x.device) * 2 - 1) * (1 - scale)
    sampling_grid = unit_grid * scale + offset
    crop = F.grid_sample(x, sampling_grid, align_corners=False)
    #crops.append(crop)
    #crop = torch.stack(crops, dim=1)
    crop = crop.view(B // num_crops, num_crops, crop.size(1), crop.size(2), crop.size(3))

    return crop

# BaseModel

In [None]:
class BaseModel(torch.nn.Module):
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.device = torch.device('cuda:0') if opt.num_gpus > 0 else torch.device('cpu')

    def initialize(self):
        pass

    def per_gpu_initialize(self):
        pass

    def compute_generator_losses(self, data_i):
        return {}

    def compute_discriminator_losses(self, data_i):
        return {}

    def get_visuals_for_snapshot(self, data_i):
        return {}

    def get_parameters_for_mode(self, mode):
        return {}

    def save(self, total_steps_so_far):
        savedir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
        checkpoint_name = "%dk_checkpoint.pth" % (total_steps_so_far // 1000)
        savepath = os.path.join(savedir, checkpoint_name)
        torch.save(self.state_dict(), savepath)
        sympath = os.path.join(savedir, "latest_checkpoint.pth")
        if os.path.exists(sympath):
            os.remove(sympath)
        os.symlink(checkpoint_name, sympath)

    def load(self):
        if self.opt.isTrain and self.opt.pretrained_name is not None:
            loaddir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
        else:
            loaddir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
        checkpoint_name = "%s_checkpoint.pth" % self.opt.resume_iter
        checkpoint_path = os.path.join(loaddir, checkpoint_name)
        if not os.path.exists(checkpoint_path):
            print("\n\ncheckpoint %s does not exist!" % checkpoint_path)
            assert self.opt.isTrain, "In test mode, the checkpoint file must exist"
            print("Training will start from scratch")
            return
        state_dict = torch.load(checkpoint_path,
                                map_location=str(self.device))
        # self.load_state_dict(state_dict)
        own_state = self.state_dict()
        skip_all = False
        for name, own_param in own_state.items():
            if not self.opt.isTrain and (name.startswith("D.") or name.startswith("Dpatch.")):
                continue
            if name not in state_dict:
                print("Key %s does not exist in checkpoint. Skipping..." % name)
                continue
            # if name.startswith("C.net"):
            #    continue
            param = state_dict[name]
            if own_param.shape != param.shape:
                message = "Key [%s]: Shape does not match the created model (%s) and loaded checkpoint (%s)" % (name, str(own_param.shape), str(param.shape))
                if skip_all:
                    print(message)
                    min_shape = [min(s1, s2) for s1, s2 in zip(own_param.shape, param.shape)]
                    ms = min_shape
                    if len(min_shape) == 1:
                        own_param[:ms[0]].copy_(param[:ms[0]])
                        own_param[ms[0]:].copy_(own_param[ms[0]:] * 0)
                    elif len(min_shape) == 2:
                        own_param[:ms[0], :ms[1]].copy_(param[:ms[0], :ms[1]])
                        own_param[ms[0]:, ms[1]:].copy_(own_param[ms[0]:, ms[1]:] * 0)
                    elif len(ms) == 4:
                        own_param[:ms[0], :ms[1], :ms[2], :ms[3]].copy_(param[:ms[0], :ms[1], :ms[2], :ms[3]])
                        own_param[ms[0]:, ms[1]:, ms[2]:, ms[3]:].copy_(own_param[ms[0]:, ms[1]:, ms[2]:, ms[3]:] * 0)
                    else:
                        print("Skipping min_shape of %s" % str(ms))
                    continue
                userinput = input("%s. Force loading? (yes, no, all) " % (message))
                if userinput.lower() == "yes":
                    pass
                elif userinput.lower() == "no":
                    #assert own_param.shape == param.shape
                    continue
                elif userinput.lower() == "all":
                    skip_all = True
                else:
                    raise ValueError(userinput)
                min_shape = [min(s1, s2) for s1, s2 in zip(own_param.shape, param.shape)]
                ms = min_shape
                if len(min_shape) == 1:
                    own_param[:ms[0]].copy_(param[:ms[0]])
                    own_param[ms[0]:].copy_(own_param[ms[0]:] * 0)
                elif len(min_shape) == 2:
                    own_param[:ms[0], :ms[1]].copy_(param[:ms[0], :ms[1]])
                    own_param[ms[0]:, ms[1]:].copy_(own_param[ms[0]:, ms[1]:] * 0)
                elif len(ms) == 4:
                    own_param[:ms[0], :ms[1], :ms[2], :ms[3]].copy_(param[:ms[0], :ms[1], :ms[2], :ms[3]])
                    own_param[ms[0]:, ms[1]:, ms[2]:, ms[3]:].copy_(own_param[ms[0]:, ms[1]:, ms[2]:, ms[3]:] * 0)
                else:
                    print("Skipping min_shape of %s" % str(ms))
                continue
            own_param.copy_(param)
        print("checkpoint loaded from %s" % os.path.join(loaddir, checkpoint_name))

    def forward(self, *args, command=None, **kwargs):
        """ wrapper for multigpu training. BaseModel is expected to be
        wrapped in nn.parallel.DataParallel, which distributes its call to
        the BaseModel instance on each GPU """
        if command is not None:
            method = getattr(self, command)
            assert callable(method), "[%s] is not a method of %s" % (command, type(self).__name__)
            return method(*args, **kwargs)
        else:
            raise ValueError(command)

# Loss

In [None]:
import torchvision

def gan_loss(pred, should_be_classified_as_real):
    bs = pred.size(0)
    if should_be_classified_as_real:
        return F.softplus(-pred).view(bs, -1).mean(dim=1)
    else:
        return F.softplus(pred).view(bs, -1).mean(dim=1)


def feature_matching_loss(xs, ys, equal_weights=False, num_layers=6):
    loss = 0.0
    for i, (x, y) in enumerate(zip(xs[:num_layers], ys[:num_layers])):
        if equal_weights:
            weight = 1.0 / min(num_layers, len(xs))
        else:
            weight = 1 / (2 ** (min(num_layers, len(xs)) - i))
        loss = loss + (x - y).abs().flatten(1).mean(1) * weight
    return loss


class IntraImageNCELoss(nn.Module):
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='mean')

    def forward(self, query, target):
        num_locations = min(query.size(2) * query.size(3), self.opt.intraimage_num_locations)
        bs = query.size(0)
        patch_ids = torch.randperm(num_locations, device=query.device)

        query = query.flatten(2, 3)
        target = target.flatten(2, 3)

        # both query and target are of size B x C x N
        query = query[:, :, patch_ids]
        target = target[:, :, patch_ids]

        cosine_similarity = torch.bmm(query.transpose(1, 2), target)
        cosine_similarity = cosine_similarity.flatten(0, 1)
        target_label = torch.arange(num_locations, dtype=torch.long, device=query.device).repeat(bs)
        loss = self.cross_entropy_loss(cosine_similarity / 0.07, target_label)
        return loss


class VGG16Loss(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg_convs = torchvision.models.vgg16(pretrained=True).features
        self.register_buffer('mean',
                             torch.tensor([0.485, 0.456, 0.406])[None, :, None, None] - 0.5)
        self.register_buffer('stdev',
                             torch.tensor([0.229, 0.224, 0.225])[None, :, None, None] * 2)
        self.downsample = Downsample([1, 2, 1], factor=2)

    def copy_section(self, source, start, end):
        slice = torch.nn.Sequential()
        for i in range(start, end):
            slice.add_module(str(i), source[i])
        return slice

    def vgg_forward(self, x):
        x = (x - self.mean) / self.stdev
        features = []
        for name, layer in self.vgg_convs.named_children():
            if "MaxPool2d" == type(layer).__name__:
                features.append(x)
                if len(features) == 3:
                    break
                x = self.downsample(x)
            else:
                x = layer(x)
        return features

    def forward(self, x, y):
        y = y.detach()
        loss = 0
        weights = [1 / 32, 1 / 16, 1 / 8, 1 / 4, 1.0]
        #weights = [1] * 5
        total_weights = 0.0
        for i, (xf, yf) in enumerate(zip(self.vgg_forward(x), self.vgg_forward(y))):
            loss += F.l1_loss(xf, yf) * weights[i]
            total_weights += weights[i]
        return loss / total_weights


class NCELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='mean')

    def forward(self, query, target, negatives):
        query = normalize(query.flatten(1))
        target = normalize(target.flatten(1))
        negatives = normalize(negatives.flatten(1))
        bs = query.size(0)
        sim_pos = (query * target).sum(dim=1, keepdim=True)
        sim_neg = torch.mm(query, negatives.transpose(0, 1))
        all_similarity = torch.cat([sim_pos, sim_neg], axis=1) / 0.07
        #sim_target = util.compute_similarity_logit(query, target)
        #sim_target = torch.mm(query, target.transpose(0, 1)) / 0.07
        #sim_query = util.compute_similarity_logit(query, query)
        #util.set_diag_(sim_query, -20.0)

        #all_similarity = torch.cat([sim_target, sim_query], axis=1)

        #target_label = torch.arange(bs, dtype=torch.long,
        #                            device=query.device)
        target_label = torch.zeros(bs, dtype=torch.long, device=query.device)
        loss = self.cross_entropy_loss(all_similarity,
                                       target_label)
        return loss


class ScaleInvariantReconstructionLoss(nn.Module):
    def forward(self, query, target):
        query_flat = query.transpose(1, 3)
        target_flat = target.transpose(1, 3)
        dist = 1.0 - torch.bmm(
            query_flat[:, :, :, None, :].flatten(0, 2),
            target_flat[:, :, :, :, None].flatten(0, 2),
        )

        target_spatially_flat = target.flatten(1, 2)
        num_samples = min(target_spatially_flat.size(1), 64)
        random_indices = torch.randperm(num_samples, dtype=torch.long, device=target.device)
        randomly_sampled = target_spatially_flat[:, random_indices]
        random_indices = torch.randperm(num_samples, dtype=torch.long, device=target.device)
        another_random_sample = target_spatially_flat[:, random_indices]

        random_similarity = torch.bmm(
            randomly_sampled[:, :, None, :].flatten(0, 1),
            torch.flip(another_random_sample, [0])[:, :, :, None].flatten(0, 1)
        )

        return dist.mean() + random_similarity.clamp(min=0.0).mean()


# Swapping Autoencoder Model

In [None]:
class SwappingAutoencoderModel(BaseModel):
    
    def initialize(self):
        self.E = Encoder(opt)
        self.G = Generator(opt)
        if self.opt.lambda_GAN > 0.0:
            self.D = Discriminator(opt)
        if self.opt.lambda_PatchGAN > 0.0:
            self.Dpatch = PatchDiscriminator(opt)

        # Count the iteration count of the discriminator
        # Used for lazy R1 regularization (c.f. Appendix B of StyleGAN2)
        self.register_buffer(
            "num_discriminator_iters", torch.zeros(1, dtype=torch.long)
        )
        self.l1_loss = torch.nn.L1Loss()

#         if (not self.opt.isTrain) or self.opt.continue_train:
#             self.load()

        if self.opt.num_gpus > 0:
            self.to("cuda:0")

    def per_gpu_initialize(self):
        pass
    
    def get_random_crops(self, x):
        """ Make random crops.
            Corresponds to the yellow and blue random crops of Figure 2.
        """
        crops = apply_random_crop(
            x, self.opt.patch_size,
            (self.opt.patch_min_scale, self.opt.patch_max_scale),
            num_crops=self.opt.patch_num_crops
        )
        return crops

    def swap(self, x):
        """ Swaps (or mixes) the ordering of the minibatch to achieve transfer"""
        shape = x.shape
        assert shape[0] % 2 == 0, "Minibatch size must be a multiple of 2"
        new_shape = [shape[0] // 2, 2] + list(shape[1:])
        x = x.view(*new_shape)
        x = torch.flip(x, [1])
        return x.view(*shape)

    def compute_image_discriminator_losses(self, real, rec, mix):
#         if self.opt.lambda_GAN == 0.0:
#             return {}

        pred_real = self.D(real)
        pred_rec = self.D(rec)
        pred_mix = self.D(mix)

        losses = {}
        losses["D_real"] = gan_loss(
            pred_real, should_be_classified_as_real=True
        ) * self.opt.lambda_GAN

        losses["D_rec"] = gan_loss(
            pred_rec, should_be_classified_as_real=False
        ) * (0.5 * self.opt.lambda_GAN)
        losses["D_mix"] = gan_loss(
            pred_mix, should_be_classified_as_real=False
        ) * (0.5 * self.opt.lambda_GAN)

        return losses

    def compute_patch_discriminator_losses(self, real, mix):
        losses = {}
        real_feat = self.Dpatch.extract_features(
            self.get_random_crops(real),
            aggregate=self.opt.patch_use_aggregation
        )
        target_feat = self.Dpatch.extract_features(self.get_random_crops(real))
        mix_feat = self.Dpatch.extract_features(self.get_random_crops(mix))

        losses["PatchD_real"] = gan_loss(
            self.Dpatch.discriminate_features(real_feat, target_feat),
            should_be_classified_as_real=True,
        ) * self.opt.lambda_PatchGAN

        losses["PatchD_mix"] = gan_loss(
            self.Dpatch.discriminate_features(real_feat, mix_feat),
            should_be_classified_as_real=False,
        ) * self.opt.lambda_PatchGAN

        return losses

    def compute_discriminator_losses(self, real):
        self.num_discriminator_iters.add_(1)

        sp, gl = self.E(real)
        B = real.size(0)
        assert B % 2 == 0, "Batch size must be even on each GPU."

        # To save memory, compute the GAN loss on only
        # half of the reconstructed images
        rec = self.G(sp[:B // 2], gl[:B // 2])
        mix = self.G(self.swap(sp), gl)

        losses = self.compute_image_discriminator_losses(real, rec, mix)

        if self.opt.lambda_PatchGAN > 0.0:
            patch_losses = self.compute_patch_discriminator_losses(real, mix)
            losses.update(patch_losses)

        metrics = {}  # no metrics to report for the Discriminator iteration

        return losses, metrics, sp.detach(), gl.detach()

    def compute_R1_loss(self, real):
        losses = {}
        if self.opt.lambda_R1 > 0.0:
            real.requires_grad_()
            pred_real = self.D(real).sum()
            grad_real, = torch.autograd.grad(
                outputs=pred_real,
                inputs=[real],
                create_graph=True,
                retain_graph=True,
            )
            grad_real2 = grad_real.pow(2)
            dims = list(range(1, grad_real2.ndim))
            grad_penalty = grad_real2.sum(dims) * (self.opt.lambda_R1 * 0.5)
        else:
            grad_penalty = 0.0

        if self.opt.lambda_patch_R1 > 0.0:
            real_crop = self.get_random_crops(real).detach()
            real_crop.requires_grad_()
            target_crop = self.get_random_crops(real).detach()
            target_crop.requires_grad_()

            real_feat = self.Dpatch.extract_features(
                real_crop,
                aggregate=self.opt.patch_use_aggregation)
            target_feat = self.Dpatch.extract_features(target_crop)
            pred_real_patch = self.Dpatch.discriminate_features(
                real_feat, target_feat
            ).sum()

            grad_real, grad_target = torch.autograd.grad(
                outputs=pred_real_patch,
                inputs=[real_crop, target_crop],
                create_graph=True,
                retain_graph=True,
            )

            dims = list(range(1, grad_real.ndim))
            grad_crop_penalty = grad_real.pow(2).sum(dims) + grad_target.pow(2).sum(dims)
            grad_crop_penalty *= (0.5 * self.opt.lambda_patch_R1 * 0.5)
        else:
            grad_crop_penalty = 0.0

        losses["D_R1"] = grad_penalty + grad_crop_penalty

        return losses

    def compute_generator_losses(self, real, sp_ma, gl_ma):
        losses, metrics = {}, {}
        B = real.size(0)

        sp, gl = self.E(real)
        rec = self.G(sp[:B // 2], gl[:B // 2])  # only on B//2 to save memory
        sp_mix = self.swap(sp)
        
        # record the error of the reconstructed images for monitoring purposes
        metrics["L1_dist"] = self.l1_loss(rec, real[:B // 2])

        if self.opt.lambda_L1 > 0.0:
            losses["G_L1"] = metrics["L1_dist"] * self.opt.lambda_L1

        if self.opt.crop_size >= 1024:
            # another momery-saving trick: reduce #outputs to save memory
            real = real[B // 2:]
            gl = gl[B // 2:]
            sp_mix = sp_mix[B // 2:]

        mix = self.G(sp_mix, gl)

        if self.opt.lambda_GAN > 0.0:
            losses["G_GAN_rec"] = gan_loss(
                self.D(rec),
                should_be_classified_as_real=True
            ) * (self.opt.lambda_GAN * 0.5)

            losses["G_GAN_mix"] = gan_loss(
                self.D(mix),
                should_be_classified_as_real=True
            ) * (self.opt.lambda_GAN * 1.0)

        if self.opt.lambda_PatchGAN > 0.0:
            real_feat = self.Dpatch.extract_features(
                self.get_random_crops(real),
                aggregate=self.opt.patch_use_aggregation).detach()
            mix_feat = self.Dpatch.extract_features(self.get_random_crops(mix))

            losses["G_mix"] = gan_loss(
                self.Dpatch.discriminate_features(real_feat, mix_feat),
                should_be_classified_as_real=True,
            ) * self.opt.lambda_PatchGAN

        return losses, metrics

    def get_visuals_for_snapshot(self, real):
        if self.opt.isTrain:
            # avoid the overhead of generating too many visuals during training
            real = real[:2] if self.opt.num_gpus > 1 else real[:4]
        sp, gl = self.E(real)
        layout = resize2d_tensor(visualize_spatial_code(sp), real)
        rec = self.G(sp, gl)
        mix = self.G(sp, self.swap(gl))

        visuals = {"real": real, "layout": layout, "rec": rec, "mix": mix}

        return visuals

    def fix_noise(self, sample_image=None):
        if sample_image is not None:
            # The generator should be run at least once,
            # so that the noise dimensions could be computed
            sp, gl = self.E(sample_image)
            self.G(sp, gl)
        noise_var = self.G.fix_and_gather_noise_parameters()
        return noise_var

    def encode(self, image, extract_features=False):
        return self.E(image, extract_features=extract_features)

    def decode(self, spatial_code, global_code):
        return self.G(spatial_code, global_code)

    def get_parameters_for_mode(self, mode):
        if mode == "generator":
            return list(self.G.parameters()) + list(self.E.parameters())
        elif mode == "discriminator":
            Dparams = []
            if self.opt.lambda_GAN > 0.0:
                Dparams += list(self.D.parameters())
            if self.opt.lambda_PatchGAN > 0.0:
                Dparams += list(self.Dpatch.parameters())
            return Dparams

# Optimizer

In [None]:
class MultiGPUModelWrapper():
    def __init__(self, opt, model: BaseModel):
        self.opt = opt
        if opt.num_gpus > 0:
            model = model.to('cuda:0')
        self.parallelized_model = torch.nn.parallel.DataParallel(model)
        self.parallelized_model(command="per_gpu_initialize")
        self.singlegpu_model = self.parallelized_model.module
        self.singlegpu_model(command="per_gpu_initialize")

    def get_parameters_for_mode(self, mode):
        return self.singlegpu_model.get_parameters_for_mode(mode)

    def save(self, total_steps_so_far):
        self.singlegpu_model.save(total_steps_so_far)

    def __call__(self, *args, **kwargs):
        """ Calls are forwarded to __call__ of BaseModel through DataParallel, and corresponding methods specified by |command| will be called. Please see BaseModel.forward() to see how it is done. """
        return self.parallelized_model(*args, **kwargs)

In [None]:
class BaseOptimizer():
    def __init__(self, model: MultiGPUModelWrapper):
        self.opt = model.opt

    def train_one_step(self, data_i, total_steps_so_far):
        pass

    def get_visuals_for_snapshot(self, data_i):
        return {}

    def save(self, total_steps_so_far):
        pass

In [None]:
class SwappingAutoencoderOptimizer(BaseOptimizer):
    
    def __init__(self, model: MultiGPUModelWrapper):
        self.opt = model.opt
        opt = self.opt
        self.model = model
        self.train_mode_counter = 0
        self.discriminator_iter_counter = 0

        self.Gparams = self.model.get_parameters_for_mode("generator")
        self.Dparams = self.model.get_parameters_for_mode("discriminator")

        self.optimizer_G = torch.optim.Adam(
            self.Gparams, lr=opt.lr, betas=(opt.beta1, opt.beta2)
        )

        # c.f. StyleGAN2 (https://arxiv.org/abs/1912.04958) Appendix B
        c = opt.R1_once_every / (1 + opt.R1_once_every)
        self.optimizer_D = torch.optim.Adam(
            self.Dparams, lr=opt.lr * c, betas=(opt.beta1 ** c, opt.beta2 ** c)
        )

    def set_requires_grad(self, params, requires_grad):
        """ For more efficient optimization, turn on and off
            recording of gradients for |params|.
        """
        for p in params:
            p.requires_grad_(requires_grad)

    def toggle_training_mode(self):
        modes = ["discriminator", "generator", "generator"]
        self.train_mode_counter = (self.train_mode_counter + 1) % len(modes)
        return modes[self.train_mode_counter]

    def train_one_step(self, images_minibatch):
        if self.toggle_training_mode() == "generator":
            losses = self.train_discriminator_one_step(images_minibatch)
        else:
            losses = self.train_generator_one_step(images_minibatch)
        return to_numpy(losses)

    def train_generator_one_step(self, images):
        self.set_requires_grad(self.Dparams, False)
        self.set_requires_grad(self.Gparams, True)
        sp_ma, gl_ma = None, None
        self.optimizer_G.zero_grad()
        g_losses, g_metrics = self.model(
            images, sp_ma, gl_ma, command="compute_generator_losses"
        )
        g_loss = sum([v.mean() for v in g_losses.values()])
        g_loss.backward()
        self.optimizer_G.step()
        g_losses.update(g_metrics)
        return g_losses

    def train_discriminator_one_step(self, images):
        if self.opt.lambda_GAN == 0.0 and self.opt.lambda_PatchGAN == 0.0:
            return {}
        self.set_requires_grad(self.Dparams, True)
        self.set_requires_grad(self.Gparams, False)
        self.discriminator_iter_counter += 1
        self.optimizer_D.zero_grad()
        d_losses, d_metrics, sp, gl = self.model(
            images, command="compute_discriminator_losses"
        )
        self.previous_sp = sp.detach()
        self.previous_gl = gl.detach()
        d_loss = sum([v.mean() for v in d_losses.values()])
        d_loss.backward()
        self.optimizer_D.step()

        needs_R1 = self.opt.lambda_R1 > 0.0 or self.opt.lambda_patch_R1 > 0.0
        needs_R1_at_current_iter = needs_R1 and self.discriminator_iter_counter % self.opt.R1_once_every == 0
        if needs_R1_at_current_iter:
            self.optimizer_D.zero_grad()
            r1_losses = self.model(images, command="compute_R1_loss")
            d_losses.update(r1_losses)
            r1_loss = sum([v.mean() for v in r1_losses.values()])
            r1_loss = r1_loss * self.opt.R1_once_every
            r1_loss.backward()
            self.optimizer_D.step()

        d_losses["D_total"] = sum([v.mean() for v in d_losses.values()])
        d_losses.update(d_metrics)
        return d_losses

    def get_visuals_for_snapshot(self, data_i):
        images = self.prepare_images(data_i)
        with torch.no_grad():
            return self.model(images, command="get_visuals_for_snapshot")

    def save(self, total_steps_so_far):
        self.model.save(total_steps_so_far)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters()) / 1000000

In [None]:
base = BaseModel(opt)
model = SwappingAutoencoderModel(opt)
model.initialize()
param = torch.load("/kaggle/input/ffhq-670k/ffhq_670k.pth")
model.load_state_dict(param)
optimizer = SwappingAutoencoderOptimizer(model)

print(f"Number of Encoder parameters:             {count_parameters(model.E):.2f}M")
print(f"Number of Generator parameters:           {count_parameters(model.G):.2f}M")
print(f"Number of Discriminator parameters:       {count_parameters(model.D):.2f}M")
print(f"Number of Patch Discriminator parameters: {count_parameters(model.Dpatch):.2f}M")

# Train

In [None]:
class MetricTracker:
    def __init__(self, opt):
        self.opt = opt
        self.metrics = {}

    def moving_average(self, old, new):
        s = 0.98
        return old * (s) + new * (1 - s)

    def update_metrics(self, metric_dict, smoothe=True):
        default_smoothe = smoothe
        for k, v in metric_dict.items():
            if k == "D_R1":
                smoothe = False
            else:
                smoothe = default_smoothe
            if k in self.metrics and smoothe:
                self.metrics[k] = self.moving_average(self.metrics[k], v)
            else:
                self.metrics[k] = v

    def current_metrics(self):
        keys = sorted(list(self.metrics.keys()))
        ordered_metrics = OrderedDict([(k, self.metrics[k]) for k in keys])
        return ordered_metrics
    
metric_tracker = MetricTracker(opt)

In [None]:
def print_current_losses(iters, losses):
    """
    Parameters:
        epoch (int) -- current epoch
        iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
        losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
        t_comp (float) -- computational time per data point (normalized by batch_size)
        t_data (float) -- data loading time per data point (normalized by batch_size)
    """
    message = 'iters: %d, ' % (iters)
    
    for k, v in losses.items():
        message += '%s: %.3f, ' % (k, v.mean())
    print(message)  # print the message
    
    
def display_current_results(model, real):
    sp, gl = model.E(real)
    rec = model.G(sp, gl)
    mix = model.G(sp, model.swap(gl))
    
    display(to_pil_image(rec[0]))
    display(to_pil_image(mix[0]))
    display(to_pil_image(rec[1]))
    display(to_pil_image(mix[1]))

In [None]:
import time
from torchvision.transforms.functional import to_tensor

epochs = 50
iters = 0

start_time = time.time()

In [None]:
# anime_path = "/kaggle/input/test-jin/Test/J (1).png"
# human_path = "/kaggle/input/test-jin/Test/jin (13).png"

# anime_img = Image.open(anime_path)
# human_img = Image.open(human_path)

# anime_tensor = torch.unsqueeze(to_tensor(anime_img), 0).to("cuda")
# human_tensor = torch.unsqueeze(to_tensor(human_img), 0).to("cuda")

# anime_sp, anime_gl = model.E(anime_tensor)
# human_sp, human_gl = model.E(human_tensor)

# aa_rec = model.G(anime_sp, anime_gl)
# ah_rec = model.G(anime_sp, human_gl)
# hh_rec = model.G(human_sp, human_gl)
# ha_rec = model.G(human_sp, anime_gl)

# display(to_pil_image(aa_rec[0]))
# display(to_pil_image(ah_rec[0]))
# display(to_pil_image(hh_rec[0]))
# display(to_pil_image(ha_rec[0]))

# del anime_sp, anime_gl, human_sp, human_gl, aa_rec, ah_rec, hh_rec, ha_rec

In [None]:
for epoch in range(epochs):
    for batch in Cartoon_Dataloader:
        img_batch = [Image.open(img) for tuple in zip(*batch) for img in tuple]
        for i in range(batch_size*2):
            img_batch[i] = to_tensor(img_batch[i])
        
        img_batch = torch.stack(img_batch).to('cuda')
        
        losses = optimizer.train_one_step(img_batch)
        metric_tracker.update_metrics(losses, smoothe=True)
        
        iters += 1
        if iters % opt.print_freq == 0:
            print_current_losses(iters, metric_tracker.current_metrics())
        
        if iters % opt.display_freq == 0:
            display_current_results(model, img_batch)
            print(round((time.time() - start_time) / 60), 2)
            
        if iters % opt.save_freq == 0:
            torch.save(model.state_dict(), "./ffhq_690k.pth")