In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns # Used for confusion matrix plotting (implicitly by ConfusionMatrixDisplay's default style)
import random
import gc # For garbage collection
from PIL import Image # For image manipulation, potentially for reconstruction or debugging
from datetime import datetime # For timestamping or time limits

# --- TensorFlow and Keras specific imports ---
import tensorflow as tf
import tensorflow.keras.backend as K
import segmentation_models as sm
from tensorflow.keras.metrics import MeanIoU # Explicit import for clarity

# --- Scikit-learn metrics ---
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
    ConfusionMatrixDisplay, # For plotting confusion matrices
    precision_recall_fscore_support # For per-class metrics
)

# --- Progress bar for loops ---
from tqdm import tqdm

# --- Additional Imports ---
import pandas as pd
from PIL import Image
import cv2
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral

# Clear Keras session to avoid conflicts from previous model definitions
K.clear_session()


# --- Global Configuration and Constants ---

# Number of semantic classes
NUM_CLASSES = 6
# Names corresponding to each class ID
CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']

# Mapping from RGB color (as tuple) to integer class ID for labels
COLOR_TO_CLASS = {
    (230, 25, 75): 0,      # Red: Building
    (145, 30, 180): 1,     # Purple: Clutter
    (60, 180, 75): 2,      # Green: Vegetation
    (245, 130, 48): 3,     # Orange: Water
    (255, 255, 255): 4,    # White: Background
    (0, 130, 200): 5,      # Blue: Car
    (255, 0, 255): 6       # Magenta: Often used as an "ignore" or "padding" pixel color
}

# Inverse mapping from integer class ID to RGB color (as tuple)
# Excludes class ID 6 (ignore class) from this mapping for normal visualization
CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items() if v < NUM_CLASSES}
IGNORE_COLOR = (255, 0, 255) # The specific RGB color for ignored regions (magenta)

# Default output directory for plots and saved models
out_dir = "/content/figs"

# --- Measurement Functions ---

def measure_inference_time(
    model: tf.keras.Model,
    generator: tf.data.Dataset,
    num_batches: int = 5
) -> None:
    """Measures inference time of a Keras model on a dataset.

    Args:
        model (tf.keras.Model): The trained Keras model to evaluate.
        generator (tf.data.Dataset): The input dataset for inference.
        num_batches (int, optional): Number of batches to use for timing. Defaults to 5.

    Returns:
        None
    """
    import time
    total_time = 0
    total_images = 0

    if num_batches is None:
        num_batches = tf.data.experimental.cardinality(generator).numpy()

    for i, (x_batch, _) in enumerate(generator.take(num_batches)):
        start = time.time()
        _ = model.predict(x_batch, verbose=0)
        end = time.time()
        total_time += (end - start)
        total_images += x_batch.shape[0]

    print(f"🧠 Inference time: {total_time:.2f} sec for {total_images} images")
    print(f"⏱️ Avg inference time per image: {total_time / total_images:.4f} sec")


# --- Plotting Functions ---

def plot_training_curves(
    history: tf.keras.callbacks.History,
    out_dir: str
) -> None:
    """Plots and saves training/validation loss and IoU curves.

    Args:
        history (tf.keras.callbacks.History): History object returned by `model.fit()`.
        out_dir (str): Directory path to save the generated plot.

    Returns:
        None
    """
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    history_dict = history.history
    required_keys = ["loss", "val_loss", "iou_score", "val_iou_score"]

    missing_keys = [k for k in required_keys if k not in history_dict]
    if missing_keys:
        print(f"⚠️ Missing keys in history: {missing_keys}")
        return

    fig, axs = plt.subplots(1, 2, figsize=(12, 5))

    # Plot loss
    axs[0].plot(history_dict["loss"], label="Train Loss")
    axs[0].plot(history_dict["val_loss"], label="Val Loss")
    axs[0].set_title("Loss over Epochs")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].legend()

    # Plot IoU
    axs[1].plot(history_dict["iou_score"], label="Train IoU")
    axs[1].plot(history_dict["val_iou_score"], label="Val IoU")
    axs[1].set_title("Mean IoU over Epochs")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Mean IoU")
    axs[1].legend()

    plt.tight_layout()
    save_path = os.path.join(out_dir, "training_curves.png")
    plt.savefig(save_path)
    plt.show()
    plt.close()

    print(f"✅ Saved training curves to: {save_path}")


from typing import List, Optional
def visualise_prediction_grid(
    rgb_list: List[np.ndarray],
    true_mask_list: List[np.ndarray],
    pred_mask_list: List[np.ndarray],
    tile_id_list: Optional[List[str]] = None,
    all_tile_ids: Optional[List[str]] = None,
    n_rows: int = 4,
    n_cols: int = 3
) -> None:
    """Displays a grid of RGB images, ground truth masks, and predicted masks.

    Args:
        rgb_list (List[np.ndarray]): List of input RGB images (uint8, 0–255).
        true_mask_list (List[np.ndarray]): List of one-hot encoded ground truth masks.
        pred_mask_list (List[np.ndarray]): List of predicted masks (class ID format).
        tile_id_list (Optional[List[str]]): Tile IDs to prioritise for visualisation.
        all_tile_ids (Optional[List[str]]): Full list of tile IDs aligned with inputs.
        n_rows (int): Number of rows in the grid.
        n_cols (int): Number of triplet columns (each triplet is Input, Ground Truth, Prediction).

    Returns:
        None
    """
    total = n_rows * n_cols
    fig, axs = plt.subplots(n_rows, n_cols * 3, figsize=(n_cols * 6.6, n_rows * 2.6))

    indices_to_plot = []

    if tile_id_list and all_tile_ids:
        tile_id_set = set(tile_id_list)
        matched_indices = [i for i, tid in enumerate(all_tile_ids) if tid in tile_id_set]
        random.shuffle(matched_indices)
        indices_to_plot = matched_indices[:total]

        if len(indices_to_plot) < total:
            all_indices = list(set(range(len(rgb_list))) - set(indices_to_plot))
            random.shuffle(all_indices)
            indices_to_plot += all_indices[:total - len(indices_to_plot)]
    else:
        indices_to_plot = list(range(min(total, len(rgb_list))))

    for idx_plot, data_idx in enumerate(indices_to_plot):
        rgb = rgb_list[data_idx]
        true_mask = np.argmax(true_mask_list[data_idx], axis=-1)
        pred_mask = pred_mask_list[data_idx]

        h, w = true_mask.shape
        true_rgb = np.zeros((h, w, 3), dtype=np.uint8)
        pred_rgb = np.zeros((h, w, 3), dtype=np.uint8)

        for class_id, color in CLASS_TO_COLOR.items():
            true_rgb[true_mask == class_id] = color
            pred_rgb[pred_mask == class_id] = color

        ignore_mask = np.all(true_mask_list[data_idx] == 0, axis=-1)
        true_rgb[ignore_mask] = (255, 0, 255)
        pred_rgb[ignore_mask] = (255, 0, 255)

        row = idx_plot // n_cols
        col = (idx_plot % n_cols) * 3

        axs[row, col + 0].imshow(rgb)
        axs[row, col + 0].set_title("Input")
        axs[row, col + 1].imshow(true_rgb)
        axs[row, col + 1].set_title("Ground Truth")
        axs[row, col + 2].imshow(pred_rgb)
        axs[row, col + 2].set_title("Prediction")

        for i in range(3):
            axs[row, col + i].axis("off")

    plt.tight_layout()
    plt.show()
    plt.close(fig)


'''
from typing import List, Optional

def visualise_prediction_grid(
    rgb_list: List[np.ndarray],
    true_mask_list: List[np.ndarray],
    pred_mask_list: List[np.ndarray],
    specific_tile_ids: Optional[List[str]] = None,
    all_tile_ids: Optional[List[str]] = None,
    n_rows: int = 4,
    n_cols: int = 3
) -> None:
    """Displays a grid of RGB images, ground truth masks, and predicted masks.

    Args:
        rgb_list (List[np.ndarray]): List of input RGB images (uint8, 0–255).
        true_mask_list (List[np.ndarray]): List of one-hot encoded ground truth masks.
        pred_mask_list (List[np.ndarray]): List of predicted masks (class ID format).
        specific_tile_ids (Optional[List[str]]): Tile IDs to prioritise for visualisation.
        all_tile_ids (Optional[List[str]]): Full list of tile IDs aligned with inputs.
        n_rows (int): Number of rows in the grid.
        n_cols (int): Number of triplet columns (each triplet is Input, Ground Truth, Prediction).

    Returns:
        None
    """
    import random
    total = n_rows * n_cols
    fig, axs = plt.subplots(n_rows, n_cols * 3, figsize=(n_cols * 6.6, n_rows * 2.6))

    random.seed(420)
    indices_to_plot = []

    if specific_tile_ids and all_tile_ids:
        tile_id_set = set(specific_tile_ids)
        matched_indices = [i for i, tid in enumerate(all_tile_ids) if tid in tile_id_set]
        random.shuffle(matched_indices)
        indices_to_plot = matched_indices[:total]

        if len(indices_to_plot) < total:
            filler = list(set(range(len(rgb_list))) - set(indices_to_plot))
            random.shuffle(filler)
            indices_to_plot += filler[:total - len(indices_to_plot)]
    else:
        # Fallback: prioritise chips with water, clutter, or car (class IDs 3, 1, 5)
        scores = []
        for i, one_hot_mask in enumerate(true_mask_list):
            class_ids = np.unique(np.argmax(one_hot_mask, axis=-1))
            score = sum(1 for cls in class_ids if cls in {1, 3, 5})  # clutter, water, car
            scores.append((i, score))
        scores.sort(key=lambda x: -x[1])
        indices_to_plot = [i for i, _ in scores[:total]]

    for idx_plot, data_idx in enumerate(indices_to_plot):
        rgb = rgb_list[data_idx]
        true_mask = np.argmax(true_mask_list[data_idx], axis=-1)
        pred_mask = pred_mask_list[data_idx]

        h, w = true_mask.shape
        true_rgb = np.zeros((h, w, 3), dtype=np.uint8)
        pred_rgb = np.zeros((h, w, 3), dtype=np.uint8)

        for class_id, colour in CLASS_TO_COLOR.items():
            true_rgb[true_mask == class_id] = colour
            pred_rgb[pred_mask == class_id] = colour

        ignore_mask = np.all(true_mask_list[data_idx] == 0, axis=-1)
        true_rgb[ignore_mask] = (255, 0, 255)
        pred_rgb[ignore_mask] = (255, 0, 255)

        row = idx_plot // n_cols
        col = (idx_plot % n_cols) * 3

        axs[row, col + 0].imshow(rgb)
        axs[row, col + 0].set_title("Input")
        axs[row, col + 1].imshow(true_rgb)
        axs[row, col + 1].set_title("Ground Truth")
        axs[row, col + 2].imshow(pred_rgb)
        axs[row, col + 2].set_title("Prediction")

        for i in range(3):
            axs[row, col + i].axis("off")

    plt.tight_layout()
    plt.show()
    plt.close(fig)

'''











from typing import Union
def normalize_confusion_matrix(
    cm: Union[np.ndarray, list],
    norm: str = 'true'
) -> np.ndarray:
    """Normalises a confusion matrix by rows, columns, or entire matrix.

    Args:
        cm (Union[np.ndarray, list]): Raw confusion matrix.
        norm (str): Normalisation method. Options:
            - 'true': Normalise by rows (ground truth labels).
            - 'pred': Normalise by columns (predicted labels).
            - 'all': Normalise entire matrix to sum to 1.

    Returns:
        np.ndarray: Normalised confusion matrix.

    Raises:
        ValueError: If `norm` is not one of 'true', 'pred', or 'all'.
    """
    cm = np.array(cm, dtype=np.float32)

    if norm == 'true':
        cm_normalized = cm / cm.sum(axis=1, keepdims=True)
    elif norm == 'pred':
        cm_normalized = cm / cm.sum(axis=0, keepdims=True)
    elif norm == 'all':
        cm_normalized = cm / cm.sum()
    else:
        raise ValueError("Unknown normalization type. Use 'true', 'pred', or 'all'.")

    return cm_normalized


# --- Main Evaluation Function ---
from typing import Optional, List

def evaluate_on_test(
    model: tf.keras.Model,
    test_gen: tf.data.Dataset,
    test_df: pd.DataFrame,
    out_dir: str,
    image_dir: str,
    label_dir: str,
    tile_size: int = 256,
    n_rows: int = 4,
    n_cols: int = 3,
    specific_tile_ids: Optional[List[str]] = None
) -> None:
    """Evaluates the model on the test set and generates metrics and visualisations.

    This includes mean IoU, macro F1, Precision, Recall, a confusion matrix,
    and a prediction visualisation grid.

    Args:
        model (tf.keras.Model): Trained segmentation model.
        test_gen (tf.data.Dataset): Test dataset generator.
        test_df (pd.DataFrame): DataFrame with test metadata including tile IDs.
        out_dir (str): Directory to save output plots.
        image_dir (str): Directory containing RGB images (not used in this function).
        label_dir (str): Directory containing label images (not used in this function).
        tile_size (int): Size of each tile in pixels (e.g. 256x256).
        n_rows (int): Number of rows in the prediction grid.
        n_cols (int): Number of columns in the prediction grid.
        specific_tile_ids (Optional[List[str]]): Tile IDs to prioritise for visualisation.

    Returns:
        None
    """
    print("🧪 Running evaluation on test set...")

    all_test_preds = []
    all_test_trues = []

    visual_rgb = []
    visual_true = []
    visual_pred = []
    visual_tile_ids = []
    visual_limit = n_rows * n_cols if n_rows and n_cols else 5

    os.makedirs(out_dir, exist_ok=True)
    test_tile_ids = test_df['tile_id'].tolist()
    tile_index = 0

    for batch_x, batch_y in test_gen.as_numpy_iterator():
        if batch_x.size == 0:
            continue

        pred = model.predict(batch_x, verbose=0)
        pred_mask = np.argmax(pred, axis=-1).astype(np.uint8)
        true_mask = np.argmax(batch_y, axis=-1).astype(np.uint8)

        for j in range(batch_x.shape[0]):
            if tile_index >= len(test_tile_ids):
                break

            tile_id = test_tile_ids[tile_index]
            tile_index += 1

            if len(visual_rgb) < visual_limit:
                rgb_tile = (batch_x[j][:, :, :3] * 255).astype(np.uint8)
                visual_rgb.append(rgb_tile)
                visual_true.append(batch_y[j])
                visual_pred.append(pred_mask[j])
                visual_tile_ids.append(tile_id)

        all_test_preds.extend(pred_mask.reshape(-1))
        all_test_trues.extend(true_mask.reshape(-1))

        del batch_x, batch_y, pred, pred_mask, true_mask
        gc.collect()

    if not all_test_preds:
        print("⚠️ No test predictions were collected.")
        return

    all_test_preds = np.array(all_test_preds)
    all_test_trues = np.array(all_test_trues)

    if visual_rgb:

        visualise_prediction_grid(
            visual_rgb,
            visual_true,
            visual_pred,
            tile_id_list=specific_tile_ids,  # ❌ old name
            all_tile_ids=visual_tile_ids,
            n_rows=n_rows,
            n_cols=n_cols
        )

        '''
        visualise_prediction_grid(
            rgb_list=visual_rgb,
            true_mask_list=visual_true,
            pred_mask_list=visual_pred,
            specific_tile_ids=specific_tile_ids,
            all_tile_ids=visual_tile_ids,
            n_rows=n_rows,
            n_cols=n_cols
        )'''


    mean_iou_metric = tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES)
    mean_iou_metric.update_state(all_test_trues, all_test_preds)
    miou = mean_iou_metric.result().numpy()

    macro_f1 = f1_score(all_test_trues, all_test_preds, average='macro')
    macro_precision = precision_score(all_test_trues, all_test_preds, average='macro')
    macro_recall = recall_score(all_test_trues, all_test_preds, average='macro')

    print(f"\n📊 Macro Metrics:")
    print(f"  F1 Score     : {macro_f1:.4f}")
    print(f"  Precision    : {macro_precision:.4f}")
    print(f"  Recall       : {macro_recall:.4f}")

    conf_matrix = confusion_matrix(all_test_trues, all_test_preds, labels=list(range(NUM_CLASSES)))
    print("\n📏 Per-class IoU Scores:")
    class_ious = []
    for i in range(NUM_CLASSES):
        intersection = conf_matrix[i, i]
        union = np.sum(conf_matrix[i, :]) + np.sum(conf_matrix[:, i]) - intersection
        iou = intersection / union if union > 0 else float('nan')
        class_ious.append(iou)
        print(f"  {CLASS_NAMES[i]:<12} IoU: {iou:.4f}")

    print(f"\n📈 Mean IoU (mIoU): {miou:.4f}")

    print("\n🔍 Classification Report:")
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_test_trues,
        all_test_preds,
        labels=list(range(NUM_CLASSES)),
        zero_division=0
    )
    for i, name in enumerate(CLASS_NAMES):
        print(f"{name:<12} | F1: {f1[i]:.4f} | Prec: {precision[i]:.4f} | Recall: {recall[i]:.4f}")

    print("\n🌀 Row-Normalised Confusion Matrix:")
    with np.errstate(divide='ignore', invalid='ignore'):
        row_sums = conf_matrix.sum(axis=1, keepdims=True)
        norm_conf = np.divide(conf_matrix.astype(np.float32), row_sums, where=row_sums != 0)

    fig, ax = plt.subplots(figsize=(8, 6))
    disp = ConfusionMatrixDisplay(norm_conf, display_labels=CLASS_NAMES)
    disp.plot(cmap='viridis', ax=ax, values_format=".2f")
    fig.tight_layout()
    fig.savefig(os.path.join(out_dir, 'confusion_matrix_row_norm.png'))
    plt.show()
    plt.close(fig)



