In [1]:
import os.path
import time
import re

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# cuda_available = torch.cuda.is_available()
# dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# device = "cuda" if torch.cuda.is_available() else "cpu"
# Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
# print(f"Using {device} device")
# torch.cuda.empty_cache()
Tensor = torch.Tensor
device = torch.device("cpu")

In [3]:
# def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
#     """3x3 convolution with padding"""
#     return nn.Conv2d(
#         in_planes,
#         out_planes,
#         kernel_size=3,
#         stride=stride,
#         padding=dilation,
#         groups=groups,
#         bias=False,
#         dilation=dilation,
#     )
#
#
# def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
#     """1x1 convolution"""
#     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

In [4]:
class DownsampleBlock(nn.Module):
    def __init__(self):
        super().__init__()

        def _block():
            return [
                nn.Conv2d(3, 3, kernel_size=3, padding=1, stride=2),
                nn.ReLU(),
            ]

        self.layers = nn.Sequential(
            *_block(),
            *_block(),
            *_block(),
            *_block(),
            *_block(),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.layers.forward(x)

In [5]:
# PyTorch version of ResNet
# https://github.com/pytorch/vision/blob/cddad9ca3822011548e18342f52a3e9f4724c2dd/torchvision/models/resnet.py#L88


class ResBlk(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv_1 = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
        self.batch_norm_1 = nn.BatchNorm2d(3)
        self.conv_2 = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
        self.batch_norm_2 = nn.BatchNorm2d(3)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x: Tensor) -> Tensor:
        out = self.conv_1(x)
        out = self.batch_norm_1(out)
        out = self.relu(out)

        out = self.conv_2(out)
        out = self.batch_norm_2(out)

        out += x
        out = self.relu(out)

        return out

In [6]:
shared_layer_1, shared_layer_2 = ResBlk(), ResBlk()

E_f = nn.Sequential(
    DownsampleBlock(),
    *(ResBlk() for _ in range(4)),
    shared_layer_1,
    shared_layer_2,
).to(device)

# E_r = nn.Sequential(
#     DownsampleBlock(),
#     *(ResBlk() for _ in range(4)),
#     shared_layer_1,
#     shared_layer_2,
# ).to(device)

In [7]:
class GeneratorFake(nn.Module):
    def __init__(self):
        super().__init__()

        def _upscale_conv():
            return (
                nn.ConvTranspose2d(3, 3, kernel_size=3, stride=2, padding=1, output_padding=1),  # is this correct?
                nn.LeakyReLU(),
            )

        self.resblk1 = ResBlk()
        self.resblk2 = ResBlk()
        self.upscale_layers = nn.Sequential(
            *_upscale_conv(),
            *_upscale_conv(),
            *_upscale_conv(),
            *_upscale_conv(),
            *_upscale_conv(),
        )

        self.out_conv = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, x: Tensor) -> Tensor:
        out = self.resblk1(x)
        out = self.resblk2(out)
        out = self.upscale_layers(out)
        # print(out.shape)

        # for i in range(5):
        #     out = self.upscale_layers[i](out)
        # print(out.shape)

        return self.out_conv(out)


G_f = GeneratorFake().to(device)
# G_f = nn.Sequential(
#     *(ResBlk() for _ in range(2)),
#     *(nn.ConvTranspose2d(3, 3, kernel_size=3, stride=2, padding=1) for _ in range(5)),
#     # nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1),
# )

In [8]:
# from https://github.com/NVlabs/SPADE

# Creates SPADE normalization layer based on the given configuration
# SPADE consists of two steps. First, it normalizes the activations using
# your favorite normalization method, such as Batch Norm or Instance Norm.
# Second, it applies scale and bias to the normalized output, conditioned on
# the segmentation map.
# The format of |config_text| is spade(norm)(ks), where
# (norm) specifies the type of parameter-free normalization.
#       (e.g. syncbatch, batch, instance)
# (ks) specifies the size of kernel in the SPADE module (e.g. 3x3)
# Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5.
# Also, the other arguments are
# |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE
# |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE
class SPADE(nn.Module):
    def __init__(self, config_text, norm_nc, label_nc):
        super().__init__()

        assert config_text.startswith('spade')
        parsed = re.search('spade(\D+)(\d)x\d', config_text)
        param_free_norm_type = str(parsed.group(1))
        ks = int(parsed.group(2))

        if param_free_norm_type == 'instance':
            self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        # elif param_free_norm_type == 'syncbatch':
        #     self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == 'batch':
            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
        else:
            raise ValueError('%s is not a recognized param-free norm type in SPADE'
                             % param_free_norm_type)

        # The dimension of the intermediate embedding space. Yes, hardcoded.
        nhidden = 128

        pw = ks // 2
        self.mlp_shared = nn.Sequential(
            nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
            nn.ReLU()
        )
        self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)

    def forward(self, x, segmap):

        # Part 1. generate parameter-free normalized activations
        normalized = self.param_free_norm(x)

        # Part 2. produce scaling and bias conditioned on semantic map
        segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
        actv = self.mlp_shared(segmap)
        gamma = self.mlp_gamma(actv)
        beta = self.mlp_beta(actv)

        # apply scale and bias
        out = normalized * (1 + gamma) + beta

        return out


