In [None]:
# | default_exp metrics/detection

# Imports

In [None]:
# | export


from typing import Literal

import torch
from torchmetrics import Metric

from vision_architectures.utils.bounding_boxes import _IndexedConfidenceScore, get_tps_fps_fns

In [None]:
import torch.distributed as dist
import torch.multiprocessing as mp
from monai.data.box_utils import convert_box_to_standard_mode

# Mean Average Precision

### Direct function

In [None]:
# | export


def mean_average_precision_mean_average_recall(
    pred_bboxes: list[torch.Tensor],
    pred_confidence_scores: list[torch.Tensor],
    target_bboxes: list[torch.Tensor],
    target_classes: list[torch.Tensor],
    iou_thresholds: list[float] = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95],
    average_precision_num_points: int = 101,
    min_confidence_threshold: float = 0.0,
    max_bboxes_per_image: int | None = 100,
    return_intermediates: bool = False,
) -> tuple[float, float] | tuple[float, float, dict[float, dict[int, float]], dict[float, dict[int, float]]]:
    """Calculate the COCO mean average precision (mAP) for object detection.

    Args:
        pred_bboxes: A list of length B containing tensors of shape (NP, 4) or (NP, 6) containing the predicted bounding
            box parameters in xyxy or xyzxyz format.
        pred_confidence_scores: A list of length B containing tensors of shape (NP, 1+num_classes) containing the
            predicted confidence scores for each class. Note that the first column corresponds to the "no-object" class,
            and bounding boxes which fall in this category are ignored.
        target_bboxes: A list of length B containing tensors of shape (NT, 4) or (NT, 6) containing the target bounding
            box parameters in xyxy or xyzxyz format.
        target_classes: A list of length B containing tensors of shape (NT,) containing the target class labels for the
            objects in the image.
        iou_thresholds: A list of IoU thresholds to use for calculating mAP and mAR.
        average_precision_num_points: Number of points over which to calculate average precision.
        min_confidence_threshold: Minimum confidence probability threshold to consider a prediction.
        max_bboxes_per_image: Maximum number of bounding boxes to consider per image. If more are present, only the top
            `max_bboxes_per_image` boxes based on confidence scores are considered. If set to None, all bounding boxes
            are considered.
        return_intermediates: If True, return intermediate values used to calculate mAP and mAR.

    Returns:
        The mean average precision (mAP) and mean average recall (mAR) across all classes and IoU thresholds for the
        entire dataset.
        If `return_intermediates` is True, also returns two dictionaries containing the average precision and average
        recall for each class at each IoU threshold.
    """
    # Some basic tests
    assert len(pred_bboxes) == len(pred_confidence_scores) == len(target_bboxes) == len(target_classes), (
        f"All input lists must have the same length. Got lengths: {len(pred_bboxes)}, {len(pred_confidence_scores)}, "
        f"{len(target_bboxes)}, {len(target_classes)}."
    )
    assert all(
        pred_confidence_score.ndim == 2 for pred_confidence_score in pred_confidence_scores
    ), "Each prediction confidence score input list element must be a 2D tensor."

    # Set some globaly used variables
    B = len(pred_bboxes)
    num_classes = pred_confidence_scores[0].shape[-1] - 1

    # Continue tests
    assert all(
        pred_bbox.shape[0] == pred_confidence_score.shape[0]
        for pred_bbox, pred_confidence_score in zip(pred_bboxes, pred_confidence_scores)
    ), "Each prediction input list element must have the same number of bounding boxes."
    assert all(
        pred_bbox.shape[1] == 4 or pred_bbox.shape[1] == 6 for pred_bbox in pred_bboxes
    ), "Prediction bounding boxes must have shape (NP, 4) or (NP, 6)."
    assert all(
        pred_confidence_score.shape[-1] == num_classes + 1 for pred_confidence_score in pred_confidence_scores
    ), "Prediction class probabilities must have shape (NP, 1 + num_classes)."
    assert all(
        target_bbox.shape[0] == target_class.shape[0]
        for target_bbox, target_class in zip(target_bboxes, target_classes)
    ), "Each target must have the same number of bounding boxes."
    assert all(
        target_bbox.shape[1] == 4 or target_bbox.shape[1] == 6 for target_bbox in target_bboxes
    ), "Target bounding boxes must have shape (NT, 4) or (NT, 6)."
    assert all(
        (target_class >= 0).all() and (target_class <= num_classes).all() for target_class in target_classes
    ), f"Target class labels must be between 0 and {num_classes} inclusive."

    # Split everything based on different classes.
    pred_bboxes_by_class = [[] for _ in range(num_classes)]
    pred_confidences_scores_by_class = [[] for _ in range(num_classes)]
    target_bboxes_by_class = [[] for _ in range(num_classes)]
    for b in range(B):
        pred_classes = torch.argmax(pred_confidence_scores[b], dim=-1)
        # (NP,)

        for c in range(num_classes):
            pred_classes_mask = pred_classes == (c + 1)
            # (NP,)
            target_classes_mask = target_classes[b] == (c + 1)
            # (NT,)

            pred_bboxes_by_class[c].append(pred_bboxes[b][pred_classes_mask])
            pred_confidences_scores_by_class[c].append(pred_confidence_scores[b][pred_classes_mask][:, c + 1])
            # (NP,)

            target_bboxes_by_class[c].append(target_bboxes[b][target_classes_mask])
            # (NT,)

    # Limit number of bounding boxes per image if applicable
    if max_bboxes_per_image is not None:
        for b in range(B):
            _confidence_scores = []
            for c in range(num_classes):
                if pred_bboxes_by_class[c][b].numel() > 0:
                    _confidence_scores.extend(
                        _IndexedConfidenceScore.from_batch(
                            pred_confidences_scores_by_class[c][b], batch_index=b, class_index=c
                        )
                    )
            if len(_confidence_scores) == 0:
                continue

            if len(_confidence_scores) > max_bboxes_per_image:
                _confidence_scores = sorted(_confidence_scores, reverse=True)
                topk_confidences = _confidence_scores[:max_bboxes_per_image]
                for c in range(num_classes):
                    topk_confidences_with_class = [x for x in topk_confidences if x.class_index == c]

                    pred_bboxes_by_class[c][b] = torch.stack(
                        [pred_bboxes_by_class[c][b][x.offset_index] for x in topk_confidences_with_class],
                        dim=0,
                        dtype=pred_bboxes_by_class[c][b].dtype,
                        device=pred_bboxes_by_class[c][b].device,
                    )
                    # (NP', 4) or (NP', 6)
                    pred_confidences_scores_by_class[c][b] = torch.stack(
                        [x.score for x in topk_confidences_with_class],
                        dim=0,
                        dtype=pred_confidences_scores_by_class[c][b].dtype,
                        device=pred_confidences_scores_by_class[c][b].device,
                    )
                    # (NP',)

    # For each IOU threshold, calculate average precision and average recall
    average_precisions = {}
    average_recalls = {}
    for iou_threshold in iou_thresholds:
        # For each class calculate average precision and average recall
        class_average_precisions = {}
        class_average_recalls = {}
        for c in range(num_classes):
            # If no target boxes for this class, skip it
            if all(target_bbox.numel() == 0 for target_bbox in target_bboxes_by_class[c]):
                class_average_precisions[c + 1] = float("nan")
                class_average_recalls[c + 1] = float("nan")
                continue

            _, _, _, intermediate_counts = get_tps_fps_fns(
                pred_bboxes=pred_bboxes_by_class[c],
                pred_confidence_scores=pred_confidences_scores_by_class[c],
                target_bboxes=target_bboxes_by_class[c],
                iou_threshold=iou_threshold,
                matching_method="coco",
                min_confidence_threshold=min_confidence_threshold,
                max_bboxes_per_image=None,  # As this has already been done across classes
                return_intermediate_counts=True,
            )
            intermediate_counts = torch.tensor(intermediate_counts, device=pred_bboxes[0].device, dtype=torch.float32)
            # (NC, 3) where the first column is TP, second is FP and third is FN for each prediction considered
            precisions = (intermediate_counts[:, 0] + 1e-5) / (
                intermediate_counts[:, 0] + intermediate_counts[:, 1] + 1e-5
            )
            recalls = (intermediate_counts[:, 0] + 1e-5) / (
                intermediate_counts[:, 0] + intermediate_counts[:, 2] + 1e-5
            )
            # (NC,) each

            # Precision envelope: P_interp(r) = max_{r' >= r} P(r')
            enveloped_precisions = precisions.clone()
            for i in range(len(enveloped_precisions) - 2, -1, -1):
                if enveloped_precisions[i] < enveloped_precisions[i + 1]:
                    enveloped_precisions[i] = enveloped_precisions[i + 1]

            # Calculate average precision using step-wise interpolation
            recall_samples = torch.linspace(0, 1, average_precision_num_points, device=recalls.device)
            idxs = torch.searchsorted(recalls, recall_samples, side="left")
            valid = idxs < enveloped_precisions.numel()
            enveloped_precisions_at_t = torch.zeros_like(recall_samples)
            enveloped_precisions_at_t[valid] = enveloped_precisions[idxs[valid]]
            class_average_precisions[c + 1] = enveloped_precisions_at_t.mean().item()

            # Calculate average recall i.e. maximum recall achieved at this IoU threshold
            class_average_recalls[c + 1] = recalls.max().item() if recalls.numel() > 0 else 0.0

        average_precisions[iou_threshold] = class_average_precisions
        average_recalls[iou_threshold] = class_average_recalls

    map_metric = torch.nanmean(
        torch.stack([torch.tensor(ap) for iou_aps in average_precisions.values() for ap in iou_aps.values()])
    ).item()
    mar_metric = torch.nanmean(
        torch.stack([torch.tensor(ar) for iou_ars in average_recalls.values() for ar in iou_ars.values()])
    ).item()

    if return_intermediates:
        return map_metric, mar_metric, average_precisions, average_recalls
    return map_metric, mar_metric


