In [None]:
# | default_exp metrics/detection

# Imports

In [None]:
# | export


from typing import Literal

import torch
from monai.data.box_utils import box_iou
from torchmetrics import Metric

In [None]:
from monai.data.box_utils import convert_box_to_standard_mode

# Mean Average Precision

### Direct function

In [None]:
# | export


def mean_average_precision_mean_average_recall(
    pred_bboxes: list[torch.Tensor],
    pred_objectness_probabilities: list[torch.Tensor] | None,
    pred_class_probabilities: list[torch.Tensor],
    target_bboxes: list[torch.Tensor],
    target_classes: list[torch.Tensor],
    iou_thresholds: list[float] = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95],
    average_precision_num_points: int = 101,
    min_confidence_threshold: float = 0.0,
    max_bboxes_per_image: int | None = 100,
    return_intermediates: bool = False,
) -> tuple[float, float] | tuple[float, float, dict[float, dict[int, float]], dict[float, dict[int, float]]]:
    """Calculate the COCO mean average precision (mAP) for object detection.

    Args:
        pred_bboxes: A list of length B containing tensors of shape (NP, 4) or (NP, 6) containing the predicted bounding
            box parameters in xyxy or xyzxyz format.
        pred_objectness_probabilities: A list of length B containing tensors of shape (NP,) containing the predicted
            objectness probabilities for the corresponding bounding boxes. This can be set to None in which case only
            the class probabilities are considered.
        pred_class_probabilities: A list of length B containing tensors of shape (NP, num_classes) containing the
            predicted class probabilities for the corresponding bounding boxes.
        target_bboxes: A list of length B containing tensors of shape (NT, 4) or (NT, 6) containing the target bounding
            box parameters in xyxy or xyzxyz format.
        target_classes: A list of length B containing tensors of shape (NT,) containing the target class labels for the
            objects in the image.
        iou_thresholds: A list of IoU thresholds to use for calculating mAP and mAR.
        average_precision_num_points: Number of points over which to calculate average precision.
        min_confidence_threshold: Minimum confidence probability threshold to consider a prediction.
        max_bboxes_per_image: Maximum number of bounding boxes to consider per image. If more are present, only the top
            `max_bboxes_per_image` boxes based on confidence scores are considered. If set to None, all bounding boxes
            are considered.
        return_intermediates: If True, return intermediate values used to calculate mAP and mAR.

    Returns:
        The mean average precision (mAP) and mean average recall (mAR) across all classes and IoU thresholds for the
        entire dataset.
        If `return_intermediates` is True, also returns two dictionaries containing the average precision and average
        recall for each class at each IoU threshold.
    """
    # Set some globaly used variables
    B = len(pred_bboxes)
    num_classes = pred_class_probabilities[0].shape[-1]

    if pred_objectness_probabilities is None:
        pred_objectness_probabilities = [
            torch.ones_like(pred_class_probability[:, 0]) for pred_class_probability in pred_class_probabilities
        ]

    # Some basic tests
    assert (
        len(pred_bboxes)
        == len(pred_objectness_probabilities)
        == len(pred_class_probabilities)
        == len(target_bboxes)
        == len(target_classes)
        == B
    ), (
        f"All input lists must have the same length. Got lengths: {len(pred_bboxes)}, "
        f"{len(pred_objectness_probabilities)}, {len(pred_class_probabilities)}, {len(target_bboxes)}, "
        f"{len(target_classes)}."
    )
    assert all(
        pred_bbox.shape[0] == pred_objectness_probability.shape[0] == pred_class_probability.shape[0]
        for pred_bbox, pred_objectness_probability, pred_class_probability in zip(
            pred_bboxes, pred_objectness_probabilities, pred_class_probabilities
        )
    ), "Each prediction input list element must have the same number of bounding boxes."
    assert all(
        pred_bbox.shape[1] == 4 or pred_bbox.shape[1] == 6 for pred_bbox in pred_bboxes
    ), "Prediction bounding boxes must have shape (NP, 4) or (NP, 6)."
    assert all(
        pred_class_probability.shape[1] == num_classes for pred_class_probability in pred_class_probabilities
    ), "Prediction class probabilities must have shape (NP, num_classes)."
    assert all(
        target_bbox.shape[0] == target_class.shape[0]
        for target_bbox, target_class in zip(target_bboxes, target_classes)
    ), "Each target must have the same number of bounding boxes."

    # Split everything based on different classes
    pred_bboxes_by_class = [[] for _ in range(num_classes)]
    pred_objectness_probabilities_by_class = [[] for _ in range(num_classes)]
    pred_class_probabilities_by_class = [[] for _ in range(num_classes)]
    target_bboxes_by_class = [[] for _ in range(num_classes)]
    for b in range(B):
        pred_classes = torch.argmax(pred_class_probabilities[b], dim=-1)
        # (NP,)
        for c in range(num_classes):
            pred_classes_mask = pred_classes == c
            # (NP,)
            target_classes_mask = target_classes[b] == (c + 1)
            # (NT,)

            pred_bboxes_by_class[c].append(pred_bboxes[b][pred_classes_mask])
            pred_objectness_probabilities_by_class[c].append(pred_objectness_probabilities[b][pred_classes_mask])
            pred_class_probabilities_by_class[c].append(pred_class_probabilities[b][pred_classes_mask])
            # (NP,)

            target_bboxes_by_class[c].append(target_bboxes[b][target_classes_mask])
            # (NT,)

    # Helper function to sort a tensor in descending order based on values in first column
    def _sort_by_first_column_descending(tensor: torch.Tensor) -> torch.Tensor:
        return tensor[torch.argsort(tensor[:, 0], descending=True)]

    # Calculate IOUs for all prediction and target bounding box pairs for each class
    # Also calculate confidence scores for each prediction bounding box along with index
    # Also track number of target boxes for each class
    ious = [[] for _ in range(num_classes)]
    confidence_scores = [[] for _ in range(num_classes)]
    num_target_boxes = [0 for _ in range(num_classes)]
    for b in range(B):
        for c in range(num_classes):
            _ious = box_iou(pred_bboxes_by_class[c][b], target_bboxes_by_class[c][b])
            # (NP, NT)

            _confidence_scores = (
                pred_objectness_probabilities_by_class[c][b] * pred_class_probabilities_by_class[c][b][:, c]
            )
            _batch_index = torch.full_like(_confidence_scores, b)
            _offset_index = torch.arange(len(_confidence_scores), device=_confidence_scores.device)
            _confidence_scores = torch.stack([_confidence_scores, _batch_index, _offset_index], dim=-1)
            # (NP, 3) -> (confidence_score, batch_index, offset_index)

            ious[c].append(_ious)
            confidence_scores[c].append(_confidence_scores)
            num_target_boxes[c] += target_bboxes_by_class[c][b].shape[0]

        # Limit number of bounding boxes per image if applicable
        if max_bboxes_per_image is not None:
            all_confidences = torch.cat(
                [
                    torch.cat(
                        [
                            confidence_scores[c][b],  # (NC, 3)
                            torch.full_like(confidence_scores[c][b][:, :1], c),  # (NC, 1)
                        ],
                        dim=-1,
                    )  # (NC, 4)
                    for c in range(num_classes)
                ],
                dim=0,
            )
            if all_confidences.shape[0] > max_bboxes_per_image:
                all_confidences = _sort_by_first_column_descending(all_confidences)
                topk_confidences = all_confidences[:max_bboxes_per_image]
                # (max_bboxes_per_image, 4)
                for c in range(num_classes):
                    class_mask = topk_confidences[:, 3] == c
                    # (max_bboxes_per_image,)
                    confidence_scores[c][b] = topk_confidences[class_mask][:, :3]
                    # (N_class, 3)

    # Concatenate confidence scores and sort them in descending order for each class
    for c in range(num_classes):
        confidence_scores[c] = torch.cat(confidence_scores[c], dim=0)
        confidence_scores[c] = _sort_by_first_column_descending(confidence_scores[c])

    # For each IOU threshold, calculate average precision and average recall
    average_precisions = {}
    average_recalls = {}
    for iou_threshold in iou_thresholds:
        # For each class calculate average precision and average recall
        class_average_precisions = {}
        class_average_recalls = {}
        for c in range(num_classes):
            if num_target_boxes[c] == 0:
                # If no target boxes for this class, skip it
                class_average_precisions[c + 1] = float("nan")
                class_average_recalls[c + 1] = float("nan")
                continue

            matched_target_indices = [set() for _ in range(B)]
            tps, fps, fns = 0, 0, num_target_boxes[c]
            precisions, recalls = [], []

            # In descending order, update tp, fp, fn and calculate precision and recall at each step
            for confidence_score, b, pred_offset in confidence_scores[c]:
                if confidence_score < min_confidence_threshold:
                    # Do not consider this and following predictions if confidence score is below threshold
                    break

                b, pred_offset = int(b), int(pred_offset)

                pred_ious = ious[c][b][pred_offset]
                # (NT,)

                if pred_ious.numel() > 0:
                    if matched_target_indices[b]:
                        # Exclude already matched target boxes
                        pred_ious[list(matched_target_indices[b])] = -1.0
                    # Exclude target boxes below IOU threshold
                    pred_ious[pred_ious < iou_threshold] = -1.0

                if pred_ious.numel() == 0 or pred_ious.amax() < 0.0:
                    # No valid target box to match with
                    fps += 1
                else:
                    target_offset = pred_ious.argmax().item()
                    matched_target_indices[b].add(target_offset)
                    tps += 1
                    fns -= 1

                precisions.append(tps / (tps + fps) if (tps + fps) > 0 else 1.0)
                recalls.append(tps / (tps + fns) if (tps + fns) > 0 else 0.0)

            precisions = torch.tensor(precisions)
            recalls = torch.tensor(recalls)

            # Precision envelope: P_interp(r) = max_{r' >= r} P(r')
            enveloped_precisions = precisions.clone()
            for i in range(len(enveloped_precisions) - 2, -1, -1):
                if enveloped_precisions[i] < enveloped_precisions[i + 1]:
                    enveloped_precisions[i] = enveloped_precisions[i + 1]

            # Calculate average precision using step-wise interpolation
            recall_samples = torch.linspace(0, 1, average_precision_num_points, device=recalls.device)
            idxs = torch.searchsorted(recalls, recall_samples, side="left")
            valid = idxs < enveloped_precisions.numel()
            enveloped_precisions_at_t = torch.zeros_like(recall_samples)
            enveloped_precisions_at_t[valid] = enveloped_precisions[idxs[valid]]
            class_average_precisions[c + 1] = enveloped_precisions_at_t.mean().item()

            # Calculate average recall i.e. maximum recall achieved at this IoU threshold
            class_average_recalls[c + 1] = recalls.max().item() if recalls.numel() > 0 else 0.0

        average_precisions[iou_threshold] = class_average_precisions
        average_recalls[iou_threshold] = class_average_recalls

    map_metric = torch.nanmean(
        torch.stack([torch.tensor(ap) for iou_aps in average_precisions.values() for ap in iou_aps.values()])
    ).item()
    mar_metric = torch.nanmean(
        torch.stack([torch.tensor(ar) for iou_ars in average_recalls.values() for ar in iou_ars.values()])
    ).item()

    if return_intermediates:
        return map_metric, mar_metric, average_precisions, average_recalls
    return map_metric, mar_metric


