In [None]:
#| default_exp roi_heads
#| export
from torch import nn
import torch
from qct_3d_nod_detect.structures import Instances3D, pairwise_iou_3d
from typing import List, Tuple, Optional, Dict

def add_ground_truth_to_proposals_3d(
    targets: List[Instances3D],
    proposals: List[Instances3D],
) -> List[Instances3D]:
    """
    Augment proposals with ground-truth boxes.
    """

    assert len(targets) == len(proposals)

    new_proposals = []

    for proposals_per_image, targets_per_image in zip(proposals, targets):

        if len(targets_per_image) == 0:
            new_proposals.append(proposals_per_image)
            continue

        # Clone to avoid in-place modification
        proposals_per_image = proposals_per_image.clone()

        gt_boxes = targets_per_image.gt_boxes
        device = gt_boxes.tensor.device

        # Create new Instances3D for GT boxes
        gt_proposals = Instances3D(proposals_per_image.image_size)
        gt_proposals.proposal_boxes = gt_boxes

        # Objectness logits: high confidence for GT
        gt_proposals.objectness_logits = torch.ones(
            len(gt_boxes), device=device
        )

        # Concatenate proposals
        proposals_per_image = Instances3D.cat(
            [proposals_per_image, gt_proposals]
        )

        new_proposals.append(proposals_per_image)

    return new_proposals

def subsample_labels(
    labels: torch.Tensor,
    num_samples: int,
    positive_fraction: float,
    num_classes: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    """
    Args:
        labels (Tensor): shape (N,), values in:
            [0, num_classes) = foreground
            num_classes      = background
            -1               = ignore
        num_samples (int): total number of samples
        positive_fraction (float): fraction of positives
        num_classes (int): number of foreground classes

    Returns:
        sampled_fg_idxs (Tensor)
        sampled_bg_idxs (Tensor)
    """

    # foreground: [0, num_classes)
    fg_mask = (labels >= 0) & (labels < num_classes)
    fg_idxs = torch.nonzero(fg_mask).squeeze(1)

    # background: == num_classes
    bg_mask = labels == num_classes
    bg_idxs = torch.nonzero(bg_mask).squeeze(1)

    num_fg = int(num_samples * positive_fraction)
    num_fg = min(num_fg, fg_idxs.numel())

    num_bg = num_samples - num_fg
    num_bg = min(num_bg, bg_idxs.numel())

    # Random sampling
    perm_fg = torch.randperm(fg_idxs.numel(), device=labels.device)[:num_fg]
    perm_bg = torch.randperm(bg_idxs.numel(), device=labels.device)[:num_bg]

    sampled_fg_idxs = fg_idxs[perm_fg]
    sampled_bg_idxs = bg_idxs[perm_bg]

    return sampled_fg_idxs, sampled_bg_idxs

class ROIHeads3D(nn.Module):
    def __init__(
            self,
            *,
            num_classes: int,
            batch_size_per_image: int,
            positive_fraction: float,
            proposal_matcher,
            proposal_append_gt: bool,
            roi_pooler,
            box_head,
            box_predictor,
            is_training,
    ):
        
        super().__init__()
        self.num_classes = num_classes
        self.batch_size_per_image = batch_size_per_image
        self.positive_fraction = positive_fraction
        self.proposal_matcher = proposal_matcher
        self.proposal_append_gt = proposal_append_gt

        self.roi_pooler = roi_pooler
        self.box_head = box_head
        self.box_predictor = box_predictor
        self.training = is_training

    def _sample_proposals(
            self,
            matched_idxs: torch.Tensor,
            matched_labels: torch.Tensor,
            gt_classes: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        
        has_gt = gt_classes.numel() > 0

        if has_gt:
            gt_classes = gt_classes[matched_idxs]
            gt_classes[matched_labels == 0] = self.num_classes
            gt_classes[matched_labels == -1] = -1
        else:
            gt_classes = torch.zeros_like(matched_idxs) + self.num_classes

        sampled_fg_idxs, sampled_bg_idxs = subsample_labels(
            gt_classes,
            self.batch_size_per_image,
            self.positive_fraction,
            self.num_classes
        )

        sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
        return sampled_idxs, gt_classes[sampled_idxs]

    @torch.no_grad()
    def label_and_sample_proposals(
        self,
        proposals: List[Instances3D],
        targets: List[Instances3D],
    ) -> List[Instances3D]:
        
        if self.proposal_append_gt:
            proposals = add_ground_truth_to_proposals_3d(targets, proposals)

        proposals_with_gt = []

        for proposal_per_image, targets_per_image in zip(proposals, targets):

            has_gt = len(targets_per_image) > 0

            if has_gt:
                match_quality_matrix = pairwise_iou_3d(
                    targets_per_image.gt_boxes,
                    proposal_per_image.proposal_boxes
                )

                matched_idxs, matched_labels = self.proposal_matcher(match_quality_matrix)

            else:
                device = proposal_per_image.proposal_boxes.tensor.device
                matched_idxs = torch.zeros(
                    len(proposal_per_image), dtype=torch.int64, device=device

                )

                matched_labels = torch.zeros_like(matched_idxs)

            sampled_idxs, gt_classes = self._sample_proposals(
                matched_idxs,
                matched_labels,
                targets_per_image.gt_classes if has_gt else torch.empty(0),
            )

            proposals_per_image = proposals_per_image[sampled_idxs]
            proposal_per_image.gt_classes = gt_classes

            if has_gt:
                sampled_targets = matched_idxs[sampled_idxs]
                proposal_per_image.gt_boxes = targets_per_image.gt_boxes[sampled_targets]

            proposals_with_gt.append(proposals_per_image)

        return proposals_with_gt
    
    def forward(
            self,
            features: Dict[str, torch.Tensor],
            proposals: List[Instances3D],
            targets: Optional[List[Instances3D]] = None,
    ):
        
        if self.training:
            assert targets is not None
            proposals = self.label_and_sample_proposals(proposals, targets)

        box_features = self.roi_pooler(features, proposals)
        box_features = self.box_head(box_features)

        predictions = self.box_predictor(box_features)

        if self.training:
            return self.box_predictor.losses(predictions, proposals)
        else:
            return self.box_predictor.inference(predictions, proposals)