In [None]:
# | default_exp metrics/detection

# Imports

In [None]:
# | export


from typing import Literal

import torch
from torchmetrics import Metric

from vision_architectures.utils.bounding_boxes import get_tps_fps_fns, sort_by_first_column_descending

In [None]:
from monai.data.box_utils import convert_box_to_standard_mode

# Mean Average Precision

### Direct function

In [None]:
# | export


def mean_average_precision_mean_average_recall(
    pred_bboxes: list[torch.Tensor],
    pred_confidence_scores: list[torch.Tensor],
    target_bboxes: list[torch.Tensor],
    target_classes: list[torch.Tensor],
    iou_thresholds: list[float] = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95],
    average_precision_num_points: int = 101,
    min_confidence_threshold: float = 0.0,
    max_bboxes_per_image: int | None = 100,
    return_intermediates: bool = False,
) -> tuple[float, float] | tuple[float, float, dict[float, dict[int, float]], dict[float, dict[int, float]]]:
    """Calculate the COCO mean average precision (mAP) for object detection.

    Args:
        pred_bboxes: A list of length B containing tensors of shape (NP, 4) or (NP, 6) containing the predicted bounding
            box parameters in xyxy or xyzxyz format.
        pred_confidence_scores: A list of length B containing tensors of shape (NP, 1+num_classes) containing the
            predicted confidence scores for each class. Note that the first column corresponds to the "no-object" class,
            and bounding boxes which fall in this category are ignored.
        target_bboxes: A list of length B containing tensors of shape (NT, 4) or (NT, 6) containing the target bounding
            box parameters in xyxy or xyzxyz format.
        target_classes: A list of length B containing tensors of shape (NT,) containing the target class labels for the
            objects in the image.
        iou_thresholds: A list of IoU thresholds to use for calculating mAP and mAR.
        average_precision_num_points: Number of points over which to calculate average precision.
        min_confidence_threshold: Minimum confidence probability threshold to consider a prediction.
        max_bboxes_per_image: Maximum number of bounding boxes to consider per image. If more are present, only the top
            `max_bboxes_per_image` boxes based on confidence scores are considered. If set to None, all bounding boxes
            are considered.
        return_intermediates: If True, return intermediate values used to calculate mAP and mAR.

    Returns:
        The mean average precision (mAP) and mean average recall (mAR) across all classes and IoU thresholds for the
        entire dataset.
        If `return_intermediates` is True, also returns two dictionaries containing the average precision and average
        recall for each class at each IoU threshold.
    """
    # Set some globaly used variables
    B = len(pred_bboxes)
    num_classes = pred_confidence_scores[0].shape[-1] - 1

    # Some basic tests
    assert len(pred_bboxes) == len(pred_confidence_scores) == len(target_bboxes) == len(target_classes) == B, (
        f"All input lists must have the same length. Got lengths: {len(pred_bboxes)}, {len(pred_confidence_scores)}, "
        f"{len(target_bboxes)}, {len(target_classes)}."
    )
    assert all(
        pred_bbox.shape[0] == pred_confidence_score.shape[0]
        for pred_bbox, pred_confidence_score in zip(pred_bboxes, pred_confidence_scores)
    ), "Each prediction input list element must have the same number of bounding boxes."
    assert all(
        pred_bbox.shape[1] == 4 or pred_bbox.shape[1] == 6 for pred_bbox in pred_bboxes
    ), "Prediction bounding boxes must have shape (NP, 4) or (NP, 6)."
    assert all(
        pred_confidence_score.shape[-1] == num_classes + 1 for pred_confidence_score in pred_confidence_scores
    ), "Prediction class probabilities must have shape (NP, 1 + num_classes)."
    assert all(
        target_bbox.shape[0] == target_class.shape[0]
        for target_bbox, target_class in zip(target_bboxes, target_classes)
    ), "Each target must have the same number of bounding boxes."

    # Split everything based on different classes.
    pred_bboxes_by_class = [[] for _ in range(num_classes)]
    pred_confidences_scores_by_class = [[] for _ in range(num_classes)]
    target_bboxes_by_class = [[] for _ in range(num_classes)]
    for b in range(B):
        pred_classes = torch.argmax(pred_confidence_scores[b], dim=-1)
        # (NP,)

        for c in range(num_classes):
            pred_classes_mask = pred_classes == (c + 1)
            # (NP,)
            target_classes_mask = target_classes[b] == (c + 1)
            # (NT,)

            pred_bboxes_by_class[c].append(pred_bboxes[b][pred_classes_mask])
            pred_confidences_scores_by_class[c].append(pred_confidence_scores[b][pred_classes_mask][:, c + 1])
            # (NP,)

            target_bboxes_by_class[c].append(target_bboxes[b][target_classes_mask])
            # (NT,)

    # Limit number of bounding boxes per image if applicable
    if max_bboxes_per_image is not None:
        for b in range(B):
            _confidence_scores = []
            for c in range(num_classes):
                if pred_bboxes_by_class[c][b].numel() > 0:
                    _confidence_scores.append(
                        torch.stack(
                            [
                                pred_confidences_scores_by_class[c][b],
                                torch.arange(
                                    pred_confidences_scores_by_class[c][b].shape[0], device=pred_bboxes[b].device
                                ),
                                torch.full_like(pred_confidences_scores_by_class[c][b], c),
                            ],
                            dim=-1,
                        )
                    )
            if len(_confidence_scores) == 0:
                continue
            _confidence_scores = torch.cat(_confidence_scores, dim=0)
            # (NC, 3)

            if _confidence_scores.shape[0] > max_bboxes_per_image:
                _confidence_scores = sort_by_first_column_descending(_confidence_scores)
                topk_confidences = _confidence_scores[:max_bboxes_per_image]
                # (max_bboxes_per_image, 3)
                for c in range(num_classes):
                    class_mask = topk_confidences[:, 2] == c
                    # (max_bboxes_per_image,)
                    offsets_to_keep = topk_confidences[class_mask][:, 1].long()
                    pred_bboxes_by_class[c][b] = pred_bboxes_by_class[c][b][offsets_to_keep]
                    # (NP', 4) or (NP', 6)
                    pred_confidences_scores_by_class[c][b] = pred_confidences_scores_by_class[c][b][offsets_to_keep]
                    # (NP',)

    # For each IOU threshold, calculate average precision and average recall
    average_precisions = {}
    average_recalls = {}
    for iou_threshold in iou_thresholds:
        # For each class calculate average precision and average recall
        class_average_precisions = {}
        class_average_recalls = {}
        for c in range(num_classes):
            # If no target boxes for this class, skip it
            if all(target_bbox.numel() == 0 for target_bbox in target_bboxes_by_class[c]):
                class_average_precisions[c + 1] = float("nan")
                class_average_recalls[c + 1] = float("nan")
                continue

            _, _, _, intermediate_counts = get_tps_fps_fns(
                pred_bboxes=pred_bboxes_by_class[c],
                pred_confidence_scores=pred_confidences_scores_by_class[c],
                target_bboxes=target_bboxes_by_class[c],
                iou_threshold=iou_threshold,
                matching_method="coco",
                min_confidence_threshold=min_confidence_threshold,
                max_bboxes_per_image=max_bboxes_per_image,
                return_intermediate_counts=True,
            )
            intermediate_counts = torch.tensor(intermediate_counts, device=pred_bboxes[0].device, dtype=torch.float32)
            # (NC, 3) where the first column is TP, second is FP and third is FN for each prediction considered
            precisions = intermediate_counts[:, 0] / (intermediate_counts[:, 0] + intermediate_counts[:, 1] + 1e-8)
            recalls = intermediate_counts[:, 0] / (intermediate_counts[:, 0] + intermediate_counts[:, 2] + 1e-8)
            # (NC,) each

            # Precision envelope: P_interp(r) = max_{r' >= r} P(r')
            enveloped_precisions = precisions.clone()
            for i in range(len(enveloped_precisions) - 2, -1, -1):
                if enveloped_precisions[i] < enveloped_precisions[i + 1]:
                    enveloped_precisions[i] = enveloped_precisions[i + 1]

            # Calculate average precision using step-wise interpolation
            recall_samples = torch.linspace(0, 1, average_precision_num_points, device=recalls.device)
            idxs = torch.searchsorted(recalls, recall_samples, side="left")
            valid = idxs < enveloped_precisions.numel()
            enveloped_precisions_at_t = torch.zeros_like(recall_samples)
            enveloped_precisions_at_t[valid] = enveloped_precisions[idxs[valid]]
            class_average_precisions[c + 1] = enveloped_precisions_at_t.mean().item()

            # Calculate average recall i.e. maximum recall achieved at this IoU threshold
            class_average_recalls[c + 1] = recalls.max().item() if recalls.numel() > 0 else 0.0

        average_precisions[iou_threshold] = class_average_precisions
        average_recalls[iou_threshold] = class_average_recalls

    map_metric = torch.nanmean(
        torch.stack([torch.tensor(ap) for iou_aps in average_precisions.values() for ap in iou_aps.values()])
    ).item()
    mar_metric = torch.nanmean(
        torch.stack([torch.tensor(ar) for iou_ars in average_recalls.values() for ar in iou_ars.values()])
    ).item()

    if return_intermediates:
        return map_metric, mar_metric, average_precisions, average_recalls
    return map_metric, mar_metric


