In [None]:
import os
import shutil
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torchvision.models as models
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR
import numpy as np
import cv2

In [None]:
def delete_all_in_folder(folder_path):
    if os.path.exists(folder_path):
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
            except Exception as e:
                print(f'Failed to delete {file_path}. Reason: {e}')
    else:
        print(f'The folder {folder_path} does not exist.')
delete_all_in_folder('/kaggle/working/')

In [None]:
def set_seed(seed_value: int):
    """
    Set the seed for reproducibility in Python, NumPy, and PyTorch.
    Args:
    - seed_value (int): The seed value to use for reproducibility.
    """
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_value = 42
set_seed(seed_value)

In [None]:
class CIFAR100InpaintingDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.folder = folder
        self.transform = transform
        self.samples = []
        for root, _, files in os.walk(folder):
            for file in files:
                if file.endswith(('.jpg', '.jpeg', '.png')):
                    self.samples.append(os.path.join(root, file))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        mask = self.generate_random_mask(image.size(1), image.size(2))
        masked_image = image * (1 - mask)
        return masked_image, mask, image
    def generate_random_mask(self, height, width):
        mask = np.zeros((height, width), dtype=np.uint8)
        for _ in range(5):
            x1, y1 = random.randint(0, width-1), random.randint(0, height-1)
            x2, y2 = random.randint(0, width-1), random.randint(0, height-1)
            cv2.line(mask, (x1, y1), (x2, y2), 1, thickness=random.randint(1, 5))
        for _ in range(5):
            x, y = random.randint(0, width-1), random.randint(0, height-1)
            radius = random.randint(5, 20)
            cv2.circle(mask, (x, y), radius, 1, -1)
        mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)
        return mask
    def get_sample(self, idx):
        img_path = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        mask = self.generate_random_mask(image.size(1), image.size(2))
        masked_image = image * (1 - mask)
        return masked_image, mask, image
