In [None]:
# | default_exp metrics/detection

# Imports

In [None]:
# | export


from typing import Literal

import torch
from torchmetrics import Metric

from vision_architectures.utils.bounding_boxes import get_tps_fps_fns, sort_by_first_column_descending

In [None]:
from monai.data.box_utils import convert_box_to_standard_mode

# Mean Average Precision

### Direct function

In [None]:
# | export


def mean_average_precision_mean_average_recall(
    pred_bboxes: list[torch.Tensor],
    pred_confidence_scores: list[torch.Tensor],
    target_bboxes: list[torch.Tensor],
    target_classes: list[torch.Tensor],
    iou_thresholds: list[float] = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95],
    average_precision_num_points: int = 101,
    min_confidence_threshold: float = 0.0,
    max_bboxes_per_image: int | None = 100,
    return_intermediates: bool = False,
) -> tuple[float, float] | tuple[float, float, dict[float, dict[int, float]], dict[float, dict[int, float]]]:
    """Calculate the COCO mean average precision (mAP) for object detection.

    Args:
        pred_bboxes: A list of length B containing tensors of shape (NP, 4) or (NP, 6) containing the predicted bounding
            box parameters in xyxy or xyzxyz format.
        pred_confidence_scores: A list of length B containing tensors of shape (NP, 1+num_classes) containing the
            predicted confidence scores for each class. Note that the first column corresponds to the "no-object" class,
            and bounding boxes which fall in this category are ignored.
        target_bboxes: A list of length B containing tensors of shape (NT, 4) or (NT, 6) containing the target bounding
            box parameters in xyxy or xyzxyz format.
        target_classes: A list of length B containing tensors of shape (NT,) containing the target class labels for the
            objects in the image.
        iou_thresholds: A list of IoU thresholds to use for calculating mAP and mAR.
        average_precision_num_points: Number of points over which to calculate average precision.
        min_confidence_threshold: Minimum confidence probability threshold to consider a prediction.
        max_bboxes_per_image: Maximum number of bounding boxes to consider per image. If more are present, only the top
            `max_bboxes_per_image` boxes based on confidence scores are considered. If set to None, all bounding boxes
            are considered.
        return_intermediates: If True, return intermediate values used to calculate mAP and mAR.

    Returns:
        The mean average precision (mAP) and mean average recall (mAR) across all classes and IoU thresholds for the
        entire dataset.
        If `return_intermediates` is True, also returns two dictionaries containing the average precision and average
        recall for each class at each IoU threshold.
    """
    # Set some globaly used variables
    B = len(pred_bboxes)
    num_classes = pred_confidence_scores[0].shape[-1] - 1

    # Some basic tests
    assert len(pred_bboxes) == len(pred_confidence_scores) == len(target_bboxes) == len(target_classes) == B, (
        f"All input lists must have the same length. Got lengths: {len(pred_bboxes)}, {len(pred_confidence_scores)}, "
        f"{len(target_bboxes)}, {len(target_classes)}."
    )
    assert all(
        pred_bbox.shape[0] == pred_confidence_score.shape[0]
        for pred_bbox, pred_confidence_score in zip(pred_bboxes, pred_confidence_scores)
    ), "Each prediction input list element must have the same number of bounding boxes."
    assert all(
        pred_bbox.shape[1] == 4 or pred_bbox.shape[1] == 6 for pred_bbox in pred_bboxes
    ), "Prediction bounding boxes must have shape (NP, 4) or (NP, 6)."
    assert all(
        pred_confidence_score.shape[-1] == num_classes + 1 for pred_confidence_score in pred_confidence_scores
    ), "Prediction class probabilities must have shape (NP, 1 + num_classes)."
    assert all(
        target_bbox.shape[0] == target_class.shape[0]
        for target_bbox, target_class in zip(target_bboxes, target_classes)
    ), "Each target must have the same number of bounding boxes."

    # Split everything based on different classes. Calculate confidence scores as well.
    pred_bboxes_by_class = [[] for _ in range(num_classes)]
    pred_confidences_scores_by_class = [[] for _ in range(num_classes)]
    target_bboxes_by_class = [[] for _ in range(num_classes)]
    for b in range(B):
        pred_classes = torch.argmax(pred_confidence_scores[b], dim=-1)
        # (NP,)

        for c in range(num_classes):
            pred_classes_mask = pred_classes == (c + 1)
            # (NP,)
            target_classes_mask = target_classes[b] == (c + 1)
            # (NT,)

            pred_bboxes_by_class[c].append(pred_bboxes[b][pred_classes_mask])
            pred_confidences_scores_by_class[c].append(pred_confidence_scores[b][pred_classes_mask][:, c + 1])
            # (NP,)

            target_bboxes_by_class[c].append(target_bboxes[b][target_classes_mask])
            # (NT,)

    # Limit number of bounding boxes per image if applicable
    if max_bboxes_per_image is not None:
        for b in range(B):
            _confidence_scores = []
            for c in range(num_classes):
                if pred_bboxes_by_class[c][b].numel() > 0:
                    _confidence_scores.append(
                        torch.stack(
                            [
                                pred_confidences_scores_by_class[c][b],
                                torch.arange(
                                    pred_confidences_scores_by_class[c][b].shape[0], device=pred_bboxes[b].device
                                ),
                                torch.full_like(pred_confidences_scores_by_class[c][b], c),
                            ],
                            dim=-1,
                        )
                    )
            if len(_confidence_scores) == 0:
                continue
            _confidence_scores = torch.cat(_confidence_scores, dim=0)
            # (NC, 3)

            if _confidence_scores.shape[0] > max_bboxes_per_image:
                _confidence_scores = sort_by_first_column_descending(_confidence_scores)
                topk_confidences = _confidence_scores[:max_bboxes_per_image]
                # (max_bboxes_per_image, 3)
                for c in range(num_classes):
                    class_mask = topk_confidences[:, 2] == c
                    # (max_bboxes_per_image,)
                    offsets_to_keep = topk_confidences[class_mask][:, 1].long()
                    pred_bboxes_by_class[c][b] = pred_bboxes_by_class[c][b][offsets_to_keep]
                    # (NP', 4) or (NP', 6)
                    pred_confidences_scores_by_class[c][b] = pred_confidences_scores_by_class[c][b][offsets_to_keep]
                    # (NP',)

    # For each IOU threshold, calculate average precision and average recall
    average_precisions = {}
    average_recalls = {}
    for iou_threshold in iou_thresholds:
        # For each class calculate average precision and average recall
        class_average_precisions = {}
        class_average_recalls = {}
        for c in range(num_classes):
            # If no target boxes for this class, skip it
            if all(target_bbox.numel() == 0 for target_bbox in target_bboxes_by_class[c]):
                class_average_precisions[c + 1] = float("nan")
                class_average_recalls[c + 1] = float("nan")
                continue

            _, _, _, intermediate_counts = get_tps_fps_fns(
                pred_bboxes=pred_bboxes_by_class[c],
                pred_confidence_scores=pred_confidences_scores_by_class[c],
                target_bboxes=target_bboxes_by_class[c],
                iou_threshold=iou_threshold,
                matching_method="coco",
                min_confidence_threshold=min_confidence_threshold,
                max_bboxes_per_image=max_bboxes_per_image,
                return_intermediate_counts=True,
            )
            intermediate_counts = torch.tensor(intermediate_counts, device=pred_bboxes[0].device, dtype=torch.float32)
            # (NC, 3) where the first column is TP, second is FP and third is FN for each prediction considered
            precisions = intermediate_counts[:, 0] / (intermediate_counts[:, 0] + intermediate_counts[:, 1] + 1e-8)
            recalls = intermediate_counts[:, 0] / (intermediate_counts[:, 0] + intermediate_counts[:, 2] + 1e-8)
            # (NC,), (NC,)

            # Precision envelope: P_interp(r) = max_{r' >= r} P(r')
            enveloped_precisions = precisions.clone()
            for i in range(len(enveloped_precisions) - 2, -1, -1):
                if enveloped_precisions[i] < enveloped_precisions[i + 1]:
                    enveloped_precisions[i] = enveloped_precisions[i + 1]

            # Calculate average precision using step-wise interpolation
            recall_samples = torch.linspace(0, 1, average_precision_num_points, device=recalls.device)
            idxs = torch.searchsorted(recalls, recall_samples, side="left")
            valid = idxs < enveloped_precisions.numel()
            enveloped_precisions_at_t = torch.zeros_like(recall_samples)
            enveloped_precisions_at_t[valid] = enveloped_precisions[idxs[valid]]
            class_average_precisions[c + 1] = enveloped_precisions_at_t.mean().item()

            # Calculate average recall i.e. maximum recall achieved at this IoU threshold
            class_average_recalls[c + 1] = recalls.max().item() if recalls.numel() > 0 else 0.0

        average_precisions[iou_threshold] = class_average_precisions
        average_recalls[iou_threshold] = class_average_recalls

    map_metric = torch.nanmean(
        torch.stack([torch.tensor(ap) for iou_aps in average_precisions.values() for ap in iou_aps.values()])
    ).item()
    mar_metric = torch.nanmean(
        torch.stack([torch.tensor(ar) for iou_ars in average_recalls.values() for ar in iou_ars.values()])
    ).item()

    if return_intermediates:
        return map_metric, mar_metric, average_precisions, average_recalls
    return map_metric, mar_metric