# ResNet block that uses SPADE.
# It differs from the ResNet block of pix2pixHD in that
# it takes in the segmentation map as input, learns the skip connection if necessary,
# and applies normalization first and then convolution.
# This architecture seemed like a standard architecture for unconditional or
# class-conditional GAN architecture using residual block.
# The code was inspired from https://github.com/LMescheder/GAN_stability.
class SPADEResnetBlock(nn.Module):
    def __init__(self, fin, fout, opt):
        super().__init__()
        # Attributes
        self.learned_shortcut = (fin != fout)
        fmiddle = min(fin, fout)

        # create conv layers
        self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
        self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)

        # apply spectral norm if specified
        # if 'spectral' in opt.norm_G:
        #     self.conv_0 = spectral_norm(self.conv_0)
        #     self.conv_1 = spectral_norm(self.conv_1)
        #     if self.learned_shortcut:
        #         self.conv_s = spectral_norm(self.conv_s)

        # define normalization layers
        spade_config_str = opt.norm_G.replace('spectral', '')
        self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc)
        self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc)
        if self.learned_shortcut:
            self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc)

    # note the resnet block with SPADE also takes in |seg|,
    # the semantic segmentation map as input
    def forward(self, x, seg):
        x_s = self.shortcut(x, seg)

        dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
        dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))

        out = x_s + dx

        return out

    def shortcut(self, x, seg):
        if self.learned_shortcut:
            x_s = self.conv_s(self.norm_s(x, seg))
        else:
            x_s = x
        return x_s

    def actvn(self, x):
        return F.leaky_relu(x, 2e-1)

In [9]:
# class GeneratorReal(nn.Module):
#     def __init__(self):
#         super().__init__()
#
#         self.resblk1 = ResBlk()
#         self.resblk2 = ResBlk()
#
#         self.spade1 = SPADEResnetBlock()
#         self.spade2 = SPADEResnetBlock()
#         self.spade3 = SPADEResnetBlock()
#         self.spade4 = SPADEResnetBlock()
#         self.spade5 = SPADEResnetBlock()


