In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# Progressive growing of layers
# Minibatch std layer (Discriminator)
# Equalized learning rate
# Pixelwise normalization layer (Generator)

In [None]:
# inputs / torch.sqrt(torch.mean(torch.square(X), axis=1, keepdim=True) + 1e-8)

class PixelWiseNormalization(nn.Module):
    def __init__(self, eps=1e-8):
        super().__init__()
        self.eps = eps

    def forward(self, x: torch.Tensor):
        return (x / torch.sqrt(torch.mean(torch.square(x), dim=1, keepdim=True)
         + self.eps))

# x = torch.randn(32, 3, 32, 32)
# model = PixelWiseNormalization()
# model(x).size()

In [None]:
class EqualizedConv2d(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, gain=2):
        super().__init__()
        self.conv2d: nn.Conv2d = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding)
        self.fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.conv2d.weight)
        self.wscale = torch.sqrt(gain / torch.tensor(self.fan_in))
        self.reset_weights()

    def reset_weights(self):
        nn.init.normal_(self.conv2d.weight) * self.wscale

    def forward(self, x):
        return self.conv2d(x)


# x = torch.randn(2, 3, 5, 5)
# model = EqualizedConv2d(3, 6, 3)
# model(x).size()

torch.Size([2, 6, 3, 3])

In [None]:
class EqualizedConvTranspose2d(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, gain=2):
        super().__init__()
        self.conv_transpose2d: nn.ConvTranspose2d = nn.ConvTranspose2d(in_ch, out_ch, kernel_size, stride, padding)
        self.fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.conv_transpose2d.weight)
        self.wscale = torch.sqrt(gain / torch.tensor(self.fan_in))
        self.reset_weights()

    def reset_weights(self):
        nn.init.normal_(self.conv_transpose2d.weight) * self.wscale

    def forward(self, x):
        return self.conv_transpose2d(x)


# x = torch.randn(2, 512, 1, 1)
# model = EqualizedConvTranspose2d(512, 512, 4)
# model(x).size()

torch.Size([2, 512, 4, 4])

In [None]:
# S = tf.math.reduce_std(inputs, axis=[0, -1])
# v = tf.reduce_mean(S)
# tf.concat([inputs, tf.fill([batch_size, height, width, 1], v)], axis=-1)

class MiniBatchSTD(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        S = torch.std(x, dim=(0, 1))
        v = torch.mean(S)
        statistics = torch.empty(x.size(0), 1, x.size(2), x.size(3))
        statistics.fill_(v)

        return torch.cat([x, statistics], axis=1)

# x = torch.randn(2, 3, 5, 5)
# model = MiniBatchSTD()
# model(x)[:, 3, ...]

In [None]:
class BlockConv2d(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, leakyscale=0.2, use_pn=True):
        super().__init__()
        self.use_pn = use_pn
        self.cn1 = EqualizedConv2d(in_ch, out_ch, kernel_size, stride=stride, padding=padding)
        self.leaky_relu1 = nn.LeakyReLU(leakyscale)
        self.cn2 = EqualizedConv2d(out_ch, out_ch, kernel_size, stride=stride, padding=padding)
        self.leaky_relu2 = nn.LeakyReLU(leakyscale)
        self.pn = PixelWiseNormalization()

    def forward(self, x):
        x = self.leaky_relu1(self.cn1(x))
        if self.use_pn:
            x = self.pn(x)
        x = self.leaky_relu2(self.cn2(x))
        if self.use_pn:
            x = self.pn(x)
        return x


x = torch.randn(2, 3, 5, 5)
model = BlockConv2d(3, 10)
model(x).size()


torch.Size([2, 10, 5, 5])

In [None]:
input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)

F.upsample(input, scale_factor=2)
m(input).size()


torch.Size([1, 1, 4, 4])

In [None]:
# Latent vector ---> 512 × 1 × 1
############# step0 ===> initial block (4 * 4)
# TransposeConv2d 4 × 4 LReLU ---> 512 × 4 × 4
# Conv 3 × 3 LReLU ----> 512 × 4 × 4
############ step1 (8 * 8)
# Upsample ->  512 × 8 × 8
# Conv 3 × 3 LReLU --> 512 × 8 × 8
# Conv 3 × 3 LReLU --> 512 × 8 × 8
############ step2 (16 * 16)
# Upsample –> 512 × 16 × 16
# Conv 3 × 3 LReLU 512 × 16 × 16
# Conv 3 × 3 LReLU 512 × 16 × 16
############ step3 (32 * 32)
# Upsample –> 512 × 32 × 32
# Conv 3 × 3 LReLU 512 × 32 × 32
# Conv 3 × 3 LReLU 512 × 32 × 32
############ step4 (64 * 64)
# Upsample – 512 × 64 × 64
# Conv 3 × 3 LReLU 256 × 64 × 64
# Conv 3 × 3 LReLU 256 × 64 × 64
############ step5 (128 * 128)
# Upsample – 256 × 128 × 128
# Conv 3 × 3 LReLU 128 × 128 × 128
# Conv 3 × 3 LReLU 128 × 128 × 128
############ step6 (256 * 256)
# Upsample – 128 × 256 × 256
# Conv 3 × 3 LReLU 64 × 256 × 256
# Conv 3 × 3 LReLU 64 × 256 × 256
############ step7 (512 * 512)
# Upsample – 64 × 512 × 512
# Conv 3 × 3 LReLU 32 × 512 × 512
# Conv 3 × 3 LReLU 32 × 512 × 512
############ step8 (1024 * 1024)
# Upsample – 32 × 1024 × 1024
# Conv 3 × 3 LReLU 16 × 1024 × 1024
# Conv 3 × 3 LReLU 16 × 1024 × 1024


class Generator(nn.Module):
    def __init__(self, latent_channel, leaky_scale=0.2):
        super().__init__()
        self.initial_layer = nn.Sequential(
            EqualizedConvTranspose2d(latent_channel, latent_channel, 4),
            nn.LeakyReLU(leaky_scale),
            PixelWiseNormalization(),
            EqualizedConv2d(latent_channel, latent_channel, 3, padding=1),
            nn.LeakyReLU(leaky_scale),
            PixelWiseNormalization(),
        )
        self.initial_toRGB = EqualizedConv2d(latent_channel, 3, 1)


        self.block_ch = [512, 512, 512, 512, 256, 128, 64, 32, 16]
        self.progressive_blocks = nn.ModuleList()
        self.toRGB = nn.ModuleList()
        self.toRGB.append(self.initial_toRGB)

        for i in range(len(self.block_ch) - 1):
            self.progressive_blocks.append(BlockConv2d(self.block_ch[i], self.block_ch[i+1]))
            self.toRGB.append(EqualizedConv2d(self.block_ch[i+1], 3, 1))

    def forward(self, x, step, alpha):
        x = self.initial_layer(x)

        if step == 0:
            return self.initial_toRGB(x)

        for i in range(step):
            upsample_out = F.interpolate(x, scale_factor=2)
            x = self.progressive_blocks[i](upsample_out)

        output_old_layers = (1 - alpha) * self.toRGB[step - 1](upsample_out)
        output_new_layer = alpha * self.toRGB[step](x)
        return torch.tanh(output_old_layers + output_new_layer)

generator = Generator(512)
x = torch.randn(2, 512, 1, 1)
generator(x, 8, 0.2).size()


torch.Size([2, 3, 1024, 1024])