class EnhancedInpaintingResNet(nn.Module):
    def __init__(self):
        super(EnhancedInpaintingResNet, self).__init__()
        resnet = models.resnet50(weights=False)
        self.encoder_layers = list(resnet.children())[:-2]
        self.encoder1 = nn.Sequential(*self.encoder_layers[:4])
        self.encoder2 = self.encoder_layers[4]
        self.encoder3 = self.encoder_layers[5]
        self.encoder4 = self.encoder_layers[6]
        self.encoder5 = self.encoder_layers[7]
        self.decoder5 = nn.Sequential(
            nn.ConvTranspose2d(2048, 1024, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )
        self.decoder4 = nn.Sequential(
            nn.ConvTranspose2d(1024 + 1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.decoder3 = nn.Sequential(
            nn.ConvTranspose2d(512 + 512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.decoder2 = nn.Sequential(
            nn.ConvTranspose2d(256 + 256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(128 + 64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.final_layer = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 3, kernel_size=3, padding=1),
            nn.Upsample(size=(224, 224), mode='bilinear', align_corners=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(enc1)
        enc3 = self.encoder3(enc2)
        enc4 = self.encoder4(enc3)
        enc5 = self.encoder5(enc4)
        dec5 = self.decoder5(enc5)
        dec5 = nn.functional.interpolate(dec5, size=enc4.shape[2:], mode='bilinear', align_corners=False)
        dec4 = self.decoder4(torch.cat([dec5, enc4], dim=1))
        dec4 = nn.functional.interpolate(dec4, size=enc3.shape[2:], mode='bilinear', align_corners=False)
        dec3 = self.decoder3(torch.cat([dec4, enc3], dim=1))
        dec3 = nn.functional.interpolate(dec3, size=enc2.shape[2:], mode='bilinear', align_corners=False)
        dec2 = self.decoder2(torch.cat([dec3, enc2], dim=1))
        dec2 = nn.functional.interpolate(dec2, size=enc1.shape[2:], mode='bilinear', align_corners=False)
        dec1 = self.decoder1(torch.cat([dec2, enc1], dim=1))
        output = self.final_layer(dec1)
        return output

In [None]:
def debug_plot_samples(dataset, num_samples=10):
    fig, axs = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
    sampled_indices = random.sample(range(len(dataset)), num_samples)
    for i, idx in enumerate(sampled_indices):
        masked_image, mask, original_image = dataset.get_sample(idx)
        original_image_np = original_image.permute(1, 2, 0).numpy()
        masked_image_np = masked_image.permute(1, 2, 0).numpy()
        mask_np = mask.squeeze(0).numpy()
        masked_image_np[mask_np == 1] = 0
        axs[i, 0].imshow(original_image_np)
        axs[i, 0].set_title('Original Image')
        axs[i, 0].axis('off')
        axs[i, 1].imshow(mask_np, cmap='gray')
        axs[i, 1].set_title('Mask')
        axs[i, 1].axis('off')
        axs[i, 2].imshow(masked_image_np)
        axs[i, 2].set_title('Masked Image')
        axs[i, 2].axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
def save_predictions(model, dataset, device, epoch, num_examples=5, save_dir='/kaggle/working'):
    model.eval()
    epoch_dir = os.path.join(save_dir, f'epoch{epoch}')
    os.makedirs(epoch_dir, exist_ok=True)
    with torch.no_grad():
        for i in range(num_examples):
            idx = random.randint(0, len(dataset) - 1)
            masked_image, mask, original_image = dataset.get_sample(idx)
            masked_image = masked_image.unsqueeze(0).to(device)
            predicted_image = model(masked_image)
            predicted_image = predicted_image.squeeze(0).cpu()
            mask = mask.expand_as(predicted_image)
            inpainted_image = original_image.clone()
            inpainted_image[mask == 1] = predicted_image[mask == 1]
            original_image_np = original_image.permute(1, 2, 0).numpy()
            masked_image_np = masked_image.squeeze(0).permute(1, 2, 0).cpu().numpy()
            inpainted_image_np = inpainted_image.permute(1, 2, 0).numpy()
            original_image_path = os.path.join(epoch_dir, f'example_{i+1}_original.png')
            masked_image_path = os.path.join(epoch_dir, f'example_{i+1}_masked.png')
            inpainted_image_path = os.path.join(epoch_dir, f'example_{i+1}_predicted.png')
            plt.imsave(original_image_path, original_image_np)
            plt.imsave(masked_image_path, masked_image_np)
            plt.imsave(inpainted_image_path, inpainted_image_np)
    print(f"Saved predictions for epoch {epoch} in {epoch_dir}")

In [None]:
def train_and_validate(model, train_loader, val_loader, dataset, criterion, optimizer, num_epochs=50, device='cuda'):
    best_val_loss = float('inf')
    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for masked_img, mask, original_img in tqdm(train_loader, desc=f'Training Epoch {epoch+1}/{num_epochs}'):
            masked_img, mask, original_img = masked_img.to(device), mask.to(device), original_img.to(device)
            outputs = model(masked_img)
            mask = mask.expand_as(outputs)
            inpainted_img = masked_img.clone()
            inpainted_img[mask == 1] = outputs[mask == 1]
            loss = criterion(outputs * mask, original_img * mask)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.5f}')
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for masked_img, mask, original_img in tqdm(val_loader, desc='Validation'):
                masked_img, mask, original_img = masked_img.to(device), mask.to(device), original_img.to(device)
                outputs = model(masked_img)
                mask = mask.expand_as(outputs)
                inpainted_img = masked_img.clone()
                inpainted_img[mask == 1] = outputs[mask == 1]
                loss = criterion(outputs * mask, original_img * mask)
                val_loss += loss.item()
        avg_val_loss = val_loss / len(val_loader)
        print(f'Validation Loss: {avg_val_loss:.5f}')
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'bestmodel.pth')
            print(f"Saved new best model with validation loss: {best_val_loss:.5f}")
        if (epoch + 1) % 2 == 0:
            save_predictions(model, dataset, device, epoch + 1, num_examples=5)
def test(model, test_loader, criterion, device='cuda'):
    model.to(device)
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for masked_img, mask, original_img in tqdm(test_loader, desc='Testing'):
            masked_img, mask, original_img = masked_img.to(device), mask.to(device), original_img.to(device)
            outputs = model(masked_img)
            mask = mask.expand_as(outputs)
            inpainted_img = original_img.clone()
            inpainted_img[mask == 1] = outputs[mask == 1]
            loss = criterion(outputs * mask, original_img * mask)
            test_loss += loss.item()
    avg_test_loss = test_loss / len(test_loader)
    print(f'Test Loss: {avg_test_loss:.5f}')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
train_dataset = CIFAR100InpaintingDataset(folder='/kaggle/input/cifar100/cifar100/train', transform=transform)
test_dataset = CIFAR100InpaintingDataset(folder='/kaggle/input/cifar100/cifar100/test', transform=transform)
train_idx, val_idx = train_test_split(
    list(range(len(train_dataset))), test_size=0.20, stratify=[train_dataset.samples[i].split('/')[-2] for i in range(len(train_dataset))]
)
train_set = Subset(train_dataset, train_idx)
val_set = Subset(train_dataset, val_idx)
print(f'Training set size: {len(train_set)}')
print(f'Validation set size: {len(val_set)}')
print(f'Test set size: {len(test_dataset)}')
batch_size = 64
num_workers = 8
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
model = EnhancedInpaintingResNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
debug_plot_samples(train_dataset, num_samples=10)

In [None]:
train_and_validate(model, train_loader, val_loader, train_dataset, criterion, optimizer, num_epochs=50, device=device)

In [None]:
model.load_state_dict(torch.load('bestmodel.pth'))
test(model, test_loader, criterion, device=device)