# --- Reconstruction Utility ---

def reconstruct_canvas(
    model: tf.keras.Model,
    df: pd.DataFrame,
    source_file_prefix: str, # The common prefix for all chips of one large image
    generator_class: callable, # The function to build the TensorFlow dataset (e.g., `build_tf_dataset`)
    img_dir: str,
    elev_dir: str,
    slope_dir: str,
    label_dir: str,
    tile_size: int = 256
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Reconstructs the full RGB image, ground truth mask, and model prediction mask
    for a given source file by stitching together its individual chips.

    Args:
        model (tf.keras.Model): The trained Keras model.
        df (pd.DataFrame): The DataFrame containing metadata for all chips (e.g., test_df).
        source_file_prefix (str): The common prefix for all chips belonging to the same large image.
                                  (e.g., "25f1c24f30_EB81FE6E2BOPENPIPELINE")
        generator_class (callable): The function to build the TensorFlow dataset (e.g., `build_tf_dataset`).
        img_dir (str): Directory containing RGB images.
        elev_dir (str): Directory containing elevation .npy files.
        slope_dir (str): Directory containing slope .npy files.
        label_dir (str): Directory containing label images.
        tile_size (int): The size of individual tiles/chips (e.g., 256).

    Returns:
        tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing the
        reconstructed RGB canvas, Ground Truth canvas, and Prediction canvas (all as np.uint8).

    Raises:
        ValueError: If no chips are found for the specified source file prefix.
    """
    # Filter for chips belonging to the specific source file prefix
    # Assuming 'tile_id' column contains the full ID like 'prefix_x_y'
    df_file = df[df['tile_id'].str.startswith(source_file_prefix)].copy()
    if df_file.empty:
        raise ValueError(f"No chips found for source file prefix: {source_file_prefix}")

    # Determine the overall canvas shape based on min/max x,y coordinates and tile size
    min_x = df_file['x'].min()
    min_y = df_file['y'].min()
    max_x = df_file['x'].max() + tile_size
    max_y = df_file['y'].max() + tile_size

    canvas_w = max_x - min_x
    canvas_h = max_y - min_y

    canvas_shape_rgb = (canvas_h, canvas_w, 3)
    canvas_shape_mask = (canvas_h, canvas_w, 3) # For GT/Pred, will be colored RGB

    # Initialize canvases with IGNORE_COLOR (magenta) for padding/uncovered areas
    rgb_canvas = np.full(canvas_shape_rgb, IGNORE_COLOR, dtype=np.uint8)
    gt_canvas = np.full(canvas_shape_mask, IGNORE_COLOR, dtype=np.uint8)
    pred_canvas = np.full(canvas_shape_mask, IGNORE_COLOR, dtype=np.uint8)

    # Build dataset for just this file's chips, ensuring correct order for stitching
    # Use shuffle=False and augment=False for reconstruction, and a 'val' or 'test' split
    # to guarantee no augmentation. Batch size can be larger for efficiency during inference.
    # Assuming 'generator_class' (e.g., build_tf_dataset) can accept a 'split' argument.
    gen = generator_class(
        df_file, img_dir, elev_dir, slope_dir, label_dir,
        batch_size=64, shuffle=False, augment=False, split="val" # Use 'val' or 'test' to ensure no augments
    )

    # Iterate through the generated batches, predict, and fill the canvases in order
    row_index_in_df = 0 # Tracks position in the filtered df_file to get (x,y) coords
    for batch_x, batch_y_onehot in tqdm(gen, desc=f"Reconstructing {source_file_prefix}"):
        # Get model predictions (softmax probabilities)
        preds_softmax = model.predict(batch_x, verbose=0)
        # Convert probabilities to class IDs (predicted masks)
        pred_mask_ids = tf.argmax(preds_softmax, axis=-1).numpy()
        # Convert one-hot true labels to class IDs (ground truth masks)
        true_mask_ids = tf.argmax(batch_y_onehot, axis=-1).numpy()

        batch_size_actual = batch_x.shape[0]
        for i in range(batch_size_actual):
            if row_index_in_df >= len(df_file): # Safety break if we've processed all chips
                break

            # Get chip's original (x, y) coordinates from the DataFrame
            row_df_entry = df_file.iloc[row_index_in_df]
            rel_x = row_df_entry.x - min_x # Relative X coordinate for placing on canvas
            rel_y = row_df_entry.y - min_y # Relative Y coordinate for placing on canvas
            row_index_in_df += 1

            # Extract current chip data (RGB is always first 3 channels)
            # Scale from [0,1] to [0,255] and cast to uint8
            rgb_chip = tf.cast(batch_x[i][..., :3] * 255.0, tf.uint8).numpy()
            true_mask_chip = true_mask_ids[i]
            pred_mask_chip = pred_mask_ids[i]

            # Place RGB chip onto the canvas
            rgb_canvas[rel_y : rel_y + tile_size, rel_x : rel_x + tile_size] = rgb_chip

            # Color and place Ground Truth mask onto the canvas
            gt_rgb_chip = np.zeros((tile_size, tile_size, 3), dtype=np.uint8)
            for cid, colour in CLASS_TO_COLOR.items():
                gt_rgb_chip[true_mask_chip == cid] = np.array(colour, dtype=np.uint8)
            gt_canvas[rel_y : rel_y + tile_size, rel_x : rel_x + tile_size] = gt_rgb_chip

            # Color and place Prediction mask onto the canvas
            pred_rgb_chip = np.zeros((tile_size, tile_size, 3), dtype=np.uint8)
            for cid, colour in CLASS_TO_COLOR.items():
                pred_rgb_chip[pred_mask_chip == cid] = np.array(colour, dtype=np.uint8)
            pred_canvas[rel_y : rel_y + tile_size, rel_x : rel_x + tile_size] = pred_rgb_chip
    
    return rgb_canvas, gt_canvas, pred_canvas


def plot_reconstruction(img: np.ndarray, label: np.ndarray, pred: np.ndarray, source_file_prefix: str):
    """
    Plots the reconstructed full RGB image, ground truth mask, and model prediction mask
    for a given source file prefix.

    Args:
        img (np.ndarray): The reconstructed RGB image canvas (uint8).
        label (np.ndarray): The reconstructed ground truth mask canvas (colored, uint8).
        pred (np.ndarray): The reconstructed prediction mask canvas (colored, uint8).
        source_file_prefix (str): The common prefix of the source file for the title.
    """
    fig, axs = plt.subplots(1, 3, figsize=(26.5, 13))

    axs[0].imshow(img)
    axs[0].set_title("RGB Image", fontsize=18)
    axs[0].axis('off')

    axs[1].imshow(label)
    axs[1].set_title("Ground Truth", fontsize=18)
    axs[1].axis('off')

    axs[2].imshow(pred)
    axs[2].set_title("Model Prediction", fontsize=18)
    axs[2].axis('off')

    plt.suptitle(f"Reconstruction for: {source_file_prefix}", fontsize=24, y=0.95)
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    plt.show()
    plt.close(fig)




# --- Unused Utility Functions ---

def visualise_prediction_grid_by_performance(
    rgb_list,
    true_mask_list,
    pred_mask_list,
    miou_list,
    class_list,
    n_rows=4,
    n_cols=3,
    class_names=CLASS_NAMES,
    class_to_color=CLASS_TO_COLOR
):
    """
    Visualizes a grid of predictions sorted by performance (mIoU),
    with an attempt to include specific classes (like Water or Building)
    from the lowest performance tier.

    Args:
        rgb_list (list): List of input RGB images (NumPy arrays).
        true_mask_list (list): List of ground truth masks (one-hot encoded NumPy arrays).
        pred_mask_list (list): List of predicted masks (integer class ID NumPy arrays).
        miou_list (list): List of per-chip mIoU scores.
        class_list (list): List where each element is a list of integer class IDs present in that chip.
        n_rows (int): Number of rows in the visualization grid.
        n_cols (int): Number of columns in the visualization grid.
        class_names (list): List of class names.
        class_to_color (dict): Mapping from class ID to RGB color tuple.
    """
    import matplotlib.pyplot as plt
    import numpy as np

    # Step 1: Sort chips by mIoU
    chip_data = list(zip(rgb_list, true_mask_list, pred_mask_list, miou_list, class_list))
    chip_data.sort(key=lambda x: x[3])  # Sort by mIoU

    n_total = n_rows * n_cols
    chips_per_group = n_total // 3

    bottom_third = chip_data[:chips_per_group]
    middle_third = chip_data[chips_per_group:2 * chips_per_group]
    top_third = chip_data[2 * chips_per_group:]

    # Step 2: Select 4 chips from each group with some class coverage strategy
    def select_chips(chips, needed, reserve_class=None):
        selected = []
        for chip in chips:
            if len(selected) >= needed:
                break
            if reserve_class is not None and reserve_class in chip[4]:
                selected.append(chip)
                reserve_class = None  # only once
            elif reserve_class is None:
                selected.append(chip)
        return selected

    selected_bottom = select_chips(bottom_third, 4, reserve_class=class_names.index("Water"))
    selected_bottom = select_chips([c for c in bottom_third if c not in selected_bottom], 4 - len(selected_bottom), reserve_class=class_names.index("Building")) + selected_bottom

    selected_middle = middle_third[:4]
    selected_top = top_third[:4]

    # Combine and assign column index
    all_selected = selected_bottom + selected_middle + selected_top
    column_assignments = [0] * 4 + [1] * 4 + [2] * 4  # 0: bottom, 1: middle, 2: top

    # Step 3: Plotting
    fig, axs = plt.subplots(n_rows, n_cols * 3, figsize=(n_cols * 6.6, n_rows * 2.6))

    for idx, (rgb, true_mask_oh, pred_mask, _, _) in enumerate(all_selected):
        true_mask = np.argmax(true_mask_oh, axis=-1)
        h, w = true_mask.shape
        true_rgb = np.zeros((h, w, 3), dtype=np.uint8)
        pred_rgb = np.zeros((h, w, 3), dtype=np.uint8)

        for class_id, color in class_to_color.items():
            true_rgb[true_mask == class_id] = color
            pred_rgb[pred_mask == class_id] = color

        ignore_mask = np.all(true_mask_oh == 0, axis=-1)
        true_rgb[ignore_mask] = (255, 0, 255)
        pred_rgb[ignore_mask] = (255, 0, 255)

        row = idx % n_rows
        col = column_assignments[idx] * 3

        axs[row, col + 0].imshow(rgb)
        axs[row, col + 0].set_title("Input")
        axs[row, col + 1].imshow(true_rgb)
        axs[row, col + 1].set_title("Ground Truth")
        axs[row, col + 2].imshow(pred_rgb)
        axs[row, col + 2].set_title("Prediction")

        for i in range(3):
            axs[row, col + i].axis("off")

    plt.tight_layout()
    plt.show()
    plt.close(fig)


def apply_crf(rgb, probs, t=5, compat=10, sxy=3, srgb=13):
    import pydensecrf.densecrf as dcrf
    from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral
    import numpy as np

    h, w = probs.shape[:2]
    num_classes = probs.shape[-1]

    d = dcrf.DenseCRF2D(w, h, num_classes)

    # 🔧 FIX: reshape to (num_classes, H*W)
    unary = unary_from_softmax(probs.transpose(2, 0, 1))
    unary = unary.reshape((num_classes, -1)).copy(order='C')

    d.setUnaryEnergy(unary)

    rgb_float = rgb.astype(np.float32)
    if not rgb_float.flags['C_CONTIGUOUS']:
        rgb_float = np.ascontiguousarray(rgb_float)

    pairwise = create_pairwise_bilateral(sdims=(sxy, sxy), schan=(srgb, srgb, srgb), img=rgb_float, chdim=2)
    d.addPairwiseEnergy(pairwise, compat=compat)

    Q = d.inference(t)
    return np.argmax(Q, axis=0).reshape((h, w))




def evaluate_model_with_crf(model, test_gen, test_df, out_dir, image_dir, label_dir, tile_size=256, n_rows=4, n_cols=3):
    os.makedirs(out_dir, exist_ok=True)
    all_preds = []
    all_trues = []

    rgb_list = []
    true_mask_list = []
    pred_mask_list = []
    present_classes_per_chip = []
    chip_ious = []

    visual_limit = n_rows * n_cols
    tile_index = 0

    for x_batch, y_batch in tqdm(test_gen, desc="Evaluating"):
        if tf.size(x_batch).numpy() == 0:
            continue

        soft_preds = model.predict(x_batch, verbose=0)
        true_mask_batch = tf.argmax(y_batch, axis=-1).numpy().astype(np.uint8)

        for i in range(x_batch.shape[0]):
            if tile_index >= len(test_df):
                break

            # Fix: Convert tensor to NumPy before applying astype()
            rgb_img = (x_batch[i][..., :3].numpy() * 255).astype(np.uint8)
            softmax_pred = soft_preds[i]
            crf_mask = apply_crf(rgb_img, softmax_pred)

            true_mask = true_mask_batch[i]
            all_preds.extend(crf_mask.reshape(-1))
            all_trues.extend(true_mask.reshape(-1))

            # Per-chip IoU
            cm = confusion_matrix(true_mask.flatten(), crf_mask.flatten(), labels=list(range(NUM_CLASSES)))
            ious = []
            for j in range(NUM_CLASSES):
                intersection = cm[j, j]
                union = np.sum(cm[j, :]) + np.sum(cm[:, j]) - intersection
                if union > 0:
                    iou = intersection / union
                    ious.append(iou)
            chip_ious.append(np.mean(ious) if ious else 0)

            if len(rgb_list) < visual_limit:
                rgb_list.append(rgb_img)
                true_mask_onehot = y_batch[i].numpy()
                true_mask_list.append(true_mask_onehot)
                pred_mask_list.append(crf_mask)

                present_classes = np.unique(true_mask)
                present_classes = [int(c) for c in present_classes if c < NUM_CLASSES]
                present_classes_per_chip.append(present_classes)

            tile_index += 1

        gc.collect()


    visualise_prediction_grid(
        rgb_list=rgb_list,
        true_mask_list=true_mask_list,
        pred_mask_list=pred_mask_list,
        n_rows=n_rows,
        n_cols=n_cols
    )

    # --- Metrics ---
    f1 = f1_score(all_trues, all_preds, average='macro')
    precision = precision_score(all_trues, all_preds, average='macro')
    recall = recall_score(all_trues, all_preds, average='macro')

    conf_matrix = confusion_matrix(all_trues, all_preds, labels=list(range(NUM_CLASSES)))
    per_class_ious = []
    for i in range(NUM_CLASSES):
        intersection = conf_matrix[i, i]
        union = np.sum(conf_matrix[i, :]) + np.sum(conf_matrix[:, i]) - intersection
        iou = intersection / union if union > 0 else 0
        per_class_ious.append(iou)
    miou = np.mean(per_class_ious)

    report = classification_report(
        all_trues, all_preds, target_names=CLASS_NAMES, digits=4
    )

    # Confusion Matrix Plot
    row_sums = conf_matrix.sum(axis=1, keepdims=True)
    norm_conf = np.divide(conf_matrix.astype(np.float32), row_sums, where=row_sums != 0)
    fig, ax = plt.subplots(figsize=(8, 6))
    disp = ConfusionMatrixDisplay(norm_conf, display_labels=CLASS_NAMES)
    disp.plot(cmap='viridis', ax=ax, values_format=".2f")
    fig.tight_layout()
    fig.savefig(os.path.join(out_dir, 'confusion_matrix_crf.png'))
    plt.close(fig)

    # --- Per-chip mIoU histogram ---
    print("\n📊 Generating per-chip mIoU histogram...")
    bin_edges = np.linspace(0, 1, 21)  # 5% bins
    plt.figure(figsize=(10, 6))
    plt.hist(chip_ious, bins=bin_edges, edgecolor='black')
    plt.xlabel("Per-Chip mIoU")
    plt.ylabel("Number of Chips")
    plt.title("Distribution of Per-Chip mIoU (CRF Post-Processed)")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "per_chip_miou_hist.png"))
    plt.show()

    return {
        "macro_f1": f1,
        "precision": precision,
        "recall": recall,
        "miou": miou,
        "per_class_ious": per_class_ious,
        "classification_report": report
    }

In [None]:
'''
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns # Used for confusion matrix plotting (implicitly by ConfusionMatrixDisplay's default style)
import random
import gc # For garbage collection
from PIL import Image # For image manipulation, potentially for reconstruction or debugging
from datetime import datetime # For timestamping or time limits

# --- TensorFlow and Keras specific imports ---
import tensorflow as tf
import tensorflow.keras.backend as K
import segmentation_models as sm
from tensorflow.keras.metrics import MeanIoU # Explicit import for clarity

# --- Scikit-learn metrics ---
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
    ConfusionMatrixDisplay, # For plotting confusion matrices
    precision_recall_fscore_support # For per-class metrics
)

# --- Progress bar for loops ---
from tqdm import tqdm


# Clear Keras session to avoid conflicts from previous model definitions
K.clear_session()


# --- Global Configuration and Constants ---

# Number of semantic classes
NUM_CLASSES = 6
# Names corresponding to each class ID
CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']

# Mapping from RGB color (as tuple) to integer class ID for labels
COLOR_TO_CLASS = {
    (230, 25, 75): 0,      # Red: Building
    (145, 30, 180): 1,     # Purple: Clutter
    (60, 180, 75): 2,      # Green: Vegetation
    (245, 130, 48): 3,     # Orange: Water
    (255, 255, 255): 4,    # White: Background
    (0, 130, 200): 5,      # Blue: Car
    (255, 0, 255): 6       # Magenta: Often used as an "ignore" or "padding" pixel color
}

# Inverse mapping from integer class ID to RGB color (as tuple)
# Excludes class ID 6 (ignore class) from this mapping for normal visualization
CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items() if v < NUM_CLASSES}
IGNORE_COLOR = (255, 0, 255) # The specific RGB color for ignored regions (magenta)

# Default output directory for plots and saved models
out_dir = "/content/figs"

def visualise_prediction_grid_by_performance(
    rgb_list: list,
    true_mask_list: list,
    pred_mask_list: list,
    miou_list: list,
    class_list: list, # List of lists of present class IDs per chip
    n_rows: int = 4,
    n_cols: int = 3,
    class_names: list = CLASS_NAMES, # Default to global CLASS_NAMES
    class_to_color: dict = CLASS_TO_COLOR # Default to global CLASS_TO_COLOR
):
    """
    Visualizes a grid of predictions sorted by performance (mIoU),
    with an attempt to include specific classes (like Water or Building)
    from the lowest performance tier.

    Args:
        rgb_list (list): List of input RGB images (NumPy arrays).
        true_mask_list (list): List of ground truth masks (one-hot encoded NumPy arrays).
        pred_mask_list (list): List of predicted masks (integer class ID NumPy arrays).
        miou_list (list): List of per-chip mIoU scores.
        class_list (list): List where each element is a list of integer class IDs present in that chip.
        n_rows (int): Number of rows in the visualization grid.
        n_cols (int): Number of columns in the visualization grid.
        class_names (list): List of class names.
        class_to_color (dict): Mapping from class ID to RGB color tuple.
    """
    # Step 1: Combine and sort chip data by mIoU
    chip_data = list(zip(rgb_list, true_mask_list, pred_mask_list, miou_list, class_list))
    chip_data.sort(key=lambda x: x[3])  # Sort by mIoU (index 3)

    n_total_plots = n_rows * n_cols
    
    # Divide chips into performance tiers (bottom, middle, top third by mIoU)
    chips_per_group = n_total_plots // 3
    
    bottom_third_data = chip_data[:chips_per_group]
    middle_third_data = chip_data[chips_per_group:2 * chips_per_group]
    top_third_data = chip_data[2 * chips_per_group:]

    # Step 2: Select a subset of chips from each group, prioritizing certain classes
    def select_chips(chips_pool, needed_count, reserve_class_id=None):
        """Helper to select chips, ensuring a specific class is included if possible."""
        selected_chips = []
        # Try to find a chip with the reserved class first
        if reserve_class_id is not None:
            for i, chip in enumerate(chips_pool):
                if reserve_class_id in chip[4]: # chip[4] is the class_list for this chip
                    selected_chips.append(chip)
                    # Remove from pool to avoid re-selecting
                    chips_pool.pop(i) 
                    break
        
        # Fill remaining slots with random chips from the pool
        random.shuffle(chips_pool) # Shuffle remaining for randomness
        selected_chips.extend(chips_pool[:needed_count - len(selected_chips)])
        return selected_chips

    # Select chips from each performance group
    # Prioritize water and building in the bottom (worst performing) third
    selected_bottom = select_chips(bottom_third_data.copy(), 4, reserve_class_id=class_names.index("Water"))
    # If not enough, try adding a Building chip too
    selected_bottom.extend(select_chips([c for c in bottom_third_data if c not in selected_bottom], 
                                         4 - len(selected_bottom), 
                                         reserve_class_id=class_names.index("Building")))
    
    selected_middle = select_chips(middle_third_data.copy(), 4) # Random 4 from middle
    selected_top = select_chips(top_third_data.copy(), 4) # Random 4 from top

    # Combine all selected chips (will be up to 12 if n_rows=4, n_cols=3)
    all_selected = selected_bottom + selected_middle + selected_top
    # Trim to exactly n_total_plots if more were selected than needed (shouldn't happen with exact numbers)
    all_selected = all_selected[:n_total_plots]

    # Step 3: Plotting
    fig, axs = plt.subplots(n_rows, n_cols * 3, figsize=(n_cols * 6.6, n_rows * 2.6))

    # Handle cases where axs might be 1D (e.g., n_rows=1 or n_cols=1)
    if n_rows == 1 and n_cols == 1:
        axs = np.array([axs]) # Make it 2D (1,3)
    elif n_rows == 1:
        axs = np.expand_dims(axs, axis=0) # Make first dim 1 for (1, N)
    elif n_cols == 1:
        axs = np.expand_dims(axs, axis=1) # Make second dim 1 for (N, 1)

    for idx_plot, (rgb, true_mask_oh, pred_mask_ids, _, _) in enumerate(all_selected):
        # Convert true_mask from one-hot to class IDs for coloring
        true_mask_ids = np.argmax(true_mask_oh, axis=-1)
        h, w = true_mask_ids.shape

        # Initialize RGB versions of masks for display
        true_rgb_display = np.zeros((h, w, 3), dtype=np.uint8)
        pred_rgb_display = np.zeros((h, w, 3), dtype=np.uint8)

        # Color the true and predicted masks
        for class_id, color in class_to_color.items():
            true_rgb_display[true_mask_ids == class_id] = np.array(color, dtype=np.uint8)
            pred_rgb_display[pred_mask_ids == class_id] = np.array(color, dtype=np.uint8)

        # Apply ignore color (magenta) to ignored regions
        ignore_mask = np.all(true_mask_oh == 0, axis=-1)
        true_rgb_display[ignore_mask] = IGNORE_COLOR
        pred_rgb_display[ignore_mask] = IGNORE_COLOR

        row = idx_plot // n_cols
        col = (idx_plot % n_cols) * 3

        axs[row, col + 0].imshow(rgb)
        axs[row, col + 0].set_title("Input", fontsize=10)
        axs[row, col + 1].imshow(true_rgb_display)
        axs[row, col + 1].set_title("Ground Truth", fontsize=10)
        axs[row, col + 2].imshow(pred_rgb_display)
        axs[row, col + 2].set_title("Prediction", fontsize=10)

        for i_ax in range(3):
            axs[row, col + i_ax].axis("off")

    plt.tight_layout()
    plt.show()
    plt.close(fig) # Explicitly close the figure


# --- Conditional Random Field (CRF) Post-processing ---
# This block attempts to import pydensecrf. If it's not installed, it defines
# dummy functions to prevent NameErrors elsewhere.

try:
    import pydensecrf.densecrf as dcrf
    from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral

    def apply_crf(rgb: np.ndarray, probs: np.ndarray, t: int = 5, compat: int = 10, sxy: int = 3, schan: int = 13) -> np.ndarray:
        """
        Applies Dense Conditional Random Field (CRF) post-processing to raw softmax predictions.
        CRF refines segmentation boundaries by considering both pixel-wise probabilities
        and image appearance (color and spatial proximity).

        Args:
            rgb (np.ndarray): The input RGB image (H, W, 3) as np.uint8 (0-255).
            probs (np.ndarray): Softmax output probabilities from the model (H, W, num_classes).
            t (int): Number of CRF inference iterations.
            compat (int): Compatibility constant for the pairwise term.
            sxy (int): Spatial standard deviation for the pairwise bilateral term.
            schan (int): Color standard deviation for the pairwise bilateral term.

        Returns:
            np.ndarray: The refined segmentation mask (H, W) as integer class IDs.
        """
        h, w = probs.shape[:2]
        num_classes = probs.shape[-1]

        d = dcrf.DenseCRF2D(w, h, num_classes)

        # Reshape unary potentials from (H, W, num_classes) to (num_classes, H*W)
        # Ensure it's C-contiguous for pydensecrf
        unary = unary_from_softmax(probs.transpose(2, 0, 1)).reshape((num_classes, -1)).copy(order='C')
        d.setUnaryEnergy(unary)

        # Convert RGB to float32 and ensure C-contiguous for pydensecrf
        rgb_float = np.ascontiguousarray(rgb.astype(np.float32))

        # Create pairwise bilateral energy term
        pairwise = create_pairwise_bilateral(sdims=(sxy, sxy), schan=(schan, schan, schan), img=rgb_float, chdim=2)
        d.addPairwiseEnergy(pairwise, compat=compat)

        # Perform inference
        Q = d.inference(t)
        # Return the most likely class ID for each pixel
        return np.argmax(Q, axis=0).reshape((h, w))

    def evaluate_model_with_crf(
        model: tf.keras.Model,
        test_gen: tf.data.Dataset,
        test_df: pd.DataFrame,
        out_dir: str,
        image_dir: str, # Not directly used in function, but might be needed by calling code
        label_dir: str, # Not directly used in function, but might be needed by calling code
        tile_size: int = 256,
        n_rows: int = 4,
        n_cols: int = 3
    ) -> dict:
        """
        Evaluates the trained model on the test set with CRF post-processing.
        Calculates and prints overall and per-class metrics, generates a confusion matrix,
        and visualizes sample predictions after CRF.

        Args:
            model (tf.keras.Model): The trained Keras model.
            test_gen (tf.data.Dataset): TensorFlow Dataset for the test set.
            test_df (pd.DataFrame): DataFrame containing metadata for the test set.
            out_dir (str): Directory to save plots.
            image_dir (str): Path to the directory containing RGB images (for visualization, not CRF input).
            label_dir (str): Path to the directory containing label images (for visualization).
            tile_size (int): The size of the tiles (e.g., 256x256).
            n_rows (int): Number of rows for the prediction visualization grid.
            n_cols (int): Number of columns for the prediction visualization grid.

        Returns:
            dict: A dictionary containing various evaluation metrics.
        """
        os.makedirs(out_dir, exist_ok=True)
        print("🧪 Running evaluation on test set with CRF post-processing...")
        all_preds_flat = [] # Flat list of all predicted class IDs (after CRF)
        all_trues_flat = [] # Flat list of all true class IDs

        rgb_list_viz = [] # RGB images for visualization
        true_mask_list_viz = [] # True masks for visualization (one-hot)
        pred_mask_list_viz = [] # Predicted masks for visualization (class IDs after CRF)
        chip_ious = [] # Per-chip mIoUs (after CRF)

        visual_limit = n_rows * n_cols # Max number of samples to collect for visualization
        
        test_df_tile_ids = test_df['tile_id'].tolist()
        tile_index_in_df = 0 # To track which chip from test_df is being processed

        for x_batch, y_batch_onehot in tqdm(test_gen, desc="CRF Evaluating"):
            # Skip empty batches
            if tf.size(x_batch).numpy() == 0:
                continue

            # Get model's raw softmax predictions
            soft_preds = model.predict(x_batch, verbose=0)

            for i in range(x_batch.shape[0]):
                if tile_index_in_df >= len(test_df_tile_ids): # Safety break
                    break

                # Extract RGB image for CRF (needs to be 0-255 uint8)
                rgb_img = (x_batch[i][..., :3].numpy() * 255).astype(np.uint8)
                softmax_pred_for_crf = soft_preds[i] # Softmax probabilities for current chip

                # Apply CRF post-processing
                crf_mask = apply_crf(rgb_img, softmax_pred_for_crf) # Returns class IDs (H, W)

                # Get true mask as class IDs for metric calculation
                true_mask_ids = np.argmax(y_batch_onehot[i], axis=-1).astype(np.uint8)

                # Collect flattened predictions and true labels for overall metrics
                all_preds_flat.extend(crf_mask.reshape(-1))
                all_trues_flat.extend(true_mask_ids.reshape(-1))

                # --- Per-chip IoU calculation ---
                # This needs to be done here as it uses the CRF-processed mask
                cm_chip = confusion_matrix(true_mask_ids.flatten(), crf_mask.flatten(), labels=list(range(NUM_CLASSES)))
                ious_chip = []
                for j in range(NUM_CLASSES):
                    intersection_chip = cm_chip[j, j]
                    union_chip = np.sum(cm_chip[j, :]) + np.sum(cm_chip[:, j]) - intersection_chip
                    if union_chip > 0:
                        iou_chip = intersection_chip / union_chip
                        ious_chip.append(iou_chip)
                chip_ious.append(np.mean(ious_chip) if ious_chip else 0)


                # Collect visuals up to the limit
                if len(rgb_list_viz) < visual_limit:
                    rgb_list_viz.append(rgb_img)
                    true_mask_list_viz.append(y_batch_onehot[i].numpy()) # Keep one-hot for plotting
                    pred_mask_list_viz.append(crf_mask) # CRF output (class IDs)
                    # You might also want to append the tile_id here if sorting by performance is used

                tile_index_in_df += 1 # Move to the next tile in the test_df

            gc.collect() # Garbage collection after each batch

        # Convert collected lists to NumPy arrays for final metric calculations
        all_preds_flat = np.array(all_preds_flat)
        all_trues_flat = np.array(all_trues_flat)

        if not all_preds_flat.size:
            print("⚠️ No predictions collected for CRF evaluation. Skipping metrics and plots.")
            return {}

        # --- Visualise Grid (using CRF outputs) ---
        if rgb_list_viz:
            print(f"\n🖼️ Visualizing {len(rgb_list_viz)} sample predictions with CRF...")
            visualise_prediction_grid(
                rgb_list=rgb_list_viz,
                true_mask_list=true_mask_list_viz, # These are one-hot, `visualise_prediction_grid` will convert
                pred_mask_list=pred_mask_list_viz,
                n_rows=n_rows,
                n_cols=n_cols,
                # If you want to use the performance-based grid (`visualise_prediction_grid_by_performance`),
                # uncomment the line in your `train_unet`'s `evaluate_on_test` call
                # and pass the appropriate lists (miou_list, class_list) to it.
            )
        else:
            print("No samples collected for CRF visualization grid.")

        # --- Metrics Calculation (using CRF output) ---
        
        miou_metric_tf = tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES)
        miou_metric_tf.update_state(all_trues_flat, all_preds_flat)
        overall_miou = miou_metric_tf.result().numpy()

        macro_f1 = f1_score(all_trues_flat, all_preds_flat, average='macro', zero_division=0)
        macro_precision = precision_score(all_trues_flat, all_preds_flat, average='macro', zero_division=0)
        macro_recall = recall_score(all_trues_flat, all_preds_flat, average='macro', zero_division=0)

        conf_matrix_crf = confusion_matrix(all_trues_flat, all_preds_flat, labels=list(range(NUM_CLASSES)))
        
        # Per-class IoU
        per_class_ious_crf = []
        for i in range(NUM_CLASSES):
            intersection_crf = conf_matrix_crf[i, i]
            union_crf = np.sum(conf_matrix_crf[i, :]) + np.sum(conf_matrix_crf[:, i]) - intersection_crf
            iou_crf = intersection_crf / union_crf if union_crf > 0 else float('nan') # Use NaN for undefined IoU
            per_class_ious_crf.append(iou_crf)
        
        # Per-class F1, Precision, Recall
        precision_crf, recall_crf, f1_crf, _ = precision_recall_fscore_support(
            all_trues_flat, all_preds_flat, labels=list(range(NUM_CLASSES)), zero_division=0
        )

        print(f"\n📊 CRF Macro Metrics:")
        print(f"  F1 Score     : {macro_f1:.4f}")
        print(f"  Precision    : {macro_precision:.4f}")
        print(f"  Recall       : {macro_recall:.4f}")
        print(f"\n📈 CRF Mean IoU (mIoU): {overall_miou:.4f}")
        print("\n📏 CRF Per-class IoU Scores:")
        for i in range(NUM_CLASSES):
            print(f"  {CLASS_NAMES[i]:<12} IoU: {per_class_ious_crf[i]:.4f}")
        print("\n🔍 CRF Per-class Classification Report (F1 | Precision | Recall):")
        for i, name in enumerate(CLASS_NAMES):
            print(f"{name:<12} | F1: {f1_crf[i]:.4f} | Prec: {precision_crf[i]:.4f} | Recall: {recall_crf[i]:.4f}")

        # Confusion Matrix Plot (Row-Normalised) for CRF
        print("\n🌀 Generating CRF Row-Normalised Confusion Matrix plot...")
        # Use the normalize_confusion_matrix helper for consistent normalization
        norm_conf_crf = normalize_confusion_matrix(conf_matrix_crf, norm='true')

        fig_cm, ax_cm = plt.subplots(figsize=(8, 6))
        disp_crf = ConfusionMatrixDisplay(norm_conf_crf, display_labels=CLASS_NAMES)
        disp_crf.plot(cmap='viridis', ax=ax_cm, values_format=".2f")
        ax_cm.set_title("CRF Row-Normalised Confusion Matrix (True Label %) - Test Set")
        fig_cm.tight_layout()
        cm_crf_save_path = os.path.join(out_dir, 'confusion_matrix_crf_row_norm.png')
        fig_cm.savefig(cm_crf_save_path)
        plt.show()
        plt.close(fig_cm)
        print(f"Saved CRF confusion matrix to: {cm_crf_save_path}")

        # Per-chip mIoU histogram for CRF results
        if chip_ious: # This list is populated in this function
            print("\n📊 Generating per-chip mIoU histogram (CRF post-processed)...")
            bin_edges = np.linspace(0, 1, 21) # 5% bins (0.0 to 1.0 in 0.05 increments)
            fig_hist, ax_hist = plt.subplots(figsize=(10, 6))
            ax_hist.hist(chip_ious, bins=bin_edges, edgecolor='black')
            ax_hist.set_xlabel("Per-Chip mIoU")
            ax_hist.set_ylabel("Number of Chips")
            ax_hist.set_title("Distribution of Per-Chip mIoU (CRF Post-Processed)")
            ax_hist.grid(True)
            fig_hist.tight_layout()
            hist_crf_save_path = os.path.join(out_dir, "per_chip_miou_hist_crf.png")
            fig_hist.savefig(hist_crf_save_path)
            plt.show()
            plt.close(fig_hist)
            print(f"Saved CRF per-chip mIoU histogram to: {hist_crf_save_path}")

        # Return results if needed by a calling function
        return {
            "macro_f1": macro_f1,
            "precision": macro_precision,
            "recall": macro_recall,
            "miou": overall_miou,
            "per_class_ious": per_class_ious_crf,
            "classification_report_str": classification_report(
                all_trues_flat, all_preds_flat, target_names=CLASS_NAMES, digits=4, zero_division=0
            )
        }

except ImportError:
    print("\nWarning: pydensecrf is not installed. Skipping CRF-related functions (apply_crf, evaluate_model_with_crf).")
    # Define dummy functions if CRF is not available, to avoid NameError elsewhere
    def apply_crf(*args, **kwargs):
        print("CRF post-processing skipped (pydensecrf not installed).")
        return np.zeros(args[0].shape[:2], dtype=np.uint8) # Return a dummy mask

    def evaluate_model_with_crf(*args, **kwargs):
        print("CRF evaluation skipped because pydensecrf is not installed.")
        return {}


# --- Main Evaluation Function (without CRF) ---

def evaluate_on_test(
    model: tf.keras.Model,
    test_gen: tf.data.Dataset,
    test_df: pd.DataFrame,
    out_dir: str,
    image_dir: str, # Not directly used in this function for loading, but passed to visualise_prediction_grid
    label_dir: str, # Not directly used in this function for loading, but passed to visualise_prediction_grid
    tile_size: int = 256,
    n_rows: int = 4,
    n_cols: int = 3,
    specific_tile_ids: list = None # List of tile IDs to prioritize for visualization
):
    """
    Evaluates the trained model on the test set, calculates various metrics (mIoU, F1, Precision, Recall),
    generates a confusion matrix plot, and visualizes sample predictions (without CRF).

    Args:
        model (tf.keras.Model): The trained Keras model.
        test_gen (tf.data.Dataset): TensorFlow Dataset for the test set.
        test_df (pd.DataFrame): DataFrame containing metadata for the test set.
        out_dir (str): Directory to save plots.
        image_dir (str): Path to the directory containing RGB images (for visualization).
        label_dir (str): Path to the directory containing label images (for visualization).
        tile_size (int): The size of the tiles (e.g., 256x256).
        n_rows (int): Number of rows for the prediction visualization grid.
        n_cols (int): Number of columns for the prediction visualization grid.
        specific_tile_ids (list, optional): List of specific tile IDs to prioritize for plotting.
                                            Defaults to None (random selection).
    """
    os.makedirs(out_dir, exist_ok=True)
    print("🧪 Running evaluation on test set (without CRF)...")

    all_test_preds_flat = [] # Flat list of all predicted class IDs
    all_test_trues_flat = [] # Flat list of all true class IDs

    # Data collection for visualization
    visual_rgb = []
    visual_true = [] # Stores one-hot true masks for plotting
    visual_pred = [] # Stores class ID predicted masks for plotting
    visual_tile_ids = [] # Stores tile_ids for matching with specific_tile_ids
    
    # Calculate desired limit for visualization based on grid size
    visual_limit = n_rows * n_cols

    # Get all tile IDs from the test DataFrame for potential matching
    test_df_tile_ids = test_df['tile_id'].tolist()
    
    tile_index_in_df = 0 # To track which chip from test_df is being processed

    # Iterate through the test generator
    for batch_x, batch_y_onehot in tqdm(test_gen, desc="Evaluating"):
        # Skip empty batches if any
        if tf.size(batch_x).numpy() == 0:
            continue

        # Get model predictions (softmax probabilities)
        pred_probs = model.predict(batch_x, verbose=0)
        # Convert probabilities to class IDs (predicted masks)
        pred_mask_batch = np.argmax(pred_probs, axis=-1).astype(np.uint8)
        # Convert one-hot true labels to class IDs (ground truth masks)
        true_mask_batch_ids = np.argmax(batch_y_onehot, axis=-1).numpy().astype(np.uint8)

        # Process each image in the current batch
        for j in range(batch_x.shape[0]):
            # Stop if we have processed all tiles in test_df (safety break)
            if tile_index_in_df >= len(test_df_tile_ids):
                break
            
            current_tile_id = test_df_tile_ids[tile_index_in_df]
            
            # Collect data for visualization up to the visual_limit
            if len(visual_rgb) < visual_limit:
                # RGB image (scale from [0,1] to [0,255] and cast to uint8)
                # Assumes RGB is always the first 3 channels if it's a multi-channel input
                rgb_tile = (batch_x[j][:, :, :3].numpy() * 255).astype(np.uint8)
                visual_rgb.append(rgb_tile)
                visual_true.append(batch_y_onehot[j].numpy()) # Keep true mask as one-hot for plotting
                visual_pred.append(pred_mask_batch[j]) # Predicted mask (class IDs)
                visual_tile_ids.append(current_tile_id)

            # Collect all true and predicted pixels for overall metric calculation
            all_test_preds_flat.extend(pred_mask_batch[j].reshape(-1))
            all_test_trues_flat.extend(true_mask_batch_ids[j].reshape(-1)) # Flatten for metric calculation

            tile_index_in_df += 1 # Move to the next tile in the test_df

        # Explicitly delete batch variables and collect garbage to free memory
        del batch_x, batch_y_onehot, pred_probs, pred_mask_batch, true_mask_batch_ids
        gc.collect()

    if not all_test_preds_flat:
        print("⚠️ No test predictions were collected. Evaluation skipped.")
        return {}

    # Convert lists to NumPy arrays for metric calculations
    all_test_preds_flat = np.array(all_test_preds_flat)
    all_test_trues_flat = np.array(all_test_trues_flat)

    # --- Visualise Grid ---
    if visual_rgb: # Only plot if there's data to visualize
        print(f"\n🖼️ Visualizing {len(visual_rgb)} sample predictions (without CRF)...")
        visualise_prediction_grid(
            rgb_list=visual_rgb,
            true_mask_list=visual_true, # Pass one-hot true masks
            pred_mask_list=visual_pred,
            tile_id_list=specific_tile_ids, # Pass specific tile IDs to prioritize
            all_tile_ids=visual_tile_ids, # Pass the tile IDs collected during evaluation
            n_rows=n_rows,
            n_cols=n_cols
        )
    else:
        print("No samples collected for visualization grid.")

    # --- Overall Mean IoU ---
    mean_iou_metric_tf = tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES)
    mean_iou_metric_tf.update_state(all_test_trues_flat, all_test_preds_flat)
    miou = mean_iou_metric_tf.result().numpy()
    print(f"\n📈 Overall Mean IoU (mIoU): {miou:.4f}")

    # --- Macro F1, Precision, Recall ---
    # Using sklearn for macro averages
    macro_f1 = f1_score(all_test_trues_flat, all_test_preds_flat, average='macro', zero_division=0)
    macro_precision = precision_score(all_test_trues_flat, all_test_preds_flat, average='macro', zero_division=0)
    macro_recall = recall_score(all_test_trues_flat, all_test_preds_flat, average='macro', zero_division=0)

    print(f"\n📊 Macro Metrics:")
    print(f" F1 Score   : {macro_f1:.4f}")
    print(f" Precision  : {macro_precision:.4f}")
    print(f" Recall     : {macro_recall:.4f}")

    # --- Confusion Matrix & Per-class IoU ---
    conf_matrix = confusion_matrix(all_test_trues_flat, all_test_preds_flat, labels=list(range(NUM_CLASSES)))
    print("\n📏 Per-class IoU Scores:")
    class_ious = []
    for i in range(NUM_CLASSES):
        intersection = conf_matrix[i, i]
        union = np.sum(conf_matrix[i, :]) + np.sum(conf_matrix[:, i]) - intersection
        iou = intersection / union if union > 0 else float('nan') # Use NaN for undefined IoU
        class_ious.append(iou)
        print(f"  {CLASS_NAMES[i]:<12} IoU: {iou:.4f}")

    # --- Per-class F1, Precision, Recall ---
    # Using sklearn's precision_recall_fscore_support for per-class metrics
    precision_per_class, recall_per_class, f1_per_class, _ = precision_recall_fscore_support(
        all_test_trues_flat, all_test_preds_flat, labels=list(range(NUM_CLASSES)), zero_division=0
    )
    print("\n🔍 Per-class Classification Report (F1 | Precision | Recall):")
    for i, name in enumerate(CLASS_NAMES):
        print(f"{name:<12} | F1: {f1_per_class[i]:.4f} | Prec: {precision_per_class[i]:.4f} | Recall: {recall_per_class[i]:.4f}")

    # --- Confusion Matrix Plot (Row-Normalised) ---
    print("\n🌀 Generating Row-Normalised Confusion Matrix plot...")
    norm_conf = normalize_confusion_matrix(conf_matrix, norm='true')

    fig_cm, ax_cm = plt.subplots(figsize=(8, 6))
    disp_cm = ConfusionMatrixDisplay(norm_conf, display_labels=CLASS_NAMES)
    disp_cm.plot(cmap='viridis', ax=ax_cm, values_format=".2f")
    ax_cm.set_title("Row-Normalised Confusion Matrix (True Label %) - Test Set")
    fig_cm.tight_layout()
    cm_save_path = os.path.join(out_dir, 'confusion_matrix_row_norm.png')
    fig_cm.savefig(cm_save_path)
    plt.show()
    plt.close(fig_cm)
    print(f"Saved confusion matrix to: {cm_save_path}")

    # --- Per-chip mIoU histogram ---
    # Note: To enable this, you need to add per-chip IoU calculation in this function's loop.
    # Currently, 'chip_ious' is only collected in evaluate_model_with_crf.
    # If you want this histogram without CRF, you'd add:
    #     cm_chip = confusion_matrix(true_mask_batch_ids[j].flatten(), pred_mask_batch[j].flatten(), labels=list(range(NUM_CLASSES)))
    #     # ... calculate ious_chip from cm_chip ...
    #     chip_ious.append(np.mean(ious_chip) if ious_chip else 0)
    # inside the loop, and then uncomment the histogram plotting.
    # For now, it remains commented out.

    # Return results if needed by a calling function
    return {
        "macro_f1": macro_f1,
        "precision": macro_precision,
        "recall": macro_recall,
        "miou": miou,
        "per_class_ious": class_ious,
        "classification_report_str": classification_report(
            all_test_trues_flat, all_test_preds_flat, target_names=CLASS_NAMES, digits=4, zero_division=0
        )
    }


# --- Reconstruction Utility ---

def reconstruct_canvas(
    model: tf.keras.Model,
    df: pd.DataFrame,
    source_file_prefix: str, # The common prefix for all chips of one large image
    generator_class: callable, # The function to build the TensorFlow dataset (e.g., `build_tf_dataset`)
    img_dir: str,
    elev_dir: str,
    slope_dir: str,
    label_dir: str,
    tile_size: int = 256
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Reconstructs the full RGB image, ground truth mask, and model prediction mask
    for a given source file by stitching together its individual chips.

    Args:
        model (tf.keras.Model): The trained Keras model.
        df (pd.DataFrame): The DataFrame containing metadata for all chips (e.g., test_df).
        source_file_prefix (str): The common prefix for all chips belonging to the same large image.
                                  (e.g., "25f1c24f30_EB81FE6E2BOPENPIPELINE")
        generator_class (callable): The function to build the TensorFlow dataset (e.g., `build_tf_dataset`).
        img_dir (str): Directory containing RGB images.
        elev_dir (str): Directory containing elevation .npy files.
        slope_dir (str): Directory containing slope .npy files.
        label_dir (str): Directory containing label images.
        tile_size (int): The size of individual tiles/chips (e.g., 256).

    Returns:
        tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing the
        reconstructed RGB canvas, Ground Truth canvas, and Prediction canvas (all as np.uint8).

    Raises:
        ValueError: If no chips are found for the specified source file prefix.
    """
    # Filter for chips belonging to the specific source file prefix
    # Assuming 'tile_id' column contains the full ID like 'prefix_x_y'
    df_file = df[df['tile_id'].str.startswith(source_file_prefix)].copy()
    if df_file.empty:
        raise ValueError(f"No chips found for source file prefix: {source_file_prefix}")

    # Determine the overall canvas shape based on min/max x,y coordinates and tile size
    min_x = df_file['x'].min()
    min_y = df_file['y'].min()
    max_x = df_file['x'].max() + tile_size
    max_y = df_file['y'].max() + tile_size

    canvas_w = max_x - min_x
    canvas_h = max_y - min_y

    canvas_shape_rgb = (canvas_h, canvas_w, 3)
    canvas_shape_mask = (canvas_h, canvas_w, 3) # For GT/Pred, will be colored RGB

    # Initialize canvases with IGNORE_COLOR (magenta) for padding/uncovered areas
    rgb_canvas = np.full(canvas_shape_rgb, IGNORE_COLOR, dtype=np.uint8)
    gt_canvas = np.full(canvas_shape_mask, IGNORE_COLOR, dtype=np.uint8)
    pred_canvas = np.full(canvas_shape_mask, IGNORE_COLOR, dtype=np.uint8)

    # Build dataset for just this file's chips, ensuring correct order for stitching
    # Use shuffle=False and augment=False for reconstruction, and a 'val' or 'test' split
    # to guarantee no augmentation. Batch size can be larger for efficiency during inference.
    # Assuming 'generator_class' (e.g., build_tf_dataset) can accept a 'split' argument.
    gen = generator_class(
        df_file, img_dir, elev_dir, slope_dir, label_dir,
        batch_size=64, shuffle=False, augment=False, split="val" # Use 'val' or 'test' to ensure no augments
    )

    # Iterate through the generated batches, predict, and fill the canvases in order
    row_index_in_df = 0 # Tracks position in the filtered df_file to get (x,y) coords
    for batch_x, batch_y_onehot in tqdm(gen, desc=f"Reconstructing {source_file_prefix}"):
        # Get model predictions (softmax probabilities)
        preds_softmax = model.predict(batch_x, verbose=0)
        # Convert probabilities to class IDs (predicted masks)
        pred_mask_ids = tf.argmax(preds_softmax, axis=-1).numpy()
        # Convert one-hot true labels to class IDs (ground truth masks)
        true_mask_ids = tf.argmax(batch_y_onehot, axis=-1).numpy()

        batch_size_actual = batch_x.shape[0]
        for i in range(batch_size_actual):
            if row_index_in_df >= len(df_file): # Safety break if we've processed all chips
                break

            # Get chip's original (x, y) coordinates from the DataFrame
            row_df_entry = df_file.iloc[row_index_in_df]
            rel_x = row_df_entry.x - min_x # Relative X coordinate for placing on canvas
            rel_y = row_df_entry.y - min_y # Relative Y coordinate for placing on canvas
            row_index_in_df += 1

            # Extract current chip data (RGB is always first 3 channels)
            # Scale from [0,1] to [0,255] and cast to uint8
            rgb_chip = tf.cast(batch_x[i][..., :3] * 255.0, tf.uint8).numpy()
            true_mask_chip = true_mask_ids[i]
            pred_mask_chip = pred_mask_ids[i]

            # Place RGB chip onto the canvas
            rgb_canvas[rel_y : rel_y + tile_size, rel_x : rel_x + tile_size] = rgb_chip

            # Color and place Ground Truth mask onto the canvas
            gt_rgb_chip = np.zeros((tile_size, tile_size, 3), dtype=np.uint8)
            for cid, colour in CLASS_TO_COLOR.items():
                gt_rgb_chip[true_mask_chip == cid] = np.array(colour, dtype=np.uint8)
            gt_canvas[rel_y : rel_y + tile_size, rel_x : rel_x + tile_size] = gt_rgb_chip

            # Color and place Prediction mask onto the canvas
            pred_rgb_chip = np.zeros((tile_size, tile_size, 3), dtype=np.uint8)
            for cid, colour in CLASS_TO_COLOR.items():
                pred_rgb_chip[pred_mask_chip == cid] = np.array(colour, dtype=np.uint8)
            pred_canvas[rel_y : rel_y + tile_size, rel_x : rel_x + tile_size] = pred_rgb_chip
    
    return rgb_canvas, gt_canvas, pred_canvas


def plot_reconstruction(img: np.ndarray, label: np.ndarray, pred: np.ndarray, source_file_prefix: str):
    """
    Plots the reconstructed full RGB image, ground truth mask, and model prediction mask
    for a given source file prefix.

    Args:
        img (np.ndarray): The reconstructed RGB image canvas (uint8).
        label (np.ndarray): The reconstructed ground truth mask canvas (colored, uint8).
        pred (np.ndarray): The reconstructed prediction mask canvas (colored, uint8).
        source_file_prefix (str): The common prefix of the source file for the title.
    """
    fig, axs = plt.subplots(1, 3, figsize=(26.5, 13))

    axs[0].imshow(img)
    axs[0].set_title("RGB Image", fontsize=18)
    axs[0].axis('off')

    axs[1].imshow(label)
    axs[1].set_title("Ground Truth", fontsize=18)
    axs[1].axis('off')

    axs[2].imshow(pred)
    axs[2].set_title("Model Prediction", fontsize=18)
    axs[2].axis('off')

    plt.suptitle(f"Reconstruction for: {source_file_prefix}", fontsize=24, y=0.95)
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    plt.show()
    plt.close(fig)


def normalize_confusion_matrix(cm: np.ndarray, norm: str = 'true') -> np.ndarray:
    """
    Normalizes a confusion matrix.

    Args:
        cm (np.ndarray): Confusion matrix to be normalized (integer counts).
        norm (str): Type of normalization.
                    - 'true': Normalize by true labels (rows sum to 1).
                    - 'pred': Normalize by predicted labels (columns sum to 1).
                    - 'all': Normalize by total count (matrix sums to 1).

    Returns:
        np.ndarray: Normalized confusion matrix (float, sums to 1 along specified axis/total).

    Raises:
        ValueError: If an unknown normalization type is provided.
    """
    cm_normalized = cm.astype(np.float32)

    if norm == 'true':
        # Normalize each row by its sum
        row_sums = cm_normalized.sum(axis=1, keepdims=True)
        # Use np.divide with 'where' clause to avoid division by zero
        cm_normalized = np.divide(cm_normalized, row_sums, where=row_sums != 0)
        # Fill NaN results (where row_sum was 0, meaning no true samples for that class) with 0
        cm_normalized = np.nan_to_num(cm_normalized, nan=0.0)
    elif norm == 'pred':
        # Normalize each column by its sum
        col_sums = cm_normalized.sum(axis=0, keepdims=True)
        cm_normalized = np.divide(cm_normalized, col_sums, where=col_sums != 0)
        cm_normalized = np.nan_to_num(cm_normalized, nan=0.0)
    elif norm == 'all':
        # Normalize by the grand total sum
        total_sum = cm_normalized.sum()
        cm_normalized = cm_normalized / total_sum if total_sum > 0 else cm_normalized
        cm_normalized = np.nan_to_num(cm_normalized, nan=0.0)
    else:
        raise ValueError("Unknown normalization type. Use 'true', 'pred', or 'all'.")
    
    return cm_normalized
    '''

In [None]:
'''import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import pandas as pd
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TerminateOnNaN, Callback
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
from tensorflow.keras.losses import CategoricalCrossentropy
import segmentation_models as sm
from collections import defaultdict
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras import mixed_precision # Required for LossScaleOptimizer
import gc
from PIL import Image # Used for Image.fromarray in some visualization tasks (though not directly here)
from datetime import datetime # Often used for timestamps in model saving
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix, classification_report, ConfusionMatrixDisplay
import random
# from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral # Imports for CRF, often within function
# import pydensecrf.densecrf as dcrf # Imports for CRF, often within function

# Clear Keras session to avoid conflicts with previous models
tf.keras.backend.clear_session()


# --- Global Configuration and Constants ---

# Number of semantic classes
NUM_CLASSES = 6
# Names corresponding to each class ID
CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']

# Mapping from RGB color (as tuple) to integer class ID
COLOR_TO_CLASS = {
    (230, 25, 75): 0,      # Building
    (145, 30, 180): 1,     # Clutter
    (60, 180, 75): 2,      # Vegetation
    (245, 130, 48): 3,     # Water
    (255, 255, 255): 4,    # Background
    (0, 130, 200): 5,      # Car
    (255, 0, 255): 6       # Often used as an "ignore" or "padding" pixel color
}

# Inverse mapping from integer class ID to RGB color (as tuple)
# Excludes the ignore class (ID 6) for visualization purposes
CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items() if v < NUM_CLASSES}
IGNORE_COLOR = (255, 0, 255) # The specific color for ignored regions (magenta)

