In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns # Used for confusion matrix plotting (implicitly by ConfusionMatrixDisplay's default style)
import random
import gc # For garbage collection
from PIL import Image # For image manipulation, potentially for reconstruction or debugging
from datetime import datetime # For timestamping or time limits

# --- TensorFlow and Keras specific imports ---
import tensorflow as tf
import tensorflow.keras.backend as K
import segmentation_models as sm
from tensorflow.keras.metrics import MeanIoU # Explicit import for clarity

# --- Scikit-learn metrics ---
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
    ConfusionMatrixDisplay, # For plotting confusion matrices
    precision_recall_fscore_support # For per-class metrics
)

# --- Progress bar for loops ---
from tqdm import tqdm

# --- Additional Imports ---
import pandas as pd
from PIL import Image
import cv2
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral

# Clear Keras session to avoid conflicts from previous model definitions
K.clear_session()



# --- Measurement Functions ---

def measure_inference_time(
    model: tf.keras.Model,
    generator: tf.data.Dataset,
    num_batches: int = 5
) -> None:
    """Measures inference time of a Keras model on a dataset.

    Args:
        model (tf.keras.Model): The trained Keras model to evaluate.
        generator (tf.data.Dataset): The input dataset for inference.
        num_batches (int, optional): Number of batches to use for timing. Defaults to 5.

    Returns:
        None
    """
    import time
    total_time = 0
    total_images = 0

    if num_batches is None:
        num_batches = tf.data.experimental.cardinality(generator).numpy()

    for i, (x_batch, _) in enumerate(generator.take(num_batches)):
        start = time.time()
        _ = model.predict(x_batch, verbose=0)
        end = time.time()
        total_time += (end - start)
        total_images += x_batch.shape[0]

    print(f"Inference time: {total_time:.2f} sec for {total_images} images")
    print(f"Avg inference time per image: {total_time / total_images:.4f} sec")


# --- Plotting Functions ---

def plot_training_curves(
    history: tf.keras.callbacks.History,
    out_dir: str
) -> None:
    """Plots and saves training/validation loss and IoU curves.

    Args:
        history (tf.keras.callbacks.History): History object returned by `model.fit()`.
        out_dir (str): Directory path to save the generated plot.

    Returns:
        None
    """
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    history_dict = history.history
    required_keys = ["loss", "val_loss", "iou_score", "val_iou_score"]

    missing_keys = [k for k in required_keys if k not in history_dict]
    if missing_keys:
        print(f"Missing keys in history: {missing_keys}")
        return

    fig, axs = plt.subplots(1, 2, figsize=(12, 5))

    # Plot loss
    axs[0].plot(history_dict["loss"], label="Train Loss")
    axs[0].plot(history_dict["val_loss"], label="Val Loss")
    axs[0].set_title("Loss over Epochs")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].legend()

    # Plot IoU
    axs[1].plot(history_dict["iou_score"], label="Train IoU")
    axs[1].plot(history_dict["val_iou_score"], label="Val IoU")
    axs[1].set_title("Mean IoU over Epochs")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Mean IoU")
    axs[1].legend()

    plt.tight_layout()
    save_path = os.path.join(out_dir, "training_curves.png")
    plt.savefig(save_path)
    plt.show()
    plt.close()

    print(f"Saved training curves to: {save_path}")


