In [None]:
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img, array_to_img, save_img
import math

In [None]:
# --------------------------
# Parameters
# --------------------------
TRAIN_PATH = "../data/splits/train"
AUGMENTED_PATH = "../data/splits/train_augmented"  # save augmented images here
IMG_SIZE = (224, 224)
AUGMENT_FACTOR = 0.3  # 30% increase

# --------------------------
# Create augmented folder structure
# --------------------------
if not os.path.exists(AUGMENTED_PATH):
    os.makedirs(AUGMENTED_PATH)

classes = sorted(os.listdir(TRAIN_PATH))
for cls in classes:
    os.makedirs(os.path.join(AUGMENTED_PATH, cls), exist_ok=True)

# --------------------------
# Define augmentation
# --------------------------
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True,
    vertical_flip=False,  # optional
    brightness_range=[0.8, 1.2],
    fill_mode='nearest'
)

# --------------------------
# Apply augmentation per class
# --------------------------
for cls in classes:
    class_dir = os.path.join(TRAIN_PATH, cls)
    augmented_class_dir = os.path.join(AUGMENTED_PATH, cls)
    images = [f for f in os.listdir(class_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
    
    # Number of augmented images needed
    n_aug = math.ceil(len(images) * AUGMENT_FACTOR)
    
    count = 0
    for img_name in images:
        img_path = os.path.join(class_dir, img_name)
        img = load_img(img_path, target_size=IMG_SIZE)
        x = img_to_array(img)
        x = x.reshape((1,) + x.shape)  # reshape for datagen
        
        # Generate augmented images
        for batch in datagen.flow(x, batch_size=1, save_to_dir=augmented_class_dir,
                                  save_prefix='aug', save_format='jpg'):
            count += 1
            if count >= n_aug:
                break  # stop when enough augmented images are generated
        if count >= n_aug:
            break

print("âœ… Data augmentation completed!")
