In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display

# --- Global Configuration and Constants ---

# List of all available file prefixes (tile_ids without coordinate suffix)
# These are used to identify the source files for splitting the dataset.
all_files = [
    '107f24d6e9_F1BE1D4184INSPIRE', '11cdce7802_B6A62F8BE0INSPIRE', '12fa5e614f_53197F206FOPENPIPELINE', '130a76ebe1_68B40B480AOPENPIPELINE', 
    '1476907971_CHADGRISMOPENPIPELINE', '1553541487_APIGENERATED', '1553541585_APIGENERATED', '1553627230_APIGENERATED', '15efe45820_D95DF0B1F4INSPIRE', 
    '1726eb08ef_60693DB04DINSPIRE', '1d056881e8_29FEA32BC7INSPIRE', '1d4fbe33f3_F1BE1D4184INSPIRE', '1df70e7340_4413A67E91INSPIRE', '2552eb56dd_2AABB46C86OPENPIPELINE', 
    '25f1c24f30_EB81FE6E2BOPENPIPELINE', '2ef3a4994a_0CCD105428INSPIRE', '2ef883f08d_F317F9C1DFOPENPIPELINE', '34fbf7c2bd_E8AD935CEDINSPIRE', 
    '3502e187b2_23071E4605OPENPIPELINE', '39e77bedd0_729FB913CDOPENPIPELINE', '420d6b69b8_84B52814D2OPENPIPELINE', '520947aa07_8FCB044F58OPENPIPELINE', 
    '551063e3c5_8FCB044F58INSPIRE', '57426ebe1e_84B52814D2OPENPIPELINE', '5fa39d6378_DB9FF730D9OPENPIPELINE', '6f93b9026b_F1BFB8B17DOPENPIPELINE', 
    '7008b80b00_FF24A4975DINSPIRE', '74d7796531_EB81FE6E2BOPENPIPELINE', '7c719dfcc0_310490364FINSPIRE', '84410645db_8D20F02042OPENPIPELINE', 
    '8710b98ea0_06E6522D6DINSPIRE', '888432f840_80E7FD39EBINSPIRE', '9170479165_625EDFBAB6OPENPIPELINE', 'a1af86939f_F1BE1D4184OPENPIPELINE', 
    'b61673f780_4413A67E91INSPIRE', 'b705d0cc9c_E5F5E0E316OPENPIPELINE', 'b771104de5_7E02A41EBEOPENPIPELINE', 'c2e8370ca3_3340CAC7AEOPENPIPELINE', 
    'c37dbfae2f_84B52814D2OPENPIPELINE', 'c644f91210_27E21B7F30OPENPIPELINE', 'c6d131e346_536DE05ED2OPENPIPELINE', 'c8a7031e5f_32156F5DC2INSPIRE', 
    'cc4b443c7d_A9CBEF2C97INSPIRE', 'd06b2c67d2_2A62B67B52OPENPIPELINE', 'd9161f7e18_C05BA1BC72OPENPIPELINE', 'dabec5e872_E8AD935CEDINSPIRE', 
    'e87da4ebdb_29FEA32BC7INSPIRE', 'ebffe540d0_7BA042D858OPENPIPELINE', 'ec09336a6f_06BA0AF311OPENPIPELINE', 
    'f0747ed88d_E74C0DD8FDOPENPIPELINE', 'f4dd768188_NOLANOPENPIPELINE', 'f56b6b2232_2A62B67B52OPENPIPELINE', 
    'f971256246_MIKEINSPIRE', 'f9f43e5144_1DB9E6F68BINSPIRE', 'fc5837dcf8_7CD52BE09EINSPIRE'
]

# Prefixes for files designated for the validation set
val_files = [
    "c644f91210_27E21B7F30OPENPIPELINE",
    "f9f43e5144_1DB9E6F68BINSPIRE",
    "1d056881e8_29FEA32BC7INSPIRE",
    "3502e187b2_23071E4605OPENPIPELINE",
    "d9161f7e18_C05BA1BC72OPENPIPELINE",
    "c8a7031e5f_32156F5DC2INSPIRE",
    "551063e3c5_8FCB044F58INSPIRE",
    "fc5837dcf8_7CD52BE09EINSPIRE",
    "39e77bedd0_729FB913CDOPENPIPELINE",
]

