In [10]:
#| default_exp roi_heads
#| export
from torch import nn
import torch
from qct_3d_nod_detect.structures import Instances3D, pairwise_iou_3d
from typing import List, Tuple, Optional, Dict
import math

def add_ground_truth_to_proposals_3d(
    targets: List[Instances3D],
    proposals: List[Instances3D],
) -> List[Instances3D]:
    """
    Augment proposals with ground-truth boxes.
    """

    assert len(targets) == len(proposals)

    new_proposals = []

    for proposals_per_image, targets_per_image in zip(proposals, targets):

        if len(targets_per_image) == 0:
            new_proposals.append(proposals_per_image)
            continue

        # Clone to avoid in-place modification
        proposals_per_image = proposals_per_image.clone()

        gt_boxes = targets_per_image.gt_boxes
        device = gt_boxes.tensor.device

        # Create new Instances3D for GT boxes
        gt_proposals = Instances3D(proposals_per_image.image_size)
        gt_proposals.proposal_boxes = gt_boxes

        # Objectness logits: high confidence for GT
        gt_logit_value = math.log((1.0 - 1e-10) / (1 - (1.0 - 1e-10)))
        gt_logits = gt_logit_value * torch.ones(len(gt_boxes), device=device)
        gt_proposals.objectness_logits = gt_logits

        # Concatenate proposals
        proposals_per_image = Instances3D.cat(
            [proposals_per_image, gt_proposals]
        )

        new_proposals.append(proposals_per_image)

    return new_proposals