from typing import List, Optional
def visualise_prediction_grid(
    rgb_list: List[np.ndarray],
    true_mask_list: List[np.ndarray],
    pred_mask_list: List[np.ndarray],
    tile_id_list: Optional[List[str]] = None,
    all_tile_ids: Optional[List[str]] = None,
    n_rows: int = 4,
    n_cols: int = 3
) -> None:
    """Displays a grid of RGB images, ground truth masks, and predicted masks.

    Args:
        rgb_list (List[np.ndarray]): List of input RGB images (uint8, 0-255).
        true_mask_list (List[np.ndarray]): List of one-hot encoded ground truth masks.
        pred_mask_list (List[np.ndarray]): List of predicted masks (class ID format).
        tile_id_list (Optional[List[str]]): Tile IDs to prioritise for visualisation.
        all_tile_ids (Optional[List[str]]): Full list of tile IDs aligned with inputs.
        n_rows (int): Number of rows in the grid.
        n_cols (int): Number of triplet columns (each triplet is Input, Ground Truth, Prediction).

    Returns:
        None
    """
    total = n_rows * n_cols
    fig, axs = plt.subplots(n_rows, n_cols * 3, figsize=(n_cols * 6.6, n_rows * 2.6))

    indices_to_plot = []

    if tile_id_list and all_tile_ids:
        tile_id_set = set(tile_id_list)
        matched_indices = [i for i, tid in enumerate(all_tile_ids) if tid in tile_id_set]
        random.shuffle(matched_indices)
        indices_to_plot = matched_indices[:total]

        if len(indices_to_plot) < total:
            all_indices = list(set(range(len(rgb_list))) - set(indices_to_plot))
            random.shuffle(all_indices)
            indices_to_plot += all_indices[:total - len(indices_to_plot)]
    else:
        indices_to_plot = list(range(min(total, len(rgb_list))))

    for idx_plot, data_idx in enumerate(indices_to_plot):
        rgb = rgb_list[data_idx]
        true_mask = np.argmax(true_mask_list[data_idx], axis=-1)
        pred_mask = pred_mask_list[data_idx]

        h, w = true_mask.shape
        true_rgb = np.zeros((h, w, 3), dtype=np.uint8)
        pred_rgb = np.zeros((h, w, 3), dtype=np.uint8)

        for class_id, color in CLASS_TO_COLOR.items():
            true_rgb[true_mask == class_id] = color
            pred_rgb[pred_mask == class_id] = color

        ignore_mask = np.all(true_mask_list[data_idx] == 0, axis=-1)
        true_rgb[ignore_mask] = (255, 0, 255)
        pred_rgb[ignore_mask] = (255, 0, 255)

        row = idx_plot // n_cols
        col = (idx_plot % n_cols) * 3

        axs[row, col + 0].imshow(rgb)
        axs[row, col + 0].set_title("Input")
        axs[row, col + 1].imshow(true_rgb)
        axs[row, col + 1].set_title("Ground Truth")
        axs[row, col + 2].imshow(pred_rgb)
        axs[row, col + 2].set_title("Prediction")

        for i in range(3):
            axs[row, col + i].axis("off")

    plt.tight_layout()
    plt.show()
    plt.close(fig)


from typing import Union
def normalize_confusion_matrix(
    cm: Union[np.ndarray, list],
    norm: str = 'true'
) -> np.ndarray:
    """Normalises a confusion matrix by rows, columns, or entire matrix.

    Args:
        cm (Union[np.ndarray, list]): Raw confusion matrix.
        norm (str): Normalisation method. Options:
            - 'true': Normalise by rows (ground truth labels).
            - 'pred': Normalise by columns (predicted labels).
            - 'all': Normalise entire matrix to sum to 1.

    Returns:
        np.ndarray: Normalised confusion matrix.

    Raises:
        ValueError: If `norm` is not one of 'true', 'pred', or 'all'.
    """
    cm = np.array(cm, dtype=np.float32)

    if norm == 'true':
        cm_normalized = cm / cm.sum(axis=1, keepdims=True)
    elif norm == 'pred':
        cm_normalized = cm / cm.sum(axis=0, keepdims=True)
    elif norm == 'all':
        cm_normalized = cm / cm.sum()
    else:
        raise ValueError("Unknown normalization type. Use 'true', 'pred', or 'all'.")

    return cm_normalized


# --- Main Evaluation Function ---
from typing import Optional, List

