# PCAM Dataset Extractor - Quick Start Guide

This script extracts and downsamples the PatchCamelyon (PCAM) dataset from HDF5 files while maintaining class balance through stratified sampling. The extractor reads labels first, randomly selects indices proportional to each class distribution, then extracts only those selected images—making it highly memory-efficient even for large datasets.

To create a custom subset, modify three key values in the `if __name__ == "__main__":` section at the bottom:
- **`BASE_DIR`**: Set to your H5 files location (e.g., `Path(r"E:\PCam")`)
- **`OUTPUT_DIR`**: Set to where you want extracted images saved (e.g., `Path(r"E:\PCam_Extracted_50k")`)
- **`TARGET_SIZE`**: Set to your desired total dataset size (e.g., `50000` for 50k images, `10000` for 10k images)

The script automatically splits your target size into 80% train, 10% validation, and 10% test sets while preserving the original class distribution in each split. The `RANDOM_SEED = 42` ensures reproducibility—using the same seed always selects the same images, allowing you to verify results or share exact subsets with collaborators.

**Note:** Originally created to run in VS Code, not Google Colab.

100K Images Subset

In [None]:
import h5py # Reading HDF5 (.h5) file format
import numpy as np # Numerical operations and array manipulation
from pathlib import Path # Cross-platform file path handling
from PIL import Image # Image processing and saving
from tqdm import tqdm # Displayinng progress bars during loops

