In [30]:
#| default_exp box_heads
#| export 
from torch import nn
from typing import List
import numpy as np
from qct_3d_nod_detect.layers import ShapeSpec

import torch.nn as nn

def get_norm(norm, num_channels, dim=3):
    """
    Args:
        norm (str or callable or None):
            - "" or None: no normalization
            - "BN": BatchNorm
            - "GN": GroupNorm (32 groups)
            - callable: custom norm layer
        num_channels (int): number of channels
        dim (int): 2 or 3 (Conv2d / Conv3d)
    """
    if not norm:
        return None

    if callable(norm):
        return norm(num_channels)

    norm = norm.upper()

    if norm == "BN":
        return nn.BatchNorm3d(num_channels) if dim == 3 else nn.BatchNorm2d(num_channels)

    if norm == "GN":
        return nn.GroupNorm(32, num_channels)

    if norm == "LN":
        return nn.GroupNorm(1, num_channels)

    raise ValueError(f"Unsupported norm type: {norm}")

class FastRCNNConvFCHead3D(nn.Sequential):
    """
    3D version of FastRCNNConvFCHead.
    Consists of Conv3D layers followed by FC layers.
    """

    def __init__(
        self,
        input_shape: ShapeSpec,
        *,
        conv_dims: List[int],
        fc_dims: List[int],
        conv_norm: str = "",
    ):
        super().__init__()
        assert len(conv_dims) + len(fc_dims) > 0

        # input_shape: (C, D, H, W)
        self._output_size = (
            input_shape.channels,
            input_shape.depth,
            input_shape.height,
            input_shape.width,
        )

        # -------------------------
        # Conv3D stack
        # -------------------------
        self.conv_norm_relus = []
        for k, conv_dim in enumerate(conv_dims):
            conv = nn.Conv3d(
                in_channels=self._output_size[0],
                out_channels=conv_dim,
                kernel_size=3,
                padding=1,
                bias=not conv_norm,
            )

            self.add_module(f"conv{k+1}", conv)
            if conv_norm:
                self.add_module(
                    f"conv{k+1}_norm",
                    get_norm(conv_norm, conv_dim),
                )
            self.add_module(f"conv{k+1}_relu", nn.ReLU(inplace=True))

            self.conv_norm_relus.append(conv)
            self._output_size = (
                conv_dim,
                self._output_size[1],
                self._output_size[2],
                self._output_size[3],
            )

        # -------------------------
        # FC stack
        # -------------------------
        self.fcs = []
        for k, fc_dim in enumerate(fc_dims):
            if k == 0:
                self.add_module("flatten", nn.Flatten(start_dim=1))

            fc = nn.Linear(int(np.prod(self._output_size)), fc_dim)
            self.add_module(f"fc{k+1}", fc)
            self.add_module(f"fc{k+1}_relu", nn.ReLU(inplace=True))

            self.fcs.append(fc)
            self._output_size = fc_dim

        # -------------------------
        # Initialization
        # -------------------------
        for layer in self.conv_norm_relus:
            nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
        for layer in self.fcs:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.constant_(layer.bias, 0)

    @classmethod
    def from_config(cls, cfg, input_shape):
        num_conv = cfg.MODEL.ROI_BOX_HEAD.NUM_CONV
        conv_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_DIM
        num_fc = cfg.MODEL.ROI_BOX_HEAD.NUM_FC
        fc_dim = cfg.MODEL.ROI_BOX_HEAD.FC_DIM

        return {
            "input_shape": input_shape,
            "conv_dims": [conv_dim] * num_conv,
            "fc_dims": [fc_dim] * num_fc,
            "conv_norm": cfg.MODEL.ROI_BOX_HEAD.NORM,
        }

    def forward(self, x):
        for layer in self:
            x = layer(x)
        return x

    @property
    def output_shape(self):
        o = self._output_size
        if isinstance(o, int):
            return ShapeSpec(channels=o)
        else:
            return ShapeSpec(
                channels=o[0],
                depth=o[1],
                height=o[2],
                width=o[3],
            )

In [32]:
import torch