# Create aliases
map_mar = mean_average_precision_mean_average_recall
mean_average_precision_recall = mean_average_precision_mean_average_recall

In [None]:
# Random predicted and target boxes

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 10, 6) * 128, "cccwhd") for i in range(25)]
pred_confidence_scores = [torch.rand(i + 10, 6) for i in range(25)]

target_bboxes = [convert_box_to_standard_mode(torch.rand(i + 1 + 10 * (i % 2), 6) * 128, "cccwhd") for i in range(25)]
target_classes = [torch.randint(1, 6, (i + 1 + 10 * (i % 2),)) for i in range(25)]

print([x.shape for x in pred_confidence_scores])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_confidence_scores,
    target_bboxes,
    target_classes,
    iou_thresholds=[0.001],
    return_intermediates=True,
)

[torch.Size([10, 6]), torch.Size([11, 6]), torch.Size([12, 6]), torch.Size([13, 6]), torch.Size([14, 6]), torch.Size([15, 6]), torch.Size([16, 6]), torch.Size([17, 6]), torch.Size([18, 6]), torch.Size([19, 6]), torch.Size([20, 6]), torch.Size([21, 6]), torch.Size([22, 6]), torch.Size([23, 6]), torch.Size([24, 6]), torch.Size([25, 6]), torch.Size([26, 6]), torch.Size([27, 6]), torch.Size([28, 6]), torch.Size([29, 6]), torch.Size([30, 6]), torch.Size([31, 6]), torch.Size([32, 6]), torch.Size([33, 6]), torch.Size([34, 6])]
[torch.Size([1]), torch.Size([12]), torch.Size([3]), torch.Size([14]), torch.Size([5]), torch.Size([16]), torch.Size([7]), torch.Size([18]), torch.Size([9]), torch.Size([20]), torch.Size([11]), torch.Size([22]), torch.Size([13]), torch.Size([24]), torch.Size([15]), torch.Size([26]), torch.Size([17]), torch.Size([28]), torch.Size([19]), torch.Size([30]), torch.Size([21]), torch.Size([32]), torch.Size([23]), torch.Size([34]), torch.Size([25])]



