In [None]:
# Image Inpainting Using GANs
# This notebook contains both the training and testing of the model

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Conv2d, ConvTranspose2d, BatchNorm2d, Dropout, LeakyReLU, ReLU, Linear, Flatten, Tanh, InstanceNorm2d
import torchvision
import matplotlib.pyplot as plt
import os
from PIL import Image
from tqdm import tqdm
import random
from IPython.display import clear_output
from torchsummary import summary
import torch.nn.functional as F
from utils import *

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
device

In [None]:
torch.cuda.empty_cache()

In [None]:
IMG_SIZE = 256
MASK_SIZE = 128
BATCH_SIZE = 2
TEST_SIZE = 64
NUM_CHANNELS = 3
LEARNING_RATE_DISC = 0.0001
LEARNING_RATE_GEN = 0.0001
EPOCHS = 30
LAMBDA_AD = 0.01
LAMBDA_R = 1

In [None]:
root_dir = "archive/img_align_celeba/img_align_celeba/"

In [None]:
def augment(image):

    # Function to add the square patches to the images

    x1,y1 = np.random.randint(0, (IMG_SIZE - MASK_SIZE) , 2)
    x2,y2 = x1 + MASK_SIZE, y1 + MASK_SIZE

    mask = np.zeros((1, IMG_SIZE,IMG_SIZE), dtype=np.float32)
    mask[: ,y1:y2,x1:x2] = 1.

    return mask

In [None]:
class CustomDataset(torch.utils.data.Dataset):

    def __init__(self, root_dir, transforms=None):

        self.root_dir = root_dir
        self.transforms = transforms
        self.files = sorted(os.listdir(root_dir))

    def __len__(self):

        return len(self.files)

    def __getitem__(self, idx):
        image_path = self.files[idx]
        image_path = self.root_dir + image_path
        image = Image.open(image_path)

        image = np.array(image, dtype = np.float32)
        image = image / 255

        if self.transforms:
            image = self.transforms(image)

            mask = augment(image)

        return image, mask

In [None]:
from torchvision.transforms import ToTensor, Resize, Compose

transforms = Compose([ToTensor(), Resize((IMG_SIZE, IMG_SIZE))])

In [None]:
dataset = CustomDataset(root_dir,transforms)
train_size = dataset.__len__() - TEST_SIZE
test_size = TEST_SIZE

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size,test_size])

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size= BATCH_SIZE, shuffle=True, num_workers=2)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size= 1, shuffle=test_dataset.__len__(), num_workers=2)

In [None]:
for i,(image,mask) in enumerate(test_dataloader):

    # Save test images and masks in a separate directory

    to_pil = torchvision.transforms.ToPILImage()
    img_pil = to_pil(image[0])
    img_pil.save(f'TestingImages/img{i}.png')
    mask_pil = to_pil(mask[0])
    mask_pil.save(f'TestingMasks/img{i}.png')

In [None]:
sample = next(iter(test_dataloader))
sample_idx = random.randint(0,0)

In [None]:
# Check if the masks for the images and the dataset are fine

plt.subplot(2,2,1)
plt.imshow(sample[0][sample_idx].permute(1,2,0))
plt.subplot(2,2,2)
plt.imshow(sample[1][sample_idx].permute(1,2,0))

