In [None]:
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import os
from tqdm import trange
from tensorflow.keras.metrics import MeanIoU
import gc
from PIL import Image
from datetime import datetime
from sklearn.metrics import f1_score, precision_score, recall_score


out_dir="/content/figs"
CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items() if v < 6} # Exclude ignore color
CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']
NUM_CLASSES = 6
IGNORE_COLOR = (255, 0, 255)


def measure_inference_time(model, generator, num_batches=5):
    import time
    total_time = 0
    total_images = 0

    if num_batches is None:
        num_batches = tf.data.experimental.cardinality(generator).numpy()

    for i, (x_batch, _) in enumerate(generator.take(num_batches)):
        start = time.time()
        _ = model.predict(x_batch, verbose=0)
        end = time.time()
        total_time += (end - start)
        total_images += x_batch.shape[0]

    print(f"🧠 Inference time: {total_time:.2f} sec for {total_images} images")
    print(f"⏱️ Avg inference time per image: {total_time / total_images:.4f} sec")



def plot_training_curves(history, out_dir):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    history_dict = history.history
    required_keys = ["loss", "val_loss", "iou_score", "val_iou_score"]

    missing_keys = [k for k in required_keys if k not in history_dict]
    if missing_keys:
        print(f"⚠️ Missing keys in history: {missing_keys}")
        return

    fig, axs = plt.subplots(1, 2, figsize=(12, 5))

    # Loss
    axs[0].plot(history_dict["loss"], label="Train Loss")
    axs[0].plot(history_dict["val_loss"], label="Val Loss")
    axs[0].set_title("Loss over Epochs")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].legend()

    # IoU
    axs[1].plot(history_dict["iou_score"], label="Train IoU")
    axs[1].plot(history_dict["val_iou_score"], label="Val IoU")
    axs[1].set_title("Mean IoU over Epochs")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Mean IoU")
    axs[1].legend()

    plt.tight_layout()
    save_path = os.path.join(out_dir, "training_curves.png")
    plt.savefig(save_path)
    plt.show()
    plt.close()
    print(f"Saved training curves to: {save_path}")






# Re-import necessary packages after kernel reset
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix, f1_score, precision_score, recall_score, ConfusionMatrixDisplay
import tensorflow as tf
import gc
from tqdm import tqdm
import cv2
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral

# Constants
NUM_CLASSES = 6
CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']
COLOR_TO_CLASS = {
    (230, 25, 75): 0,
    (145, 30, 180): 1,
    (60, 180, 75): 2,
    (245, 130, 48): 3,
    (255, 255, 255): 4,
    (0, 130, 200): 5,
    (255, 0, 255): 6
}
CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items() if v < 6}

def apply_crf(rgb, probs, t=5, compat=10, sxy=3, srgb=13):
    h, w = probs.shape[1:3]
    d = dcrf.DenseCRF2D(w, h, NUM_CLASSES)
    unary = unary_from_softmax(probs.transpose(2, 0, 1))
    d.setUnaryEnergy(unary)
    pairwise = create_pairwise_bilateral(sdims=(sxy, sxy), schan=(srgb, srgb, srgb), img=rgb, chdim=2)
    d.addPairwiseEnergy(pairwise, compat=compat)
    Q = d.inference(t)
    return np.argmax(np.array(Q), axis=0).reshape((h, w))