[1m([0m
    [1;36m0.3235795497894287[0m,
    [1;36m0.5082506537437439[0m,
    [1m{[0m
        [1;36m0.001[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.3801289200782776[0m,
            [1;36m2[0m: [1;36m0.30422544479370117[0m,
            [1;36m3[0m: [1;36m0.34670326113700867[0m,
            [1;36m4[0m: [1;36m0.3678421974182129[0m,
            [1;36m5[0m: [1;36m0.21899795532226562[0m
        [1m}[0m
    [1m}[0m,
    [1m{[0m
        [1;36m0.001[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.5529412031173706[0m,
            [1;36m2[0m: [1;36m0.45652174949645996[0m,
            [1;36m3[0m: [1;36m0.5333333611488342[0m,
            [1;36m4[0m: [1;36m0.5049505233764648[0m,
            [1;36m5[0m: [1;36m0.4935064911842346[0m
        [1m}[0m
    [1m}[0m
[1m)[0m

In [None]:
# Predicted boxes are approximately equal to target boxes i.e. precision should be high

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i, 6) * 128, "cccwhd") for i in range(25)]
pred_confidence_scores = [torch.rand(i, 6) for i in range(25)]

target_bboxes = [pred_bboxes[i] + 0.5 for i in range(25)]
target_classes = [pred_confidence_scores[i].argmax(dim=-1) for i in range(25)]

print([x.shape for x in pred_confidence_scores])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_confidence_scores,
    target_bboxes,
    target_classes,
    return_intermediates=True,
)

[torch.Size([0, 6]), torch.Size([1, 6]), torch.Size([2, 6]), torch.Size([3, 6]), torch.Size([4, 6]), torch.Size([5, 6]), torch.Size([6, 6]), torch.Size([7, 6]), torch.Size([8, 6]), torch.Size([9, 6]), torch.Size([10, 6]), torch.Size([11, 6]), torch.Size([12, 6]), torch.Size([13, 6]), torch.Size([14, 6]), torch.Size([15, 6]), torch.Size([16, 6]), torch.Size([17, 6]), torch.Size([18, 6]), torch.Size([19, 6]), torch.Size([20, 6]), torch.Size([21, 6]), torch.Size([22, 6]), torch.Size([23, 6]), torch.Size([24, 6])]
[torch.Size([0]), torch.Size([1]), torch.Size([2]), torch.Size([3]), torch.Size([4]), torch.Size([5]), torch.Size([6]), torch.Size([7]), torch.Size([8]), torch.Size([9]), torch.Size([10]), torch.Size([11]), torch.Size([12]), torch.Size([13]), torch.Size([14]), torch.Size([15]), torch.Size([16]), torch.Size([17]), torch.Size([18]), torch.Size([19]), torch.Size([20]), torch.Size([21]), torch.Size([22]), torch.Size([23]), torch.Size([24])]



[1m([0m
    [1;36m0.7756006717681885[0m,
    [1;36m0.840447187423706[0m,
    [1m{[0m
        [1;36m0.5[0m: [1m{[0m[1;36m1[0m: [1;36m1.0[0m, [1;36m2[0m: [1;36m0.8976839184761047[0m, [1;36m3[0m: [1;36m1.0[0m, [1;36m4[0m: [1;36m0.9666128754615784[0m, [1;36m5[0m: [1;36m0.9423359036445618[0m[1m}[0m,
        [1;36m0.55[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.9762377738952637[0m,
            [1;36m2[0m: [1;36m0.8976839184761047[0m,
            [1;36m3[0m: [1;36m0.9657965302467346[0m,
            [1;36m4[0m: [1;36m0.9666128754615784[0m,
            [1;36m5[0m: [1;36m0.9113795161247253[0m
        [1m}[0m,
        [1;36m0.6[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.9762377738952637[0m,
            [1;36m2[0m: [1;36m0.8734544515609741[0m,
            [1;36m3[0m: [1;36m0.9657965302467346[0m,
            [1;36m4[0m: [1;36m0.9666128754615784[0m,
            [1;36m5[0m: [1;36m0.9113795161247253[0m
        [1m}

In [None]:
# Predicted boxes are subset of target boxes but with random classes

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i, 6) * 128, "cccwhd") for i in range(25)]
pred_confidence_scores = [torch.rand(i, 6) for i in range(25)]

target_bboxes = [
    torch.cat([pred_bboxes[i], convert_box_to_standard_mode(torch.rand(i, 6) * 128, "cccwhd")]) for i in range(25)
]
target_classes = [torch.cat([pred_confidence_scores[i].argmax(dim=-1), torch.randint(1, 6, (i,))]) for i in range(25)]

print([x.shape for x in pred_confidence_scores])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_confidence_scores,
    target_bboxes,
    target_classes,
    return_intermediates=True,
)

[torch.Size([0, 6]), torch.Size([1, 6]), torch.Size([2, 6]), torch.Size([3, 6]), torch.Size([4, 6]), torch.Size([5, 6]), torch.Size([6, 6]), torch.Size([7, 6]), torch.Size([8, 6]), torch.Size([9, 6]), torch.Size([10, 6]), torch.Size([11, 6]), torch.Size([12, 6]), torch.Size([13, 6]), torch.Size([14, 6]), torch.Size([15, 6]), torch.Size([16, 6]), torch.Size([17, 6]), torch.Size([18, 6]), torch.Size([19, 6]), torch.Size([20, 6]), torch.Size([21, 6]), torch.Size([22, 6]), torch.Size([23, 6]), torch.Size([24, 6])]
[torch.Size([0]), torch.Size([2]), torch.Size([4]), torch.Size([6]), torch.Size([8]), torch.Size([10]), torch.Size([12]), torch.Size([14]), torch.Size([16]), torch.Size([18]), torch.Size([20]), torch.Size([22]), torch.Size([24]), torch.Size([26]), torch.Size([28]), torch.Size([30]), torch.Size([32]), torch.Size([34]), torch.Size([36]), torch.Size([38]), torch.Size([40]), torch.Size([42]), torch.Size([44]), torch.Size([46]), torch.Size([48])]



[1m([0m
    [1;36m0.4534653425216675[0m,
    [1;36m0.45213639736175537[0m,
    [1m{[0m
        [1;36m0.5[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.4653465449810028[0m,
            [1;36m2[0m: [1;36m0.4653465449810028[0m,
            [1;36m3[0m: [1;36m0.4455445408821106[0m,
            [1;36m4[0m: [1;36m0.48514851927757263[0m,
            [1;36m5[0m: [1;36m0.40594059228897095[0m
        [1m}[0m,
        [1;36m0.55[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.4653465449810028[0m,
            [1;36m2[0m: [1;36m0.4653465449810028[0m,
            [1;36m3[0m: [1;36m0.4455445408821106[0m,
            [1;36m4[0m: [1;36m0.48514851927757263[0m,
            [1;36m5[0m: [1;36m0.40594059228897095[0m
        [1m}[0m,
        [1;36m0.6[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.4653465449810028[0m,
            [1;36m2[0m: [1;36m0.4653465449810028[0m,
            [1;36m3[0m: [1;36m0.4455445408821106[0m,
            [1;36m4

In [None]:
# Target boxes are subset of prediction boxes but with random classes i.e. recall should be high

pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 10, 6) * 128, "cccwhd") for i in range(25)]
pred_confidence_scores = [torch.rand(i + 10, 6) for i in range(25)]

target_bboxes = [pred_bboxes[i][: i + 1] for i in range(25)]
target_classes = [pred_confidence_scores[i][: i + 1].argmax(dim=-1) for i in range(25)]

print([x.shape for x in pred_confidence_scores])
print([x.shape for x in target_classes])
map_mar(
    pred_bboxes,
    pred_confidence_scores,
    target_bboxes,
    target_classes,
    return_intermediates=True,
)

[torch.Size([10, 6]), torch.Size([11, 6]), torch.Size([12, 6]), torch.Size([13, 6]), torch.Size([14, 6]), torch.Size([15, 6]), torch.Size([16, 6]), torch.Size([17, 6]), torch.Size([18, 6]), torch.Size([19, 6]), torch.Size([20, 6]), torch.Size([21, 6]), torch.Size([22, 6]), torch.Size([23, 6]), torch.Size([24, 6]), torch.Size([25, 6]), torch.Size([26, 6]), torch.Size([27, 6]), torch.Size([28, 6]), torch.Size([29, 6]), torch.Size([30, 6]), torch.Size([31, 6]), torch.Size([32, 6]), torch.Size([33, 6]), torch.Size([34, 6])]
[torch.Size([1]), torch.Size([2]), torch.Size([3]), torch.Size([4]), torch.Size([5]), torch.Size([6]), torch.Size([7]), torch.Size([8]), torch.Size([9]), torch.Size([10]), torch.Size([11]), torch.Size([12]), torch.Size([13]), torch.Size([14]), torch.Size([15]), torch.Size([16]), torch.Size([17]), torch.Size([18]), torch.Size([19]), torch.Size([20]), torch.Size([21]), torch.Size([22]), torch.Size([23]), torch.Size([24]), torch.Size([25])]



[1m([0m
    [1;36m0.6328680515289307[0m,
    [1;36m1.0[0m,
    [1m{[0m
        [1;36m0.5[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.6783559322357178[0m,
            [1;36m2[0m: [1;36m0.6269821524620056[0m,
            [1;36m3[0m: [1;36m0.5948416590690613[0m,
            [1;36m4[0m: [1;36m0.5922386050224304[0m,
            [1;36m5[0m: [1;36m0.6719220876693726[0m
        [1m}[0m,
        [1;36m0.55[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.6783559322357178[0m,
            [1;36m2[0m: [1;36m0.6269821524620056[0m,
            [1;36m3[0m: [1;36m0.5948416590690613[0m,
            [1;36m4[0m: [1;36m0.5922386050224304[0m,
            [1;36m5[0m: [1;36m0.6719220876693726[0m
        [1m}[0m,
        [1;36m0.6[0m: [1m{[0m
            [1;36m1[0m: [1;36m0.6783559322357178[0m,
            [1;36m2[0m: [1;36m0.6269821524620056[0m,
            [1;36m3[0m: [1;36m0.5948416590690613[0m,
            [1;36m4[0m: [1;36m0.59223

### Lightning metrics

In [None]:
# | export


class _MeanAveragePrecisionMeanAverageRecallBase(Metric):
    """Calculate the COCO mean average precision (mAP) and mean average recall (mAR) for object detection."""

    is_differentiable: bool = False
    higher_is_better: bool = True
    plot_lower_bound: float = 0.0
    plot_upper_bound: float = 1.0

    def __init__(
        self,
        iou_thresholds: list[float] = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95],
        average_precision_num_points: int = 101,
        min_confidence_threshold: float = 0.0,
        max_bboxes_per_image: int | None = 100,
    ):
        """Initialize the MeanAveragePrecisionMeanAverageRecall metric.

        Args:
            num_classes: Number of classes in the dataset.
            iou_thresholds: A list of IoU thresholds to use for calculating mAP and mAR.
            average_precision_num_points: Number of points over which to calculate average precision.
            min_confidence_threshold: Minimum confidence score threshold to consider a prediction.
            max_bboxes_per_image: Maximum number of bounding boxes to consider per image. If more are present, only the
                top `max_bboxes_per_image` boxes based on confidence scores are considered.
        """
        super().__init__()

        self.iou_thresholds = iou_thresholds
        self.average_precision_num_points = average_precision_num_points
        self.min_confidence_threshold = min_confidence_threshold
        self.max_bboxes_per_image = max_bboxes_per_image

        self.add_state("pred_bboxes", [], dist_reduce_fx=None, persistent=False)
        self.add_state("pred_confidence_scores", [], dist_reduce_fx=None, persistent=False)
        self.add_state("target_bboxes", [], dist_reduce_fx=None, persistent=False)
        self.add_state("target_classes", [], dist_reduce_fx=None, persistent=False)

    def update(
        self,
        pred_bboxes: list[torch.Tensor],
        pred_confidence_scores: list[torch.Tensor],
        target_bboxes: list[torch.Tensor],
        target_classes: list[torch.Tensor],
    ):
        self.pred_bboxes.extend(pred_bboxes)
        self.pred_confidence_scores.extend(pred_confidence_scores)
        self.target_bboxes.extend(target_bboxes)
        self.target_classes.extend(target_classes)

    def run_functional(self, return_metrics: Literal["map_only", "mar_only"]):
        map_metric, mar_metric = mean_average_precision_mean_average_recall(
            self.pred_bboxes,
            self.pred_confidence_scores,
            self.target_bboxes,
            self.target_classes,
            iou_thresholds=self.iou_thresholds,
            average_precision_num_points=self.average_precision_num_points,
            min_confidence_threshold=self.min_confidence_threshold,
            max_bboxes_per_image=self.max_bboxes_per_image,
        )
        if return_metrics == "map_only":
            return torch.tensor(map_metric, device=self.pred_bboxes[0].device)
        elif return_metrics == "mar_only":
            return torch.tensor(mar_metric, device=self.pred_bboxes[0].device)
        raise NotImplementedError('Only "map_only" and "mar_only" are supported.')

In [None]:
# | export


class MeanAveragePrecision(_MeanAveragePrecisionMeanAverageRecallBase):
    """Calculate the COCO mean average precision (mAP) for object detection."""

    def compute(self):
        return self.run_functional("map_only")

In [None]:
test = MeanAveragePrecision(max_bboxes_per_image=100)

for _ in range(100):
    pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 5, 6) * 128, "cccwhd") for i in range(10)]
    pred_confidence_scores = [torch.rand(i + 5, 4) for i in range(10)]

    target_bboxes = pred_bboxes
    target_classes = [torch.randint(1, 4, (i + 5,)) for i in range(10)]

    map_metric = test(
        pred_bboxes,
        pred_confidence_scores,
        pred_bboxes,
        target_classes,
    )
    print(map_metric)

print(len(test.pred_bboxes))
test.reset()
print(len(test.pred_bboxes))

tensor(0.1404)
tensor(0.0991)
tensor(0.0623)
tensor(0.1222)
tensor(0.0668)
tensor(0.1412)
tensor(0.1888)
tensor(0.0514)
tensor(0.0890)
tensor(0.1042)
tensor(0.1249)
tensor(0.1432)
tensor(0.0962)
tensor(0.1763)
tensor(0.1479)
tensor(0.1851)
tensor(0.1664)
tensor(0.1232)
tensor(0.0904)
tensor(0.1330)
tensor(0.0839)
tensor(0.1173)
tensor(0.0551)
tensor(0.1098)
tensor(0.0788)
tensor(0.1773)
tensor(0.1451)
tensor(0.1120)
tensor(0.1534)
tensor(0.1284)
tensor(0.1153)
tensor(0.0499)
tensor(0.1323)
tensor(0.0828)
tensor(0.1559)
tensor(0.1699)
tensor(0.0912)
tensor(0.0595)
tensor(0.0677)
tensor(0.0897)
tensor(0.1617)
tensor(0.1017)
tensor(0.1365)
tensor(0.1214)
tensor(0.0767)
tensor(0.1222)
tensor(0.1675)
tensor(0.1821)
tensor(0.1345)
tensor(0.1514)
tensor(0.1188)
tensor(0.1550)
tensor(0.1478)
tensor(0.1590)
tensor(0.1045)
tensor(0.0940)
tensor(0.1022)
tensor(0.1750)
tensor(0.1999)
tensor(0.1444)
tensor(0.0864)
tensor(0.1668)
tensor(0.1513)
tensor(0.1144)
tensor(0.1269)
tensor(0.1276)
tensor(0.0

In [None]:
# | export


class MeanAverageRecall(_MeanAveragePrecisionMeanAverageRecallBase):
    """Calculate the COCO mean average recall (mAR) for object detection."""

    def compute(self):
        return self.run_functional("mar_only")

In [None]:
test = MeanAverageRecall(max_bboxes_per_image=100)

for _ in range(100):
    pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 5, 6) * 128, "cccwhd") for i in range(10)]
    pred_confidence_scores = [torch.rand(i + 5, 4) for i in range(10)]

    target_bboxes = pred_bboxes
    target_classes = [torch.randint(1, 4, (i + 5,)) for i in range(10)]

    mar_metric = test(
        pred_bboxes,
        pred_confidence_scores,
        pred_bboxes,
        target_classes,
    )
    print(mar_metric)

print(len(test.pred_bboxes))
test.reset()
print(len(test.pred_bboxes))

tensor(0.2619)
tensor(0.2240)
tensor(0.2443)
tensor(0.2224)
tensor(0.1682)
tensor(0.2160)
tensor(0.3165)
tensor(0.2741)
tensor(0.2338)
tensor(0.3032)
tensor(0.3166)
tensor(0.3248)
tensor(0.3276)
tensor(0.2116)
tensor(0.2416)
tensor(0.2536)
tensor(0.2780)
tensor(0.2766)
tensor(0.3098)
tensor(0.3622)
tensor(0.2107)
tensor(0.3280)
tensor(0.2614)
tensor(0.2324)
tensor(0.2317)
tensor(0.2069)
tensor(0.1509)
tensor(0.2263)
tensor(0.3582)
tensor(0.2630)
tensor(0.2254)
tensor(0.3161)
tensor(0.2550)
tensor(0.2412)
tensor(0.2006)
tensor(0.3181)
tensor(0.2735)
tensor(0.2209)
tensor(0.2254)
tensor(0.2716)
tensor(0.1993)
tensor(0.1679)
tensor(0.2543)
tensor(0.2125)
tensor(0.2533)
tensor(0.3133)
tensor(0.3558)
tensor(0.1845)
tensor(0.3096)
tensor(0.3223)
tensor(0.2793)
tensor(0.2643)
tensor(0.1999)
tensor(0.2638)
tensor(0.3167)
tensor(0.3174)
tensor(0.2171)
tensor(0.3020)
tensor(0.2225)
tensor(0.2793)
tensor(0.2211)
tensor(0.2657)
tensor(0.2245)
tensor(0.2605)
tensor(0.3287)
tensor(0.3077)
tensor(0.2

In [None]:
# | export


class AveragePrecision(MeanAveragePrecision):
    """Calculate the COCO average precision (AP) for object detection."""

    def __init__(self, iou_threshold: float, *args, **kwargs):
        iou_thresholds = [iou_threshold]
        super().__init__(iou_thresholds=iou_thresholds, *args, **kwargs)

In [None]:
test = AveragePrecision(iou_threshold=0.1, max_bboxes_per_image=100)

for _ in range(100):
    pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 5, 6) * 128, "cccwhd") for i in range(10)]
    pred_confidence_scores = [torch.rand(i + 5, 4) for i in range(10)]

    target_bboxes = pred_bboxes
    target_classes = [torch.randint(1, 4, (i + 5,)) for i in range(10)]

    ap10 = test(
        pred_bboxes,
        pred_confidence_scores,
        pred_bboxes,
        target_classes,
    )
    print(ap10)

print(len(test.pred_bboxes))
test.reset()
print(len(test.pred_bboxes))

tensor(0.1928)
tensor(0.2184)
tensor(0.1991)
tensor(0.2696)
tensor(0.1385)
tensor(0.1686)
tensor(0.1754)
tensor(0.1207)
tensor(0.1436)
tensor(0.1552)
tensor(0.1292)
tensor(0.1554)
tensor(0.1337)
tensor(0.1203)
tensor(0.1738)
tensor(0.2090)
tensor(0.1681)
tensor(0.1294)
tensor(0.2085)
tensor(0.1600)
tensor(0.2941)
tensor(0.1859)
tensor(0.2067)
tensor(0.1602)
tensor(0.2566)
tensor(0.1181)
tensor(0.1716)
tensor(0.1515)
tensor(0.1993)
tensor(0.1820)
tensor(0.2554)
tensor(0.1597)
tensor(0.1419)
tensor(0.3179)
tensor(0.1547)
tensor(0.1632)
tensor(0.2357)
tensor(0.2232)
tensor(0.2361)
tensor(0.1290)
tensor(0.1436)
tensor(0.1443)
tensor(0.1413)
tensor(0.1707)
tensor(0.1946)
tensor(0.2324)
tensor(0.2241)
tensor(0.1922)
tensor(0.2396)
tensor(0.1638)
tensor(0.1751)
tensor(0.2145)
tensor(0.1496)
tensor(0.1936)
tensor(0.1394)
tensor(0.1144)
tensor(0.2061)
tensor(0.2624)
tensor(0.1995)
tensor(0.1759)
tensor(0.1834)
tensor(0.1364)
tensor(0.1722)
tensor(0.1137)
tensor(0.1549)
tensor(0.1783)
tensor(0.1

In [None]:
# | export


class AverageRecall(MeanAverageRecall):
    """Calculate the COCO average recall (AR) for object detection."""

    def __init__(self, iou_threshold: float, *args, **kwargs):
        iou_thresholds = [iou_threshold]
        super().__init__(iou_thresholds=iou_thresholds, *args, **kwargs)

In [None]:
test = AverageRecall(iou_threshold=0.1, max_bboxes_per_image=100)

for _ in range(100):
    pred_bboxes = [convert_box_to_standard_mode(torch.rand(i + 5, 6) * 128, "cccwhd") for i in range(10)]
    pred_confidence_scores = [torch.rand(i + 5, 4) for i in range(10)]

    target_bboxes = pred_bboxes
    target_classes = [torch.randint(1, 4, (i + 5,)) for i in range(10)]

    ar10 = test(
        pred_bboxes,
        pred_confidence_scores,
        pred_bboxes,
        target_classes,
    )
    print(ar10)

print(len(test.pred_bboxes))
test.reset()
print(len(test.pred_bboxes))

tensor(0.3118)
tensor(0.3958)
tensor(0.3361)
tensor(0.3693)
tensor(0.3675)
tensor(0.3491)
tensor(0.2971)
tensor(0.3728)
tensor(0.3054)
tensor(0.3075)
tensor(0.2865)
tensor(0.3315)
tensor(0.2254)
tensor(0.2240)
tensor(0.3199)
tensor(0.2811)
tensor(0.2381)
tensor(0.3233)
tensor(0.3201)
tensor(0.3192)
tensor(0.3095)
tensor(0.3231)
tensor(0.2634)
tensor(0.3040)
tensor(0.2982)
tensor(0.2189)
tensor(0.3628)
tensor(0.2727)
tensor(0.2615)
tensor(0.2737)
tensor(0.3659)
tensor(0.2063)
tensor(0.2967)
tensor(0.2901)
tensor(0.2788)
tensor(0.3231)
tensor(0.2992)
tensor(0.3917)
tensor(0.3527)
tensor(0.3106)
tensor(0.3959)
tensor(0.4212)
tensor(0.2536)
tensor(0.2991)
tensor(0.3865)
tensor(0.3193)
tensor(0.2706)
tensor(0.2346)
tensor(0.3431)
tensor(0.2574)
tensor(0.3275)
tensor(0.3302)
tensor(0.3136)
tensor(0.3662)
tensor(0.3728)
tensor(0.3269)
tensor(0.3012)
tensor(0.3756)
tensor(0.2569)
tensor(0.3337)
tensor(0.3097)
tensor(0.2257)
tensor(0.2854)
tensor(0.2753)
tensor(0.3474)
tensor(0.2519)
tensor(0.2

# nbdev

In [None]:
!nbdev_export