def calculate_split_sizes(total_target=100000, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
    """
    Purpose:
    --------
    Calculate how many images to keep in each split.
    
    Parameters:
    -----------
    total_target : int
        Total number of images desired in downsampled dataset
    train_ratio, val_ratio, test_ratio : float
        Ratios for train/val/test split
    
    Returns:
    --------
    dict : Number of images to keep for each split
    """
    # Calculate number of training images by multiplying total by train ratio
    train_size = int(total_target * train_ratio)
    # Calculate number of validation images by multiplying total by validation ratio
    val_size = int(total_target * val_ratio)
    # Calculate number of test images by multiplying by test ratio
    test_size = int(total_target * test_ratio)
    
    # Adjust to ensure exact total
    # Calcualte sum of all three split sizes to check for rounding errors
    actual_total = train_size + val_size + test_size
    # Check if actual total differs from target due to integer rounding
    if actual_total != total_target:
        # Add the difference to the training set to reach the exact target total
        train_size += (total_target - actual_total)
    
    # Return a dictionary mapping split names to their calculated sizes
    return {
        'train': train_size,
        'valid': val_size,
        'test': test_size
    }


def get_stratified_indices(labels, n_samples, random_seed=42):
    """
    Purpose:
    --------
    Get stratified sample indices to maintain class balance.
    Uses a fixed seed for reproducibility.
    
    Parameters:
    -----------
    labels : numpy.ndarray
        Array of binary labels (0 or 1)
    n_samples : int
        Number of samples to select
    random_seed : int
        Random seed for reproducibility
    
    Returns:
    --------
    selected_indices : numpy.ndarray
        Sorted array of selected indices
    """
    # Set random seed for reproducible sampling across runs
    # Initializes the psuedo-random number generator (PRNG) with a specific starting point
    np.random.seed(random_seed)
    
    # Get indices for each class
    # Find all indices where the label is 0 (no tumor class)
    class_0_indices = np.where(labels == 0)[0]
    # Find all indices where the label is 1 (tumor class)
    class_1_indices = np.where(labels == 1)[0]
    
    # Calculate how many of each class to keep (proportional)
    # Count total number of class 0 samples
    total_class_0 = len(class_0_indices)
    # Count total number of class 1 samples
    total_class_1 = len(class_1_indices)
    # Calculate total number of sampels across both classes
    total = total_class_0 + total_class_1
    
    # Calculate how many class 0 samples to keep, proportional to original distribution
    n_class_0 = int(n_samples * (total_class_0 / total))
    # Calculate class 1 samples as remainder to ensure exact n_samples total
    n_class_1 = n_samples - n_class_0
    
    # Ensure we don't try to sample more than available
    # Sample the smallest amount (either the number of class 0 claculated or the total number of class 0 images)
    n_class_0 = min(n_class_0, total_class_0)
    # Sample the smallest amount (either the number of class 1 claculated or the total number of class 01images)
    n_class_1 = min(n_class_1, total_class_1)
    
    # Randomly sample from each class (but reproducibly due to the set seed)
    # Randomly select n_class_0 indices from class 0 without replacement
    selected_class_0 = np.random.choice(class_0_indices, size=n_class_0, replace=False)
    # Randomly select n_class_1 indices from class 1 without replacement
    selected_class_1 = np.random.choice(class_1_indices, size=n_class_1, replace=False)
    
    # Combine and sort indices
    # Combine the selected indices from both classes into a single array
    selected_indices = np.concatenate([selected_class_0, selected_class_1])
    # Sort the indices in ascending order for sequential access
    selected_indices = np.sort(selected_indices)
    
    return selected_indices # Return the sorted array of selected indices


def extract_and_downsample_split(images_h5_path, labels_h5_path, output_dir, 
                                  split_name, n_samples, random_seed=42, 
                                  save_format='png'):
    """
    Purpose:
    --------
    Extract and downsample a single split directly from H5 files.
    Only extracts the selected subset, never storing all images.
    
    Parameters:
    -----------
    images_h5_path : str or Path
        Path to input images .h5 file
    labels_h5_path : str or Path
        Path to input labels .h5 file
    output_dir : str or Path
        Base directory where extracted images will be saved
    split_name : str
        Name of the split ('train', 'valid', 'test')
    n_samples : int
        Number of samples to extract
    random_seed : int
        Random seed for reproducibility
    save_format : str
        Image format to save ('png', 'jpg', etc.)
    
    Returns:
    --------
    selected_indices : numpy.ndarray
        Selected indices (for verification)
    """
    # Ensure the path to the images is a Path object
    images_h5_path = Path(images_h5_path)
    # Ensure the path to the labels is a Path object
    labels_h5_path = Path(labels_h5_path)
    # Convert the output directory (string or Path) to a Path object
    output_dir = Path(output_dir)
    
    # Create split-specific directories: output_dir/train, output_dir/val, output_dir/test
    split_dir = output_dir / split_name
    # Create path for images subdirectory within the split
    images_output_dir = split_dir / "images"
    # Create path for labels subdirectory within the split
    labels_output_dir = split_dir / "labels"
    # Create images directory and all parent directories if they don't already exist
    images_output_dir.mkdir(parents=True, exist_ok=True)
    # Create labels directory and all parent directories if they don't already exist
    labels_output_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"\n{'='*70}") # Print separator line
    # Print header showing which split is being processed
    print(f"Processing {split_name.upper()} split")
    print(f"{'='*70}") # Print separator line
    # Print the source h5 filename
    print(f"Source: {images_h5_path.name}")
    # Print the target number of samples to extract
    print(f"Target samples: {n_samples}")
    
    # Step 1: Load labels to determine stratified sampling
    print("\nStep 1: Loading labels for stratification...")

    try: # Try to find actual h5 file (manages both file nad directory paths)
        # Call helper function to find the h5 labels file
        actual_labels_path = find_h5_file(labels_h5_path)
        # Print the resolved path to the labels file
        print(f"  Found labels H5: {actual_labels_path}")
    # Catch any file not found errors
    except FileNotFoundError as e:
        print(f"  ERROR: {e}") # Print error message
        raise # Re-raise the exception to stop execution
    # Open the labels h5 file in read mode
    with h5py.File(actual_labels_path, 'r') as h5_file:
        # Read all labels from the 'y' dataset in the h5 file
        labels = h5_file['y'][:]
        # Remove extra dimensions (converting shape (n,1) to (n,))
        labels = labels.squeeze()
    
    # Get the total number of labels (original dataset size)
    original_size = len(labels)
    # Count how many samples have label 0 (no tumor)
    class_0_count = np.sum(labels == 0)
    # Count how many samples have label 1 (tumor)
    class_1_count = np.sum(labels == 1)
    
    # Print original dataset size with thousands separator
    print(f"  Original size: {original_size:,}")
    # Print class 0 count and percentage
    print(f"  Class 0 (no tumor): {class_0_count:,} ({class_0_count/original_size*100:.1f}%)")
    # Print class 1 count and percentage
    print(f"  Class 1 (tumor):    {class_1_count:,} ({class_1_count/original_size*100:.1f}%)")
    
    # Step 2: Get stratified sample indices
    # Print step 2 header with the random seed value used
    print("\nStep 2: Selecting stratified sample (reproducible with seed={})...".format(random_seed))
    # Call function to get stratified sample indices maintaining class balance
    selected_indices = get_stratified_indices(labels, n_samples, random_seed)
    # Get the labels corresponding to the selected indices
    selected_labels = labels[selected_indices]
    
    # Count class 0 samples in the selected subset
    new_class_0 = np.sum(selected_labels == 0)
    # Count class 1 samples in the selected subset
    new_class_1 = np.sum(selected_labels == 1)
    
    # Print number of selected samples
    print(f"  Selected: {len(selected_indices):,} samples")
    # Print class 0 count and percentage in selected subset
    print(f"  Class 0 (no tumor): {new_class_0:,} ({new_class_0/len(selected_indices)*100:.1f}%)")
    # Print class 1 count and percentage in selected subset
    print(f"  Class 1 (tumor):    {new_class_1:,} ({new_class_1/len(selected_indices)*100:.1f}%)")
    
    # Step 3: Extract only selected images (memory efficient!)
    print(f"\nStep 3: Extracting {len(selected_indices):,} selected images...")
    
    # Convert selected_indices to a set for O(1) lookup (quick)
    selected_indices_set = set(selected_indices)
    
    try: # Try to find the actual h5 images file (manages both directory and file path)
        # Call helper function to locate the h5 images file
        actual_images_path = find_h5_file(images_h5_path)
        # Print the resolved path to the images file
        print(f"  Found images H5: {actual_images_path}")
    # Catch any file not found errors
    except FileNotFoundError as e:
        print(f"  ERROR: {e}") # Print error message
        raise # Re-raise the exception to stope execution
    
    # Open the h5 images file in read mode
    with h5py.File(actual_images_path, 'r') as h5_file:
        # Get the reference to the 'x' dataset containing images (without loading all the images into memory)
        images_data = h5_file['x']
        
        # Extract only the selected indices
        extracted_count = 0 # Initialize counter for successfully extracted images
        # Loop through selected indices with enumeration for new sequential index
        for new_idx, original_idx in enumerate(tqdm(selected_indices, desc=f"Extracting {split_name}")):
            img_array = images_data[original_idx] # Read single image array from the h5 file at the original index
            
            # Convert the numpy array to a PIL Image object (specify RGB (color) mode)
            img = Image.fromarray(img_array.astype('uint8'), 'RGB')
            # Create the output path with zero-padded 5-digit filename
            output_path = images_output_dir / f"{new_idx:05d}.{save_format}"
            # Save the iamge to disk in specified format
            img.save(output_path)
            
            # Save corresponding label
            # Get the label value for this image from the selected labels
            label_value = selected_labels[new_idx]
            # Create the output path for the label text file
            label_path = labels_output_dir / f"{new_idx:05d}.txt"
            with open(label_path, 'w') as f: # Open the label file for writing
                # Write the label as an integer string to the file
                f.write(str(int(label_value)))
            
            extracted_count += 1 # Increment the extraction counter
    
    # Pritn success message with the extraction count
    print(f"\n✓ Successfully extracted {extracted_count:,} images")
    # Print the split directory path
    print(f"  Split directory: {split_dir}")
    # Pritn the images subdirectory path
    print(f"  Images: {images_output_dir}")
    # Print the labels subdirectory path
    print(f"  Labels: {labels_output_dir}")
    
    return selected_indices # Return the selected indices array for verification