In [None]:
class Discriminator(nn.Module):

    def __init__(self,in_channels=3, cnum=64):

        super(Discriminator, self).__init__()

        
        self.disc = nn.Sequential(
            
            ConvSN(in_channels, cnum),
            ConvSN(cnum, 2*cnum),
            ConvSN(2*cnum, 4*cnum),
            ConvSN(4*cnum, 4*cnum),
            ConvSN(4*cnum, 4*cnum),
            ConvSN(4*cnum, 4*cnum),
            nn.Flatten()
        )

    def forward(self, x):

        return self.disc(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, cnum_in = 3, cnum= 64, cnum_out=3):
        super(Generator, self).__init__()

        self.coarse_net = nn.Sequential(
            
            GatedConv(cnum_in+2, cnum//2, kernel_size=5, stride=1, padding=2),
    
           
            DownSample(cnum//2, cnum),
            DownSample(cnum, 2*cnum),
    
            GatedConv(2*cnum, 2*cnum, kernel_size=3, stride=1),
            GatedConv(2*cnum, 2*cnum, kernel_size=3, rate=2, padding=2),
            GatedConv(2*cnum, 2*cnum, kernel_size=3, rate=4, padding=4),
            GatedConv(2*cnum, 2*cnum, kernel_size=3, rate=8, padding=8),
            GatedConv(2*cnum, 2*cnum, kernel_size=3, rate=16, padding=16),
            GatedConv(2*cnum, 2*cnum, kernel_size=3, stride=1),
            GatedConv(2*cnum, 2*cnum, kernel_size=3, stride=1),
    
            UpSample(2*cnum, cnum),
            UpSample(cnum, cnum//4, cnum//2),
    
            nn.Conv2d(cnum//4, cnum_out, kernel_size=3, stride=1, padding = "same"),
                                    
            nn.Tanh()
        )

        self.refine_down = nn.Sequential(

            GatedConv(cnum_in, cnum//2, kernel_size=5, stride=1, padding=2),
    
           
            DownSample(cnum//2, cnum),
            DownSample(cnum, 2*cnum),
    
            GatedConv(2*cnum, 2*cnum, kernel_size=3, stride=1),
            GatedConv(2*cnum, 2*cnum, kernel_size=3, rate=2, padding=2),
            GatedConv(2*cnum, 2*cnum, kernel_size=3, rate=4, padding=4),
            GatedConv(2*cnum, 2*cnum, kernel_size=3, rate=8, padding=8),
            GatedConv(2*cnum, 2*cnum, kernel_size=3, rate=16, padding=16)
        )

        self.attention = SelfAttention(2*cnum, "relu")

        self.refine_up = nn.Sequential(
            
            GatedConv(2*2*cnum, 2*cnum, kernel_size=3, stride=1),
            GatedConv(2*cnum, 2*cnum, kernel_size=3, stride=1),
    
            UpSample(2*cnum, cnum),
            UpSample(cnum, cnum//4, cnum//2),
    
            nn.Conv2d(cnum//4, cnum_out, kernel_size=3, stride=1, padding = "same"),
                                    
            nn.Tanh()
        )

    def forward(self, x, mask):
        x_ = x
        ones_x = torch.ones_like(mask)[:, 0:1].to(device)
        x = torch.cat([x, ones_x, ones_x * mask], dim = 1)
        x_coarse = self.coarse_net(x)

        x2 = x_coarse * mask + x_
        x_conv = self.refine_down(x2)
        x_att = self.refine_down(x2)
        x_att = self.attention(x_att)
        x_cat = torch.cat([x_conv, x_att], dim =1)
        x_refine = self.refine_up(x_cat)
        
        return x_refine

In [None]:
global_disc = Discriminator(in_channels=3).to(device)
gen = Generator().to(device)

In [None]:
summary(gen, [(3,IMG_SIZE,IMG_SIZE), (1, IMG_SIZE, IMG_SIZE)], 1)

In [None]:
def initialize_weights(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

In [None]:
initialize_weights(global_disc)
initialize_weights(gen)

In [None]:
# global_disc.load_state_dict(torch.load('disc.pth'))
# gen.load_state_dict(torch.load('gen.pth'))

In [None]:
gen_opt = torch.optim.Adam(gen.parameters(), lr = LEARNING_RATE_GEN, betas= (0.5, 0.999))
global_opt = torch.optim.Adam(global_disc.parameters(), lr = LEARNING_RATE_DISC, betas= (0.5, 0.999))

In [None]:
g_scaler = torch.cuda.amp.GradScaler()
global_scaler = torch.cuda.amp.GradScaler()

In [None]:
L1_LOSS = torch.nn.L1Loss()

In [None]:
def train_loop(dataloader, global_disc, gen , global_opt, gen_opt, global_scaler, g_scaler, l1):

    count = 0

    loop = tqdm(dataloader, leave=True)

    for batch in loop:

        gen.train()
        global_disc.train()
        targets = batch[0].to(device)
        masks = batch[1].to(device)

        x_t = targets * (1 - masks).float() + masks
        x_pred = gen(x_t, masks)
    
        with torch.cuda.amp.autocast():

            fake_global = global_disc(x_pred.detach())
            real_global = global_disc(targets)
    
            global_loss = torch.mean(torch.relu(torch.ones_like(real_global).to(device) - real_global)) + torch.mean(torch.relu(torch.ones_like(fake_global).to(device) + fake_global))
    
        global_disc.zero_grad()
        global_scaler.scale(global_loss).backward(retain_graph=True)
        global_scaler.step(global_opt)
        global_scaler.update()

        with torch.cuda.amp.autocast():

            fake_gen = global_disc(x_pred)
            adversarial_loss = -torch.mean(torch.relu(torch.ones_like(fake_gen).to(device) + fake_gen))
            recon_loss = l1(x_pred, targets)
            gen_loss = LAMBDA_AD* adversarial_loss + LAMBDA_R * recon_loss 

        gen.zero_grad()
        g_scaler.scale(gen_loss).backward()
        g_scaler.step(gen_opt)
        g_scaler.update()

        count += 1
        
        if count % 100 == 0:
            clear_output(wait=True)
            print(f"Generator Loss : {gen_loss} Global Loss : {global_loss}")
            gen.eval()
            plt.subplot(2, 2, 1)
            x = ((1 - sample[1][sample_idx])*sample[0][sample_idx] + sample[1][sample_idx]).expand((1,3,IMG_SIZE,IMG_SIZE)).to(device)
            sample_pred = gen(x, sample[1][sample_idx].expand((1,1,IMG_SIZE,IMG_SIZE)).to(device))
            plt.imshow(sample_pred[0].permute(1,2,0).cpu().detach().numpy())
            plt.show()

        if count % 10000 == 0:

            torch.save(gen.state_dict(), f"gen{count}.pth")
            torch.save(global_disc.state_dict(),f"disc{count}.pth")

In [None]:
for epoch in range(EPOCHS):

    print(f"EPOCH {epoch+1}:")
    train_loop(train_dataloader, global_disc, gen ,  global_opt, gen_opt, global_scaler, g_scaler, L1_LOSS)
    
    torch.save(gen.state_dict(), "Model/gen.pth")
    torch.save(global_disc.state_dict(),"Model/disc.pth")

In [None]:
# Testing section

In [None]:
gen = Generator().to(device)
disc = Discriminator().to(device)

In [None]:
gen.load_state_dict(torch.load('Model/gen.pth'))

In [None]:
images = sorted(os.listdir('TestingImages/'))
masks = sorted(os.listdir('TestingMasks/'))

In [None]:
test_images = []
test_masked_images = []
test_predictions = []

for i in range(len(images)):

    image = images[i]
    mask = masks[i]
    image = np.array(Image.open('TestingImages/' + image))
    mask = np.array(Image.open('TestingMasks/' + mask))

    test_images.append(image)
    
    image = torchvision.transforms.Compose([ToTensor()])(image).expand((1,3,IMG_SIZE,IMG_SIZE)).to(device)
    mask = torchvision.transforms.Compose([ToTensor()])(mask).expand((1,1,IMG_SIZE,IMG_SIZE)).to(device)

    masked_image = image * (1 - mask).float() + mask

    test_masked_images.append(masked_image[0].permute(1,2,0).cpu().detach().numpy())

    with torch.no_grad():
        gen.eval()

        pred = gen(masked_image, mask)

        pred_image = pred * mask + image * (1 - mask).float()
        pred_image = pred_image[0].permute(1,2,0).cpu().detach().numpy()
        test_predictions.append(pred_image)
    

In [None]:
for index in range(len(test_images)):

    fig = plt.figure(figsize=(8,5))

    plt.subplot(2,3,1)
    plt.imshow(test_images[index])
    plt.title('Original Image')
    plt.axis(False)

    plt.subplot(2,3,2)
    plt.imshow(test_masked_images[index])
    plt.title('Masked Image')
    plt.axis(False)

    plt.subplot(2,3,3)
    plt.imshow(test_predictions[index])
    plt.title('Inpainted Image')
    plt.axis(False)

    plt.savefig(f'Results/img{index}')

    plt.close()
    plt.show()