# Output directory for plots and saved models
out_dir = "/content/figs"

# Configuration for different input types (RGB only or RGB + Elevation)
INPUT_TYPE_CONFIG = {
    "rgb": {"description": "RGB only", "channels": 3},
    "rgb_elev": {"description": "RGB + elevation + slope", "channels": 5} # RGB (3) + Elev (1) + Slope (1)
}


# List of specific tile IDs to select for detailed visualization (e.g., hard examples)
# These are the full tile_id strings including x_y coordinates
specific_tile_ids = [
    # Group 1 (Example source file: 25f1c24f30_EB81FE6E2BOPENPIPELINE)
    "25f1c24f30_EB81FE6E2BOPENPIPELINE_3456_1280", "25f1c24f30_EB81FE6E2BOPENPIPELINE_3584_8320",
    "25f1c24f30_EB81FE6E2BOPENPIPELINE_896_2816", "25f1c24f30_EB81FE6E2BOPENPIPELINE_3840_4736",
    "25f1c24f30_EB81FE6E2BOPENPIPELINE_3968_384", "25f1c24f30_EB81FE6E2BOPENPIPELINE_4736_512",
    "25f1c24f30_EB81FE6E2BOPENPIPELINE_4736_768", "25f1c24f30_EB81FE6E2BOPENPIPELINE_1024_5888",
    "25f1c24f30_EB81FE6E2BOPENPIPELINE_896_5888", "25f1c24f30_EB81FE6E2BOPENPIPELINE_1024_6016",

    # Group 2 (Example source file: 1d4fbe33f3_F1BE1D4184INSPIRE)
    "1d4fbe33f3_F1BE1D4184INSPIRE_2560_4864", "1d4fbe33f3_F1BE1D4184INSPIRE_896_3584",
    "1d4fbe33f3_F1BE1D4184INSPIRE_768_3584", "1d4fbe33f3_F1BE1D4184INSPIRE_896_3712",
    "1d4fbe33f3_F1BE1D4184INSPIRE_1280_2432", "1d4fbe33f3_F1BE1D4184INSPIRE_1536_4608",
    "1d4fbe33f3_F1BE1D4184INSPIRE_1152_2432", "1d4fbe33f3_F1BE1D4184INSPIRE_1664_4864",
    "1d4fbe33f3_F1BE1D4184INSPIRE_1664_4736", "1d4fbe33f3_F1BE1D4184INSPIRE_1408_1280",
    "1d4fbe33f3_F1BE1D4184INSPIRE_1152_4864", "1d4fbe33f3_F1BE1D4184INSPIRE_1280_2432",
    "1d4fbe33f3_F1BE1D4184INSPIRE_1408_1408", "1d4fbe33f3_F1BE1D4184INSPIRE_1536_4736",
    "1d4fbe33f3_F1BE1D4184INSPIRE_384_1152",

    # Group 3 (Example source file: 15efe45820_D95DF0B1F4INSPIRE)
    "15efe45820_D95DF0B1F4INSPIRE_4736_9472", "15efe45820_D95DF0B1F4INSPIRE_9600_6016",
    "15efe45820_D95DF0B1F4INSPIRE_5888_6272", "15efe45820_D95DF0B1F4INSPIRE_7168_7936",
    "15efe45820_D95DF0B1F4INSPIRE_6016_6272", "15efe45820_D95DF0B1F4INSPIRE_8704_1024",
    "15efe45820_D95DF0B1F4INSPIRE_7040_6912", "15efe45820_D95DF0B1F4INSPIRE_8064_3968",
    "15efe45820_D95DF0B1F4INSPIRE_2688_2048", "15efe45820_D95DF0B1F4INSPIRE_7680_1920",
    "15efe45820_D95DF0B1F4INSPIRE_6272_10624", "15efe45820_D95DF0B1F4INSPIRE_6784_6784",
    "15efe45820_D95DF0B1F4INSPIRE_6528_8576",

    # Group 4 (Example source file: c6d131e346_536DE05ED2OPENPIPELINE)
    "c6d131e346_536DE05ED2OPENPIPELINE_128_896", "c6d131e346_536DE05ED2OPENPIPELINE_256_768",
    "c6d131e346_536DE05ED2OPENPIPELINE_256_896", "c6d131e346_536DE05ED2OPENPIPELINE_1792_512",
    "c6d131e346_536DE05ED2OPENPIPELINE_1792_640", "c6d131e346_536DE05ED2OPENPIPELINE_256_640",
    "c6d131e346_536DE05ED2OPENPIPELINE_128_640", "c6d131e346_536DE05ED2OPENPIPELINE_128_128",
    "c6d131e346_536DE05ED2OPENPIPELINE_256_128", "c6d131e346_536DE05ED2OPENPIPELINE_256_512",
    "c6d131e346_536DE05ED2OPENPIPELINE_2688_2176", "c6d131e346_536DE05ED2OPENPIPELINE_2560_2176",
    "c6d131e346_536DE05ED2OPENPIPELINE_2688_2048", "c6d131e346_536DE05ED2OPENPIPELINE_2560_2048",
    "c6d131e346_536DE05ED2OPENPIPELINE_2688_2304", "c6d131e346_536DE05ED2OPENPIPELINE_2560_2304",
    "c6d131e346_536DE05ED2OPENPIPELINE_2816_2176", "c6d131e346_536DE05ED2OPENPIPELINE_2816_2048",
    "c6d131e346_536DE05ED2OPENPIPELINE_2816_2304", "c6d131e346_536DE05ED2OPENPIPELINE_2304_2560",
    "c6d131e346_536DE05ED2OPENPIPELINE_2304_2688", "c6d131e346_536DE05ED2OPENPIPELINE_2432_2688",
    "c6d131e346_536DE05ED2OPENPIPELINE_2432_2560", "c6d131e346_536DE05ED2OPENPIPELINE_2176_2560",
    "c6d131e346_536DE05ED2OPENPIPELINE_2176_2688",

    # Group 5 (Example source file: 12fa5e614f_53197F206FOPENPIPELINE)
    "12fa5e614f_53197F206FOPENPIPELINE_384_3072", "12fa5e614f_53197F206FOPENPIPELINE_512_3072",
    "12fa5e614f_53197F206FOPENPIPELINE_256_3200", "12fa5e614f_53197F206FOPENPIPELINE_1024_3712",
    "12fa5e614f_53197F206FOPENPIPELINE_384_3200", "12fa5e614f_53197F206FOPENPIPELINE_640_3072",
    "12fa5e614f_53197F206FOPENPIPELINE_256_3328", "12fa5e614f_53197F206FOPENPIPELINE_256_3072",
    "12fa5e614f_53197F206FOPENPIPELINE_3200_1152", "12fa5e614f_53197F206FOPENPIPELINE_1152_2688",
    "12fa5e614f_53197F206FOPENPIPELINE_1536_2432", "12fa5e614f_53197F206FOPENPIPELINE_1280_2560",
    "12fa5e614f_53197F206FOPENPIPELINE_1536_2048", "12fa5e614f_53197F206FOPENPIPELINE_512_3840",
    "12fa5e614f_53197F206FOPENPIPELINE_512_3712", "12fa5e614f_53197F206FOPENPIPELINE_1664_2304",
    "12fa5e614f_53197F206FOPENPIPELINE_384_3456", "12fa5e614f_53197F206FOPENPIPELINE_384_3328",
    "12fa5e614f_53197F206FOPENPIPELINE_1280_3584", "12fa5e614f_53197F206FOPENPIPELINE_384_3584",
    "12fa5e614f_53197F206FOPENPIPELINE_3072_1152", "12fa5e614f_53197F206FOPENPIPELINE_3456_1024",

    # Group 6 (Example source file: 5fa39d6378_DB9FF730D9OPENPIPELINE)
    "5fa39d6378_DB9FF730D9OPENPIPELINE_3072_2688", "5fa39d6378_DB9FF730D9OPENPIPELINE_1024_6784",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_3712_2816", "5fa39d6378_DB9FF730D9OPENPIPELINE_3200_2688",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_2944_2816", "5fa39d6378_DB9FF730D9OPENPIPELINE_4224_3072",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_3328_4992", "5fa39d6378_DB9FF730D9OPENPIPELINE_1024_6528",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_3840_5888", "5fa39d6378_DB9FF730D9OPENPIPELINE_2816_4224",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_5760_1920", "5fa39d6378_DB9FF730D9OPENPIPELINE_3328_2816",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4352_4864", "5fa39d6378_DB9FF730D9OPENPIPELINE_3072_6912",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4096_3328", "5fa39d6378_DB9FF730D9OPENPIPELINE_2816_3968",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_5888_1920", "5fa39d6378_DB9FF730D9OPENPIPELINE_1280_2432",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_3584_2560", "5fa39d6378_DB9FF730D9OPENPIPELINE_1280_5632",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_1280_5504", "5fa39d6378_DB9FF730D9OPENPIPELINE_1408_5504",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_1408_5632", "5fa39d6378_DB9FF730D9OPENPIPELINE_4608_4480",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4608_4352", "5fa39d6378_DB9FF730D9OPENPIPELINE_4736_4480",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_2816_6400", "5fa39d6378_DB9FF730D9OPENPIPELINE_2944_6400",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_1280_5760", "5fa39d6378_DB9FF730D9OPENPIPELINE_2816_6528",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_2944_6528", "5fa39d6378_DB9FF730D9OPENPIPELINE_1408_5760",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4608_4608", "5fa39d6378_DB9FF730D9OPENPIPELINE_4480_4352",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4736_4608", "5fa39d6378_DB9FF730D9OPENPIPELINE_4480_4480",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_1280_5376", "5fa39d6378_DB9FF730D9OPENPIPELINE_1408_5376",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4736_4352", "5fa39d6378_DB9FF730D9OPENPIPELINE_2816_6272",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_2944_6272", "5fa39d6378_DB9FF730D9OPENPIPELINE_1152_5760"
]