def evaluate_model_with_crf(model, test_gen, test_df, out_dir, image_dir, label_dir, tile_size=256):
    os.makedirs(out_dir, exist_ok=True)
    all_preds = []
    all_trues = []
    crf_visuals = []

    test_tile_ids = test_df['tile_id'].tolist()
    tile_index = 0

    for x_batch, y_batch in tqdm(test_gen, desc="Evaluating"):
        if x_batch.size == 0:
            continue

        soft_preds = model.predict(x_batch, verbose=0)
        true_mask = tf.argmax(y_batch, axis=-1).numpy().astype(np.uint8)

        for i in range(x_batch.shape[0]):
            rgb_img = (x_batch[i][..., :3] * 255).astype(np.uint8)
            softmax_pred = soft_preds[i]  # Shape: (H, W, C)
            crf_mask = apply_crf(rgb_img, softmax_pred)

            all_preds.extend(crf_mask.reshape(-1))
            all_trues.extend(true_mask[i].reshape(-1))

            if len(crf_visuals) < 6:
                pred_rgb = np.zeros((tile_size, tile_size, 3), dtype=np.uint8)
                gt_rgb = np.zeros((tile_size, tile_size, 3), dtype=np.uint8)
                for cid, color in CLASS_TO_COLOR.items():
                    pred_rgb[crf_mask == cid] = color
                    gt_rgb[true_mask[i] == cid] = color
                crf_visuals.append((rgb_img, gt_rgb, pred_rgb))

        gc.collect()

    f1 = f1_score(all_trues, all_preds, average='macro')
    precision = precision_score(all_trues, all_preds, average='macro')
    recall = recall_score(all_trues, all_preds, average='macro')

    conf_matrix = confusion_matrix(all_trues, all_preds, labels=list(range(NUM_CLASSES)))
    per_class_ious = []
    for i in range(NUM_CLASSES):
        intersection = conf_matrix[i, i]
        union = np.sum(conf_matrix[i, :]) + np.sum(conf_matrix[:, i]) - intersection
        iou = intersection / union if union > 0 else 0
        per_class_ious.append(iou)
    miou = np.mean(per_class_ious)

    report = classification_report(
        all_trues, all_preds, target_names=CLASS_NAMES, digits=4
    )

    row_sums = conf_matrix.sum(axis=1, keepdims=True)
    norm_conf = np.divide(conf_matrix.astype(np.float32), row_sums, where=row_sums != 0)
    fig, ax = plt.subplots(figsize=(8, 6))
    disp = ConfusionMatrixDisplay(norm_conf, display_labels=CLASS_NAMES)
    disp.plot(cmap='viridis', ax=ax, values_format=".2f")
    fig.tight_layout()
    fig.savefig(os.path.join(out_dir, 'confusion_matrix_crf.png'))
    plt.close(fig)

    fig, axs = plt.subplots(len(crf_visuals), 3, figsize=(14, 3 * len(crf_visuals)))
    for idx, (img, gt, pred) in enumerate(crf_visuals):
        axs[idx, 0].imshow(img)
        axs[idx, 0].set_title("Input")
        axs[idx, 1].imshow(gt)
        axs[idx, 1].set_title("Ground Truth")
        axs[idx, 2].imshow(pred)
        axs[idx, 2].set_title("CRF Prediction")
        for i in range(3):
            axs[idx, i].axis('off')
    plt.tight_layout()
    plt.show()

    return {
        "macro_f1": f1,
        "precision": precision,
        "recall": recall,
        "miou": miou,
        "per_class_ious": per_class_ious,
        "classification_report": report
    }





def reconstruct_canvas(model, df, source_file, generator_class, img_dir, elev_dir, slope_dir, label_dir):
    """
    Reconstruct RGB, GT and prediction canvas for a single base file.
    
    Returns:
        Tuple of RGB canvas, GT canvas, Pred canvas (np.uint8)
    """
    import tensorflow as tf
    import numpy as np

    # 1. Filter for the specific source file
    df_file = df[df['source_file'] == source_file].copy()
    if df_file.empty:
        raise ValueError(f"No chips found for source file: {source_file}")

    # 2. Determine canvas shape
    min_x = df_file['x'].min()
    min_y = df_file['y'].min()
    max_x = df_file['x'].max() + 256
    max_y = df_file['y'].max() + 256

    canvas_w = max_x - min_x
    canvas_h = max_y - min_y

    canvas_shape = (canvas_h, canvas_w, 3)
    rgb_canvas = np.full(canvas_shape, IGNORE_COLOR, dtype=np.uint8)
    gt_canvas = np.full(canvas_shape, IGNORE_COLOR, dtype=np.uint8)
    pred_canvas = np.full(canvas_shape, IGNORE_COLOR, dtype=np.uint8)

    # 3. Build dataset for just this file
    gen = generator_class(
        df_file, img_dir, elev_dir, slope_dir, label_dir,
        batch_size=64, shuffle=False, augment=False, split="custom"
    )

    # 4. Predict and fill canvas in order
    row_index = 0
    for batch_x, batch_y in gen:
        preds = model.predict(batch_x, verbose=0)
        pred_mask = tf.argmax(preds, axis=-1).numpy()
        true_mask = tf.argmax(batch_y, axis=-1).numpy()

        batch_size = batch_x.shape[0]
        for i in range(batch_size):
            if row_index >= len(df_file):
                break

            row = df_file.iloc[row_index]
            rel_x = row.x - min_x
            rel_y = row.y - min_y
            row_index += 1

            # RGB image (scale and cast)
            rgb = tf.cast(batch_x[i][..., :3] * 255.0, tf.uint8).numpy()
            rgb_canvas[rel_y:rel_y+256, rel_x:rel_x+256] = rgb

            # GT mask
            gt_rgb = np.zeros((256, 256, 3), dtype=np.uint8)
            for cid, colour in CLASS_TO_COLOR.items():
                gt_rgb[true_mask[i] == cid] = colour
            gt_canvas[rel_y:rel_y+256, rel_x:rel_x+256] = gt_rgb

            # Prediction mask
            pred_rgb = np.zeros((256, 256, 3), dtype=np.uint8)
            for cid, colour in CLASS_TO_COLOR.items():
                pred_rgb[pred_mask[i] == cid] = colour
            pred_canvas[rel_y:rel_y+256, rel_x:rel_x+256] = pred_rgb

    return rgb_canvas, gt_canvas, pred_canvas



