## Simple FPN

In [34]:
#| default_exp SimpleFPN
#| export 

from torch import nn
import torch
from typing import Tuple

class SimpleFPN(nn.Module):
    def __init__(self, dim, out_channels, scales):
        super().__init__()
        self.blocks = nn.ModuleDict()

        scale_to_level = {
            2: "p2",
            1: "p3",
            0.5: "p4",
            0.25: "p5",
        }

        for scale in sorted(scales, reverse=True):
            layers = []

            if scale == 2:
                layers.append(nn.ConvTranspose3d(dim, dim, 2, 2))
            elif scale == 1:
                pass
            elif scale == 0.5:
                layers.append(nn.Conv3d(dim, dim, 2, 2))
            elif scale == 0.25:
                layers.append(nn.Conv3d(dim, dim, 2, 2))
                layers.append(nn.Conv3d(dim, dim, 2, 2))
            else:
                raise NotImplementedError

            layers.extend([
                nn.Conv3d(dim, out_channels, 1),
                nn.GroupNorm(1, out_channels),
                nn.Conv3d(out_channels, out_channels, 3, padding=1),
                nn.GroupNorm(1, out_channels),
            ])

            self.blocks[scale_to_level[scale]] = nn.Sequential(*layers)

    def forward(self, x):
        return {k: block(x) for k, block in self.blocks.items()}

class BackboneFPN(nn.Module):

    def __init__(
        self,
        backbone: nn.Module,
        fpn: nn.Module,
        patch_grid_size: Tuple[int, int, int] = (16, 16, 16),
    ):
        super().__init__()
        self.backbone = backbone
        self.fpn = fpn
        self.patch_grid_size = patch_grid_size

        block = fpn.blocks["p2"]
        for layer in reversed(block):
            if isinstance(layer, nn.Conv3d):
                self.out_channels = layer.out_channels
                break
            else:
                raise ValueError("Could not determine out_channels from FPN block.")
            

        self.out_channels = fpn.blocks["p2"][-1].num_channels \
            if isinstance(fpn.blocks["p2"][-1], nn.GroupNorm) \
                else None
        
    def forward(self, x):

        """
        Args:
            x: Tensor[B, 1, D, H, W]
        Returns:
            Dict[str, Tensor]: multi-scale features
        """

        feat = self.backbone(x)
        patch_tokens = feat['patch_tokens']  # B, N, C
        B, N, C = patch_tokens.shape

        Dp, Hp, Wp = self.patch_grid_size
        assert N == Dp * Hp * Wp, f"Expected N={Dp*Hp*Wp}, got N={N}"
        feat = (patch_tokens.view(B, Dp, Hp, Wp, C).permute(0, 4, 1, 2, 3).contiguous())
        features = self.fpn(feat)

        return features
