In [1]:
import os
import shutil
import random
from collections import defaultdict

def balance_dataset(input_folder, output_folder, removed_folder):
    # Create output and removed folders if they don't exist
    os.makedirs(output_folder, exist_ok=True)
    os.makedirs(removed_folder, exist_ok=True)
    
    for phase in ['train', 'test']:
        phase_input_path = os.path.join(input_folder, phase)
        phase_output_path = os.path.join(output_folder, phase)
        phase_removed_path = os.path.join(removed_folder, phase)
        
        # Create output and removed folders for the phase
        os.makedirs(phase_output_path, exist_ok=True)
        os.makedirs(phase_removed_path, exist_ok=True)
        
        class_counts = defaultdict(int)
        class_images = defaultdict(list)
        
        # Gather images and count class occurrences
        for class_name in os.listdir(phase_input_path):
            class_input_path = os.path.join(phase_input_path, class_name)
            class_output_path = os.path.join(phase_output_path, class_name)
            class_removed_path = os.path.join(phase_removed_path, class_name)
            
            # Create class folders in output and removed
            os.makedirs(class_output_path, exist_ok=True)
            os.makedirs(class_removed_path, exist_ok=True)
            
            for image_name in os.listdir(class_input_path):
                image_path = os.path.join(class_input_path, image_name)
                class_images[class_name].append(image_path)
                class_counts[class_name] += 1
        
        # Find the minimum class count for balancing
        min_count = min(class_counts.values())
        
        # Balance classes
        for class_name, images in class_images.items():
            random.shuffle(images)
            for i, image_path in enumerate(images):
                if i < min_count:
                    shutil.copy(image_path, os.path.join(phase_output_path, class_name, os.path.basename(image_path)))
                else:
                    shutil.copy(image_path, os.path.join(phase_removed_path, class_name, os.path.basename(image_path)))

# Paths for input, output, and removed images
input_folder = r'C:\Dataset'
output_folder = r'C:\balanced dataset'
removed_folder = r'C:\removed dataset'

balance_dataset(input_folder, output_folder, removed_folder)