# Prefixes for files designated for the test set
test_files = [
    "25f1c24f30_EB81FE6E2BOPENPIPELINE",
    "1d4fbe33f3_F1BE1D4184INSPIRE",
    "15efe45820_D95DF0B1F4INSPIRE",
    "c6d131e346_536DE05ED2OPENPIPELINE",
    "12fa5e614f_53197F206FOPENPIPELINE",
    "5fa39d6378_DB9FF730D9OPENPIPELINE",
    "ebffe540d0_7BA042D858OPENPIPELINE",
    "8710b98ea0_06E6522D6DINSPIRE",
    "84410645db_8D20F02042OPENPIPELINE",
    "a1af86939f_F1BE1D4184OPENPIPELINE"
]

# --- Class Definitions ---

NUM_CLASSES = 6 # Total number of semantic classes
# Column names for class pixel counts in the metadata CSV
class_cols = ['0: Building', '1: Clutter', '2: Vegetation', '3: Water', '4: Background', '5: Car']
CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']

# Mapping from RGB color (as tuple) to integer class ID
COLOR_TO_CLASS = {
    (230, 25, 75): 0,      # Red: Building
    (145, 30, 180): 1,     # Purple: Clutter
    (60, 180, 75): 2,      # Green: Vegetation
    (245, 130, 48): 3,     # Orange: Water
    (255, 255, 255): 4,    # White: Background
    (0, 130, 200): 5,      # Blue: Car
    # Note: (255, 0, 255): 6 exists in original COLOR_TO_CLASS but not in CLASS_NAMES.
    # Assuming it's an un-used or ignored class based on NUM_CLASSES = 6.
}

# Inverse mapping from integer class ID to RGB color (as tuple)
CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items() if v < NUM_CLASSES} # Ensure it matches NUM_CLASSES


# --- Helper Functions ---

def compute_chip_score(row: pd.Series) -> float:
    """
    Computes a 'score' for a given chip (row in the DataFrame) to guide selection.
    Lower scores indicate more desirable chips for training.
    
    Chips with cars are always kept (-1 score).
    Chips with water (but not almost entirely water) are always kept (-1 score).
    Chips dominated by background are skipped (inf score).
    Chips with only one class present are skipped (inf score).
    Other chips are scored based on the inverse proportion of certain classes,
    with bonuses for diversity.

    Args:
        row (pd.Series): A row from the metadata DataFrame, containing normalized
                         pixel counts for each class.

    Returns:
        float: A score indicating the chip's desirability.
               -1: Always keep (e.g., contains cars or significant water).
               inf: Always skip (e.g., mostly background or too homogeneous).
               Other float: Score for ranking; lower is better.
    """
    car_ratio = row["5: Car_norm"]
    background_ratio = row["4: Background_norm"]
    water_ratio = row["3: Water_norm"]

    # Create a dictionary of normalized class ratios for easier access
    class_ratios = {f"{i}: {cls}": row[f"{i}: {cls}_norm"] for i, cls in enumerate(CLASS_NAMES)}

    # Rule 1: Always keep chips with cars
    if car_ratio > 0:
        return -1.0 # Use float to be consistent with inf

    # Rule 2: Always keep chips with water, unless it's almost pure water
    if water_ratio > 0 and water_ratio < 0.95:
        return -1.0

    # Rule 3: Skip chips dominated by background
    if background_ratio >= 0.95: # Original: 0.825, adjusted to 0.95 as per original logic.
        return float('inf')

    # Rule 4: Count how many unique classes exist in the chip
    # A class is considered "present" if its normalized pixel count is above a small threshold
    unique_class_count = sum(1 for val in class_ratios.values() if val > 0.001)
    if unique_class_count == 1:
        return float('inf') # Skip if only one class is significantly present (too homogeneous)

    # Rule 5: Calculate score based on class proportions (lower is better)
    # Penalize lack of building, water, or too much vegetation (inverted score)
    score = 0.0
    score += 2.5 * max(1 - class_ratios["0: Building"], 0)   # Penalize chips with few buildings
    score += 11.0 * max(1 - class_ratios["3: Water"], 0)      # Heavily penalize chips with few water
    score += -19.5 * min(0.1 - class_ratios["2: Vegetation"], 0) # Reward chips with moderate vegetation (up to 0.1)

    # Rule 6: Reward chips for class diversity (lower score for more diverse chips)
    if unique_class_count >= 4:
        score -= 6.0
    elif unique_class_count == 3:
        score -= 3.0
    elif unique_class_count == 2:
        score -= 0.5
    # If unique_class_count is 1, it's already skipped by Rule 4, so this last `elif` is effectively unreachable
    # elif unique_class_count == 1:
    #     score += 5.0 # This would penalize homogeneous chips, but they are already skipped

    return score

