In [None]:
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

out_dir="/content/figs"

# --- Sanity ---
def test_scoring_sanity():
    print("✅ from scoring.ipynb")


def evaluate_predictions(pred_mask, true_mask, num_classes=6):
    """
    Compare predicted and true class masks and print evaluation metrics.
    
    Parameters:
    - pred_mask: 2D or 1D array of predicted class indices
    - true_mask: 2D or 1D array of ground truth class indices
    - num_classes: total number of semantic classes
    """
    print("🔎 Evaluating predictions...")

    # --- Flatten if needed ---
    pred_flat = np.asarray(pred_mask).flatten()
    true_flat = np.asarray(true_mask).flatten()

    if pred_flat.shape != true_flat.shape:
        raise ValueError(f"Prediction and ground truth shapes don't match: {pred_flat.shape} vs {true_flat.shape}")

    # --- Classification Report ---
    print("📈 Classification Report:")
    print(classification_report(
        true_flat,
        pred_flat,
        labels=np.arange(num_classes),
        zero_division=0,
        digits=3
    ))

    class_names = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']
    
    # --- Confusion Matrix ---
    print("\n🌀 Confusion Matrix:")
    cm = confusion_matrix(true_flat, pred_flat, labels=np.arange(num_classes))

    # Normalise rows
    cm_normalised = cm.astype(np.float32) / cm.sum(axis=1, keepdims=True)

    plt.figure(figsize=(9, 8))
    sns.heatmap(cm_normalised, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=class_names,
                yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Normalised Confusion Matrix")
    plt.yticks(rotation=0)  # ✅ Fix: horizontal y-axis labels
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "conf_matrix.png"))
    plt.show()

    # --- IoU Calculation ---
    print("\n📊 IoU per Class:")
    intersection = np.diag(cm)
    ground_truth_set = cm.sum(axis=1)
    predicted_set = cm.sum(axis=0)
    union = ground_truth_set + predicted_set - intersection

    iou_per_class = intersection / np.maximum(union, 1)

    for i, iou in enumerate(iou_per_class):
        name = class_names[i]
        print(f"  Class {i} ({name}): {iou:.4f}")

    mean_iou = np.mean(iou_per_class)
    print(f"\n📈 Mean IoU: {mean_iou:.4f}")


    from tqdm import trange
    
    def evaluate_on_test(model, test_gen, n_vis=10):
        print("🧪 Running evaluation on test set...")
        all_test_preds = []
        all_test_trues = []
        shown = 0

        total_batches = len(test_gen)

        for i in trange(total_batches, desc="🧪 Evaluating test set", unit="batch"):
            test_imgs, test_lbls = test_gen[i]
            if test_imgs.size == 0:
                continue

            pred = model.predict(test_imgs, verbose=0)
            pred_mask = np.argmax(pred, axis=-1).astype(np.uint8)
            true_mask = np.argmax(test_lbls, axis=-1).astype(np.uint8)

            all_test_preds.extend(pred_mask.reshape(-1))
            all_test_trues.extend(true_mask.reshape(-1))

            for j in range(test_imgs.shape[0]):
                if shown >= n_vis:
                    break
                rgb_tile = (test_imgs[j][:, :, :3] * 255).astype(np.uint8)
                visualise_prediction(rgb_tile, true_mask[j], pred_mask[j])
                shown += 1

        if not all_test_preds:
            print("⚠️ No test predictions were collected.")
            return

        print("\n📊 Test Set Evaluation Results:")
        evaluate_predictions(np.array(all_test_preds), np.array(all_test_trues))


