In [7]:
import os
import random
from torchvision import transforms
from PIL import Image

In [8]:
ignore_folder = "cat11 L_ShopMall"

In [9]:
dataset_path = "validation"

In [10]:
augmentation_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=1.0),  # Flip image
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Adjust color
])

In [11]:
target_count = 105

In [12]:
for class_name in os.listdir(dataset_path):
    class_path = os.path.join(dataset_path, class_name)
    
    # Skip non-directories and the folder to ignore
    if not os.path.isdir(class_path) or class_name == ignore_folder:
        print(f"Skipping folder: {class_name}")
        continue

    # List all images in the class
    images = os.listdir(class_path)
    image_count = len(images)
    
    if image_count < target_count:
        # Underrepresented class: augment data
        while image_count < target_count:
            image_name = random.choice(images)
            image_path = os.path.join(class_path, image_name)
            
            # Load the image and apply augmentation
            image = Image.open(image_path).convert("RGB")
            augmented_image = augmentation_transform(image)
            
            # Save augmented image with a unique name
            new_image_name = f"augmented_{image_count}.jpg"
            augmented_image.save(os.path.join(class_path, new_image_name))
            
            image_count += 1
    elif image_count > target_count:
        # Overrepresented class: randomly delete images
        random.shuffle(images)
        images_to_remove = images[target_count:]  # Keep only the first `target_count` images
        
        for image_name in images_to_remove:
            image_path = os.path.join(class_path, image_name)
            os.remove(image_path)

    print(f"Class '{class_name}' balanced to {target_count} images.")

print("Dataset balancing completed!")

Class 'cat07 MasAptMotel' balanced to 105 images.
Class 'cat10 StripMall' balanced to 105 images.
Class 'cat17 LowRise' balanced to 105 images.
Class 'cat18 MidRise' balanced to 105 images.
Class 'cat19 HighRise' balanced to 105 images.
Class 'cat21 MetalBldg' balanced to 105 images.
Class 'cat22 Canopy' balanced to 105 images.
Dataset balancing completed!
