In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display



# --- Helper Functions ---

def compute_chip_score(row: pd.Series) -> float:
    """
    Computes a 'score' for a given chip to guide selection.
    This version is updated to use the 'dist_...' column names correctly.
    """
    # Create a dictionary of all the distribution columns for this row
    dist_cols = {col: row[col] for col in row.index if col.startswith('dist_')}

    # Dynamically find the correct column name instead of hardcoding indices
    def get_col_name(class_name):
        # Finds the column like 'dist_4:Car' by searching for the class name
        for col in dist_cols:
            if class_name in col:
                return col
        return None

    car_col = get_col_name('Car')
    bg_col = get_col_name('Background')
    water_col = get_col_name('Water')
    bldg_col = get_col_name('Building')
    veg_col = get_col_name('Vegetation')

    # If for some reason a column is not found, return a neutral score
    if not all([car_col, bg_col, water_col, bldg_col, veg_col]):
        print("Warning: Could not find all required columns in row.")
        return 0

    car_ratio = row[car_col]
    background_ratio = row[bg_col]
    water_ratio = row[water_col]

    # Rule 1: Always keep chips with cars
    if car_ratio > 0:
        return -1.0

    # Rule 2: Always keep chips with water, unless it's almost pure water
    if water_ratio > 0 and water_ratio < 0.95:
        return -1.0

    # Rule 3: Skip chips dominated by background
    if background_ratio >= 0.95:
        return float('inf')

    # Rule 4: Count how many unique classes exist in the chip
    unique_class_count = sum(1 for val in dist_cols.values() if val > 0.001)
    if unique_class_count <= 1:
        return float('inf')

    # Rule 5: Calculate score based on class proportions
    score = 0.0
    score += 2.5 * (1 - row[bldg_col])
    score += 11.0 * (1 - row[water_col])
    score += -19.5 * min(0.1 - row[veg_col], 0)

    # Rule 6: Reward chips for class diversity
    if unique_class_count >= 4:
        score -= 6.0
    elif unique_class_count == 3:
        score -= 3.0
    elif unique_class_count == 2:
        score -= 0.5

    return score


def plot_class_distribution_from_df(dataframe: pd.DataFrame, title: str = "Class Distribution"):
    """
    Plots the normalized pixel distribution for each class across a given DataFrame of chips.
    """
    if dataframe.empty:
        print(f"Warning: DataFrame for '{title}' is empty. Cannot plot distribution.")
        return
    
    # Use the correct `count_` columns for plotting pixel counts
    count_cols = [col for col in dataframe.columns if col.startswith('count_')]
    if not count_cols:
        print("Warning: No 'count_' columns found for plotting.")
        return

    pixel_sums = dataframe[count_cols].sum()
    total_pixels_across_df = pixel_sums.sum()

    if total_pixels_across_df == 0:
        return

    pixel_props = pixel_sums / total_pixels_across_df
    class_labels = [c.replace('count_', '') for c in count_cols]
    
    # Dynamically create colors based on the order of columns
    colours = []
    for label in class_labels:
        class_id_str = label.split(':')[0]
        class_id = int(class_id_str)
        colours.append(np.array(CLASS_TO_COLOR[class_id]) / 255.0)

    plt.figure(figsize=(10, 5))
    bars = plt.bar(class_labels, pixel_props, color=colours, edgecolor='black')
    plt.title(title)
    plt.xlabel("Class")
    plt.ylabel("Proportion")
    plt.grid(True, axis='y', linestyle='--', alpha=0.5)

    # Add proportion labels on top of bars
    for bar, prop in zip(bars, pixel_props):
        plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), f"{prop:.2%}",
                 ha='center', va='bottom', fontsize=9)

    plt.tight_layout()
    plt.show()

# --- Main Dataframe Loading Functions ---

def csv_to_df(split: str, metadata_path: str = os.path.join(base_dir, "train_metadata.csv"), subset: float = 1.0) -> pd.DataFrame:
    """
    Loads metadata from CSV and filters chips for a specific data split
    based on predefined lists of scene names.
    """
    df = pd.read_csv(metadata_path)

    if split == 'train':
        scene_list = TRAIN_SCENES
    elif split == 'val':
        scene_list = VAL_SCENES
    elif split == 'test':
        scene_list = TEST_SCENES
    else:
        raise ValueError(f"Invalid split type: {split}")

    # The correct way to filter: using .isin() on the 'source_file' column
    df_split = df[df['source_file'].isin(scene_list)].copy()
    
    print(f"Loaded {len(df_split)} chips for the '{split}' split from {len(scene_list)} scenes.")

    if split == 'train':
        # Apply scoring logic only to the training set
        df_split["score"] = df_split.apply(compute_chip_score, axis=1)

        keep_chips = df_split[df_split["score"] == -1.0]
        rest = df_split[(df_split['score'] != -1.0) & (df_split['score'] != float('inf'))].sort_values('score')
        
        num_to_select = int(len(df_split) * subset) - len(keep_chips)
        if num_to_select < 0: num_to_select = 0
            
        final_df = pd.concat([keep_chips, rest.head(num_to_select)])
        print(f"After scoring, returning {len(final_df)} chips for training.")
        plot_class_distribution_from_df(final_df, title=f"'{split.capitalize()}' Class Distribution")
        return final_df
    else:
        # For val and test, just return the filtered data
        plot_class_distribution_from_df(df_split, title=f"'{split.capitalize()}' Class Distribution")
        return df_split