# Create aliases
map_mar = mean_average_precision_mean_average_recall
mean_average_precision_recall = mean_average_precision_mean_average_recall

In [None]:
# Random predicted and target boxes

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 10, 6) * 128, "cccwhd") for i in range(25)]
pred_objectness_probabilities = [torch.rand(i + 10) for i in range(25)]
pred_class_probabilities = [torch.rand(i + 10, 5) for i in range(25)]

target_bboxes = [convert_box_to_standard_mode(torch.rand(i + 1 + 10 * (i % 2), 6) * 128, "cccwhd") for i in range(25)]
target_classes = [torch.randint(1, 6, (i + 1 + 10 * (i % 2),)) for i in range(25)]

print([x.shape for x in pred_objectness_probabilities])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_objectness_probabilities,
    pred_class_probabilities,
    target_bboxes,
    target_classes,
    iou_thresholds=[0.001],
    return_intermediates=True,
)

[torch.Size([10]), torch.Size([11]), torch.Size([12]), torch.Size([13]), torch.Size([14]), torch.Size([15]), torch.Size([16]), torch.Size([17]), torch.Size([18]), torch.Size([19]), torch.Size([20]), torch.Size([21]), torch.Size([22]), torch.Size([23]), torch.Size([24]), torch.Size([25]), torch.Size([26]), torch.Size([27]), torch.Size([28]), torch.Size([29]), torch.Size([30]), torch.Size([31]), torch.Size([32]), torch.Size([33]), torch.Size([34])]
[torch.Size([1]), torch.Size([12]), torch.Size([3]), torch.Size([14]), torch.Size([5]), torch.Size([16]), torch.Size([7]), torch.Size([18]), torch.Size([9]), torch.Size([20]), torch.Size([11]), torch.Size([22]), torch.Size([13]), torch.Size([24]), torch.Size([15]), torch.Size([26]), torch.Size([17]), torch.Size([28]), torch.Size([19]), torch.Size([30]), torch.Size([21]), torch.Size([32]), torch.Size([23]), torch.Size([34]), torch.Size([25])]