# --- Custom Callbacks (Placeholders - assuming their full definitions are available elsewhere) ---

class TimeLimitCallback(Callback):
    """
    Keras Callback to stop training after a specified number of minutes.
    Assumes `datetime` from `datetime` module is imported.
    """
    def __init__(self, max_minutes):
        super().__init__()
        self.max_minutes = max_minutes
        self.start_time = None

    def on_train_begin(self, logs=None):
        self.start_time = datetime.now()

    def on_epoch_end(self, epoch, logs=None):
        if self.start_time is None: # Should not happen if on_train_begin is called
            return
        elapsed_minutes = (datetime.now() - self.start_time).total_seconds() / 60
        if elapsed_minutes > self.max_minutes:
            print(f"\nTime limit of {self.max_minutes} minutes reached. Stopping training.")
            self.model.stop_training = True

class StepTimer(Callback):
    """
    Keras Callback to print the time taken per step (batch).
    Assumes `time` module is imported.
    """
    def on_train_begin(self, logs={}):
        self.times = []
        self.timetaken = tf.Variable(0., dtype=tf.float32)

    def on_train_batch_begin(self, batch, logs={}):
        self.timetaken.assign(tf.timestamp())

    def on_train_batch_end(self, batch, logs={}):
        batch_time = tf.timestamp() - self.timetaken
        self.times.append(batch_time.numpy())
        if batch % 100 == 0: # Print every 100 batches
            print(f"Batch {batch}: {batch_time.numpy():.4f} seconds/step")