def extract_and_downsample_pcam(base_dir, output_dir, target_size=50000,
                                 train_ratio=0.8, val_ratio=0.1, test_ratio=0.1,
                                 random_seed=42, save_format='png'):
    """
    Purpose:
    --------
    Extract and downsample entire PCAM dataset in one pass.
    Memory efficient: never stores more than necessary images at once.
    
    Parameters:
    -----------
    base_dir : str or Path
        Directory containing original PCAM .h5 files
    output_dir : str or Path
        Directory where downsampled extracted images will be saved
    target_size : int
        Total number of images in downsampled dataset
    train_ratio, val_ratio, test_ratio : float
        Ratios for splits
    random_seed : int
        Random seed for reproducibility (CRITICAL for reproducibility)
    save_format : str
        Image format to save ('png', 'jpg', etc.)
    """
    # Convert the base directory string to a Path object
    base_dir = Path(base_dir)
    # Convert the output directory string to a Path object
    output_dir = Path(output_dir)
    # Create the output directory and all parent directories if they don't already exist
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Call helper function to calcualte how many images shoudl be in each split (train, val, test)
    split_sizes = calculate_split_sizes(target_size, train_ratio, val_ratio, test_ratio)
    
    print("="*70) # Print header separator
    # Pritn main title
    print("PCAM EXTRACT & DOWNSAMPLE")
    print("="*70) # Print header separator
    # Print the base directory path
    print(f"Base directory: {base_dir}")
    # Print the output directory path
    print(f"Output directory: {output_dir}")
    # Pritn target total dataset size
    print(f"Target total size: {target_size:,}")
    # Print random seed for reproducibility
    print(f"Random seed: {random_seed} (ensures reproducibility)")
    # Print section ehader for split size
    print(f"\nTarget split sizes:")
    # Print train split size
    print(f"  Train: {split_sizes['train']:,} images")
    # Print validation split size
    print(f"  Valid: {split_sizes['valid']:,} images")
    # Print test split size
    print(f"  Test:  {split_sizes['test']:,} images")
    print("="*70) # Print header separator
    
    # Create list of split names to process
    splits = ['train', 'valid', 'test']
    # Initialize empty dictionary to store selected indices for each split
    indices_log = {}
    
    for split in splits: # Loop through each split name
        # Create path to the images h5 file for the current split
        images_h5 = base_dir / f"camelyonpatch_level_2_split_{split}_x.h5"
        # Create path to the labels h5 file for the current split
        labels_h5 = base_dir / f"camelyonpatch_level_2_split_{split}_y.h5"
        
        # Check if the given images h5 file exists
        if not images_h5.exists():
            # Print warning and skip this split if the file isn't found
            print(f"\nWARNING: {images_h5} not found, skipping {split}...")
            continue # Continue to next iteration of loop

        # Check if the given labels h5 file exists
        if not labels_h5.exists():
            # Print warning and skip this split if the file isn't found
            print(f"\nWARNING: {labels_h5} not found, skipping {split}...")
            continue # Continue to next iteration of loop
        
        # Call helper function to extract and downsample this split
        selected_indices = extract_and_downsample_split(
            images_h5_path=images_h5,
            labels_h5_path=labels_h5,
            output_dir=output_dir,
            split_name=split,
            n_samples=split_sizes[split],
            random_seed=random_seed,
            save_format=save_format
        )
        
        # Store the selected indices in the log dictionary
        indices_log[split] = selected_indices
    
    # Create path for file to save selected indices
    indices_file = output_dir / "selected_indices.npz"
    # Save all selected indices to compressed numpy file
    np.savez(indices_file, **indices_log)
    print(f"\n{'='*70}") # Print separator
    # Print confirmation of indices file save
    print(f"✓ Saved selected indices to {indices_file}")
    # Print note about file usage
    print("  (This file can be used to verify reproducibility)")
    
    # Print separator for completion section
    print("\n" + "="*70)
    # Print completion header
    print("EXTRACTION & DOWNSAMPLING COMPLETE")
    print("="*70) # Print separator
    # Print output directory path
    print(f"Downsampled dataset saved to: {output_dir}")
    # Calcualte and print total images extracted across all splits
    print(f"\nTotal images extracted: {sum(len(indices_log[s]) for s in indices_log):,}")
    # Print section header for reproducibility instructions
    print("\nTo verify reproducibility:")
    # Print instruction with the random seed used
    print(f"  Run this script again with random_seed={random_seed}")
    # Pritn comparison instruction
    print("  Compare the selected_indices.npz files")
    print("="*70) # Print final separator


def find_h5_file(base_path):
    """
    Purpose:
    --------
    Find the actual H5 file, handling both file and directory structures.
    
    Parameters:
    -----------
    base_path : Path
        Base path that might be a file or directory containing the H5 file
    
    Returns:
    --------
    Path : Path to the actual H5 file
    """
    # Convert the base path to a Path object
    base_path = Path(base_path)
    
    # Check if the path points to an existing file
    if base_path.is_file():
        # Return the path as-is since it's already a file
        return base_path
    
    # Check if the path points to an existing directory
    if base_path.is_dir():
        # Construct path to file with same name inside directory
        same_name_file = base_path / base_path.name
        # Check if that same-name file exists
        if same_name_file.is_file():
            # Return the same-name file path
            return same_name_file
        
        # Search for any h5 files in the directory
        h5_files = list(base_path.glob("*.h5"))
        # Check if any h5 files were found
        if h5_files:
            # Return the first h5 file found
            return h5_files[0]
    
    # Raise error if no h5 file could be found
    raise FileNotFoundError(f"Could not find H5 file at {base_path}")


