In [None]:
import os
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from math import ceil
import matplotlib.pyplot as plt
from PIL import Image

# Set the path to your folders
folder_path = 'PizzaHut-samples'
augmented_images_dir = 'PizzaHut-samples'

# Check if the folder exists
if not os.path.exists(folder_path):
    print(f"Error: The folder '{folder_path}' does not exist.")
else:
    # Create the augmented_images folder if it doesn't exist
    os.makedirs(augmented_images_dir, exist_ok=True)
    
    # Get list of all image files in the folder
    image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    num_images = len(image_files)
    print(f"Found {num_images} images in the folder '{folder_path}'.")

    # Check how many images are in the augmented_images folder
    augmented_image_files = os.listdir(augmented_images_dir)
    augmented_images_count = len([f for f in augmented_image_files if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
    print(f"Found {augmented_images_count} augmented images in the folder '{augmented_images_dir}'.")

    # If there are already 200 or more augmented images, no further augmentation is necessary
    if augmented_images_count >= 200:
        print("The folder already contains 200 or more augmented images.")
    else:
        # Calculate how many more images are needed
        images_needed = 200 - augmented_images_count
        print(f"Need to generate {images_needed} more augmented images.")

        # Calculate how many augmented images to create per original image
        images_per_original = ceil(images_needed / num_images)
        print(f"Each original image will generate {images_per_original} augmented images.")

        # Initialize the custom ImageDataGenerator with additional augmentations
        datagen = ImageDataGenerator(
            rotation_range=40,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            fill_mode='nearest'
        )

        # Function to augment and save images
        def augment_images(image_files, images_per_original):
            for image_file in image_files:
                img_path = os.path.join(folder_path, image_file)
                img = load_img(img_path)
                img_array = img_to_array(img)
                img_array = np.expand_dims(img_array, axis=0)
                
                # Generate a fixed number of augmented images for the current original image
                generated_count = 0
                for batch in datagen.flow(
                    img_array,
                    batch_size=1,
                    save_to_dir=augmented_images_dir,
                    save_prefix='aug_' + os.path.splitext(image_file)[0],
                    save_format='jpg'
                ):
                    generated_count += 1
                    if generated_count >= images_per_original:
                        break

            print("Augmentation complete.")

        # Call the function to augment images
        augment_images(image_files, images_per_original)

        # Optionally, display some augmented images to verify
        augmented_image_files = os.listdir(augmented_images_dir)
        sample_images = augmented_image_files[:5]  # Display first 5 augmented images
        fig, axes = plt.subplots(1, 5, figsize=(15, 15))
        for ax, img_file in zip(axes, sample_images):
            img = Image.open(os.path.join(augmented_images_dir, img_file))
            ax.imshow(img)
            ax.axis('off')
        plt.show()
