# A Modern Reimplementation of "Context Encoders: Feature Learning by Inpainting" using PyTorch



Brief description of project

# Part I. Data Preparation

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

#dependencies for importing/reshaping data

import os
import zipfile
from PIL import Image, ImageDraw

import matplotlib.pyplot as plt

import random

In [None]:
def apply_central_mask_with_block(img, mask_size=64):
    """
    For a 128x128 image, applies a central square mask (mask out a 64x64 area).
    Returns:
      masked_img: image with the central block replaced by black.
      block_img: the 64x64 central region extracted from the original image.
      coords: (left, top, right, bottom) of the masked region.
    """
    width, height = img.size  # Should be 128x128
    left = (width - mask_size) // 2
    top = (height - mask_size) // 2
    right = left + mask_size
    bottom = top + mask_size

    # Crop the block from the original image
    block_img = img.crop((left, top, right, bottom))

    # Create a masked version: paste black (0,0,0) in the central region
    masked_img = img.copy()
    masked_img.paste((0, 0, 0), (left, top, right, bottom))
    
    return masked_img, block_img, (left, top, right, bottom)

def apply_random_region_mask_with_block(img, dropout_fraction=0.25):
    """
    For a 227x227 image, apply a random rectangular mask.
    The area of the mask is approximately dropout_fraction of the image.
    Returns:
      - masked_img: image with the random region masked (set to black).
      - block_img: the cropped region that was masked out.
      - coords: (left, top, right, bottom) of the masked region.
    """
    width, height = img.size  # Expected to be 227x227
    total_area = width * height
    target_area = total_area * dropout_fraction

    # Choose a random width between a minimum (say 20) and half of the width
    w = random.randint(20, width // 2)
    h = int(target_area / w)
    h = min(h, height // 2)

    max_left = width - w
    max_top = height - h
    left = random.randint(0, max_left)
    top = random.randint(0, max_top)
    right = left + w
    bottom = top + h

    block_img = img.crop((left, top, right, bottom))
    masked_img = img.copy()
    masked_img.paste((0, 0, 0), (left, top, right, bottom))
    
    return masked_img, block_img, (left, top, right, bottom)


In [None]:
def apply_masks_to_resized_images(root_dir, target_sizes={'128x128': (128, 128), '227x227': (227, 227)}):
    """
    For each target size folder in root_dir, apply the corresponding masking strategy:
      - For '128x128': apply a central mask (64x64) and extract the block.
      - For '227x227': apply a random region mask (covering ~25% of the area).
    Save the masked images and, for '128x128', also save the extracted block in separate directories.
    """
    for size_label, dims in target_sizes.items():
        # Input directory (resized images)
        input_dir = os.path.join(root_dir, size_label)
        # Output directory for masked images
        output_masked_base = os.path.join(root_dir, f"{size_label}_masked")
        os.makedirs(output_masked_base, exist_ok=True)
        print(f"Processing {size_label} images from {input_dir} -> saving masked images to {output_masked_base}")
        
        # For 128x128, also prepare a directory for the extracted block
        if size_label == '128x128':
            output_block_base = os.path.join(root_dir, f"{size_label}_block")
            os.makedirs(output_block_base, exist_ok=True)
            #print(f"Saving extracted blocks to {output_block_base}")
        
        # Walk through the input directory recursively
        for current_dir, subdirs, files in os.walk(input_dir):
            rel_dir = os.path.relpath(current_dir, input_dir)
            output_masked_subdir = os.path.join(output_masked_base, rel_dir)
            os.makedirs(output_masked_subdir, exist_ok=True)
            
            if size_label == '128x128':
                output_block_subdir = os.path.join(root_dir, f"{size_label}_block", rel_dir)
                os.makedirs(output_block_subdir, exist_ok=True)
            
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    file_path = os.path.join(current_dir, file)
                    try:
                        with Image.open(file_path) as img:
                            img = img.convert('RGB')
                            if img.size != dims:
                                print(f"Skipping {file_path}: size {img.size} does not match expected {dims}")
                                continue
                            
                            if size_label == '128x128':
                                # Apply central mask for 128x128 images
                                masked_img, block_img, coords = apply_central_mask_with_block(img, mask_size=64)
                                masked_save_path = os.path.join(output_masked_subdir, file)
                                block_save_path = os.path.join(output_block_subdir, file)
                                masked_img.save(masked_save_path)
                                block_img.save(block_save_path)
                                #print(f"Processed {file}: saved masked image to {masked_save_path} and block to {block_save_path}")
                            elif size_label == '227x227':
                                # Apply random region mask for 227x227 images
                                masked_img, _, coords = apply_random_region_mask_with_block(img, dropout_fraction=0.25)
                                masked_save_path = os.path.join(output_masked_subdir, file)
                                masked_img.save(masked_save_path)
                                #print(f"Processed {file}: saved masked image to {masked_save_path}")
                    except Exception as e:
                        print(f"Error processing {file_path}: {e}")

# Example usage for masking:
base_dir = './data'
apply_masks_to_resized_images(base_dir, target_sizes={'128x128': (128, 128), '227x227': (227, 227)})

# Part II. Training the model

In [None]:
from dataset_module import InpaintingDataset

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# Hyperparameters
num_epochs = 50
batch_size = 32
lr_gen = 1e-4
lr_disc = 1e-5
lambda_rec = 0.999  # reconstruction loss weight
lambda_adv = 0.001  # adversarial loss weight

# Define image transformation: resize images to 128x128 and convert to tensor
train_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

# Create the dataset and dataloader.
# This dataset loads images from "./data/128x128" (unmasked) and applies a central mask on the fly.
dataset_dir = './data/128x128'
train_dataset = InpaintingDataset(root_dir=dataset_dir, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Import models from your models directory
from models.FeatureLearner import ContextEncoder
from models.SemanticInpainter import Discriminator

# Instantiate models and move them to device
generator = ContextEncoder().to(device)
discriminator = Discriminator().to(device)

# Define loss functions
criterion_rec = nn.MSELoss()   # Reconstruction (L2) loss
criterion_adv = nn.BCELoss()   # Adversarial loss (BCE)

# Define optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr_gen, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_disc, betas=(0.5, 0.999))

# Helper function to visualize inpainting results
def visualize_inpainting(masked, generated, ground_truth, idx=0):
    masked_img = transforms.ToPILImage()(masked[idx].cpu())
    gen_img = transforms.ToPILImage()(generated[idx].cpu().detach())
    gt_img = transforms.ToPILImage()(ground_truth[idx].cpu())
    
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(masked_img)
    axs[0].set_title("Masked Input")
    axs[1].imshow(gen_img)
    axs[1].set_title("Generated Output")
    axs[2].imshow(gt_img)
    axs[2].set_title("Ground Truth Block")
    plt.show()

# Training loop
print("Starting training...")
for epoch in range(num_epochs):
    for i, (masked_imgs, blocks) in enumerate(train_loader):
        masked_imgs = masked_imgs.to(device)
        blocks = blocks.to(device)
        
        # --- Update Discriminator ---
        discriminator.zero_grad()
        
        # Compute discriminator outputs for real blocks and create matching labels
        outputs_real = discriminator(blocks)
        real_labels = torch.ones_like(outputs_real, device=device)
        d_loss_real = criterion_adv(outputs_real, real_labels)
        
        # Generate fake blocks and compute discriminator outputs and matching labels
        fake_blocks = generator(masked_imgs)
        outputs_fake = discriminator(fake_blocks.detach())
        fake_labels = torch.zeros_like(outputs_fake, device=device)
        d_loss_fake = criterion_adv(outputs_fake, fake_labels)
        
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()
        
        # --- Update Generator ---
        generator.zero_grad()
        fake_blocks = generator(masked_imgs)
        outputs_gen = discriminator(fake_blocks)
        real_labels_gen = torch.ones_like(outputs_gen, device=device)
        adv_loss = criterion_adv(outputs_gen, real_labels_gen)
        
        # Resize fake_blocks to match ground truth dimensions (128x128) for reconstruction loss
        fake_blocks_resized = torch.nn.functional.interpolate(fake_blocks, size=blocks.shape[2:], mode='bilinear', align_corners=False)
        rec_loss = criterion_rec(fake_blocks_resized, blocks)
        
        g_loss = lambda_rec * rec_loss + lambda_adv * adv_loss
        g_loss.backward()
        optimizer_G.step()
        
        if i % 50 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i}/{len(train_loader)}] "
                  f"d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, "
                  f"rec_loss: {rec_loss.item():.4f}, adv_loss: {adv_loss.item():.4f}")
            visualize_inpainting(masked_imgs, fake_blocks, blocks)
    
    # Save model checkpoints after each epoch
    torch.save(generator.state_dict(), f"generator_epoch_{epoch+1}.pth")
    torch.save(discriminator.state_dict(), f"discriminator_epoch_{epoch+1}.pth")

print("Training complete!")


# Part III. Evaluating the model