In [None]:
import torch
import random
import math
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

# Generator and Discriminator Utility Functions
def initialize_linear_layer(layer):
    nn.init.xavier_normal(layer.weight)
    layer.bias.data.zero_()

def initialize_conv_layer(layer):
    nn.init.kaiming_normal(layer.weight)
    if layer.bias is not None:
        layer.bias.data.zero_()
        
def equal_layer(module, name="weight"):
    EqualLR.apply(module, name)
    return module


class EqualLR:
    def __init__(self, name):
        self.name = name

    def compute_weight(self, module):
        weight = getattr(module, self.name + "_orig")
        input = weight.data.size(1) * weight.data[0][0].numel()
        return weight * math.sqrt(2 / input)

    @staticmethod
    def apply(module, name):
        fn = EqualLR(name)
        weight = getattr(module, name)
        del module._parameters[name]
        module.register_parameter(name + "_orig", nn.Parameter(weight.data))
        module.register_forward_pre_hook(fn)
        return fn

    def __call__(self, module, input):
        weight = self.compute_weight(module)
        setattr(module, self.name, weight)

class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)

class EqualConvLayer(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        conv = nn.Conv2d(*args, **kwargs)
        conv.weight.data.normal_()
        conv.bias.data.zero_()
        self.conv = equal_layer(conv)

    def forward(self, x):
        return self.conv(x)

class EqualLinearLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        linear = nn.Linear(in_channels, out_channels)
        linear.weight.data.normal_()
        linear.bias.data.zero_()
        self.linear = equal_layer(linear)

    def forward(self, x):
        return self.linear(x)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, kernel_size2=None, padding2=None):
        super().__init__()
        pad1 = padding
        pad2 = padding
        if padding2 is not None:
            pad2 = padding2

        kernel1 = kernel_size
        kernel2 = kernel_size
        if kernel_size2 is not None:
            kernel2 = kernel_size2

        self.conv_block = nn.Sequential(
            EqualConvLayer(in_channels, out_channels, kernel1, pad1),
            nn.LeakyReLU(0.2),
            EqualConvLayer(out_channels, out_channels, kernel2, pad2),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv_block(x)

class AdaIN(nn.Module):
    def __init__(self, in_channels, style_dim):
        super().__init__()
        self.adain = nn.InstanceNorm2d(in_channels)
        self.style = EqualLinearLayer(style_dim, in_channels*2)
        self.style.linear.bias.data[:in_channels] = 1
        self.style.linear.bias.data[in_channels:] = 0

    def forward(self, x, style):
        style = self.style(style).unsqueeze(2).unsqueeze(3)
        gamma, beta = style.chunk(2, 1)
        x = self.adain(x)
        x = gamma*x + beta
        return x

class AddNoise(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x, noise):
        return x + self.weight*noise

class ConstantIn(nn.Module):
    def __init__(self, channels, size=4):
        super().__init__()
        self.input = nn.Parameter(torch.randn(1, channels, size, size))

    def forward(self, x):
        batch = x.shape[0]
        out = self.input.repeat(batch, 1, 1, 1)
        return out

class StyleConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, style_dim=512, initial=False):
        super().__init__()
        if initial:
            self.conv1 = ConstantIn(in_channels)
        else:
            self.conv1 = EqualConvLayer(in_channels, out_channels, kernel_size, padding=padding)
        self.noise1 = equal_layer(AddNoise(out_channels))
        self.adain1 = AdaIN(out_channels, style_dim)
        self.leakyrelu1 = nn.LeakyReLU(0.2)

        self.conv2 = EqualConvLayer(out_channels, out_channels, kernel_size, padding=padding)
        self.noise2 = equal_layer(AddNoise(out_channels))
        self.adain2 = AdaIN(out_channels, style_dim)
        self.leakyrelu2 = nn.LeakyReLU(0.2)

    def forward(self, x, style, noise):
        x = self.conv1(x)
        x = self.noise1(x, noise)
        x = self.adain1(x, style)
        x = self.leakyrelu1(x)
        x = self.conv2(x)
        x = self.noise2(x, noise)
        x = self.adain2(x, style)
        x = self.leakyrelu2(x)
        return x
    

