In [None]:
#| default_exp faster_rcnn
#| export
from typing import List, Union, Tuple, Dict
import torch
import torch.nn as nn
from qct_3d_nod_detect.box_regression import _dense_box_regression_loss_3d
from qct_3d_nod_detect.structures import Boxes3D, Instances3D
from qct_3d_nod_detect.layers import nonzero_tuple
import torch.nn.functional as F
from qct_3d_nod_detect.proposal_utils import batched_nms_3d

def fast_rcnn_inference_3d(
    boxes: List[torch.Tensor],
    scores: List[torch.Tensor],
    volume_shapes: List[Tuple[int, int, int]],
    score_thresh: float,
    nms_thresh: float,
    topk_per_image: int,
):
    results = [
        fast_rcnn_inference_single_image_3d(
            boxes_per_image,
            scores_per_image,
            volume_shape,
            score_thresh,
            nms_thresh,
            topk_per_image,
        )
        
        for boxes_per_image, scores_per_image, volume_shape
        in zip(boxes, scores, volume_shapes)
    ]

    return [r[0] for r in results], [r[1] for r in results]

def fast_rcnn_inference_single_image_3d(
    boxes: torch.Tensor,
    scores: torch.Tensor,
    volume_shape: Tuple[int, int, int],
    score_thresh: float,
    nms_thresh: float,
    topk_per_image: int,
):
    """
    boxes: (R, K*6) or (R, 6)
    scores: (R, K+1)
    volume_shape: (D, H, W)
    """

    # Remove invalid rows
    valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1)
    if not valid_mask.all():
        boxes = boxes[valid_mask]
        scores = scores[valid_mask]

    scores = scores[:, :-1]  # remove background
    num_bbox_reg_classes = boxes.shape[1] // 6

    # Reshape boxes
    boxes = Boxes3D(boxes.reshape(-1, 6))
    boxes.clip(volume_shape)
    boxes = boxes.tensor.view(-1, num_bbox_reg_classes, 6)  # R x C x 6

    # 1. Score thresholding
    filter_mask = scores > score_thresh  # R x K
    filter_inds = filter_mask.nonzero(as_tuple=False)  # (N, 2)

    if num_bbox_reg_classes == 1:
        boxes = boxes[filter_inds[:, 0], 0]
    else:
        boxes = boxes[filter_mask]

    scores = scores[filter_mask]

    # 2. Class-wise 3D NMS
    keep = batched_nms_3d(
        boxes,
        scores,
        filter_inds[:, 1],
        nms_thresh,
    )

    if topk_per_image >= 0:
        keep = keep[:topk_per_image]

    boxes = boxes[keep]
    scores = scores[keep]
    filter_inds = filter_inds[keep]

    # Output container (Detectron2-style)
    result = {
        "pred_boxes": boxes,
        "scores": scores,
        "pred_classes": filter_inds[:, 1],
    }

    return result, filter_inds[:, 0]

