In [None]:
# | default_exp utils/bounding_boxes

# Imports

In [None]:
# | export


from typing import Literal

import torch
from monai.data.box_utils import box_iou
from scipy.optimize import linear_sum_assignment

In [None]:
from monai.data.box_utils import convert_box_to_standard_mode

# Functions

### Helper

In [None]:
# | export


def sort_by_first_column_descending(tensor: torch.Tensor) -> torch.Tensor:
    """Helper function to sort a tensor in descending order based on values in first column"""
    return tensor[torch.argsort(tensor[:, 0], descending=True)]

### Main

In [None]:
# | export


def get_tps_fps_fns(
    pred_bboxes: list[torch.Tensor],
    pred_confidence_scores: list[torch.Tensor],
    target_bboxes: list[torch.Tensor],
    iou_threshold: float,
    matching_method: Literal["coco", "hungarian"] = "coco",
    min_confidence_threshold: float = 0.0,
    max_bboxes_per_image: int | None = None,
    return_intermediate_counts: bool = False,
) -> (
    tuple[set[tuple[int, int, int]], set[tuple[int, int]], set[tuple[int, int]]]
    | tuple[set, set, set, list[tuple[int, int, int]]]
):
    """Given predicted and target bounding boxes, their confidence scores, and an IOU threshold, get a matching of
    true positives, and a set of false positives and false negatives.

    Args:
        pred_bboxes: A list of length B containing tensors of shape (NP, 4) or (NP, 6) containing the predicted bounding
            box parameters in xyxy or xyzxyz format.
        pred_confidence_scores: A list of length B containing tensors of shape (NP,) containing the predicted confidence
            scores for the corresponding bounding boxes.
        target_bboxes: A list of length B containing tensors of shape (NT, 4) or (NT, 6) containing the target bounding
            box parameters in xyxy or xyzxyz format.
        iou_threshold: The IOU threshold above which a predicted box is considered a match for a target box.
        matching_method: The method to use for matching predicted boxes to target boxes. 'coco' implements the greedy
            matching algorithm used in the COCO dataset. 'hungarian' implements the Hungarian algorithm for optimal
            matching. Note that 'hungarian' is more computationally expensive and may not scale well to large numbers of
            boxes.
        min_confidence_threshold: Minimum confidence score for a predicted box to be considered for matching.
        max_bboxes_per_image: If not None, consider only the top K predicted boxes per image based on confidence scores.
        return_intermediate_counts: Whether to return intermediate counts of true positives, false positives and false
            negatives after each prediction is considered. Useful for plotting precision-recall curves.

    Returns:
        The first set contains tuples of (b, p, t) where b is the batch index, p is the index of the predicted box
        and t is the index of the matched target box. The second set contains tuples of (b, p) where b is the batch
        index and p is the index of the false positive predicted box. The third set contains tuples of (b, t) where b
        is the batch index and t is the index of the false negative target box.
        If `return_intermediate_counts` is True, also returns a list of tuples of (TP, FP, FN) counts after each
        prediction.
    """
    assert (
        len(pred_bboxes) == len(pred_confidence_scores) == len(target_bboxes)
    ), "Batch size must be the same for all inputs"
    assert matching_method in ["coco", "hungarian"], "matching_method must be either 'coco' or 'hungarian'"

    B = len(pred_bboxes)

    # Join all confidence scores and keep track of which batch and which box they correspond to
    pred_confidence_scores_temp = []
    for b in range(B):
        _batch_index = torch.full_like(pred_confidence_scores[b], float(b))
        _offset_index = torch.arange(
            len(pred_confidence_scores[b]),
            device=pred_confidence_scores[b].device,
            dtype=pred_confidence_scores[b].dtype,
        )
        _confidence_scores = torch.stack([pred_confidence_scores[b], _batch_index, _offset_index], dim=-1)
        if max_bboxes_per_image is not None and len(_confidence_scores) > max_bboxes_per_image:
            _confidence_scores = sort_by_first_column_descending(_confidence_scores)[:max_bboxes_per_image]
        pred_confidence_scores_temp.append(_confidence_scores)
    pred_confidence_scores = torch.cat(pred_confidence_scores_temp, dim=0)
    pred_confidence_scores = pred_confidence_scores[torch.argsort(pred_confidence_scores[:, 0], descending=True)]
    del pred_confidence_scores_temp

    # Calculate IOUs between all predicted and target boxes
    ious = []
    hungarian_matchings = []
    for b in range(B):
        _ious = box_iou(pred_bboxes[b], target_bboxes[b])
        ious.append(_ious)
        # (NP, NT), where NP is number of predicted boxes and NT is number of target boxes

        if matching_method == "hungarian" and _ious.numel() > 0:
            # Calculate optimal matching using Hungarian algorithm
            _matching = linear_sum_assignment(-_ious.cpu().numpy())
            hungarian_matchings.append(_matching)

    matched_target_indices = [set() for _ in range(B)]
    tps, fps, fns = set(), set(), {(b, i) for b in range(B) for i in range(len(target_bboxes[b]))}

    # In descending order, update tp, fp, fn and calculate precision and recall at each step
    intermediate_counts = []
    for confidence_score, b, pred_offset in pred_confidence_scores:
        if confidence_score.item() < min_confidence_threshold:
            # Do not consider this and following predictions if confidence score is below threshold
            break

        b, pred_offset = int(b.item()), int(pred_offset.item())

        if matching_method == "coco":
            # COCO-style greedy matching
            pred_ious: torch.Tensor = ious[b][pred_offset].clone()
            # (NT,)

            if pred_ious.numel() > 0:
                if matched_target_indices[b]:
                    # Exclude already matched target boxes
                    pred_ious[list(matched_target_indices[b])] = -1.0
                # Exclude target boxes below IOU threshold
                pred_ious[pred_ious < iou_threshold] = -1.0

            if pred_ious.numel() == 0 or pred_ious.amax() < 0.0:
                # No valid target box to match with
                fps.add((b, pred_offset))
            else:
                target_offset = pred_ious.argmax().item()
                matched_target_indices[b].add(target_offset)
                tps.add((b, pred_offset, target_offset))
                fns.discard((b, target_offset))
        else:
            matched_pred_offsets, matched_target_offsets = hungarian_matchings[b]
            if pred_offset in matched_pred_offsets:
                target_offset_index = (matched_pred_offsets == pred_offset).nonzero()[0].item()
                target_offset = matched_target_offsets[target_offset_index]
                if ious[b][pred_offset, target_offset] >= iou_threshold:
                    tps.add((b, pred_offset, target_offset))
                    fns.discard((b, target_offset))
                else:
                    fps.add((b, pred_offset))
            else:
                fps.add((b, pred_offset))

        if return_intermediate_counts:
            intermediate_counts.append((len(tps), len(fps), len(fns)))

    if return_intermediate_counts:
        return tps, fps, fns, intermediate_counts
    return tps, fps, fns