def print_dataset_summary(base_dir):
    """
    Purpose:
    --------
    Print summary statistics for all splits before extraction.
    
    Parameters:
    -----------
    base_dir : str or Path
        Base directory containing all PCAM .h5 files
    """
    # Convert base directory to Path object
    base_dir = Path(base_dir)
    # Create list of split names to analyze
    splits = ['train', 'valid', 'test']
    
    print("="*70) # Print header separator
    # Print summary title
    print("PCAM Dataset Summary (Before Downsampling)")
    print("="*70) # Print header separator
    
    total_images = 0 # Initialize counter for total images across all splits
    total_class_0 = 0 # Initialize counter for total class 0 samples
    total_class_1 = 0 # Initialize counter for total class 1 samples
    
    # Loop through each split (train, val, test)
    for split in splits:
        # Construct path to labels h5 file for the current split
        labels_h5_base = base_dir / f"camelyonpatch_level_2_split_{split}_y.h5"
        
        # Check if the base path exists
        if not labels_h5_base.exists():
            # Print message if the file isn't found
            print(f"{split.capitalize():10s}: File not found")
            continue # Continue to the next split (next loop iteration)
        
        try: # Try to find the actual h5 file
            # Call helper function to locate the labels file
            labels_h5 = find_h5_file(labels_h5_base)
        # Catch file not found errors
        except FileNotFoundError:
            # Print error message if h5 file is not found in the directory
            print(f"{split.capitalize():10s}: H5 file not found in directory")
            continue # Continue to the next split (next loop iteration)
        
        # Open the labels h5 file in read mode
        with h5py.File(labels_h5, 'r') as h5_file:
            # Read all labels from the 'y' dataset
            labels = h5_file['y'][:]
            # Remove extra dimensions from labels array
            labels = labels.squeeze()
            
            # Get total number of samples in this split
            total = len(labels)
            # Count class 0 samples in this split
            class_0 = np.sum(labels == 0)
            # Count class 1 samples in this split
            class_1 = np.sum(labels == 1)
            
            # Print split name header with capitalization and padding
            print(f"\n{split.capitalize():10s}:")
            # Print total images in this split
            print(f"  Total images: {total:,}")
            # Print class 0 count and percentage
            print(f"  Class 0 (no tumor): {class_0:,} ({class_0/total*100:.1f}%)")
            # Print class 1 count and percentage
            print(f"  Class 1 (tumor):    {class_1:,} ({class_1/total*100:.1f}%)")
            
            total_images += total # Add current split's total to overall total
            total_class_0 += class_0 # Add current split's class 0 count to overall class 0 count
            total_class_1 += class_1 # Add current split's class 1 count to overall class 1 count
    
    print(f"\n{'='*70}") # Print separator for summary section
    print(f"Total Dataset:") # Print overall dataset header
    # Print total images across all splits
    print(f"  Total images: {total_images:,}")
    # Print total class 0 samples and percentage
    print(f"  Class 0 (no tumor): {total_class_0:,} ({total_class_0/total_images*100:.1f}%)")
    # Print total class 1 samples and percentage
    print(f"  Class 1 (tumor):    {total_class_1:,} ({total_class_1/total_images*100:.1f}%)")
    print("="*70) # Print closing separator

# Actual execution
if __name__ == "__main__":
    # Define the base directory containing the PCAM h5 files
    BASE_DIR = Path(r"E:\PCam")
    # Define the output directory for extracted images
    OUTPUT_DIR = Path(r"E:\PCam_Extracted_100k")
    TARGET_SIZE = 100000 # Set target size for downsampled dataset
    # Set rando mseed for reproducible sample (CRITICAL)
    RANDOM_SEED = 42
    
    # Print original dataset summary header
    print("Analyzing original dataset...\n")
    # Call function to print summary of original dataset
    print_dataset_summary(BASE_DIR)
    
    # Call function to calculate split sizes
    split_sizes = calculate_split_sizes(TARGET_SIZE)
    
    # Print header for extraction plan
    print("\n\nDataset Extraction Plan:")
    print("-" * 70) # Print separator line
    # Define original split sizes in dictionary
    original_splits = {
        'train': 262144,
        'valid': 32768,
        'test': 32768
    }
    
    # Loop through each split to print extraction plan
    for split in ['train', 'valid', 'test']:
        # Get original size for this split
        original = original_splits[split]
        # Get new size for this split
        new = split_sizes[split]
        # Calcualte percentage of original that will be kept
        kept_pct = (new / original) * 100
        # Print extraction details for this split
        print(f"{split.capitalize():6s}: Extract {new:6,} of {original:,} "
              f"({kept_pct:.1f}% of original)")
    print("-" * 70) # Print separator line
    
    # Print message regarding actions
    print("\nThis will extract and downsample images directly from H5 files.")
    # Print note regarding memory
    print("Only the selected subset will be saved to disk (memory efficient!).")
    # Prompt user for confirmation
    response = input("\nProceed? (yes/no): ")
    
    # Check if user confirmed with yes or y
    if response.lower() in ['yes', 'y']:
        # Run extraction and downsampling method
        extract_and_downsample_pcam(
            base_dir=BASE_DIR,
            output_dir=OUTPUT_DIR,
            target_size=TARGET_SIZE,
            random_seed=RANDOM_SEED,
            save_format='png'
        )
    else: # If user did not confirm
        # Do nothing and print the cancellation message
        print("Cancelled.")

50K Images Subset

In [None]:
import h5py # Reading HDF5 (.h5) file format
import numpy as np # Numerical operations and array manipulation
from pathlib import Path # Cross-platform file path handling
from PIL import Image # Image processing and saving
from tqdm import tqdm # Displayinng progress bars during loops