def plot_reconstruction(img, label, pred, source_file):
    fig, axs = plt.subplots(1, 3, figsize=(26.5, 13))  # Adjust size as needed

    axs[0].imshow(img)
    axs[0].set_title("RGB Image")
    axs[0].axis('off')

    axs[1].imshow(label)
    axs[1].set_title("Ground Truth")
    axs[1].axis('off')

    axs[2].imshow(pred)
    axs[2].set_title("Model Prediction")
    axs[2].axis('off')

    # Big title for the whole figure
    plt.suptitle(f"Reconstruction for: {source_file}", fontsize=24, y=0.95)
    plt.tight_layout(rect=[0, 0, 1, 0.93])  # Leave space for the suptitle
    plt.show()
    plt.close()





def evaluate_on_test(model, test_gen, test_df, out_dir, image_dir, label_dir, tile_size=256, n_rows=2, n_cols=3):
    import os
    import gc
    import numpy as np
    import tensorflow as tf
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn.metrics import confusion_matrix, classification_report
    from PIL import Image

    print("🧪 Running evaluation on test set...")

    all_test_preds = []
    all_test_trues = []

    visual_rgb = []
    visual_true = []
    visual_pred = []
    visual_limit = n_rows * n_cols if n_rows and n_cols else 5

    os.makedirs(out_dir, exist_ok=True)
    test_tile_ids = test_df['tile_id'].tolist()
    tile_index = 0

    for batch_x, batch_y in test_gen.as_numpy_iterator():
        if batch_x.size == 0:
            continue

        pred = model.predict(batch_x, verbose=0)
        pred_mask = np.argmax(pred, axis=-1).astype(np.uint8)
        true_mask = np.argmax(batch_y, axis=-1).astype(np.uint8)

        for j in range(batch_x.shape[0]):
            if tile_index >= len(test_tile_ids):
                break
            tile_id = test_tile_ids[tile_index]
            tile_index += 1

            # Collect visuals
            if len(visual_rgb) < visual_limit:
                rgb_tile = (batch_x[j][:, :, :3] * 255).astype(np.uint8)
                visual_rgb.append(rgb_tile)
                visual_true.append(batch_y[j])
                visual_pred.append(pred_mask[j])

        all_test_preds.extend(pred_mask.reshape(-1))
        all_test_trues.extend(true_mask.reshape(-1))

        del batch_x, batch_y, pred, pred_mask, true_mask
        gc.collect()

    if not all_test_preds:
        print("⚠️ No test predictions were collected.")
        return

    all_test_preds = np.array(all_test_preds)
    all_test_trues = np.array(all_test_trues)

    # Visualise Grid
    if visual_rgb:
        visualise_prediction_grid(visual_rgb, visual_true, visual_pred, n_rows, n_cols)

    # --- Mean IoU ---
    mean_iou_metric = tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES)
    mean_iou_metric.update_state(all_test_trues, all_test_preds)
    miou = mean_iou_metric.result().numpy()

    # --- F1, Precision, Recall (Macro) ---
    macro_f1 = f1_score(all_test_trues, all_test_preds, average='macro')
    macro_precision = precision_score(all_test_trues, all_test_preds, average='macro')
    macro_recall = recall_score(all_test_trues, all_test_preds, average='macro')

    print(f"\n📊 Macro Metrics:")
    print(f"  F1 Score     : {macro_f1:.4f}")
    print(f"  Precision    : {macro_precision:.4f}")
    print(f"  Recall       : {macro_recall:.4f}")

    # --- Confusion Matrix & Per-class IoU ---
    conf_matrix = confusion_matrix(all_test_trues, all_test_preds, labels=list(range(NUM_CLASSES)))
    print("\n📏 Per-class IoU Scores:")
    class_ious = []
    for i in range(NUM_CLASSES):
        intersection = conf_matrix[i, i]
        union = np.sum(conf_matrix[i, :]) + np.sum(conf_matrix[:, i]) - intersection
        iou = intersection / union if union > 0 else float('nan')
        class_ious.append(iou)
        print(f"  {CLASS_NAMES[i]:<12} IoU: {iou:.4f}")

    print(f"\n📈 Mean IoU (mIoU): {miou:.4f}")

    # --- Classification Report ---
    print("\n🔍 Classification Report:")
    print(classification_report(
        all_test_trues,
        all_test_preds,
        labels=list(range(NUM_CLASSES)),
        target_names=CLASS_NAMES,
        digits=4
    ))

    from sklearn.metrics import ConfusionMatrixDisplay

    # --- Confusion Matrix Plot (Row-Normalised using sklearn) ---
    print("\n🌀 Row-Normalised Confusion Matrix:")

    with np.errstate(divide='ignore', invalid='ignore'):
        row_sums = conf_matrix.sum(axis=1, keepdims=True)
        norm_conf = np.divide(conf_matrix.astype(np.float32), row_sums, where=row_sums != 0)

    fig, ax = plt.subplots(figsize=(8, 6))
    disp = ConfusionMatrixDisplay(norm_conf, display_labels=CLASS_NAMES)
    disp.plot(cmap='viridis', ax=ax, values_format=".2f")
    #plt.title('Row-Normalised Confusion Matrix')
    fig.tight_layout()
    fig.savefig(os.path.join(out_dir, 'confusion_matrix_row_norm.png'))
    plt.show()
    plt.close(fig)