In [10]:
# VGG architecter, used for the perceptual loss using a pretrained VGG network
class VGG19(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super().__init__()
        vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out

# Perceptual loss that uses a pretrained VGG network
class VGGLoss(nn.Module):
    def __init__(self, gpu_ids):
        super(VGGLoss, self).__init__()
        self.vgg = VGG19().to(device)
        self.criterion = nn.L1Loss()
        self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]

    def forward(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        for i in range(len(x_vgg)):
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
        return loss

In [11]:
# https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/stargan/models.py

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.layers = nn.Sequential(
            *discriminator_block(3, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, kernel_size=4, padding=1, bias=False),
            nn.Flatten(),
            nn.Linear(1024, 1),
        )

    def forward(self, img: Tensor) -> Tensor:
        # combined = torch.cat((img_a, img_b), dim=1)
        # out = img
        # # print("Discriminator start", out.shape)
        #
        # for layer in self.layers:
        #     out = layer(out)
        #     # print(layer, out.shape)

        return self.layers(img)


D_f = Discriminator().to(device)

In [12]:
# class EmbeddingDiscriminator(nn.Module):
#     def __init__(self):
#         super().__init__()
#
#         self.layers = nn.Sequential(
#             nn.Flatten(),
#             nn.Linear(3 * 16 * 16),
#             nn.Linear()
#         )

In [13]:
class Options:
    img_height = 512
    img_width = 512
    learning_rate = 0.0001  # from paper
    starting_epoch = 0
    total_epochs = 10
    batch_size = 8
    cpu_count = 4


opt = Options()

In [14]:
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import torchvision.transforms as transforms

FILE_ROOT = "/home/andresokol/code/mastersdata"
FFHQ_DIR = f"{FILE_ROOT}/ffhq-dataset/images1024x1024"
PREPARED_ROOT = f"{FILE_ROOT}/prepared"
RENDERED_ROOT = f"{FILE_ROOT}/rendered"


class ImageDataset(Dataset):
    transform = transforms.Compose([
        transforms.Resize((opt.img_height, opt.img_width), transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    def __init__(self, mode="train"):
        self.files = []

        for i in range(70):
            for j in range(1000):
                part = f"{i:02}000/{i:02}{j:03}"
                if os.path.exists(f"{FFHQ_DIR}/{part}.png"):
                    if os.path.exists(f"{RENDERED_ROOT}/{part}_base.png"):
                        if os.path.exists(f"{RENDERED_ROOT}/{part}_structure.png"):
                            self.files.append(part)

        print(f"Read {len(self.files)} images")
        # self.files = sorted(glob.glob(os.path.join(root, mode) + "/*.*"))
        # if mode == "train":
        #     self.files.extend(sorted(glob.glob(os.path.join(root, "test") + "/*.*")))

    def __getitem__(self, index):
        part = self.files[index % len(self.files)]
        render_base = Image.open(f"{RENDERED_ROOT}/{part}_base.png").convert("RGB")
        render_struct = Image.open(f"{RENDERED_ROOT}/{part}_structure.png").convert("RGB")
        # img = Image.open(self.files[index % len(self.files)])
        # w, h = img.size
        # img_A = img.crop((0, 0, w / 2, h))
        # img_B = img.crop((w / 2, 0, w, h))

        # if np.random.random() < 0.5:
        #     img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], "RGB")
        #     img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], "RGB")

        # img_A = self.transform(img_A)
        # img_B = self.transform(img_B)

        return {
            "render_base": self.transform(render_base),
            "render_struct": self.transform(render_struct),
        }

    def __len__(self):
        return len(self.files)

In [15]:
from torch.utils.data import DataLoader

dataloader = DataLoader(
    ImageDataset(),
    batch_size=opt.batch_size,
    num_workers=opt.cpu_count,
    shuffle=True,
)

Read 19985 images


In [16]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

In [17]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)


# eh?
if opt.starting_epoch == 0:
    E_f.apply(weights_init_normal)
    G_f.apply(weights_init_normal)
    D_f.apply(weights_init_normal)

In [18]:
optimizer_F = torch.optim.Adam([
    {"params": E_f.parameters()},
    {"params": G_f.parameters()},
], lr=opt.learning_rate)
optimizer_D_F = torch.optim.Adam(D_f.parameters(), lr=opt.learning_rate)

In [None]:
import torchvision
import tqdm

prev_time = time.time()

# https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
# https://github.com/hse-ds/iad-applied-ds/blob/master/2022/seminars/sem03/sem03-gan-task.ipynb
# https://colab.research.google.com/github/yandexdataschool/mlhep2019/blob/master/notebooks/day-6/06_GAN_faces_solution.ipynb#scrollTo=-Z7YAPGxp2Xl

# branch_F_loss = Tensor([0])
# D_f_loss = Tensor([0])

branch_F_loss, D_f_loss = None, None

vgg_loss = VGGLoss(None)

for epoch in range(opt.starting_epoch, opt.total_epochs):
    with tqdm.auto.tqdm(total=len(dataloader), unit="batch", desc=f"Epoch {epoch}") as pbar:
        for i, batch in enumerate(dataloader):
            # Model inputs
            render_base = batch["render_base"].to(device)
            render_struct = batch["render_struct"].to(device)

            embedding = E_f(render_struct)
            generated_f = G_f(embedding)

            # pred_f_generated = D_f(generated_f)
            # pred_f_truth = D_f(render_base)

            if i % 2 == 0:
                # Train Generator
                optimizer_F.zero_grad()
                branch_F_loss = F.logsigmoid(-D_f(generated_f)).mean()
                branch_F_loss.backward()
                optimizer_F.step()

                writer.add_scalar('Branch_F loss', branch_F_loss.item(), i)
            else:
                optimizer_D_F.zero_grad()

                D_f_loss = -F.logsigmoid(D_f(render_base)).mean() - F.logsigmoid(-D_f(generated_f)).mean() - vgg_loss(render_base, generated_f)
                D_f_loss.backward()
                optimizer_D_F.step()

                writer.add_scalar('D_f loss', D_f_loss.item(), i)

            grid = torchvision.utils.make_grid(render_struct)
            writer.add_image('images', grid, i)
            if branch_F_loss and D_f_loss:
                pbar.set_postfix(loss_D_f=D_f_loss.item(), loss_F=branch_F_loss.item())
            pbar.update(1)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /home/andresokol/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

Epoch 0:   0%|          | 0/2499 [00:00<?, ?batch/s]

In [None]:
nn.BatchNorm1d()