In [None]:
import sys
sys.path.insert(0, '..')

from pathlib import Path
import random

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
from PIL import Image

from src.config.config import MainConfig
from src.data.coco_loader import COCOLoader
from src.data.preprocessed_dataset import PreprocessedDataset
from src.evaluation import load_or_compute_matching, get_identity_mapping
from src.evaluation.detection_matching import MatchedDetection, patch_centers_in_image_space, _point_in_bbox
from src.visualization.primitives import get_crop_bounds, compute_padding_info, patch_to_crop_coords

In [None]:
CONFIG_PATH = Path('config_zebra_test.yaml')  # <-- change as needed

config = MainConfig.from_yaml(CONFIG_PATH)
dataset = PreprocessedDataset(config.output_root)
coco_loader = COCOLoader(config.coco_json_path, config.dataset_root)

target_size = config.active_resize_size
patch_size = config.active_patch_size

matched = load_or_compute_matching(
    dataset, coco_loader, config.output_root,
    target_size=target_size, patch_size=patch_size,
    category_names=config.matching_categories,
)
print(f'{len(matched)} matched detections')

In [None]:
def patch_rects_in_image_space(det, img_w, img_h, target_size, patch_size):
    """Return list of (x, y, w, h) rectangles in image coordinates for each valid patch."""
    crop_x1, crop_y1, crop_x2, crop_y2 = get_crop_bounds(det.square_crop_bbox, img_w, img_h)
    crop_w = crop_x2 - crop_x1
    crop_h = crop_y2 - crop_y1
    if crop_w <= 0 or crop_h <= 0:
        return []

    h_patches, w_patches = det.patch_mask.shape
    rects = []
    for i in range(h_patches):
        for j in range(w_patches):
            if det.patch_mask[i, j]:
                crop_rect = patch_to_crop_coords(i, j, crop_w, crop_h, target_size, patch_size)
                if crop_rect:
                    cx, cy, cw, ch = crop_rect
                    rects.append((crop_x1 + cx, crop_y1 + cy, cw, ch))
    return rects


def plot_detection_vs_gt(det, gt_ann, img, img_w, img_h, target_size, patch_size,
                         ax=None, title_extra=""):
    """Plot a single detection against a GT annotation on the image."""
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(16, 10))

    ax.imshow(img)

    # Patch rectangles + centers
    patch_rects = patch_rects_in_image_space(det, img_w, img_h, target_size, patch_size)
    centers = patch_centers_in_image_space(
        det.square_crop_bbox, det.patch_mask, img_w, img_h, target_size, patch_size
    )
    inside = [_point_in_bbox(x, y, gt_ann.bbox) for x, y in centers]

    for idx, (x, y, w, h) in enumerate(patch_rects):
        color = (0.2, 0.8, 0.2) if inside[idx] else (0.8, 0.2, 0.2)
        ax.add_patch(Rectangle((x, y), w, h, facecolor=color, alpha=0.35, edgecolor='none'))

    # GT bbox (green solid)
    gx1, gy1, gx2, gy2 = gt_ann.bbox.x1, gt_ann.bbox.y1, gt_ann.bbox.x2, gt_ann.bbox.y2
    ax.add_patch(Rectangle((gx1, gy1), gx2 - gx1, gy2 - gy1,
                            fill=False, edgecolor='lime', linewidth=2.5))

    # Detection bbox (blue dashed)
    dx1, dy1, dx2, dy2 = det.bbox.int().tolist()
    ax.add_patch(Rectangle((dx1, dy1), dx2 - dx1, dy2 - dy1,
                            fill=False, edgecolor='dodgerblue', linewidth=2.5, linestyle='--'))

    n_inside = sum(inside)
    n_total = len(inside)
    score = n_inside / n_total if n_total > 0 else 0.0
    ax.set_title(
        f'score={score:.3f}  |  patches inside GT: {n_inside}/{n_total}  |  '
        f'det={det.detection_id}{title_extra}',
        fontsize=11,
    )
    ax.set_xlim(min(gx1, dx1) - 50, max(gx2, dx2) + 50)
    ax.set_ylim(max(gy2, dy2) + 50, min(gy1, dy1) - 50)
    ax.axis('off')
    return score