def normalize_confusion_matrix(cm, norm='true'):
    """
    Normalize a confusion matrix.
    
    Parameters:
    cm (array-like): Confusion matrix to be normalized.
    norm (str): Type of normalization ('true', 'pred', 'all').
    
    Returns:
    ndarray: Normalized confusion matrix.
    """
    if norm == 'true':
        cm_normalized = cm.astype(np.float32) / cm.sum(axis=1)[:, np.newaxis]
    elif norm == 'pred':
        cm_normalized = cm.astype(np.float32) / cm.sum(axis=0)[np.newaxis, :]
    elif norm == 'all':
        cm_normalized = cm.astype(np.float32) / cm.sum()
    else:
        raise ValueError("Unknown normalization type. Use 'true', 'pred', or 'all'.")
    
    return cm_normalized



def visualise_prediction_grid(rgb_list, true_mask_list, pred_mask_list, n_rows, n_cols):
    import matplotlib.pyplot as plt
    import numpy as np

    total = n_rows * n_cols
    fig, axs = plt.subplots(n_rows, n_cols * 3, figsize=(n_cols * 6.6, n_rows * 2.6))

    for idx in range(total):
        rgb = rgb_list[idx]
        true_mask = np.argmax(true_mask_list[idx], axis=-1)
        pred_mask = pred_mask_list[idx]

        h, w = true_mask.shape
        true_rgb = np.zeros((h, w, 3), dtype=np.uint8)
        pred_rgb = np.zeros((h, w, 3), dtype=np.uint8)

        for class_id, color in CLASS_TO_COLOR.items():
            true_rgb[true_mask == class_id] = color
            pred_rgb[pred_mask == class_id] = color

        ignore_mask = np.all(true_mask_list[idx] == 0, axis=-1)
        true_rgb[ignore_mask] = (255, 0, 255)
        pred_rgb[ignore_mask] = (255, 0, 255)

        row = idx // n_cols
        col = (idx % n_cols) * 3

        axs[row, col + 0].imshow(rgb)
        axs[row, col + 0].set_title("Input")
        axs[row, col + 1].imshow(true_rgb)
        axs[row, col + 1].set_title("Ground Truth")
        axs[row, col + 2].imshow(pred_rgb)
        axs[row, col + 2].set_title("Prediction")

        for i in range(3):
            axs[row, col + i].axis("off")

    plt.tight_layout()
    plt.show()
    plt.close(fig)


