In [11]:
import os
import random
from PIL import Image
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from pathlib import Path

# Set Data size
data_size = 300

# Define input and output paths
input_path = './data'
output_path = f'./augmented_data_{data_size}'

# Ensure output directory exists
Path(output_path).mkdir(parents=True, exist_ok=True)

# Define the augmentations
augmentations = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(degrees=45),
    transforms.ColorJitter(brightness=0.5, contrast=0.5),
    transforms.RandomResizedCrop(size=(256, 256), scale=(0.8, 1.0)),
])

# Load dataset
dataset = ImageFolder(root=input_path)

# Function to augment and save images
def augment_and_save_image(image, save_path, img_name, num_augmented):
    for i in range(num_augmented):
        augmented_image = augmentations(image)
        new_img_name = f"{img_name}_aug_{i+1}.png"
        augmented_image.save(os.path.join(save_path, new_img_name))

# Count and augment images for each class folder
for class_name in dataset.classes:
    print(class_name)
    class_path = os.path.join(input_path, class_name)
    class_output_path = os.path.join(output_path, class_name)
    Path(class_output_path).mkdir(parents=True, exist_ok=True)
    
    current_images = os.listdir(class_path)
    num_existing_images = len(current_images)
    print(num_existing_images)

    if num_existing_images >= data_size:
        print(f"{class_name} already has {data_size} or more images.")
        continue

    num_augmented_images_needed = data_size - num_existing_images
    
    for idx, img_name in enumerate(current_images):
        # Save the original image in the output folder
        image.save(os.path.join(class_output_path, img_name))


    # Augment existing images until we reach 300 images
    for idx, img_name in enumerate(current_images):
        if num_existing_images + idx >= data_size:
            print(f"Reached {data_size} images for class {class_name}")
            break
        img_path = os.path.join(class_path, img_name)
        image = Image.open(img_path)

        # Calculate how many augmentations to apply per image
        num_augment_per_image = num_augmented_images_needed // num_existing_images + 1

        # Augment and save images
        augment_and_save_image(image, class_output_path, img_name.split('.')[0], num_augment_per_image)
    print(f"Augmentation completed for class: {class_name}, total images: {len(os.listdir(class_output_path))}")

Chickenpox
107
Augmentation completed for class: Chickenpox, total images: 321
Measles
91
Augmentation completed for class: Measles, total images: 364
Monkeypox
279
Reached 300 images for class Monkeypox
Augmentation completed for class: Monkeypox, total images: 300
Normal
293
Reached 300 images for class Normal
Augmentation completed for class: Normal, total images: 300


In [2]:
import os
import random
from PIL import Image
from torchvision import transforms
from pathlib import Path

# Set Data size
data_size = 300

# Define input and output paths
input_path = './data'
output_path = f'./augmented_data_{data_size}'

# Ensure output directory exists
Path(output_path).mkdir(parents=True, exist_ok=True)

# Define the augmentations
augmentations = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(degrees=45),
    transforms.ColorJitter(brightness=0.5, contrast=0.5),
    transforms.RandomResizedCrop(size=(256, 256), scale=(0.8, 1.0)),
])

# Function to augment and save images
def augment_and_save_image(image, save_path, img_name, num_augmented):
    for i in range(num_augmented):
        augmented_image = augmentations(image)
        new_img_name = f"{img_name}_aug_{i+1}.png"
        augmented_image.save(os.path.join(save_path, new_img_name))

# Load existing images and augment them to reach 300
for class_name in os.listdir(input_path):
    print(f"Processing class: {class_name}")
    
    class_path = os.path.join(input_path, class_name)
    class_output_path = os.path.join(output_path, class_name)
    Path(class_output_path).mkdir(parents=True, exist_ok=True)

    current_images = os.listdir(class_path)
    num_existing_images = len(current_images)

    if num_existing_images >= data_size:
        print(f"{class_name} already has {data_size} or more images.")
        continue

    num_augmented_images_needed = data_size - num_existing_images

    # Loop over the original images and augment them until we reach the desired size
    augmented_image_count = 0
    while len(os.listdir(class_output_path)) < data_size:
        for img_name in current_images:
            if len(os.listdir(class_output_path)) >= data_size:
                break
            
            img_path = os.path.join(class_path, img_name)
            image = Image.open(img_path)

            # Save the original image if it doesn't exist in output
            if augmented_image_count < num_existing_images:
                image.save(os.path.join(class_output_path, img_name))

            # Augment and save images
            augment_and_save_image(image, class_output_path, img_name.split('.')[0], 1)
            augmented_image_count += 1

    print(f"Augmentation completed for class: {class_name}, total images: {len(os.listdir(class_output_path))}")


Processing class: Measles
Augmentation completed for class: Measles, total images: 364
Processing class: Chickenpox
Augmentation completed for class: Chickenpox, total images: 321
Processing class: Normal
Augmentation completed for class: Normal, total images: 300
Processing class: Monkeypox
Augmentation completed for class: Monkeypox, total images: 300
