In [1]:
import os.path
import time

import torch
import torch.nn as nn

In [2]:
# cuda_available = torch.cuda.is_available()
# dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# device = "cuda" if torch.cuda.is_available() else "cpu"
# Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
# print(f"Using {device} device")
# torch.cuda.empty_cache()
Tensor = torch.Tensor
device = torch.device("cpu")

In [3]:
# def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
#     """3x3 convolution with padding"""
#     return nn.Conv2d(
#         in_planes,
#         out_planes,
#         kernel_size=3,
#         stride=stride,
#         padding=dilation,
#         groups=groups,
#         bias=False,
#         dilation=dilation,
#     )
#
#
# def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
#     """1x1 convolution"""
#     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

In [4]:
class DownsampleBlock(nn.Module):
    def __init__(self):
        super().__init__()

        self.layers = nn.Sequential(
            *(nn.Conv2d(3, 3, kernel_size=3, padding=1, stride=2) for _ in range(5))
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.layers.forward(x)

In [5]:
# PyTorch version of ResNet
# https://github.com/pytorch/vision/blob/cddad9ca3822011548e18342f52a3e9f4724c2dd/torchvision/models/resnet.py#L88


class ResBlk(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv_1 = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
        self.batch_norm_1 = nn.BatchNorm2d(3)
        self.conv_2 = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
        self.batch_norm_2 = nn.BatchNorm2d(3)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x: Tensor) -> Tensor:
        out = self.conv_1(x)
        out = self.batch_norm_1(out)
        out = self.relu(out)

        out = self.conv_2(out)
        out = self.batch_norm_2(out)

        out += x
        out = self.relu(out)

        return out

In [6]:
shared_layer_1, shared_layer_2 = ResBlk(), ResBlk()

E_f = nn.Sequential(
    DownsampleBlock(),
    *(ResBlk() for _ in range(4)),
    shared_layer_1,
    shared_layer_2,
).to(device)

E_r = nn.Sequential(
    DownsampleBlock(),
    *(ResBlk() for _ in range(4)),
    shared_layer_1,
    shared_layer_2,
).to(device)

In [7]:
class GeneratorFake(nn.Module):
    def __init__(self):
        super().__init__()

        self.resblk1 = ResBlk()
        self.resblk2 = ResBlk()
        self.upscale_layers = nn.Sequential(*(nn.ConvTranspose2d(3, 3, kernel_size=3,
                                                                 stride=2,
                                                                 padding=1,
                                                                 output_padding=1,  # is this correct?
                                                                 ) for _ in range(5)))

        self.out_conv = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, x: Tensor) -> Tensor:
        out = self.resblk1(x)
        out = self.resblk2(out)
        out = self.upscale_layers(out)
        # print(out.shape)

        # for i in range(5):
        #     out = self.upscale_layers[i](out)
        # print(out.shape)

        return self.out_conv(out)


G_f = GeneratorFake().to(device)
# G_f = nn.Sequential(
#     *(ResBlk() for _ in range(2)),
#     *(nn.ConvTranspose2d(3, 3, kernel_size=3, stride=2, padding=1) for _ in range(5)),
#     # nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1),
# )

In [8]:
# https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/stargan/models.py

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.layers = nn.Sequential(
            *discriminator_block(3, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, kernel_size=4, padding=1, bias=False),
            nn.Flatten(),
            nn.Linear(1024, 1),
        )

    def forward(self, img: Tensor) -> Tensor:
        # combined = torch.cat((img_a, img_b), dim=1)
        # out = img
        # # print("Discriminator start", out.shape)
        #
        # for layer in self.layers:
        #     out = layer(out)
        #     # print(layer, out.shape)

        return self.layers(img)


D_f = Discriminator().to(device)

In [9]:
# class EmbeddingDiscriminator(nn.Module):
#     def __init__(self):
#         super().__init__()
#
#         self.layers = nn.Sequential(
#             nn.Flatten(),
#             nn.Linear(3 * 16 * 16),
#             nn.Linear()
#         )

In [10]:
class Options:
    img_height = 512
    img_width = 512
    learning_rate = 0.0001  # from paper
    starting_epoch = 0
    total_epochs = 10
    batch_size = 8
    cpu_count = 4


opt = Options()

In [11]:
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import torchvision.transforms as transforms

FILE_ROOT = "/home/andresokol/code/mastersdata"
FFHQ_DIR = f"{FILE_ROOT}/ffhq-dataset/images1024x1024"
PREPARED_ROOT = f"{FILE_ROOT}/prepared"
RENDERED_ROOT = f"{FILE_ROOT}/rendered"


