In [33]:
# !wget https://raw.githubusercontent.com/TamirPuzanov/Pix2Pix-PyTorch/5e6e3ee676a20b32d4d1b2548a831bd72ba9a777/download_pix2pix_dataset.sh
# !bash download_pix2pix_dataset.sh maps

In [34]:
from torch.autograd import Variable
from tqdm.notebook import tqdm

import torch.nn as nn
import torch

from torchvision import models
import torchvision.transforms as tt
import torchvision.transforms.functional as TF

from PIL import Image
import numpy as np

from torch.utils.data import Dataset, DataLoader
import os, random

from torchvision.utils import make_grid

In [35]:
import warnings
warnings.filterwarnings("ignore")

In [36]:
path_A = "../input/facades/facades/trainB/"
path_B = "../input/facades/facades/trainA/"

image_size = 128
scale_size = int(image_size)

mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

batch_size = 64

In [37]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [38]:
def numel(m: torch.nn.Module, only_trainable: bool = False):
    """
    returns the total number of parameters used by `m` (only counting
    shared parameters once); if `only_trainable` is True, then only
    includes parameters with `requires_grad = True`
    """
    parameters = list(m.parameters())
    if only_trainable:
        parameters = [p for p in parameters if p.requires_grad]
    unique = {p.data_ptr(): p for p in parameters}.values()
    return sum(p.numel() for p in unique)

In [39]:
from pathlib import Path

In [40]:
class Night2DayData(Dataset):
    def __init__(self, path, transform=None):
        super(Dataset, self).__init__()
        
        self.files = list(Path(path).rglob("*.*"))
        self.transform = transform
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        path_ = self.files[idx]
        
        image = Image.open(path_).convert("RGB")
        w, h = image.size
        
        B = image.crop((0, 0, w / 2, h))
        A = image.crop((w / 2, 0, w, h))
        
        if random.random() > 0.5:
            A = TF.hflip(A)
            B = TF.hflip(B)
        
        if self.transform is not None:
            A = self.transform(A)
            B = self.transform(B)
        
        return A, B

In [41]:
class FacedesData(Dataset):
    def __init__(self, path_A, path_B, transform=None):
        super(Dataset, self).__init__()
        
        self.files_A = os.listdir(path_A)
        
        self.path_A = path_A
        self.path_B = path_B
        
        self.transform = transform
        
    def __len__(self):
        return len(self.files_A)
    
    def __getitem__(self, idx):
        path_A = os.path.join(self.path_A, self.files_A[idx])
        path_B = os.path.join(self.path_B, self.files_A[idx].replace("B", "A"))
        
        A = Image.open(path_A).convert("RGB")
        B = Image.open(path_B).convert("RGB")
        
        if random.random() > 0.5:
            A = TF.hflip(A)
            B = TF.hflip(B)
        
        if self.transform is not None:
            A = self.transform(A)
            B = self.transform(B)
        
        return A, B

In [42]:
class Denormalize:
    def __init__(self, mean, std, inplace=False):
        self.mean = mean
        self.demean = [-m/s for m, s in zip(mean, std)]
        self.std = std
        self.destd = [1/s for s in std]
        self.inplace = inplace

    def __call__(self, tensor):
        tensor = TF.normalize(tensor, self.demean, self.destd, self.inplace)
        return torch.clamp(tensor, 0.0, 1.0)

In [43]:
class Buffer:
    def __init__(self, size=100):
        self.size = size 
        self.buffer = []

    def push_and_pop(self, data):
        data = data.cpu()
        r = []
        
        for el in data.data:
            el = torch.unsqueeze(el, 0)
            
            if len(self.buffer) < self.size:
                self.buffer.append(el)
                r.append(el)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.size - 1)
                    r.append(self.buffer[i])
                    self.buffer[i] = el
                else:
                    r.append(el)
        
        return torch.cat(r).to(device)