def calculate_split_sizes(total_target=50000, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
    """
    Purpose:
    --------
    Calculate how many images to keep in each split.
    
    Parameters:
    -----------
    total_target : int
        Total number of images desired in downsampled dataset
    train_ratio, val_ratio, test_ratio : float
        Ratios for train/val/test split
    
    Returns:
    --------
    dict : Number of images to keep for each split
    """
    # Calculate number of training images by multiplying total by train ratio
    train_size = int(total_target * train_ratio)
    # Calculate number of validation images by multiplying total by validation ratio
    val_size = int(total_target * val_ratio)
    # Calculate number of test images by multiplying by test ratio
    test_size = int(total_target * test_ratio)
    
    # Adjust to ensure exact total
    # Calcualte sum of all three split sizes to check for rounding errors
    actual_total = train_size + val_size + test_size
    # Check if actual total differs from target due to integer rounding
    if actual_total != total_target:
        # Add the difference to the training set to reach the exact target total
        train_size += (total_target - actual_total)
    
    # Return a dictionary mapping split names to their calculated sizes
    return {
        'train': train_size,
        'valid': val_size,
        'test': test_size
    }


def get_stratified_indices(labels, n_samples, random_seed=42):
    """
    Purpose:
    --------
    Get stratified sample indices to maintain class balance.
    Uses a fixed seed for reproducibility.
    
    Parameters:
    -----------
    labels : numpy.ndarray
        Array of binary labels (0 or 1)
    n_samples : int
        Number of samples to select
    random_seed : int
        Random seed for reproducibility
    
    Returns:
    --------
    selected_indices : numpy.ndarray
        Sorted array of selected indices
    """
    # Set random seed for reproducible sampling across runs
    # Initializes the psuedo-random number generator (PRNG) with a specific starting point
    np.random.seed(random_seed)
    
    # Get indices for each class
    # Find all indices where the label is 0 (no tumor class)
    class_0_indices = np.where(labels == 0)[0]
    # Find all indices where the label is 1 (tumor class)
    class_1_indices = np.where(labels == 1)[0]
    
    # Calculate how many of each class to keep (proportional)
    # Count total number of class 0 samples
    total_class_0 = len(class_0_indices)
    # Count total number of class 1 samples
    total_class_1 = len(class_1_indices)
    # Calculate total number of sampels across both classes
    total = total_class_0 + total_class_1
    
    # Calculate how many class 0 samples to keep, proportional to original distribution
    n_class_0 = int(n_samples * (total_class_0 / total))
    # Calculate class 1 samples as remainder to ensure exact n_samples total
    n_class_1 = n_samples - n_class_0
    
    # Ensure we don't try to sample more than available
    # Sample the smallest amount (either the number of class 0 claculated or the total number of class 0 images)
    n_class_0 = min(n_class_0, total_class_0)
    # Sample the smallest amount (either the number of class 1 claculated or the total number of class 01images)
    n_class_1 = min(n_class_1, total_class_1)
    
    # Randomly sample from each class (but reproducibly due to the set seed)
    # Randomly select n_class_0 indices from class 0 without replacement
    selected_class_0 = np.random.choice(class_0_indices, size=n_class_0, replace=False)
    # Randomly select n_class_1 indices from class 1 without replacement
    selected_class_1 = np.random.choice(class_1_indices, size=n_class_1, replace=False)
    
    # Combine and sort indices
    # Combine the selected indices from both classes into a single array
    selected_indices = np.concatenate([selected_class_0, selected_class_1])
    # Sort the indices in ascending order for sequential access
    selected_indices = np.sort(selected_indices)
    
    return selected_indices # Return the sorted array of selected indices


def extract_and_downsample_split(images_h5_path, labels_h5_path, output_dir, 
                                  split_name, n_samples, random_seed=42, 
                                  save_format='png'):
    """
    Purpose:
    --------
    Extract and downsample a single split directly from H5 files.
    Only extracts the selected subset, never storing all images.
    
    Parameters:
    -----------
    images_h5_path : str or Path
        Path to input images .h5 file
    labels_h5_path : str or Path
        Path to input labels .h5 file
    output_dir : str or Path
        Base directory where extracted images will be saved
    split_name : str
        Name of the split ('train', 'valid', 'test')
    n_samples : int
        Number of samples to extract
    random_seed : int
        Random seed for reproducibility
    save_format : str
        Image format to save ('png', 'jpg', etc.)
    
    Returns:
    --------
    selected_indices : numpy.ndarray
        Selected indices (for verification)
    """
    # Ensure the path to the images is a Path object
    images_h5_path = Path(images_h5_path)
    # Ensure the path to the labels is a Path object
    labels_h5_path = Path(labels_h5_path)
    # Convert the output directory (string or Path) to a Path object
    output_dir = Path(output_dir)
    
    # Create split-specific directories: output_dir/train, output_dir/val, output_dir/test
    split_dir = output_dir / split_name
    # Create path for images subdirectory within the split
    images_output_dir = split_dir / "images"
    # Create path for labels subdirectory within the split
    labels_output_dir = split_dir / "labels"
    # Create images directory and all parent directories if they don't already exist
    images_output_dir.mkdir(parents=True, exist_ok=True)
    # Create labels directory and all parent directories if they don't already exist
    labels_output_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"\n{'='*70}") # Print separator line
    # Print header showing which split is being processed
    print(f"Processing {split_name.upper()} split")
    print(f"{'='*70}") # Print separator line
    # Print the source h5 filename
    print(f"Source: {images_h5_path.name}")
    # Print the target number of samples to extract
    print(f"Target samples: {n_samples}")
    
    # Step 1: Load labels to determine stratified sampling
    print("\nStep 1: Loading labels for stratification...")

    try: # Try to find actual h5 file (manages both file nad directory paths)
        # Call helper function to find the h5 labels file
        actual_labels_path = find_h5_file(labels_h5_path)
        # Print the resolved path to the labels file
        print(f"  Found labels H5: {actual_labels_path}")
    # Catch any file not found errors
    except FileNotFoundError as e:
        print(f"  ERROR: {e}") # Print error message
        raise # Re-raise the exception to stop execution
    # Open the labels h5 file in read mode
    with h5py.File(actual_labels_path, 'r') as h5_file:
        # Read all labels from the 'y' dataset in the h5 file
        labels = h5_file['y'][:]
        # Remove extra dimensions (converting shape (n,1) to (n,))
        labels = labels.squeeze()
    
    # Get the total number of labels (original dataset size)
    original_size = len(labels)
    # Count how many samples have label 0 (no tumor)
    class_0_count = np.sum(labels == 0)
    # Count how many samples have label 1 (tumor)
    class_1_count = np.sum(labels == 1)
    
    # Print original dataset size with thousands separator
    print(f"  Original size: {original_size:,}")
    # Print class 0 count and percentage
    print(f"  Class 0 (no tumor): {class_0_count:,} ({class_0_count/original_size*100:.1f}%)")
    # Print class 1 count and percentage
    print(f"  Class 1 (tumor):    {class_1_count:,} ({class_1_count/original_size*100:.1f}%)")
    
    # Step 2: Get stratified sample indices
    # Print step 2 header with the random seed value used
    print("\nStep 2: Selecting stratified sample (reproducible with seed={})...".format(random_seed))
    # Call function to get stratified sample indices maintaining class balance
    selected_indices = get_stratified_indices(labels, n_samples, random_seed)
    # Get the labels corresponding to the selected indices
    selected_labels = labels[selected_indices]
    
    # Count class 0 samples in the selected subset
    new_class_0 = np.sum(selected_labels == 0)
    # Count class 1 samples in the selected subset
    new_class_1 = np.sum(selected_labels == 1)
    
    # Print number of selected samples
    print(f"  Selected: {len(selected_indices):,} samples")
    # Print class 0 count and percentage in selected subset
    print(f"  Class 0 (no tumor): {new_class_0:,} ({new_class_0/len(selected_indices)*100:.1f}%)")
    # Print class 1 count and percentage in selected subset
    print(f"  Class 1 (tumor):    {new_class_1:,} ({new_class_1/len(selected_indices)*100:.1f}%)")
    
    # Step 3: Extract only selected images (memory efficient!)
    print(f"\nStep 3: Extracting {len(selected_indices):,} selected images...")
    
    # Convert selected_indices to a set for O(1) lookup (quick)
    selected_indices_set = set(selected_indices)
    
    try: # Try to find the actual h5 images file (manages both directory and file path)
        # Call helper function to locate the h5 images file
        actual_images_path = find_h5_file(images_h5_path)
        # Print the resolved path to the images file
        print(f"  Found images H5: {actual_images_path}")
    # Catch any file not found errors
    except FileNotFoundError as e:
        print(f"  ERROR: {e}") # Print error message
        raise # Re-raise the exception to stope execution
    
    # Open the h5 images file in read mode
    with h5py.File(actual_images_path, 'r') as h5_file:
        # Get the reference to the 'x' dataset containing images (without loading all the images into memory)
        images_data = h5_file['x']
        
        # Extract only the selected indices
        extracted_count = 0 # Initialize counter for successfully extracted images
        # Loop through selected indices with enumeration for new sequential index
        for new_idx, original_idx in enumerate(tqdm(selected_indices, desc=f"Extracting {split_name}")):
            img_array = images_data[original_idx] # Read single image array from the h5 file at the original index
            
            # Convert the numpy array to a PIL Image object (specify RGB (color) mode)
            img = Image.fromarray(img_array.astype('uint8'), 'RGB')
            # Create the output path with zero-padded 5-digit filename
            output_path = images_output_dir / f"{new_idx:05d}.{save_format}"
            # Save the iamge to disk in specified format
            img.save(output_path)
            
            # Save corresponding label
            # Get the label value for this image from the selected labels
            label_value = selected_labels[new_idx]
            # Create the output path for the label text file
            label_path = labels_output_dir / f"{new_idx:05d}.txt"
            with open(label_path, 'w') as f: # Open the label file for writing
                # Write the label as an integer string to the file
                f.write(str(int(label_value)))
            
            extracted_count += 1 # Increment the extraction counter
    
    # Pritn success message with the extraction count
    print(f"\n✓ Successfully extracted {extracted_count:,} images")
    # Print the split directory path
    print(f"  Split directory: {split_dir}")
    # Pritn the images subdirectory path
    print(f"  Images: {images_output_dir}")
    # Print the labels subdirectory path
    print(f"  Labels: {labels_output_dir}")
    
    return selected_indices # Return the selected indices array for verification


