In [1]:
import os
import json
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

# ============================================================
# IOU function
# ============================================================
def iou_matrix(a, b):
    if len(a) == 0 or len(b) == 0:
        return np.zeros((len(a), len(b)))
    a = np.asarray(a); b = np.asarray(b)

    ax1, ay1, ax2, ay2 = a[:,0], a[:,1], a[:,2], a[:,3]
    bx1, by1, bx2, by2 = b[:,0], b[:,1], b[:,2], b[:,3]

    inter_x1 = np.maximum(ax1[:,None], bx1[None,:])
    inter_y1 = np.maximum(ay1[:,None], by1[None,:])
    inter_x2 = np.minimum(ax2[:,None], bx2[None,:])
    inter_y2 = np.minimum(ay2[:,None], by2[None,:])

    inter_w = np.clip(inter_x2 - inter_x1, 0, None)
    inter_h = np.clip(inter_y2 - inter_y1, 0, None)
    inter = inter_w * inter_h

    area_a = (ax2 - ax1) * (ay2 - ay1)
    area_b = (bx2 - bx1) * (by2 - by1)
    union = area_a[:,None] + area_b[None,:] - inter

    return inter / np.clip(union, 1e-8, None)


# ============================================================
# Compute AP from precision/recall
# ============================================================
def compute_ap(rec, prec):
    mrec = np.concatenate(([0], rec, [1]))
    mpre = np.concatenate(([0], prec, [0]))

    for i in range(len(mpre)-1, 0, -1):
        mpre[i-1] = max(mpre[i-1], mpre[i])

    idx = np.where(mrec[1:] != mrec[:-1])[0]
    return float(np.sum((mrec[idx+1] - mrec[idx]) * mpre[idx+1]))


# ============================================================
# MAIN
# ============================================================
def main():

    # Load all stored results
    with open("outputs/gt_boxes.pkl", "rb") as f:
        gt_boxes = pickle.load(f)

    with open("outputs/preds_per_class.pkl", "rb") as f:
        preds_all = pickle.load(f)

    with open("outputs/class_names.json", "r") as f:
        class_names = json.load(f)

    n_classes = len(class_names)
    print(f"✔ Loaded {n_classes} classes")

    # ensure metrics folder
    os.makedirs("metrics", exist_ok=True)

    # ============================================================
    # CONFUSION MATRIX
    # ============================================================
    cm = np.zeros((n_classes, n_classes), dtype=int)

    for cls_pred_label, preds in preds_all.items():
        cls_pred = cls_pred_label - 1  # convert 1→0 index

        for p in preds:
            img_id = p["img_id"]
            box_pred = p["box"]

            # find matching GT in same image
            best_gt_cls = None
            best_iou = 0
            for (img_gt, cls_gt), gt_list in gt_boxes.items():
                if img_gt == img_id:
                    ious = iou_matrix([box_pred], gt_list)[0]
                    idx = np.argmax(ious)
                    if ious[idx] >= 0.5:
                        best_gt_cls = cls_gt - 1
                        best_iou = ious[idx]
                        break

            if best_gt_cls is not None:
                cm[best_gt_cls, cls_pred] += 1

    plt.figure(figsize=(10,8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Prediction")
    plt.ylabel("Ground truth")
    plt.title("Confusion Matrix (IoU>0.5)")
    plt.tight_layout()
    plt.savefig("metrics/confusion_matrix.png", dpi=300)
    plt.close()
    print("✔ Saved confusion matrix")

    # ============================================================
    # AP per class for IoU thresholds
    # ============================================================
    iou_thresholds = [0.5, 0.75] + [round(x/100, 2) for x in range(50, 100, 5)]
    ap_per_iou = {thr: [] for thr in iou_thresholds}

    # compute GT count per class
    gt_count = {c: 0 for c in range(1, n_classes+1)}
    for (_, cls), lst in gt_boxes.items():
        gt_count[cls] += len(lst)

    # loop IoU thresholds
    for thr in iou_thresholds:
        for cls in range(1, n_classes+1):

            preds = sorted(preds_all.get(cls, []), key=lambda x: x["score"], reverse=True)
            tp, fp = [], []
            matched = defaultdict(set)
            total_gt = gt_count[cls]

            for p in preds:
                img_id = p["img_id"]
                box_pred = p["box"]

                # find GT for that class in same img
                gt_list = gt_boxes.get((img_id, cls), [])
                match = False

                if len(gt_list) > 0:
                    ious = iou_matrix([box_pred], gt_list)[0]
                    idx = np.argmax(ious)
                    if ious[idx] >= thr and idx not in matched[(img_id, cls)]:
                        match = True
                        matched[(img_id, cls)].add(idx)

                tp.append(1 if match else 0)
                fp.append(0 if match else 1)

            if total_gt == 0:
                ap_per_iou[thr].append(0)
                continue

            tp = np.cumsum(tp)
            fp = np.cumsum(fp)

            rec = tp / total_gt
            prec = tp / np.maximum(tp + fp, 1e-8)
            ap_per_iou[thr].append(compute_ap(rec, prec))

    # mAP summary
    map_50_95 = np.mean([np.mean(v) for k,v in ap_per_iou.items() if k >= 0.5 and k <= 0.95])
    ap50 = np.mean(ap_per_iou[0.5])
    ap75 = np.mean(ap_per_iou[0.75])

    # ============================================================
    # SAVE JSON REPORT
    # ============================================================
    report = {
        "mAP_50_95": map_50_95,
        "AP50": ap50,
        "AP75": ap75,
        "per_class_AP50": {class_names[i]: float(ap_per_iou[0.5][i]) for i in range(n_classes)},
        "per_class_AP75": {class_names[i]: float(ap_per_iou[0.75][i]) for i in range(n_classes)},
        "num_ground_truth": sum(gt_count.values()),
        "num_predictions": sum(len(v) for v in preds_all.values()),
        "class_names": class_names
    }

    with open("metrics/final_metrics.json", "w") as f:
        json.dump(report, f, indent=4)

    print("✔ Saved metrics/final_metrics.json")
    print("\n===== SUMMARY =====")
    print(json.dumps(report, indent=4))


if __name__ == "__main__":
    main()

✔ Loaded 8 classes
✔ Saved confusion matrix
✔ Saved metrics/final_metrics.json

===== SUMMARY =====
{
    "mAP_50_95": 0.5669891379379539,
    "AP50": 0.8028507496731394,
    "AP75": 0.6127044507987772,
    "per_class_AP50": {
        "temoin:0": 1.0,
        "temoin:25": 0.9437264560996812,
        "temoin:50": 0.9026581041463073,
        "temoin:75": 0.9090909090909091,
        "temoin:80": 0.9,
        "temoin:90": 0.7857142857142858,
        "temoin:95": 0.0,
        "temoin:100": 0.9816162423339315
    },
    "per_class_AP75": {
        "temoin:0": 0.25,
        "temoin:25": 0.6055546887189889,
        "temoin:50": 0.7942185117967334,
        "temoin:75": 0.9090909090909091,
        "temoin:80": 0.7200000000000001,
        "temoin:90": 0.7857142857142858,
        "temoin:95": 0.0,
        "temoin:100": 0.8370572110693009
    },
    "num_ground_truth": 145,
    "num_predictions": 260,
    "class_names": [
        "temoin:0",
        "temoin:25",
        "temoin:50",
        "temoin