def plot_match(match: MatchedDetection, dataset, coco_loader, target_size, patch_size, figsize=(16, 10)):
    """Visualize a single detection-to-GT match."""
    det = dataset.get_detection(match.detection_id)
    if det is None:
        print(f'Detection {match.detection_id} not found')
        return

    gt = match.gt_annotation
    coco_image = coco_loader._images[gt.image_uuid]
    img_path = coco_loader.get_image_path(coco_image)
    img = Image.open(img_path).convert('RGB')
    img_w, img_h = img.size

    fig, ax = plt.subplots(1, 1, figsize=figsize)
    plot_detection_vs_gt(det, gt, img, img_w, img_h, target_size, patch_size, ax=ax,
                         title_extra=f'\nidentity={gt.individual_id}  |  viewpoint={gt.viewpoint}')

    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], color='lime', linewidth=2.5, label='GT bbox'),
        Line2D([0], [0], color='dodgerblue', linewidth=2.5, linestyle='--', label='Detection bbox'),
        Rectangle((0, 0), 1, 1, facecolor=(0.2, 0.8, 0.2), alpha=0.35, label='Patch inside GT'),
        Rectangle((0, 0), 1, 1, facecolor=(0.8, 0.2, 0.2), alpha=0.35, label='Patch outside GT'),
    ]
    ax.legend(handles=legend_elements, loc='upper right', fontsize=9)
    plt.tight_layout()
    plt.show()


def show_all_candidates_for_gt(gt_uuid, coco_loader, dataset, target_size, patch_size, matched=None):
    """Show ALL candidate detections for a specific GT annotation UUID."""
    from src.evaluation.detection_matching import get_image_uuid_from_detection_id, _bboxes_overlap

    # Find GT annotation
    gt_ann = None
    for ann in coco_loader.annotations:
        if ann.uuid == gt_uuid:
            gt_ann = ann
            break
    if gt_ann is None:
        print(f'GT annotation {gt_uuid} not found')
        return

    print(f'GT: uuid={gt_ann.uuid}  identity={gt_ann.individual_id}  viewpoint={gt_ann.viewpoint}')
    print(f'    bbox=({gt_ann.bbox.x1:.0f}, {gt_ann.bbox.y1:.0f}, {gt_ann.bbox.x2:.0f}, {gt_ann.bbox.y2:.0f})')

    # Load the image once
    coco_image = coco_loader._images[gt_ann.image_uuid]
    img_path = coco_loader.get_image_path(coco_image)
    img = Image.open(img_path).convert('RGB')
    img_w, img_h = img.size

    # Find which detection was actually matched to this GT (if any)
    matched_det_id = None
    if matched is not None:
        for m in matched:
            if m.gt_annotation.uuid == gt_uuid:
                matched_det_id = m.detection_id
                break

    # Find all detections on this image
    all_det_ids = [
        did for did in dataset._index['detection_to_batch']
        if get_image_uuid_from_detection_id(did) == gt_ann.image_uuid
    ]
    print(f'{len(all_det_ids)} detections on this image')

    # Load each detection, compute score, collect candidates
    candidates = []
    for det_id in all_det_ids:
        det = dataset.get_detection(det_id)
        if det is None:
            continue
        if not _bboxes_overlap(det.bbox, gt_ann.bbox):
            continue

        centers = patch_centers_in_image_space(
            det.square_crop_bbox, det.patch_mask, img_w, img_h, target_size, patch_size
        )
        if not centers:
            continue
        n_inside = sum(1 for x, y in centers if _point_in_bbox(x, y, gt_ann.bbox))
        score = n_inside / len(centers)
        if score > 0:
            candidates.append((det, score))

    candidates.sort(key=lambda t: t[1], reverse=True)
    print(f'{len(candidates)} candidate detections with overlap > 0\n')

    for rank, (det, score) in enumerate(candidates):
        is_chosen = det.detection_id == matched_det_id
        fig, ax = plt.subplots(1, 1, figsize=(16, 10))
        plot_detection_vs_gt(det, gt_ann, img, img_w, img_h, target_size, patch_size, ax=ax,
                             title_extra=f'  |  rank #{rank+1}' + ('  ** MATCHED **' if is_chosen else ''))
        if is_chosen:
            ax.patch.set_edgecolor('gold')
            ax.patch.set_linewidth(4)
        plt.tight_layout()
        plt.show()

In [None]:
# Random match
m = random.choice(matched)
plot_match(m, dataset, coco_loader, target_size, patch_size)

In [None]:
m

In [None]:
# Specific match by index
IDX = 0  # <-- change as needed
plot_match(matched[IDX], dataset, coco_loader, target_size, patch_size)

In [None]:
# Show 5 worst matches (lowest score) to find potential problems
sorted_by_score = sorted(matched, key=lambda m: m.score)
for m in sorted_by_score[:5]:
    print(m.gt_annotation)
    plot_match(m, dataset, coco_loader, target_size, patch_size)

In [None]:
GT_UUID = '81ab419c-6252-409c-b788-8b1e92113014'  # <-- change as needed
show_all_candidates_for_gt(GT_UUID, coco_loader, dataset, target_size, patch_size, matched=matched)

