In [1]:
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold, KFold
import os
import shutil
import numpy as np

In [2]:
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold, KFold # KFold is already imported
import os
import shutil
import numpy as np

# --- Configuration ---
# !!! IMPORTANT: Update these paths to match your environment !!!
SPECIES_CSV_PATH = "../20250515/species_summary_stats.csv"
IMAGES_CSV_PATH = "../20250515/image_metadata.csv"
SOURCE_FILES_BASE_PATH = "../20250515/data/"  # Folder containing the actual image and JSON files

# --- New Output Locations ---
DESTINATION_TEST_SET_PATH = "./test_set_output/" # Specific location for the test set
DESTINATION_CLASSIFICATION_FOLDS_BASE_PATH = "./classification_folds_output/" # Base for fold1, fold2 (species specific)
DESTINATION_DETECTION_FOLDS_BASE_PATH = "./detection_folds_output/" # Base for detection_fold1, detection_fold2 (num_shapes >=1, no species limit)


# --- Helper Function to Copy Files ---
def copy_files_to_destination(df_subset, image_column_name, source_base_path, destination_folder_path):
    """
    Copies image and its corresponding JSON file from source to destination.
    Assumes JSON file has the same name as the image file but with a .json extension.
    """
    if not os.path.exists(destination_folder_path):
        os.makedirs(destination_folder_path, exist_ok=True)

    copied_count = 0
    skipped_image_not_found = 0
    skipped_json_not_found = 0

    for index, row in df_subset.iterrows():
        image_filename = row[image_column_name]
        if not isinstance(image_filename, str):
            print(f"Warning: Image filename is not a string, skipping row: {row}")
            skipped_image_not_found +=1
            continue

        base_name, img_ext = os.path.splitext(image_filename)
        json_filename = base_name + ".json"

        source_image_full_path = os.path.join(source_base_path, image_filename)
        source_json_full_path = os.path.join(source_base_path, json_filename)

        dest_image_full_path = os.path.join(destination_folder_path, image_filename)
        dest_json_full_path = os.path.join(destination_folder_path, json_filename)

        if os.path.exists(source_image_full_path):
            shutil.copy2(source_image_full_path, dest_image_full_path)
            copied_count += 1

            if os.path.exists(source_json_full_path):
                shutil.copy2(source_json_full_path, dest_json_full_path)
            else:
                print(f"Warning: JSON file not found for {image_filename} (image was copied): {source_json_full_path}")
                skipped_json_not_found +=1
        else:
            print(f"Warning: Image file not found, skipping: {source_image_full_path}")
            skipped_image_not_found += 1
            continue # If image not found, don't try to copy JSON


    print(f"Copied {copied_count} images to {destination_folder_path}.")
    if skipped_image_not_found > 0:
        print(f"Skipped {skipped_image_not_found} images (source file not found or invalid filename).")
    if skipped_json_not_found > 0:
        print(f"For {skipped_json_not_found} copied images, their corresponding JSON files were not found (but images were copied).")


