In [None]:
import os
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.utils import resample
import shutil
from pathlib import Path
import cv2

class ImageDatasetBalancer:
    """
    A class to handle imbalanced image datasets stored in folders.
    """
    def __init__(self, base_path, target_size=(224, 224)):
        """
        Initialize the dataset balancer.
        
        Parameters:
        -----------
        base_path : str
            Path to the main folder containing class subfolders
        target_size : tuple
            Target size for resizing images (width, height)
        """
        self.base_path = Path(base_path)
        self.target_size = target_size
        self.class_paths = {}
        self.class_counts = {}
        self.load_dataset_info()

    def load_dataset_info(self):
        """Load information about the dataset structure and class distribution."""
        # Get all subdirectories (class folders)
        self.class_paths = {
            folder.name: list(folder.glob('*.[jp][pn][gf]*')) 
            for folder in self.base_path.iterdir() 
            if folder.is_dir()
        }
        
        # Count images in each class
        self.class_counts = {
            class_name: len(paths) 
            for class_name, paths in self.class_paths.items()
        }
        
        print("Dataset structure:")
        for class_name, count in self.class_counts.items():
            print(f"Class '{class_name}': {count} images")

    def load_and_preprocess_image(self, image_path):
        """Load and preprocess a single image."""
        try:
            image = Image.open(image_path)
            image = image.convert('RGB')
            image = image.resize(self.target_size)
            return np.array(image)
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return None

    def balance_dataset(self, output_path, method='random_oversampling'):
        """
        Balance the dataset using the specified method and save to output directory.
        
        Parameters:
        -----------
        output_path : str
            Path where balanced dataset will be saved
        method : str
            Balancing method: 'random_oversampling', 'random_undersampling', or 'combination'
        """
        output_path = Path(output_path)
        
        # Create output directory
        output_path.mkdir(parents=True, exist_ok=True)
        
        # Find majority and minority classes
        majority_class = max(self.class_counts, key=self.class_counts.get)
        minority_class = min(self.class_counts, key=self.class_counts.get)
        majority_count = self.class_counts[majority_class]
        minority_count = self.class_counts[minority_class]

        if method == 'random_oversampling':
            target_count = majority_count
        elif method == 'random_undersampling':
            target_count = minority_count
        elif method == 'combination':
            target_count = (majority_count + minority_count) // 2
        else:
            raise ValueError(f"Unknown method: {method}")

        # Process each class
        for class_name, image_paths in self.class_paths.items():
            class_output_path = output_path / class_name
            class_output_path.mkdir(exist_ok=True)
            
            # Determine number of images needed
            current_count = len(image_paths)
            
            if current_count > target_count:
                # Undersample
                selected_paths = resample(image_paths, 
                                       n_samples=target_count, 
                                       replace=False, 
                                       random_state=42)
            else:
                # Oversample
                selected_paths = resample(image_paths, 
                                       n_samples=target_count, 
                                       replace=True, 
                                       random_state=42)
            
            # Copy/save images
            for idx, src_path in enumerate(selected_paths):
                dst_path = class_output_path / f"{class_name}_{idx}{src_path.suffix}"
                shutil.copy2(src_path, dst_path)

        # Update dataset info after balancing
        self.base_path = output_path
        self.load_dataset_info()

    def visualize_distribution(self):
        """Visualize the class distribution in the dataset."""
        plt.figure(figsize=(10, 6))
        
        # Create bar plot
        classes = list(self.class_counts.keys())
        counts = list(self.class_counts.values())
        bars = plt.bar(classes, counts)
        
        # Add value labels on top of bars
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height,
                    f'{int(height)}',
                    ha='center', va='bottom')
        
        plt.title('Class Distribution in Dataset')
        plt.xlabel('Classes')
        plt.ylabel('Number of Images')
        plt.xticks(rotation=45)
        plt.tight_layout()
        
        return plt.gcf()

    def show_sample_images(self, num_samples=5):
        """
        Display sample images from each class.
        
        Parameters:
        -----------
        num_samples : int
            Number of sample images to show from each class
        """
        num_classes = len(self.class_paths)
        fig, axes = plt.subplots(num_classes, num_samples, 
                                figsize=(15, 3*num_classes))
        
        for class_idx, (class_name, image_paths) in enumerate(self.class_paths.items()):
            selected_paths = np.random.choice(image_paths, 
                                            min(num_samples, len(image_paths)), 
                                            replace=False)
            
            for img_idx, img_path in enumerate(selected_paths):
                img = self.load_and_preprocess_image(img_path)
                if img is not None:
                    if num_classes == 1:
                        ax = axes[img_idx]
                    else:
                        ax = axes[class_idx, img_idx]
                    ax.imshow(img)
                    ax.axis('off')
                    if img_idx == 0:
                        ax.set_title(f'{class_name}', pad=10)
        
        plt.tight_layout()
        return fig

In [None]:
# Initialize the balancer
balancer = ImageDatasetBalancer(
    base_path="path/to/your/dataset",  # Contains two class folders
    target_size=(224, 224)  # Desired image size
)

# Visualize original distribution
balancer.visualize_distribution()
plt.show()

# Show sample images from each class
balancer.show_sample_images(num_samples=5)
plt.show()

# Balance the dataset
balanced_dataset_path = "path/to/output/balanced_dataset"
balancer.balance_dataset(
    output_path=balanced_dataset_path,
    method='random_oversampling'  # or 'random_undersampling' or 'combination'
)

# Visualize the balanced distribution
balancer.visualize_distribution()
plt.show()