box_head = FastRCNNConvFCHead3D(input_shape=ShapeSpec(256, 7, 7, 7), conv_dims=[256, 256], fc_dims=[256, 256])
x = torch.rand(1, 256, 7, 7, 7)

In [35]:
out.shape

torch.Size([1, 256])

In [2]:
from qct_3d_nod_detect.layers import ShapeSpec
from qct_3d_nod_detect.box_heads import FastRCNNConvFCHead3D
import torch

input_shape = ShapeSpec(
    channels=256,
    depth=4,
    height=7,
    width=7,
)

box_head = FastRCNNConvFCHead3D(
    input_shape=input_shape,
    conv_dims=[256, 256],
    fc_dims=[1024],
)

x = torch.randn(8, 256, 4, 7, 7)  # 8 ROIs
y = box_head(x)

print(y.shape)


torch.Size([8, 1024])


In [3]:
from qct_3d_nod_detect.structures import Instances3D, Boxes3D
from qct_3d_nod_detect.matcher import Matcher
import torch

# Fake proposals
proposals = []
for _ in range(2):  # batch size = 2
    inst = Instances3D(image_size=(64, 64, 64))
    inst.proposal_boxes = Boxes3D(torch.rand(10, 6))  # 10 proposals
    inst.objectness_logits = torch.rand(10)
    proposals.append(inst)

# Fake GT
targets = []
for _ in range(2):
    inst = Instances3D(image_size=(64, 64, 64))
    inst.gt_boxes = Boxes3D(torch.rand(3, 6))         # 3 GT boxes
    inst.gt_classes = torch.randint(0, 2, (3,))       # 2 classes
    targets.append(inst)


In [4]:
from qct_3d_nod_detect.roi_heads import ROIHeads3D
from qct_3d_nod_detect.poolers import ROIPooler3D
from qct_3d_nod_detect.roi_head import FasterRCNNOutputLayers3D
from qct_3d_nod_detect.box_regression import Box3DTransform
import math

box3d2box3d_transform = Box3DTransform(
    weights=(1.0, 1.0, 1.0, 1.0, 1.0, 1.0),
    scale_clamp=math.log(1000.0),
)

proposal_matcher = Matcher(
    thresholds=[0.1, 0.2],
    labels = [0, -1, 1],
    allow_low_quality_matches=True
)

roi_pooler = ROIPooler3D(
    output_size=(7, 7, 7),
    canonical_level=4,
    canonical_box_size=224,
    pooler_type="ROIALign3DV2",
    scales=[1, 2, 0.5, 0.25]
)

box_predictor = FasterRCNNOutputLayers3D(
    input_shape=(32, 256, 7, 7, 7),
    num_classes=1,
    box2box_transform=box3d2box3d_transform,
    cls_agnostic_bbox_reg=False,
)

In [5]:
roi_heads = ROIHeads3D(
    num_classes=1,
    batch_size_per_image=2,
    positive_fraction=0.5,
    proposal_matcher=proposal_matcher,
    proposal_append_gt=True,       # IMPORTANT for early training
    roi_pooler=roi_pooler,         # your 3D ROI pooler
    box_head=box_head,             # FastRCNNConvFCHead3D
    box_predictor=box_predictor,   # FasterRCNNOutputLayers3D
    is_training=True
)

In [6]:
sampled = roi_heads.label_and_sample_proposals(proposals, targets)

In [25]:
for i, proposals_per_image in enumerate(proposals):
    print(f"Proposals in image - {i} - {proposals_per_image.proposal_boxes.tensor.shape[0]}")

Proposals in image - 0 - 10
Proposals in image - 1 - 10


In [27]:
B, C = 2, 256

features = {}
features['p2'] = torch.rand(B, C, 32, 32, 32)
features['p3'] = torch.rand(B, C, 16, 16, 16)
features['p4'] = torch.rand(B, C, 8, 8, 8)
features['p5'] = torch.rand(B, C, 4, 4, 4)

In [28]:
losses = roi_heads(features, proposals, targets)

AttributeError: Cannot find field 'tensor' in the given Instances!

NameError: name 'features' is not defined