In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import json
import h5py
import pandas, numpy, random, cv2
import matplotlib.pyplot as plt

In [None]:
# Check HDF5 file
h5py_dir = '/content/drive/My Drive/Colab Notebooks/Data/train.h5py'

with h5py.File(h5py_dir, 'r') as hdf:
    if 'in_image' in hdf and 'local' in hdf and 'parameters' in hdf:
        print(f"Number of in_images: {len(hdf['in_image'])}")
        print(f"Number of local: {len(hdf['local'])}")
        print(f"Number of parameters: {len(hdf['parameters'])}")
    else:
        print("Expected datasets not found in the HDF5 file.")

In [None]:
# Check if CUDA is available. If yes, set default tensor type to cuda

if torch.cuda.is_available():
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    print("using cuda:", torch.cuda.get_device_name(0))
    pass

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

In [None]:
class View(nn.Module):
    # A variation of .view that can be exectued inside nn.Sequential()
    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(self.shape)

In [None]:
def in_painting(in_image, coordinate, left_patch, right_patch):

    patched_image = in_image

    # Resize left eye
    lx, ly = int(coordinate[0].item()), int(coordinate[1].item())
    lw, lh = int(coordinate[2].item()), int(coordinate[3].item())
    up = max(0, int(ly - lh / 2))
    down = min(80, int(ly + lh / 2) + 1)
    lh = down - up
    left_patch_resized = cv2.resize(left_patch, (lw + 1, lh))

    # In paint the left eye
    patched_image[up:down, int(lx - lw / 2):(int(lx - lw / 2) + lw + 1)] = left_patch_resized

    # Resize the right eye
    rx, ry = int(coordinate[4].item()), int(coordinate[5].item())
    rw, rh = int(coordinate[6].item()), int(coordinate[7].item())
    up = max(0, int(ry - rh / 2))
    down = min(80, int(ry + rh / 2) + 1)
    rh = down - up
    right_patch_resized = cv2.resize(right_patch, (rw + 1, rh))

    # In paint the right eye
    patched_image[up:down, int(rx - rw / 2):(int(rx - rw / 2) + rw + 1)] = right_patch_resized

    return patched_image

In [None]:
class Celeb_ID_Dataset(Dataset):

    def __init__(self, file_path):
        self.file_path = file_path
        self.file_object = h5py.File(file_path, 'r')
        self.data_len = len(self.file_object['in_image'])

    def __len__(self):
        return self.data_len

    def __getitem__(self, idx):
        in_image = self.file_object['in_image'][idx]
        local = self.file_object['local'][idx]
        parameter = self.file_object['parameters'][idx]
        left = local[0:100, 0:100, ...]
        right = local[0:100, 100:200, ...]
        ur_image = in_painting(in_image, parameter, left, right)

        in_image_ = torch.FloatTensor(in_image).permute(2,0,1).view(3, 80, 256) / 255.0
        local_ = torch.FloatTensor(local).permute(2,0,1).view(3, 100, 200) / 255.0
        param_ = torch.FloatTensor(parameter)
        ur_image_ = torch.FloatTensor(ur_image).permute(2,0,1).view(3, 80, 256) / 255.0

        return in_image_, local_, param_, ur_image_

    def __del__(self):
        self.file_object.close()

# Create Dataset and Dataloader object
file_path = '/content/drive/My Drive/Colab Notebooks/Data/train.h5py'
dataset = Celeb_ID_Dataset(file_path)
params = {'batch_size': 32, 'shuffle': True, 'generator':torch.Generator(device='cuda')}
dataloader = DataLoader(dataset, **params)

In [None]:
# @title
# Check dataset and dataloader

print("Length of the dataset:", len(dataset))

def show_batch_images(images, batch_size):
    fig, axes = plt.subplots(4, 8, figsize=(15, 7))
    axes = axes.flatten()

    for img, ax in zip(images, axes):
        img = img.permute(0,2,3,1).view(100,200,3).cpu().numpy()
        ax.imshow(img, cmap='gray')
        ax.axis('off')

    plt.tight_layout()
    plt.show()

try:
    data_iter = iter(dataloader)
    first_batch = next(data_iter)
    print("\nDataLoader contains data.")

    in_image, local, parameters = first_batch
    print("\nExample of a batch of parameters:")
    print(parameters)
    print("\nExample of a batch of images:")
    # show_batch_images(in_image, batch_size=32)
    show_batch_images(local, batch_size=32)

except StopIteration:
    print("DataLoader is empty.")

