In [39]:
#| default_exp poolers
#| export
import math
from typing import List, Optional, Tuple
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from qct_3d_nod_detect.layers import nonzero_tuple, cat, shapes_to_tensor
from qct_3d_nod_detect.structures import Boxes3D, Instances3D

In [40]:
#| export
def assign_boxes_to_levels_3d(
        box_lists: List[Boxes3D],
        min_level: int,
        max_level: int,
        canonical_box_size: float = 224.0,
        canonical_level: int = 4,
) -> torch.Tensor:

    volumes = torch.cat([b.volume() for b in box_lists])
    box_sizes = volumes ** (1.0 / 3.0)

    level_assignments = torch.floor(
        canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-6)
    )

    level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level)
    return (level_assignments - min_level).to(torch.int64)

def _convert_boxes_to_pooler_format_3d(boxes: Tensor, sizes: Tensor) -> Tensor:
    """
    Low-level helper: prepends batch indices to the box coordinates.

    Args:
        boxes: (M_total, 6) tensor of [x1,y1,z1, x2,y2,z2]
        sizes: (B,) tensor with number of boxes per image

    Returns:
        (M_total, 7) tensor of [batch_idx, x1,y1,z1, x2,y2,z2]
    """
    sizes = sizes.to(device=boxes.device)
    indices = torch.repeat_interleave(
        torch.arange(len(sizes), dtype=boxes.dtype, device=boxes.device),
        sizes
    )
    return cat([indices[:, None], boxes], dim=1)

def convert_boxes_to_pooler_format_3d(box_lists: List[Boxes3D]) -> Tensor:
    """
    Convert a list of per-image 3D box objects into the format expected by
    3D RoI pooling / alignment operations.

    Compatible with:
    - Boxes3D.tensor: [x1, y1, z1, x2, y2, z2] (min/max corners)
    - ROIAlign3D / RoIPool3D expecting [batch_idx, x1,y1,z1, x2,y2,z2]

    Args:
        box_lists (List[Boxes3D]):
            List of N Boxes3D objects (N = batch size).
            Each Boxes3D has .tensor of shape (num_boxes_i, 6)

    Returns:
        Tensor of shape (M, 7), where M is total number of boxes across batch.
        Columns: [batch_index, x1, y1, z1, x2, y2, z2]
    """
    if len(box_lists) == 0:
        # Return empty tensor with correct dtype/device
        return torch.empty((0, 7), dtype=torch.float32, device=torch.device("cpu"))

    # Concatenate all box tensors from all images
    all_boxes = torch.cat([b.tensor for b in box_lists], dim=0)  # (M_total, 6)

    # Get number of boxes per image (tracing-friendly)
    sizes = shapes_to_tensor([len(b) for b in box_lists], device=all_boxes.device)

    return _convert_boxes_to_pooler_format_3d(all_boxes, sizes)