def plot_class_distribution_from_df(dataframe: pd.DataFrame, title: str = "Class Distribution"):
    """
    Plots the normalized pixel distribution for each class across a given DataFrame of chips.

    Args:
        dataframe (pd.DataFrame): The DataFrame containing '0: Building', '1: Clutter', etc., columns.
        title (str): The title for the plot.
    """
    if dataframe.empty:
        print(f"Warning: DataFrame for '{title}' is empty. Cannot plot distribution.")
        return

    # Calculate total pixel sums for each class across all selected chips
    pixel_sums = dataframe[class_cols].sum()
    total_pixels_across_df = pixel_sums.sum()

    if total_pixels_across_df == 0:
        print(f"Warning: No pixels found in DataFrame for '{title}'. Cannot plot distribution.")
        return

    # Calculate proportion of each class relative to the total pixels in the DataFrame
    pixel_props = pixel_sums / total_pixels_across_df

    # Prepare labels and colors for the plot
    class_labels = [f"{i}: {CLASS_NAMES[i]}" for i in range(NUM_CLASSES)]
    # Convert normalized RGB to 0-1 range for matplotlib
    colours = [np.array(CLASS_TO_COLOR[i]) / 255.0 for i in range(NUM_CLASSES)]

    plt.figure(figsize=(10, 5))
    bars = plt.bar(class_labels, pixel_props, color=colours, edgecolor='black')
    plt.title(title)
    plt.xlabel("Class")
    plt.ylabel("Proportion")
    plt.grid(True, axis='y', linestyle='--', alpha=0.5)

    # Add proportion labels on top of bars
    for bar, prop in zip(bars, pixel_props):
        plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), f"{prop:.2%}",
                 ha='center', va='bottom', fontsize=9)

    plt.tight_layout()
    plt.show()


# --- Main Dataframe Loading Functions ---

def csv_to_df(split: str, subset: float = 0.27) -> pd.DataFrame:
    """
    Loads metadata from CSV and filters/scores chips for a specific data split (train, val, or test).

    Args:
        split (str): The desired data split ('train', 'val', or 'test').
        subset (float): For 'train' split, the proportion of chips to select based on scoring.
                        Ignored for 'val' and 'test'.

    Returns:
        pd.DataFrame: A DataFrame of selected chips for the specified split.

    Raises:
        ValueError: If an invalid split type is provided.
    """
    metadata_path = "/content/chipped_data/content/chipped_data/train_metadata.csv"
    df = pd.read_csv(metadata_path)

    if split == 'train':
        # Exclude chips that belong to the validation or test sets based on their tile_id prefix
        excluded_prefixes = val_files + test_files
        df = df[~df['tile_id'].apply(lambda tid: any(tid.startswith(p) for p in excluded_prefixes))].copy()

        # Calculate normalized pixel counts for scoring
        df['total'] = df[class_cols].sum(axis=1)
        for col in class_cols:
            # Handle potential division by zero if 'total' pixels for a chip is 0 (shouldn't happen with valid data)
            df[col + '_norm'] = df[col] / df['total'].replace(0, np.nan) # Replace 0 with NaN to avoid /0, will result in NaN ratios
            df[col + '_norm'] = df[col + '_norm'].fillna(0) # Fill NaN ratios with 0

        # Compute a score for each chip to determine inclusion in the training set
        df["score"] = df.apply(compute_chip_score, axis=1)

        # Filter chips based on computed scores
        keep_chips_car_water = df[df["score"] == -1.0].copy()
        skipped_chips_background_homo = df[df["score"] == float('inf')].copy()

        print(f"🚗 Chips with cars/significant water (kept): {len(keep_chips_car_water)} ({len(keep_chips_car_water)/len(df):.2%})")
        print(f"🧱 Chips skipped due to background/homogeneity: {len(skipped_chips_background_homo)} ({len(skipped_chips_background_homo)/len(df):.2%})")

        # Select the remaining chips based on their score
        # 'rest' contains chips that were not 'always_keep' and not 'skipped'
        rest = df[(df['score'] != -1.0) & (df['score'] != float('inf'))].sort_values('score').copy()

        # Determine how many more chips are needed to reach the 'subset' proportion
        num_chips_to_select_from_rest = int(len(df) * subset) - len(keep_chips_car_water)
        
        # Ensure we don't try to select more chips than available
        if num_chips_to_select_from_rest < 0:
            num_chips_to_select_from_rest = 0 # Should not happen if subset is reasonable
        
        # Concatenate the always-kept chips with the best-scoring chips from 'rest'
        best_chips = pd.concat([keep_chips_car_water, rest.head(num_chips_to_select_from_rest)])
        final_n = len(best_chips)

        plot_class_distribution_from_df(best_chips, title="Training Class Distribution")
        print(f"\n📦 Selected {final_n:,} chips from {len(df):,} total ({final_n / len(df):.2%})")

        return best_chips

    elif split in ['val', 'test']:
        # For validation and test sets, select chips based on predefined file lists
        file_list = val_files if split == 'val' else test_files
        df = df[df['tile_id'].apply(lambda tid: any(tid.startswith(p) for p in file_list))].copy()
        # Filter for chips that are at 256x256 pixel boundaries (assuming full tiles)
        df = df[(df['x'] % 256 == 0) & (df['y'] % 256 == 0)].copy()
        plot_class_distribution_from_df(df, title=f"{split.capitalize()} Class Distribution")
        return df

    else:
        raise ValueError(f"Invalid split: {split}. Choose from 'train', 'val', or 'test'.")


