In [2]:
#| default_exp anchor_generator_3d
#| export
from torch import nn
import torch
from typing import List
from qct_3d_nod_detect.layers import ShapeSpec
import collections

def _create_grid_offsets_3d(
    size: List[int],          # [D, H, W]
    stride: int,
    offset: float,
    device: torch.device     
):

    D, H, W = size

    shifts_x = (torch.arange(W, device=device) + offset) * stride
    shifts_y = (torch.arange(H, device=device) + offset) * stride
    shifts_z = (torch.arange(D, device=device) + offset) * stride

    shift_z, shift_y, shift_x = torch.meshgrid(
        shifts_z, shifts_y, shifts_x, indexing="ij"
    )

    return (
        shift_x.reshape(-1),
        shift_y.reshape(-1),
        shift_z.reshape(-1),
    )

def _broadcast_params(params, num_features, name):

    assert isinstance(params, collections.abc.Sequence), f"{name} must be a list!"
    assert len(params), f"{name} cannot be empty!"

    if not isinstance(params[0], collections.abc.Sequence):
        return [params] * num_features
    if len(params) == 1:
        return list(params) * num_features
    assert len(params) == num_features, (
        f"Got {name} of length {len(params)}, but {num_features} feature maps!"
    )

    return params

class BufferList(nn.Module):

    """Non-persistent buffers"""
    
    def __init__(self, buffers):
        super().__init__()
        for i, buffer in enumerate(buffers):
            self.register_buffer(str(i), buffer, persistent=False)

    def __len__(self):
        return len(self._buffers)

    def __iter__(self):
        return iter(self._buffers.values())

class DefaultAnchorGenerator3D(nn.Module):

    box_dim: int = 6

    def __init__(
        self,
        *,
        sizes,
        aspect_ratios_3d,
        strides,
        offset: float = 0.5
    ):

        super().__init__()
        self.strides = strides
        self.num_features = len(self.strides)

        sizes = _broadcast_params(sizes, self.num_features, "sizes")
        aspect_ratios_3d = _broadcast_params(aspect_ratios_3d, self.num_features, "aspect_ratios_3d")

        self.cell_anchors = self._calculate_anchors(sizes, aspect_ratios_3d)
        self.offset = offset

        assert 0.0 <= self.offset < 1.0

    @classmethod
    def from_config(cls, cfg, input_shape: List[ShapeSpec]):
        return {
            "sizes": cfg.MODEL.ANCHOR_GENERATOR.SIZES,
            "aspect_ratios_3d": cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS_3D,
            "strides": [x.stride for x  in input_shape],
            "offset": cfg.MODEL.ANCHOR_GENERATOR.OFFSET,
        }

    def _calculate_anchors(self, sizes, aspect_ratios_3d):
        cell_anchors = [
            self.generate_cell_anchors_3d(s, ar).float()
            for s, ar in zip(sizes, aspect_ratios_3d)
        ]

        return BufferList(cell_anchors)

    @property
    def num_cell_anchors(self):
        return [len(ca) for ca in self.cell_anchors]

    def _grid_anchors(self, grid_sizes: List[List[int]]):
        anchors = []
        device = next(iter(self.cell_anchors._buffers.values())).device

        for size, stride, base_anchors in zip(grid_sizes, self.strides, self.cell_anchors):
            shift_x, shift_y, shift_z = _create_grid_offsets_3d(size, stride, self.offset, device)
            shifts = torch.stack((shift_x, shift_y, shift_z, shift_x, shift_y, shift_z), dim=1)

            grid_anchors = (
                shifts[:, None, :] + base_anchors[None, :, :]
            ).reshape(-1, 6)

            anchors.append(grid_anchors)

        return anchors

    def generate_cell_anchors_3d(self, sizes=(32, 64, 128 ,256), aspect_ratios_3d=((1.0, 1.0), (1.0, 2.0), (2.0, 1.0), (0.5, 1.0), (1.0, 0.5))):

        anchors = []
        for size in sizes:
            for dz_dx, dy_dx in aspect_ratios_3d:
                w = size
                h = size * dy_dx
                d = size * dz_dx

                x0 = -w / 2.0
                y0 = -h / 2.0
                z0 = -d / 2.0
                x1 = +w / 2.0
                y1 = +h / 2.0
                z1 = +d / 2.0

                anchors.append([x0, y0, z0, x1, y1, z1])

        return torch.tensor(anchors)

    def forward(self, features: List[torch.Tensor]):

        grid_sizes = [f.shape[-3:] for f in features]
        anchors = self._grid_anchors(grid_sizes)

        return anchors

In [3]:
#| export
def build_anchor_generator_3d(cfg, input_shape):
    return DefaultAnchorGenerator3D.from_config(cfg, input_shape)