In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import os
import cv2
import json
import h5py
import pandas, numpy, random
import matplotlib.pyplot as plt

In [None]:
# Check HDF5 file
h5py_dir = '/content/drive/My Drive/Colab Notebooks/Data/train.h5py'

with h5py.File(h5py_dir, 'r') as hdf:
    if 'complete' in hdf and 'local' in hdf and 'parameters' in hdf:
        print(f"Number of in_images: {len(hdf['complete'])}")
        print(f"Number of local: {len(hdf['local'])}")
        print(f"Number of parameters: {len(hdf['parameters'])}")
    else:
        print("Expected datasets not found in the HDF5 file.")

In [None]:
# Check if CUDA is available. If yes, set default tensor type to cuda

if torch.cuda.is_available():
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    print("using cuda:", torch.cuda.get_device_name(0))
    pass

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
class View(nn.Module):
    # A variation of .view that can be exectued inside nn.Sequential()
    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(self.shape)

In [None]:
def in_painting(in_image, coordinate, left_patch, right_patch):

    patched_image = in_image

    # Resize left eye
    lx, ly = int(coordinate[0].item()), int(coordinate[1].item())
    lw, lh = int(coordinate[2].item()), int(coordinate[3].item())
    up = max(0, int(ly - lh / 2))
    down = min(80, int(ly + lh / 2) + 1)
    lh = down - up
    left_patch_resized = cv2.resize(left_patch, (lw + 1, lh))

    # In paint the left eye
    patched_image[up:down, int(lx - lw / 2):(int(lx - lw / 2) + lw + 1)] = left_patch_resized

    # Resize the right eye
    rx, ry = int(coordinate[4].item()), int(coordinate[5].item())
    rw, rh = int(coordinate[6].item()), int(coordinate[7].item())
    up = max(0, int(ry - rh / 2))
    down = min(80, int(ry + rh / 2) + 1)
    rh = down - up
    right_patch_resized = cv2.resize(right_patch, (rw + 1, rh))

    # In paint the right eye
    patched_image[up:down, int(rx - rw / 2):(int(rx - rw / 2) + rw + 1)] = right_patch_resized

    return patched_image

In [None]:
class Celeb_ID_Dataset(Dataset):

    def __init__(self, file_path):
        self.file_path = file_path
        self.file_object = h5py.File(file_path, 'r')
        self.data_len = len(self.file_object['complete'])

    def __len__(self):
        return self.data_len

    def __getitem__(self, idx):
        complete = self.file_object['complete'][idx]
        local = self.file_object['local'][idx]
        parameter = self.file_object['parameters'][idx]

        complete = torch.FloatTensor(complete).permute(2,0,1).view(3, 80, 256) / 255.0
        local = torch.FloatTensor(local).permute(2,0,1).view(3, 100, 200) / 255.0
        param = torch.FloatTensor(parameter)

        return complete, local, param

    def __del__(self):
        self.file_object.close()

# Create Dataset and Dataloader object
file_path = '/content/drive/My Drive/Colab Notebooks/Data/train.h5py'
dataset = Celeb_ID_Dataset(file_path)
params = {'batch_size': 32, 'shuffle': True, 'generator':torch.Generator(device='cuda')}
dataloader = DataLoader(dataset, **params)

In [None]:
class Discriminator(nn.Module):

    def __init__(self):
        super().__init__()

        # Global Discriminator
        self.global_discriminator = nn.Sequential(
            # The input should be a 80 * 256 * 3 image that has been patched
            nn.Conv2d(3, 128, kernel_size=(9, 27), stride=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 64, kernel_size=(7, 27), stride=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 3, kernel_size=(1, 6), stride=2),
            nn.BatchNorm2d(3),
            nn.LeakyReLU(0.2),
            View((32, 3*8*20)),
            nn.Linear(3*8*20, 100)
            # The output should be a 1D tensor with 100 elements
        )

        # Local Discriminator
        self.local_discriminator = nn.Sequential(
            # The input should be a 100 * 200 * 3 image
            nn.Conv2d(3, 128, kernel_size=5, stride=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 64, kernel_size=5, stride=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 3, kernel_size=(4, 8), stride=2),
            nn.BatchNorm2d(3),
            nn.LeakyReLU(0.2),
            View((32, 3*10*20)),
            nn.Linear(3*10*20, 100)
            # The output should be a 1D tensor with 100 elements
        )

        # The final layer with one neuron whose output is a scalar and a Sigmoid activation
        self.concat_fc = nn.Sequential(
            nn.Linear(200, 1),
            nn.Sigmoid()
        )

        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        self.counter = 0
        self.progress = []
        pass

    def forward(self, global_x, local_x):
        # Concatenate the global and local discriminator outputs
        global_output = self.global_discriminator(global_x)
        local_output = self.local_discriminator(local_x)
        output = torch.cat((global_output, local_output), dim=1)
        output = self.concat_fc(output)
        return output

    def cal_bce_loss(self, d_score, label):
        # BCE loss of the concatenated local and global discriminator
        alpha = 5e-4
        d_loss = F.binary_cross_entropy(d_score, label)
        return alpha * d_loss

    def cal_mse_loss(self, local_x, ur_local):
        # MSE loss of the global image
        local_loss = F.mse_loss(local_x, ur_local, reduction='mean')
        return local_loss

    def train_on_true(self, global_x, local_x):
        # Calculate d_score and the total loss
        d_score = self.forward(global_x, local_x)
        label = torch.ones(d_score.shape)
        loss = self.cal_bce_loss(d_score, label)

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 1000 == 0):
            print("counter = ", self.counter)
            pass

        # Reset gradients, perform back propagation and update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        pass

    def train_on_fake(self, global_x, local_x, local_y):
        # Calculate d_score and the total loss
        d_score = self.forward(global_x, local_x)
        label = torch.zeros(d_score.shape)
        bce_loss = self.cal_bce_loss(d_score, label)
        mse_loss = self.cal_mse_loss(local_x, local_y)

        loss = bce_loss + mse_loss

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 1000 == 0):
            print("counter = ", self.counter)
            pass

        # Reset gradients, perform back propagation and update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        pass

    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.1, 0.02))
        pass