# --- Custom Learning Rate Schedule ---

class TransformerLRSchedule(LearningRateSchedule):
    """
    Custom learning rate schedule inspired by the Transformer model.
    The learning rate increases linearly for the first `warmup_steps`
    and then decreases proportionally to the inverse square root of the step number.
    """
    def __init__(self, d_model: float, warmup_steps: int = 4000):
        """
        Initializes the Transformer learning rate schedule.

        Args:
            d_model (float): The dimensionality of the model's embeddings.
            warmup_steps (int): The number of warm-up steps during which the learning rate increases.
        """
        super().__init__()
        self.d_model = tf.cast(d_model, tf.float32)
        self.warmup_steps = tf.cast(warmup_steps, tf.float32)

    def __call__(self, step: tf.Tensor) -> tf.Tensor:
        """
        Calculates the learning rate for a given training step.

        Args:
            step (tf.Tensor): The current global training step (scalar).

        Returns:
            tf.Tensor: The calculated learning rate.
        """
        step = tf.cast(step, tf.float32)
        # Inverse square root of step (decay part)
        arg1 = tf.math.rsqrt(step)
        # Linear warmup part
        arg2 = step * tf.pow(self.warmup_steps, -1.5)
        
        # Combine the two parts: min(linear_warmup, inverse_sqrt_decay)
        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

    def get_config(self) -> dict:
        """
        Returns the configuration of the learning rate schedule for serialization.

        Returns:
            dict: A dictionary containing the configuration parameters.
        """
        return {
            "d_model": self.d_model.numpy(),  # convert back to Python float
            "warmup_steps": self.warmup_steps.numpy()
        }

# --- Dynamic Class Weight Updater Callback ---