# Training Utility Functions
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

def accumulate(model1, model2, decay=0.999):
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())
    for k in par1.keys():
        par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)

def get_data_sample(dataset, batch_size, image_size=4):
    transform = transforms.Compose(
        [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
    dataset.transform = transform
    return DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=16)

def adjust_lr(optimizer, lr):
    for group in optimizer.param_groups:
        group["lr"] = lr * group.get("mult", 1)

# Shoe Dataset
def target_to_oh_shoe(target):
    NUM_CLASS = 3
    one_hot = torch.eye(NUM_CLASS)[target]
    return one_hot

In [None]:
class Discriminator(nn.Module):
    def __init__(self, num_classes=10, condition=2):
        super().__init__()
        if type(condition) not in (list, tuple):
            condition = [condition]
        condition_channels = [num_classes if i in condition else 0 for i in range(9)]
        self.progression_list = nn.ModuleList(
            [
                ConvBlock(16, 32, 3, 1),
                ConvBlock(32, 64, 3, 1),
                ConvBlock(64, 128, 3, 1),
                ConvBlock(128, 256, 3, 1),
                ConvBlock(256, 512, 3, 1),
                ConvBlock(512, 512, 3, 1),
                ConvBlock(512, 512, 3, 1),
                ConvBlock(512, 512, 3, 1),
                ConvBlock(512, 512, 3, 1, 4, 0),
            ]
        )
        self.rgb_list = nn.ModuleList(
            [
                EqualConvLayer(3+condition_channels[8], 16, 1),
                EqualConvLayer(3+condition_channels[7], 32, 1),
                EqualConvLayer(3+condition_channels[6], 64, 1),
                EqualConvLayer(3+condition_channels[5], 128, 1),
                EqualConvLayer(3+condition_channels[4], 256, 1),
                EqualConvLayer(3+condition_channels[3], 512, 1),
                EqualConvLayer(3+condition_channels[2], 512, 1),
                EqualConvLayer(3+condition_channels[1], 512, 1),
                EqualConvLayer(3+condition_channels[0], 512, 1),
            ]
        )
        self.num_layers = len(self.progression_list)
        self.equal_linear = EqualLinearLayer(512, 1)
        self.condition = condition
        self.num_classes = num_classes
        self.label_emb = nn.Embedding(num_classes, num_classes)

    def forward(self, x, label, step=0, alpha=-1):
        label = self.label_emb(label).view(-1, self.num_classes, 1, 1)

        for i in range(step, -1, -1):
            index = self.num_layers - i - 1
            downsample_input = current_input = x
            if i in self.condition:
                current_input = torch.cat([x, label.repeat(1, 1, *x.shape[2:])], dim=1)
            if i-1 in self.condition:
                downsample_input = torch.cat([x, label.repeat(1, 1, *x.shape[2:])], dim=1)
            if i == step:
                out = self.rgb_list[index](current_input)
            if i == 0:
                out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8)
                mean_std = out_std.mean()
                mean_std = mean_std.expand(out.size(0), 1, 4, 4)
                out = torch.cat([out, mean_std], 1)

            out = self.progression_list[index](out)
            if i > 0:
                out = F.interpolate(out, scale_factor=0.5, mode="bilinear", align_corners=False)

                if i == step and 0 <= alpha < 1:
                    skip_rgb = self.rgb_list[index + 1](downsample_input)
                    skip_rgb = F.interpolate(skip_rgb, scale_factor=0.5, mode="bilinear", align_corners=False)
                    out = (1 - alpha)*skip_rgb + alpha*out

        out = out.squeeze(2).squeeze(2)
        out = self.equal_linear(out)
        return out
    