In [None]:
class Generator(nn.Module):

    def __init__(self):
        super().__init__()

        self.concentration = nn.Sequential(
            # The input should be a 1 * 3 * 80 * 256 image
            nn.Conv2d(3, 128, kernel_size=(4, 27), stride=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=(3, 27), stride=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 3, kernel_size=(1, 7), stride=2),
            nn.BatchNorm2d(3),
            nn.LeakyReLU(0.2)
            # This gives a convolved image of shape 1 * 3 * 10 * 20
        )

        self.model = nn.Sequential(
            # The input should be a 1 * 3 * 10 * 20 tensor
            nn.ConvTranspose2d(3, 64, kernel_size=5, stride=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 128, kernel_size=5, stride=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 3, kernel_size=(4, 24), stride=2),
            nn.BatchNorm2d(3),
            nn.Sigmoid()
            # The output should be a 1 * 3 * 100 * 200 image of the in painting region
        )

        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        self.counter = 0
        self.progress = []
        pass

    def forward(self, inputs):
        convolved_image = self.concentration(inputs)
        return self.model(convolved_image)

    def train(self, D, global_x, local_x, local_y):
        # Calculate d_score and the total loss
        d_score = D.forward(global_x, local_x)
        label = torch.ones(d_score.shape)
        bce_loss = D.cal_bce_loss(d_score, label)
        mse_loss = D.cal_mse_loss(local_x, local_y)

        loss = bce_loss + mse_loss

        # Increase counter and accumulate error every 10
        self.counter += 1
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass

        # Reset gradients, perform back propagation and update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        pass

    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.1, 0.02))
        pass


In [None]:
def process_image(batch_complete, batch_patches, batch_parameter):
    # Unpack the batch, perform individual operation on each and repack in to batch
    patched = []
    for i in range(batch_patches.size(0)):
        in_image = batch_complete[i].unsqueeze(0).permute(0,2,3,1).view(80,256,3).cpu().numpy()
        raw_patches = batch_patches[i].unsqueeze(0).permute(0,2,3,1).view(100,200,3).detach().cpu().numpy()
        coordinate = batch_parameter[i].unsqueeze(1)

        patches = (raw_patches).clip(0, 1).astype(numpy.float32)
        left_patch = patches[0:100, 0:100, ...]
        right_patch = patches[0:100, 100:200, ...]

        completion = in_painting(in_image, coordinate, left_patch, right_patch)
        patched.append(torch.FloatTensor(completion).permute(2,0,1).view(3, 80, 256))

    # Re-pack to mini-batch on GPU device
    return torch.stack(patched).to(device)

In [None]:
# Create Discriminator and Generator
D = Discriminator().to(device)
G = Generator().to(device)

In [None]:
# Load parameters (opt)
D.load_state_dict(torch.load('/content/drive/My Drive/Colab Notebooks/D_epoch_25.pth'))
G.load_state_dict(torch.load('/content/drive/My Drive/Colab Notebooks/G_epoch_25.pth'))

In [None]:
%%time

# Train the Network

current_epoch = 80
itr_epoch = 20
batch_size = 32

