In [8]:
# Sanity check
import math
import torch
from torch import nn
import torch.nn.functional as F
from qct_3d_nod_detect.anchor_generator_3d import DefaultAnchorGenerator3D
from qct_3d_nod_detect.box_regression import Box3DTransform
from qct_3d_nod_detect.matcher import Matcher
from qct_3d_nod_detect.rpn import RPN3D, StandardRPNHead3d
from qct_3d_nod_detect.structures import Boxes3D, Instances3D
from qct_3d_nod_detect.poolers import ROIPooler3D

# ──────────────────────────────────────────────────────────────
#  Reuse your existing setup
# ──────────────────────────────────────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

N = 2
C = 256

image_sizes = [(32, 128, 128), (32, 128, 128)]

features = {
    "p3": torch.randn(N, C, 16, 64, 64, device=device),   # stride ~8
    "p4": torch.randn(N, C,  8, 32, 32, device=device),   # stride ~16
    "p5": torch.randn(N, C,  4, 16, 16, device=device),   # stride ~32
}

class ImageList3D:
    def __init__(self, image_sizes):
        self.image_sizes = image_sizes

images = ImageList3D(image_sizes)

gt_instances = []

for i in range(N):
    inst = Instances3D(image_sizes[i])

    inst.gt_boxes = Boxes3D(
        torch.tensor(
            [
                [10, 20, 5, 40, 60, 20],
                [50, 40, 10, 90, 100, 30]
            ],
            dtype=torch.float32,
            device=device
        )
    )

    gt_instances.append(inst)

anchor_generator_3d = DefaultAnchorGenerator3D(
    sizes=[[2], [4], [8]],
    aspect_ratios_3d=[[(1.0, 1.0)], [(1.0, 1.0)], [(1.0, 1.0)]],
    strides=[8, 16, 32],
    offset=0.5,
).to(device)

print("Anchors per level:", anchor_generator_3d.num_cell_anchors)

box3d2box3d_transform = Box3DTransform(
    weights=(1.0, 1.0, 1.0, 1.0, 1.0, 1.0),
    scale_clamp=math.log(1000.0),
)

num_anchors = anchor_generator_3d.num_cell_anchors[0]  # same for all levels
rpn_head_3d = StandardRPNHead3d(
    in_channels=C,
    num_anchors=num_anchors,
    box_dim=6,
).to(device)

anchor_matcher = Matcher(
    thresholds=[0.3, 0.7],
    labels=[0, -1, 1],
    allow_low_quality_matches=True,
)

rpn = RPN3D(
    in_features=["p3", "p4", "p5"],
    head=rpn_head_3d,
    anchor_generator=anchor_generator_3d,
    anchor_matcher=anchor_matcher,
    box3d_transform=box3d2box3d_transform,
    batch_size_per_image=256,
    positive_fraction=0.5,
    pre_nms_topk=(200, 100),
    post_nms_topk=(100, 50),
    nms_thresh=0.5,
    min_box_size=2.0,
    box_reg_loss_type="smooth_l1",
    smooth_l1_beta=0.0,
).to(device)

rpn.eval()
with torch.no_grad():
    proposals, losses = rpn(images, features, gt_instances)

total_proposals = sum(len(p) for p in proposals)
print(f"Total proposals: {total_proposals}")

pooler = ROIPooler3D(
    output_size     = (7, 7, 7),
    scales          = [1/8.0, 1/16.0, 1/32.0],   # must match feature strides
    sampling_ratio  = 0,                         # your impl ignores it anyway
    pooler_type     = "ROIALign3DV2",            # or "ROIAlign3D" or "ROIPool3D"
    canonical_box_size = 32.0,                   # smaller than 224 — your volumes are tiny
    canonical_level = 1,                         # adjust depending on how you number levels
).to(device)

print("ROIPooler3D created with pooler type:", pooler.pooler_type)
print("Levels:", pooler.min_level, "→", pooler.max_level)

multi_scale_features = [features["p3"], features["p4"], features["p5"]]

proposal_boxes_list = [inst.proposal_boxes for inst in proposals]
pooled = pooler(
    x          = multi_scale_features,
    box_lists  = proposal_boxes_list
)

print(pooled.shape)

Using device: cuda
Anchors per level: [1, 1, 1]
Total proposals: 25
ROIPooler3D created with pooler type: ROIALign3DV2
Levels: 3 → 5
torch.Size([25, 256, 7, 7, 7])