# Create aliases
map_mar = mean_average_precision_mean_average_recall
mean_average_precision_recall = mean_average_precision_mean_average_recall

In [None]:
# Random predicted and target boxes

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 10, 6) * 128, "cccwhd") for i in range(25)]
pred_confidence_scores = [torch.rand(i + 10, 6) for i in range(25)]

target_bboxes = [convert_box_to_standard_mode(torch.rand(i + 1 + 10 * (i % 2), 6) * 128, "cccwhd") for i in range(25)]
target_classes = [torch.randint(1, 6, (i + 1 + 10 * (i % 2),)) for i in range(25)]

print([x.shape for x in pred_confidence_scores])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_confidence_scores,
    target_bboxes,
    target_classes,
    iou_thresholds=[0.001],
    return_intermediates=True,
)

[torch.Size([10, 6]), torch.Size([11, 6]), torch.Size([12, 6]), torch.Size([13, 6]), torch.Size([14, 6]), torch.Size([15, 6]), torch.Size([16, 6]), torch.Size([17, 6]), torch.Size([18, 6]), torch.Size([19, 6]), torch.Size([20, 6]), torch.Size([21, 6]), torch.Size([22, 6]), torch.Size([23, 6]), torch.Size([24, 6]), torch.Size([25, 6]), torch.Size([26, 6]), torch.Size([27, 6]), torch.Size([28, 6]), torch.Size([29, 6]), torch.Size([30, 6]), torch.Size([31, 6]), torch.Size([32, 6]), torch.Size([33, 6]), torch.Size([34, 6])]
[torch.Size([1]), torch.Size([12]), torch.Size([3]), torch.Size([14]), torch.Size([5]), torch.Size([16]), torch.Size([7]), torch.Size([18]), torch.Size([9]), torch.Size([20]), torch.Size([11]), torch.Size([22]), torch.Size([13]), torch.Size([24]), torch.Size([15]), torch.Size([26]), torch.Size([17]), torch.Size([28]), torch.Size([19]), torch.Size([30]), torch.Size([21]), torch.Size([32]), torch.Size([23]), torch.Size([34]), torch.Size([25])]