In [1]:
#| export
class FasterRCNNOutputLayers3D(nn.Module):

    def __init__(
        self,
        input_dim: int,
        num_classes: int,
        *,
        box2box_transform,
        cls_agnostic_bbox_reg: bool = False,
        test_score_thresh: float = 0.05,
        test_nms_thresh: float = 0.5,
        test_topk_per_image: int = 100,
        loss_weight: Union[float, Dict[str, float]] = 1.0,
        box_reg_loss_type: str = "smooth_l1",
        smooth_l1_beta: float = 0.0,
    ):

        super().__init__()

        self.num_classes = num_classes
        self.box2box_transform = box2box_transform
        self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
        self.bbox_reg_loss_type = box_reg_loss_type
        self.smooth_l1_beta = smooth_l1_beta

        # ---- heads ----
        self.cls_score = nn.Linear(input_dim, num_classes + 1)
        num_bbox_reg_classes = 1 if cls_agnostic_bbox_reg else num_classes
        self.bbox_pred = nn.Linear(input_dim, num_bbox_reg_classes * 6)

        # ---- init ----
        nn.init.normal_(self.cls_score.weight, std=0.01)
        nn.init.normal_(self.bbox_pred.weight, std=0.001)
        nn.init.constant_(self.cls_score.bias, 0)
        nn.init.constant_(self.bbox_pred.bias, 0)

        self.test_score_thresh = test_score_thresh
        self.test_nms_thresh = test_nms_thresh
        self.test_topk_per_image = test_topk_per_image

        if isinstance(loss_weight, float):
            loss_weight = {"loss_cls": loss_weight, "loss_box_reg": loss_weight}

        self.loss_weight = loss_weight

    def forward(self, x):

        """
        Args:
            x: (N, C) or (N, C, ...)
        Returns:
            scores: (N, K + 1)
            proposal_deltas: (N, 6) or (N, K * 6)
        """

        if x.dim() > 2:
            x = torch.flatten(x, start_dim=1)

        scores = self.cls_score(x)
        proposal_deltas = self.bbox_pred(x)

        return scores, proposal_deltas

    def inference(
        self,
        predictions: Tuple[torch.Tensor, torch.Tensor],
        proposals: List[Instances3D],
    ):

        """
        Returns:
            results: List[Instances3D]
            kept_indices: List[Tensor]
        """

        boxes = self.predict_boxes(predictions, proposals)
        scores = self.predict_probs(predictions, proposals)

        image_shapes = [p.image_size for p in proposals]

        return fast_rcnn_inference_3d(
            boxes,
            scores,
            image_shapes,
            self.test_score_thresh,
            self.test_nms_thresh,
            self.test_topk_per_image,
        )

    def box_reg_loss(
            self, 
            proposal_boxes,
            gt_boxes,
            pred_deltas,
            gt_classes
    ):

        box_dim = proposal_boxes.shape[1] # 6
        fg_inds = nonzero_tuple((gt_classes >= 0) & (gt_classes < self.num_classes))[0]
        fg_mask = torch.ones(len(fg_inds), dtype=torch.bool, device=proposal_boxes.device)

        if fg_inds.numel() == 0:
            return pred_deltas.sum() * 0.0
        
        if pred_deltas.shape[1] == box_dim:
            fg_pred_deltas = pred_deltas[fg_inds]
        else:
            fg_pred_deltas = pred_deltas.view(-1, self.num_classes, box_dim)[
                fg_inds, gt_classes[fg_inds]
            ]

        loss_bbox_reg = _dense_box_regression_loss_3d(
            [proposal_boxes[fg_inds]],
            self.box2box_transform,
            [fg_pred_deltas.unsqueeze(0)],
            [gt_boxes[fg_inds]],
            fg_mask.unsqueeze(0),
            self.bbox_reg_loss_type,
            self.smooth_l1_beta,
        )
        
        return loss_bbox_reg / max(gt_classes.numel(), 1.0)

    def losses(
            self,
            predictions,
            proposals,
    ):
        
        scores, proposal_deltas = predictions

        gt_classes = torch.cat([p.gt_classes for p in proposals], dim=0)
        proposal_boxes = torch.cat([p.proposal_boxes.tensor for p in proposals], dim=0)
        gt_boxes = torch.cat([p.gt_boxes.tensor for p in proposals], dim=0)

        valid_mask = gt_classes >= 0
        loss_cls = F.cross_entropy(scores[valid_mask], gt_classes[valid_mask])

        loss_box_reg = self.box_reg_loss(
            proposal_boxes,
            gt_boxes,
            proposal_deltas,
            gt_classes
        )

        return {
            "loss_cls": loss_cls * self.loss_weight["loss_cls"],
            "loss_box_reg": loss_box_reg * self.loss_weight["loss_box_reg"],
        }

    def predict_boxes(
            self,
            predictions: Tuple[torch.Tensor, torch.Tensor],
            proposals: List[Instances3D],
    ):

        if not len(proposals):
            return []
        
        _, proposal_deltas = predictions
        proposal_boxes = torch.cat([p.proposal_boxes.tensor for p in proposals], dim=0)

        pred_boxes = self.box2box_transform.apply_deltas(
            proposal_deltas, proposal_boxes
        )  # (N, K*6) or (N, 6)

        num_props = [len(p) for p in proposals]
        return pred_boxes.split(num_props)

    def predict_probs(
            self,
            predictions: Tuple[torch.Tensor, torch.Tensor],
            proposals: List[Instances3D],
    ):

        if not len(proposals):
            return []

        scores, _ = predictions
        num_props = [len(p) for p in proposals]
        probs = torch.softmax(scores, dim=-1)
        return probs.split(num_props)

NameError: name 'nn' is not defined