In [None]:
# Random predicted and target boxes


pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 10, 6) * 128, "cccwhd") for i in range(5)]
pred_confidence_scores = [torch.rand(i + 10) for i in range(5)]
target_bboxes = [convert_box_to_standard_mode(torch.rand(i + 1 + 10 * (i % 2), 6) * 128, "cccwhd") for i in range(5)]

print([x.shape for x in pred_bboxes])
print([x.shape for x in target_bboxes])

get_tps_fps_fns(
    pred_bboxes,
    pred_confidence_scores,
    target_bboxes,
    iou_threshold=0.1,
    matching_method="coco",
    return_intermediate_counts=True,
)

[torch.Size([10, 6]), torch.Size([11, 6]), torch.Size([12, 6]), torch.Size([13, 6]), torch.Size([14, 6])]
[torch.Size([1, 6]), torch.Size([12, 6]), torch.Size([3, 6]), torch.Size([14, 6]), torch.Size([5, 6])]



[1m([0m
    [1m{[0m
        [1m([0m[1;36m4[0m, [1;36m9[0m, [1;36m4[0m[1m)[0m,
        [1m([0m[1;36m1[0m, [1;36m5[0m, [1;36m1[0m[1m)[0m,
        [1m([0m[1;36m4[0m, [1;36m11[0m, [1;36m3[0m[1m)[0m,
        [1m([0m[1;36m3[0m, [1;36m1[0m, [1;36m10[0m[1m)[0m,
        [1m([0m[1;36m3[0m, [1;36m7[0m, [1;36m13[0m[1m)[0m,
        [1m([0m[1;36m4[0m, [1;36m3[0m, [1;36m1[0m[1m)[0m,
        [1m([0m[1;36m1[0m, [1;36m0[0m, [1;36m10[0m[1m)[0m,
        [1m([0m[1;36m3[0m, [1;36m2[0m, [1;36m4[0m[1m)[0m,
        [1m([0m[1;36m2[0m, [1;36m2[0m, [1;36m1[0m[1m)[0m,
        [1m([0m[1;36m3[0m, [1;36m8[0m, [1;36m5[0m[1m)[0m,
        [1m([0m[1;36m1[0m, [1;36m6[0m, [1;36m4[0m[1m)[0m,
        [1m([0m[1;36m3[0m, [1;36m0[0m, [1;36m1[0m[1m)[0m,
        [1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m9[0m[1m)[0m,
        [1m([0m[1;36m1[0m, [1;36m10[0m, [1;36m5[0m[1m)[0m,
        [1m([0m

In [None]:
# Same boxes but with hungarian matching

get_tps_fps_fns(
    pred_bboxes,
    pred_confidence_scores,
    target_bboxes,
    iou_threshold=0.1,
    matching_method="hungarian",
    return_intermediate_counts=True,
)


[1m([0m
    [1m{[0m
        [1m([0m[1;36m1[0m, [1;36m5[0m, [1;36m1[0m[1m)[0m,
        [1m([0m[1;36m3[0m, [1;36m7[0m, [1;36m13[0m[1m)[0m,
        [1m([0m[1;36m4[0m, [1;36m3[0m, [1;36m1[0m[1m)[0m,
        [1m([0m[1;36m4[0m, [1;36m7[0m, [1;36m4[0m[1m)[0m,
        [1m([0m[1;36m1[0m, [1;36m0[0m, [1;36m10[0m[1m)[0m,
        [1m([0m[1;36m3[0m, [1;36m1[0m, [1;36m9[0m[1m)[0m,
        [1m([0m[1;36m2[0m, [1;36m2[0m, [1;36m1[0m[1m)[0m,
        [1m([0m[1;36m4[0m, [1;36m1[0m, [1;36m3[0m[1m)[0m,
        [1m([0m[1;36m3[0m, [1;36m8[0m, [1;36m5[0m[1m)[0m,
        [1m([0m[1;36m3[0m, [1;36m2[0m, [1;36m10[0m[1m)[0m,
        [1m([0m[1;36m1[0m, [1;36m6[0m, [1;36m4[0m[1m)[0m,
        [1m([0m[1;36m3[0m, [1;36m0[0m, [1;36m1[0m[1m)[0m,
        [1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m9[0m[1m)[0m,
        [1m([0m[1;36m1[0m, [1;36m10[0m, [1;36m5[0m[1m)[0m,
        [1m([0m[

# nbdev

In [None]:
!nbdev_export