In [1]:
import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
from glob import glob
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Constants

In [2]:
MASK_SIZE = 50
N_MASKS = 2
FILL_VALUE = 0.5
UNET_FLAG = False

# Dataset

In [3]:
class InpaintingDataset(Dataset):
    def __init__(self, image_paths, transform=None, mask_size=100, n_masks=1):
        """
        Dataset for image inpainting.

        Args:
        - image_paths (List[str]): List of images paths
        - transform (callable, optional): A function/transform to apply to the images.
        - mask_size (int, optional): Size of the square mask to apply.
        - n_masks (int, optional): Number of masks to apply per image.
        """
        self.image_paths = image_paths
        self.transform = transform
        self.mask_size = mask_size
        self.n_masks = n_masks

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        mask, mask_region = self.apply_random_mask(image, self.n_masks)

        return image, mask, mask_region

    def apply_random_mask(self, img, n_masks=1):
        """
        Apply random masks to the input image, ensuring no overlap.

        Args:
        - img (Tensor): Input image.
        - n_masks (int, optional): Number of masks to apply.

        Returns:
        - mask (Tensor): Binary mask indicating masked regions.
        - mask_region (Tensor): Coordinates of the mask regions.
        """
        mask = torch.zeros(img.shape[1:]).unsqueeze(0)
        mask_coords = []
        
        for _ in range(n_masks):
            intersects = True
            while intersects:
                y1 = torch.randint(0, img.shape[1] - self.mask_size, (1,)).item()
                x1 = torch.randint(0, img.shape[2] - self.mask_size, (1,)).item()
                y2, x2 = y1 + self.mask_size, x1 + self.mask_size

                intersects = any(
                    y1 < y2p and y2 > y1p and x1 < x2p and x2 > x1p
                    for y1p, x1p, y2p, x2p in mask_coords
                )

            mask[:, y1:y2, x1:x2] = 1
            mask_coords.append((y1, x1, y2, x2))

        mask_regions = torch.tensor([[y1, x1] for y1, x1, y2, x2 in mask_coords])
        return mask, mask_regions

In [4]:
import random
train_path = "/kaggle/input/imagenetmini-1000/imagenet-mini/train"
val_path = "/kaggle/input/imagenetmini-1000/imagenet-mini/val"

train_paths_list = glob(train_path + "/**/*.JPEG", recursive=True)[:10000]
val_paths_list = glob(val_path + "/**/*.JPEG", recursive=True)
# split validation data on val and test
random.shuffle(val_paths_list)
val_paths_list, test_paths_list = val_paths_list[:-1024], val_paths_list[-1024:]

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

train_dataset = InpaintingDataset(train_paths_list, transform, mask_size=MASK_SIZE, n_masks=N_MASKS)
val_dataset = InpaintingDataset(val_paths_list, transform, mask_size=MASK_SIZE, n_masks=N_MASKS)
test_dataset = InpaintingDataset(test_paths_list, transform, mask_size=MASK_SIZE, n_masks=N_MASKS)
len(train_dataset), len(val_dataset), len(test_dataset)

(10000, 2899, 1024)

In [5]:
image, mask, masks_coords = train_dataset[0]
image.shape, mask.shape, masks_coords.shape

(torch.Size([3, 256, 256]), torch.Size([1, 256, 256]), torch.Size([2, 2]))

In [None]:
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.imshow(image.permute(1, 2, 0).numpy())
plt.subplot(1, 3, 2)
masked_img = image - image * mask + FILL_VALUE * mask
plt.imshow(masked_img.permute(1, 2, 0).numpy())
plt.subplot(1, 3, 3)
plt.imshow(mask.permute(1, 2, 0).numpy(), cmap="binary")