def subsample_labels(
    labels: torch.Tensor,
    num_samples: int,
    positive_fraction: float,
    num_classes: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    """
    Args:
        labels (Tensor): shape (N,), values in:
            [0, num_classes) = foreground
            num_classes      = background
            -1               = ignore
        num_samples (int): total number of samples
        positive_fraction (float): fraction of positives
        num_classes (int): number of foreground classes

    Returns:
        sampled_fg_idxs (Tensor)
        sampled_bg_idxs (Tensor)
    """

    # foreground: [0, num_classes)
    fg_mask = (labels >= 0) & (labels < num_classes)
    fg_idxs = torch.nonzero(fg_mask).squeeze(1)

    # background: == num_classes
    bg_mask = labels == num_classes
    bg_idxs = torch.nonzero(bg_mask).squeeze(1)

    num_fg = int(num_samples * positive_fraction)
    num_fg = min(num_fg, fg_idxs.numel())

    num_bg = num_samples - num_fg
    num_bg = min(num_bg, bg_idxs.numel())

    # Random sampling
    perm_fg = torch.randperm(fg_idxs.numel(), device=labels.device)[:num_fg]
    perm_bg = torch.randperm(bg_idxs.numel(), device=labels.device)[:num_bg]

    sampled_fg_idxs = fg_idxs[perm_fg]
    sampled_bg_idxs = bg_idxs[perm_bg]

    return sampled_fg_idxs, sampled_bg_idxs

class ROIHeads3D(nn.Module):
    def __init__(
            self,
            *,
            num_classes: int,
            batch_size_per_image: int,
            positive_fraction: float,
            proposal_matcher,
            proposal_append_gt: bool,
            roi_pooler,
            box_head,
            box_predictor,
    ):
        
        super().__init__()
        self.num_classes = num_classes
        self.batch_size_per_image = batch_size_per_image
        self.positive_fraction = positive_fraction
        self.proposal_matcher = proposal_matcher
        self.proposal_append_gt = proposal_append_gt

        self.roi_pooler = roi_pooler
        self.box_head = box_head
        self.box_predictor = box_predictor

    def _sample_proposals(
            self,
            matched_idxs: torch.Tensor,
            matched_labels: torch.Tensor,
            gt_classes: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        
        has_gt = gt_classes.numel() > 0

        if has_gt:
            gt_classes = gt_classes[matched_idxs]
            gt_classes[matched_labels == 0] = self.num_classes
            gt_classes[matched_labels == -1] = -1
        else:
            gt_classes = torch.zeros_like(matched_idxs) + self.num_classes

        sampled_fg_idxs, sampled_bg_idxs = subsample_labels(
            gt_classes,
            self.batch_size_per_image,
            self.positive_fraction,
            self.num_classes
        )

        sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
        return sampled_idxs, gt_classes[sampled_idxs]

    @torch.no_grad()
    def label_and_sample_proposals(
        self,
        proposals: List[Instances3D],
        targets: List[Instances3D],
    ) -> List[Instances3D]:
        
        if self.proposal_append_gt:
            proposals = add_ground_truth_to_proposals_3d(targets, proposals)

        proposals_with_gt = []

        for proposal_per_image, targets_per_image in zip(proposals, targets):

            has_gt = len(targets_per_image) > 0

            if has_gt:
                match_quality_matrix = pairwise_iou_3d(
                    targets_per_image.gt_boxes,
                    proposal_per_image.proposal_boxes
                )

                matched_idxs, matched_labels = self.proposal_matcher(match_quality_matrix)

            else:
                device = proposal_per_image.proposal_boxes.tensor.device
                matched_idxs = torch.zeros(
                    len(proposal_per_image), dtype=torch.int64, device=device

                )

                matched_labels = torch.zeros_like(matched_idxs)

            sampled_idxs, gt_classes = self._sample_proposals(
                matched_idxs,
                matched_labels,
                targets_per_image.gt_classes if has_gt else torch.empty(0),
            )

            proposals_per_image = proposal_per_image[sampled_idxs]
            proposals_per_image.gt_classes = gt_classes

            if has_gt:
                sampled_targets = matched_idxs[sampled_idxs]
                proposals_per_image.gt_boxes = targets_per_image.gt_boxes[sampled_targets]

            proposals_with_gt.append(proposals_per_image)

        return proposals_with_gt
    
    def forward(
            self,
            features: Dict[str, torch.Tensor],
            proposals: List[Instances3D],
            targets: Optional[List[Instances3D]] = None,
            training: bool = True,
    ):
        
        if training:
            assert targets is not None
            proposals = self.label_and_sample_proposals(proposals, targets)
        
        proposal_boxes = [p.proposal_boxes for p in proposals]
        feature_tensors = [v for k, v in features.items()]

        box_features = self.roi_pooler(feature_tensors, proposal_boxes)
        # print(f"Box features pool - {box_features.shape}")
        box_features = self.box_head(box_features)
        # print(f"Box features shape - {box_features.shape}")

        predictions = self.box_predictor(box_features)

        if training:
            return self.box_predictor.losses(predictions, proposals)
        else:
            return self.box_predictor.inference(predictions, proposals)

In [7]:
from qct_3d_nod_detect.layers import ShapeSpec
from qct_3d_nod_detect.box_heads import FastRCNNConvFCHead3D
import torch

def random_valid_boxes_3d(num_boxes, image_size):
    """
    Returns valid 3D boxes in (x1, y1, z1, x2, y2, z2) format
    """
    D, H, W = image_size

    x1 = torch.rand(num_boxes) * (W * 0.8)
    y1 = torch.rand(num_boxes) * (H * 0.8)
    z1 = torch.rand(num_boxes) * (D * 0.8)

    w = torch.rand(num_boxes) * (W * 0.2) + 1.0
    h = torch.rand(num_boxes) * (H * 0.2) + 1.0
    d = torch.rand(num_boxes) * (D * 0.2) + 1.0

    x2 = x1 + w
    y2 = y1 + h
    z2 = z1 + d

    return torch.stack([x1, y1, z1, x2, y2, z2], dim=1)


input_shape = ShapeSpec(
    channels=256,
    depth=7,
    height=7,
    width=7,
)

box_head = FastRCNNConvFCHead3D(
    input_shape=input_shape,
    conv_dims=[256, 256],
    fc_dims=[512],
)

x = torch.randn(8, 256, 7, 7, 7)  # 8 ROIs
y = box_head(x)

print(y.shape)

from qct_3d_nod_detect.structures import Instances3D, Boxes3D
from qct_3d_nod_detect.matcher import Matcher
import torch

# Fake proposals
proposals = []
for _ in range(2):  # batch size = 2
    inst = Instances3D(image_size=(128, 128, 128))
    inst.proposal_boxes = Boxes3D(torch.rand(10, 6)) # 10 proposals
    inst.objectness_logits = torch.rand(10)
    proposals.append(inst)

# Fake GT
targets = []
for _ in range(2):
    inst = Instances3D(image_size=(128, 128, 128))
    inst.gt_boxes = Boxes3D(random_valid_boxes_3d(3, (128,)*3)) # 3 GT boxes
    inst.gt_classes = torch.randint(0, 1, (3,)) # 2 classes
    targets.append(inst)


torch.Size([8, 512])


In [8]:
from qct_3d_nod_detect.poolers import ROIPooler3D
from qct_3d_nod_detect.faster_rcnn import FasterRCNNOutputLayers3D
from qct_3d_nod_detect.box_regression import Box3DTransform
import math

box3d2box3d_transform = Box3DTransform(
    weights=(1.0, 1.0, 1.0, 1.0, 1.0, 1.0),
    scale_clamp=math.log(1000.0),
)

proposal_matcher = Matcher(
    thresholds=[0.1, 0.2],
    labels = [0, -1, 1],
    allow_low_quality_matches=True
)

roi_pooler = ROIPooler3D(
    output_size=(7, 7, 7),
    canonical_level=4,
    canonical_box_size=224,
    pooler_type="ROIALign3DV2",
    scales=[1, 2, 0.5, 0.25]
)

box_predictor = FasterRCNNOutputLayers3D(
    input_dim=512,
    num_classes=1,
    box2box_transform=box3d2box3d_transform,
    cls_agnostic_bbox_reg=False,
)

In [9]:
roi_heads = ROIHeads3D(
    num_classes=1,
    batch_size_per_image=2,
    positive_fraction=0.5,
    proposal_matcher=proposal_matcher,
    proposal_append_gt=True,       # IMPORTANT for early training
    roi_pooler=roi_pooler,         # your 3D ROI pooler
    box_head=box_head,             # FastRCNNConvFCHead3D
    box_predictor=box_predictor,   # FasterRCNNOutputLayers3D
    is_training=True
)

TypeError: ROIHeads3D.__init__() got an unexpected keyword argument 'is_training'

In [5]:
B, C = 2, 256

features = {}
features['p2'] = torch.rand(B, C, 32, 32, 32)
features['p3'] = torch.rand(B, C, 16, 16, 16)
features['p4'] = torch.rand(B, C, 8, 8, 8)
features['p5'] = torch.rand(B, C, 4, 4, 4)

In [12]:
losses = roi_heads(features, proposals, targets)

Box features pool - torch.Size([4, 256, 7, 7, 7])
Box features shape - torch.Size([4, 512])


In [13]:
losses

{'loss_cls': tensor(0.6722, grad_fn=<MulBackward0>),
 'loss_box_reg': tensor(0.0004, grad_fn=<MulBackward0>)}