# Data Augemntation - Image Erasing
1. Hide and Seek -> divides the image into random regions and removes some, masking multiple parts at the same time
2. CoarseDropout -> removes a fixed-size subregion
3. GridDropout -> overlaying a transparent grid, blacking out the rest of the image
4. RandomErasing -> replaces a randomly selected rectangular region with a black box

In [None]:
import torch
import shutil
import os
import numpy as np
import cv2
import random
import albumentations as A #this supposedly is much faster than using torchvision
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

### Augmentation Pipeline (Balanced Erasing)

In [None]:
def hide_and_seek(image, **kwargs):  
    """
    Less aggressive Hide-and-Seek occlusion technique without altering colors.
    """
    if isinstance(image, np.ndarray):
        img = np.transpose(image, (2, 0, 1))
        img = img.copy()  
    else:
        img = image.clone() if isinstance(image, torch.Tensor) else image  
 
    c, h, w = img.shape  
    grid_sizes = [12, 24, 36, 48]
    hide_prob = 0.2  # Reduced probability to erase smaller portions
 
    grid_size = random.choice(grid_sizes)
 
    for x in range(0, w, grid_size):
        for y in range(0, h, grid_size):
            x_end = min(w, x + grid_size)
            y_end = min(h, y + grid_size)
            if random.random() <= hide_prob:
                img[:, y:y_end, x:x_end] = 0
 
    if isinstance(image, np.ndarray):
        img = np.transpose(img, (1, 2, 0))  
 
    return img

def get_random_augmentation():
    """
    Randomly selects and returns one augmentation technique without altering colors.
    """
    augmentations = [
        A.CoarseDropout(max_holes=10, max_height=50, max_width=50, fill_value=0, p=1.0),
        A.GridDropout(ratio=0.45, p=1.0),
        A.Lambda(image=hide_and_seek, p=1.0),
        A.Erasing(p=1.0, scale=(0.05, 0.5), ratio=(0.25, 4.0), value=0) 
    ]
    return A.Compose([random.choice(augmentations), ToTensorV2()])

### Use the pipeline to balance the dataset
### Idea : The augmented images at each catergory must all be equal to majority_class*2

In [None]:
# Define dataset paths
input_dir = "../data/dataset_split/train"  
output_dir = "../data/dataset_erasing/train"
 
if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
    
os.makedirs(output_dir, exist_ok=True)
 
class_counts = {}
for class_folder in os.listdir(input_dir):
    class_path = os.path.join(input_dir, class_folder)
    if not os.path.isdir(class_path):
        continue
    num_images = len(os.listdir(class_path))
    class_counts[class_folder] = num_images
 
change_factor = 2  # Increase the number of images in each class by 2x
max_class_size = max(class_counts.values())
new_target_size = max_class_size * change_factor
 
print(class_counts)
print(f"Max category is {max_class_size} of class {max(class_counts, key=class_counts.get)}")
 
for class_folder, current_count in tqdm(class_counts.items(), desc="Balancing & Expanding Classes"):
    class_path = os.path.join(input_dir, class_folder)
    augmented_class_path = os.path.join(output_dir, class_folder)
    os.makedirs(augmented_class_path, exist_ok=True)
 
    images = os.listdir(class_path)
    
    for img_name in images:
        src_path = os.path.join(class_path, img_name)
        dst_path = os.path.join(augmented_class_path, img_name)
        img = cv2.imread(src_path)
        cv2.imwrite(dst_path, img)
 
    # Compute number of extra images needed
    num_needed = new_target_size - current_count
 
    while num_needed > 0:
        for img_name in images:
            if num_needed <= 0:
                break
 
            img_path = os.path.join(class_path, img_name)
            image = cv2.imread(img_path)
 
            augmentation_pipeline = get_random_augmentation()
            augmented = augmentation_pipeline(image=image)["image"]
 
            if isinstance(augmented, torch.Tensor):  
                augmented = augmented.clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
 
            output_filename = f"{os.path.splitext(img_name)[0]}_aug_{num_needed}.png"
            output_path = os.path.join(augmented_class_path, output_filename)
            cv2.imwrite(output_path, augmented)
            num_needed -= 1
 
print(f"Dataset balanced & expanded! New images saved in {output_dir}")