In [1]:
import os
import cv2
import numpy as np
from tqdm import tqdm

In [2]:
datasets = {
    # 'drc': {
    #     'input': '/Users/abhiruppaul/Abhirup/DCU/Practicum/dataset_not_aug/drc/training',
    #     'output': '/Users/abhiruppaul/Abhirup/DCU/Practicum/dataset_not_aug/augmented/drc_aug'
    # },
    'aptos': {
        'input': '/Users/abhiruppaul/Abhirup/DCU/Practicum/dataset_not_aug/aptos/training',
        'output': '/Users/abhiruppaul/Abhirup/DCU/Practicum/dataset_not_aug/augmented/aptos_aug_2'
    }
}


# Class folder names: class_0 to class_4
classes = [f'class_{i}' for i in range(5)]


In [3]:
# --- Augmentation functions ---
def horizontal_flip(image):
    return cv2.flip(image, 1), '_hf'

def vertical_flip(image):
    return cv2.flip(image, 0), '_vf'

def adjust_brightness(image, factor=1.2):
    hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
    hsv[..., 2] = np.clip(hsv[..., 2] * factor, 0, 255)
    image_out = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
    suffix = f"_b{int(factor*10)}" if factor != 1.2 else "_b"
    return image_out, suffix

def adjust_saturation(image, factor=1.2):
    hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
    hsv[..., 1] = np.clip(hsv[..., 1] * factor, 0, 255)
    image_out = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
    suffix = f"_s{int(factor*10)}" if factor != 1.2 else "_s"
    return image_out, suffix

In [4]:
# --- Augmentation strategy ---
def augment_image(image, count):
    augmented = []

    if count >= 1:
        img, suffix = horizontal_flip(image)
        augmented.append((img, suffix))
    if count >= 2:
        img, suffix = vertical_flip(image)
        augmented.append((img, suffix))
    if count >= 3:
        img, suffix = adjust_saturation(image, 1.2)
        augmented.append((img, suffix))
    if count >= 4:
        img, suffix = adjust_brightness(image, 1.2)
        augmented.append((img, suffix))
    for i in range(4, count):
        if i % 2 == 0:
            factor = 1.2 + 0.1 * ((i - 4)//2 + 1)
            img, suffix = adjust_saturation(image, factor)
        else:
            factor = 1.2 + 0.1 * ((i - 5)//2 + 1)
            img, suffix = adjust_brightness(image, factor)
        augmented.append((img, suffix))
    return augmented

In [5]:
# --- Main processing function ---
def process_dataset(dataset_name):
    input_dir = datasets[dataset_name]['input']
    output_dir = datasets[dataset_name]['output']

    # Step 1: Calculate class sizes
    class_sizes = {}
    for class_name in classes:
        class_path = os.path.join(input_dir, class_name)
        num_images = len([
            img for img in os.listdir(class_path)
            if img.lower().endswith(('.jpg', '.jpeg', '.png'))
        ])
        class_sizes[class_name] = num_images

    max_class_size = max(class_sizes.values())

    # Step 2: Augment only underrepresented classes
    for class_name in classes:
        input_class_dir = os.path.join(input_dir, class_name)
        output_class_dir = os.path.join(output_dir, class_name)
        os.makedirs(output_class_dir, exist_ok=True)

        img_names = [
            img for img in os.listdir(input_class_dir)
            if img.lower().endswith(('.jpg', '.jpeg', '.png'))
        ]

        current_count = len(img_names)
        required_total = max_class_size

        if current_count == 0:
            continue  # Skip empty class

        # Progress bar
        print(f"\nProcessing class: {class_name} ({current_count} → {required_total})")
        for img_name in tqdm(img_names, desc=f'{class_name}', leave=False):
            input_img_path = os.path.join(input_class_dir, img_name)
            image = cv2.imread(input_img_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            # Save original image
            save_original_path = os.path.join(output_class_dir, img_name)
            cv2.imwrite(save_original_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

        if current_count >= required_total:
            continue  # No augmentation needed

        # Step 3: Augmentation to balance class
        aug_index = 0
        num_to_generate = required_total - current_count
        while aug_index < num_to_generate:
            img_name = img_names[aug_index % current_count]
            input_img_path = os.path.join(input_class_dir, img_name)
            image = cv2.imread(input_img_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            aug_imgs = augment_image(image, count=10)  # Generate multiple augmentations

            for aug_img, suffix in aug_imgs:
                if aug_index >= num_to_generate:
                    break
                aug_filename = f"{os.path.splitext(img_name)[0]}_aug{suffix}_{aug_index + 1}.png"
                save_path = os.path.join(output_class_dir, aug_filename)
                cv2.imwrite(save_path, cv2.cvtColor(aug_img, cv2.COLOR_RGB2BGR))
                aug_index += 1

In [6]:
# --- Run both datasets ---
for dataset in datasets:
    print(f"\n====================")
    print(f"Processing dataset: {dataset}")
    print(f"====================")
    process_dataset(dataset)


Processing dataset: aptos

Processing class: class_0 (1444 → 1444)


                                                             


Processing class: class_1 (296 → 1444)


                                                           


Processing class: class_2 (799 → 1444)


                                                           


Processing class: class_3 (154 → 1444)


                                                           


Processing class: class_4 (236 → 1444)


                                                           