# DETR Model Evaluation Dashboard

This notebook evaluates the fine-tuned DETR model, reports metrics using MLOps-style visuals, and showcases predictions on test images. Feel free to extend it with additional analyses as the experiment evolves.


In [None]:
import importlib
import sys
import warnings
from collections.abc import Iterable
from pathlib import Path

import matplotlib.patches as patches
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from PIL import Image

try:
    NOTEBOOK_DIR = Path(__file__).resolve().parent
except NameError:
    NOTEBOOK_DIR = Path.cwd().resolve()

PROJECT_ROOT = NOTEBOOK_DIR
for candidate in [NOTEBOOK_DIR, *NOTEBOOK_DIR.parents]:
    if (candidate / "src").exists():
        PROJECT_ROOT = candidate
        break
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

eval_utils = importlib.import_module("src.evaluation.evaluate_detr")

sns.set_theme(style="whitegrid")
warnings.filterwarnings(
    "ignore",
    message=r"for .*: copying from a non-meta parameter",
    category=UserWarning,
)
print(f"Project root: {PROJECT_ROOT}")

In [None]:
MODEL_DIR = PROJECT_ROOT / "models" / "detr-finetuned" / "v2"
TEST_IMAGES_DIR = PROJECT_ROOT / "data" / "images" / "test"
BATCH_SIZE = 2
METRIC_SCORE_THRESHOLD = 0.1
VISUAL_SCORE_THRESHOLD = 0.5

print(
    f"Model directory: {MODEL_DIR if MODEL_DIR.exists() else 'fallback HF checkpoint'}"
)
print(f"Test images directory: {TEST_IMAGES_DIR}")

model, processor, id2label = eval_utils.load_model(
    MODEL_DIR if MODEL_DIR.exists() else None
)
print(f"Loaded {len(id2label)} classes; model is on {eval_utils.DEVICE}.")

In [None]:
metrics, predictions, coco_gt = eval_utils.run_test_evaluation(
    model=model,
    processor=processor,
    batch_size=BATCH_SIZE,
    score_threshold=METRIC_SCORE_THRESHOLD,
)

metrics_df = (
    pd.Series(metrics)
    .rename("value")
    .sort_values(ascending=False)
    .reset_index()
    .rename(columns={"index": "metric"})
)
metrics_df

In [None]:
metric_families = []
for name in metrics_df["metric"]:
    if name.startswith("mAP"):
        metric_families.append("mAP")
    else:
        metric_families.append("AR")
metrics_df["family"] = metric_families

fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharey=False)
for ax, family in zip(axes, ["mAP", "AR"], strict=False):
    subset = metrics_df[metrics_df["family"] == family]
    if subset.empty:
        ax.axis("off")
        ax.set_title(f"No {family} metrics available")
        continue
    plot = sns.barplot(
        data=subset,
        y="metric",
        x="value",
        hue="metric",
        dodge=False,
        ax=ax,
        palette="viridis",
    )
    if plot.legend_:
        plot.legend_.remove()
    ax.set_title(f"{family}-metrics")
    ax.set_xlabel("Score")
    ax.set_ylabel("Metric")
    ax.set_xlim(0, 1)
plt.suptitle("Test split metric overview")
plt.tight_layout()

In [None]:
if predictions:
    preds_df = pd.DataFrame(predictions)
    preds_df["label"] = preds_df["category_id"].map(id2label)
    preds_df["score"] = preds_df["score"].round(3)
    display(preds_df.head())
else:
    preds_df = pd.DataFrame(
        columns=["image_id", "category_id", "bbox", "score", "label"]
    )
    print("No predictions above the selected threshold.")

In [None]:
if not preds_df.empty:
    label_counts = (
        preds_df.groupby("label")["score"]
        .count()
        .rename("detections")
        .sort_values(ascending=False)
        .reset_index()
    )
    display(label_counts.set_index("label"))
    plt.figure(figsize=(8, 4))
    plot = sns.barplot(
        data=label_counts,
        x="label",
        y="detections",
        hue="label",
        dodge=False,
        palette="mako",
    )
    if plot.legend_:
        plot.legend_.remove()
    plt.xticks(rotation=45, ha="right")
    plt.ylabel("Number of detections")
    plt.title("Detections per class (test set predictions)")
    plt.tight_layout()
else:
    print("No detection stats to visualize.")

In [None]:
import json

