In [None]:
import os
import random
from shutil import copyfile
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array


In [None]:
# Set the paths to the original and augmented data directories
data_dir = 'dataset'
original_fake_dir = os.path.join(data_dir, '1')
original_real_dir = os.path.join(data_dir, '2')
augmented_fake_dir = os.path.join(data_dir, 'augmented_fake')
augmented_real_dir = os.path.join(data_dir, 'augmented_real')


In [None]:
# Create directories for augmented data if they don't exist
os.makedirs(augmented_fake_dir, exist_ok=True)
os.makedirs(augmented_real_dir, exist_ok=True)


In [None]:
# Count the number of images in each class
num_fake_images = len(os.listdir(original_fake_dir))
num_real_images = len(os.listdir(original_real_dir))


In [None]:
# Calculate the augmentation factor to balance the data
augmentation_factor = max(num_fake_images, num_real_images) // min(num_fake_images, num_real_images)

# Data augmentation
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)


In [None]:
def augment_images(original_dir, augmented_dir, target_num_images):
    num_original_images = len(os.listdir(original_dir))
    num_copies_needed = target_num_images - num_original_images

    for img_file in os.listdir(original_dir):
        img_path = os.path.join(original_dir, img_file)
        copyfile(img_path, os.path.join(augmented_dir, img_file))

    img_list = os.listdir(augmented_dir)
    while len(img_list) < target_num_images:
        chosen_img = random.choice(img_list)
        img_path = os.path.join(augmented_dir, chosen_img)
        img = load_img(img_path)
        x = img_to_array(img)
        x = x.reshape((1,) + x.shape)
        i = 0
        for batch in datagen.flow(x, batch_size=1, save_to_dir=augmented_dir, save_prefix='aug', save_format='jpeg'):
            i += 1
            if i >= augmentation_factor:
                break
        img_list = os.listdir(augmented_dir)
        

In [None]:
# Balance the data by augmenting the minority class
if num_fake_images < num_real_images:
    augment_images(original_fake_dir, augmented_fake_dir, num_real_images)
    original_fake_dir = augmented_fake_dir  
    augment_images(original_real_dir, augmented_real_dir, num_fake_images)
    original_real_dir = augmented_real_dir 
