In [None]:
from glob import glob
pathes = glob("/kaggle/input/imagenetmini-1000/imagenet-mini/train//**/*.JPEG", recursive=True)
print(len(pathes))

In [None]:
import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
from glob import glob

In [None]:
class InpaintingDataset(Dataset):
    def __init__(self, root_dir, transform=None, mask_size=100, n_masks=1):
        """
        Dataset for image inpainting.

        Args:
        - root_dir (str): Root directory containing images.
        - transform (callable, optional): A function/transform to apply to the images.
        - mask_size (int, optional): Size of the square mask to apply.
        - n_masks (int, optional): Number of masks to apply per image.
        """
        self.image_paths = glob(root_dir + "/**/*.JPEG", recursive=True)
        self.transform = transform
        self.mask_size = mask_size
        self.n_masks = n_masks

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        """
        Get a sample from the dataset.

        Args:
        - index (int): Index of the sample to retrieve.

        Returns:
        - image (Tensor): Original image.
        - masked_img (Tensor): Image with random masks applied.
        - mask (Tensor): Binary mask indicating masked regions.
        - mask_region (Tensor): Coordinates of the mask regions.
        """
        image_path = self.image_paths[index]
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        masked_img, mask, mask_region = self.apply_random_mask(image, self.n_masks)

        return image, masked_img, mask, mask_region
    
    def apply_random_mask(self, img, n_masks=1):
        """
        Apply random masks to the input image.

        Args:
        - img (Tensor): Input image.
        - n_masks (int, optional): Number of masks to apply.

        Returns:
        - masked_img (Tensor): Image with random masks applied.
        - mask (Tensor): Binary mask indicating masked regions.
        - mask_region (Tensor): Coordinates of the mask regions.
        """
        masked_img = img.clone()
        fill_value = masked_img.mean()
        mask = torch.zeros(masked_img.shape)
        
        y1 = torch.randint(0, masked_img.shape[1] - self.mask_size, (n_masks, ))
        x1 = torch.randint(0, masked_img.shape[2] - self.mask_size, (n_masks, ))
        y2, x2 = y1 + self.mask_size, x1 + self.mask_size
        
        for i in range(n_masks):
            masked_img[:, y1[i]:y2[i], x1[i]:x2[i]] = fill_value
            mask[:, y1[i]:y2[i], x1[i]:x2[i]] = 1
        
        return masked_img, mask, torch.stack((y1, x1), axis=1)

In [None]:
root_dir = "/kaggle/input/imagenetmini-1000/imagenet-mini/train"
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])
dataset = InpaintingDataset(root_dir, transform, mask_size=50, n_masks=2)
len(dataset)

In [None]:
image, masked_img, mask, masks_coords = dataset[0]
image.shape, masked_img.shape, mask.shape, masks_coords.shape

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.imshow(image.permute(1, 2, 0).numpy())
plt.subplot(1, 3, 2)
plt.imshow(masked_img.permute(1, 2, 0).numpy())
plt.subplot(1, 3, 3)
plt.imshow(mask.permute(1, 2, 0).numpy())

In [None]:
def get_masked_region(image, masks_coords, mask_size):
    """
    Extract masked regions from the input image based on the provided coordinates.

    Args:
    - image (Tensor): Input image.
    - masks_coords (Tensor): Coordinates of the masked regions.
    - mask_size (int): Size of the square mask.

    Returns:
    - regions (list): List of masked regions extracted from the image.
    """
    regions = []
    for y1, x1 in masks_coords:
        regions.append(image[:, y1:y1+mask_size, x1:x1+mask_size])
    return regions

In [None]:
regions = get_masked_region(image, masks_coords, 50)
for i in range(len(regions)):
    plt.subplot(1, 2, i + 1)
    plt.imshow(regions[i].permute(1, 2, 0).numpy())

In [None]:
import torch.nn as nn
import torch.nn.functional as F

## Completion Class