[1m([0m
    [1;36m0.33234840631484985[0m,
    [1;36m0.5730844736099243[0m,
    [1m{[0m
        [1;36m0.001[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.32793092727661133[0m,
            [1;36m2[0m: [1;36m0.3697085976600647[0m,
            [1;36m3[0m: [1;36m0.32134315371513367[0m,
            [1;36m4[0m: [1;36m0.32542985677719116[0m,
            [1;36m5[0m: [1;36m0.3173292577266693[0m
        [1m}[0m
    [1m}[0m,
    [1m{[0m
        [1;36m0.001[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.6052631735801697[0m,
            [1;36m2[0m: [1;36m0.5495495200157166[0m,
            [1;36m3[0m: [1;36m0.5280898809432983[0m,
            [1;36m4[0m: [1;36m0.6022727489471436[0m,
            [1;36m5[0m: [1;36m0.5802469253540039[0m
        [1m}[0m
    [1m}[0m
[1m)[0m

In [None]:
# Predicted boxes are approximately equal to target boxes i.e. precision should be high

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i, 6) * 128, "cccwhd") for i in range(25)]
pred_objectness_probabilities = [torch.rand(i) for i in range(25)]
pred_class_probabilities = [torch.rand(i, 5) for i in range(25)]

target_bboxes = [pred_bboxes[i] + 0.5 for i in range(25)]
target_classes = [pred_class_probabilities[i].argmax(dim=-1) + 1 for i in range(25)]

print([x.shape for x in pred_objectness_probabilities])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_objectness_probabilities,
    pred_class_probabilities,
    target_bboxes,
    target_classes,
    return_intermediates=True,
)

[torch.Size([0]), torch.Size([1]), torch.Size([2]), torch.Size([3]), torch.Size([4]), torch.Size([5]), torch.Size([6]), torch.Size([7]), torch.Size([8]), torch.Size([9]), torch.Size([10]), torch.Size([11]), torch.Size([12]), torch.Size([13]), torch.Size([14]), torch.Size([15]), torch.Size([16]), torch.Size([17]), torch.Size([18]), torch.Size([19]), torch.Size([20]), torch.Size([21]), torch.Size([22]), torch.Size([23]), torch.Size([24])]
[torch.Size([0]), torch.Size([1]), torch.Size([2]), torch.Size([3]), torch.Size([4]), torch.Size([5]), torch.Size([6]), torch.Size([7]), torch.Size([8]), torch.Size([9]), torch.Size([10]), torch.Size([11]), torch.Size([12]), torch.Size([13]), torch.Size([14]), torch.Size([15]), torch.Size([16]), torch.Size([17]), torch.Size([18]), torch.Size([19]), torch.Size([20]), torch.Size([21]), torch.Size([22]), torch.Size([23]), torch.Size([24])]



[1m([0m
    [1;36m0.7826094627380371[0m,
    [1;36m0.8521413207054138[0m,
    [1m{[0m
        [1;36m0.5[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.9727721810340881[0m,
            [1;36m2[0m: [1;36m0.9051622152328491[0m,
            [1;36m3[0m: [1;36m0.9708970785140991[0m,
            [1;36m4[0m: [1;36m1.0[0m,
            [1;36m5[0m: [1;36m0.965676486492157[0m
        [1m}[0m,
        [1;36m0.55[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.9727721810340881[0m,
            [1;36m2[0m: [1;36m0.9051622152328491[0m,
            [1;36m3[0m: [1;36m0.9423943161964417[0m,
            [1;36m4[0m: [1;36m0.9658967852592468[0m,
            [1;36m5[0m: [1;36m0.965676486492157[0m
        [1m}[0m,
        [1;36m0.6[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.9375[0m,
            [1;36m2[0m: [1;36m0.864537239074707[0m,
            [1;36m3[0m: [1;36m0.9423943161964417[0m,
            [1;36m4[0m: [1;36m0.9333105683326721[0m

In [None]:
# Predicted boxes are subset of target boxes but with random classes i.e. precision should be high

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i, 6) * 128, "cccwhd") for i in range(25)]
pred_objectness_probabilities = [torch.rand(i) for i in range(25)]
pred_class_probabilities = [torch.rand(i, 5) for i in range(25)]

target_bboxes = [
    torch.cat([pred_bboxes[i], convert_box_to_standard_mode(torch.rand(i, 6) * 128, "cccwhd")]) for i in range(25)
]
target_classes = [
    torch.cat([pred_class_probabilities[i].argmax(dim=-1) + 1, torch.randint(1, 6, (i,))]) for i in range(25)
]

print([x.shape for x in pred_objectness_probabilities])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_objectness_probabilities,
    pred_class_probabilities,
    target_bboxes,
    target_classes,
    return_intermediates=True,
)

[torch.Size([0]), torch.Size([1]), torch.Size([2]), torch.Size([3]), torch.Size([4]), torch.Size([5]), torch.Size([6]), torch.Size([7]), torch.Size([8]), torch.Size([9]), torch.Size([10]), torch.Size([11]), torch.Size([12]), torch.Size([13]), torch.Size([14]), torch.Size([15]), torch.Size([16]), torch.Size([17]), torch.Size([18]), torch.Size([19]), torch.Size([20]), torch.Size([21]), torch.Size([22]), torch.Size([23]), torch.Size([24])]
[torch.Size([0]), torch.Size([2]), torch.Size([4]), torch.Size([6]), torch.Size([8]), torch.Size([10]), torch.Size([12]), torch.Size([14]), torch.Size([16]), torch.Size([18]), torch.Size([20]), torch.Size([22]), torch.Size([24]), torch.Size([26]), torch.Size([28]), torch.Size([30]), torch.Size([32]), torch.Size([34]), torch.Size([36]), torch.Size([38]), torch.Size([40]), torch.Size([42]), torch.Size([44]), torch.Size([46]), torch.Size([48])]



[1m([0m
    [1;36m0.49702972173690796[0m,
    [1;36m0.49871501326560974[0m,
    [1m{[0m
        [1;36m0.5[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.41584157943725586[0m,
            [1;36m2[0m: [1;36m0.5148515105247498[0m,
            [1;36m3[0m: [1;36m0.5049505233764648[0m,
            [1;36m4[0m: [1;36m0.5049505233764648[0m,
            [1;36m5[0m: [1;36m0.5445544719696045[0m
        [1m}[0m,
        [1;36m0.55[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.41584157943725586[0m,
            [1;36m2[0m: [1;36m0.5148515105247498[0m,
            [1;36m3[0m: [1;36m0.5049505233764648[0m,
            [1;36m4[0m: [1;36m0.5049505233764648[0m,
            [1;36m5[0m: [1;36m0.5445544719696045[0m
        [1m}[0m,
        [1;36m0.6[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.41584157943725586[0m,
            [1;36m2[0m: [1;36m0.5148515105247498[0m,
            [1;36m3[0m: [1;36m0.5049505233764648[0m,
            [1;36m4

In [None]:
# Target boxes are subset of prediction boxes but with random classes i.e. recall should be high

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 10, 6) * 128, "cccwhd") for i in range(25)]
pred_objectness_probabilities = [torch.rand(i + 10) for i in range(25)]
pred_class_probabilities = [torch.rand(i + 10, 5) for i in range(25)]

target_bboxes = [pred_bboxes[i][: i + 1] for i in range(25)]
target_classes = [pred_class_probabilities[i][: i + 1].argmax(dim=-1) + 1 for i in range(25)]

print([x.shape for x in pred_objectness_probabilities])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_objectness_probabilities,
    pred_class_probabilities,
    target_bboxes,
    target_classes,
    return_intermediates=True,
)

[torch.Size([10]), torch.Size([11]), torch.Size([12]), torch.Size([13]), torch.Size([14]), torch.Size([15]), torch.Size([16]), torch.Size([17]), torch.Size([18]), torch.Size([19]), torch.Size([20]), torch.Size([21]), torch.Size([22]), torch.Size([23]), torch.Size([24]), torch.Size([25]), torch.Size([26]), torch.Size([27]), torch.Size([28]), torch.Size([29]), torch.Size([30]), torch.Size([31]), torch.Size([32]), torch.Size([33]), torch.Size([34])]
[torch.Size([1]), torch.Size([2]), torch.Size([3]), torch.Size([4]), torch.Size([5]), torch.Size([6]), torch.Size([7]), torch.Size([8]), torch.Size([9]), torch.Size([10]), torch.Size([11]), torch.Size([12]), torch.Size([13]), torch.Size([14]), torch.Size([15]), torch.Size([16]), torch.Size([17]), torch.Size([18]), torch.Size([19]), torch.Size([20]), torch.Size([21]), torch.Size([22]), torch.Size([23]), torch.Size([24]), torch.Size([25])]



[1m([0m
    [1;36m0.6408372521400452[0m,
    [1;36m1.0[0m,
    [1m{[0m
        [1;36m0.5[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.6826467514038086[0m,
            [1;36m2[0m: [1;36m0.6561459898948669[0m,
            [1;36m3[0m: [1;36m0.730953574180603[0m,
            [1;36m4[0m: [1;36m0.5114720463752747[0m,
            [1;36m5[0m: [1;36m0.622967541217804[0m
        [1m}[0m,
        [1;36m0.55[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.6826467514038086[0m,
            [1;36m2[0m: [1;36m0.6561459898948669[0m,
            [1;36m3[0m: [1;36m0.730953574180603[0m,
            [1;36m4[0m: [1;36m0.5114720463752747[0m,
            [1;36m5[0m: [1;36m0.622967541217804[0m
        [1m}[0m,
        [1;36m0.6[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.6826467514038086[0m,
            [1;36m2[0m: [1;36m0.6561459898948669[0m,
            [1;36m3[0m: [1;36m0.730953574180603[0m,
            [1;36m4[0m: [1;36m0.5114720463

### Lightning metrics

In [None]:
# | export


class MeanAveragePrecisionMeanAverageRecall(Metric):
    """Calculate the COCO mean average precision (mAP) and mean average recall (mAR) for object detection."""

    def __init__(
        self,
        iou_thresholds: list[float] = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95],
        average_precision_num_points: int = 101,
        min_confidence_threshold: float = 0.0,
        max_bboxes_per_image: int | None = 100,
    ):
        """Initialize the MeanAveragePrecisionMeanAverageRecall metric.

        Args:
            num_classes: Number of classes in the dataset.
            iou_thresholds: A list of IoU thresholds to use for calculating mAP and mAR.
            average_precision_num_points: Number of points over which to calculate average precision.
            min_confidence_threshold: Minimum confidence score threshold to consider a prediction.
            max_bboxes_per_image: Maximum number of bounding boxes to consider per image. If more are present, only the
                top `max_bboxes_per_image` boxes based on confidence scores are considered.
        """
        super().__init__()

        self.iou_thresholds = iou_thresholds
        self.average_precision_num_points = average_precision_num_points
        self.min_confidence_threshold = min_confidence_threshold
        self.max_bboxes_per_image = max_bboxes_per_image

        self.add_state("pred_bboxes", [], dist_reduce_fx=None, persistent=False)
        self.add_state("pred_objectness_probabilities", [], dist_reduce_fx=None, persistent=False)
        self.add_state("pred_class_probabilities", [], dist_reduce_fx=None, persistent=False)
        self.add_state("target_bboxes", [], dist_reduce_fx=None, persistent=False)
        self.add_state("target_classes", [], dist_reduce_fx=None, persistent=False)

    def update(
        self,
        pred_bboxes: list[torch.Tensor],
        pred_objectness_probabilities: list[torch.Tensor],
        pred_class_probabilities: list[torch.Tensor],
        target_bboxes: list[torch.Tensor],
        target_classes: list[torch.Tensor],
    ):
        self.pred_bboxes.extend(pred_bboxes)
        self.pred_objectness_probabilities.extend(pred_objectness_probabilities)
        self.pred_class_probabilities.extend(pred_class_probabilities)
        self.target_bboxes.extend(target_bboxes)
        self.target_classes.extend(target_classes)

    def compute(self):
        return mean_average_precision_mean_average_recall(
            self.pred_bboxes,
            self.pred_objectness_probabilities,
            self.pred_class_probabilities,
            self.target_bboxes,
            self.target_classes,
            iou_thresholds=self.iou_thresholds,
            average_precision_num_points=self.average_precision_num_points,
            min_confidence_threshold=self.min_confidence_threshold,
            max_bboxes_per_image=self.max_bboxes_per_image,
        )

    def forward(self, *args, return_metrics: Literal["map_only", "mar_only", "map_mar"] = "map_mar", **kwargs):
        map, mar = super().forward(*args, **kwargs)
        if return_metrics == "map_only":
            return map
        elif return_metrics == "mar_only":
            return mar
        return map, mar


# Aliases
MeanAveragePrecisionRecall = MeanAveragePrecisionMeanAverageRecall

In [None]:
test = MeanAveragePrecisionMeanAverageRecall(max_bboxes_per_image=100)

for _ in range(100):
    pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 5, 6) * 128, "cccwhd") for i in range(10)]
    pred_objectness_probabilities = [torch.rand(i + 5) for i in range(10)]
    pred_class_probabilities = [torch.rand(i + 5, 3) for i in range(10)]

    target_bboxes = pred_bboxes
    target_classes = [torch.randint(1, 4, (i + 5,)) for i in range(10)]

    map, mar = test(
        pred_bboxes,
        pred_objectness_probabilities,
        pred_class_probabilities,
        pred_bboxes,
        target_classes,
    )
    print(map, mar)

print(len(test.pred_bboxes))
test.reset()
print(len(test.pred_bboxes))

0.12013497948646545 0.31321558356285095
0.1737494319677353 0.3393162786960602
0.13819558918476105 0.3115581274032593
0.14782793819904327 0.354580283164978
0.1189122125506401 0.29157084226608276
0.10286863148212433 0.2706470191478729
0.1637032926082611 0.31627780199050903
0.10249726474285126 0.2940656542778015
0.17166513204574585 0.37528958916664124
0.08065803349018097 0.24506725370883942
0.11709938198328018 0.2689473032951355
0.15382486581802368 0.298595130443573
0.09263339638710022 0.2678799629211426
0.16220563650131226 0.3891448974609375
0.07644350826740265 0.21815475821495056
0.1997639238834381 0.3837607502937317
0.10294722765684128 0.29992738366127014
0.12229587137699127 0.2946907877922058
0.1488412320613861 0.3504139482975006
0.1799801141023636 0.37070709466934204
0.14648985862731934 0.3479166328907013
0.08585381507873535 0.20489419996738434
0.1423136293888092 0.30973130464553833
0.15525752305984497 0.32660767436027527
0.11408495903015137 0.27222222089767456
0.18081574141979218 0.

In [None]:
# | export


class MeanAveragePrecision(MeanAveragePrecisionMeanAverageRecall):
    """Calculate the COCO mean average precision (mAP) for object detection."""

    def forward(self, *args, **kwargs):
        return super().forward(*args, return_metrics="map_only", **kwargs)

In [None]:
# | export


class MeanAverageRecall(MeanAveragePrecisionMeanAverageRecall):
    """Calculate the COCO mean average recall (mAR) for object detection."""

    def forward(self, *args, **kwargs):
        return super().forward(*args, return_metrics="mar_only", **kwargs)

# nbdev

In [None]:
!nbdev_export