def csv_to_hard_df() -> pd.DataFrame:
    """
    Loads metadata and filters for 'hard' chips, typically used for a second stage of training.
    'Hard' chips are defined by specific criteria focusing on mixed classes,
    presence of water, building, clutter, and specific thresholds.

    Returns:
        pd.DataFrame: A DataFrame containing chips selected based on 'hard' criteria.
    """
    metadata_path = "/content/chipped_data/content/chipped_data/train_metadata.csv"
    df = pd.read_csv(metadata_path)

    # Exclude chips that belong to the validation or test sets (for training data)
    exclude_files = set(val_files + test_files)
    df = df[~df["source_file"].isin(exclude_files)].copy() # Use .copy() to avoid SettingWithCopyWarning

    # Define column references for convenience
    cols = {
        "building": "0: Building",
        "clutter": "1: Clutter",
        "vegetation": "2: Vegetation",
        "water": "3: Water",
        "background": "4: Background",
        "car": "5: Car"
    }

    # Helper column to count classes with non-zero pixel counts
    df['non_zero_classes'] = df[[*cols.values()]].gt(0).sum(axis=1)

    # --- Criteria for 'Hard' Chips ---
    # 1. Chips with water present alongside background
    water_and_background = df[(df[cols["water"]] > 0) & (df[cols["background"]] > 0)].copy()
    # 2. Chips with water present alongside vegetation
    water_and_veg = df[(df[cols["water"]] > 0) & (df[cols["vegetation"]] > 0)].copy()
    # 3. Chips with water, background, and vegetation all present
    water_background_veg = df[
        (df[cols["water"]] > 0) &
        (df[cols["background"]] > 0) &
        (df[cols["vegetation"]] > 0)
    ].copy()
    # 4. Chips with significant building and background, but building not dominating
    building_and_background = df[
        (df[cols["building"]] > 0.05) & (df[cols["background"]] > 0.1) &
        (df[cols["building"]] < 0.4) # reduce building-dominant chips
    ].copy()
    # 5. Chips that are almost pure water
    pure_water = df[(df[[*cols.values()]].sum(axis=1) > 0) & # Ensure not entirely empty
                    (df[cols["water"]] / df[[*cols.values()]].sum(axis=1) > 0.7) 
                   ].copy()

    # 6. A sample of chips that are almost pure building (to ensure building class isn't forgotten)
    # Ensure there are enough pure building chips to sample from before trying to sample
    potential_pure_building = df[(df[[*cols.values()]].sum(axis=1) > 0) &
                                 (df[cols["building"]] / df[[*cols.values()]].sum(axis=1) > 0.7)
                                ].copy()
    num_pure_building_to_sample = min(100, len(potential_pure_building)) 
    pure_building = potential_pure_building.sample(n=num_pure_building_to_sample, random_state=42).copy()


    # Combine all selected 'hard' chips and remove duplicates
    stage2_df = pd.concat([
        water_and_background,
        water_and_veg,
        water_background_veg,
        building_and_background,
        pure_water,
        pure_building,
    ]).drop_duplicates().reset_index(drop=True) # Reset index after concat and drop_duplicates

    plot_class_distribution_from_df(stage2_df, title="Stage 2 Class Distribution")
    print(f"\n📦 Selected {len(stage2_df):,} hard chips for Stage 2 training.")
    return stage2_df