class DynamicClassWeightUpdater(Callback):
    """
    A Keras Callback to dynamically update class weights during training based on
    per-class F1-score or IoU on the validation set. Weights are updated every `update_every`
    epochs, typically set inversely proportional to the metric (1/metric).
    """
    def __init__(self, val_data: tf.data.Dataset, update_every: int = 5, 
                 target: str = 'f1', ignore_class: int = None):
        """
        Initializes the DynamicClassWeightUpdater callback.

        Args:
            val_data (tf.data.Dataset): The TensorFlow Dataset to use for validation metrics.
            update_every (int): How often (in epochs) to update the class weights.
            target (str): The metric to target for weighting ('f1' or 'iou').
            ignore_class (int, optional): Class ID to ignore (set its weight to 0.0). Defaults to None.
        """
        super().__init__()
        self.val_data = val_data
        self.update_every = update_every
        self.target = target
        self.ignore_class = ignore_class
        # Ensure class_weights is a tf.Variable, expected to be passed from the model compile scope
        # self.class_weights will be assigned in on_train_begin if using tf.Variable for loss.

    def on_epoch_end(self, epoch: int, logs: dict = None):
        """
        Method called at the end of each epoch. Updates weights if `epoch + 1` is a multiple
        of `update_every`.

        Args:
            epoch (int): The current epoch number.
            logs (dict, optional): Dictionary of logs. Defaults to None.
        """
        if (epoch + 1) % self.update_every != 0:
            return

        # Ensure the `class_weights` variable from the model's loss is accessible and trainable
        # This assumes `class_weights` is a tf.Variable in your loss function's scope.
        # A more robust way might involve passing the tf.Variable reference directly in __init__
        # or having the loss function itself be a Keras Layer that exposes the variable.
        # For this setup, we assume `class_weights` is a global tf.Variable that the loss uses.
        global class_weights # Access the global tf.Variable class_weights

        y_true_all = []
        y_pred_all = []

        # Predict on validation data to compute per-class metrics
        print(f"\n📊 Epoch {epoch+1}: Computing per-class metrics for dynamic weight update...")
        for x_batch, y_batch in self.val_data:
            # Predict with verbose=0 to suppress per-batch output
            preds = self.model.predict(x_batch, verbose=0) 
            # Convert one-hot to class IDs and flatten for metric calculation
            y_true = tf.argmax(y_batch, axis=-1).numpy().flatten()
            y_pred = tf.argmax(preds, axis=-1).numpy().flatten()

            y_true_all.extend(y_true)
            y_pred_all.extend(y_pred)

        y_true_all = np.array(y_true_all)
        y_pred_all = np.array(y_pred_all)

        new_weights = []

        for i in range(NUM_CLASSES):
            if self.ignore_class is not None and i == self.ignore_class:
                new_weights.append(0.0) # Set weight to 0 for ignored class
                continue

            if self.target == 'f1':
                # f1_score from sklearn.metrics
                # zero_division=0 means F1 is 0 if no true samples for class
                f1 = f1_score(y_true_all == i, y_pred_all == i, zero_division=0)
                weight = 1.0 if f1 == 0 else 1.0 / f1
            else: # target == 'iou'
                # Manual IoU calculation for the class
                intersection = np.logical_and(y_true_all == i, y_pred_all == i).sum()
                union = (y_true_all == i).sum() + (y_pred_all == i).sum() - intersection
                iou = intersection / union if union > 0 else 0.0
                weight = 1.0 if iou == 0 else 1.0 / iou

            new_weights.append(weight)

        # Normalize weights to prevent extremely large values and maintain scale
        new_weights = np.array(new_weights, dtype=np.float32)
        new_weights = new_weights / new_weights.max() # normalise by max value

        # Update the TensorFlow Variable
        class_weights.assign(new_weights)
        print(f"\n📈 Epoch {epoch+1}: Updated class weights: {new_weights}\n")

# --- Custom Metric for Mean IoU ---

os.environ["SM_FRAMEWORK"] = "tf.keras" # Ensure Segmentation Models uses TensorFlow/Keras backend

class MeanIoUMetric(tf.keras.metrics.MeanIoU):
    """
    Custom Mean IoU metric for Keras, handling one-hot encoded true labels and
    softmax predictions by taking argmax before calculation.
    """
    def __init__(self, num_classes: int, name: str = "mean_iou", dtype=None):
        """
        Initializes the MeanIoUMetric.

        Args:
            num_classes (int): The total number of classes.
            name (str): The name of the metric instance.
            dtype: The data type of the metric's state variables.
        """
        super().__init__(num_classes=num_classes, name=name, dtype=dtype)

    def update_state(self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor = None):
        """
        Updates the metric's state variables.

        Args:
            y_true (tf.Tensor): Ground truth labels (one-hot encoded).
            y_pred (tf.Tensor): Model predictions (softmax probabilities).
            sample_weight (tf.Tensor, optional): Optional weighting of samples. Defaults to None.
        """
        # Convert one-hot true labels to class IDs
        y_true = tf.argmax(y_true, axis=-1)
        # Convert softmax predictions to class IDs
        y_pred = tf.argmax(y_pred, axis=-1)
        return super().update_state(y_true, y_pred, sample_weight)


# --- Loss Function Helpers ---

def apply_label_smoothing(y_true: tf.Tensor, smoothing: float = 0.1) -> tf.Tensor:
    """
    Applies label smoothing to one-hot encoded ground truth labels.
    This helps prevent the model from becoming overconfident and can improve generalization.

    Args:
        y_true (tf.Tensor): True labels (one-hot encoded).
        smoothing (float): The smoothing factor (0.0 for no smoothing, 1.0 for uniform distribution).

    Returns:
        tf.Tensor: Smoothed true labels.
    """
    num_classes = tf.cast(tf.shape(y_true)[-1], tf.float32)
    # y_true * (1 - smoothing) distributes (1-smoothing) probability to the true class
    # (smoothing / num_classes) distributes 'smoothing' probability uniformly across all classes
    return y_true * (1.0 - smoothing) + (smoothing / num_classes)

def apply_ignore_class_mask(y_true: tf.Tensor, y_pred: tf.Tensor, 
                            ignore_class: int = 4, loss_fn: callable = None) -> tf.Tensor:
    """
    Applies a mask to a pixel-wise loss function to ignore contributions from a specific class (e.g., Background).
    This function expects `loss_fn` to return a pixel-wise loss (not a scalar sum).

    Args:
        y_true (tf.Tensor): Ground truth labels (one-hot encoded, or 2D class IDs if passed to loss_fn).
        y_pred (tf.Tensor): Model predictions (softmax probabilities).
        ignore_class (int): The class ID to ignore in loss calculation.
        loss_fn (callable): The base loss function (e.g., CCE, Dice, Focal) that calculates
                            loss per pixel/element. This function should accept y_true and y_pred.

    Returns:
        tf.Tensor: The masked and reduced loss. Returns scalar if original loss_fn is scalar,
                   otherwise sums and normalizes the masked pixel-wise loss.
    """
    # y_true, y_pred shape: (batch, h, w, num_classes)
    class_ids = tf.argmax(y_true, axis=-1)  # shape: (batch, h, w)
    mask = tf.not_equal(class_ids, ignore_class)  # Boolean mask: shape (batch, h, w)

    mask = tf.cast(mask, tf.float32)  # Convert mask to float32: same shape as class_ids
    mask = tf.expand_dims(mask, axis=-1)  # Expand to (batch, h, w, 1) for broadcasting

    # Apply base loss function
    loss = loss_fn(y_true, y_pred)  # shape: (batch, h, w, 1) or scalar if loss_fn is already averaged

    # If the loss_fn returns a scalar (already averaged), we cannot apply pixel-wise masking.
    # In such a case, the ignore_class logic would need to be integrated directly into the
    # definition of the base loss function (e.g., custom DiceLoss).
    if len(loss.shape) < 4:
        tf.print("Warning: apply_ignore_class_mask received a scalar loss. Pixel-wise masking may not apply correctly.")
        return loss # Cannot apply pixel-wise mask if loss is already aggregated.

    masked_loss = loss * mask # Apply the mask element-wise
    # Return the sum of masked loss divided by the sum of the mask (to average only over non-ignored pixels)
    # Add a small epsilon to the denominator to prevent division by zero if mask is all zeros.
    return tf.reduce_sum(masked_loss) / (tf.reduce_sum(mask) + tf.keras.backend.epsilon())


# --- Image Decoding (Similar to data_pipeline, ensuring consistency) ---

# Note: This function is slightly different from decode_coloured_label in data_pipeline
# as it iterates through pixels, which is slower in graph mode.
# If decode_coloured_label from data_pipeline is used, this might be redundant.
def decode_label_image(label_img: np.ndarray) -> np.ndarray:
    """
    Decodes an RGB label image (NumPy array) into a 2D integer class map.
    This function is primarily for NumPy arrays (e.g., for direct image processing).
    For TensorFlow graph operations, `decode_coloured_label` from the data pipeline is preferred.

    Args:
        label_img (np.ndarray): A 3-channel RGB label image (H, W, 3) as a NumPy array.

    Returns:
        np.ndarray: A 2D NumPy array (H, W) where each element is an integer class ID.

    Raises:
        ValueError: If an unknown label color is encountered.
    """
    # Using COLOR_LOOKUP which maps tuple(RGB) to class ID
    COLOR_LOOKUP = {tuple(c): i for c, i in COLOR_TO_CLASS.items()}

    h, w, _ = label_img.shape
    label_map = np.zeros((h, w), dtype=np.uint8)
    for y in range(h):
        for x in range(w):
            pixel = tuple(label_img[y, x])
            if pixel not in COLOR_LOOKUP:
                raise ValueError(f"Unknown label colour {pixel} at ({y}, {x})")
            label_map[y, x] = COLOR_LOOKUP[pixel]
    return label_map


# --- Utility Functions for File Filtering ---

def filter_tile_ids_by_substring(image_dir: str, base_names: list) -> list:
    """
    Filters a list of tile IDs in a directory based on whether their base name
    (without the '-ortho.png' suffix) contains any of the specified base names.

    Args:
        image_dir (str): The directory containing image files.
        base_names (list): A list of base names (e.g., source file prefixes) to filter by.

    Returns:
        list: A list of filtered tile IDs (without the suffix).
    """
    return [f.replace('-ortho.png', '') for f in os.listdir(image_dir) if any(base in f for base in base_names)]


# --- Visualization of Prediction Grids ---

def plot_colored_mask(mask_2d: np.ndarray) -> np.ndarray:
    """
    Converts a 2D integer class mask into a 3-channel RGB image using CLASS_TO_COLOR mapping.

    Args:
        mask_2d (np.ndarray): A 2D numpy array where each element is an integer class ID (H, W).

    Returns:
        np.ndarray: A 3-channel numpy array (H, W, 3) representing the colored mask.
    """
    colored_mask = np.zeros((mask_2d.shape[0], mask_2d.shape[1], 3), dtype=np.uint8)
    for class_id, color_rgb in CLASS_TO_COLOR.items():
        colored_mask[mask_2d == class_id] = color_rgb
    return colored_mask


def visualise_prediction_grid(
    rgb_list: list,
    true_mask_list: list,
    pred_mask_list: list,
    tile_id_list: list = None,
    all_tile_ids: list = None,
    n_rows: int = 4,
    n_cols: int = 3
):
    """
    Visualizes a grid of input RGB images, their ground truth masks, and model predictions.
    Optionally prioritizes specific tile IDs for display.

    Args:
        rgb_list (list): List of input RGB images (NumPy arrays).
        true_mask_list (list): List of ground truth masks (one-hot encoded NumPy arrays).
        pred_mask_list (list): List of predicted masks (integer class ID NumPy arrays).
        tile_id_list (list, optional): A list of specific tile IDs to prioritize for plotting.
                                       If None, random chips are selected.
        all_tile_ids (list, optional): A list of all tile IDs corresponding to `rgb_list`, etc.
                                       Required if `tile_id_list` is provided.
        n_rows (int): Number of rows in the visualization grid.
        n_cols (int): Number of columns in the visualization grid.
    """
    total_plots = n_rows * n_cols
    fig, axs = plt.subplots(n_rows, n_cols * 3, figsize=(n_cols * 6.6, n_rows * 2.6))

    indices_to_plot = []

    if tile_id_list and all_tile_ids:
        # Prioritize specific tile IDs if requested
        tile_id_set = set(tile_id_list)
        matched_indices = [i for i, tid in enumerate(all_tile_ids) if tid in tile_id_set]
        random.shuffle(matched_indices)
        indices_to_plot = matched_indices[:total_plots] # Take up to total_plots

        # Fill remaining slots with random other chips if not enough specific ones
        if len(indices_to_plot) < total_plots:
            # Get indices of chips NOT in specific_tile_ids
            other_indices = list(set(range(len(rgb_list))) - set(indices_to_plot))
            random.shuffle(other_indices)
            indices_to_plot.extend(other_indices[:total_plots - len(indices_to_plot)])
    else:
        # If no specific IDs, just take random 'total_plots' number of indices
        indices_to_plot = list(range(min(total_plots, len(rgb_list))))
        random.shuffle(indices_to_plot) # Shuffle to get random samples

    # Ensure we plot exactly `total_plots` or as many as available
    indices_to_plot = indices_to_plot[:total_plots]

    for idx_plot, data_idx in enumerate(indices_to_plot):
        rgb = rgb_list[data_idx]
        true_mask_onehot = true_mask_list[data_idx] # This is one-hot
        true_mask = np.argmax(true_mask_onehot, axis=-1) # Convert to class IDs
        pred_mask = pred_mask_list[data_idx]

        h, w = true_mask.shape
        # Initialize RGB versions of masks
        true_rgb = np.zeros((h, w, 3), dtype=np.uint8)
        pred_rgb = np.zeros((h, w, 3), dtype=np.uint8)

        # Color the true and predicted masks based on CLASS_TO_COLOR
        for class_id, color in CLASS_TO_COLOR.items():
            true_rgb[true_mask == class_id] = color
            pred_rgb[pred_mask == class_id] = color

        # Apply ignore color (magenta) to ignored regions in ground truth and prediction
        # The ignore mask is based on y_true being all zeros (no class assigned)
        ignore_mask = np.all(true_mask_onehot == 0, axis=-1) # Assuming 0-vec in one-hot means ignored
        true_rgb[ignore_mask] = IGNORE_COLOR
        pred_rgb[ignore_mask] = IGNORE_COLOR

        row = idx_plot // n_cols
        col = (idx_plot % n_cols) * 3 # Each group (RGB, GT, Pred) takes 3 columns

        axs[row, col + 0].imshow(rgb)
        axs[row, col + 0].set_title("Input", fontsize=10)
        axs[row, col + 1].imshow(true_rgb)
        axs[row, col + 1].set_title("Ground Truth", fontsize=10)
        axs[row, col + 2].imshow(pred_rgb)
        axs[row, col + 2].set_title("Prediction", fontsize=10)

        for i_ax in range(3):
            axs[row, col + i_ax].axis("off")

    plt.tight_layout()
    plt.show()
    plt.close(fig)


# --- Evaluation and Metrics ---

def measure_inference_time(model: tf.keras.Model, generator: tf.data.Dataset, num_batches: int = 5):
    """
    Measures the inference time of a Keras model on a given TensorFlow dataset generator.

    Args:
        model (tf.keras.Model): The trained Keras model.
        generator (tf.data.Dataset): The TensorFlow dataset generator for inference.
        num_batches (int): Number of batches to run inference on for measurement.
                           If None, tries to get cardinality of the generator.
    """
    import time # Import time locally as it's only used here.

    total_time = 0.0
    total_images = 0

    if num_batches is None:
        try:
            # tf.data.experimental.cardinality returns UNKNOWN for some datasets, so handle it
            cardinality = tf.data.experimental.cardinality(generator).numpy()
            if cardinality == tf.data.experimental.INFINITE_CARDINALITY:
                print("Warning: Generator has infinite cardinality. Measuring on first 5 batches.")
                num_batches = 5
            else:
                num_batches = cardinality
        except Exception:
            print("Warning: Could not determine generator cardinality. Measuring on first 5 batches.")
            num_batches = 5 # Default to 5 batches if cardinality is unknown

    print(f"\n⏱️ Measuring inference time over {num_batches} batches...")
    for i, (x_batch, _) in enumerate(generator.take(num_batches)):
        start = time.time()
        _ = model.predict(x_batch, verbose=0) # verbose=0 suppresses progress bar
        end = time.time()
        total_time += (end - start)
        total_images += x_batch.shape[0]

    if total_images > 0:
        print(f"🧠 Inference time: {total_time:.2f} sec for {total_images} images")
        print(f"⏱️ Avg inference time per image: {total_time / total_images:.4f} sec")
    else:
        print("No images processed for inference time measurement.")


def plot_training_curves(history: tf.keras.callbacks.History, out_dir: str):
    """
    Plots and saves the training and validation loss, and IoU scores over epochs.

    Args:
        history (tf.keras.callbacks.History): The History object returned by model.fit().
        out_dir (str): Directory to save the plot.
    """
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    history_dict = history.history
    # Define required keys, ensuring they match actual metric names in history
    required_keys = ["loss", "val_loss", "iou_score", "val_iou_score"]

    missing_keys = [k for k in required_keys if k not in history_dict]
    if missing_keys:
        print(f"⚠️ Missing keys in history: {missing_keys}. Cannot plot all curves.")
        # Attempt to plot what's available
        pass 

    fig, axs = plt.subplots(1, 2, figsize=(12, 5))

    # Plot Loss if available
    if "loss" in history_dict and "val_loss" in history_dict:
        axs[0].plot(history_dict["loss"], label="Train Loss")
        axs[0].plot(history_dict["val_loss"], label="Val Loss")
        axs[0].set_title("Loss over Epochs")
        axs[0].set_xlabel("Epoch")
        axs[0].set_ylabel("Loss")
        axs[0].legend()
    else:
        axs[0].set_title("Loss Data Not Available")
        axs[0].axis('off')

    # Plot IoU if available
    if "iou_score" in history_dict and "val_iou_score" in history_dict:
        axs[1].plot(history_dict["iou_score"], label="Train IoU")
        axs[1].plot(history_dict["val_iou_score"], label="Val IoU")
        axs[1].set_title("Mean IoU over Epochs")
        axs[1].set_xlabel("Epoch")
        axs[1].set_ylabel("Mean IoU")
        axs[1].legend()
    else:
        axs[1].set_title("IoU Data Not Available")
        axs[1].axis('off')

    plt.tight_layout()
    save_path = os.path.join(out_dir, "training_curves.png")
    plt.savefig(save_path)
    plt.show()
    plt.close(fig) # Use close(fig) to explicitly close the figure
    print(f"Saved training curves to: {save_path}")


