# Dataset Preparation and Masking Tool

This notebook prepares a dataset for image reconstruction using a GAN model by applying random rectangular masks to images and organizing the dataset into train, validation, and test splits. The goal is to ensure that each split contains a balanced representation of each class.

## Overview
The notebook:
1. **Applies Random Rectangular Masks**: Adds a black rectangular mask at a random position on each image. The mask size is randomly chosen within a specified range.
2. **Organizes Dataset into Splits**: Splits the dataset into train, validation, and test sets, ensuring that each set contains a proportional distribution of each class (stratified sampling).
3. **Names Images Consistently**: Saves each original image and its corresponding masked version with a consistent naming convention, making it easy for the GAN model to pair original and masked images during training.

## Parameters
- `source_dir`: Path to the original dataset directory containing subdirectories for each class.
- `target_dir`: Path where the processed dataset will be saved, organized into train, validation, and test directories.
- `min_mask_size` and `max_mask_size`: Range for the rectangular mask dimensions (height and width).
- `test_split` and `val_split`: Proportions of the dataset allocated to the test and validation sets, respectively. The remaining portion is used for training.

## Execution
After setting the parameters, execute the final cell to process the dataset. This will:
- Apply masks to each image in the specified directory.
- Save both the original and masked images in the train, validation, and test directories within `target_dir`.
- Ensure each split has a balanced mix of classes and a consistent naming scheme.

The final dataset will be structured and ready for training in an image reconstruction task using a GAN model.

### Imports

In [7]:
import os
import random
from PIL import Image, ImageDraw
from sklearn.model_selection import train_test_split

### Set Parameters

In [8]:
# Use the current working directory in a Jupyter notebook
current_dir = os.getcwd()

# Define source and target directories based on the current file location
source_dir = os.path.join(current_dir, "imagenet_selected_raw_classes")
target_dir = os.path.join(current_dir, "..", "Dataset")  # Moves one level up and creates 'Dataset' directory

min_mask_size = 30                        # Minimum dimension for the rectangular mask
max_mask_size = 60                       # Maximum dimension for the rectangular mask

# Split ratios
test_split = 0.2                          # Fraction of data for the test set
val_split = 0.1                           # Fraction of data for the validation set (from remaining data)

image_size = 224                          # Size of the images for the GAN model

## Dataset Processing with Stratified Split and Random Rectangle Mask

This section processes a dataset by creating train, validation, and test splits that evenly represent each class.
A random rectangular mask is applied to each image, saving both the original and masked versions in the target directory.

### Masking Function

In [9]:
def apply_random_rectangle_mask(image, min_mask_size=30, max_mask_size=60):
    """
    Applies a black rectangular mask of random size within a specified range at a random location on an image.
    
    Parameters:
    - image (PIL.Image.Image): The image object to be masked.
    - min_mask_size (int): Minimum dimension for the rectangular mask (height and width).
    - max_mask_size (int): Maximum dimension for the rectangular mask (height and width).
    
    Returns:
    - masked_image (PIL.Image.Image): Image with a black rectangular mask applied.
    """
    
    # Create a copy of the image to avoid modifying the original
    masked_image = image.copy()
    width, height = masked_image.size

    # Randomly determine the rectangle width and height within the given range
    rect_width = random.randint(min_mask_size, max_mask_size)
    rect_height = random.randint(min_mask_size, max_mask_size)

    # Generate random coordinates for the top-left corner of the rectangle
    x = random.randint(0, width - rect_width)
    y = random.randint(0, height - rect_height)

    # Draw the black rectangle mask
    draw = ImageDraw.Draw(masked_image)
    draw.rectangle([x, y, x + rect_width, y + rect_height], fill="black")
    
    return masked_image

### Dataset Processing with Stratified Splits