# --- Main Script Logic ---
def main():
    # Create all base destination directories if they don't exist
    os.makedirs(DESTINATION_TEST_SET_PATH, exist_ok=True)
    os.makedirs(DESTINATION_CLASSIFICATION_FOLDS_BASE_PATH, exist_ok=True)
    os.makedirs(DESTINATION_DETECTION_FOLDS_BASE_PATH, exist_ok=True)

    # --- 1. Load Data ---
    print("Loading data...")
    try:
        df_species = pd.read_csv(SPECIES_CSV_PATH)
        df_images = pd.read_csv(IMAGES_CSV_PATH)
    except FileNotFoundError as e:
        print(f"Error: CSV file not found. {e}")
        print(f"Please ensure '{SPECIES_CSV_PATH}' and '{IMAGES_CSV_PATH}' are correct.")
        return

    print(f"Initial df_species shape: {df_species.shape}")
    print(f"Initial df_images shape: {df_images.shape}\n")

    # --- 2. Filter df_species (for classification dataset) ---
    print("Filtering species dataframe (for classification dataset)...")
    df_species['full_species_name'] = df_species['full_species_name'].astype(str)
    df_species_filtered = df_species[
        (df_species['total_shapes'] >= 50) &
        (df_species['num_images'] > 50) &
        (~df_species['full_species_name'].str.contains("unknown", case=False, na=False)) &
        (df_species['full_species_name'].str.lower() != "nan")
    ]
    valid_species_names = df_species_filtered['full_species_name'].unique()
    print(f"Number of species after filtering: {len(valid_species_names)}")
    if not len(valid_species_names):
        print("No species meet the criteria from df_species for classification dataset. Classification folds might be empty or skipped.")
    else:
        print(f"First 10 selected species names for classification: {valid_species_names[:10]} ...\n")

    # --- 3. Filter df_images (for classification dataset) ---
    print("Filtering images dataframe (for classification dataset)...")
    df_images_for_classification = df_images.copy()

    if 'full_species_name' not in df_images_for_classification.columns:
        if 'genus' in df_images_for_classification.columns and 'species' in df_images_for_classification.columns:
            print("Constructing 'full_species_name' in df_images_for_classification...")
            df_genus_str = df_images_for_classification['genus'].astype(str).fillna('')
            df_species_str = df_images_for_classification['species'].astype(str).fillna('')
            df_images_for_classification['full_species_name'] = df_genus_str + "_" + df_species_str
            df_images_for_classification.loc[df_images_for_classification['full_species_name'] == '_', 'full_species_name'] = 'unknown_both_empty'
        else:
            print("Error: 'full_species_name' not in df_images_for_classification, and 'genus'/'species' cols missing.")
            return
    else:
        print("'full_species_name' column already exists in df_images_for_classification. Assuming correct format.")

    df_images_for_classification['full_species_name'] = df_images_for_classification['full_species_name'].astype(str)
    df_images_classification_filtered = df_images_for_classification[
        df_images_for_classification['full_species_name'].isin(valid_species_names) &
        (df_images_for_classification['num_shapes'] >= 1)
    ].copy()

    print(f"Number of images for classification dataset after filtering: {df_images_classification_filtered.shape[0]}")
    
    df_test = pd.DataFrame() 
    df_train_val_classification = pd.DataFrame() 

    if df_images_classification_filtered.empty:
        print("No images meet criteria for classification dataset. Test set and classification folds will be empty/skipped.")
    else:
        stratify_column_classification = 'full_species_name'
        df_images_classification_filtered[stratify_column_classification] = df_images_classification_filtered[stratify_column_classification].fillna('unknown_species_final')
        class_counts_classification = df_images_classification_filtered[stratify_column_classification].value_counts()
        print("\nClass distribution in filtered classification image data (top 10):")
        print(class_counts_classification.head(10))

        if class_counts_classification.empty:
            print("No class data for classification stratification. Test set/classification folds might be problematic.")
        else:
            # --- 4. Create Test Set (from classification filtered data) ---
            print("\nCreating test set (from classification data, 15%)...")
            if df_images_classification_filtered.shape[0] < 2:
                print("Not enough data for test split. Test set will be empty.")
                df_train_val_classification = df_images_classification_filtered.copy()
            else:
                if (class_counts_classification < 2).any():
                    print("\nWarning: Some classes (classification) have only 1 sample. Stratified test split problematic.")
                
                sss = StratifiedShuffleSplit(n_splits=1, test_size=0.15, random_state=42)
                X_indices_cls = df_images_classification_filtered.index
                y_labels_cls = df_images_classification_filtered[stratify_column_classification]
                try:
                    train_val_indices_cls, test_indices_cls = next(sss.split(X_indices_cls, y_labels_cls))
                    df_test = df_images_classification_filtered.loc[X_indices_cls[test_indices_cls]]
                    df_train_val_classification = df_images_classification_filtered.loc[X_indices_cls[train_val_indices_cls]]
                except ValueError as e:
                    print(f"Error during test set splitting: {e}. Test set may be empty. Using all for train/val.")
                    df_train_val_classification = df_images_classification_filtered.copy()

            print(f"Test set size: {df_test.shape[0]} images")
            print(f"Data for classification folds: {df_train_val_classification.shape[0]} images")
            if not df_test.empty:
                print("Test set class distribution (normalized, top 5):")
                print(df_test[stratify_column_classification].value_counts(normalize=True).head())

            # --- 5. Copy Test Set Files ---
            print("\nCopying test set files...")
            if not df_test.empty:
                copy_files_to_destination(df_test, 'image_name', SOURCE_FILES_BASE_PATH, DESTINATION_TEST_SET_PATH)
            else:
                print("Test set is empty, no files to copy.")

            # --- 6. Create Classification Folds ---
            print("\nCreating 5 classification folds...")
            N_CLASSIFICATION_FOLDS = 5
            if df_train_val_classification.empty or df_train_val_classification.shape[0] < N_CLASSIFICATION_FOLDS:
                print(f"Not enough data for {N_CLASSIFICATION_FOLDS} classification folds. Skipping.")
            else:
                cls_fold_class_counts = df_train_val_classification[stratify_column_classification].value_counts()
                if (cls_fold_class_counts < N_CLASSIFICATION_FOLDS).any():
                    print(f"\nWarning: Some classes for classification folds have fewer than {N_CLASSIFICATION_FOLDS} samples. Stratification problematic.")
                
                skf_classification = StratifiedKFold(n_splits=N_CLASSIFICATION_FOLDS, shuffle=True, random_state=42)
                X_cv_cls_indices = df_train_val_classification.index
                y_cv_cls_labels = df_train_val_classification[stratify_column_classification]
                fold_num_cls = 1
                try:
                    for _, val_fold_indices_pos_cls in skf_classification.split(X_cv_cls_indices, y_cv_cls_labels):
                        current_fold_df_indices_cls = X_cv_cls_indices[val_fold_indices_pos_cls]
                        df_current_cls_fold = df_train_val_classification.loc[current_fold_df_indices_cls]
                        cls_fold_dest_path = os.path.join(DESTINATION_CLASSIFICATION_FOLDS_BASE_PATH, f"fold{fold_num_cls}")
                        
                        print(f"\n--- Processing Classification Fold {fold_num_cls} ---")
                        print(f"Size: {df_current_cls_fold.shape[0]} images. Path: {cls_fold_dest_path}")
                        if not df_current_cls_fold.empty:
                            copy_files_to_destination(df_current_cls_fold, 'image_name', SOURCE_FILES_BASE_PATH, cls_fold_dest_path)
                        fold_num_cls += 1
                except ValueError as e:
                    print(f"Error during Classification K-Fold splitting: {e}")

    # --- 7. Create Detection Folds (num_shapes >= 1, excluding test set images) ---
    print("\nCreating 5 detection folds...")
    df_detection_source = df_images[df_images['num_shapes'] >= 1].copy()
    print(f"Number of images initially with num_shapes >= 1: {df_detection_source.shape[0]}")

    if not df_test.empty: 
        test_set_image_names = df_test['image_name'].unique()
        df_detection_candidates = df_detection_source[~df_detection_source['image_name'].isin(test_set_image_names)].copy()
        print(f"Number of images for detection folds (after excluding test set): {df_detection_candidates.shape[0]}")
    else:
        print("Test set (df_test) was empty. Using all images with num_shapes >= 1 for detection folds.")
        df_detection_candidates = df_detection_source.copy()

    if df_detection_candidates.empty:
        print("No images available for detection folds. Skipping this step.")
    else:
        N_DETECTION_FOLDS = 5
        if df_detection_candidates.shape[0] < N_DETECTION_FOLDS:
            print(f"Not enough data ({df_detection_candidates.shape[0]}) for {N_DETECTION_FOLDS} detection folds. Skipping.")
        else:
            # User specified class distribution does not matter for detection data.
            # Using KFold for non-stratified splitting. Shuffle for randomness.
            print(f"\nSplitting detection data into {N_DETECTION_FOLDS} folds (non-stratified, shuffled).")
            kf_detection = KFold(n_splits=N_DETECTION_FOLDS, shuffle=True, random_state=123) # Using KFold
            
            X_detection_indices = df_detection_candidates.index # We'll split based on these indices

            fold_num_det = 1
            # KFold.split(X) yields (train_indices, test_indices) relative to X.
            # We are interested in the 'test_indices' part (val_fold_indices_pos_det) for each fold's content.
            for _, val_fold_indices_pos_det in kf_detection.split(X_detection_indices): # No y_labels needed for KFold
                current_fold_df_indices_det = X_detection_indices[val_fold_indices_pos_det]
                df_current_detection_fold = df_detection_candidates.loc[current_fold_df_indices_det]
                
                det_fold_dest_path = os.path.join(DESTINATION_DETECTION_FOLDS_BASE_PATH, f"detection_fold{fold_num_det}")

                print(f"\n--- Processing Detection Fold {fold_num_det} ---")
                print(f"Size: {df_current_detection_fold.shape[0]} images. Path: {det_fold_dest_path}")

                if not df_current_detection_fold.empty:
                    copy_files_to_destination(df_current_detection_fold, 'image_name', SOURCE_FILES_BASE_PATH, det_fold_dest_path)
                fold_num_det += 1
            # No specific ValueError for stratification issues is expected here with KFold

    print("\n--- Script Finished ---")

if __name__ == '__main__':
    main()

Loading data...
Initial df_species shape: (235, 3)
Initial df_images shape: (54421, 10)

Filtering species dataframe (for classification dataset)...
Number of species after filtering: 63
First 10 selected species names for classification: ['Ambrosiodmus_minor' 'Ambrosiophilus_atratus' 'Anisandrus_dispar'
 'Anisandrus_sayi' 'Cnestus_mutilatus' 'Coccotrypes_carpophagus'
 'Coccotrypes_dactyliperda' 'Coptoborus_ricini' 'Cryptocarenus_heveae'
 'Ctonoxylon_hagedorn'] ...

Filtering images dataframe (for classification dataset)...
Constructing 'full_species_name' in df_images_for_classification...
Number of images for classification dataset after filtering: 13538

Class distribution in filtered classification image data (top 10):
full_species_name
Ips_typographus              2093
Dendroctonus_valens          1759
Ips_sexdentatus              1516
Ips_acuminatus                964
Hylesinus_varius              519
Xylosandrus_crassiusculus     399
Dendroctonus_terebrans        309
Platypus_cy