In [44]:
def weights_normal(model):
    class_name = model.__class__.__name__
    if class_name.find("Conv") != -1:
        torch.nn.init.normal_(model.weight.data, 0.0, 0.02)
        if hasattr(model, "bias") and model.bias is not None:
            torch.nn.init.constant_(model.bias.data, 0.0)
    elif class_name.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(model.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(model.bias.data, 0.0)

In [45]:
train_transform = tt.Compose([
    tt.Resize((image_size, image_size), Image.BICUBIC),
    tt.ToTensor(),
    
    tt.Normalize(mean, std)
])

In [46]:
# train_ds = FacedesData(path_A, path_B, train_transform)
# train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

In [47]:
train_ds = Night2DayData("../input/maps-data/maps/", train_transform)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

In [48]:
class UnetGenerator(nn.Module):
    def __init__(self, c=[64, 128, 256, 512]):
        super(UnetGenerator, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, c[0], 3, 1, 1),
            nn.BatchNorm2d(c[0]), nn.LeakyReLU() 
        )
        
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(c[0] + 3, c[0] // 2, 3, 1, 1),
            nn.BatchNorm2d(c[0] // 2), nn.LeakyReLU(),
            
            nn.ConvTranspose2d(c[0] // 2, 3, 4, 2, 1),
            nn.BatchNorm2d(3), nn.LeakyReLU(),
            
            nn.Conv2d(3, 3, 4, 2, 1), nn.Tanh()
        )
        
        self.down_sample = nn.Sequential(*[
            self.down_block(c[i], c[i + 1]) for i in range(len(c) - 1)
        ])
        
        c = c[::-1]
        self.up_sample = nn.Sequential(*[
            self.up_block(c[i], c[i + 1]) for i in range(len(c) - 1)
        ])
    
    def down_block(self, inp, out):
        return nn.Sequential(*[
            nn.Conv2d(inp, out, 4, 2, 1),
            nn.BatchNorm2d(out), nn.LeakyReLU(),
        ])
    
    def up_block(self, inp, out):
        return nn.Sequential(*[
            nn.ConvTranspose2d(inp * 2, out, 4, 2, 1),
            nn.BatchNorm2d(out), nn.LeakyReLU(),
        ])
    
    def forward(self, x0):
        x = self.conv1(x0)
        r = []
        
        for block in self.down_sample:
            x = block(x); r.append(x)
        
        for block, b in zip(self.up_sample, r[::-1]):
            x = block(torch.cat((x, b), dim=1))
        
        x = self.conv2(torch.cat((x, x0), dim=1))
        return x

In [49]:
class ResBlock(nn.Module):
    def __init__(self, c):
        super(ResBlock, self).__init__()
        
        self.block = nn.Sequential(
            nn.Conv2d(c, c, 3, 1, 1), nn.BatchNorm2d(c),
            nn.LeakyReLU(), nn.Conv2d(c, c, 3, 1, 1)
        )
        
        self.norm = nn.Sequential(
            nn.BatchNorm2d(c),
            nn.LeakyReLU()
        )
    
    def forward(self, x):
        return self.norm(self.block(x) + x)


class ResNetGenerator(nn.Module):
    def __init__(self, c=[64, 128, 256, 1024, 2048], blocks=9):
        super(ResNetGenerator, self).__init__()
        
        model = [
            nn.Conv2d(3, c[0], 4, 2, 1), nn.BatchNorm2d(c[0]),
            nn.LeakyReLU()
        ]
        
        for i in range(len(c) - 1):
            model += [
                nn.Conv2d(c[i], c[i + 1], 4, 2, 1), nn.BatchNorm2d(c[i + 1]),
                nn.LeakyReLU()
            ]
        
        for i in range(blocks):
            model.append(ResBlock(c[-1]))
        
        for i in range(len(c) - 1, 0, -1):
            model += [
                nn.ConvTranspose2d(c[i], c[i - 1], 4, 2, 1), nn.BatchNorm2d(c[i - 1]),
                nn.LeakyReLU()
            ]
        
        model += [
            nn.ConvTranspose2d(c[0], 3, 4, 2, 1), nn.Tanh()
        ]
        
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        return self.model(x)

In [50]:
class Discriminator(nn.Module):
    def __init__(self, c=[64, 128, 256]):
        super(Discriminator, self).__init__()
        model = [
            nn.Conv2d(6, c[0], 4, 2, 1),
            nn.BatchNorm2d(c[0]), nn.LeakyReLU(),
        ]
        
        for i in range(len(c) - 1):
            model += [
                nn.Conv2d(c[i], c[i + 1], 4, 2, 1),
                nn.BatchNorm2d(c[i + 1]), nn.LeakyReLU(),
            ]
        
        model += [
            nn.AdaptiveAvgPool2d(1), nn.Flatten(), 
            nn.Linear(c[-1], 1), nn.Sigmoid()
        ]
        
        self.model = nn.Sequential(*model)
    
    def forward(self, A, x):
        return self.model(torch.cat((A, x), dim=1))

In [51]:
model_g = UnetGenerator(c=[64, 256, 512, 1024, 2048])
model_d = Discriminator(c=[32, 64, 128, 256, 512])

In [52]:
model_g.to(device)
model_d.to(device)
None

In [53]:
optim_g = torch.optim.Adamax(model_g.parameters(), lr=0.0002)
optim_d = torch.optim.Adamax(model_d.parameters(), lr=0.0002)

scheduler_g = torch.optim.lr_scheduler.CosineAnnealingLR(optim_g, T_max=4500)
scheduler_d = torch.optim.lr_scheduler.CosineAnnealingLR(optim_g, T_max=4500)

In [54]:
# c = torch.load("../input/dls-gan-checkpoints/edges2handbags_best.pth")
# model_g.load_state_dict(c["model_G"])
# model_d.load_state_dict(c["model_D"])
# optim_g.load_state_dict(c["optim_G"])
# optim_d.load_state_dict(c["optim_D"])

In [55]:
real_label = torch.ones(batch_size, device=device)
fake_label = torch.zeros(batch_size, device=device)

In [56]:
criterion_d = nn.BCELoss()
criterion_g = nn.L1Loss()

lambda_ = 2.5

In [57]:
def set_requires_grad(model, p=True):
    for param in model.parameters():
        param.requires_grad = p

In [58]:
def train_batch(batch, buffer):
    A = batch[0].to(device)
    B = batch[1].to(device)
    
    set_requires_grad(model_d, p=True)
    optim_d.zero_grad()
    
    fake = model_g(A)
    
    out = model_d(A, B).view(-1)
    lossD_real = criterion_d(out, real_label)
    real_score = out.mean().item()
    
    out = model_d(A, buffer.push_and_pop(fake.detach())).view(-1)
    lossD_fake = criterion_d(out, fake_label)
    fake_score = out.mean().item()
    
    lossD = (lossD_real + lossD_fake) / 2
    lossD.backward()
    
    optim_d.step()
    
    set_requires_grad(model_d, p=False)
    optim_g.zero_grad()
    
    out = model_d(A, fake).view(-1)
    
    lossG_gan = criterion_d(out, real_label)
    lossG_pix = criterion_g(fake, B) * lambda_
    
    lossG = lossG_pix + lossG_gan
    lossG.backward()
    
    optim_g.step()
    
    return {
        "lossG": lossG.item(), "lossD": lossD.item(),
        "real": real_score, "fake": fake_score
    }

In [59]:
def train_epoch(epoch):
    model_g.train()
    model_d.train()
    
    buffer = Buffer(150)
    
    tq = tqdm(train_dl, total=len(train_dl), desc=f"Epoch #{epoch}")
    scores = {"lossG": 0, "lossD": 0, "real": 0, "fake": 0, "n": 0}
    
    for batch in tq:
        m = train_batch(batch, buffer)
        scores["n"] += 1
        
        for c in m.keys():
            scores[c] += m[c]
    
        tq.set_postfix({
            k: v / scores["n"] for k, v in scores.items()
        })

In [60]:
import matplotlib.pyplot as plt

In [61]:
denorm = Denormalize(mean, std)

In [62]:
def show_sample(n=5):
    model_g.eval()
    model_d.eval()
    
    x = [train_ds[i][0] for i in range(n)]
    img = make_grid([denorm(i) for i in x], nrow=n).permute(1, 2, 0)
    
    plt.figure(figsize = (15,4))
    plt.imshow(img)
    plt.show()
    
    img = make_grid([denorm(model_g(i[None].to(device))[0].cpu()) for i in x], nrow=n).permute(1, 2, 0)
    plt.figure(figsize = (15,4))
    plt.imshow(img)
    plt.show()

In [63]:
for epoch in range(4250):
    torch.cuda.empty_cache()
    train_epoch(epoch)
    
    scheduler_g.step()
    scheduler_d.step()
    
    if epoch % 25 == 0:
        show_sample(5)

In [64]:
show_sample()

In [65]:
torch.save({
    "model_G": model_g.state_dict(),
    "model_D": model_d.state_dict(),
    "optim_D": optim_d.state_dict(),
    "optim_G": optim_g.state_dict(),
}, "maps_best.pth")

In [66]:
from IPython.display import FileLink

In [68]:
FileLink("maps_best.pth")