def evaluate_on_test(
    model: tf.keras.Model,
    test_gen: tf.data.Dataset,
    test_df: pd.DataFrame,
    out_dir: str,
    image_dir: str,
    label_dir: str,
    tile_size: int = 256,
    n_rows: int = 4,
    n_cols: int = 3,
    specific_tile_ids: Optional[List[str]] = None
) -> None:
    """Evaluates the model on the test set and generates metrics and visualisations.

    This includes mean IoU, macro F1, Precision, Recall, a confusion matrix,
    and a prediction visualisation grid.

    Args:
        model (tf.keras.Model): Trained segmentation model.
        test_gen (tf.data.Dataset): Test dataset generator.
        test_df (pd.DataFrame): DataFrame with test metadata including tile IDs.
        out_dir (str): Directory to save output plots.
        image_dir (str): Directory containing RGB images (not used in this function).
        label_dir (str): Directory containing label images (not used in this function).
        tile_size (int): Size of each tile in pixels (e.g. 256x256).
        n_rows (int): Number of rows in the prediction grid.
        n_cols (int): Number of columns in the prediction grid.
        specific_tile_ids (Optional[List[str]]): Tile IDs to prioritise for visualisation.

    Returns:
        None
    """
    print("Running evaluation on test set...")

    all_test_preds = []
    all_test_trues = []

    visual_rgb = []
    visual_true = []
    visual_pred = []
    visual_tile_ids = []
    visual_limit = n_rows * n_cols if n_rows and n_cols else 5

    os.makedirs(out_dir, exist_ok=True)
    test_tile_ids = test_df['tile_id'].tolist()
    tile_index = 0

    for batch_x, batch_y in test_gen.as_numpy_iterator():
        if batch_x.size == 0:
            continue

        pred = model.predict(batch_x, verbose=0)
        pred_mask = np.argmax(pred, axis=-1).astype(np.uint8)
        true_mask = np.argmax(batch_y, axis=-1).astype(np.uint8)

        for j in range(batch_x.shape[0]):
            if tile_index >= len(test_tile_ids):
                break

            tile_id = test_tile_ids[tile_index]
            tile_index += 1

            if len(visual_rgb) < visual_limit:
                rgb_tile = (batch_x[j][:, :, :3] * 255).astype(np.uint8)
                visual_rgb.append(rgb_tile)
                visual_true.append(batch_y[j])
                visual_pred.append(pred_mask[j])
                visual_tile_ids.append(tile_id)

        all_test_preds.extend(pred_mask.reshape(-1))
        all_test_trues.extend(true_mask.reshape(-1))

        del batch_x, batch_y, pred, pred_mask, true_mask
        gc.collect()

    if not all_test_preds:
        print("No test predictions were collected.")
        return

    all_test_preds = np.array(all_test_preds)
    all_test_trues = np.array(all_test_trues)

    if visual_rgb:

        visualise_prediction_grid(
            visual_rgb,
            visual_true,
            visual_pred,
            tile_id_list=specific_tile_ids, # If provided, prioritise these tiles
            all_tile_ids=visual_tile_ids,
            n_rows=n_rows,
            n_cols=n_cols
        )



    mean_iou_metric = tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES)
    mean_iou_metric.update_state(all_test_trues, all_test_preds)
    miou = mean_iou_metric.result().numpy()

    macro_f1 = f1_score(all_test_trues, all_test_preds, average='macro')
    macro_precision = precision_score(all_test_trues, all_test_preds, average='macro')
    macro_recall = recall_score(all_test_trues, all_test_preds, average='macro')

    print(f"\n📊 Macro Metrics:")
    print(f"  F1 Score     : {macro_f1:.4f}")
    print(f"  Precision    : {macro_precision:.4f}")
    print(f"  Recall       : {macro_recall:.4f}")

    conf_matrix = confusion_matrix(all_test_trues, all_test_preds, labels=list(range(NUM_CLASSES)))
    print("\n📏 Per-class IoU Scores:")
    class_ious = []
    for i in range(NUM_CLASSES):
        intersection = conf_matrix[i, i]
        union = np.sum(conf_matrix[i, :]) + np.sum(conf_matrix[:, i]) - intersection
        iou = intersection / union if union > 0 else float('nan')
        class_ious.append(iou)
        print(f"  {CLASS_NAMES[i]:<12} IoU: {iou:.4f}")

    print(f"\n📈 Mean IoU (mIoU): {miou:.4f}")

    print("\n🔍 Classification Report:")
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_test_trues,
        all_test_preds,
        labels=list(range(NUM_CLASSES)),
        zero_division=0
    )
    for i, name in enumerate(CLASS_NAMES):
        print(f"{name:<12} | F1: {f1[i]:.4f} | Prec: {precision[i]:.4f} | Recall: {recall[i]:.4f}")

    print("\n🌀 Row-Normalised Confusion Matrix:")
    with np.errstate(divide='ignore', invalid='ignore'):
        row_sums = conf_matrix.sum(axis=1, keepdims=True)
        norm_conf = np.divide(conf_matrix.astype(np.float32), row_sums, where=row_sums != 0)

    fig, ax = plt.subplots(figsize=(8, 6))
    disp = ConfusionMatrixDisplay(norm_conf, display_labels=CLASS_NAMES)
    disp.plot(cmap='viridis', ax=ax, values_format=".2f")
    fig.tight_layout()
    fig.savefig(os.path.join(out_dir, 'confusion_matrix_row_norm.png'))
    plt.show()
    plt.close(fig)



