# Imports and Setup

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optimizers
import torch.nn.functional as F
import torchvision
from torchvision import transforms, datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
%matplotlib inline
import statistics
from tqdm import tqdm
import pickle

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generator

## Pixel Norm

In [None]:
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)

## Weight-Scaled Convolution (Equalized Learning Rate)

In [None]:
class WSConv2d(nn.Module):

    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding="same", gain=2, conv = True):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = padding) if conv else nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
        # only weights should be scaled and not the bias terms, thus the lines below
        self.bias = self.conv.bias
        self.conv.bias = None

        # initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

## Encoder Block

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm=True, downscale = True):
        super(EncoderBlock, self).__init__()
        self.use_pn = use_pixelnorm
        self.use_downscaling = downscale
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)  
        self.pn = PixelNorm()
        self.downscaled = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)

    def forward(self, x):
        x1 = self.leaky(self.pn(self.conv1(x))) if self.use_pn else self.leaky(self.conv1(x))
        x2 = self.pn(self.conv2(x1)) if self.use_pn else self.conv2(x1)
        x3 = self.leaky(x2 + x1) #skip connection
        out = self.downscaled(x3) if self.use_downscaling else x3
        return out

## Decoder Block

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm=True, upscale=True):
        super(DecoderBlock, self).__init__()
        self.use_pn = use_pixelnorm
        self.use_upscaling = upscale
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()

    def forward(self, x):
        x = self.upscaled(x) if self.use_upscaling else x
        x1 = self.leaky(self.pn(self.conv1(x))) if self.use_pn else self.leaky(self.conv1(x))
        x2 = self.pn(self.conv2(x1)) if self.use_pn else self.conv2(x1)
        out = self.leaky(x2 + x1) #skip connection
        return out

## U-Net Generator (Autoencoder)

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()

        self.enc1 = nn.Sequential(
            nn.WSConv2d(in_channels, features, kernel_size = 4, stride = 2, padding = 1),
            nn.LeakyReLU(0.2),
        ) #out = (128*128,64)
        
        self.enc2 = EncoderBlock(features, features * 2) #out = (64*64,128)

        self.enc3 = EncoderBlock(features * 2, features * 4) #out = (32*32,256)

        self.enc4 = EncoderBlock(features * 4, features * 8) #out = (16*16,512)

        self.enc5 = EncoderBlock(features * 8, features * 8) #out = (8*8,512)

        self.enc6 = EncoderBlock(features * 8, features * 8) #out = (4*4,512)

        self.enc7 = EncoderBlock(features * 8, features * 8) #out = (2*2,512)

        self.enc8 = EncoderBlock(features * 8, features * 8) #out = (1*1,512)

        # for decoder blocks [dec2:] the number of input channels have been doubled to incorporate the skip connections (U-Net)

        self.dec1 = DecoderBlock(features * 8 * 2, features * 8) #out = (2*2,512)

        self.dec2 = DecoderBlock(features * 8 * 2, features * 8) #out = (4*4,512)

        self.dec3 = DecoderBlock(features * 8 * 2, features * 8) #out = (8*8,512)

        self.dec4 = DecoderBlock(features * 8 * 2, features * 8) #out = (16*16,512)

        self.dec5 = DecoderBlock(features * 8 * 2, features * 4) #out = (32*32,256)

        self.dec6 = DecoderBlock(features * 4 * 2, features * 2) #out = (64*64,128)

        self.dec7 = DecoderBlock(features * 2 * 2, features) #out = (128*128, 64)

        self.dec8 = nn.Sequential(
            WSConv2d(features * 2, in_channels, kernel_size = 4, stride = 2, padding = 1, conv = False), #transpose convolution
            nn.Tanh(),
        )

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        e6 = self.enc6(e5)
        e7 = self.enc7(e6)

        bottleneck = self.enc8(e7)

        d1 = self.dec1(bottleneck)
        d2 = self.dec2(torch.cat([d1, e7], 1))
        d3 = self.dec3(torch.cat([d2, e6], 1))
        d4 = self.dec4(torch.cat([d3, e5], 1))
        d5 = self.dec5(torch.cat([d4, e4], 1))
        d6 = self.dec6(torch.cat([d5, e3], 1))
        d7 = self.dec7(torch.cat([d6, e2], 1))
        
        return self.dec8(torch.cat([[d7, e1], 1))

# Discriminator (PatchGAN)

In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(CNNBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels * 2,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2),
            )
            in_channels = feature

        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            ),
        )

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x ## 26*26 matrix