class RoIPool3D(nn.Module):
    """
    3D ROI Pooling module.

    Args:
        output_size (tuple[int]): (out_d, out_h, out_w), size of the output feature map
    """

    def __init__(self, output_size: Tuple[int, int, int]):
        super().__init__()
        self.output_size = output_size

    def forward(self, features: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
        """
        Args:
            features (Tensor[N, C, D, H, W]): input volumetric feature maps
            rois (Tensor[num_rois, 7]): ROI boxes with (batch_idx, z1, y1, x1, z2, y2, x2)
                coordinates must be in **feature map scale**.

        Returns:
            Tensor[num_rois, C, out_d, out_h, out_w]: pooled features
        """
        num_rois = rois.size(0)
        N, C, _, _, _ = features.shape
        out_d, out_h, out_w = self.output_size

        pooled = torch.zeros(
            (num_rois, C, out_d, out_h, out_w),
            device=features.device,
            dtype=features.dtype,
        )

        for i in range(num_rois):
            batch_idx = int(rois[i, 0])
            z1, y1, x1, z2, y2, x2 = rois[i, 1:]

            # Crop ROI and apply adaptive max pooling
            roi_feature = features[
                batch_idx : batch_idx + 1,
                :,
                int(z1):int(z2),
                int(y1):int(y2),
                int(x1):int(x2),
            ]

            pooled[i] = F.adaptive_max_pool3d(roi_feature, self.output_size)

        return pooled

class ROIAlign3D(nn.Module):
    def __init__(self, output_size, spatial_scale, sampling_ratio=-1, aligned=True):
        """
        3D ROIAlign: trilinear align.

        Args:
            output_size (tuple): (out_d, out_h, out_w)
            spatial_scale (float): scale bar for the input boxes
            sampling_ratio (int): number of grid samples per output bin (<=0 means adaptive)
            aligned (bool): if True, use the same alignment strategy as Detectron2
        """
        super().__init__()
        self.output_size = output_size
        self.spatial_scale = spatial_scale
        self.sampling_ratio = sampling_ratio
        self.aligned = aligned

    def forward(self, input: torch.Tensor, rois: torch.Tensor):
        """
        Args:
            input: (N, C, D, H, W)
            rois: (num_rois, 7) = (batch_idx, x1,y1,z1,x2,y2,z2)
        """
        assert rois.dim() == 2 and rois.size(1) == 7

        num_rois = rois.size(0)
        output = []

        N, C, D, H, W = input.shape
        out_d, out_h, out_w = self.output_size

        for i in range(num_rois):
            batch_idx = int(rois[i, 0].item())
            x1, y1, z1, x2, y2, z2 = rois[i, 1:] * self.spatial_scale

            if self.aligned:
                # coordinate shift for correct alignment
                x1 -= 0.5
                y1 -= 0.5
                z1 -= 0.5
                x2 -= 0.5
                y2 -= 0.5
                z2 -= 0.5

            # voxel grid spacing
            d_step = (z2 - z1) / max(out_d, 1)
            h_step = (y2 - y1) / max(out_h, 1)
            w_step = (x2 - x1) / max(out_w, 1)

            # Construct a sampling grid
            d = torch.linspace(z1 + 0.5 * d_step, z2 - 0.5 * d_step, out_d, device=input.device)
            h = torch.linspace(y1 + 0.5 * h_step, y2 - 0.5 * h_step, out_h, device=input.device)
            w = torch.linspace(x1 + 0.5 * w_step, x2 - 0.5 * w_step, out_w, device=input.device)

            grid_d, grid_h, grid_w = torch.meshgrid(d, h, w, indexing="ij")

            # Normalize to [-1,1] in each dimension
            grid = torch.stack((grid_w / (W - 1) * 2 - 1,
                                grid_h / (H - 1) * 2 - 1,
                                grid_d / (D - 1) * 2 - 1), dim=-1)

            grid = grid.unsqueeze(0)  # (1, out_d, out_h, out_w, 3)

            # Sample with trilinear interpolation
            roi_feat = F.grid_sample(
                input[batch_idx : batch_idx + 1],  # shape (1, C, D, H, W)
                grid,
                mode="bilinear",
                padding_mode="zeros",
                align_corners=True,
            )

            output.append(roi_feat)

        return torch.cat(output, dim=0)  # (num_rois, C, out_d, out_h, out_w)

    def __repr__(self):
        return (
            f"{self.__class__.__name__}("
            f"output_size={self.output_size}, "
            f"spatial_scale={self.spatial_scale}, "
            f"sampling_ratio={self.sampling_ratio}, "
            f"aligned={self.aligned}"
            f")"
        )

In [41]:
#| export
class ROIPooler3D(nn.Module):

    def __init__(
        self,
        output_size: Tuple[int, int, int],
        scales: List[float],
        sampling_ratio: int = 0,
        pooler_type: str = "ROIAlignV2",
        canonical_box_size: float = 224.0,
        canonical_level: int = 4,
    ):
        super().__init__()
        if isinstance(output_size, int):
            output_size = (output_size, output_size, output_size)

        self.output_size = output_size
        self.canonical_box_size = canonical_box_size
        self.canonical_level = canonical_level

        # Compute levels from scales (stride = 1/scale, assume power of 2)
        strides = [1.0 / s for s in scales]
        min_level = int(round(math.log2(strides[0])))
        max_level = int(round(math.log2(strides[-1])))

        assert max_level - min_level + 1 == len(scales)

        self.min_level = min_level
        self.max_level = max_level

        # Create per level poolers
        for scale in scales:
            if pooler_type == "ROIAlign3D":
                self.level_poolers = nn.ModuleList(
                    ROIAlign3D(
                        output_size,
                        spatial_scale=scale,
                        sampling_ratio=sampling_ratio,
                        aligned=False,
                    )
                    for scale in scales
                )

            elif pooler_type == "ROIALign3DV2":
                self.level_poolers = nn.ModuleList(
                    ROIAlign3D(
                        output_size,
                        spatial_scale=scale,
                        sampling_ratio=sampling_ratio,
                        aligned=True,
                    )
                    for scale in scales
                )

            elif pooler_type == "ROIPool3D":
                self.level_poolers = nn.ModuleList(
                    RoIPool3D(output_size)
                    for scale in scales
                )

            else:
                raise ValueError(f"Unknown pooler type: {pooler_type}")

            self.pooler_type = pooler_type

    def forward(
            self, 
            x: List[torch.Tensor], 
            box_lists: List[Boxes3D]
        ) -> torch.Tensor:

            """
            Args:

            x: List[Tensor]      [ (B,C,D,H,W), (B,C,D/2,H/2,W/2), ... ]
            box_lists: List[Boxes3D]  per image, usually list of length B
            """

            pooler_fmt_boxes = convert_boxes_to_pooler_format_3d(box_lists)

            if len(box_lists) == 0:
                return torch.zeros((0, x[0].shape[1], *self.output_size), device=x[0].device)

            # Mutli_level box assignment
            level_assignments = assign_boxes_to_levels_3d(
                box_lists,
                self.min_level,
                self.max_level,
                self.canonical_box_size,
                self.canonical_level,
            )

            num_channels = x[0].shape[1]
            output = torch.zeros(
                (pooler_fmt_boxes.shape[0], num_channels, *self.output_size),
                dtype=x[0].dtype, device=x[0].device
            )

            for lvl, pooler in enumerate(self.level_poolers):
                inds = (level_assignments == lvl).nonzero(as_tuple=True)[0]
                if len(inds) == 0:
                    continue
                boxes_lvl = pooler_fmt_boxes[inds]
                feat_lvl = pooler(x[lvl], boxes_lvl)
                output.index_put_((inds,), feat_lvl)

            return output

In [45]:
# Sanity check
import math
import torch
from torch import nn
import torch.nn.functional as F
from qct_3d_nod_detect.anchor_generator_3d import DefaultAnchorGenerator3D
from qct_3d_nod_detect.box_regression import Box3DTransform
from qct_3d_nod_detect.matcher import Matcher
from qct_3d_nod_detect.rpn import RPN3D, StandardRPNHead3d

# ──────────────────────────────────────────────────────────────
#  Reuse your existing setup
# ──────────────────────────────────────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

N = 2
C = 256

image_sizes = [(32, 128, 128), (32, 128, 128)]

features = {
    "p3": torch.randn(N, C, 16, 64, 64, device=device),   # stride ~8
    "p4": torch.randn(N, C,  8, 32, 32, device=device),   # stride ~16
    "p5": torch.randn(N, C,  4, 16, 16, device=device),   # stride ~32
}

class ImageList3D:
    def __init__(self, image_sizes):
        self.image_sizes = image_sizes

images = ImageList3D(image_sizes)

gt_instances = []

for i in range(N):
    inst = Instances3D(image_sizes[i])

    inst.gt_boxes = Boxes3D(
        torch.tensor(
            [
                [10, 20, 5, 40, 60, 20],
                [50, 40, 10, 90, 100, 30]
            ],
            dtype=torch.float32,
            device=device
        )
    )

    gt_instances.append(inst)

anchor_generator_3d = DefaultAnchorGenerator3D(
    sizes=[[2], [4], [8]],
    aspect_ratios_3d=[[(1.0, 1.0)], [(1.0, 1.0)], [(1.0, 1.0)]],
    strides=[8, 16, 32],
    offset=0.5,
).to(device)

print("Anchors per level:", anchor_generator_3d.num_cell_anchors)

box3d2box3d_transform = Box3DTransform(
    weights=(1.0, 1.0, 1.0, 1.0, 1.0, 1.0),
    scale_clamp=math.log(1000.0),
)

num_anchors = anchor_generator_3d.num_cell_anchors[0]  # same for all levels
rpn_head_3d = StandardRPNHead3d(
    in_channels=C,
    num_anchors=num_anchors,
    box_dim=6,
).to(device)

anchor_matcher = Matcher(
    thresholds=[0.3, 0.7],
    labels=[0, -1, 1],
    allow_low_quality_matches=True,
)

rpn = RPN3D(
    in_features=["p3", "p4", "p5"],
    head=rpn_head_3d,
    anchor_generator=anchor_generator_3d,
    anchor_matcher=anchor_matcher,
    box3d_transform=box3d2box3d_transform,
    batch_size_per_image=256,
    positive_fraction=0.5,
    pre_nms_topk=(200, 100),
    post_nms_topk=(100, 50),
    nms_thresh=0.5,
    min_box_size=2.0,
    box_reg_loss_type="smooth_l1",
    smooth_l1_beta=0.0,
).to(device)

rpn.eval()
with torch.no_grad():
    proposals, losses = rpn(images, features, gt_instances)

Using device: cuda
Anchors per level: [1, 1, 1]


In [46]:
total_proposals = sum(len(p) for p in proposals)
print(f"Total proposals: {total_proposals}")

Total proposals: 45


In [47]:
pooler = ROIPooler3D(
    output_size     = (7, 7, 7),
    scales          = [1/8.0, 1/16.0, 1/32.0],   # must match feature strides
    sampling_ratio  = 0,                         # your impl ignores it anyway
    pooler_type     = "ROIALign3DV2",            # or "ROIAlign3D" or "ROIPool3D"
    canonical_box_size = 32.0,                   # smaller than 224 — your volumes are tiny
    canonical_level = 1,                         # adjust depending on how you number levels
).to(device)

In [48]:
print("ROIPooler3D created with pooler type:", pooler.pooler_type)
print("Levels:", pooler.min_level, "→", pooler.max_level)

ROIPooler3D created with pooler type: ROIALign3DV2
Levels: 3 → 5


In [49]:
multi_scale_features = [features["p3"], features["p4"], features["p5"]]

In [50]:
proposal_boxes_list = [inst.proposal_boxes for inst in proposals]

In [51]:
pooled = pooler(
    x          = multi_scale_features,
    box_lists  = proposal_boxes_list
)

In [52]:
pooled.shape

torch.Size([45, 256, 7, 7, 7])