In [None]:
class Generator(nn.Module):
    def __init__(self, code_dim, num_classes=10, condition=2):
        super().__init__()
        if type(condition) not in (list, tuple):
            condition = [condition]
        self.progression_list = nn.ModuleList(
            [
                StyleConvBlock(512, 512, 3, 1, initial=True),
                StyleConvBlock(512, 512, 3, 1),
                StyleConvBlock(512, 512, 3, 1),
                StyleConvBlock(512, 512, 3, 1),
                StyleConvBlock(512, 256, 3, 1),
                StyleConvBlock(256, 128, 3, 1),
                StyleConvBlock(128, 64, 3, 1),
                StyleConvBlock(64, 32, 3, 1),
                StyleConvBlock(32, 16, 3, 1),
            ]
        )
        self.rgb_list = nn.ModuleList(
            [
                EqualConvLayer(512, 3, 1),
                EqualConvLayer(512, 3, 1),
                EqualConvLayer(512, 3, 1),
                EqualConvLayer(512, 3, 1),
                EqualConvLayer(256, 3, 1),
                EqualConvLayer(128, 3, 1),
                EqualConvLayer(64, 3, 1),
                EqualConvLayer(32, 3, 1),
                EqualConvLayer(16, 3, 1),
            ]
        )
        self.condition = condition
        self.num_classes = num_classes
        self.label_emb = nn.Embedding(num_classes, num_classes)

    def forward(self, style, label, noise, step=0, alpha=-1, mixing_range=(-1, -1)):
        out = noise[0]
        if len(style) < 2:
            inject_index = [len(self.progression_list) + 1]
        else:
            inject_index = random.sample(list(range(step)), len(style) - 1)

        crossover = 0
        for i, (conv, to_rgb) in enumerate(zip(self.progression_list, self.rgb_list)):
            if mixing_range == (-1, -1):
                if crossover < len(inject_index) and i > inject_index[crossover]:
                    crossover = min(crossover + 1, len(style))
                style_step = style[crossover]
            else:
                if mixing_range[0] <= i <= mixing_range[1]:
                    style_step = style[1]
                else:
                    style_step = style[0]

            style_step = style_step.clone()
            if i in self.condition:
                style_step[:, :self.num_classes] = self.label_emb(label).view(-1, self.num_classes)
            else:
                style_step[:, :self.num_classes] = torch.zeros(style_step.shape[0], self.num_classes).to(style_step.device)

            if i > 0 and step > 0:
                upsample = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
                out = conv(upsample, style_step, noise[i])
            else:
                out = conv(out, style_step, noise[i])

            if i == step:
                out = to_rgb(out)
                if i > 0 and 0 <= alpha < 1:
                    skip_rgb = self.rgb_list[i - 1](upsample)
                    out = (1 - alpha)*skip_rgb + alpha*out
                break

        return out


class StyledGenerator(nn.Module):
    def __init__(self, code_dim=512, num_mlp=8, num_classes=10, condition=2):
        super().__init__()
        self.generator = Generator(code_dim, num_classes=num_classes, condition=condition)
        layers = [PixelNorm()]
        for _ in range(num_mlp):
            layers.append(EqualLinearLayer(code_dim, code_dim))
            layers.append(nn.LeakyReLU(0.2))
        self.style = nn.Sequential(*layers)

    def forward(self, x, label, noise=None, step=0, alpha=-1, mean_style=None, style_weight=0, mixing_range=(-1, -1),):
        styles = []
        if type(x) not in (list, tuple):
            x = [x]

        for i in x:
            styles.append(self.style(i))

        batch = x[0].shape[0]
        if noise is None:
            noise = []
            for i in range(step + 1):
                size = 4 * 2 ** i
                noise.append(torch.randn(batch, 1, size, size, device=x[0].device))

        if mean_style is not None:
            styles_norm = []
            for style in styles:
                styles_norm.append(mean_style + style_weight * (style - mean_style))
            styles = styles_norm

        return self.generator(styles, label, noise, step, alpha, mixing_range=mixing_range)

    def mean_style(self, x):
        style = self.style(x).mean(0, keepdim=True)
        return style


In [None]:
import argparse
from tqdm.auto import tqdm
from torch import optim
from torch.autograd import Variable, grad
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