# Create aliases
map_mar = mean_average_precision_mean_average_recall
mean_average_precision_recall = mean_average_precision_mean_average_recall

In [None]:
# Random predicted and target boxes

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 10, 6) * 128, "cccwhd") for i in range(25)]
pred_confidence_scores = [torch.rand(i + 10, 6) for i in range(25)]

target_bboxes = [convert_box_to_standard_mode(torch.rand(i + 1 + 10 * (i % 2), 6) * 128, "cccwhd") for i in range(25)]
target_classes = [torch.randint(1, 6, (i + 1 + 10 * (i % 2),)) for i in range(25)]

print([x.shape for x in pred_confidence_scores])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_confidence_scores,
    target_bboxes,
    target_classes,
    iou_thresholds=[0.001],
    return_intermediates=True,
)

[torch.Size([10, 6]), torch.Size([11, 6]), torch.Size([12, 6]), torch.Size([13, 6]), torch.Size([14, 6]), torch.Size([15, 6]), torch.Size([16, 6]), torch.Size([17, 6]), torch.Size([18, 6]), torch.Size([19, 6]), torch.Size([20, 6]), torch.Size([21, 6]), torch.Size([22, 6]), torch.Size([23, 6]), torch.Size([24, 6]), torch.Size([25, 6]), torch.Size([26, 6]), torch.Size([27, 6]), torch.Size([28, 6]), torch.Size([29, 6]), torch.Size([30, 6]), torch.Size([31, 6]), torch.Size([32, 6]), torch.Size([33, 6]), torch.Size([34, 6])]
[torch.Size([1]), torch.Size([12]), torch.Size([3]), torch.Size([14]), torch.Size([5]), torch.Size([16]), torch.Size([7]), torch.Size([18]), torch.Size([9]), torch.Size([20]), torch.Size([11]), torch.Size([22]), torch.Size([13]), torch.Size([24]), torch.Size([15]), torch.Size([26]), torch.Size([17]), torch.Size([28]), torch.Size([19]), torch.Size([30]), torch.Size([21]), torch.Size([32]), torch.Size([23]), torch.Size([34]), torch.Size([25])]
tensor([[0.7500, 0.0000],
  


[1m([0m
    [1;36m0.3113456070423126[0m,
    [1;36m0.5027083158493042[0m,
    [1m{[0m
        [1;36m0.001[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.3091408610343933[0m,
            [1;36m2[0m: [1;36m0.265804260969162[0m,
            [1;36m3[0m: [1;36m0.349173367023468[0m,
            [1;36m4[0m: [1;36m0.274908185005188[0m,
            [1;36m5[0m: [1;36m0.35770130157470703[0m
        [1m}[0m
    [1m}[0m,
    [1m{[0m
        [1;36m0.001[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.45370370149612427[0m,
            [1;36m2[0m: [1;36m0.4637681245803833[0m,
            [1;36m3[0m: [1;36m0.5952380895614624[0m,
            [1;36m4[0m: [1;36m0.5058823823928833[0m,
            [1;36m5[0m: [1;36m0.49494948983192444[0m
        [1m}[0m
    [1m}[0m
[1m)[0m

In [None]:
# Predicted boxes are approximately equal to target boxes i.e. precision, recall should be high

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i, 6) * 128, "cccwhd") for i in range(25)]
pred_confidence_scores = [torch.rand(i, 6) for i in range(25)]

target_bboxes = [pred_bboxes[i] + 0.5 for i in range(25)]
target_classes = [pred_confidence_scores[i].argmax(dim=-1) for i in range(25)]

print([x.shape for x in pred_confidence_scores])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_confidence_scores,
    target_bboxes,
    target_classes,
    return_intermediates=True,
)

[torch.Size([0, 6]), torch.Size([1, 6]), torch.Size([2, 6]), torch.Size([3, 6]), torch.Size([4, 6]), torch.Size([5, 6]), torch.Size([6, 6]), torch.Size([7, 6]), torch.Size([8, 6]), torch.Size([9, 6]), torch.Size([10, 6]), torch.Size([11, 6]), torch.Size([12, 6]), torch.Size([13, 6]), torch.Size([14, 6]), torch.Size([15, 6]), torch.Size([16, 6]), torch.Size([17, 6]), torch.Size([18, 6]), torch.Size([19, 6]), torch.Size([20, 6]), torch.Size([21, 6]), torch.Size([22, 6]), torch.Size([23, 6]), torch.Size([24, 6])]
[torch.Size([0]), torch.Size([1]), torch.Size([2]), torch.Size([3]), torch.Size([4]), torch.Size([5]), torch.Size([6]), torch.Size([7]), torch.Size([8]), torch.Size([9]), torch.Size([10]), torch.Size([11]), torch.Size([12]), torch.Size([13]), torch.Size([14]), torch.Size([15]), torch.Size([16]), torch.Size([17]), torch.Size([18]), torch.Size([19]), torch.Size([20]), torch.Size([21]), torch.Size([22]), torch.Size([23]), torch.Size([24])]
tensor([[1.0000, 0.0213],
        [1.0000, 


[1m([0m
    [1;36m0.7450447082519531[0m,
    [1;36m0.8355322480201721[0m,
    [1m{[0m
        [1;36m0.5[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.9559721946716309[0m,
            [1;36m2[0m: [1;36m0.9739857316017151[0m,
            [1;36m3[0m: [1;36m0.8929669260978699[0m,
            [1;36m4[0m: [1;36m0.8727788925170898[0m,
            [1;36m5[0m: [1;36m0.9252988696098328[0m
        [1m}[0m,
        [1;36m0.55[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.9559721946716309[0m,
            [1;36m2[0m: [1;36m0.9739857316017151[0m,
            [1;36m3[0m: [1;36m0.8929669260978699[0m,
            [1;36m4[0m: [1;36m0.856163501739502[0m,
            [1;36m5[0m: [1;36m0.9252988696098328[0m
        [1m}[0m,
        [1;36m0.6[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.9559721946716309[0m,
            [1;36m2[0m: [1;36m0.9739857316017151[0m,
            [1;36m3[0m: [1;36m0.8929669260978699[0m,
            [1;36m4[0m: 

In [None]:
# Predicted boxes are subset of target boxes

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 1, 6) * 128, "cccwhd") for i in range(25)]
pred_confidence_scores = [torch.rand(i + 1, 6) for i in range(25)]

target_bboxes = [
    torch.cat([pred_bboxes[i], convert_box_to_standard_mode(torch.rand(i, 6) * 128, "cccwhd")]) for i in range(25)
]
target_classes = [torch.cat([pred_confidence_scores[i].argmax(dim=-1), torch.randint(1, 6, (i,))]) for i in range(25)]

print([x.shape for x in pred_confidence_scores])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_confidence_scores,
    target_bboxes,
    target_classes,
    return_intermediates=True,
)

[torch.Size([1, 6]), torch.Size([2, 6]), torch.Size([3, 6]), torch.Size([4, 6]), torch.Size([5, 6]), torch.Size([6, 6]), torch.Size([7, 6]), torch.Size([8, 6]), torch.Size([9, 6]), torch.Size([10, 6]), torch.Size([11, 6]), torch.Size([12, 6]), torch.Size([13, 6]), torch.Size([14, 6]), torch.Size([15, 6]), torch.Size([16, 6]), torch.Size([17, 6]), torch.Size([18, 6]), torch.Size([19, 6]), torch.Size([20, 6]), torch.Size([21, 6]), torch.Size([22, 6]), torch.Size([23, 6]), torch.Size([24, 6]), torch.Size([25, 6])]
[torch.Size([1]), torch.Size([3]), torch.Size([5]), torch.Size([7]), torch.Size([9]), torch.Size([11]), torch.Size([13]), torch.Size([15]), torch.Size([17]), torch.Size([19]), torch.Size([21]), torch.Size([23]), torch.Size([25]), torch.Size([27]), torch.Size([29]), torch.Size([31]), torch.Size([33]), torch.Size([35]), torch.Size([37]), torch.Size([39]), torch.Size([41]), torch.Size([43]), torch.Size([45]), torch.Size([47]), torch.Size([49])]
tensor([[1.0000, 0.0097],
        [1.


[1m([0m
    [1;36m0.4792078733444214[0m,
    [1;36m0.4800601601600647[0m,
    [1m{[0m
        [1;36m0.5[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.49504950642585754[0m,
            [1;36m2[0m: [1;36m0.5346534848213196[0m,
            [1;36m3[0m: [1;36m0.4356435537338257[0m,
            [1;36m4[0m: [1;36m0.48514851927757263[0m,
            [1;36m5[0m: [1;36m0.4455445408821106[0m
        [1m}[0m,
        [1;36m0.55[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.49504950642585754[0m,
            [1;36m2[0m: [1;36m0.5346534848213196[0m,
            [1;36m3[0m: [1;36m0.4356435537338257[0m,
            [1;36m4[0m: [1;36m0.48514851927757263[0m,
            [1;36m5[0m: [1;36m0.4455445408821106[0m
        [1m}[0m,
        [1;36m0.6[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.49504950642585754[0m,
            [1;36m2[0m: [1;36m0.5346534848213196[0m,
            [1;36m3[0m: [1;36m0.4356435537338257[0m,
            [1;36m4

In [None]:
# Target boxes are subset of prediction boxes i.e. recall should be high

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 10, 6) * 128, "cccwhd") for i in range(25)]
pred_confidence_scores = [torch.rand(i + 10, 6) for i in range(25)]

target_bboxes = [pred_bboxes[i][: i + 1] for i in range(25)]
target_classes = [pred_confidence_scores[i][: i + 1].argmax(dim=-1) for i in range(25)]

print([x.shape for x in pred_confidence_scores])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_confidence_scores,
    target_bboxes,
    target_classes,
    return_intermediates=True,
)

[torch.Size([10, 6]), torch.Size([11, 6]), torch.Size([12, 6]), torch.Size([13, 6]), torch.Size([14, 6]), torch.Size([15, 6]), torch.Size([16, 6]), torch.Size([17, 6]), torch.Size([18, 6]), torch.Size([19, 6]), torch.Size([20, 6]), torch.Size([21, 6]), torch.Size([22, 6]), torch.Size([23, 6]), torch.Size([24, 6]), torch.Size([25, 6]), torch.Size([26, 6]), torch.Size([27, 6]), torch.Size([28, 6]), torch.Size([29, 6]), torch.Size([30, 6]), torch.Size([31, 6]), torch.Size([32, 6]), torch.Size([33, 6]), torch.Size([34, 6])]
[torch.Size([1]), torch.Size([2]), torch.Size([3]), torch.Size([4]), torch.Size([5]), torch.Size([6]), torch.Size([7]), torch.Size([8]), torch.Size([9]), torch.Size([10]), torch.Size([11]), torch.Size([12]), torch.Size([13]), torch.Size([14]), torch.Size([15]), torch.Size([16]), torch.Size([17]), torch.Size([18]), torch.Size([19]), torch.Size([20]), torch.Size([21]), torch.Size([22]), torch.Size([23]), torch.Size([24]), torch.Size([25])]



[1m([0m
    [1;36m0.659028172492981[0m,
    [1;36m1.0[0m,
    [1m{[0m
        [1;36m0.5[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.5099009871482849[0m,
            [1;36m2[0m: [1;36m0.8180058598518372[0m,
            [1;36m3[0m: [1;36m0.7023004293441772[0m,
            [1;36m4[0m: [1;36m0.610049307346344[0m,
            [1;36m5[0m: [1;36m0.6548842787742615[0m
        [1m}[0m,
        [1;36m0.55[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.5099009871482849[0m,
            [1;36m2[0m: [1;36m0.8180058598518372[0m,
            [1;36m3[0m: [1;36m0.7023004293441772[0m,
            [1;36m4[0m: [1;36m0.610049307346344[0m,
            [1;36m5[0m: [1;36m0.6548842787742615[0m
        [1m}[0m,
        [1;36m0.6[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.5099009871482849[0m,
            [1;36m2[0m: [1;36m0.8180058598518372[0m,
            [1;36m3[0m: [1;36m0.7023004293441772[0m,
            [1;36m4[0m: [1;36m0.61004930

### Lightning metrics

In [None]:
# | export


class _MeanAveragePrecisionMeanAverageRecallBase(Metric):
    """Calculate the COCO mean average precision (mAP) and mean average recall (mAR) for object detection."""

    is_differentiable: bool = False
    higher_is_better: bool = True
    plot_lower_bound: float = 0.0
    plot_upper_bound: float = 1.0

    def __init__(
        self,
        iou_thresholds: list[float] = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95],
        average_precision_num_points: int = 101,
        min_confidence_threshold: float = 0.0,
        max_bboxes_per_image: int | None = 100,
        *args,
        **kwargs
    ):
        """Initialize the MeanAveragePrecisionMeanAverageRecall metric.

        Args:
            num_classes: Number of classes in the dataset.
            iou_thresholds: A list of IoU thresholds to use for calculating mAP and mAR.
            average_precision_num_points: Number of points over which to calculate average precision.
            min_confidence_threshold: Minimum confidence score threshold to consider a prediction.
            max_bboxes_per_image: Maximum number of bounding boxes to consider per image. If more are present, only the
                top `max_bboxes_per_image` boxes based on confidence scores are considered.
        """
        super().__init__(*args, **kwargs)

        self.iou_thresholds = iou_thresholds
        self.average_precision_num_points = average_precision_num_points
        self.min_confidence_threshold = min_confidence_threshold
        self.max_bboxes_per_image = max_bboxes_per_image

        self.add_state("pred_bboxes", [], dist_reduce_fx=None, persistent=False)
        self.add_state("pred_confidence_scores", [], dist_reduce_fx=None, persistent=False)
        self.add_state("target_bboxes", [], dist_reduce_fx=None, persistent=False)
        self.add_state("target_classes", [], dist_reduce_fx=None, persistent=False)

    def update(
        self,
        pred_bboxes: list[torch.Tensor],
        pred_confidence_scores: list[torch.Tensor],
        target_bboxes: list[torch.Tensor],
        target_classes: list[torch.Tensor],
    ):
        self.pred_bboxes.extend(pred_bboxes)
        self.pred_confidence_scores.extend(pred_confidence_scores)
        self.target_bboxes.extend(target_bboxes)
        self.target_classes.extend(target_classes)

    def run_functional(self, return_metrics: Literal["map_only", "mar_only"]):
        map_metric, mar_metric = mean_average_precision_mean_average_recall(
            self.pred_bboxes,
            self.pred_confidence_scores,
            self.target_bboxes,
            self.target_classes,
            iou_thresholds=self.iou_thresholds,
            average_precision_num_points=self.average_precision_num_points,
            min_confidence_threshold=self.min_confidence_threshold,
            max_bboxes_per_image=self.max_bboxes_per_image,
        )
        if return_metrics == "map_only":
            return torch.tensor(map_metric, device=self.pred_bboxes[0].device)
        elif return_metrics == "mar_only":
            return torch.tensor(mar_metric, device=self.pred_bboxes[0].device)
        raise NotImplementedError('Only "map_only" and "mar_only" are supported.')

In [None]:
# | export


class MeanAveragePrecision(_MeanAveragePrecisionMeanAverageRecallBase):
    """Calculate the COCO mean average precision (mAP) for object detection."""

    def compute(self):
        return self.run_functional("map_only")

In [None]:
test = MeanAveragePrecision(max_bboxes_per_image=100)

for _ in range(100):
    pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 5, 6) * 128, "cccwhd") for i in range(10)]
    pred_confidence_scores = [torch.rand(i + 5, 4) for i in range(10)]

    target_bboxes = pred_bboxes
    target_classes = [torch.randint(1, 4, (i + 5,)) for i in range(10)]

    map_metric = test(
        pred_bboxes,
        pred_confidence_scores,
        target_bboxes,
        target_classes,
    )
    print(map_metric)

print(len(test.pred_bboxes))
test.reset()
print(len(test.pred_bboxes))

tensor(0.1534)
tensor(0.1436)
tensor(0.0960)
tensor(0.1073)
tensor(0.1372)
tensor(0.1314)
tensor(0.1473)
tensor(0.1393)
tensor(0.0852)
tensor(0.1629)
tensor(0.1593)
tensor(0.0826)
tensor(0.1325)
tensor(0.0982)
tensor(0.2100)
tensor(0.1371)
tensor(0.0835)
tensor(0.1193)
tensor(0.1173)
tensor(0.1142)
tensor(0.2146)
tensor(0.0750)
tensor(0.1102)
tensor(0.1462)
tensor(0.0917)
tensor(0.0990)
tensor(0.1126)
tensor(0.0715)
tensor(0.1182)
tensor(0.0591)
tensor(0.0613)
tensor(0.0848)
tensor(0.1237)
tensor(0.1612)
tensor(0.1387)
tensor(0.1147)
tensor(0.0965)
tensor(0.0440)
tensor(0.1512)
tensor(0.1822)
tensor(0.1268)
tensor(0.1262)
tensor(0.0915)
tensor(0.0932)
tensor(0.1481)
tensor(0.0896)
tensor(0.1205)
tensor(0.1107)
tensor(0.1586)
tensor(0.0814)
tensor(0.0941)
tensor(0.1501)
tensor(0.1387)
tensor(0.0993)
tensor(0.1589)
tensor(0.1150)
tensor(0.1353)
tensor(0.1593)
tensor(0.1467)
tensor(0.1103)
tensor(0.1373)
tensor(0.0846)
tensor(0.1396)
tensor(0.1176)
tensor(0.0986)
tensor(0.1200)
tensor(0.0

In [None]:
# | export


class MeanAverageRecall(_MeanAveragePrecisionMeanAverageRecallBase):
    """Calculate the COCO mean average recall (mAR) for object detection."""

    def compute(self):
        return self.run_functional("mar_only")

In [None]:
test = MeanAverageRecall(max_bboxes_per_image=100)

for _ in range(100):
    pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 5, 6) * 128, "cccwhd") for i in range(10)]
    pred_confidence_scores = [torch.rand(i + 5, 4) for i in range(10)]

    target_bboxes = pred_bboxes
    target_classes = [torch.randint(1, 4, (i + 5,)) for i in range(10)]

    mar_metric = test(
        pred_bboxes,
        pred_confidence_scores,
        pred_bboxes,
        target_classes,
    )
    print(mar_metric)

print(len(test.pred_bboxes))
test.reset()
print(len(test.pred_bboxes))

tensor(0.2294)
tensor(0.2866)
tensor(0.3679)
tensor(0.2191)
tensor(0.2506)
tensor(0.2643)
tensor(0.2170)
tensor(0.2434)
tensor(0.3153)
tensor(0.2122)
tensor(0.2508)
tensor(0.2428)
tensor(0.1789)
tensor(0.2537)
tensor(0.2434)
tensor(0.2165)
tensor(0.2849)
tensor(0.2144)
tensor(0.2489)
tensor(0.2104)
tensor(0.1815)
tensor(0.2860)
tensor(0.1706)
tensor(0.2419)
tensor(0.2731)
tensor(0.2528)
tensor(0.1822)
tensor(0.1325)
tensor(0.2013)
tensor(0.1804)
tensor(0.2560)
tensor(0.2621)
tensor(0.2231)
tensor(0.2320)
tensor(0.2327)
tensor(0.3896)
tensor(0.3112)
tensor(0.2449)
tensor(0.2634)
tensor(0.2114)
tensor(0.3195)
tensor(0.2429)
tensor(0.2667)
tensor(0.2400)
tensor(0.3582)
tensor(0.2407)
tensor(0.2348)
tensor(0.2229)
tensor(0.2217)
tensor(0.3180)
tensor(0.3130)
tensor(0.1933)
tensor(0.2734)
tensor(0.2123)
tensor(0.2478)
tensor(0.2643)
tensor(0.2961)
tensor(0.2246)
tensor(0.1929)
tensor(0.2511)
tensor(0.2500)
tensor(0.2275)
tensor(0.2325)
tensor(0.2242)
tensor(0.2329)
tensor(0.3358)
tensor(0.2

In [None]:
# | export


class AveragePrecision(MeanAveragePrecision):
    """Calculate the COCO average precision (AP) for object detection."""

    def __init__(self, iou_threshold: float, *args, **kwargs):
        iou_thresholds = [iou_threshold]
        super().__init__(iou_thresholds=iou_thresholds, *args, **kwargs)

In [None]:
test = AveragePrecision(iou_threshold=0.1, max_bboxes_per_image=100)

for _ in range(100):
    pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 5, 6) * 128, "cccwhd") for i in range(10)]
    pred_confidence_scores = [torch.rand(i + 5, 4) for i in range(10)]

    target_bboxes = pred_bboxes
    target_classes = [torch.randint(1, 4, (i + 5,)) for i in range(10)]

    ap10 = test(
        pred_bboxes,
        pred_confidence_scores,
        pred_bboxes,
        target_classes,
    )
    print(ap10)

print(len(test.pred_bboxes))
test.reset()
print(len(test.pred_bboxes))

tensor(0.1099)
tensor(0.2714)
tensor(0.0887)
tensor(0.1543)
tensor(0.2709)
tensor(0.3276)
tensor(0.1317)
tensor(0.1898)
tensor(0.2650)
tensor(0.2256)
tensor(0.1914)
tensor(0.1290)
tensor(0.2257)
tensor(0.1951)
tensor(0.1970)
tensor(0.1395)
tensor(0.1860)
tensor(0.1978)
tensor(0.1468)
tensor(0.1438)
tensor(0.1571)
tensor(0.0773)
tensor(0.1463)
tensor(0.1900)
tensor(0.1046)
tensor(0.1162)
tensor(0.1430)
tensor(0.1342)
tensor(0.1491)
tensor(0.2177)
tensor(0.1767)
tensor(0.2247)
tensor(0.2434)
tensor(0.1874)
tensor(0.1405)
tensor(0.1491)
tensor(0.1213)
tensor(0.1820)
tensor(0.0970)
tensor(0.1547)
tensor(0.1906)
tensor(0.1782)
tensor(0.1877)
tensor(0.1869)
tensor(0.2053)
tensor(0.1565)
tensor(0.1814)
tensor(0.1334)
tensor(0.2548)
tensor(0.1222)
tensor(0.1800)
tensor(0.1975)
tensor(0.1806)
tensor(0.2372)
tensor(0.2073)
tensor(0.1560)
tensor(0.1871)
tensor(0.1311)
tensor(0.1403)
tensor(0.1828)
tensor(0.1630)
tensor(0.1592)
tensor(0.1912)
tensor(0.1644)
tensor(0.1483)
tensor(0.2777)
tensor(0.1

In [None]:
# | export


class AverageRecall(MeanAverageRecall):
    """Calculate the COCO average recall (AR) for object detection."""

    def __init__(self, iou_threshold: float, *args, **kwargs):
        iou_thresholds = [iou_threshold]
        super().__init__(iou_thresholds=iou_thresholds, *args, **kwargs)

In [None]:
test = AverageRecall(iou_threshold=0.1, max_bboxes_per_image=100)

for _ in range(100):
    pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 5, 6) * 128, "cccwhd") for i in range(10)]
    pred_confidence_scores = [torch.rand(i + 5, 4) for i in range(10)]

    target_bboxes = pred_bboxes
    target_classes = [torch.randint(1, 4, (i + 5,)) for i in range(10)]

    ar10 = test(
        pred_bboxes,
        pred_confidence_scores,
        pred_bboxes,
        target_classes,
    )
    print(ar10)

print(len(test.pred_bboxes))
test.reset()
print(len(test.pred_bboxes))

tensor(0.2828)
tensor(0.2866)
tensor(0.3844)
tensor(0.3709)
tensor(0.2954)
tensor(0.3364)
tensor(0.3216)
tensor(0.2218)
tensor(0.3167)
tensor(0.2398)
tensor(0.3361)
tensor(0.3275)
tensor(0.2768)
tensor(0.2670)
tensor(0.3062)
tensor(0.3645)
tensor(0.3088)
tensor(0.3405)
tensor(0.2954)
tensor(0.2741)
tensor(0.3771)
tensor(0.2795)
tensor(0.3678)
tensor(0.2948)
tensor(0.2550)
tensor(0.3180)
tensor(0.2503)
tensor(0.3179)
tensor(0.3483)
tensor(0.2510)
tensor(0.3077)
tensor(0.2637)
tensor(0.3886)
tensor(0.2968)
tensor(0.3361)
tensor(0.3714)
tensor(0.3134)
tensor(0.3127)
tensor(0.2696)
tensor(0.2315)
tensor(0.2465)
tensor(0.3458)
tensor(0.3051)
tensor(0.3666)
tensor(0.3366)
tensor(0.2726)
tensor(0.3250)
tensor(0.3102)
tensor(0.2340)
tensor(0.3461)
tensor(0.3016)
tensor(0.3811)
tensor(0.3207)
tensor(0.3050)
tensor(0.3372)
tensor(0.3131)
tensor(0.1949)
tensor(0.2759)
tensor(0.3560)
tensor(0.2738)
tensor(0.3564)
tensor(0.3576)
tensor(0.3518)
tensor(0.3672)
tensor(0.3153)
tensor(0.3570)
tensor(0.3

In [None]:
test = AverageRecall(iou_threshold=0.1, max_bboxes_per_image=100)

for _ in range(5):
    pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 5, 6) * 128, "cccwhd") for i in range(10)]
    pred_confidence_scores = [torch.rand(i + 5, 4) for i in range(10)]

    target_bboxes = pred_bboxes
    target_classes = [torch.argmax(pred_confidence_score, dim=-1) for pred_confidence_score in pred_confidence_scores]

    ar10 = test(
        pred_bboxes,
        pred_confidence_scores,
        pred_bboxes,
        target_classes,
    )
    print(ar10)

print(len(test.pred_bboxes))
test.reset()
print(len(test.pred_bboxes))

tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
50
0


# nbdev

In [None]:
!nbdev_export