# Training data coverage per class for quick sanity checks
DATA_VERSION = MODEL_DIR.name or "v2"
train_coco_path = PROJECT_ROOT / "data" / "processed" / DATA_VERSION / "train.json"
print(f"Training dataset stats from: {train_coco_path}")

train_counts_df = pd.DataFrame(columns=["class", "image_count"])
if not train_coco_path.exists():
    print("Training annotations not available locally.")
else:
    with open(train_coco_path) as f:
        train_coco = json.load(f)

    class_lookup = {cat["id"]: cat["name"] for cat in train_coco.get("categories", [])}
    counts = {cid: 0 for cid in class_lookup}

    ann_df = pd.DataFrame(train_coco.get("annotations", []))
    if not ann_df.empty:
        unique_counts = ann_df.groupby("category_id")["image_id"].nunique().to_dict()
        counts.update(unique_counts)

    if not class_lookup:
        print("No categories found in training annotations.")
    else:
        ordered_ids = sorted(class_lookup)
        train_counts_df = (
            pd.DataFrame(
                {
                    "class": [class_lookup[cid] for cid in ordered_ids],
                    "image_count": [counts[cid] for cid in ordered_ids],
                }
            )
            .sort_values("image_count", ascending=False)
            .reset_index(drop=True)
        )
        display(train_counts_df.set_index("class"))

        plt.figure(figsize=(11, 4))
        sns.barplot(
            data=train_counts_df,
            x="class",
            y="image_count",
            hue="class",
            dodge=False,
            legend=False,
            palette="viridis",
        )
        plt.xticks(rotation=45, ha="right")
        plt.ylabel("Unique training images")
        plt.xlabel("Class")
        plt.title("Training images per class")
        plt.tight_layout()

In [None]:
def draw_detections(
    ax, payload: dict, color_cycle: Iterable[tuple[float, float, float]] | None = None
) -> None:
    """Draw bounding boxes and labels on a Matplotlib axis."""
    image = Image.open(payload["image_path"]).convert("RGB")
    ax.imshow(image)
    ax.set_axis_off()
    boxes = payload["boxes"]
    scores = payload["scores"]
    labels = payload["labels"]
    if color_cycle is None:
        color_cycle = sns.color_palette("tab10", n_colors=max(1, len(boxes)))
    for idx, (box, score, label) in enumerate(zip(boxes, scores, labels, strict=False)):
        x0, y0, x1, y1 = box
        width = x1 - x0
        height = y1 - y0
        color = color_cycle[idx % len(color_cycle)]
        rect = patches.Rectangle(
            (x0, y0), width, height, linewidth=2, edgecolor=color, facecolor="none"
        )
        ax.add_patch(rect)
        ax.text(
            x0,
            max(y0 - 2, 0),
            f"{label} ({score:.2f})",
            color="white",
            fontsize=10,
            ha="left",
            va="bottom",
            bbox=dict(facecolor=color, alpha=0.6, edgecolor="none", pad=2),
        )
    ax.set_title(payload["image_path"].name)

In [None]:
visual_payloads = []
example_images = sorted(TEST_IMAGES_DIR.glob("*.jpg"))[:3]
if not example_images:
    print("No test images found for visualization.")
else:
    visual_payloads = eval_utils.get_visual_predictions(
        model=model,
        processor=processor,
        id2label=id2label,
        image_paths=example_images,
        score_threshold=VISUAL_SCORE_THRESHOLD,
    )
    cols = len(visual_payloads)
    fig, axes = plt.subplots(1, cols, figsize=(6 * cols, 6))
    if cols == 1:
        axes = [axes]
    for ax, payload in zip(axes, visual_payloads, strict=False):
        draw_detections(ax, payload)
    plt.suptitle("Sample predictions on test images")
    plt.tight_layout()

In [None]:
if example_images and visual_payloads:
    detailed_rows = []
    for payload in visual_payloads:
        for box, score, label in zip(
            payload["boxes"],
            payload["scores"],
            payload["labels"],
            strict=False,
        ):
            detailed_rows.append(
                {
                    "image": payload["image_path"].name,
                    "label": label,
                    "score": round(score, 3),
                    "x_min": round(box[0], 1),
                    "y_min": round(box[1], 1),
                    "x_max": round(box[2], 1),
                    "y_max": round(box[3], 1),
                }
            )
    if detailed_rows:
        pd.DataFrame(detailed_rows)
    else:
        print("No bounding boxes above the visualization threshold.")
else:
    print("Skipping detail table; no visualizations available.")