def extract_and_downsample_pcam(base_dir, output_dir, target_size=50000,
                                 train_ratio=0.8, val_ratio=0.1, test_ratio=0.1,
                                 random_seed=42, save_format='png'):
    """
    Purpose:
    --------
    Extract and downsample entire PCAM dataset in one pass.
    Memory efficient: never stores more than necessary images at once.
    
    Parameters:
    -----------
    base_dir : str or Path
        Directory containing original PCAM .h5 files
    output_dir : str or Path
        Directory where downsampled extracted images will be saved
    target_size : int
        Total number of images in downsampled dataset
    train_ratio, val_ratio, test_ratio : float
        Ratios for splits
    random_seed : int
        Random seed for reproducibility (CRITICAL for reproducibility)
    save_format : str
        Image format to save ('png', 'jpg', etc.)
    """
    # Convert the base directory string to a Path object
    base_dir = Path(base_dir)
    # Convert the output directory string to a Path object
    output_dir = Path(output_dir)
    # Create the output directory and all parent directories if they don't already exist
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Call helper function to calcualte how many images shoudl be in each split (train, val, test)
    split_sizes = calculate_split_sizes(target_size, train_ratio, val_ratio, test_ratio)
    
    print("="*70) # Print header separator
    # Pritn main title
    print("PCAM EXTRACT & DOWNSAMPLE")
    print("="*70) # Print header separator
    # Print the base directory path
    print(f"Base directory: {base_dir}")
    # Print the output directory path
    print(f"Output directory: {output_dir}")
    # Pritn target total dataset size
    print(f"Target total size: {target_size:,}")
    # Print random seed for reproducibility
    print(f"Random seed: {random_seed} (ensures reproducibility)")
    # Print section ehader for split size
    print(f"\nTarget split sizes:")
    # Print train split size
    print(f"  Train: {split_sizes['train']:,} images")
    # Print validation split size
    print(f"  Valid: {split_sizes['valid']:,} images")
    # Print test split size
    print(f"  Test:  {split_sizes['test']:,} images")
    print("="*70) # Print header separator
    
    # Create list of split names to process
    splits = ['train', 'valid', 'test']
    # Initialize empty dictionary to store selected indices for each split
    indices_log = {}
    
    for split in splits: # Loop through each split name
        # Create path to the images h5 file for the current split
        images_h5 = base_dir / f"camelyonpatch_level_2_split_{split}_x.h5"
        # Create path to the labels h5 file for the current split
        labels_h5 = base_dir / f"camelyonpatch_level_2_split_{split}_y.h5"
        
        # Check if the given images h5 file exists
        if not images_h5.exists():
            # Print warning and skip this split if the file isn't found
            print(f"\nWARNING: {images_h5} not found, skipping {split}...")
            continue # Continue to next iteration of loop

        # Check if the given labels h5 file exists
        if not labels_h5.exists():
            # Print warning and skip this split if the file isn't found
            print(f"\nWARNING: {labels_h5} not found, skipping {split}...")
            continue # Continue to next iteration of loop
        
        # Call helper function to extract and downsample this split
        selected_indices = extract_and_downsample_split(
            images_h5_path=images_h5,
            labels_h5_path=labels_h5,
            output_dir=output_dir,
            split_name=split,
            n_samples=split_sizes[split],
            random_seed=random_seed,
            save_format=save_format
        )
        
        # Store the selected indices in the log dictionary
        indices_log[split] = selected_indices
    
    # Create path for file to save selected indices
    indices_file = output_dir / "selected_indices.npz"
    # Save all selected indices to compressed numpy file
    np.savez(indices_file, **indices_log)
    print(f"\n{'='*70}") # Print separator
    # Print confirmation of indices file save
    print(f"✓ Saved selected indices to {indices_file}")
    # Print note about file usage
    print("  (This file can be used to verify reproducibility)")
    
    # Print separator for completion section
    print("\n" + "="*70)
    # Print completion header
    print("EXTRACTION & DOWNSAMPLING COMPLETE")
    print("="*70) # Print separator
    # Print output directory path
    print(f"Downsampled dataset saved to: {output_dir}")
    # Calcualte and print total images extracted across all splits
    print(f"\nTotal images extracted: {sum(len(indices_log[s]) for s in indices_log):,}")
    # Print section header for reproducibility instructions
    print("\nTo verify reproducibility:")
    # Print instruction with the random seed used
    print(f"  Run this script again with random_seed={random_seed}")
    # Pritn comparison instruction
    print("  Compare the selected_indices.npz files")
    print("="*70) # Print final separator