In [None]:
class Generator(nn.Module):
    """
    """
    
    def __init__(self):
        """
        Initialize Generator model
        """
        super(CompletionNetwork, self).__init__()
        conv_block = lambda in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1: [
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ]

        deconv_block = lambda in_channels, out_channels, kernel_size, stride=1, padding=0: [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ]
        
        self.encoder = nn.Sequential(
            *conv_block(4, 64, kernel_size=5, padding=2),
            *conv_block(64, 128, kernel_size=3, stride=2, padding=1),
            *conv_block(128, 128, kernel_size=3, padding=1),
            *conv_block(128, 256, kernel_size=3, stride=2, padding=1),
            *conv_block(256, 256, kernel_size=3, padding=1),
            *conv_block(256, 256, kernel_size=3, padding=1),
            *conv_block(256, 256, kernel_size=3, dilation=2, padding=2),
            *conv_block(256, 256, kernel_size=3, dilation=4, padding=4),
            *conv_block(256, 256, kernel_size=3, dilation=8, padding=8),
            *conv_block(256, 256, kernel_size=3, dilation=16, padding=16),
            *conv_block(256, 256, kernel_size=3, padding=1),
            *conv_block(256, 256, kernel_size=3, padding=1)
        )
        
        self.decoder = nn.Sequential(
            *deconv_block(256, 128, kernel_size=4, stride=2, padding=1),
            *conv_block(128, 128, kernel_size=3, padding=1),
            *deconv_block(128, 64, kernel_size=4, stride=2, padding=1),
            *conv_block(64, 32, kernel_size=3, padding=1),
            nn.Conv2d(32, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
        
        
    def forward(self, x):
        encoded = self.encoder(x)
        output = self.decoder(encoded)
        return output

## Global discriminator class

In [None]:
class GlobalDiscriminator(nn.Module):
    """
    """
    
    def __init__(self, image_shape):
        """
        Initialize Global Discriminator model
        """
        conv_block = lambda in_channels, out_channels, kernel_size=5, stride=2, padding=2: [
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        ]
        
        self.layers = nn.Sequential(
            *conv_block(in_channels=image_shape[0], out_channels=64),
            *conv_block(in_channels=64, out_channels=128),
            *conv_block(in_channels=128, out_channels=256),
            *conv_block(in_channels=256, out_channels=512),
            *conv_block(in_channels=512, out_channels=512),
            *conv_block(in_channels=512, out_channels=512),
        )
        
        self.flatten_layer = nn.Flatten()
        out_h = image_shape[1] // (2 ** 6)  # 6 max-pooling layers with stride 2
        out_w = image_shape[2] // (2 ** 6)
        self.linear = nn.Linear(512 * out_h * out_w, 1024)
        self.activation = nn.ReLU()
    
    def forward(self, x):
        conv_output = self.layers(x)
        flatten_output = self.flatten_layer(conv_output)
        output = self.activation(self.linear(flatten_output))
        return output


## Local discriminator class

In [None]:
class LocalDiscriminator(nn.Module):
    """
    """
    
    def __init__(self, image_shape):
        """
        Initialize Local Discriminator model
        """
        
        conv_block = lambda in_channels, out_channels, kernel_size=5, stride=2, padding=2: [
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        ]
        
        self.layers = nn.Sequential(
            *conv_block(in_channels=image_shape[0], out_channels=64),
            *conv_block(in_channels=64, out_channels=128),
            *conv_block(in_channels=128, out_channels=256),
            *conv_block(in_channels=256, out_channels=512),
            *conv_block(in_channels=512, out_channels=512),
        )
        
        self.flatten_layer = nn.Flatten()
        out_h = image_shape[1] // (2 ** 5)  # 5 max-pooling layers with stride 2
        out_w = image_shape[2] // (2 ** 5)
        self.linear = nn.Linear(512 * out_h * out_w, 1024)
        self.activation = nn.ReLU()
        
    
    def forward(self, x):
        conv_output = self.layers(x)
        flatten_output = self.flatten_layer(conv_output)
        output = self.activation(self.linear(flatten_output))
        return output

## Context discriminator class

In [None]:
class ConcatedDiscriminator(nn.Module):
    """
    """
    def __init__(self, local_input_shape, global_input_shape):
        """
        Initialize Context Discriminator model
        """
        
        self.local_discrimitator = LocalDiscriminator(local_input_shape)
        self.global_discrimitator = GlobalDiscriminator(global_input_shape)
        
        self.concatenation_layer = nn.Concatenate(dim=-1)
        self.linear_layer = nn.Linear(2048, 1)
        self.sigmoid_layer = nn.Sigmoid()
    
    def forward(self, x):
        local_discriminator_output =  self.local_discrimitator(x)
        global_discrimitator_output =  self.global_discrimitator(x)
        
        concatenated_output = self.concatenation_layer([local_discriminator_output, global_discrimitator_output])
        output = self.sigmoid_layer(self.linear_layer(concatenated_output))
        return output
    

## Train code