for epoch in range(itr_epoch):
    current_epoch += 1
    print("epoch = ", current_epoch)

    for batch_complete, batch_local, batch_parameter in dataloader:
        # Move data to CUDA device
        batch_complete = batch_complete.to(device)
        batch_local = batch_local.to(device)

        # Train the discriminator on true *
        D.train_on_true(batch_complete, batch_local)

        # In paint with noise as input tensor
        batch_noise = 0.2 * torch.randn(batch_size, 3, 100, 200)
        batch_input = process_image(batch_complete, batch_noise, batch_parameter)

        # Generate patches and in-paint the images
        batch_patches = G.forward(batch_input)
        batch_completion = process_image(batch_complete, batch_patches, batch_parameter)

        # Train the discriminator on fake (false) *
        local_x = batch_patches.detach()
        D.train_on_fake(batch_completion, local_x, batch_local)

        # Train the generator *
        G.train(D, batch_completion, batch_patches, batch_local)

    if (current_epoch % 5 == 0):
        # Save the model weights
        model_dir = '/content/drive/My Drive/Colab Notebooks/'
        D_name = 'D_epoch_{}.pth'.format(current_epoch)
        G_name = 'G_epoch_{}.pth'.format(current_epoch)
        torch.save(D.state_dict(), os.path.join(model_dir, D_name))
        torch.save(G.state_dict(), os.path.join(model_dir, G_name))

        # Plot Progress
        # D.plot_progress()
        # G.plot_progress()
        pass

    pass

In [None]:
# Save the model weights
model_dir = '/content/drive/My Drive/Colab Notebooks/'

D_name = 'D_epoch_{}.pth'.format(current_epoch)
G_name = 'G_epoch_{}.pth'.format(current_epoch)

torch.save(D.state_dict(), os.path.join(model_dir, D_name))
torch.save(G.state_dict(), os.path.join(model_dir, G_name))

In [None]:
# Check the result after training

fig, axs = plt.subplots(1, 4, figsize=(15, 5))

for batch_complete, batch_local, batch_parameter in dataloader:
    # Move data to CUDA device
    batch_complete = batch_complete.to(device)
    batch_local = batch_local.to(device)

    # In paint with noise as input tensor
    batch_noise = 0.2 * torch.randn(32, 3, 100, 200)
    batch_input = process_image(batch_complete, batch_noise, batch_parameter)

    # Generate patches and in-paint the images
    batch_patches = G.forward(batch_input)
    batch_completion = process_image(batch_complete, batch_patches, batch_parameter)

    for i in range(batch_completion.size(0)):
        output = batch_completion[i].unsqueeze(0)
        input = batch_input[i].unsqueeze(0)
        local_y = batch_local[i].unsqueeze(0)
        raw_patches = batch_patches[i].unsqueeze(0)
        parameter = batch_parameter[i].unsqueeze(1)

        output = output.permute(0,2,3,1).view(80,256,3).cpu().numpy()
        input = input.permute(0,2,3,1).view(80,256,3).cpu().numpy()
        local_y = local_y.permute(0,2,3,1).view(100,200,3).cpu().numpy()
        patches = raw_patches.permute(0,2,3,1).view(100,200,3).detach().cpu().numpy()

        axs[0].imshow(patches, cmap='gray')
        axs[0].set_title("Patches")

        axs[1].imshow(local_y, cmap='gray')
        axs[1].set_title("Original")

        axs[2].imshow(input, cmap='gray')
        axs[2].set_title("Input")

        axs[3].imshow(output.clip(0, 1).astype(numpy.float32), cmap='gray')
        axs[3].set_title("Completion")

        break

    break

In [None]:
# Plot discriminator error
D.plot_progress()

In [None]:
# Plot generator error
G.plot_progress()

In [None]:
# Application

fig, axes = plt.subplots(1, 2, figsize=(9, 6))

# Taylor Swift
input_image = cv2.imread('/content/drive/My Drive/Colab Notebooks/Data/Test Image/TS.jpg')
parameter = numpy.array([[81,], [41,], [53,], [48,], [179,], [40,], [48,], [49,]])

# Donald Trump
# input_image = cv2.imread('/content/drive/My Drive/Colab Notebooks/Data/Test Image/DT.jpg')
# parameter = numpy.array([[87,], [42,], [45,], [45,], [190,], [33,], [45,], [45,]])

image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)

gallery = []
gallery.append(image.copy())

noise = 0.2 * torch.randn(100, 200, 3).cpu().numpy()
noise = (noise * 255).clip(0, 255).astype(numpy.uint8)
left = noise[0:100, 0:100, ...]
right = noise[0:100, 100:200, ...]

input = in_painting(image, parameter, left, right)
batch_input = torch.FloatTensor(input).permute(2,0,1).view(3, 80, 256) / 255.0
batch_input = batch_input.repeat(32, 1, 1, 1).to(device)
batch_patches = G.forward(batch_input).detach()


for i in range(1):
    raw_patches = batch_patches[i].unsqueeze(0)
    gray_patches = raw_patches.permute(0,2,3,1).view(100,200,3).cpu().numpy()
    patches = (gray_patches * 255).clip(0, 255).astype(numpy.uint8)

    left_patch = patches[0:100, 0:100, ...]
    right_patch = patches[0:100, 100:200, ...]

    completion = in_painting(image, parameter, left_patch, right_patch)
    gallery.append(completion)
    pass

axes = axes.flatten()
num = 0
for img, ax in zip(gallery, axes):
    ax.imshow(img)
    if num == 0:
        ax.set_title("Original")
    else:
        ax.set_title(f"Completion")
    ax.axis('off')
    num += 1

plt.tight_layout()
plt.show()