[1m([0m
    [1;36m0.3264089822769165[0m,
    [1;36m0.502989649772644[0m,
    [1m{[0m
        [1;36m0.001[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.2578059136867523[0m,
            [1;36m2[0m: [1;36m0.3967687785625458[0m,
            [1;36m3[0m: [1;36m0.32028132677078247[0m,
            [1;36m4[0m: [1;36m0.30382004380226135[0m,
            [1;36m5[0m: [1;36m0.35336896777153015[0m
        [1m}[0m
    [1m}[0m,
    [1m{[0m[1;36m0.001[0m: [1m{[0m[1;36m1[0m: [1;36m0.5066666603088379[0m, [1;36m2[0m: [1;36m0.5952380895614624[0m, [1;36m3[0m: [1;36m0.5[0m, [1;36m4[0m: [1;36m0.41304346919059753[0m, [1;36m5[0m: [1;36m0.5[0m[1m}[0m[1m}[0m
[1m)[0m

In [None]:
# Predicted boxes are approximately equal to target boxes i.e. precision should be high

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i, 6) * 128, "cccwhd") for i in range(25)]
pred_confidence_scores = [torch.rand(i, 6) for i in range(25)]

target_bboxes = [pred_bboxes[i] + 0.5 for i in range(25)]
target_classes = [pred_confidence_scores[i].argmax(dim=-1) for i in range(25)]

print([x.shape for x in pred_confidence_scores])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_confidence_scores,
    target_bboxes,
    target_classes,
    return_intermediates=True,
)

[torch.Size([0, 6]), torch.Size([1, 6]), torch.Size([2, 6]), torch.Size([3, 6]), torch.Size([4, 6]), torch.Size([5, 6]), torch.Size([6, 6]), torch.Size([7, 6]), torch.Size([8, 6]), torch.Size([9, 6]), torch.Size([10, 6]), torch.Size([11, 6]), torch.Size([12, 6]), torch.Size([13, 6]), torch.Size([14, 6]), torch.Size([15, 6]), torch.Size([16, 6]), torch.Size([17, 6]), torch.Size([18, 6]), torch.Size([19, 6]), torch.Size([20, 6]), torch.Size([21, 6]), torch.Size([22, 6]), torch.Size([23, 6]), torch.Size([24, 6])]
[torch.Size([0]), torch.Size([1]), torch.Size([2]), torch.Size([3]), torch.Size([4]), torch.Size([5]), torch.Size([6]), torch.Size([7]), torch.Size([8]), torch.Size([9]), torch.Size([10]), torch.Size([11]), torch.Size([12]), torch.Size([13]), torch.Size([14]), torch.Size([15]), torch.Size([16]), torch.Size([17]), torch.Size([18]), torch.Size([19]), torch.Size([20]), torch.Size([21]), torch.Size([22]), torch.Size([23]), torch.Size([24])]



[1m([0m
    [1;36m0.7494872212409973[0m,
    [1;36m0.8362684845924377[0m,
    [1m{[0m
        [1;36m0.5[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.9171066284179688[0m,
            [1;36m2[0m: [1;36m0.9698019623756409[0m,
            [1;36m3[0m: [1;36m0.8550244569778442[0m,
            [1;36m4[0m: [1;36m0.9474145770072937[0m,
            [1;36m5[0m: [1;36m0.9071193933486938[0m
        [1m}[0m,
        [1;36m0.55[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.9171066284179688[0m,
            [1;36m2[0m: [1;36m0.9698019623756409[0m,
            [1;36m3[0m: [1;36m0.8550244569778442[0m,
            [1;36m4[0m: [1;36m0.9474145770072937[0m,
            [1;36m5[0m: [1;36m0.9071193933486938[0m
        [1m}[0m,
        [1;36m0.6[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.9171066284179688[0m,
            [1;36m2[0m: [1;36m0.9504950642585754[0m,
            [1;36m3[0m: [1;36m0.8550244569778442[0m,
            [1;36m4[0m:

In [None]:
# Predicted boxes are subset of target boxes but with random classes i.e. precision should be high

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i, 6) * 128, "cccwhd") for i in range(25)]
pred_confidence_scores = [torch.rand(i, 6) for i in range(25)]

target_bboxes = [
    torch.cat([pred_bboxes[i], convert_box_to_standard_mode(torch.rand(i, 6) * 128, "cccwhd")]) for i in range(25)
]
target_classes = [torch.cat([pred_confidence_scores[i].argmax(dim=-1), torch.randint(1, 6, (i,))]) for i in range(25)]

print([x.shape for x in pred_confidence_scores])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_confidence_scores,
    target_bboxes,
    target_classes,
    return_intermediates=True,
)

[torch.Size([0, 6]), torch.Size([1, 6]), torch.Size([2, 6]), torch.Size([3, 6]), torch.Size([4, 6]), torch.Size([5, 6]), torch.Size([6, 6]), torch.Size([7, 6]), torch.Size([8, 6]), torch.Size([9, 6]), torch.Size([10, 6]), torch.Size([11, 6]), torch.Size([12, 6]), torch.Size([13, 6]), torch.Size([14, 6]), torch.Size([15, 6]), torch.Size([16, 6]), torch.Size([17, 6]), torch.Size([18, 6]), torch.Size([19, 6]), torch.Size([20, 6]), torch.Size([21, 6]), torch.Size([22, 6]), torch.Size([23, 6]), torch.Size([24, 6])]
[torch.Size([0]), torch.Size([2]), torch.Size([4]), torch.Size([6]), torch.Size([8]), torch.Size([10]), torch.Size([12]), torch.Size([14]), torch.Size([16]), torch.Size([18]), torch.Size([20]), torch.Size([22]), torch.Size([24]), torch.Size([26]), torch.Size([28]), torch.Size([30]), torch.Size([32]), torch.Size([34]), torch.Size([36]), torch.Size([38]), torch.Size([40]), torch.Size([42]), torch.Size([44]), torch.Size([46]), torch.Size([48])]



[1m([0m
    [1;36m0.46138614416122437[0m,
    [1;36m0.46011048555374146[0m,
    [1m{[0m
        [1;36m0.5[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.5148515105247498[0m,
            [1;36m2[0m: [1;36m0.48514851927757263[0m,
            [1;36m3[0m: [1;36m0.3762376308441162[0m,
            [1;36m4[0m: [1;36m0.4752475321292877[0m,
            [1;36m5[0m: [1;36m0.4554455578327179[0m
        [1m}[0m,
        [1;36m0.55[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.5148515105247498[0m,
            [1;36m2[0m: [1;36m0.48514851927757263[0m,
            [1;36m3[0m: [1;36m0.3762376308441162[0m,
            [1;36m4[0m: [1;36m0.4752475321292877[0m,
            [1;36m5[0m: [1;36m0.4554455578327179[0m
        [1m}[0m,
        [1;36m0.6[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.5148515105247498[0m,
            [1;36m2[0m: [1;36m0.48514851927757263[0m,
            [1;36m3[0m: [1;36m0.3762376308441162[0m,
            [1;36m4

In [None]:
# Target boxes are subset of prediction boxes but with random classes i.e. recall should be high

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 10, 6) * 128, "cccwhd") for i in range(25)]
pred_confidence_scores = [torch.rand(i + 10, 6) for i in range(25)]

target_bboxes = [pred_bboxes[i][: i + 1] for i in range(25)]
target_classes = [pred_confidence_scores[i][: i + 1].argmax(dim=-1) for i in range(25)]

print([x.shape for x in pred_confidence_scores])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_confidence_scores,
    target_bboxes,
    target_classes,
    return_intermediates=True,
)

[torch.Size([10, 6]), torch.Size([11, 6]), torch.Size([12, 6]), torch.Size([13, 6]), torch.Size([14, 6]), torch.Size([15, 6]), torch.Size([16, 6]), torch.Size([17, 6]), torch.Size([18, 6]), torch.Size([19, 6]), torch.Size([20, 6]), torch.Size([21, 6]), torch.Size([22, 6]), torch.Size([23, 6]), torch.Size([24, 6]), torch.Size([25, 6]), torch.Size([26, 6]), torch.Size([27, 6]), torch.Size([28, 6]), torch.Size([29, 6]), torch.Size([30, 6]), torch.Size([31, 6]), torch.Size([32, 6]), torch.Size([33, 6]), torch.Size([34, 6])]
[torch.Size([1]), torch.Size([2]), torch.Size([3]), torch.Size([4]), torch.Size([5]), torch.Size([6]), torch.Size([7]), torch.Size([8]), torch.Size([9]), torch.Size([10]), torch.Size([11]), torch.Size([12]), torch.Size([13]), torch.Size([14]), torch.Size([15]), torch.Size([16]), torch.Size([17]), torch.Size([18]), torch.Size([19]), torch.Size([20]), torch.Size([21]), torch.Size([22]), torch.Size([23]), torch.Size([24]), torch.Size([25])]



[1m([0m
    [1;36m0.6364901661872864[0m,
    [1;36m1.0[0m,
    [1m{[0m
        [1;36m0.5[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.7791070342063904[0m,
            [1;36m2[0m: [1;36m0.623065710067749[0m,
            [1;36m3[0m: [1;36m0.5761346817016602[0m,
            [1;36m4[0m: [1;36m0.5862584710121155[0m,
            [1;36m5[0m: [1;36m0.6204817891120911[0m
        [1m}[0m,
        [1;36m0.55[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.7762216329574585[0m,
            [1;36m2[0m: [1;36m0.623065710067749[0m,
            [1;36m3[0m: [1;36m0.5761346817016602[0m,
            [1;36m4[0m: [1;36m0.5862584710121155[0m,
            [1;36m5[0m: [1;36m0.6204817891120911[0m
        [1m}[0m,
        [1;36m0.6[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.7762216329574585[0m,
            [1;36m2[0m: [1;36m0.623065710067749[0m,
            [1;36m3[0m: [1;36m0.5761346817016602[0m,
            [1;36m4[0m: [1;36m0.58625847

### Lightning metrics

In [None]:
# | export


class MeanAveragePrecisionMeanAverageRecall(Metric):
    """Calculate the COCO mean average precision (mAP) and mean average recall (mAR) for object detection."""

    def __init__(
        self,
        iou_thresholds: list[float] = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95],
        average_precision_num_points: int = 101,
        min_confidence_threshold: float = 0.0,
        max_bboxes_per_image: int | None = 100,
    ):
        """Initialize the MeanAveragePrecisionMeanAverageRecall metric.

        Args:
            num_classes: Number of classes in the dataset.
            iou_thresholds: A list of IoU thresholds to use for calculating mAP and mAR.
            average_precision_num_points: Number of points over which to calculate average precision.
            min_confidence_threshold: Minimum confidence score threshold to consider a prediction.
            max_bboxes_per_image: Maximum number of bounding boxes to consider per image. If more are present, only the
                top `max_bboxes_per_image` boxes based on confidence scores are considered.
        """
        super().__init__()

        self.iou_thresholds = iou_thresholds
        self.average_precision_num_points = average_precision_num_points
        self.min_confidence_threshold = min_confidence_threshold
        self.max_bboxes_per_image = max_bboxes_per_image

        self.add_state("pred_bboxes", [], dist_reduce_fx=None, persistent=False)
        self.add_state("pred_confidence_scores", [], dist_reduce_fx=None, persistent=False)
        self.add_state("target_bboxes", [], dist_reduce_fx=None, persistent=False)
        self.add_state("target_classes", [], dist_reduce_fx=None, persistent=False)

    def update(
        self,
        pred_bboxes: list[torch.Tensor],
        pred_confidence_scores: list[torch.Tensor],
        target_bboxes: list[torch.Tensor],
        target_classes: list[torch.Tensor],
    ):
        self.pred_bboxes.extend(pred_bboxes)
        self.pred_confidence_scores.extend(pred_confidence_scores)
        self.target_bboxes.extend(target_bboxes)
        self.target_classes.extend(target_classes)

    def compute(self):
        return mean_average_precision_mean_average_recall(
            self.pred_bboxes,
            self.pred_confidence_scores,
            self.target_bboxes,
            self.target_classes,
            iou_thresholds=self.iou_thresholds,
            average_precision_num_points=self.average_precision_num_points,
            min_confidence_threshold=self.min_confidence_threshold,
            max_bboxes_per_image=self.max_bboxes_per_image,
        )

    def forward(self, *args, return_metrics: Literal["map_only", "mar_only", "map_mar"] = "map_mar", **kwargs):
        map, mar = super().forward(*args, **kwargs)
        if return_metrics == "map_only":
            return map
        elif return_metrics == "mar_only":
            return mar
        return map, mar


# Aliases
MeanAveragePrecisionRecall = MeanAveragePrecisionMeanAverageRecall

In [None]:
test = MeanAveragePrecisionMeanAverageRecall(max_bboxes_per_image=100)

for _ in range(100):
    pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 5, 6) * 128, "cccwhd") for i in range(10)]
    pred_confidence_scores = [torch.rand(i + 5, 4) for i in range(10)]

    target_bboxes = pred_bboxes
    target_classes = [torch.randint(1, 4, (i + 5,)) for i in range(10)]

    map, mar = test(
        pred_bboxes,
        pred_confidence_scores,
        pred_bboxes,
        target_classes,
    )
    print(map, mar)

print(len(test.pred_bboxes))
test.reset()
print(len(test.pred_bboxes))

0.12114287912845612 0.23144379258155823
0.11896856874227524 0.2432991862297058
0.17773547768592834 0.31904760003089905
0.046452268958091736 0.175459086894989
0.1610197126865387 0.29822856187820435
0.13873833417892456 0.28324511647224426
0.15099795162677765 0.2914254069328308
0.1306196004152298 0.27521368861198425
0.12080462276935577 0.2408391684293747
0.1591007113456726 0.2665651738643646
0.11366598308086395 0.23442761600017548
0.09732644259929657 0.25192898511886597
0.13474041223526 0.26367121934890747
0.08466680347919464 0.22331950068473816
0.14574941992759705 0.29970234632492065
0.09319876879453659 0.2093474268913269
0.13157980144023895 0.29380589723587036
0.18510419130325317 0.31570881605148315
0.1347496509552002 0.29645636677742004
0.12988512217998505 0.26022130250930786
0.1308775097131729 0.24292929470539093
0.08382631093263626 0.2107023298740387
0.1526949256658554 0.2888981103897095
0.11463499069213867 0.243418350815773
0.08263934403657913 0.1864435374736786
0.12882666289806366 

In [None]:
# | export


class MeanAveragePrecision(MeanAveragePrecisionMeanAverageRecall):
    """Calculate the COCO mean average precision (mAP) for object detection."""

    def forward(self, *args, **kwargs):
        return super().forward(*args, return_metrics="map_only", **kwargs)

In [None]:
# | export


class MeanAverageRecall(MeanAveragePrecisionMeanAverageRecall):
    """Calculate the COCO mean average recall (mAR) for object detection."""

    def forward(self, *args, **kwargs):
        return super().forward(*args, return_metrics="mar_only", **kwargs)

In [None]:
# | export


class AveragePrecision(MeanAveragePrecision):
    """Calculate the COCO average precision (AP) for object detection."""

    def __init__(self, iou_threshold: float, *args, **kwargs):
        iou_thresholds = [iou_threshold]
        super().__init__(iou_thresholds=iou_thresholds, *args, **kwargs)

In [None]:
# | export


class AverageRecall(MeanAverageRecall):
    """Calculate the COCO average recall (AR) for object detection."""

    def __init__(self, iou_threshold: float, *args, **kwargs):
        iou_thresholds = [iou_threshold]
        super().__init__(iou_thresholds=iou_thresholds, *args, **kwargs)

# nbdev

In [None]:
!nbdev_export