In [3]:
import os
import random
import shutil

def balance_image_dataset(dataset_path):
    """
    Balances the number of images across all class directories in a dataset.
    
    This function is designed to address class imbalance in image classification datasets
    by ensuring each class has an equal number of images, which is crucial for training
    fair and accurate Convolutional Neural Networks (CNNs).
    
    Key Steps:
    1. Identify all class directories
    2. Count the number of images in each class
    3. Find the minimum number of images across all classes
    4. Randomly remove excess images from classes with more images
    
    Args:
    dataset_path (str): Full path to the root directory containing class subdirectories
    
    Returns:
    dict: A dictionary containing the number of images in each class after balancing
    """
    # Step 1: Discover Class Directories
    # We assume each subdirectory represents a unique class for the CNN
    class_dirs = [
        d for d in os.listdir(dataset_path) 
        if os.path.isdir(os.path.join(dataset_path, d))
    ]
    
    # Step 2: Count Images in Each Class Directory
    # Support a wide range of image file extensions to ensure comprehensive detection
    image_extensions = ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp']
    
    # Create a dictionary to store the number of images in each class
    image_counts = {}
    for class_dir in class_dirs:
        full_path = os.path.join(dataset_path, class_dir)
        
        # Carefully count image files, excluding hidden or non-image files
        image_files = [
            f for f in os.listdir(full_path) 
            if any(f.lower().endswith(ext) for ext in image_extensions)
        ]
        
        image_counts[class_dir] = len(image_files)
    
    # Step 3: Determine the Minimum Number of Images
    # This becomes our target number of images for each class
    min_images = min(image_counts.values())
    
    # Provide initial dataset information
    print("Initial Dataset Composition:")
    for class_dir, count in image_counts.items():
        print(f"{class_dir}: {count} images")
    print(f"\nTarget images per class: {min_images}")
    
    # Step 4: Balance the Dataset
    balanced_counts = {}
    for class_dir, count in image_counts.items():
        full_path = os.path.join(dataset_path, class_dir)
        
        # If this class has more images than the minimum, remove excess
        if count > min_images:
            # Gather all image files in the directory
            image_files = [
                f for f in os.listdir(full_path) 
                if any(f.lower().endswith(ext) for ext in image_extensions)
            ]
            
            # Randomly select files to remove
            # This ensures a random sampling of images to remove
            excess_files = random.sample(image_files, count - min_images)
            
            # Remove excess files
            for file in excess_files:
                file_path = os.path.join(full_path, file)
                os.remove(file_path)
            
            print(f"Removed {count - min_images} images from {class_dir}")
        
        # Record the final count for this class
        balanced_counts[class_dir] = min(count, min_images)
    
    # Final Report
    print("\nFinal Dataset Composition:")
    for class_dir, count in balanced_counts.items():
        print(f"{class_dir}: {count} images")
    
    return balanced_counts

# Practical Usage Example
# Make sure to replace with your actual dataset path
dataset_path = '/Users/jakehopkins/Desktop/clean:dirty/Train'
final_image_counts = balance_image_dataset(dataset_path)

Initial Dataset Composition:
Clean 5 gpm: 4366 images
Clean 2.5 gpm: 1829 images
Clean .75 gpm: 1123 images
Clean 1.75 gpm: 1810 images

Target images per class: 1123
Removed 3243 images from Clean 5 gpm
Removed 706 images from Clean 2.5 gpm
Removed 687 images from Clean 1.75 gpm

Final Dataset Composition:
Clean 5 gpm: 1123 images
Clean 2.5 gpm: 1123 images
Clean .75 gpm: 1123 images
Clean 1.75 gpm: 1123 images


In [None]:
# Define the learning rate schedule
initial_learning_rate = 0.001
lr_schedule = ExponentialDecay(
    initial_learning_rate,
    decay_steps=100000,
    decay_rate=0.96,
    staircase=True)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import tensorflow as tf
# Assuming you have already trained your model and have validation data
# Predict the labels for the validation dataset
val_labels = np.concatenate([y for x, y in validation_dataset], axis=0)
val_predictions = model.predict(validation_dataset)
val_predictions = np.round(val_predictions).astype(int).flatten()
# Compute the confusion matrix
cm = confusion_matrix(val_labels, val_predictions)
# Plot the confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Cat', 'Dog'], yticklabels=['Cat', 'Dog'])
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()