def evaluate_on_test(
    model: tf.keras.Model,
    test_gen: tf.data.Dataset,
    test_df: pd.DataFrame,
    out_dir: str,
    image_dir: str, # Required for reconstruction if 'yummy' is true and not handled by generator
    label_dir: str, # Required for reconstruction if 'yummy' is true and not handled by generator
    tile_size: int = 256,
    n_rows: int = 4,
    n_cols: int = 3,
    specific_tile_ids: list = None # List of tile IDs to prioritize for visualization
):
    """
    Evaluates the trained model on the test set, calculates various metrics (mIoU, F1, Precision, Recall),
    generates a confusion matrix plot, and visualizes sample predictions.

    Args:
        model (tf.keras.Model): The trained Keras model.
        test_gen (tf.data.Dataset): TensorFlow Dataset for the test set.
        test_df (pd.DataFrame): DataFrame containing metadata for the test set.
        out_dir (str): Directory to save plots.
        image_dir (str): Path to the directory containing RGB images (for visualization/reconstruction).
        label_dir (str): Path to the directory containing label images (for visualization/reconstruction).
        tile_size (int): The size of the tiles (e.g., 256x256).
        n_rows (int): Number of rows for the prediction visualization grid.
        n_cols (int): Number of columns for the prediction visualization grid.
        specific_tile_ids (list, optional): List of specific tile IDs to prioritize for plotting.
                                            Defaults to None (random selection).
    """
    os.makedirs(out_dir, exist_ok=True)
    print("🧪 Running evaluation on test set...")

    all_test_preds = []
    all_test_trues = []

    visual_rgb = [] # To store RGB images for plotting
    visual_true = [] # To store ground truth masks for plotting
    visual_pred = [] # To store predicted masks for plotting
    visual_tile_ids = [] # To store tile_ids for plotting, useful for specific_tile_ids matching
    
    # Calculate desired limit for visualization based on grid size
    visual_limit = n_rows * n_cols

    # Get all tile IDs from the test DataFrame for potential matching
    test_df_tile_ids = test_df['tile_id'].tolist()
    
    tile_index_in_df = 0 # To track which chip from test_df is being processed

    # Iterate through the test generator
    for batch_x, batch_y in tqdm(test_gen, desc="Evaluating"):
        # Skip empty batches if any
        if tf.size(batch_x).numpy() == 0:
            continue

        # Get model predictions (softmax probabilities)
        pred_probs = model.predict(batch_x, verbose=0)
        # Convert probabilities to class IDs (predicted masks)
        pred_mask_batch = np.argmax(pred_probs, axis=-1).astype(np.uint8)
        # Convert one-hot true labels to class IDs (ground truth masks)
        true_mask_batch = np.argmax(batch_y, axis=-1).numpy().astype(np.uint8)

        # Process each image in the current batch
        for j in range(batch_x.shape[0]):
            # Stop if we have processed all tiles in test_df (safety break)
            if tile_index_in_df >= len(test_df_tile_ids):
                break
            
            current_tile_id = test_df_tile_ids[tile_index_in_df]
            
            # Collect data for visualization up to the visual_limit
            if len(visual_rgb) < visual_limit:
                # RGB image (scale from [0,1] to [0,255] and cast to uint8)
                # Assumes RGB is always the first 3 channels if it's a multi-channel input
                rgb_tile = (batch_x[j][:, :, :3].numpy() * 255).astype(np.uint8)
                visual_rgb.append(rgb_tile)
                visual_true.append(batch_y[j].numpy()) # Keep true mask as one-hot for plotting
                visual_pred.append(pred_mask_batch[j]) # Predicted mask (class IDs)
                visual_tile_ids.append(current_tile_id)

            # Collect all true and predicted pixels for overall metric calculation
            all_test_preds.extend(pred_mask_batch[j].reshape(-1))
            all_test_trues.extend(true_mask_batch[j].reshape(-1)) # Flatten for metric calculation

            tile_index_in_df += 1 # Move to the next tile in the test_df

        # Explicitly delete batch variables and collect garbage to free memory
        del batch_x, batch_y, pred_probs, pred_mask_batch, true_mask_batch
        gc.collect()

    if not all_test_preds:
        print("⚠️ No test predictions were collected. Evaluation skipped.")
        return

    # Convert lists to NumPy arrays for metric calculations
    all_test_preds = np.array(all_test_preds)
    all_test_trues = np.array(all_test_trues)

    # --- Visualise Grid ---
    if visual_rgb: # Only plot if there's data to visualize
        print(f"\nVisualizing {len(visual_rgb)} sample predictions...")
        visualise_prediction_grid(
            rgb_list=visual_rgb,
            true_mask_list=visual_true, # Pass one-hot true masks
            pred_mask_list=visual_pred,
            tile_id_list=specific_tile_ids, # Pass specific tile IDs to prioritize
            all_tile_ids=visual_tile_ids, # Pass the tile IDs collected during evaluation
            n_rows=n_rows,
            n_cols=n_cols
        )
    else:
        print("No samples collected for visualization grid.")

    # --- Overall Mean IoU ---
    # Using tf.keras.metrics.MeanIoU directly as it's robust and standard
    mean_iou_metric_tf = tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES)
    mean_iou_metric_tf.update_state(all_test_trues, all_test_preds)
    miou = mean_iou_metric_tf.result().numpy()
    print(f"\n📈 Overall Mean IoU (mIoU): {miou:.4f}")

    # --- Macro F1, Precision, Recall ---
    # Using sklearn for macro averages
    macro_f1 = f1_score(all_test_trues, all_test_preds, average='macro', zero_division=0)
    macro_precision = precision_score(all_test_trues, all_test_preds, average='macro', zero_division=0)
    macro_recall = recall_score(all_test_trues, all_test_preds, average='macro', zero_division=0)

    print(f"\n📊 Macro Metrics:")
    print(f"    F1 Score    : {macro_f1:.4f}")
    print(f"    Precision   : {macro_precision:.4f}")
    print(f"    Recall      : {macro_recall:.4f}")

    # --- Per-class IoU ---
    # Using sklearn's confusion matrix to calculate per-class IoU
    conf_matrix = confusion_matrix(all_test_trues, all_test_preds, labels=list(range(NUM_CLASSES)))
    print("\n📏 Per-class IoU Scores:")
    class_ious = []
    for i in range(NUM_CLASSES):
        intersection = conf_matrix[i, i]
        union = np.sum(conf_matrix[i, :]) + np.sum(conf_matrix[:, i]) - intersection
        iou = intersection / union if union > 0 else float('nan') # Use NaN for undefined IoU
        class_ious.append(iou)
        print(f"  {CLASS_NAMES[i]:<12} IoU: {iou:.4f}")

    # --- Per-class F1, Precision, Recall ---
    # Using sklearn's precision_recall_fscore_support for per-class metrics
    precision_per_class, recall_per_class, f1_per_class, _ = precision_recall_fscore_support(
        all_test_trues, all_test_preds, labels=list(range(NUM_CLASSES)), zero_division=0
    )
    print("\n🔍 Per-class Classification Report (F1 | Precision | Recall):")
    for i, name in enumerate(CLASS_NAMES):
        print(f"{name:<12} | F1: {f1_per_class[i]:.4f} | Prec: {precision_per_class[i]:.4f} | Recall: {recall_per_class[i]:.4f}")

    # --- Confusion Matrix Plot (Row-Normalised) ---
    print("\n🌀 Generating Row-Normalised Confusion Matrix plot...")
    with np.errstate(divide='ignore', invalid='ignore'): # Suppress division by zero warnings
        row_sums = conf_matrix.sum(axis=1, keepdims=True)
        # Normalize by true labels sum (row-wise)
        norm_conf = np.divide(conf_matrix.astype(np.float32), row_sums, where=row_sums != 0)

    fig, ax = plt.subplots(figsize=(8, 6))
    disp = ConfusionMatrixDisplay(norm_conf, display_labels=CLASS_NAMES)
    disp.plot(cmap='viridis', ax=ax, values_format=".2f")
    ax.set_title("Row-Normalised Confusion Matrix (True Label %) - Test Set")
    fig.tight_layout()
    cm_save_path = os.path.join(out_dir, 'confusion_matrix_row_norm.png')
    fig.savefig(cm_save_path)
    plt.show()
    plt.close(fig)
    print(f"Saved confusion matrix to: {cm_save_path}")

    # --- Per-chip mIoU histogram ---
    # To enable this, ensure 'chip_ious' is collected in evaluate_on_test.
    # Currently, 'chip_ious' is collected in evaluate_model_with_crf.
    # If you want it here, you'd need to re-add the per-chip IoU calculation loop.
    # For now, I'll add a placeholder if you decide to add it.
    
    # Example if you collect chip_ious:
    # if chip_ious: # assuming chip_ious is populated
    #     print("\n📊 Generating per-chip mIoU histogram...")
    #     bin_edges = np.linspace(0, 1, 21) # 5% bins
    #     plt.figure(figsize=(10, 6))
    #     plt.hist(chip_ious, bins=bin_edges, edgecolor='black')
    #     plt.xlabel("Per-Chip mIoU")
    #     plt.ylabel("Number of Chips")
    #     plt.title("Distribution of Per-Chip mIoU")
    #     plt.grid(True)
    #     plt.tight_layout()
    #     hist_save_path = os.path.join(out_dir, "per_chip_miou_hist.png")
    #     plt.savefig(hist_save_path)
    #     plt.show()
    #     plt.close()
    #     print(f"Saved per-chip mIoU histogram to: {hist_save_path}")


# --- Reconstruction Functions ---