def find_h5_file(base_path):
    """
    Purpose:
    --------
    Find the actual H5 file, handling both file and directory structures.
    
    Parameters:
    -----------
    base_path : Path
        Base path that might be a file or directory containing the H5 file
    
    Returns:
    --------
    Path : Path to the actual H5 file
    """
    # Convert the base path to a Path object
    base_path = Path(base_path)
    
    # Check if the path points to an existing file
    if base_path.is_file():
        # Return the path as-is since it's already a file
        return base_path
    
    # Check if the path points to an existing directory
    if base_path.is_dir():
        # Construct path to file with same name inside directory
        same_name_file = base_path / base_path.name
        # Check if that same-name file exists
        if same_name_file.is_file():
            # Return the same-name file path
            return same_name_file
        
        # Search for any h5 files in the directory
        h5_files = list(base_path.glob("*.h5"))
        # Check if any h5 files were found
        if h5_files:
            # Return the first h5 file found
            return h5_files[0]
    
    # Raise error if no h5 file could be found
    raise FileNotFoundError(f"Could not find H5 file at {base_path}")


def print_dataset_summary(base_dir):
    """
    Purpose:
    --------
    Print summary statistics for all splits before extraction.
    
    Parameters:
    -----------
    base_dir : str or Path
        Base directory containing all PCAM .h5 files
    """
    # Convert base directory to Path object
    base_dir = Path(base_dir)
    # Create list of split names to analyze
    splits = ['train', 'valid', 'test']
    
    print("="*70) # Print header separator
    # Print summary title
    print("PCAM Dataset Summary (Before Downsampling)")
    print("="*70) # Print header separator
    
    total_images = 0 # Initialize counter for total images across all splits
    total_class_0 = 0 # Initialize counter for total class 0 samples
    total_class_1 = 0 # Initialize counter for total class 1 samples
    
    # Loop through each split (train, val, test)
    for split in splits:
        # Construct path to labels h5 file for the current split
        labels_h5_base = base_dir / f"camelyonpatch_level_2_split_{split}_y.h5"
        
        # Check if the base path exists
        if not labels_h5_base.exists():
            # Print message if the file isn't found
            print(f"{split.capitalize():10s}: File not found")
            continue # Continue to the next split (next loop iteration)
        
        try: # Try to find the actual h5 file
            # Call helper function to locate the labels file
            labels_h5 = find_h5_file(labels_h5_base)
        # Catch file not found errors
        except FileNotFoundError:
            # Print error message if h5 file is not found in the directory
            print(f"{split.capitalize():10s}: H5 file not found in directory")
            continue # Continue to the next split (next loop iteration)
        
        # Open the labels h5 file in read mode
        with h5py.File(labels_h5, 'r') as h5_file:
            # Read all labels from the 'y' dataset
            labels = h5_file['y'][:]
            # Remove extra dimensions from labels array
            labels = labels.squeeze()
            
            # Get total number of samples in this split
            total = len(labels)
            # Count class 0 samples in this split
            class_0 = np.sum(labels == 0)
            # Count class 1 samples in this split
            class_1 = np.sum(labels == 1)
            
            # Print split name header with capitalization and padding
            print(f"\n{split.capitalize():10s}:")
            # Print total images in this split
            print(f"  Total images: {total:,}")
            # Print class 0 count and percentage
            print(f"  Class 0 (no tumor): {class_0:,} ({class_0/total*100:.1f}%)")
            # Print class 1 count and percentage
            print(f"  Class 1 (tumor):    {class_1:,} ({class_1/total*100:.1f}%)")
            
            total_images += total # Add current split's total to overall total
            total_class_0 += class_0 # Add current split's class 0 count to overall class 0 count
            total_class_1 += class_1 # Add current split's class 1 count to overall class 1 count
    
    print(f"\n{'='*70}") # Print separator for summary section
    print(f"Total Dataset:") # Print overall dataset header
    # Print total images across all splits
    print(f"  Total images: {total_images:,}")
    # Print total class 0 samples and percentage
    print(f"  Class 0 (no tumor): {total_class_0:,} ({total_class_0/total_images*100:.1f}%)")
    # Print total class 1 samples and percentage
    print(f"  Class 1 (tumor):    {total_class_1:,} ({total_class_1/total_images*100:.1f}%)")
    print("="*70) # Print closing separator