# --- Reconstruction Utility ---

def reconstruct_canvas(
    model: tf.keras.Model,
    df: pd.DataFrame,
    source_file_prefix: str, # The common prefix for all chips of one large image
    generator_class: callable, # The function to build the TensorFlow dataset (e.g., `build_tf_dataset`)
    img_dir: str,
    elev_dir: str,
    slope_dir: str,
    label_dir: str,
    tile_size: int = 256
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Reconstructs the full RGB image, ground truth mask, and model prediction mask
    for a given source file by stitching together its individual chips.

    Args:
        model (tf.keras.Model): The trained Keras model.
        df (pd.DataFrame): The DataFrame containing metadata for all chips (e.g., test_df).
        source_file_prefix (str): The common prefix for all chips belonging to the same large image.
                                  (e.g., "25f1c24f30_EB81FE6E2BOPENPIPELINE")
        generator_class (callable): The function to build the TensorFlow dataset (e.g., `build_tf_dataset`).
        img_dir (str): Directory containing RGB images.
        elev_dir (str): Directory containing elevation .npy files.
        slope_dir (str): Directory containing slope .npy files.
        label_dir (str): Directory containing label images.
        tile_size (int): The size of individual tiles/chips (e.g., 256).

    Returns:
        tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing the
        reconstructed RGB canvas, Ground Truth canvas, and Prediction canvas (all as np.uint8).

    Raises:
        ValueError: If no chips are found for the specified source file prefix.
    """
    # Filter for chips belonging to the specific source file prefix
    # Assuming 'tile_id' column contains the full ID like 'prefix_x_y'
    df_file = df[df['tile_id'].str.startswith(source_file_prefix)].copy()
    if df_file.empty:
        raise ValueError(f"No chips found for source file prefix: {source_file_prefix}")

    # Determine the overall canvas shape based on min/max x,y coordinates and tile size
    min_x = df_file['x'].min()
    min_y = df_file['y'].min()
    max_x = df_file['x'].max() + tile_size
    max_y = df_file['y'].max() + tile_size

    canvas_w = max_x - min_x
    canvas_h = max_y - min_y

    canvas_shape_rgb = (canvas_h, canvas_w, 3)
    canvas_shape_mask = (canvas_h, canvas_w, 3) # For GT/Pred, will be colored RGB

    # Initialize canvases with IGNORE_COLOR (magenta) for padding/uncovered areas
    rgb_canvas = np.full(canvas_shape_rgb, IGNORE_COLOR, dtype=np.uint8)
    gt_canvas = np.full(canvas_shape_mask, IGNORE_COLOR, dtype=np.uint8)
    pred_canvas = np.full(canvas_shape_mask, IGNORE_COLOR, dtype=np.uint8)

    # Build dataset for just this file's chips, ensuring correct order for stitching
    # Use shuffle=False and augment=False for reconstruction, and a 'val' or 'test' split
    # to guarantee no augmentation. Batch size can be larger for efficiency during inference.
    # Assuming 'generator_class' (e.g., build_tf_dataset) can accept a 'split' argument.
    gen = generator_class(
        df_file, img_dir, elev_dir, slope_dir, label_dir,
        batch_size=64, shuffle=False, augment=False, split="val" # Use 'val' or 'test' to ensure no augments
    )

    # Iterate through the generated batches, predict, and fill the canvases in order
    row_index_in_df = 0 # Tracks position in the filtered df_file to get (x,y) coords
    for batch_x, batch_y_onehot in tqdm(gen, desc=f"Reconstructing {source_file_prefix}"):
        # Get model predictions (softmax probabilities)
        preds_softmax = model.predict(batch_x, verbose=0)
        # Convert probabilities to class IDs (predicted masks)
        pred_mask_ids = tf.argmax(preds_softmax, axis=-1).numpy()
        # Convert one-hot true labels to class IDs (ground truth masks)
        true_mask_ids = tf.argmax(batch_y_onehot, axis=-1).numpy()

        batch_size_actual = batch_x.shape[0]
        for i in range(batch_size_actual):
            if row_index_in_df >= len(df_file): # Safety break if we've processed all chips
                break

            # Get chip's original (x, y) coordinates from the DataFrame
            row_df_entry = df_file.iloc[row_index_in_df]
            rel_x = row_df_entry.x - min_x # Relative X coordinate for placing on canvas
            rel_y = row_df_entry.y - min_y # Relative Y coordinate for placing on canvas
            row_index_in_df += 1

            # Extract current chip data (RGB is always first 3 channels)
            # Scale from [0,1] to [0,255] and cast to uint8
            rgb_chip = tf.cast(batch_x[i][..., :3] * 255.0, tf.uint8).numpy()
            true_mask_chip = true_mask_ids[i]
            pred_mask_chip = pred_mask_ids[i]

            # Place RGB chip onto the canvas
            rgb_canvas[rel_y : rel_y + tile_size, rel_x : rel_x + tile_size] = rgb_chip

            # Color and place Ground Truth mask onto the canvas
            gt_rgb_chip = np.zeros((tile_size, tile_size, 3), dtype=np.uint8)
            for cid, colour in CLASS_TO_COLOR.items():
                gt_rgb_chip[true_mask_chip == cid] = np.array(colour, dtype=np.uint8)
            gt_canvas[rel_y : rel_y + tile_size, rel_x : rel_x + tile_size] = gt_rgb_chip

            # Color and place Prediction mask onto the canvas
            pred_rgb_chip = np.zeros((tile_size, tile_size, 3), dtype=np.uint8)
            for cid, colour in CLASS_TO_COLOR.items():
                pred_rgb_chip[pred_mask_chip == cid] = np.array(colour, dtype=np.uint8)
            pred_canvas[rel_y : rel_y + tile_size, rel_x : rel_x + tile_size] = pred_rgb_chip
    
    return rgb_canvas, gt_canvas, pred_canvas


def plot_reconstruction(img: np.ndarray, label: np.ndarray, pred: np.ndarray, source_file_prefix: str):
    """
    Plots the reconstructed full RGB image, ground truth mask, and model prediction mask
    for a given source file prefix.

    Args:
        img (np.ndarray): The reconstructed RGB image canvas (uint8).
        label (np.ndarray): The reconstructed ground truth mask canvas (colored, uint8).
        pred (np.ndarray): The reconstructed prediction mask canvas (colored, uint8).
        source_file_prefix (str): The common prefix of the source file for the title.
    """
    fig, axs = plt.subplots(1, 3, figsize=(26.5, 13))

    axs[0].imshow(img)
    axs[0].set_title("RGB Image", fontsize=18)
    axs[0].axis('off')

    axs[1].imshow(label)
    axs[1].set_title("Ground Truth", fontsize=18)
    axs[1].axis('off')

    axs[2].imshow(pred)
    axs[2].set_title("Model Prediction", fontsize=18)
    axs[2].axis('off')

    plt.suptitle(f"Reconstruction for: {source_file_prefix}", fontsize=24, y=0.95)
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    plt.show()
    plt.close(fig)




# --- Unused Utility Functions ---

def visualise_prediction_grid_by_performance(
    rgb_list,
    true_mask_list,
    pred_mask_list,
    miou_list,
    class_list,
    n_rows=4,
    n_cols=3,
    class_names=CLASS_NAMES,
    class_to_color=CLASS_TO_COLOR
):
    """
    Visualizes a grid of predictions sorted by performance (mIoU),
    with an attempt to include specific classes (like Water or Building)
    from the lowest performance tier.

    Args:
        rgb_list (list): List of input RGB images (NumPy arrays).
        true_mask_list (list): List of ground truth masks (one-hot encoded NumPy arrays).
        pred_mask_list (list): List of predicted masks (integer class ID NumPy arrays).
        miou_list (list): List of per-chip mIoU scores.
        class_list (list): List where each element is a list of integer class IDs present in that chip.
        n_rows (int): Number of rows in the visualization grid.
        n_cols (int): Number of columns in the visualization grid.
        class_names (list): List of class names.
        class_to_color (dict): Mapping from class ID to RGB color tuple.
    """
    import matplotlib.pyplot as plt
    import numpy as np

    # Step 1: Sort chips by mIoU
    chip_data = list(zip(rgb_list, true_mask_list, pred_mask_list, miou_list, class_list))
    chip_data.sort(key=lambda x: x[3])  # Sort by mIoU

    n_total = n_rows * n_cols
    chips_per_group = n_total // 3

    bottom_third = chip_data[:chips_per_group]
    middle_third = chip_data[chips_per_group:2 * chips_per_group]
    top_third = chip_data[2 * chips_per_group:]

    # Step 2: Select 4 chips from each group with some class coverage strategy
    def select_chips(chips, needed, reserve_class=None):
        selected = []
        for chip in chips:
            if len(selected) >= needed:
                break
            if reserve_class is not None and reserve_class in chip[4]:
                selected.append(chip)
                reserve_class = None  # only once
            elif reserve_class is None:
                selected.append(chip)
        return selected

    selected_bottom = select_chips(bottom_third, 4, reserve_class=class_names.index("Water"))
    selected_bottom = select_chips([c for c in bottom_third if c not in selected_bottom], 4 - len(selected_bottom), reserve_class=class_names.index("Building")) + selected_bottom

    selected_middle = middle_third[:4]
    selected_top = top_third[:4]

    # Combine and assign column index
    all_selected = selected_bottom + selected_middle + selected_top
    column_assignments = [0] * 4 + [1] * 4 + [2] * 4  # 0: bottom, 1: middle, 2: top

    # Step 3: Plotting
    fig, axs = plt.subplots(n_rows, n_cols * 3, figsize=(n_cols * 6.6, n_rows * 2.6))

    for idx, (rgb, true_mask_oh, pred_mask, _, _) in enumerate(all_selected):
        true_mask = np.argmax(true_mask_oh, axis=-1)
        h, w = true_mask.shape
        true_rgb = np.zeros((h, w, 3), dtype=np.uint8)
        pred_rgb = np.zeros((h, w, 3), dtype=np.uint8)

        for class_id, color in class_to_color.items():
            true_rgb[true_mask == class_id] = color
            pred_rgb[pred_mask == class_id] = color

        ignore_mask = np.all(true_mask_oh == 0, axis=-1)
        true_rgb[ignore_mask] = (255, 0, 255)
        pred_rgb[ignore_mask] = (255, 0, 255)

        row = idx % n_rows
        col = column_assignments[idx] * 3

        axs[row, col + 0].imshow(rgb)
        axs[row, col + 0].set_title("Input")
        axs[row, col + 1].imshow(true_rgb)
        axs[row, col + 1].set_title("Ground Truth")
        axs[row, col + 2].imshow(pred_rgb)
        axs[row, col + 2].set_title("Prediction")

        for i in range(3):
            axs[row, col + i].axis("off")

    plt.tight_layout()
    plt.show()
    plt.close(fig)


def apply_crf(rgb, probs, t=5, compat=10, sxy=3, srgb=13):
    import pydensecrf.densecrf as dcrf
    from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral
    import numpy as np

    h, w = probs.shape[:2]
    num_classes = probs.shape[-1]

    d = dcrf.DenseCRF2D(w, h, num_classes)

    # 🔧 FIX: reshape to (num_classes, H*W)
    unary = unary_from_softmax(probs.transpose(2, 0, 1))
    unary = unary.reshape((num_classes, -1)).copy(order='C')

    d.setUnaryEnergy(unary)

    rgb_float = rgb.astype(np.float32)
    if not rgb_float.flags['C_CONTIGUOUS']:
        rgb_float = np.ascontiguousarray(rgb_float)

    pairwise = create_pairwise_bilateral(sdims=(sxy, sxy), schan=(srgb, srgb, srgb), img=rgb_float, chdim=2)
    d.addPairwiseEnergy(pairwise, compat=compat)

    Q = d.inference(t)
    return np.argmax(Q, axis=0).reshape((h, w))




def evaluate_model_with_crf(model, test_gen, test_df, out_dir, image_dir, label_dir, tile_size=256, n_rows=4, n_cols=3):
    os.makedirs(out_dir, exist_ok=True)
    all_preds = []
    all_trues = []

    rgb_list = []
    true_mask_list = []
    pred_mask_list = []
    present_classes_per_chip = []
    chip_ious = []

    visual_limit = n_rows * n_cols
    tile_index = 0

    for x_batch, y_batch in tqdm(test_gen, desc="Evaluating"):
        if tf.size(x_batch).numpy() == 0:
            continue

        soft_preds = model.predict(x_batch, verbose=0)
        true_mask_batch = tf.argmax(y_batch, axis=-1).numpy().astype(np.uint8)

        for i in range(x_batch.shape[0]):
            if tile_index >= len(test_df):
                break

            # Fix: Convert tensor to NumPy before applying astype()
            rgb_img = (x_batch[i][..., :3].numpy() * 255).astype(np.uint8)
            softmax_pred = soft_preds[i]
            crf_mask = apply_crf(rgb_img, softmax_pred)

            true_mask = true_mask_batch[i]
            all_preds.extend(crf_mask.reshape(-1))
            all_trues.extend(true_mask.reshape(-1))

            # Per-chip IoU
            cm = confusion_matrix(true_mask.flatten(), crf_mask.flatten(), labels=list(range(NUM_CLASSES)))
            ious = []
            for j in range(NUM_CLASSES):
                intersection = cm[j, j]
                union = np.sum(cm[j, :]) + np.sum(cm[:, j]) - intersection
                if union > 0:
                    iou = intersection / union
                    ious.append(iou)
            chip_ious.append(np.mean(ious) if ious else 0)

            if len(rgb_list) < visual_limit:
                rgb_list.append(rgb_img)
                true_mask_onehot = y_batch[i].numpy()
                true_mask_list.append(true_mask_onehot)
                pred_mask_list.append(crf_mask)

                present_classes = np.unique(true_mask)
                present_classes = [int(c) for c in present_classes if c < NUM_CLASSES]
                present_classes_per_chip.append(present_classes)

            tile_index += 1

        gc.collect()


    visualise_prediction_grid(
        rgb_list=rgb_list,
        true_mask_list=true_mask_list,
        pred_mask_list=pred_mask_list,
        n_rows=n_rows,
        n_cols=n_cols
    )

    # --- Metrics ---
    f1 = f1_score(all_trues, all_preds, average='macro')
    precision = precision_score(all_trues, all_preds, average='macro')
    recall = recall_score(all_trues, all_preds, average='macro')

    conf_matrix = confusion_matrix(all_trues, all_preds, labels=list(range(NUM_CLASSES)))
    per_class_ious = []
    for i in range(NUM_CLASSES):
        intersection = conf_matrix[i, i]
        union = np.sum(conf_matrix[i, :]) + np.sum(conf_matrix[:, i]) - intersection
        iou = intersection / union if union > 0 else 0
        per_class_ious.append(iou)
    miou = np.mean(per_class_ious)

    report = classification_report(
        all_trues, all_preds, target_names=CLASS_NAMES, digits=4
    )

    # Confusion Matrix Plot
    row_sums = conf_matrix.sum(axis=1, keepdims=True)
    norm_conf = np.divide(conf_matrix.astype(np.float32), row_sums, where=row_sums != 0)
    fig, ax = plt.subplots(figsize=(8, 6))
    disp = ConfusionMatrixDisplay(norm_conf, display_labels=CLASS_NAMES)
    disp.plot(cmap='viridis', ax=ax, values_format=".2f")
    fig.tight_layout()
    fig.savefig(os.path.join(out_dir, 'confusion_matrix_crf.png'))
    plt.close(fig)

    # --- Per-chip mIoU histogram ---
    print("\n📊 Generating per-chip mIoU histogram...")
    bin_edges = np.linspace(0, 1, 21)  # 5% bins
    plt.figure(figsize=(10, 6))
    plt.hist(chip_ious, bins=bin_edges, edgecolor='black')
    plt.xlabel("Per-Chip mIoU")
    plt.ylabel("Number of Chips")
    plt.title("Distribution of Per-Chip mIoU (CRF Post-Processed)")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "per_chip_miou_hist.png"))
    plt.show()

    return {
        "macro_f1": f1,
        "precision": precision,
        "recall": recall,
        "miou": miou,
        "per_class_ious": per_class_ious,
        "classification_report": report
    }