In [None]:
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import os
from tqdm import trange

out_dir="/content/figs"

# --- Sanity ---
def test_scoring_sanity():
    print("✅ from scoring.ipynb")

CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items() if v < 6} # Exclude ignore color


def visualise_prediction(rgb, true_mask_onehot, pred_mask):
    true_mask = np.argmax(true_mask_onehot, axis=-1)

    h, w = true_mask.shape
    true_rgb = np.zeros((h, w, 3), dtype=np.uint8)
    pred_rgb = np.zeros((h, w, 3), dtype=np.uint8)

    for class_id, color in CLASS_TO_COLOR.items():
        true_rgb[true_mask == class_id] = color
        pred_rgb[pred_mask == class_id] = color

    ignore_mask = np.all(true_mask_onehot == 0, axis=-1)
    true_rgb[ignore_mask] = (255, 0, 255)
    pred_rgb[ignore_mask] = (255, 0, 255)

    fig, axs = plt.subplots(1, 3, figsize=(14, 4))
    axs[0].imshow(rgb)
    axs[0].set_title("Input")
    axs[0].axis('off')

    axs[1].imshow(true_rgb)
    axs[1].set_title("Ground Truth")
    axs[1].axis('off')

    axs[2].imshow(pred_rgb)
    axs[2].set_title("Prediction")
    axs[2].axis('off')

    plt.tight_layout()
    plt.show()



def evaluate_predictions(pred_mask, true_mask, num_classes=6):
    print("🔎 Evaluating predictions...")

    pred_flat = np.asarray(pred_mask).flatten()
    true_flat = np.asarray(true_mask).flatten()

    # Remove ignored pixels (label == 6)
    valid_indices = true_flat != 6
    pred_flat = pred_flat[valid_indices]
    true_flat = true_flat[valid_indices]

    if pred_flat.shape != true_flat.shape:
        raise ValueError(f"Prediction and ground truth shapes don't match: {pred_flat.shape} vs {true_flat.shape}")

    print("📈 Classification Report:")
    print(classification_report(
        true_flat,
        pred_flat,
        labels=np.arange(num_classes),
        zero_division=0,
        digits=3
    ))

    class_names = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']

    print("\n🌀 Confusion Matrix:")
    cm = confusion_matrix(true_flat, pred_flat, labels=np.arange(num_classes))
    cm_normalised = cm.astype(np.float32) / cm.sum(axis=1, keepdims=True)

    plt.figure(figsize=(8, 6))
    sns.heatmap(cm_normalised, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=class_names,
                yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Normalised Confusion Matrix")
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "conf_matrix.png"))
    plt.show()

    print("\n📊 IoU per Class:")
    intersection = np.diag(cm)
    ground_truth_set = cm.sum(axis=1)
    predicted_set = cm.sum(axis=0)
    union = ground_truth_set + predicted_set - intersection

    iou_per_class = intersection / np.maximum(union, 1)

    for i, iou in enumerate(iou_per_class):
        name = class_names[i]
        print(f"  Class {i} ({name}): {iou:.4f}")

    mean_iou = np.mean(iou_per_class)
    print(f"\n📈 Mean IoU: {mean_iou:.4f}")


import gc

def evaluate_on_test(model, test_gen, n_vis=10, max_batches=None):
    print("🧪 Running evaluation on test set...")
    shown_ids = set()

    all_preds = []
    all_trues = []

    total_batches = len(test_gen)
    if max_batches is not None:
        total_batches = min(total_batches, max_batches)

    for i in trange(total_batches, desc="🧪 Evaluating test set", unit="batch"):
        test_imgs, test_lbls = test_gen[i]
        if test_imgs.size == 0:
            continue

        # 🧠 Predict
        pred = model.predict(test_imgs, verbose=0)
        pred_mask = np.argmax(pred, axis=-1).astype(np.uint8)
        true_mask = np.argmax(test_lbls, axis=-1).astype(np.uint8)

        # 💡 Append only for metrics — not images
        all_preds.append(pred_mask)
        all_trues.append(true_mask)

        # 🔍 Visualise only a few tiles
        for j in range(test_imgs.shape[0]):
            tile_id = i * test_gen.batch_size + j
            if len(shown_ids) < n_vis and tile_id not in shown_ids:
                rgb_tile = (test_imgs[j][:, :, :3] * 255).astype(np.uint8)
                visualise_prediction(rgb_tile, test_lbls[j], pred_mask[j])
                shown_ids.add(tile_id)

        # 🧹 Clean up after each batch
        del test_imgs, test_lbls, pred, pred_mask, true_mask
        gc.collect()

    if not all_preds:
        print("⚠️ No test predictions were collected.")
        return

    # 🔻 Convert once (faster, lighter)
    all_preds = np.concatenate(all_preds).ravel()
    all_trues = np.concatenate(all_trues).ravel()

    print("\n📊 Test Set Evaluation Results:")
    evaluate_predictions(all_preds, all_trues)
