In [1]:
#| default_exp rpn
#| export
from torch import nn
from typing import List, Union, Tuple, Dict, Optional
import torch
from qct_3d_nod_detect.anchor_generator_3d import build_anchor_generator_3d
from qct_3d_nod_detect.structures import Boxes3D, Instances3D, pairwise_iou_3d
from qct_3d_nod_detect.box_regression import Box3DTransform, _dense_box_regression_loss_3d
from qct_3d_nod_detect.matcher import Matcher
from qct_3d_nod_detect.layers import ShapeSpec, cat
from qct_3d_nod_detect.sampling import subsample_labels
import torch.nn.functional as F
from qct_3d_nod_detect.memory import retry_if_cuda_oom
from qct_3d_nod_detect.proposal_utils import find_top_rpn_proposals_3d

class StandardRPNHead3d(nn.Module):

    def __init__(self, in_channels: int, num_anchors: int, box_dim: int = 6, conv_dims: List[int] = [-1]):

        super().__init__()
        cur_channels = in_channels

        if len(conv_dims) == 1:
            out_channels = cur_channels if conv_dims[0] == -1 else conv_dims[0]
            self.conv = self._get_rpn_conv(cur_channels, out_channels) # Activation after conv
            cur_channels = out_channels

        else:
            self.conv = nn.Sequential()
            for k, conv_dim in enumerate(conv_dims):
                out_channels = cur_channels if conv_dims==-1 else conv_dim
                if out_channels <= 0:
                    raise ValueError(f"Conv output channels should be greater than 0. Got {out_channels}")
                
                conv = self._get_rpn_conv(cur_channels, out_channels)
                self.conv.add_module(f"conv{k}", conv)
                cur_channels = out_channels

        self.objectness_logits = nn.Conv3d(cur_channels, num_anchors, kernel_size=1, stride=1) # 1x1x1 conv for objectness logits
        self.anchor_deltas = nn.Conv3d(cur_channels, num_anchors * box_dim, kernel_size=1, stride=1) # 1x1x1 conv for predicting box2box transform deltas

        for layer in self.modules():
            if isinstance(layer, nn.Conv3d):
                nn.init.normal_(layer.weight, std=0.01)
                nn.init.constant_(layer.bias, 0)

    def _get_rpn_conv(self, in_channels, out_channels):
        conv = nn.Sequential(
            nn.Conv3d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
            ),
            nn.ReLU()
        )
        return conv

    @classmethod
    def from_config(cls, cfg, input_shape):

        in_channels = [s.channels for s in input_shape]
        assert len(set(in_channels)) == 1, "Each level must have the same channel"

        in_channels = in_channels[0]
        anchor_generator = build_anchor_generator_3d(cfg, input_shape) 
        num_anchors = anchor_generator.num_anchors
        box_dim = anchor_generator.box_dim # Should be 6
        assert len(set(num_anchors)) == 1, "Each level must have the same number of anchors"

        return {
            "in_channels": in_channels,
            "num_anchors": num_anchors[0],
            "box_dim": box_dim,
            "conv_dims": cfg.MODEL.RPN.CONV_DIMS
        }

    def forward(self, features: List[torch.Tensor]):

        pred_objectness_logits = []
        pred_anchor_deltas = []

        for x in features:
            t = self.conv(x)
            pred_objectness_logits.append(self.objectness_logits(t))
            pred_anchor_deltas.append(self.anchor_deltas(t))

        return pred_objectness_logits, pred_anchor_deltas



In [2]:
standard_rpn = StandardRPNHead3d(in_channels=256, num_anchors=3, box_dim=6)

features = [
    torch.rand(1, 256, 32, 32, 32),
    torch.rand(1, 256, 64, 64, 64),
    torch.rand(1, 256, 128, 128, 128),
]

In [4]:
pred_objectness_logits, pred_anchor_deltas = standard_rpn(features)

In [11]:
pred_objectness_logits[0].shape
pred_anchor_deltas[0].shape

torch.Size([1, 18, 32, 32, 32])

