In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.v2 as transforms
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import os
import numpy as np
from tqdm.auto import tqdm
from torchvision import datasets

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
IMAGE_PATH = '/media/kati/drive1TB/bg-20k/train_images'

# Load Data & Augmenting images

In [None]:
class Cutout(object):
    def __init__(self, box_length):
        self.box_length = box_length

    def __call__(self, image):
        # Get the dimensions of the image
        _, height, width = image.size()
    
        # Generate random coordinates for the top-left corner of the rectangle
        x1 = np.random.randint(0, width - self.box_length)
        y1 = np.random.randint(0, height - self.box_length)
    
        # Calculate the coordinates of the bottom-right corner of the rectangle
        x2 = x1 + self.box_length
        y2 = y1 + self.box_length
    
        # Apply the mask to the specified region
        image[:, y1:y2, x1:x2] = 255
    
        return image

In [None]:
class PairedDataset(Dataset):
    def __init__(self, dataset_arg, dataset_ori):
        assert len(dataset_arg) == len(dataset_ori), "Datasets must have the same length"
        self.dataset_arg = dataset_arg
        self.dataset_ori = dataset_ori

    def __getitem__(self, index):
        x = self.dataset_arg[index][0]  # Get image from dataset_arg
        y = self.dataset_ori[index][0]  # Get image from dataset_ori
        return x, y

    def __len__(self):
        return len(self.dataset_arg)

In [None]:
WIDTH = 100
HEIGHT = 100 
transform_arg = transforms.Compose([
    transforms.PILToTensor(),
    transforms.Resize(size=(WIDTH, HEIGHT), antialias=True),
    Cutout(40),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize(mean=[0.4338, 0.4341, 0.4100], std=[0.3002, 0.2798, 0.3021]), #  Normalize with BG-20k stats
])

transform_ori = transforms.Compose([
    transforms.PILToTensor(),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Resize(size=(WIDTH, HEIGHT), antialias=True),
    transforms.Normalize(mean=[0.4338, 0.4341, 0.4100], std=[0.3002, 0.2798, 0.3021]), #  Normalize with BG-20k stats
])

def denormalize(tensor, mean=[0.4338, 0.4341, 0.4100], std=[0.3002, 0.2798, 0.3021]):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor.clamp(0, 1)

In [None]:
batch_size = 1

dataset_arg = datasets.ImageFolder(IMAGE_PATH, transform=transform_arg)
dataset_ori = datasets.ImageFolder(IMAGE_PATH, transform=transform_ori)

paired_dataset = PairedDataset(dataset_arg, dataset_ori)

dataloader = DataLoader(paired_dataset, batch_size=batch_size, shuffle=True)

In [None]:
for x, y in dataloader:
    print(x.shape)
    x = denormalize(x[0])
    y = denormalize(y[0])
    fig, axs = plt.subplots(1, 2)
    axs[0].imshow(x.permute(1, 2, 0))
    axs[1].imshow(y.permute(1, 2, 0))
    axs[1].axis(False)
    axs[0].axis(False)
    plt.show()
    break

# Patch-Based Image Inpainting with Generative Adversarial Networks (Demir and Unal, 2018): 
based of this [paper](https://arxiv.org/pdf/1803.07422.pdf) by Ugur Demir, Gozde Unal

## ResNet

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding='same'),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding='same'),
        )

    def forward(self, x):
        return x + self.conv(x)

class RasNet(nn.Module):
    def __init__(self, in_channels=3, hidden_channels=32):
        super().__init__()
        self.downsample = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=9, padding='same'),
            nn.BatchNorm2d(hidden_channels),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels*2, kernel_size=3, padding='same'),
            nn.BatchNorm2d(hidden_channels*2),
            nn.ReLU()
        )
        self.residual_block = nn.Sequential(
            ResidualBlock(hidden_channels*2),
            ResidualBlock(hidden_channels*2),
            ResidualBlock(hidden_channels*2),
            ResidualBlock(hidden_channels*2),
            ResidualBlock(hidden_channels*2),
            ResidualBlock(hidden_channels*2),
        )
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(hidden_channels*2, hidden_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_channels),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_channels, in_channels, kernel_size=9, padding=4),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.downsample(x)
        x = self.residual_block(x)
        x = self.upsample(x)
        return x