In [7]:
def get_masked_region(image, masks_coords, mask_size):
    """
    Extract masked regions from the input image based on the provided coordinates.

    Args:
    - image (Tensor): Input image.
    - masks_coords (Tensor): Coordinates of the masked regions.
    - mask_size (int): Size of the square mask.

    Returns:
    - regions (list): List of masked regions extracted from the image.
    """
    regions = []
    for y1, x1 in masks_coords:
        regions.append(image[:, y1:y1+mask_size, x1:x1+mask_size])
    return regions

In [None]:
regions = get_masked_region(image, masks_coords, 50)
for i in range(len(regions)):
    plt.subplot(1, 2, i + 1)
    plt.imshow(regions[i].permute(1, 2, 0).numpy())

# Models classes

In [9]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.functional import mse_loss
from tqdm.notebook import trange, tqdm
from torch.optim import Adam, Adadelta
from torch.nn import BCELoss, MSELoss

## Completion Class

In [10]:
class Generator(nn.Module):
    
    def __init__(self):
        """
        Initialize Generator model
        """
        super(Generator, self).__init__()
        conv_block = lambda in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1: [
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ]

        deconv_block = lambda in_channels, out_channels, kernel_size, stride=1, padding=0: [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ]
        
        self.encoder = nn.Sequential(
            *conv_block(4, 64, kernel_size=5, padding=2),
            *conv_block(64, 128, kernel_size=3, stride=2, padding=1),
            *conv_block(128, 128, kernel_size=3, padding=1),
            *conv_block(128, 256, kernel_size=3, stride=2, padding=1),
            *conv_block(256, 256, kernel_size=3, padding=1),
            *conv_block(256, 256, kernel_size=3, padding=1),
            *conv_block(256, 256, kernel_size=3, dilation=2, padding=2),
            *conv_block(256, 256, kernel_size=3, dilation=4, padding=4),
            *conv_block(256, 256, kernel_size=3, dilation=8, padding=8),
            *conv_block(256, 256, kernel_size=3, dilation=16, padding=16),
            *conv_block(256, 256, kernel_size=3, padding=1),
            *conv_block(256, 256, kernel_size=3, padding=1)
        )
        
        self.decoder = nn.Sequential(
            *deconv_block(256, 128, kernel_size=4, stride=2, padding=1),
            *conv_block(128, 128, kernel_size=3, padding=1),
            *deconv_block(128, 64, kernel_size=4, stride=2, padding=1),
            *conv_block(64, 32, kernel_size=3, padding=1),
            nn.Conv2d(32, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
        
        
    def forward(self, x):
        encoded = self.encoder(x)
        output = self.decoder(encoded)
        return output

In [11]:
class UNetGenerator(nn.Module):

    def __init__(self):
        super(UNetGenerator, self).__init__()

        def conv_block(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )

        def deconv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1):
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        
        self.enc1 = conv_block(4, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)
        self.enc5 = conv_block(512, 1024)

        self.dec4 = deconv_block(1024, 512)
        self.dec3 = deconv_block(1024, 256)
        self.dec2 = deconv_block(512, 128)
        self.dec1 = deconv_block(256, 64)
        
        self.final = nn.Sequential(
            nn.Conv2d(128, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 3, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(F.max_pool2d(e1, 2))
        e3 = self.enc3(F.max_pool2d(e2, 2))
        e4 = self.enc4(F.max_pool2d(e3, 2))
        e5 = self.enc5(F.max_pool2d(e4, 2))
        
        d4 = self.dec4(e5)
        d3 = self.dec3(torch.cat([d4, e4], dim=1))
        d2 = self.dec2(torch.cat([d3, e3], dim=1))
        d1 = self.dec1(torch.cat([d2, e2], dim=1))
        output = self.final(torch.cat([d1, e1], dim=1))
        
        return output

## Global discriminator class

In [12]:
class GlobalDiscriminator(nn.Module):
    
    def __init__(self, image_shape):
        super(GlobalDiscriminator, self).__init__()
        """
        Initialize Global Discriminator model
        """
        conv_block = lambda in_channels, out_channels, kernel_size=5, stride=2, padding=2: [
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        ]
        
        self.layers = nn.Sequential(
            *conv_block(in_channels=image_shape[0], out_channels=64),
            *conv_block(in_channels=64, out_channels=128),
            *conv_block(in_channels=128, out_channels=256),
            *conv_block(in_channels=256, out_channels=512),
            *conv_block(in_channels=512, out_channels=512),
            *conv_block(in_channels=512, out_channels=512),
        )
        
        self.flatten_layer = nn.Flatten()
        out_h = image_shape[1] // (2 ** 6)
        out_w = image_shape[2] // (2 ** 6)
        self.linear = nn.Linear(512 * out_h * out_w, 1024)
        self.activation = nn.ReLU()
    
    def forward(self, x):
        conv_output = self.layers(x)
        flatten_output = self.flatten_layer(conv_output)
        output = self.activation(self.linear(flatten_output))
        return output


## Local discriminator class

In [13]:
class LocalDiscriminator(nn.Module):
    
    def __init__(self, image_shape):
        """
        Initialize Local Discriminator model
        """
        print(image_shape)
        super(LocalDiscriminator, self).__init__()
        conv_block = lambda in_channels, out_channels, kernel_size=5, stride=2, padding=2: [
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        ]
        
        self.layers = nn.Sequential(
            *conv_block(in_channels=image_shape[0], out_channels=64),
            *conv_block(in_channels=64, out_channels=128),
            *conv_block(in_channels=128, out_channels=256),
            *conv_block(in_channels=256, out_channels=512),
            *conv_block(in_channels=512, out_channels=512),
        )
        
        self.flatten_layer = nn.Flatten()
        out_h = image_shape[1] // (2 ** 5)
        out_w = image_shape[2] // (2 ** 5)
        self.linear = nn.Linear(512 * out_h * out_w, 1024)
        self.activation = nn.ReLU()
        
    
    def forward(self, x):
        conv_output = self.layers(x)
        flatten_output = self.flatten_layer(conv_output)
        output = self.activation(self.linear(flatten_output))
        return output

## Context discriminator class

In [14]:
class ContextDiscriminator(nn.Module):
    
    def __init__(self, local_input_shape, global_input_shape):
        """
        Initialize Context Discriminator model
        """
        super(ContextDiscriminator, self).__init__()
        self.local_discrimitator = LocalDiscriminator(local_input_shape)
        self.global_discrimitator = GlobalDiscriminator(global_input_shape)
        
        self.linear_layer = nn.Linear(2048, 1)
        self.sigmoid_layer = nn.Sigmoid()
    
    def forward(self, x):
        x_global, x_local = x
        local_discriminator_output =  self.local_discrimitator(x_local)
        global_discrimitator_output =  self.global_discrimitator(x_global)
        global_discrimitator_output = global_discrimitator_output.repeat(N_MASKS, 1)
        
        concatenated_output = torch.cat([local_discriminator_output, global_discrimitator_output], dim=-1)
        output = self.sigmoid_layer(self.linear_layer(concatenated_output))
        return output
    

## Train code

In [15]:
def get_masked_regions_batched(images, masks_coords, MASK_SIZE=50):
    MASK_SIZE = min(MASK_SIZE, 50)
    padded_size = 64
    pad_half = (padded_size - MASK_SIZE) // 2

    masked_regions = []
    N_MASKS = masks_coords.shape[1]

    for mask_ind in range(N_MASKS):
        y1 = masks_coords[:, mask_ind, 0].cpu()
        x1 = masks_coords[:, mask_ind, 1].cpu()

        for img, y, x in zip(images, y1, x1):
            y_min = max(0, y - pad_half)
            y_max = min(img.shape[1], y + MASK_SIZE + pad_half)
            x_min = max(0, x - pad_half)
            x_max = min(img.shape[2], x + MASK_SIZE + pad_half)

            region = img[:, y_min:y_max, x_min:x_max].to(img.device)
            padded_region = torch.zeros((img.shape[0], padded_size, padded_size), dtype=img.dtype, device=img.device)
            y_start = max(0, pad_half - (y - y_min))
            x_start = max(0, pad_half - (x - x_min))
            y_end = y_start + min(region.shape[1], padded_size)
            x_end = x_start + min(region.shape[2], padded_size)

            padded_region[:, y_start:y_end, x_start:x_end] = region

            masked_regions.append(padded_region)

    return torch.stack(masked_regions)


In [17]:
class Trainer():
    def __init__(self, generator, context_discriminator, config, 
                 train_data, val_data, test_data):
        """
        Trainer class for training a generative model and a context discriminator.
        """
        self.generator = generator
        self.context_disc = context_discriminator
        self.device = config["device"]
        
        self.optimizer_gen = Adam(generator.parameters(), lr=config["generator_lr"])
        self.optimizer_disc = Adam(context_discriminator.parameters(), lr=config["discriminator_lr"])
        
        self.train_loader = DataLoader(train_data, batch_size=config["batch_size"], 
                                       shuffle=True, pin_memory=True, num_workers=4)
        self.val_loader = DataLoader(val_data, batch_size=config["batch_size"], 
                                     shuffle=False, pin_memory=True, num_workers=4)
        self.test_data = DataLoader(test_data, batch_size=config["batch_size"], 
                                    shuffle=False, pin_memory=True, num_workers=4)
        
        self.disc_loss = BCELoss()
        self.alpha = config["alpha"]
        self.logs = {
            "train_loss_gen": [],
            "train_loss_disc" : [],
            "val_loss_gen": []
        }
        
    def train_generator_step(self, batch, train_disc):
        images, masks, masks_coords = batch
        masked_img = images * (1 - masks) + FILL_VALUE * masks
        gen_input = torch.cat((masked_img, masks), dim=1)
        out_gen = self.generator(gen_input)
        loss = mse_loss(out_gen, images)

        if train_disc:
            fake_input_disc_local = get_masked_regions_batched(out_gen, masks_coords)
            out_disc_fake = self.context_disc((out_gen, fake_input_disc_local))
            real_labels = torch.ones((len(images), 1), device=self.device)
            loss_disc = self.disc_loss(out_disc_fake, real_labels)
            loss += self.alpha * loss_disc

        return loss, out_gen.detach()

    
    def train_disc_only_step(self, batch):
        """
        Perform a training step for the discriminator only.

        Args:
            batch (tuple): A batch of data containing images, masks, mask coordinates, and completed images.

        Returns:
            torch.Tensor: The discriminator loss.
        """
        images, masks, masks_coords, completed_image = batch
        fake_input_disc_local = get_masked_regions_batched(completed_image.detach(), masks_coords)
        out_disc_fake = self.context_disc((completed_image.detach(), fake_input_disc_local))
        fake_labels = torch.zeros((len(images), 1), device=self.device)
        loss_fake = self.disc_loss(out_disc_fake, fake_labels)

        real_input_disc_local = get_masked_regions_batched(images, masks_coords)
        out_disc_real = self.context_disc((images, real_input_disc_local))
        real_labels = torch.ones((len(images), 1), device=self.device)
        loss_real = self.disc_loss(out_disc_real, real_labels)

        loss = (loss_fake + loss_real) * self.alpha / 2
        return loss
        
    def eval_generator(self, with_adversarial=False):
        self.generator.eval()
        self.context_disc.eval()
        val_loss = 0.0
        for batch in tqdm(self.val_loader, desc="Evaluation"):
            images, masks, masks_coords = batch
            images, masks, masks_coords = images.to(self.device), masks.to(self.device), masks_coords.to(self.device)
            with torch.no_grad():
                loss_gen, _ = self.train_generator_step((images, masks, masks_coords),
                                                                  train_disc=with_adversarial)
            val_loss += loss_gen.data.item()
        return val_loss / len(self.val_loader)
    
    def train(self, epochs, train_disc=3):
        """
        Train the generator and discriminator models for a given number of epochs.
        """
        for epoch in range(epochs):
            train_loss_gen = 0.0
            train_loss_disc = 0.0
            self.generator.train()
            self.context_disc.train()
            for batch in tqdm(self.train_loader, desc="Training"):
                images, masks, masks_coords = batch
                images, masks, masks_coords = images.to(self.device), masks.to(self.device), masks_coords.to(self.device)
                
                loss_gen, completed_image = self.train_generator_step((images, masks, masks_coords),
                                                                      train_disc=(epoch >= train_disc))
                loss_gen.backward()
                self.optimizer_gen.step()
                self.optimizer_gen.zero_grad()
                train_loss_gen += loss_gen.data.item()
                
                if epoch >= train_disc:
                    self.optimizer_disc.zero_grad()
                    loss_disc = self.train_disc_only_step((images, masks, masks_coords, completed_image))
                    loss_disc.backward()
                    self.optimizer_disc.step()
                    train_loss_disc += loss_disc.data.item()
            
            train_loss_gen /= len(self.train_loader)
            train_loss_disc /= len(self.train_loader)
            
            val_loss_gen = self.eval_generator(with_adversarial = (epoch >= train_disc))
            print(f"Epoch: {epoch} | Train gen loss: {train_loss_gen:.5f}| Train disc loss: {train_loss_disc:.5f} | Eval loss: {val_loss_gen:.5f}")
            
            self.logs["train_loss_gen"].append(train_loss_gen)
            self.logs["train_loss_disc"].append(train_loss_disc)
            self.logs["val_loss_gen"].append(val_loss_gen)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
print("Using UNet Generator:", UNET_FLAG)
generator = UNetGenerator().to(device) if UNET_FLAG else Generator().to(device)
context_discriminator = ContextDiscriminator((3, 64, 64), (3, 256, 256)).to(device)
config = {
    "device": device,
    "generator_lr": 5e-4,
    "discriminator_lr": 2e-4,
    "batch_size": 32,
    "alpha": 4e-4
}

In [19]:
trainer = Trainer(generator, context_discriminator, config, 
                 train_dataset, val_dataset, test_dataset)

In [None]:
trainer.train(epochs=15, train_disc=5)

In [None]:
images, masks, masks_coords = next(iter(trainer.val_loader))
images, masks, masks_coords = images.to(device), masks.to(device), masks_coords.to(device)
masked_img = images - images * masks + FILL_VALUE * masks
gen_input = torch.cat((masked_img, masks), dim=1)

In [None]:
generator.eval()
with torch.no_grad():
    out = generator(gen_input)

In [None]:
images = images.detach().cpu()
out = out.detach().cpu()
masks = masks.detach().cpu()
masks_coords = masks_coords.detach().cpu()

In [None]:
n_rows, n_cols = 2, 4
n_images = 8
plt.figure(figsize=(10, 5))
for i in range(n_images):
    image = images[i].permute(1, 2, 0).numpy()
    mask = masks[i].permute(1, 2, 0).numpy()
    plt.subplot(n_rows, n_cols, i+1)
    plt.imshow(image)

In [None]:
n_rows, n_cols = 2, 4
n_images = 8
plt.figure(figsize=(10, 5))
for i in range(n_images):
    image = out[i].permute(1, 2, 0).numpy()
    mask = masks[i].permute(1, 2, 0).numpy()
    plt.subplot(n_rows, n_cols, i+1)
    plt.imshow(image)
    for coords in masks_coords[i]:
        rect = Rectangle(coords.numpy()[::-1], MASK_SIZE, MASK_SIZE, linewidth=1, edgecolor='r', facecolor='none')
        plt.gca().add_patch(rect)

In [None]:
del generator
del context_discriminator
torch.cuda.empty_cache()