In [None]:
# Find where the top-1 candidate for this GT actually went
GT_UUID = '6da70b81-0f70-418e-9ef5-581eb39537f0'

from src.evaluation.detection_matching import get_image_uuid_from_detection_id, _bboxes_overlap
from matplotlib.lines import Line2D

gt_ann = next(a for a in coco_loader.annotations if a.uuid == GT_UUID)
coco_image = coco_loader._images[gt_ann.image_uuid]
img_w, img_h = coco_image.width, coco_image.height

all_det_ids = [
    did for did in dataset._index['detection_to_batch']
    if get_image_uuid_from_detection_id(did) == gt_ann.image_uuid
]

# Find best candidate
best_det_id, best_n = None, 0
for det_id in all_det_ids:
    det = dataset.get_detection(det_id)
    if det is None or not _bboxes_overlap(det.bbox, gt_ann.bbox):
        continue
    centers = patch_centers_in_image_space(
        det.square_crop_bbox, det.patch_mask, img_w, img_h, target_size, patch_size
    )
    n_inside = sum(1 for x, y in centers if _point_in_bbox(x, y, gt_ann.bbox))
    if n_inside > best_n:
        best_n = n_inside
        best_det_id = det_id

print(f"Top candidate: {best_det_id} with {best_n} patches inside GT {GT_UUID}")

# Which GT was this detection actually matched to?
for m in matched:
    if m.detection_id == best_det_id:
        rival_gt = m.gt_annotation
        print(f"\nIt was matched to GT {rival_gt.uuid} (score={m.score})")
        print(f"  identity={rival_gt.individual_id}  viewpoint={rival_gt.viewpoint}")
        print(f"  bbox=({rival_gt.bbox.x1:.0f}, {rival_gt.bbox.y1:.0f}, {rival_gt.bbox.x2:.0f}, {rival_gt.bbox.y2:.0f})")

        # IoU between the two GT bboxes
        b1, b2 = gt_ann.bbox, rival_gt.bbox
        xi1, yi1 = max(b1.x1, b2.x1), max(b1.y1, b2.y1)
        xi2, yi2 = min(b1.x2, b2.x2), min(b1.y2, b2.y2)
        inter = max(0, xi2 - xi1) * max(0, yi2 - yi1)
        union = b1.area() + b2.area() - inter
        iou = inter / union if union > 0 else 0.0
        print(f"\n  IoU between GT {GT_UUID} and GT {rival_gt.uuid}: {iou:.4f}")

        # Plot the match with BOTH GT bboxes
        det = dataset.get_detection(best_det_id)
        img = Image.open(coco_loader.get_image_path(coco_image)).convert('RGB')
        fig, ax = plt.subplots(1, 1, figsize=(16, 10))
        plot_detection_vs_gt(det, rival_gt, img, img_w, img_h, target_size, patch_size, ax=ax,
                             title_extra=f'\nRival GT {rival_gt.uuid}  |  Original GT {GT_UUID}  |  IoU={iou:.4f}')

        # Draw the original GT bbox in orange
        gx1, gy1, gx2, gy2 = gt_ann.bbox.x1, gt_ann.bbox.y1, gt_ann.bbox.x2, gt_ann.bbox.y2
        ax.add_patch(Rectangle((gx1, gy1), gx2 - gx1, gy2 - gy1,
                                fill=False, edgecolor='orange', linewidth=2.5, linestyle=':'))

        legend_elements = [
            Line2D([0], [0], color='lime', linewidth=2.5, label=f'Rival GT {rival_gt.uuid[:8]}...'),
            Line2D([0], [0], color='orange', linewidth=2.5, linestyle=':', label=f'Original GT {GT_UUID[:8]}...'),
            Line2D([0], [0], color='dodgerblue', linewidth=2.5, linestyle='--', label='Detection bbox'),
        ]
        ax.legend(handles=legend_elements, loc='upper right', fontsize=9)

        # Expand view to include both GT bboxes
        all_x = [gx1, gx2, rival_gt.bbox.x1, rival_gt.bbox.x2, det.bbox[0].item(), det.bbox[2].item()]
        all_y = [gy1, gy2, rival_gt.bbox.y1, rival_gt.bbox.y2, det.bbox[1].item(), det.bbox[3].item()]
        ax.set_xlim(min(all_x) - 50, max(all_x) + 50)
        ax.set_ylim(max(all_y) + 50, min(all_y) - 50)

        plt.tight_layout()
        plt.show()
        break
else:
    print("This detection was NOT matched to any GT (not in matched list)")