class ImageDataset(Dataset):
    transform = transforms.Compose([
        transforms.Resize((opt.img_height, opt.img_width), transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    def __init__(self, mode="train"):
        self.files = []

        for i in range(70):
            for j in range(1000):
                part = f"{i:02}000/{i:02}{j:03}"
                if os.path.exists(f"{FFHQ_DIR}/{part}.png"):
                    if os.path.exists(f"{RENDERED_ROOT}/{part}_base.png"):
                        if os.path.exists(f"{RENDERED_ROOT}/{part}_structure.png"):
                            self.files.append(part)

        print(f"Read {len(self.files)} images")
        # self.files = sorted(glob.glob(os.path.join(root, mode) + "/*.*"))
        # if mode == "train":
        #     self.files.extend(sorted(glob.glob(os.path.join(root, "test") + "/*.*")))

    def __getitem__(self, index):
        part = self.files[index % len(self.files)]
        render_base = Image.open(f"{RENDERED_ROOT}/{part}_base.png").convert("RGB")
        render_struct = Image.open(f"{RENDERED_ROOT}/{part}_structure.png").convert("RGB")
        # img = Image.open(self.files[index % len(self.files)])
        # w, h = img.size
        # img_A = img.crop((0, 0, w / 2, h))
        # img_B = img.crop((w / 2, 0, w, h))

        # if np.random.random() < 0.5:
        #     img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], "RGB")
        #     img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], "RGB")

        # img_A = self.transform(img_A)
        # img_B = self.transform(img_B)

        return {
            "render_base": self.transform(render_base),
            "render_struct": self.transform(render_struct),
        }

    def __len__(self):
        return len(self.files)

In [12]:
from torch.utils.data import DataLoader

dataloader = DataLoader(
    ImageDataset(),
    batch_size=opt.batch_size,
    num_workers=opt.cpu_count,
    shuffle=True,
)

Read 7649 images


In [13]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

In [14]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)


# eh?
if opt.starting_epoch == 0:
    E_f.apply(weights_init_normal)
    G_f.apply(weights_init_normal)
    D_f.apply(weights_init_normal)

In [15]:
optimizer_F = torch.optim.Adam([
    {"params": E_f.parameters()},
    {"params": G_f.parameters()},
], lr=opt.learning_rate)
optimizer_D_F = torch.optim.Adam(D_f.parameters(), lr=opt.learning_rate)

In [16]:
import torchvision
import tqdm
from torch.nn import functional

prev_time = time.time()

# https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
# https://github.com/hse-ds/iad-applied-ds/blob/master/2022/seminars/sem03/sem03-gan-task.ipynb
# https://colab.research.google.com/github/yandexdataschool/mlhep2019/blob/master/notebooks/day-6/06_GAN_faces_solution.ipynb#scrollTo=-Z7YAPGxp2Xl

# branch_F_loss = Tensor([0])
# D_f_loss = Tensor([0])

branch_F_loss, D_f_loss = None, None

for epoch in range(opt.starting_epoch, opt.total_epochs):
    with tqdm.auto.tqdm(total=len(dataloader), unit="batch", desc=f"Epoch {epoch}") as pbar:
        for i, batch in enumerate(dataloader):
            # Model inputs
            render_base = batch["render_base"].to(device)
            render_struct = batch["render_struct"].to(device)

            embedding = E_f(render_struct)
            generated_f = G_f(embedding)

            # pred_f_generated = D_f(generated_f)
            # pred_f_truth = D_f(render_base)

            if i % 2 == 0:
                # Train Generator
                optimizer_F.zero_grad()
                branch_F_loss = -functional.logsigmoid(D_f(generated_f)).mean()
                branch_F_loss.backward()
                optimizer_F.step()

                writer.add_scalar('Branch_F loss', branch_F_loss.item(), i)
            else:
                optimizer_D_F.zero_grad()
                D_f_loss = -functional.logsigmoid(D_f(render_base)).mean() - functional.logsigmoid(
                    -D_f(generated_f)).mean()
                D_f_loss.backward()
                optimizer_D_F.step()

                writer.add_scalar('D_f loss', D_f_loss.item(), i)

            grid = torchvision.utils.make_grid(render_struct)
            writer.add_image('images', grid, i)
            if branch_F_loss and D_f_loss:
                pbar.set_postfix(loss_D_f=D_f_loss.item(), loss_F=branch_F_loss.item())
            pbar.update(1)

Epoch 0:   0%|          | 0/957 [00:00<?, ?batch/s]

KeyboardInterrupt: 

In [None]:
nn.BatchNorm1d()