device = "cuda" if torch.cuda.is_available() else "cpu"

def train(args, dataset, generator, discriminator):
    step = int(math.log2(args.init_size)) - 2
    resolution = 4 * 2 ** step
    loader = get_data_sample(dataset, args.batch.get(resolution, args.batch_default), resolution)
    data_loader = iter(loader)

    adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
    adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))
    pbar = tqdm(range(3_000_000), mininterval=30, maxinterval=60)
    requires_grad(generator, False)
    requires_grad(discriminator, True)

    disc_loss_val = 0
    gen_loss_val = 0
    grad_loss_val = 0
    alpha = 0
    used_sample = 0

    for i in pbar:
        discriminator.zero_grad()
        alpha = min(1, 1 / args.phase * (used_sample + 1))

        if used_sample > args.phase * 2:
            step += 1
            if step > int(math.log2(args.max_size)) - 2:
                step = int(math.log2(args.max_size)) - 2
            else:
                alpha = 0
                used_sample = 0

            resolution = 4 * 2 ** step
            loader = get_data_sample(dataset, args.batch.get(resolution, args.batch_default), resolution)
            data_loader = iter(loader)
            torch.save(
                {
                    'generator': generator.module.state_dict(),
                    'discriminator': discriminator.module.state_dict(),
                    'g_optimizer': g_optimizer.state_dict(),
                    'd_optimizer': d_optimizer.state_dict(),
                },
                f'checkpoint/train_step-{step}.model',
            )
            adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
            adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))

        try:
            real_image, label = next(data_loader)
        except (OSError, StopIteration):
            data_loader = iter(loader)
            real_image, label = next(data_loader)

        used_sample += real_image.shape[0]
        b_size = real_image.size(0)
        real_image = real_image.to(device)
        label = label.to(device)

        if args.loss == 'wgan-gp':
            real_predict, label_predict = discriminator(real_image, step=step, alpha=alpha)
            real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean()
            (-real_predict).backward()
        elif args.loss == 'r1':
            real_image.requires_grad = True
            real_predict, label_predict = discriminator(real_image, step=step, alpha=alpha)
            real_predict = F.softplus(-real_predict).mean()
            real_predict.backward(retain_graph=True)
            grad_real = grad(outputs=real_predict.sum(), inputs=real_image, create_graph=True)[0]
            grad_penalty = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean()
            grad_penalty = 10 / 2 * grad_penalty
            grad_penalty.backward()
            grad_loss_val = grad_penalty.item()

        if args.mixing and random.random() < 0.9:
            gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn(4, b_size, code_size, device=device).chunk(4, 0)
            gen_in1 = [gen_in11.squeeze(0), gen_in12.squeeze(0)]
            gen_in2 = [gen_in21.squeeze(0), gen_in22.squeeze(0)]
        else:
            gen_in1, gen_in2 = torch.randn(2, b_size, code_size, device=device).chunk(2, 0)
            gen_in1 = gen_in1.squeeze(0)
            gen_in2 = gen_in2.squeeze(0)

        fake_image = generator(gen_in1, step=step, alpha=alpha)
        fake_predict, _ = discriminator(fake_image, step=step, alpha=alpha)

        if args.loss == 'wgan-gp':
            fake_predict = fake_predict.mean()
            fake_predict.backward()
            eps = torch.rand(b_size, 1, 1, 1).to(device)
            x_hat = eps * real_image.data + (1 - eps) * fake_image.data
            x_hat.requires_grad = True
            hat_predict, _ = discriminator(x_hat, step=step, alpha=alpha)
            grad_x_hat = grad(outputs=hat_predict.sum(), inputs=x_hat, create_graph=True)[0]
            grad_penalty = ((grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2).mean()
            grad_penalty = 10 * grad_penalty
            grad_penalty.backward()
            grad_loss_val = grad_penalty.item()
            disc_loss_val = (real_predict - fake_predict).item()
        elif args.loss == 'r1':
            fake_predict = F.softplus(fake_predict).mean()
            fake_predict.backward()
            disc_loss_val = (real_predict + fake_predict).item()

        d_optimizer.step()

        if (i + 1) % n_critic == 0:
            generator.zero_grad()
            requires_grad(generator, True)
            requires_grad(discriminator, False)
            fake_image = generator(gen_in2, step=step, alpha=alpha)
            predict, _ = discriminator(fake_image, step=step, alpha=alpha)

            if args.loss == 'wgan-gp':
                loss = -predict.mean()
            elif args.loss == 'r1':
                loss = F.softplus(-predict).mean()

            gen_loss_val = loss.item()
            loss.backward()
            g_optimizer.step()
            accumulate(g_running, generator.module)
            requires_grad(generator, False)
            requires_grad(discriminator, True)
        if (i + 1) % 100 == 0:
            images = []
            gen_i, gen_j = args.gen_sample.get(resolution, (10, 5))

            with torch.no_grad():
                for _ in range(gen_i):
                    images.append(g_running(torch.randn(gen_j, code_size).to(device), step=step, alpha=alpha).data.cpu())

            utils.save_image(
                torch.cat(images, 0),
                f'sample/{str(i + 1).zfill(6)}.png',
                nrow=gen_i,
                normalize=True,
                range=(-1, 1),
            )
        if (i + 1) % 10000 == 0:
            torch.save(g_running.state_dict(), f'checkpoint/{str(i + 1).zfill(6)}.model')
        if (i + 1) % 100 == 0:
            state_msg = (
                f'Size: {4 * 2 ** step}; G: {gen_loss_val:.3f}; D: {disc_loss_val:.3f};'
                f' Grad: {grad_loss_val:.3f}; Alpha: {alpha:.5f}'
            )
            pbar.set_description(state_msg)


code_size = 512
batch_size = 16
n_critic = 1
image_size = 64

parser = argparse.ArgumentParser(description="Progressive Growing of GANs")

parser.add_argument("--path", type=str, default="data", help="path of specified dataset")
parser.add_argument("--phase", type=int, default=600_000, help="number of samples used for each training phase")
parser.add_argument("--lr", default=0.001, type=float, help="learning rate")
parser.add_argument("--sched", action="store_true", help="use lr scheduling")
parser.add_argument("--init_size", default=64, type=int, help="initial image size")
parser.add_argument("--max_size", default=1024, type=int, help="max image size")
parser.add_argument("--mixing", action="store_true", help="use mixing regularization")
parser.add_argument("--loss", type=str, default="wgan-gp", choices=["wgan-gp", "r1"], help="class of gan loss")
parser.add_argument("-d", "--data", default="shoe", type=str, choices=["shoe", "lsun"], help=("Specify dataset." "Currently Shoe Dataset is supported"))

args, _ = parser.parse_known_args()

generator = nn.DataParallel(StyledGenerator(code_size)).to(device)
discriminator = nn.DataParallel(Discriminator()).to(device)
g_running = StyledGenerator(code_size).to(device)
g_running.train(False)

class_loss = nn.CrossEntropyLoss()

g_optimizer = optim.Adam(generator.module.generator.parameters(), lr=args.lr, betas=(0.0, 0.99))
g_optimizer.add_param_group(
    {
        "params": generator.module.style.parameters(),
        "lr": args.lr * 0.01,
        "mult": 0.01,
    }
)
d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99))

accumulate(g_running, generator.module, 0)

if args.data == "shoe":
    data_dir = "/kaggle/input/shoe-vs-sandal-vs-boot-dataset-15k-images/Shoe vs Sandal vs Boot Dataset"

    dataset = datasets.ImageFolder(
        root=data_dir,
        transform=transforms.Compose(
            [
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        ),
        target_transform = target_to_oh_shoe
    )

if args.sched:
    args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
    args.batch = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 32, 256: 32}
else:
    args.lr = {}
    args.batch = {}

args.gen_sample = {512: (8, 4), 1024: (4, 2)}

args.batch_default = 32

train(args, dataset, generator, discriminator)