In [None]:
class SharedLayer(nn.Module):
    def __init__(self, in_channels=3, hidden_channels=32):
        super().__init__()
        self.shared_layer = nn.Sequential(
            self.make_shared_block(in_channels, hidden_channels),
            self.make_shared_block(hidden_channels, hidden_channels*2),
            self.make_shared_block(hidden_channels*2, hidden_channels*4),
            self.make_shared_block(hidden_channels*4, hidden_channels*8),
        )
        
    @staticmethod
    def make_shared_block(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same'),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
    def forward(self, x):
        return self.shared_layer(x)

In [None]:
class GlobalDiscriminator(nn.Module):
    def __init__(self, hidden_channels=32):
        super().__init__()
        self.global_path = nn.Sequential(
            nn.Conv2d(hidden_channels * 8, hidden_channels * 16, kernel_size=3, padding='same'),
            nn.BatchNorm2d(hidden_channels * 16),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(hidden_channels * 16, 1, kernel_size=3, stride=2, padding=1)
        )
    
    def forward(self, x):
        return self.global_path(x)

In [None]:
class PatchDiscriminator(nn.Module):
    def __init__(self, hidden_channels=32):
        super().__init__()
        self.patch_gan = nn.Sequential(
            nn.Conv2d(hidden_channels * 8, hidden_channels * 16, kernel_size=3, padding='same'),
            nn.BatchNorm2d(hidden_channels * 16),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(hidden_channels * 16, 1, kernel_size=1)
        )

    def forward(self, x):
        x = self.patch_gan(x)
        return x

# Compute Loss
simplified $$L_{rec} = L_{1}Loss$$
$$L_{adv} = BCELoss$$

## Generator Loss
We update the generator parameters by joint loss $$L_{joint} = \lambda_1L_{rec} + \lambda_2L_{g-adv} + \lambda_3L_{p-adv}$$

In [None]:
def get_gen_loss(input_image, real_image, gen, shared, global_disc, patch_disc, recon_criterion, adver_criterion, lambdas=[0.995, 0.0025, 0.0025]):
    fake_image = gen(input_image) # Pass RasNet
    l_rec = recon_criterion(fake_image, real_image)

    # To PGGAN Discriminator
    shared_fake_output = shared(fake_image)

    # Global path
    global_fake_pred = global_disc(shared_fake_output)
    l_g_adv = adver_criterion(global_fake_pred, torch.ones_like(global_fake_pred))

    # Patch path
    patch_fake_pred = patch_disc(shared_fake_output)
    l_p_adv = adver_criterion(patch_fake_pred, torch.ones_like(patch_fake_pred))
    
    return lambdas[0]*l_rec + lambdas[1]*l_g_adv + lambdas[2]*l_p_adv

## Global Driscriminator Loss
by $$L_{g-adv} = BCELoss$$

In [None]:
def get_global_disc_loss(input_image, real_image, gen, shared, global_disc, adver_criterion):
    fake_image = gen(input_image)

    shared_fake_output = shared(fake_image.detach())
    shared_real_output = shared(real_image)
    
    global_fake_pred = global_disc(shared_fake_output)
    global_real_pred = global_disc(shared_real_output)

    global_fake_loss = adver_criterion(global_fake_pred, torch.zeros_like(global_fake_pred))
    global_real_loss = adver_criterion(global_real_pred, torch.ones_like(global_real_pred))
    
    return (global_fake_loss + global_real_loss) / 2

## Patch Driscriminator Loss
by  $$L_{p-adv} = BCELoss$$

In [None]:
def get_patch_disc_loss(input_image, real_image, gen, shared, patch_disc, adver_criterion):
    fake_image = gen(input_image)

    shared_fake_output = shared(fake_image.detach())
    shared_real_output = shared(real_image)
    
    patch_fake_pred = patch_disc(shared_fake_output)
    patch_real_pred = patch_disc(shared_real_output)

    patch_fake_loss = adver_criterion(patch_fake_pred, torch.zeros_like(patch_fake_pred))
    patch_real_loss = adver_criterion(patch_real_pred, torch.ones_like(patch_real_pred))
    
    return (patch_fake_loss + patch_real_loss) / 2

## Shared Layer Loss
by $$L_{g-adv} + L_{p-adv}$$

In [None]:
def get_shared_layer_loss(input_image, real_image, gen, shared, global_disc, patch_disc, adver_criterion):
    fake_image = gen(input_image)

    shared_fake_output = shared(fake_image.detach())
    shared_real_output = shared(real_image)
    
    patch_fake_pred = patch_disc(shared_fake_output)
    patch_real_pred = patch_disc(shared_real_output)

    patch_fake_loss = adver_criterion(patch_fake_pred, torch.zeros_like(patch_fake_pred))
    patch_real_loss = adver_criterion(patch_real_pred, torch.ones_like(patch_real_pred))

    global_fake_pred = global_disc(shared_fake_output)
    global_real_pred = global_disc(shared_real_output)

    global_fake_loss = adver_criterion(global_fake_pred, torch.zeros_like(global_fake_pred))
    global_real_loss = adver_criterion(global_real_pred, torch.ones_like(global_real_pred))

    global_loss = (global_fake_loss + global_real_loss) / 2
    patch_loss = (patch_fake_loss + patch_real_loss) / 2

    return global_loss + patch_loss

# Init Model

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
lr = 1e-6

In [None]:
gen = RasNet().to(device)
gen_opt = torch.optim.Adam(params=gen.parameters(), lr=lr)

shared_layer = SharedLayer().to(device)
shared_opt = torch.optim.Adam(params=shared_layer.parameters(), lr=lr)

global_disc = GlobalDiscriminator().to(device)
global_opt = torch.optim.Adam(params=global_disc.parameters(), lr=lr)

patch_disc = PatchDiscriminator().to(device)
patch_opt = torch.optim.Adam(params=patch_disc.parameters(), lr=lr)

In [None]:
gen = gen.apply(weights_init)
shared_layer = shared_layer.apply(weights_init)
global_disc = global_disc.apply(weights_init)
patch_disc = patch_disc.apply(weights_init)

# Load pre-train model (Optional)

In [None]:
gen = RasNet().to(device)
gen_opt = torch.optim.Adam(gen.parameters())
global_disc = GlobalDiscriminator().to(device)
global_opt = torch.optim.Adam(global_disc.parameters())
shared_layer = SharedLayer().to(device)
shared_opt = torch.optim.Adam(shared_layer.parameters())
patch_disc = PatchDiscriminator().to(device)
patch_opt = torch.optim.Adam(patch_disc.parameters())

# Load the state dictionaries
checkpoint = torch.load("pix2pix_210000.pth")
gen.load_state_dict(checkpoint['gen'])
gen_opt.load_state_dict(checkpoint['gen_opt'])
global_disc.load_state_dict(checkpoint['global_disc'])
global_opt.load_state_dict(checkpoint['global_opt'])
shared_layer.load_state_dict(checkpoint['shared_layer'])
shared_opt.load_state_dict(checkpoint['shared_layer_opt'])
patch_disc.load_state_dict(checkpoint['patch_disc'])
patch_opt.load_state_dict(checkpoint['patch_disc_opt'])

In [None]:
adver_criterion = nn.L1Loss()
recon_criterion  = nn.BCEWithLogitsLoss()

# Train Model

In [None]:
torch.autograd.set_detect_anomaly(True)

cur_step = 0
n_epochs = 20
display_step = 100

mean_global_loss = 0
mean_patch_loss = 0
mean_gen_loss = 0

for epoch in range(n_epochs):
    for input_image, real_image in tqdm(dataloader):
        
        input_image = input_image.to(device)
        real_image = real_image.to(device)

        # Train Discriminator
        global_opt.zero_grad()
        global_loss = get_global_disc_loss(input_image, real_image, gen, shared_layer, global_disc, adver_criterion)
        global_loss.backward(retain_graph=True)
        global_opt.step()

        patch_opt.zero_grad()
        patch_loss = get_patch_disc_loss(input_image, real_image, gen, shared_layer, patch_disc, adver_criterion)
        patch_loss.backward(retain_graph=True)
        patch_opt.step()

        shared_opt.zero_grad()
        shared_loss = get_shared_layer_loss(input_image, real_image, gen, shared_layer, global_disc, patch_disc, adver_criterion)
        shared_loss.backward(retain_graph=True)
        shared_opt.step()

        # Train Generator
        gen_opt.zero_grad()
        gen_loss = get_gen_loss(input_image, real_image, gen, shared_layer, global_disc, patch_disc, recon_criterion, adver_criterion)
        gen_loss.backward()
        gen_opt.step()

        mean_global_loss = global_loss.item() / display_step
        mean_patch_loss = patch_loss.item() / display_step
        mean_gen_loss = gen_loss.item() / display_step
        
        if cur_step % display_step == 0 and cur_step > 0:
            print(f'steps:{cur_step}, mean_gen_loss:{mean_gen_loss}, mean_global_loss:{mean_global_loss}, mean_patch_loss:{mean_patch_loss}')
            fake_image = gen(input_image)
            fake_image = denormalize(fake_image[0].detach().cpu())
            input_image = denormalize(input_image[0].cpu())
            real_image = denormalize(real_image[0].cpu())
            fig, axs = plt.subplots(1, 3)
            axs[0].imshow(fake_image.permute(1, 2, 0))
            axs[0].axis(False)
            axs[0].set_title('Generate')
            axs[1].imshow(input_image.permute(1, 2, 0))
            axs[1].axis(False)
            axs[1].set_title('Input')
            axs[2].imshow(real_image.permute(1, 2, 0))
            axs[2].axis(False)
            axs[2].set_title('Real')
            plt.show()
        cur_step += 1

    torch.save({'gen': gen.state_dict(),
        'gen_opt': gen_opt.state_dict(),
        'global_disc': global_disc.state_dict(),
        'global_opt': global_opt.state_dict(),
        'shared_layer': shared_layer.state_dict(),
        'shared_layer_opt': shared_opt.state_dict(),
        'patch_disc': patch_disc.state_dict(),
        'patch_disc_opt': patch_opt.state_dict()
    }, f"pix2pix_{cur_step}.pth")