In [None]:
import os
import shutil
import random
import numpy as np
from sklearn.model_selection import train_test_split

def create_split_dirs(output_base_path, split_names=['train', 'val', 'test']):
    """Create directories for train, validation, and test splits."""
    for split in split_names:
        split_path = os.path.join(output_base_path, split)
        if not os.path.exists(split_path):
            os.makedirs(split_path)

def split_dataset(dataset_path, output_base_path, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_seed=42):
    """
    Split dataset into train, validation, and test sets while preserving class distribution.

    Args:
        dataset_path (str): Path to the original dataset directory
        output_base_path (str): Path to save the split datasets
        train_ratio (float): Proportion of data for training (default: 0.7)
        val_ratio (float): Proportion of data for validation (default: 0.15)
        test_ratio (float): Proportion of data for testing (default: 0.15)
        random_seed (int): Seed for reproducibility
    """
    # Validate ratios
    if abs(train_ratio + val_ratio + test_ratio - 1.0) > 0.001:
        raise ValueError("Train, validation, and test ratios must sum to 1.0")

    # Set random seed for reproducibility
    random.seed(random_seed)
    np.random.seed(random_seed)

    # Create train, val, test directories
    create_split_dirs(output_base_path)

    # Process each class
    for class_name in os.listdir(dataset_path):
        class_path = os.path.join(dataset_path, class_name)
        if not os.path.isdir(class_path):
            continue

        # Get list of images in the class
        images = [f for f in os.listdir(class_path) if os.path.isfile(os.path.join(class_path, f))]
        if not images:
            print(f"Warning: No images found in {class_name}")
            continue

        # Shuffle images
        random.shuffle(images)

        # Split images into train, val, test
        train_images, temp_images = train_test_split(images, train_size=train_ratio, random_state=random_seed)
        val_size = val_ratio / (val_ratio + test_ratio)  # Proportion of remaining data for validation
        val_images, test_images = train_test_split(temp_images, train_size=val_size, random_state=random_seed)

        # Create class directories in train, val, test
        for split_name in ['train', 'val', 'test']:
            split_class_path = os.path.join(output_base_path, split_name, class_name)
            if not os.path.exists(split_class_path):
                os.makedirs(split_class_path)

        # Copy images to respective directories
        for img in train_images:
            shutil.copy(os.path.join(class_path, img), os.path.join(output_base_path, 'train', class_name, img))
        for img in val_images:
            shutil.copy(os.path.join(class_path, img), os.path.join(output_base_path, 'val', class_name, img))
        for img in test_images:
            shutil.copy(os.path.join(class_path, img), os.path.join(output_base_path, 'test', class_name, img))

        # Print split counts for verification
        print(f"Class: {class_name}")
        print(f"  Train: {len(train_images)} images")
        print(f"  Validation: {len(val_images)} images")
        print(f"  Test: {len(test_images)} images")

def get_class_counts(dataset_path):
    """Count images per class in the dataset."""
    class_counts = {}
    for class_name in os.listdir(dataset_path):
        class_path = os.path.join(dataset_path, class_name)
        if os.path.isdir(class_path):
            class_counts[class_name] = len([f for f in os.listdir(class_path) if os.path.isfile(os.path.join(class_path, f))])
    return class_counts

if __name__ == "__main__":
    # Your dataset paths (modify as needed)
    input_dataset_path = "C:/Users/Rasel/Downloads/Compressed/archive_2"
    output_split_path = "C:/Users/Rasel/Desktop/output dataset"

    # Print original class counts
    print("Original class counts:", get_class_counts(input_dataset_path))

    # Split the dataset
    split_dataset(input_dataset_path, output_split_path, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15)

    # Verify split counts
    for split in ['train', 'val', 'test']:
        split_path = os.path.join(output_split_path, split)
        print(f"\n{split.capitalize()} class counts:", get_class_counts(split_path))

Original class counts: {'[Malignant] Pro-B': 979, 'Benign': 979, '[Malignant] Pre-B': 979, '[Malignant] early Pre-B': 979}
Class: [Malignant] Pro-B
  Train: 685 images
  Validation: 147 images
  Test: 147 images
Class: Benign
  Train: 685 images
  Validation: 147 images
  Test: 147 images
Class: [Malignant] Pre-B
  Train: 685 images
  Validation: 147 images
  Test: 147 images
Class: [Malignant] early Pre-B
  Train: 685 images
  Validation: 147 images
  Test: 147 images

Train class counts: {'[Malignant] Pro-B': 685, 'Benign': 685, '[Malignant] Pre-B': 685, '[Malignant] early Pre-B': 685}

Val class counts: {'[Malignant] Pro-B': 147, 'Benign': 147, '[Malignant] Pre-B': 147, '[Malignant] early Pre-B': 147}

Test class counts: {'[Malignant] Pro-B': 147, 'Benign': 147, '[Malignant] Pre-B': 147, '[Malignant] early Pre-B': 147}
