In [None]:
#| default_exp proposal_utils
#| export

import torch
from typing import List, Tuple
from qct_3d_nod_detect.structures import Boxes3D, Instances3D
from qct_3d_nod_detect.box_regression import pairwise_iou_3d

def nms_3d(
    boxes: torch.Tensor,
    scores: torch.Tensor,
    iou_threshold: float,
):

    """
    Greedy 3D NMS using pairwise_iou_3d(Boxes3D, Boxes3D)
    """

    if boxes.numel() == 0:
        return torch.empty((0,), dtype=torch.long, device=boxes.device)

    order = scores.sort(descending=True).indices
    keep = []

    boxes_all = Boxes3D(boxes)

    while order.numel() > 0:
        i = order[0]
        keep.append(i)

        if order.numel() == 1:
            break

        ious = pairwise_iou_3d(
            Boxes3D(boxes_all.tensor[i].unsqueeze(0)),
            Boxes3D(boxes_all.tensor[order[1:]]),
        ).squeeze(0)

        order = order[1:][ious <= iou_threshold]

    return torch.stack(keep)

def batched_nms_3d(
    boxes: torch.Tensor,
    scores: torch.Tensor,
    idxs: torch.Tensor,
    iou_threshold: float
):

    """
    3D version of torchvision.ops.batched_nms
    """

    if boxes.numel() == 0:
        return torch.empty((0,), dtype=torch.int64, device=boxes.device)

    max_coord = boxes.max()
    offsets = idxs.to(boxes) * (max_coord + 1)

    boxes_for_nms = boxes.clone()
    boxes_for_nms[:, :3] += offsets[:, None]
    boxes_for_nms[:, 3:] += offsets[:, None]

    return nms_3d(boxes_for_nms, scores, iou_threshold)

def find_top_rpn_proposals_3d(
    proposals: List[torch.Tensor],
    pred_objectness_logits: List[torch.Tensor],
    image_sizes: List[Tuple[int, int, int]],
    nms_thresh: float,
    pre_nms_topk: int,
    post_nms_topk: int,
    min_box_size: float,
    training: bool
):

    num_images = len(image_sizes)
    device = (
        proposals[0].device
        if torch.jit.is_scripting()
        else ("cpu" if torch.jit.is_tracing() else proposals[0].device)
    )

    # Select top k proposals per level of the image
    topk_scores, topk_proposals, level_ids = [], [], []

    batch_idx = torch.arange(num_images, device=device)

    for level_id, (proposals_i, logits_i) in enumerate(
        zip(proposals, pred_objectness_logits)
    ):

        Di_Hi_Wi_A = logits_i.shape[1]

        if isinstance(Di_Hi_Wi_A, torch.Tensor):
            num_proposals_i = torch.clamp(Di_Hi_Wi_A, max=pre_nms_topk)
        else:
            num_proposals_i = min(Di_Hi_Wi_A, pre_nms_topk)

        topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
        topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx]

        topk_scores.append(topk_scores_i)
        topk_proposals.append(topk_proposals_i)

        level_ids.append(
            torch.full(
                (num_proposals_i, ),
                level_id,
                dtype=torch.int64,
                device=device
            )
        )

    # Concatenate all levels
    topk_scores = torch.cat(topk_scores, dim=1) # (N, sum(topk))
    topk_proposals = torch.cat(topk_proposals, dim=1) # (N, sum(topk), 6)
    level_ids = torch.cat(level_ids, dim=0) # (sum(topk),)

    # Per image nms and filtering
    results = []

    for n, image_size in enumerate(image_sizes):
        boxes = Boxes3D(topk_proposals[n]) # (K, 6)
        scores_per_img = topk_scores[n]
        lvl = level_ids

        valid_mask = (
            torch.isfinite(boxes.tensor).all(dim=1)
            & torch.isfinite(scores_per_img)
        )

        if not valid_mask.all():
            if training:
                raise FloatingPointError(
                    "Predicted boxes or scores contain Inf/NaN. Training has diverged."
                )

            boxes = boxes[valid_mask]
            scores_per_img = scores_per_img[valid_mask]
            lvl = lvl[valid_mask]

        boxes.clip(image_size)

        keep = boxes.nonempty(threshold=min_box_size)
        if keep.sum().item() != len(boxes):
            boxes = boxes[keep]
            scores_per_img = scores_per_img[keep]
            lvl = lvl[keep]

        keep = batched_nms_3d(
            boxes.tensor,
            scores_per_img,
            lvl,
            nms_thresh
        )

        keep = keep[:post_nms_topk]

        res = Instances3D(image_size)
        res.proposal_boxes = boxes[keep]
        res.objectness_logits = scores_per_img[keep]
        results.append(res)

    return results