except Exception as e:
    print(f"An error occurred: {e}")

In [None]:
class Discriminator(nn.Module):

    def __init__(self):
        super().__init__()

        # Global Discriminator
        self.global_discriminator = nn.Sequential(
            # The input should be a 80 * 256 * 3 image that has been patched
            nn.Conv2d(3, 128, kernel_size=(9, 27), stride=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 64, kernel_size=(7, 27), stride=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 3, kernel_size=(1, 6), stride=2),
            nn.BatchNorm2d(3),
            nn.LeakyReLU(0.2),
            View((32, 3*8*20)),
            nn.Linear(3*8*20, 100)
            # The output should be a 1D tensor with 100 elements
        )

        # Local Discriminator
        self.local_discriminator = nn.Sequential(
            # The input should be a 100 * 200 * 3 image
            nn.Conv2d(3, 128, kernel_size=5, stride=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 64, kernel_size=5, stride=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 3, kernel_size=(4, 8), stride=2),
            nn.BatchNorm2d(3),
            nn.LeakyReLU(0.2),
            View((32, 3*10*20)),
            nn.Linear(3*10*20, 100)
            # The output should be a 1D tensor with 100 elements
        )

        # The final layer with one neuron whose output is a scalar and a Sigmoid activation
        self.concat_fc = nn.Sequential(
            nn.Linear(200, 1),
            nn.Sigmoid()
        )

        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        self.counter = 0
        self.progress = []
        pass

    def forward(self, global_x, local_x):
        # Concatenate the global and local discriminator outputs
        global_output = self.global_discriminator(global_x)
        local_output = self.local_discriminator(local_x)
        output = torch.cat((global_output, local_output), dim=1)
        output = self.concat_fc(output)
        return output

    def cal_bce_loss(self, d_score, label):
        """
        Args:
            d_score: Output of the global and local discriminators.
            label: L'étiquette.
        """
        # BCE loss of the concatenated local and global discriminator
        alpha = 5e-4
        d_loss = F.binary_cross_entropy(d_score, label)
        return alpha * d_loss

    def cal_mse_loss(self, local_x, ur_local):
        # MSE loss of the global image
        local_loss = F.mse_loss(local_x, ur_local, reduction='mean')
        return local_loss

    def train_on_true(self, global_x, local_x):
        # Calculate d_score and the total loss
        d_score = self.forward(global_x, local_x)
        label = torch.ones(d_score.shape)
        loss = self.cal_bce_loss(d_score, label)

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 1000 == 0):
            print("counter = ", self.counter)
            pass

        # Reset gradients, perform back propagation and update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        pass

    def train_on_fake(self, global_x, local_x, ur_local):
        # Calculate d_score and the total loss
        d_score = self.forward(global_x, local_x)
        label = torch.zeros(d_score.shape)
        bce_loss = self.cal_bce_loss(d_score, label)
        mse_loss = self.cal_mse_loss(local_x, ur_local)

        loss = bce_loss + mse_loss

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 1000 == 0):
            print("counter = ", self.counter)
            pass

        # Reset gradients, perform back propagation and update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        pass

    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.1, 0.02))
        pass


In [None]:
class Generator(nn.Module):

    def __init__(self):
        super().__init__()

        self.concentration = nn.Sequential(
            # The input should be a 1 * 3 * 80 * 256 image
            nn.Conv2d(3, 128, kernel_size=(4, 27), stride=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=(3, 27), stride=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 3, kernel_size=(1, 7), stride=2),
            nn.BatchNorm2d(3),
            nn.LeakyReLU(0.2)
            # This gives a convolved image of shape 1 * 3 * 10 * 20
        )

        self.model = nn.Sequential(
            # The input should be a 1 * 3 * 10 * 20 tensor
            nn.ConvTranspose2d(3, 64, kernel_size=5, stride=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 128, kernel_size=5, stride=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 3, kernel_size=(4, 24), stride=2),
            nn.BatchNorm2d(3),
            nn.Sigmoid()
            # The output should be a 1 * 3 * 100 * 200 image of the in painting region
        )

        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        self.counter = 0
        self.progress = []
        pass

    def forward(self, inputs):
        convolved_image = self.concentration(inputs)
        # noise = torch.randn(32, 3, 2, 20)
        # Add a noise tensor of shape 2 * 20 * 3 vertically
        # concatenated_tensor = torch.cat((convolved_image, noise), dim=2)
        return self.model(convolved_image)

    def train(self, D, global_x, local_x, ur_local):
        # Calculate d_score and the total loss
        d_score = D.forward(global_x, local_x)
        label = torch.ones(d_score.shape)
        bce_loss = D.cal_bce_loss(d_score, label)
        mse_loss = D.cal_mse_loss(local_x, ur_local)

        loss = bce_loss + mse_loss

        # Increase counter and accumulate error every 10
        self.counter += 1
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass

        # Reset gradients, perform back propagation and update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        pass

    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.1, 0.2))
        pass