'''
def csv_to_hard_df() -> pd.DataFrame:
    """
    Loads metadata and filters for 'hard' chips, typically used for a second stage of training.
    'Hard' chips are defined by specific criteria focusing on mixed classes,
    presence of water, building, clutter, and specific thresholds.

    Returns:
        pd.DataFrame: A DataFrame containing chips selected based on 'hard' criteria.
    """
    metadata_path = "/content/chipped_data/content/chipped_data/train_metadata.csv"
    df = pd.read_csv(metadata_path)

    # Exclude chips that belong to the validation or test sets (for training data)
    exclude_files = set(val_files + test_files)
    df = df[~df["source_file"].isin(exclude_files)].copy() # Use .copy() to avoid SettingWithCopyWarning

    # Define column references for convenience
    cols = {
        "building": "0: Building",
        "clutter": "1: Clutter",
        "vegetation": "2: Vegetation",
        "water": "3: Water",
        "background": "4: Background",
        "car": "5: Car"
    }

    # Helper column to count classes with non-zero pixel counts
    df['non_zero_classes'] = df[[*cols.values()]].gt(0).sum(axis=1)

    # --- Criteria for 'Hard' Chips ---
    # 1. Chips with water present alongside background
    water_and_background = df[(df[cols["water"]] > 0) & (df[cols["background"]] > 0)].copy()
    # 2. Chips with water present alongside vegetation
    water_and_veg = df[(df[cols["water"]] > 0) & (df[cols["vegetation"]] > 0)].copy()
    # 3. Chips with water, background, and vegetation all present
    water_background_veg = df[
        (df[cols["water"]] > 0) &
        (df[cols["background"]] > 0) &
        (df[cols["vegetation"]] > 0)
    ].copy()
    # 4. Chips with significant building and background, but building not dominating
    building_and_background = df[
        (df[cols["building"]] > 0.05) & (df[cols["background"]] > 0.1) &
        (df[cols["building"]] < 0.4) # reduce building-dominant chips
    ].copy()
    # 5. Chips that are almost pure water
    pure_water = df[(df[[*cols.values()]].sum(axis=1) > 0) & # Ensure not entirely empty
                    (df[cols["water"]] / df[[*cols.values()]].sum(axis=1) > 0.7) 
                   ].copy()

    # 6. A sample of chips that are almost pure building (to ensure building class isn't forgotten)
    # Ensure there are enough pure building chips to sample from before trying to sample
    potential_pure_building = df[(df[[*cols.values()]].sum(axis=1) > 0) &
                                 (df[cols["building"]] / df[[*cols.values()]].sum(axis=1) > 0.7)
                                ].copy()
    num_pure_building_to_sample = min(100, len(potential_pure_building)) 
    pure_building = potential_pure_building.sample(n=num_pure_building_to_sample, random_state=42).copy()


    # Combine all selected 'hard' chips and remove duplicates
    stage2_df = pd.concat([
        water_and_background,
        water_and_veg,
        water_background_veg,
        building_and_background,
        pure_water,
        pure_building,
    ]).drop_duplicates().reset_index(drop=True) # Reset index after concat and drop_duplicates

    plot_class_distribution_from_df(stage2_df, title="Stage 2 Class Distribution")
    print(f"\nSelected {len(stage2_df):,} hard chips for Stage 2 training.")
    return stage2_df
'''

def csv_to_full_df() -> pd.DataFrame:
    """
    Loads all entries from the metadata CSV into a DataFrame without any filtering.
    This method will include all chips from all source files and will not apply
    any background exclusion or score-based filtering.

    Returns:
        pd.DataFrame: A DataFrame containing all entries from the metadata CSV.
    """
    metadata_path = "/content/chipped_data/content/chipped_data/train_metadata.csv"
    df = pd.read_csv(metadata_path)
    print(f"Loaded {len(df):,} chips from the CSV, no filters applied.")
    return df