# Actual execution
if __name__ == "__main__":
    # Define the base directory containing the PCAM h5 files
    BASE_DIR = Path(r"E:\PCam")
    # Define the output directory for extracted images
    OUTPUT_DIR = Path(r"E:\PCam_Extracted_100k")
    TARGET_SIZE = 50000 # Set target size for downsampled dataset
    # Set rando mseed for reproducible sample (CRITICAL)
    RANDOM_SEED = 42
    
    # Print original dataset summary header
    print("Analyzing original dataset...\n")
    # Call function to print summary of original dataset
    print_dataset_summary(BASE_DIR)
    
    # Call function to calculate split sizes
    split_sizes = calculate_split_sizes(TARGET_SIZE)
    
    # Print header for extraction plan
    print("\n\nDataset Extraction Plan:")
    print("-" * 70) # Print separator line
    # Define original split sizes in dictionary
    original_splits = {
        'train': 262144,
        'valid': 32768,
        'test': 32768
    }
    
    # Loop through each split to print extraction plan
    for split in ['train', 'valid', 'test']:
        # Get original size for this split
        original = original_splits[split]
        # Get new size for this split
        new = split_sizes[split]
        # Calcualte percentage of original that will be kept
        kept_pct = (new / original) * 100
        # Print extraction details for this split
        print(f"{split.capitalize():6s}: Extract {new:6,} of {original:,} "
              f"({kept_pct:.1f}% of original)")
    print("-" * 70) # Print separator line
    
    # Print message regarding actions
    print("\nThis will extract and downsample images directly from H5 files.")
    # Print note regarding memory
    print("Only the selected subset will be saved to disk (memory efficient!).")
    # Prompt user for confirmation
    response = input("\nProceed? (yes/no): ")
    
    # Check if user confirmed with yes or y
    if response.lower() in ['yes', 'y']:
        # Run extraction and downsampling method
        extract_and_downsample_pcam(
            base_dir=BASE_DIR,
            output_dir=OUTPUT_DIR,
            target_size=TARGET_SIZE,
            random_seed=RANDOM_SEED,
            save_format='png'
        )
    else: # If user did not confirm
        # Do nothing and print the cancellation message
        print("Cancelled.")

Analyzing original dataset...

PCAM Dataset Summary (Before Downsampling)

Train     :
  Total images: 262,144
  Class 0 (no tumor): 131,072 (50.0%)
  Class 1 (tumor):    131,072 (50.0%)

Valid     :
  Total images: 32,768
  Class 0 (no tumor): 16,399 (50.0%)
  Class 1 (tumor):    16,369 (50.0%)

Test      :
  Total images: 32,768
  Class 0 (no tumor): 16,391 (50.0%)
  Class 1 (tumor):    16,377 (50.0%)

Total Dataset:
  Total images: 327,680
  Class 0 (no tumor): 163,862 (50.0%)
  Class 1 (tumor):    163,818 (50.0%)


Dataset Extraction Plan:
----------------------------------------------------------------------
Train : Extract 40,000 of 262,144 (15.3% of original)
Valid : Extract  5,000 of 32,768 (15.3% of original)
Test  : Extract  5,000 of 32,768 (15.3% of original)
----------------------------------------------------------------------

This will extract and downsample images directly from H5 files.
Only the selected subset will be saved to disk (memory efficient!).
PCAM EXTRACT & 

Extracting train: 100%|██████████| 40000/40000 [22:27<00:00, 29.69it/s]



✓ Successfully extracted 40,000 images
  Split directory: E:\PCam_Extracted_50k\train
  Images: E:\PCam_Extracted_50k\train\images
  Labels: E:\PCam_Extracted_50k\train\labels

Processing VALID split
Source: camelyonpatch_level_2_split_valid_x.h5
Target samples: 5000

Step 1: Loading labels for stratification...
  Found labels H5: E:\PCam\camelyonpatch_level_2_split_valid_y.h5\camelyonpatch_level_2_split_valid_y.h5
  Original size: 32,768
  Class 0 (no tumor): 16,399 (50.0%)
  Class 1 (tumor):    16,369 (50.0%)

Step 2: Selecting stratified sample (reproducible with seed=42)...
  Selected: 5,000 samples
  Class 0 (no tumor): 2,502 (50.0%)
  Class 1 (tumor):    2,498 (50.0%)

Step 3: Extracting 5,000 selected images...
  Found images H5: E:\PCam\camelyonpatch_level_2_split_valid_x.h5\camelyonpatch_level_2_split_valid_x.h5


Extracting valid: 100%|██████████| 5000/5000 [01:13<00:00, 67.68it/s]



✓ Successfully extracted 5,000 images
  Split directory: E:\PCam_Extracted_50k\valid
  Images: E:\PCam_Extracted_50k\valid\images
  Labels: E:\PCam_Extracted_50k\valid\labels

Processing TEST split
Source: camelyonpatch_level_2_split_test_x.h5
Target samples: 5000

Step 1: Loading labels for stratification...
  Found labels H5: E:\PCam\camelyonpatch_level_2_split_test_y.h5\camelyonpatch_level_2_split_test_y.h5
  Original size: 32,768
  Class 0 (no tumor): 16,391 (50.0%)
  Class 1 (tumor):    16,377 (50.0%)

Step 2: Selecting stratified sample (reproducible with seed=42)...
  Selected: 5,000 samples
  Class 0 (no tumor): 2,501 (50.0%)
  Class 1 (tumor):    2,499 (50.0%)

Step 3: Extracting 5,000 selected images...
  Found images H5: E:\PCam\camelyonpatch_level_2_split_test_x.h5\camelyonpatch_level_2_split_test_x.h5


Extracting test: 100%|██████████| 5000/5000 [01:05<00:00, 75.86it/s]



✓ Successfully extracted 5,000 images
  Split directory: E:\PCam_Extracted_50k\test
  Images: E:\PCam_Extracted_50k\test\images
  Labels: E:\PCam_Extracted_50k\test\labels

✓ Saved selected indices to E:\PCam_Extracted_50k\selected_indices.npz
  (This file can be used to verify reproducibility)

EXTRACTION & DOWNSAMPLING COMPLETE
Downsampled dataset saved to: E:\PCam_Extracted_50k

Total images extracted: 50,000

To verify reproducibility:
  Run this script again with random_seed=42
  Compare the selected_indices.npz files
