In [1]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import os
from PIL import Image
import glob

In [2]:
def create_augmented_dataset(input_dir, output_dir, target_size=(224, 224), target_per_class=2000):

    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'Normal'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'OSCC'), exist_ok=True)
    
    datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        vertical_flip=True,
        fill_mode='reflect',
        brightness_range=[0.7, 1.3]
    )
    
    # Process each class
    for class_name in ['Normal', 'OSCC']:
        print(f"\nProcessing {class_name} class...")
        
        # Get list of original images
        class_path = os.path.join(input_dir, class_name)
        images = glob.glob(os.path.join(class_path, '*'))
        num_orig = len(images)
        print(f"Found {num_orig} original images")
        
        # Calculate how many augmented images we need per original image
        augmentations_per_image = (target_per_class - num_orig) // num_orig + 1
        
        # Counter for generated images
        generated_count = 0
        
        # Copy original images first
        for img_path in images:
            img_name = os.path.basename(img_path)
            new_path = os.path.join(output_dir, class_name, img_name)
            tf.io.gfile.copy(img_path, new_path, overwrite=True)
            generated_count += 1
        
        # Generate augmented images
        for img_path in images:
            # Load and resize image
            img = tf.keras.preprocessing.image.load_img(
                img_path, 
                target_size=target_size
            )
            x = tf.keras.preprocessing.image.img_to_array(img)
            x = x.reshape((1,) + x.shape)
            
            # Generate augmented images
            batch_count = 0
            for batch in datagen.flow(
                x, 
                batch_size=1,
                save_to_dir=os.path.join(output_dir, class_name),
                save_prefix=f'aug_{os.path.splitext(os.path.basename(img_path))[0]}',
                save_format='jpg'
            ):
                batch_count += 1
                generated_count += 1
                
                # Break when we've generated enough images for this original image
                if batch_count >= augmentations_per_image:
                    break
                
                # Break if we've reached our target
                if generated_count >= target_per_class:
                    break
            
            if generated_count >= target_per_class:
                break
        
        print(f"Generated {generated_count} total images for {class_name}")

In [3]:
def main():
    input_dir = 'dataset'
    output_dir = 'augmented_dataset'
    
    target_size = (224, 224)
    target_per_class = 2000
    
    create_augmented_dataset(input_dir, output_dir, target_size, target_per_class)
    
    for class_name in ['Normal', 'OSCC']:
        path = os.path.join(output_dir, class_name)
        num_images = len(glob.glob(os.path.join(path, '*')))
        print(f"\nFinal count for {class_name}: {num_images} images")

if __name__ == "__main__":
    main()


Processing Normal class...
Found 89 original images
Generated 2000 total images for Normal

Processing OSCC class...
Found 439 original images
Generated 2000 total images for OSCC

Final count for Normal: 1998 images

Final count for OSCC: 2000 images