In [None]:
# @title
# Test Discriminator

D = Discriminator().to(device)
G = Generator().to(device)

for batch_in_image, batch_ur_local, batch_parameter in dataloader:
    # Move data to CUDA device
    batch_in_image = batch_in_image.to(device)
    batch_ur_local = batch_ur_local.to(device)
    batch_parameter = batch_parameter.to(device)

    # Train the discriminator on true
    D.train_on_true(batch_in_image, batch_ur_local)

    # Train the discriminator on false
    batch_fake_image = G.forward(batch_in_image).detach()
    # Patch the image
    patched = []
    for i in range(batch_fake_image.size(0)):
        in_image = batch_in_image[i].unsqueeze(0)
        raw_patches = batch_fake_image[i].unsqueeze(0)
        parameter = batch_parameter[i].unsqueeze(1)

        in_img = in_image.permute(0,2,3,1).view(80,256,3).cpu().numpy()
        gray_patches = raw_patches.permute(0,2,3,1).view(100,200,3).cpu().numpy()
        patches = (gray_patches * 255).clip(0, 255).astype(numpy.uint8)
        coordinate = parameter.cpu().numpy()

        left_patch = patches[0:100, 0:100, ...]
        right_patch = patches[0:100, 100:200, ...]
        completion = in_painting(in_img, coordinate, left_patch, right_patch)
        patched.append(torch.FloatTensor(completion).permute(2,0,1).view(3, 80, 256) / 255.0)

    # batch_patched = torch.cat(patched, dim=0).to(device)
    batch_patched = torch.stack(patched).to(device)
    print(batch_patched.shape)

    # Train the discriminator on false
    fake_image = G.forward(batch_in_image).detach()
    D.train_on_fake(batch_patched, batch_fake_image, batch_ur_local)

    break

In [None]:
def process_image(batch_in_image, batch_fake_image, batch_parameter):
    # Unpack the batch, perform individual operation on each and repack in to batch
    patched = []
    for i in range(batch_fake_image.size(0)):
        in_image = batch_in_image[i].unsqueeze(0)
        raw_patches = batch_fake_image[i].unsqueeze(0)
        parameter = batch_parameter[i].unsqueeze(1)

        in_img = in_image.permute(0,2,3,1).view(80,256,3).cpu().numpy()
        gray_patches = raw_patches.detach().permute(0,2,3,1).view(100,200,3).cpu().numpy()
        patches = (gray_patches).clip(0, 1).astype(numpy.float32)
        coordinate = parameter

        left_patch = patches[0:100, 0:100, ...]
        right_patch = patches[0:100, 100:200, ...]
        completion = in_painting(in_img, coordinate, left_patch, right_patch)
        patched.append(torch.FloatTensor(completion).permute(2,0,1).view(3, 80, 256))

    # Re-pack to mini-batch
    batch_patched = torch.stack(patched).to(device)
    return batch_patched

In [None]:
# Create Discriminator and Generator
D = Discriminator().to(device)
G = Generator().to(device)

In [None]:
# Load parameters (opt)
D.load_state_dict(torch.load('/content/drive/My Drive/Colab Notebooks/D_06042237_30.pth'))
G.load_state_dict(torch.load('/content/drive/My Drive/Colab Notebooks/G_06042237_30.pth'))

In [None]:
%%time

# Train the Network

current_epoch = 0
epoch = 5
batch_size = 32

for epoch in range(epoch):
    current_epoch += 1
    print("epoch = ", current_epoch)

    for batch_in_image, batch_ur_local, batch_parameter, batch_ur_image in dataloader:
        # Move data to CUDA device
        gpu_in_image = batch_in_image.to(device)
        gpu_ur_local = batch_ur_local.to(device)
        gpu_ur_image = batch_ur_image.to(device)

        # Train the discriminator on true *
        D.train_on_true(gpu_ur_image, gpu_ur_local)

        # In paint with noise tensor
        noise = (0.01 * torch.randn(3, 100, 200)).to(device)
        batch_noise = noise.unsqueeze(0).repeat(32, 1, 1, 1)
        gpu_input = process_image(batch_in_image, batch_noise, batch_parameter)

        # Generate patches and global for the next trainings
        gpu_g_image = G.forward(gpu_input)
        gpu_patched = process_image(batch_in_image, gpu_g_image, batch_parameter)

        # Train the discriminator on fake (false) *
        gpu_fake_image = gpu_g_image.detach()
        D.train_on_fake(gpu_patched, gpu_fake_image, gpu_ur_local)

        # Train the generator *
        G.train(D, gpu_patched, gpu_g_image, gpu_ur_local)