def reconstruct_canvas(
    model: tf.keras.Model,
    df: pd.DataFrame,
    source_file_prefix: str, # Changed from source_file to source_file_prefix for clarity
    generator_class: callable, # e.g., build_tf_dataset
    img_dir: str,
    elev_dir: str,
    slope_dir: str,
    label_dir: str,
    tile_size: int = 256
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Reconstructs the full RGB image, ground truth mask, and model prediction mask
    for a given source file by stitching together its individual chips.

    Args:
        model (tf.keras.Model): The trained Keras model.
        df (pd.DataFrame): The DataFrame containing metadata for all chips (e.g., test_df).
        source_file_prefix (str): The common prefix for all chips belonging to the same large image.
                                  (e.g., "25f1c24f30_EB81FE6E2BOPENPIPELINE")
        generator_class (callable): The function to build the TensorFlow dataset (e.g., `build_tf_dataset`).
        img_dir (str): Directory containing RGB images.
        elev_dir (str): Directory containing elevation .npy files.
        slope_dir (str): Directory containing slope .npy files.
        label_dir (str): Directory containing label images.
        tile_size (int): The size of individual tiles/chips (e.g., 256).

    Returns:
        tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing the
        reconstructed RGB canvas, Ground Truth canvas, and Prediction canvas (all as np.uint8).

    Raises:
        ValueError: If no chips are found for the specified source file prefix.
    """
    # Filter for chips belonging to the specific source file
    df_file = df[df['tile_id'].str.startswith(source_file_prefix)].copy() # Use str.startswith
    if df_file.empty:
        raise ValueError(f"No chips found for source file prefix: {source_file_prefix}")

    # Determine the overall canvas shape
    min_x = df_file['x'].min()
    min_y = df_file['y'].min()
    max_x = df_file['x'].max() + tile_size
    max_y = df_file['y'].max() + tile_size

    canvas_w = max_x - min_x
    canvas_h = max_y - min_y

    canvas_shape_rgb = (canvas_h, canvas_w, 3)
    canvas_shape_mask = (canvas_h, canvas_w, 3) # For GT/Pred, will be colored RGB

    # Initialize canvases with IGNORE_COLOR (magenta) for padding/uncovered areas
    rgb_canvas = np.full(canvas_shape_rgb, IGNORE_COLOR, dtype=np.uint8)
    gt_canvas = np.full(canvas_shape_mask, IGNORE_COLOR, dtype=np.uint8)
    pred_canvas = np.full(canvas_shape_mask, IGNORE_COLOR, dtype=np.uint8)

    # Build dataset for just this file's chips, ensuring correct order for stitching
    # Use shuffle=False and augment=False for reconstruction
    gen = generator_class(
        df_file, img_dir, elev_dir, slope_dir, label_dir,
        batch_size=64, shuffle=False, augment=False, split="val" # Use 'val' or 'test' split to ensure no augments
    )

    # Iterate through the generated batches, predict, and fill the canvases
    row_index_in_df = 0 # Tracks position in the filtered df_file
    for batch_x, batch_y_onehot in gen:
        preds_softmax = model.predict(batch_x, verbose=0)
        pred_mask_ids = tf.argmax(preds_softmax, axis=-1).numpy() # Predicted class IDs
        true_mask_ids = tf.argmax(batch_y_onehot, axis=-1).numpy() # Ground truth class IDs

        batch_size_actual = batch_x.shape[0]
        for i in range(batch_size_actual):
            if row_index_in_df >= len(df_file): # Safety break
                break

            # Get chip's original (x, y) coordinates from the DataFrame
            row_df_entry = df_file.iloc[row_index_in_df]
            rel_x = row_df_entry.x - min_x
            rel_y = row_df_entry.y - min_y
            row_index_in_df += 1

            # Extract current chip data
            rgb_chip = tf.cast(batch_x[i][..., :3] * 255.0, tf.uint8).numpy()
            true_mask_chip = true_mask_ids[i]
            pred_mask_chip = pred_mask_ids[i]

            # Place RGB chip onto the canvas
            rgb_canvas[rel_y : rel_y + tile_size, rel_x : rel_x + tile_size] = rgb_chip

            # Color and place Ground Truth mask onto the canvas
            gt_rgb_chip = np.zeros((tile_size, tile_size, 3), dtype=np.uint8)
            for cid, colour in CLASS_TO_COLOR.items():
                gt_rgb_chip[true_mask_chip == cid] = colour
            gt_canvas[rel_y : rel_y + tile_size, rel_x : rel_x + tile_size] = gt_rgb_chip

            # Color and place Prediction mask onto the canvas
            pred_rgb_chip = np.zeros((tile_size, tile_size, 3), dtype=np.uint8)
            for cid, colour in CLASS_TO_COLOR.items():
                pred_rgb_chip[pred_mask_chip == cid] = colour
            pred_canvas[rel_y : rel_y + tile_size, rel_x : rel_x + tile_size] = pred_rgb_chip
    
    return rgb_canvas, gt_canvas, pred_canvas


def plot_reconstruction(img: np.ndarray, label: np.ndarray, pred: np.ndarray, source_file_prefix: str):
    """
    Plots the reconstructed full RGB image, ground truth mask, and model prediction mask
    for a given source file.

    Args:
        img (np.ndarray): The reconstructed RGB image canvas.
        label (np.ndarray): The reconstructed ground truth mask canvas (colored).
        pred (np.ndarray): The reconstructed prediction mask canvas (colored).
        source_file_prefix (str): The common prefix of the source file for the title.
    """
    fig, axs = plt.subplots(1, 3, figsize=(26.5, 13))

    axs[0].imshow(img)
    axs[0].set_title("RGB Image")
    axs[0].axis('off')

    axs[1].imshow(label)
    axs[1].set_title("Ground Truth")
    axs[1].axis('off')

    axs[2].imshow(pred)
    axs[2].set_title("Model Prediction")
    axs[2].axis('off')

    plt.suptitle(f"Reconstruction for: {source_file_prefix}", fontsize=24, y=0.95)
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    plt.show()
    plt.close(fig)


def normalize_confusion_matrix(cm: np.ndarray, norm: str = 'true') -> np.ndarray:
    """
    Normalizes a confusion matrix.

    Args:
        cm (np.ndarray): Confusion matrix to be normalized.
        norm (str): Type of normalization.
                    - 'true': Normalize by true labels (rows sum to 1).
                    - 'pred': Normalize by predicted labels (columns sum to 1).
                    - 'all': Normalize by total count (matrix sums to 1).

    Returns:
        np.ndarray: Normalized confusion matrix.

    Raises:
        ValueError: If an unknown normalization type is provided.
    """
    cm_normalized = cm.astype(np.float32)

    if norm == 'true':
        # Normalize each row by its sum
        row_sums = cm_normalized.sum(axis=1, keepdims=True)
        # Avoid division by zero by replacing 0 sums with NaN or 1 (depending on desired behavior for empty rows)
        # Using np.divide with 'where' clause for safer division
        cm_normalized = np.divide(cm_normalized, row_sums, where=row_sums != 0)
        # Fill NaN results (where row_sum was 0) with 0
        cm_normalized = np.nan_to_num(cm_normalized, nan=0.0)
    elif norm == 'pred':
        # Normalize each column by its sum
        col_sums = cm_normalized.sum(axis=0, keepdims=True)
        cm_normalized = np.divide(cm_normalized, col_sums, where=col_sums != 0)
        cm_normalized = np.nan_to_num(cm_normalized, nan=0.0)
    elif norm == 'all':
        # Normalize by the grand total sum
        total_sum = cm_normalized.sum()
        cm_normalized = cm_normalized / total_sum if total_sum > 0 else cm_normalized
        cm_normalized = np.nan_to_num(cm_normalized, nan=0.0)
    else:
        raise ValueError("Unknown normalization type. Use 'true', 'pred', or 'all'.")
    
    return cm_normalized


# --- Optional: CRF post-processing (if pydensecrf is installed) ---
# Ensure pydensecrf is installed (pip install pydensecrf)
# This function is used in evaluate_model_with_crf (commented out in your original code)

try:
    import pydensecrf.densecrf as dcrf
    from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral

    def apply_crf(rgb: np.ndarray, probs: np.ndarray, t: int = 5, compat: int = 10, sxy: int = 3, srgb: int = 13) -> np.ndarray:
        """
        Applies Dense Conditional Random Field (CRF) post-processing to raw softmax predictions.
        CRF refines segmentation boundaries by considering both pixel-wise probabilities
        and image appearance (color and spatial proximity).

        Args:
            rgb (np.ndarray): The input RGB image (H, W, 3) as np.uint8.
            probs (np.ndarray): Softmax output probabilities from the model (H, W, num_classes).
            t (int): Number of CRF inference iterations.
            compat (int): Compatibility constant for the pairwise term.
            sxy (int): Spatial standard deviation for the pairwise bilateral term.
            srgb (int): Color standard deviation for the pairwise bilateral term.

        Returns:
            np.ndarray: The refined segmentation mask (H, W) as integer class IDs.
        """
        h, w = probs.shape[:2]
        num_classes = probs.shape[-1]

        d = dcrf.DenseCRF2D(w, h, num_classes)

        # Reshape unary potentials from (H, W, num_classes) to (num_classes, H*W)
        unary = unary_from_softmax(probs.transpose(2, 0, 1))
        unary = unary.reshape((num_classes, -1)).copy(order='C') # Ensure C-contiguous memory layout

        d.setUnaryEnergy(unary)

        # Convert RGB to float32 and ensure C-contiguous for pydensecrf
        rgb_float = rgb.astype(np.float32)
        if not rgb_float.flags['C_CONTIGUOUS']:
            rgb_float = np.ascontiguousarray(rgb_float)

        # Create pairwise bilateral energy term
        pairwise = create_pairwise_bilateral(sdims=(sxy, sxy), schan=(srgb, srgb, srgb), img=rgb_float, chdim=2)
        d.addPairwiseEnergy(pairwise, compat=compat)

        # Perform inference
        Q = d.inference(t)
        # Return the most likely class ID for each pixel
        return np.argmax(Q, axis=0).reshape((h, w))

    def evaluate_model_with_crf(model, test_gen, test_df, out_dir, image_dir, label_dir, tile_size=256, n_rows=4, n_cols=3):
        """
        Evaluates the trained model on the test set with CRF post-processing,
        calculates metrics, and visualizes samples.
        Note: This function is similar to `evaluate_on_test` but explicitly includes CRF.
        It collects per-chip IoUs which is useful for histogram plotting.
        """
        os.makedirs(out_dir, exist_ok=True)
        print("🧪 Running evaluation on test set with CRF post-processing...")
        all_preds = []
        all_trues = []

        rgb_list_viz = [] # RGB images for visualization
        true_mask_list_viz = [] # True masks for visualization (one-hot)
        pred_mask_list_viz = [] # Predicted masks for visualization (class IDs after CRF)
        present_classes_per_chip = [] # For analyse_prediction_grid_by_performance if re-enabled
        chip_ious = [] # Per-chip mIoUs

        visual_limit = n_rows * n_cols # Max number of samples to collect for visualization
        
        test_df_tile_ids = test_df['tile_id'].tolist()
        tile_index_in_df = 0 # To track which chip from test_df is being processed

        for x_batch, y_batch in tqdm(test_gen, desc="CRF Evaluating"):
            if tf.size(x_batch).numpy() == 0:
                continue

            # Get model's raw softmax predictions
            soft_preds = model.predict(x_batch, verbose=0)
            true_mask_batch_onehot = y_batch.numpy() # Keep as one-hot for direct comparison if needed

            for i in range(x_batch.shape[0]):
                if tile_index_in_df >= len(test_df_tile_ids): # Safety break
                    break

                # Extract RGB image for CRF (needs to be 0-255 uint8)
                rgb_img = (x_batch[i][..., :3].numpy() * 255).astype(np.uint8)
                softmax_pred_for_crf = soft_preds[i] # Softmax probabilities for current chip

                # Apply CRF post-processing
                crf_mask = apply_crf(rgb_img, softmax_pred_for_crf) # Returns class IDs

                # Get true mask as class IDs for metric calculation
                true_mask_ids = np.argmax(true_mask_batch_onehot[i], axis=-1).astype(np.uint8)

                # Collect flattened predictions and true labels for overall metrics
                all_preds.extend(crf_mask.reshape(-1))
                all_trues.extend(true_mask_ids.reshape(-1))

                # --- Per-chip IoU calculation (if needed for histogram/sorting) ---
                # This needs to be done here as it uses the CRF-processed mask
                cm_chip = confusion_matrix(true_mask_ids.flatten(), crf_mask.flatten(), labels=list(range(NUM_CLASSES)))
                ious_chip = []
                for j in range(NUM_CLASSES):
                    intersection_chip = cm_chip[j, j]
                    union_chip = np.sum(cm_chip[j, :]) + np.sum(cm_chip[:, j]) - intersection_chip
                    if union_chip > 0:
                        iou_chip = intersection_chip / union_chip
                        ious_chip.append(iou_chip)
                chip_ious.append(np.mean(ious_chip) if ious_chip else 0)


                # Collect visuals up to the limit
                if len(rgb_list_viz) < visual_limit:
                    rgb_list_viz.append(rgb_img)
                    true_mask_list_viz.append(true_mask_batch_onehot[i]) # Keep one-hot for plotting
                    pred_mask_list_viz.append(crf_mask) # CRF output (class IDs)
                    # You might also want to append the tile_id here if sorting by performance is used

                tile_index_in_df += 1 # Move to the next tile in the test_df

            gc.collect() # Garbage collection after each batch

        # Convert collected lists to NumPy arrays for final metric calculations
        all_preds = np.array(all_preds)
        all_trues = np.array(all_trues)

        # --- Visualise Grid (using CRF outputs) ---
        if rgb_list_viz:
            print(f"\n🖼️ Visualizing {len(rgb_list_viz)} sample predictions with CRF...")
            visualise_prediction_grid(
                rgb_list=rgb_list_viz,
                true_mask_list=true_mask_list_viz, # These are one-hot, visualise_prediction_grid will convert
                pred_mask_list=pred_mask_list_viz,
                n_rows=n_rows,
                n_cols=n_cols,
                # If you want to use the performance-based grid, uncomment the old function call
                # and pass the appropriate lists (miou_list, class_list) to it.
            )
        else:
            print("No samples collected for CRF visualization grid.")

        # --- Metrics Calculation (using CRF output) ---
        # Note: These are identical to evaluate_on_test's metric calculations,
        # but they operate on the CRF-processed predictions.
        
        miou = tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES)
        miou.update_state(all_trues, all_preds)
        overall_miou = miou.result().numpy()

        macro_f1 = f1_score(all_trues, all_preds, average='macro', zero_division=0)
        macro_precision = precision_score(all_trues, all_preds, average='macro', zero_division=0)
        macro_recall = recall_score(all_trues, all_preds, average='macro', zero_division=0)

        conf_matrix_crf = confusion_matrix(all_trues, all_preds, labels=list(range(NUM_CLASSES)))
        per_class_ious_crf = []
        for i in range(NUM_CLASSES):
            intersection_crf = conf_matrix_crf[i, i]
            union_crf = np.sum(conf_matrix_crf[i, :]) + np.sum(conf_matrix_crf[:, i]) - intersection_crf
            iou_crf = intersection_crf / union_crf if union_crf > 0 else float('nan')
            per_class_ious_crf.append(iou_crf)
        
        # Per-class F1, Precision, Recall
        precision_crf, recall_crf, f1_crf, _ = precision_recall_fscore_support(
            all_trues, all_preds, labels=list(range(NUM_CLASSES)), zero_division=0
        )

        print(f"\n📊 CRF Macro Metrics:")
        print(f"  F1 Score     : {macro_f1:.4f}")
        print(f"  Precision    : {macro_precision:.4f}")
        print(f"  Recall       : {macro_recall:.4f}")
        print(f"\n📈 CRF Mean IoU (mIoU): {overall_miou:.4f}")
        print("\n📏 CRF Per-class IoU Scores:")
        for i in range(NUM_CLASSES):
            print(f"  {CLASS_NAMES[i]:<12} IoU: {per_class_ious_crf[i]:.4f}")
        print("\n🔍 CRF Per-class Classification Report (F1 | Precision | Recall):")
        for i, name in enumerate(CLASS_NAMES):
            print(f"{name:<12} | F1: {f1_crf[i]:.4f} | Prec: {precision_crf[i]:.4f} | Recall: {recall_crf[i]:.4f}")

        # Confusion Matrix Plot (Row-Normalised) for CRF
        print("\n🌀 Generating CRF Row-Normalised Confusion Matrix plot...")
        with np.errstate(divide='ignore', invalid='ignore'):
            row_sums_crf = conf_matrix_crf.sum(axis=1, keepdims=True)
            norm_conf_crf = np.divide(conf_matrix_crf.astype(np.float32), row_sums_crf, where=row_sums_crf != 0)

        fig_cm, ax_cm = plt.subplots(figsize=(8, 6))
        disp_crf = ConfusionMatrixDisplay(norm_conf_crf, display_labels=CLASS_NAMES)
        disp_crf.plot(cmap='viridis', ax=ax_cm, values_format=".2f")
        ax_cm.set_title("CRF Row-Normalised Confusion Matrix (True Label %) - Test Set")
        fig_cm.tight_layout()
        cm_crf_save_path = os.path.join(out_dir, 'confusion_matrix_crf_row_norm.png')
        fig_cm.savefig(cm_crf_save_path)
        plt.show()
        plt.close(fig_cm)
        print(f"Saved CRF confusion matrix to: {cm_crf_save_path}")

        # Per-chip mIoU histogram for CRF results
        if chip_ious:
            print("\n📊 Generating per-chip mIoU histogram (CRF post-processed)...")
            bin_edges = np.linspace(0, 1, 21) # 5% bins
            fig_hist, ax_hist = plt.subplots(figsize=(10, 6))
            ax_hist.hist(chip_ious, bins=bin_edges, edgecolor='black')
            ax_hist.set_xlabel("Per-Chip mIoU")
            ax_hist.set_ylabel("Number of Chips")
            ax_hist.set_title("Distribution of Per-Chip mIoU (CRF Post-Processed)")
            ax_hist.grid(True)
            fig_hist.tight_layout()
            hist_crf_save_path = os.path.join(out_dir, "per_chip_miou_hist_crf.png")
            fig_hist.savefig(hist_crf_save_path)
            plt.show()
            plt.close(fig_hist)
            print(f"Saved CRF per-chip mIoU histogram to: {hist_crf_save_path}")


        # Return results if needed by a calling function
        return {
            "macro_f1": macro_f1,
            "precision": macro_precision,
            "recall": macro_recall,
            "miou": overall_miou,
            "per_class_ious": per_class_ious_crf,
            "classification_report_str": classification_report(
                all_trues, all_preds, target_names=CLASS_NAMES, digits=4, zero_division=0
            )
        }

except ImportError:
    print("Warning: pydensecrf not installed. Skipping CRF-related functions.")
    # Define a dummy function if CRF is not available, to avoid NameError
    def evaluate_model_with_crf(*args, **kwargs):
        print("CRF evaluation skipped because pydensecrf is not installed.")
        return {}


# --- Reconstruction Functions ---

def reconstruct_canvas(
    model: tf.keras.Model,
    df: pd.DataFrame,
    source_file_prefix: str,
    generator_class: callable,
    img_dir: str,
    elev_dir: str,
    slope_dir: str,
    label_dir: str,
    tile_size: int = 256
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Reconstructs the full RGB image, ground truth mask, and model prediction mask
    for a given source file by stitching together its individual chips.

    Args:
        model (tf.keras.Model): The trained Keras model.
        df (pd.DataFrame): The DataFrame containing metadata for all chips (e.g., test_df).
        source_file_prefix (str): The common prefix for all chips belonging to the same large image.
                                  (e.g., "25f1c24f30_EB81FE6E2BOPENPIPELINE")
        generator_class (callable): The function to build the TensorFlow dataset (e.g., `build_tf_dataset`).
        img_dir (str): Directory containing RGB images.
        elev_dir (str): Directory containing elevation .npy files.
        slope_dir (str): Directory containing slope .npy files.
        label_dir (str): Directory containing label images.
        tile_size (int): The size of individual tiles/chips (e.g., 256).

    Returns:
        tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing the
        reconstructed RGB canvas, Ground Truth canvas, and Prediction canvas (all as np.uint8).

    Raises:
        ValueError: If no chips are found for the specified source file prefix.
    """
    # Filter for chips belonging to the specific source file prefix
    df_file = df[df['tile_id'].str.startswith(source_file_prefix)].copy()
    if df_file.empty:
        raise ValueError(f"No chips found for source file prefix: {source_file_prefix}")

    # Determine the overall canvas shape based on min/max x,y coordinates and tile size
    min_x = df_file['x'].min()
    min_y = df_file['y'].min()
    max_x = df_file['x'].max() + tile_size
    max_y = df_file['y'].max() + tile_size

    canvas_w = max_x - min_x
    canvas_h = max_y - min_y

    canvas_shape_rgb = (canvas_h, canvas_w, 3)
    canvas_shape_mask = (canvas_h, canvas_w, 3) # For GT/Pred, will be colored RGB

    # Initialize canvases with IGNORE_COLOR (magenta) for padding/uncovered areas
    rgb_canvas = np.full(canvas_shape_rgb, IGNORE_COLOR, dtype=np.uint8)
    gt_canvas = np.full(canvas_shape_mask, IGNORE_COLOR, dtype=np.uint8)
    pred_canvas = np.full(canvas_shape_mask, IGNORE_COLOR, dtype=np.uint8)

    # Build dataset for just this file's chips, ensuring correct order for stitching
    # Use shuffle=False and augment=False for reconstruction, and a 'val' or 'test' split
    # to guarantee no augmentation. Batch size can be larger for efficiency during inference.
    gen = generator_class(
        df_file, img_dir, elev_dir, slope_dir, label_dir,
        batch_size=64, shuffle=False, augment=False, split="val" 
    )

    # Iterate through the generated batches, predict, and fill the canvases in order
    row_index_in_df = 0 # Tracks position in the filtered df_file to get (x,y) coords
    for batch_x, batch_y_onehot in tqdm(gen, desc=f"Reconstructing {source_file_prefix}"):
        # Get model predictions (softmax probabilities)
        preds_softmax = model.predict(batch_x, verbose=0)
        # Convert probabilities to class IDs (predicted masks)
        pred_mask_ids = tf.argmax(preds_softmax, axis=-1).numpy()
        # Convert one-hot true labels to class IDs (ground truth masks)
        true_mask_ids = tf.argmax(batch_y_onehot, axis=-1).numpy()

        batch_size_actual = batch_x.shape[0]
        for i in range(batch_size_actual):
            if row_index_in_df >= len(df_file): # Safety break if we've processed all chips
                break

            # Get chip's original (x, y) coordinates from the DataFrame
            row_df_entry = df_file.iloc[row_index_in_df]
            rel_x = row_df_entry.x - min_x # Relative X coordinate for placing on canvas
            rel_y = row_df_entry.y - min_y # Relative Y coordinate for placing on canvas
            row_index_in_df += 1

            # Extract current chip data (RGB is always first 3 channels)
            rgb_chip = tf.cast(batch_x[i][..., :3] * 255.0, tf.uint8).numpy()
            true_mask_chip = true_mask_ids[i]
            pred_mask_chip = pred_mask_ids[i]

            # Place RGB chip onto the canvas
            rgb_canvas[rel_y : rel_y + tile_size, rel_x : rel_x + tile_size] = rgb_chip

            # Color and place Ground Truth mask onto the canvas
            gt_rgb_chip = np.zeros((tile_size, tile_size, 3), dtype=np.uint8)
            for cid, colour in CLASS_TO_COLOR.items():
                gt_rgb_chip[true_mask_chip == cid] = colour
            gt_canvas[rel_y : rel_y + tile_size, rel_x : rel_x + tile_size] = gt_rgb_chip

            # Color and place Prediction mask onto the canvas
            pred_rgb_chip = np.zeros((tile_size, tile_size, 3), dtype=np.uint8)
            for cid, colour in CLASS_TO_COLOR.items():
                pred_rgb_chip[pred_mask_chip == cid] = colour
            pred_canvas[rel_y : rel_y + tile_size, rel_x : rel_x + tile_size] = pred_rgb_chip
    
    return rgb_canvas, gt_canvas, pred_canvas


def plot_reconstruction(img: np.ndarray, label: np.ndarray, pred: np.ndarray, source_file_prefix: str):
    """
    Plots the reconstructed full RGB image, ground truth mask, and model prediction mask
    for a given source file prefix.

    Args:
        img (np.ndarray): The reconstructed RGB image canvas.
        label (np.ndarray): The reconstructed ground truth mask canvas (colored).
        pred (np.ndarray): The reconstructed prediction mask canvas (colored).
        source_file_prefix (str): The common prefix of the source file for the title.
    """
    fig, axs = plt.subplots(1, 3, figsize=(26.5, 13))

    axs[0].imshow(img)
    axs[0].set_title("RGB Image", fontsize=18)
    axs[0].axis('off')

    axs[1].imshow(label)
    axs[1].set_title("Ground Truth", fontsize=18)
    axs[1].axis('off')

    axs[2].imshow(pred)
    axs[2].set_title("Model Prediction", fontsize=18)
    axs[2].axis('off')

    plt.suptitle(f"Reconstruction for: {source_file_prefix}", fontsize=24, y=0.95)
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    plt.show()
    plt.close(fig)

'''