In [10]:
def process_dataset_with_stratified_splits(source_dir, target_dir, min_mask_size=30, max_mask_size=60, test_split=0.2, val_split=0.1, image_size=224):
    """
    Processes a dataset by applying a black rectangular mask to each image, renaming them by class,
    and organizing them into stratified train, validation, and test splits.
    
    Parameters:
    - source_dir (str): Directory containing the original dataset with subdirectories for each class.
    - target_dir (str): Directory to save the processed dataset.
    - min_mask_size (int): Minimum dimension for the rectangular mask (height and width).
    - max_mask_size (int): Maximum dimension for the rectangular mask (height and width).
    - test_split (float): Fraction of data to reserve for testing.
    - val_split (float): Fraction of data to reserve for validation (from the remaining training data).
    """
    # Create directories for train, validation, and test
    train_dir = os.path.join(target_dir, "train")
    val_dir = os.path.join(target_dir, "validation")
    test_dir = os.path.join(target_dir, "test")
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(val_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)

    # Collect all images by class
    train_images, val_images, test_images = [], [], []

    for class_name in os.listdir(source_dir):
        class_dir = os.path.join(source_dir, class_name)
        if not os.path.isdir(class_dir):
            continue  # Skip if it's not a directory

        # Gather all image paths in the class subdirectory
        image_paths = [os.path.join(class_dir, f) for f in os.listdir(class_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

        # Stratified split into train, validation, and test for this class
        train_paths, temp_paths = train_test_split(image_paths, test_size=(test_split + val_split), random_state=42)
        val_paths, test_paths = train_test_split(temp_paths, test_size=test_split / (test_split + val_split), random_state=42)

        # Append to overall lists with a consistent naming format
        train_images += [(path, class_name, i) for i, path in enumerate(train_paths)]
        val_images += [(path, class_name, i) for i, path in enumerate(val_paths)]
        test_images += [(path, class_name, i) for i, path in enumerate(test_paths)]
    
    # Process and save images in each split
    for image_path, class_name, index in train_images:
        save_image_pair(image_path, train_dir, class_name, index, min_mask_size, max_mask_size)

    for image_path, class_name, index in val_images:
        save_image_pair(image_path, val_dir, class_name, index, min_mask_size, max_mask_size)

    for image_path, class_name, index in test_images:
        save_image_pair(image_path, test_dir, class_name, index, min_mask_size, max_mask_size)

### Save Image Pairs Function

In [11]:
def save_image_pair(image_path, save_dir, class_name, index, min_mask_size=30, max_mask_size=60, image_size=224):
    """
    Saves the original and masked versions of an image with a consistent naming convention, resizing each to a specified size.
    
    Parameters:
    - image_path (str): Path to the original image file.
    - save_dir (str): Directory where the processed images will be saved.
    - class_name (str): Name of the class to prefix the filenames.
    - index (int): Index of the image in its class for unique naming.
    - min_mask_size (int): Minimum size of the mask.
    - max_mask_size (int): Maximum size of the mask.
    - image_size (int): Size to which the image will be resized (width and height).
    """
    # Define unique names for the original and masked images
    base_filename = f"{class_name}_{index:03d}"
    original_save_path = os.path.join(save_dir, f"{base_filename}.jpg")
    masked_save_path = os.path.join(save_dir, f"{base_filename}_masked.jpg")
    
    # Open and resize the original image
    original_image = Image.open(image_path).convert("RGB")
    original_image = original_image.resize((image_size, image_size))
    original_image.save(original_save_path)
    
    # Apply mask and save masked image
    masked_image = apply_random_rectangle_mask(original_image, min_mask_size, max_mask_size)
    masked_image.save(masked_save_path)

### Run Dataset Processing

In [12]:
# Run the dataset processing with the specified parameters
process_dataset_with_stratified_splits(
    source_dir=source_dir,
    target_dir=target_dir,
    min_mask_size=min_mask_size,
    max_mask_size=max_mask_size,
    test_split=test_split,
    val_split=val_split,
    image_size=image_size
)