torch.save(D.state_dict(), '/content/drive/My Drive/Colab Notebooks/D_ep_5.pth')
torch.save(G.state_dict(), '/content/drive/My Drive/Colab Notebooks/G_ep_5.pth')

In [None]:
torch.save(D.state_dict(), '/content/drive/My Drive/Colab Notebooks/D_06050148_35.pth')
torch.save(G.state_dict(), '/content/drive/My Drive/Colab Notebooks/G_06050148_35.pth')

In [None]:
# Check the result after training

fig, axs = plt.subplots(1, 4, figsize=(15, 5))

for batch_in_image, batch_ur_local, batch_parameter, batch_ur_image in dataloader:
    # Move data to CUDA device
    batch_in_image = batch_in_image.to(device)
    batch_ur_local = batch_ur_local.to(device)
    batch_parameter = batch_parameter.to(device)

    batch_patch = G.forward(batch_in_image).detach()

    for i in range(batch_patch.size(0)):
        in_image = batch_in_image[i].unsqueeze(0)
        raw_patches = batch_patch[i].unsqueeze(0)
        ur_patches = batch_ur_local[i].unsqueeze(0)
        parameter = batch_parameter[i].unsqueeze(1)

        in_img = in_image.permute(0,2,3,1).view(80,256,3).cpu().numpy()
        real_patch = ur_patches.permute(0,2,3,1).view(100,200,3).cpu().numpy()
        gray_patches = raw_patches.permute(0,2,3,1).view(100,200,3).cpu().numpy()
        patches = (gray_patches).clip(0, 1).astype(numpy.float32)
        coordinate = parameter.cpu().numpy()

        axs[0].imshow(patches, cmap='gray')
        axs[0].set_title("Patch")

        axs[1].imshow(real_patch, cmap='gray')
        axs[1].set_title("Original")

        left_patch = patches[0:100, 0:100, ...]
        right_patch = patches[0:100, 100:200, ...]
        completion = in_painting(in_img, parameter, left_patch, right_patch)

        axs[2].imshow(in_img, cmap='gray')
        axs[2].set_title("Input")

        axs[3].imshow(completion.clip(0, 1).astype(numpy.float32), cmap='gray')
        axs[3].set_title("Completion")

        break

    break

In [None]:
# Plot discriminator error
D.plot_progress()

In [None]:
# Plot generator error
G.plot_progress()

In [None]:
# Application

fig, axes = plt.subplots(1, 2, figsize=(9, 6))

input_image = cv2.imread('/content/drive/My Drive/Colab Notebooks/Data/ts.jpg')
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
in_image = input_image[0:80, 0:256, ...]
ur_image = input_image[80:160, 0:256, ...]

parameter = numpy.array([[81,], [41,], [53,], [48,], [179,], [40,], [48,], [49,]])

gallery = []
gallery.append(ur_image)

input = torch.FloatTensor(in_image).permute(2,0,1).view(3, 80, 256) / 255.0
input = input.to(device)
batch_input = input.unsqueeze(0).repeat(32, 1, 1, 1)
batch_patches = G.forward(batch_input).detach()

for i in range(1):
    raw_patches = batch_patches[i].unsqueeze(0)
    gray_patches = raw_patches.permute(0,2,3,1).view(100,200,3).cpu().numpy()
    patches = (gray_patches * 255).clip(0, 255).astype(numpy.uint8)

    left_patch = patches[0:100, 0:100, ...]
    right_patch = patches[0:100, 100:200, ...]

    completion = in_painting(in_image, parameter, left_patch, right_patch)
    gallery.append(completion)
    pass


axes = axes.flatten()
num = 0
for img, ax in zip(gallery, axes):
    ax.imshow(img)
    if num == 0:
        ax.set_title("Original")
    else:
        ax.set_title(f"Completion {num}")
    ax.axis('off')
    num += 1

plt.tight_layout()
plt.show()