In [None]:
#| export
def build_rpn_head_3d(input_shapes: List[ShapeSpec], cfg=None):
    """
    Very simple 3D RPN head builder – no registry, just creates StandardRPNHead3D.
    """
    # Assume all levels have same channels (very common in FPN-like setups)
    in_channels = input_shapes[0].channels
    assert all(s.channels == in_channels for s in input_shapes), \
        "All feature levels must have the same number of channels"

    # You can get num_anchors from anchor generator later, or pass it
    # For now we leave it to be set when instantiating RPN

    return StandardRPNHead3d(
        in_channels=in_channels,
        num_anchors=9,           # ← temporary / example value – will be overridden
        box_dim=6,
        conv_dims=[-1]           # or get from cfg if you keep some config
    )

class RPN3D(nn.Module):
    def __init__(
        self,
        in_features: List[str],
        head: nn.Module,
        anchor_generator: nn.Module,
        anchor_matcher: Matcher,
        box3d_transform: Box3DTransform,  # 3D box transform
        batch_size_per_image: int,
        positive_fraction: float,
        pre_nms_topk: Tuple[float, float],
        post_nms_topk: Tuple[float, float],
        nms_thresh: float = 0.7,
        min_box_size: float = 0.0,
        anchor_boundary_thresh: float = -1.0,
        loss_weight: Union[float, Dict[str, float]] = 1.0,
        box_reg_loss_type: str = "smooth_l1",
        smooth_l1_beta: float = 0.0,
    ):

        super().__init__()

        self.in_features = in_features
        self.rpn_head = head
        self.anchor_generator = anchor_generator
        self.anchor_matcher = anchor_matcher
        self.box3d_transform = box3d_transform
        self.batch_size_per_image = batch_size_per_image
        self.positive_fraction = positive_fraction
        self.pre_nms_topk = {True: pre_nms_topk[0], False: pre_nms_topk[1]}
        self.post_nms_topk = {True: post_nms_topk[0], False: post_nms_topk[1]}
        self.nms_thresh = nms_thresh
        self.min_box_size = float(min_box_size)
        self.anchor_boundary_thresh = anchor_boundary_thresh

        if isinstance(loss_weight, float):
            loss_weight = {"loss_rpn_cls": loss_weight, "loss_rpn_loc": loss_weight}
        
        self.loss_weight = loss_weight
        self.box_reg_loss_type = box_reg_loss_type
        self.smooth_l1_beta = smooth_l1_beta

    @classmethod
    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
        in_features = cfg.MODEL.RPN.IN_FEATURES
        ret = {
            "in_features": in_features,
            "min_box_size": cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE,
            "nms_thresh": cfg.MODEL.RPN.NMS_THRESH,
            "batch_size_per_image": cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE,
            "positive_fraction": cfg.MODEL.RPN.POSITIVE_FRACTION,
            "loss_weight": {
                "loss_rpn_cls": cfg.MODEL.RPN.LOSS_WEIGHT,
                "loss_rpn_loc": cfg.MODEL.RPN.BBOX_REG_LOSS_WEIGHT * cfg.MODEL.RPN.LOSS_WEIGHT,
            },
            "anchor_boundary_thresh": cfg.MODEL.RPN.BOUNDARY_THRESH,
            "box3d_transform": Box3DTransform(weights=cfg.MODEL.RPN.BBOX_REG_WEIGHTS),  # 3D weights (e.g., tuple of 6)
            "box_reg_loss_type": cfg.MODEL.RPN.BBOX_REG_LOSS_TYPE,
            "smooth_l1_beta": cfg.MODEL.RPN.SMOOTH_L1_BETA,
        }
        ret["pre_nms_topk"] = (cfg.MODEL.RPN.PRE_NMS_TOPK_TRAIN, cfg.MODEL.RPN.PRE_NMS_TOPK_TEST)
        ret["post_nms_topk"] = (cfg.MODEL.RPN.POST_NMS_TOPK_TRAIN, cfg.MODEL.RPN.POST_NMS_TOPK_TEST)
        ret["anchor_generator"] = build_anchor_generator_3d(cfg, [input_shape[f] for f in in_features])
        ret["anchor_matcher"] = Matcher(
            cfg.MODEL.RPN.IOU_THRESHOLDS, cfg.MODEL.RPN.IOU_LABELS, allow_low_quality_matches=True
        )

        ret["head"] = build_rpn_head_3d([input_shape[f] for f in in_features])  # Should return StandardRPNHead3D
        return ret

    def _subsample_labels(self, label):

        pos_idx, neg_idx = subsample_labels(
            label, self.batch_size_per_image, self.positive_fraction, 0
        )

        label.fill_(-1)
        label.scatter_(0, pos_idx, 1)
        label.scatter_(0, neg_idx, 0)

        return label

    @torch.no_grad
    def label_and_sample_anchors(
        self, anchors: List[Boxes3D], gt_instances: List[Instances3D]
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        
        """
            Args:
                anchors (list[Boxes]): anchors for each feature map.
                gt_instances: the ground-truth instances for each image.

            Returns:
                list[Tensor]:
                    List of #img tensors. i-th element is a vector of labels whose length is
                    the total number of anchors across all feature maps R = sum(Hi * Wi * A).
                    Label values are in {-1, 0, 1}, with meanings: -1 = ignore; 0 = negative
                    class; 1 = positive class.
                list[Tensor]:
                    i-th element is a Rx4 tensor. The values are the matched gt boxes for each
                    anchor. Values are undefined for those anchors not labeled as 1.
        """

        anchors = Boxes3D.cat(anchors) 
        gt_boxes = [x.gt_boxes for x in gt_instances]
        image_sizes = [x.image_size for x in gt_instances]
        
        del gt_instances

        gt_labels = []
        matched_gt_boxes = []

        for image_size_i, gt_boxes_i in zip(image_sizes, gt_boxes):

            match_quality_matrix = retry_if_cuda_oom(pairwise_iou_3d)(gt_boxes_i, anchors)
            matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix)
            gt_labels_i = gt_labels_i.to(device=gt_boxes_i.device)

            del match_quality_matrix

            if self.anchor_boundary_thresh >= 0:
                anchors_inside_image = anchors.inside_box_3d(image_size_i, self.anchor_boundary_thresh)
                gt_labels_i[~anchors_inside_image] = -1

            gt_labels_i = self._subsample_labels(gt_labels_i)

            if len(gt_labels_i) == 0:
                matched_gt_boxes_i = torch.zeros_like(anchors.tensors)
            else:
                matched_gt_boxes_i = gt_boxes_i[matched_idxs].tensor
                
            gt_labels.append(gt_labels_i)
            matched_gt_boxes.append(matched_gt_boxes_i)

        return gt_labels, matched_gt_boxes

    @torch.jit.unused
    def losses(
        self,
        anchors: List[Boxes3D],
        pred_objectness_logits: List[torch.Tensor],
        pred_anchor_deltas: List[torch.Tensor],
        gt_labels: List[torch.Tensor],
        gt_boxes: List[torch.Tensor]
    ) -> Dict[str, torch.Tensor]:

        """
        Return the losses from a set of RPN predictions and their associated ground-truth.

        Args:
            anchors (list[Boxes or RotatedBoxes]): anchors for each feature map, each
                has shape (Hi*Wi*A, B), where B is box dimension (4 or 5).
            pred_objectness_logits (list[Tensor]): A list of L elements.
                Element i is a tensor of shape (N, Hi*Wi*A) representing
                the predicted objectness logits for all anchors.
            gt_labels (list[Tensor]): Output of :meth:`label_and_sample_anchors`.
            pred_anchor_deltas (list[Tensor]): A list of L elements. Element i is a tensor of shape
                (N, Hi*Wi*A, 6) representing the predicted "deltas" used to transform anchors
                to proposals.
            gt_boxes (list[Tensor]): Output of :meth:`label_and_sample_anchors`.

        Returns:
            dict[loss name -> loss value]: A dict mapping from loss name to loss value.
                Loss names are: `loss_rpn_cls` for objectness classification and
                `loss_rpn_loc` for proposal localization.
        """

        num_images = len(gt_labels)
        gt_labels = torch.stack(gt_labels)

        # Log the number of positive/negative anchors per-image that's used in training
        pos_mask = gt_labels == 1
        num_pos_anchors = pos_mask.sum().item()
        num_neg_anchors = (gt_labels == 0).sum().item()

        localization_loss = _dense_box_regression_loss_3d(
            anchors,
            self.box3d_transform,
            pred_anchor_deltas,
            gt_boxes,
            pos_mask,
            box_reg_loss_type=self.box_reg_loss_type,
            smooth_l1_beta=self.smooth_l1_beta
        )

        valid_mask = gt_labels >= 0

        objectness_loss = F.binary_cross_entropy_with_logits(
            cat(pred_objectness_logits, dim=1)[valid_mask],
            gt_labels[valid_mask].to(torch.float32),
            reduction="sum"
        )

        normalizer = self.batch_size_per_image * num_images
        losses = {
            "loss_rpn_cls": objectness_loss / normalizer,
            "loss_rpn_loc": localization_loss / normalizer,
        }

        losses = {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}

        return losses

    def forward(
        self,
        images,
        features: Dict[str, torch.Tensor],
        gt_instances: Optional[List[Instances3D]] = None,
        training: bool = True
    ):

        features = [features[f] for f in self.in_features]
        anchors = self.anchor_generator(features)
        anchors = [Boxes3D(anchor) for anchor in anchors]

        # return anchors

        pred_objectness_logits, pred_anchor_deltas = self.rpn_head(features)

        # Objectness logits, 
        # (N, A, D, H, W) -> (N, D, H, W, A) -> (N, D*H*W*A)
        pred_objectness_logits = [
            score.permute(0, 2, 3, 4, 1).flatten(1)
            for score in pred_objectness_logits
        ]

        # Box Deltas
        # (N, A*6, D, H, W)
        # -> (N, A, 6, D, H, W)
        # -> (N, D, H, W, A, 6)
        # -> (N, D*H*W*A, 6)
        pred_anchor_deltas = [
            x.view(x.shape[0], -1, self.anchor_generator.box_dim, x.shape[-3], x.shape[-2], x.shape[-1])
            .permute(0, 3, 4, 5, 1, 2)
            .flatten(1, -2)
            for x in pred_anchor_deltas
        ]

        if training:

            assert gt_instances is not None, "RPN3D requires gt_instances in training!"

            gt_labels, gt_boxes = self.label_and_sample_anchors(
                anchors, gt_instances
            )

            losses = self.losses(
                anchors,
                pred_objectness_logits,
                pred_anchor_deltas,
                gt_labels,
                gt_boxes
            )

        else:
            losses = {}

        proposals = self.predict_proposals(
            anchors,
            pred_objectness_logits,
            pred_anchor_deltas,
            images.image_sizes,
            training
        )

        return proposals, losses

    def predict_proposals(
        self,
        anchors: List[Boxes3D],
        pred_objectness_logits: List[torch.Tensor],
        pred_anchor_deltas: List[torch.Tensor],
        image_sizes: List[Tuple[int, int]],
        training: bool = True,
    ):

        """
        Decode all the predicted box regression deltas to proposals. Find the top proposals
        by applying NMS and removing boxes that are too small.

        Returns:
            proposals (list[Instances]): list of N Instances. The i-th Instances
                stores post_nms_topk object proposals for image i, sorted by their
                objectness score in descending order.
        """

        with torch.no_grad():
            pred_proposals = self._decode_proposals_3d(anchors, pred_anchor_deltas)

            return find_top_rpn_proposals_3d(
                pred_proposals, 
                pred_objectness_logits,
                image_sizes,
                self.nms_thresh,
                self.pre_nms_topk[training],
                self.post_nms_topk[training],
                self.min_box_size,
                training
            )

    def _decode_proposals_3d(self, anchors: List[Boxes3D], pred_anchor_deltas: List[torch.Tensor]):
        """
        Transform anchors into proposals by applying the predicted anchor deltas.

        Returns:
            proposals (list[Tensor]): A list of L tensors. Tensor i has shape
                (N, Hi*Wi*A, B)
        """

        N = pred_anchor_deltas[0].shape[0]
        proposals = []

        for anchors_i, pred_anchor_deltas_i in zip(anchors, pred_anchor_deltas):
            B = anchors_i.tensor.size(1)
            pred_anchor_deltas_i = pred_anchor_deltas_i.reshape(-1, B)

            anchors_i = anchors_i.tensor.unsqueeze(0).expand(N, -1, -1).reshape(-1, B)
            proposals_i = self.box3d_transform.apply_deltas(
                pred_anchor_deltas_i, anchors_i
            )

            proposals.append(proposals_i.view(N, -1, B))
        
        return proposals