In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import os
from PIL import Image
import glob
from scipy.ndimage import convolve

In [2]:
def apply_sharpening(img_array, strength=1.0):
    # Sharpening kernel
    kernel = np.array([
        [-1, -1, -1],
        [-1,  9, -1],
        [-1, -1, -1]
    ])
    
    # Normalize kernel
    kernel = kernel * strength / kernel.sum()
    
    # Apply sharpening to each channel
    sharpened = np.zeros_like(img_array)
    for i in range(img_array.shape[-1]):
        sharpened[..., i] = convolve(img_array[..., i], kernel)
    
    return np.clip(sharpened, 0, 255)

def create_augmented_dataset(input_dir, output_dir, target_size=(224, 224), target_per_class=2500, sharpen_strength=1.0):
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'Normal'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'OSCC'), exist_ok=True)
    
    datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        vertical_flip=True,
        fill_mode='reflect',
        brightness_range=[0.7, 1.3],
        preprocessing_function=lambda x: apply_sharpening(x, strength=sharpen_strength)
    )
    
    for class_name in ['Normal', 'OSCC']:
        print(f"\nProcessing {class_name} class...")
        
        class_path = os.path.join(input_dir, class_name)
        images = glob.glob(os.path.join(class_path, '*'))
        num_orig = len(images)
        print(f"Found {num_orig} original images")
        
        augmentations_per_image = (target_per_class - num_orig) // num_orig + 1
        
        generated_count = 0
        
        for img_path in images:
            img = Image.open(img_path)
            img = img.resize(target_size)
            img_array = np.array(img)
            
            sharpened = apply_sharpening(img_array, strength=sharpen_strength)
            
            img_name = f"sharp_{os.path.basename(img_path)}"
            new_path = os.path.join(output_dir, class_name, img_name)
            Image.fromarray(sharpened.astype(np.uint8)).save(new_path)
            generated_count += 1
        
        # Generate augmented images
        for img_path in images:
            img = tf.keras.preprocessing.image.load_img(
                img_path, 
                target_size=target_size
            )
            x = tf.keras.preprocessing.image.img_to_array(img)
            x = x.reshape((1,) + x.shape)
            
            batch_count = 0
            for batch in datagen.flow(
                x, 
                batch_size=1,
                save_to_dir=os.path.join(output_dir, class_name),
                save_prefix=f'aug_{os.path.splitext(os.path.basename(img_path))[0]}',
                save_format='jpg'
            ):
                batch_count += 1
                generated_count += 1
                
                if batch_count >= augmentations_per_image:
                    break
                
                if generated_count >= target_per_class:
                    break
            
            if generated_count >= target_per_class:
                break
        
        print(f"Generated {generated_count} total images for {class_name}")

In [3]:

def main():
    input_dir = 'dataset'
    output_dir = 'augmented_dataset'
    
    target_size = (224, 224)
    target_per_class = 2500
    sharpen_strength = 1.0
    
    create_augmented_dataset(
        input_dir, 
        output_dir, 
        target_size, 
        target_per_class,
        sharpen_strength
    )
    
    for class_name in ['Normal', 'OSCC']:
        path = os.path.join(output_dir, class_name)
        num_images = len(glob.glob(os.path.join(path, '*')))
        print(f"\nFinal count for {class_name}: {num_images} images")

if __name__ == "__main__":
    main()


Processing Normal class...
Found 89 original images
Generated 2500 total images for Normal

Processing OSCC class...
Found 439 original images
Generated 2500 total images for OSCC

Final count for Normal: 2496 images

Final count for OSCC: 2500 images
