In [None]:
# | default_exp nets/fpn_2d

# Imports

In [None]:
# | export

from typing import Literal

import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from torch.nn import functional as F

from vision_architectures.blocks.cnn import CNNBlock2D, CNNBlockConfig
from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.custom_base_model import CustomBaseModel, Field, model_validator

# Config

In [None]:
# | export


class FPN2DBlockConfig(CNNBlockConfig):
    dim: int
    kernel_size: int = 3
    skip_conn_dim: int
    is_deepest: bool = Field(False, description="True if this is the deepest block in the FPN, else False")
    interpolation_mode: str = "bilinear"
    merge_method: Literal["add", "concat"] = "add"

    normalization: str = "batchnorm2d"

    in_channels: None = Field(None, description="Calculated based on other parameters")
    out_channels: None = Field(None, description="Calculated based on other parameters")


class FPN2DConfig(CustomBaseModel):
    blocks: list[FPN2DBlockConfig]

    @property
    def dim(self):
        return self.blocks[0].dim

    @model_validator(mode="before")
    @classmethod
    def validate_before(cls, data):
        if isinstance(data, dict):
            # Add skip_conn_dim first
            if "skip_conn_dims" in data:
                # Assume blocks are to be built from scratch
                assert "blocks" not in data, "Cannot provide both skip_conn_dims and blocks"
                skip_conn_dims = data.pop("skip_conn_dims")
                blocks: list[dict] = []
                for i, skip_conn_dim in enumerate(skip_conn_dims):
                    blocks.append({"skip_conn_dim": skip_conn_dim, "is_deepest": (i == len(skip_conn_dims) - 1)})
                data.setdefault("blocks", blocks)

            # Add the remaining
            for key, value in data.items():
                for block in data.get("blocks", []):
                    block.setdefault(key, value)
        return data

    @model_validator(mode="after")
    def validate(self):
        for block in self.blocks:
            assert block.dim == self.dim, "All blocks must have the same dim"
        return self

In [None]:
FPN2DConfig(skip_conn_dims=[12, 24, 36], dim=24, interpolation_mode="test")


[1;35mFPN2DConfig[0m[1m([0m
    [33mblocks[0m=[1m[[0m
        [1;35mFPN2DBlockConfig[0m[1m([0m
            [33min_channels[0m=[3;35mNone[0m,
            [33mout_channels[0m=[3;35mNone[0m,
            [33mkernel_size[0m=[1;36m3[0m,
            [33mpadding[0m=[32m'same'[0m,
            [33mstride[0m=[1;36m1[0m,
            [33mconv_kwargs[0m=[1m{[0m[1m}[0m,
            [33mtransposed[0m=[3;91mFalse[0m,
            [33mnormalization[0m=[32m'batchnorm2d'[0m,
            [33mnormalization_pre_args[0m=[1m[[0m[1m][0m,
            [33mnormalization_post_args[0m=[1m[[0m[1m][0m,
            [33mnormalization_kwargs[0m=[1m{[0m[1m}[0m,
            [33mactivation[0m=[32m'relu'[0m,
            [33mactivation_kwargs[0m=[1m{[0m[1m}[0m,
            [33msequence[0m=[32m'CNA'[0m,
            [33mdrop_prob[0m=[1;36m0[0m[1;36m.0[0m,
            [33mdim[0m=[1;36m24[0m,
            [33mskip_conn_dim[0m=[1;36m12[0m,
 

# Architecture

### Basic block

In [None]:
# | export


class FPN2DBlock(nn.Module):
    def __init__(self, config: FPN2DBlockConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__()

        self.config = FPN2DBlockConfig.model_validate(config | kwargs | {"kernel_size": 3})

        self.skip_conn_conv = CNNBlock2D(
            self.config,
            in_channels=self.config.skip_conn_dim,
            out_channels=self.config.dim,
            kernel_size=1,
            checkpointing_level=checkpointing_level,
        )

        if not self.config.is_deepest:
            self.out_conv = CNNBlock2D(
                self.config,
                in_channels=self.config.dim if self.config.merge_method == "add" else 2 * self.config.dim,
                out_channels=self.config.dim,
                checkpointing_level=checkpointing_level,
            )

        self.checkpointing_level2 = ActivationCheckpointing(2, checkpointing_level)

    def _forward(self, skip_conn_features: torch.Tensor, features: torch.Tensor | None):
        # skip_conn_features: (b, skip_conn_dim, h1, w1)
        # features: (b, dim, h2, w2)

        skip_conn_features = self.skip_conn_conv(skip_conn_features)
        # (b, dim, h1, w1)

        if not self.config.is_deepest:
            features = F.interpolate(
                features,
                size=skip_conn_features.shape[2:],
                mode=self.config.interpolation_mode,
                align_corners=False,
            )
            # (b, dim, h1, w1)

            if self.config.merge_method == "add":
                merged_features = skip_conn_features + features
                # (b, dim, h1, w1)
            elif self.config.merge_method == "concat":
                merged_features = torch.cat((skip_conn_features, features), dim=1)
                # (b, 2 * dim, h1, w1)

            merged_features = self.out_conv(merged_features)
            # (b, dim, h1, w1)
        else:
            merged_features = skip_conn_features

        return merged_features

    def forward(self, *args, **kwargs):
        return self.checkpointing_level2(self._forward, *args, **kwargs)

In [None]:
test = FPN2DBlock(dim=128, skip_conn_dim=256, is_deepest=True, checkpointing_level=2)
display(test)
display(test(torch.randn(2, 256, 8, 8), None).shape)


[1;35mFPN2DBlock[0m[1m([0m
  [1m([0mskip_conn_conv[1m)[0m: [1;35mCNNBlock2D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv2d[0m[1m([0m[1;36m256[0m, [1;36m128[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[35msame[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mnorm[1m)[0m: [1;35mBatchNorm2d[0m[1m([0m[1;36m128[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
  [1m([0mcheckpointing_level2[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m8[0m, [1;36m8[0m[1m][0m[1m)[0m

In [None]:
test = FPN2DBlock(dim=128, skip_conn_dim=256, is_deepest=False, checkpointing_level=1)
display(test)
display(test(torch.randn(2, 256, 16, 16), torch.randn(2, 128, 8, 8)).shape)


[1;35mFPN2DBlock[0m[1m([0m
  [1m([0mskip_conn_conv[1m)[0m: [1;35mCNNBlock2D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv2d[0m[1m([0m[1;36m256[0m, [1;36m128[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[35msame[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mnorm[1m)[0m: [1;35mBatchNorm2d[0m[1m([0m[1;36m128[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
  [1m([0mout_conv[1m)[0m: [1;35mCNNBlock2D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv2d[0m[1m([0m[1;36m128[0m, [1;36m128[0m, [33mkernel_size

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m

### Complete architecture

In [None]:
# | export


class FPN2D(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config: FPN2DConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__()

        self.config = FPN2DConfig.model_validate(config | kwargs)

        self.blocks = nn.ModuleList()
        for i in range(len(self.config.blocks)):
            self.blocks.append(FPN2DBlock(self.config.blocks[i], checkpointing_level))

        self.checkpointing_level4 = ActivationCheckpointing(4, checkpointing_level)

    def _forward(self, features: list[torch.Tensor]):
        # features: [
        #   (b, in_dim1, h1, w1),
        #   (b, in_dim2, h2, w2),
        #   ...
        # ]

        outputs = [None]
        for i in reversed(range(len(features))):
            outputs.append(self.blocks[i](features[i], outputs[-1]))
        outputs = outputs[1:]

        return outputs

    def forward(self, *args, **kwargs):
        return self.checkpointing_level4(self._forward, *args, **kwargs)

In [None]:
test_config = FPN2DConfig(skip_conn_dims=[64, 128, 256, 512], dim=128)
test_input = [
    torch.randn(2, 64, 32, 32),
    torch.randn(2, 128, 16, 16),
    torch.randn(2, 256, 8, 8),
    torch.randn(2, 512, 4, 4),
]
test = FPN2D(test_config, 3)

display(test)
display([output.shape for output in test(test_input)])


[1;35mFPN2D[0m[1m([0m
  [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mFPN2DBlock[0m[1m([0m
      [1m([0mskip_conn_conv[1m)[0m: [1;35mCNNBlock2D[0m[1m([0m
        [1m([0mconv[1m)[0m: [1;35mConv2d[0m[1m([0m[1;36m64[0m, [1;36m128[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[35msame[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
        [1m([0mnorm[1m)[0m: [1;35mBatchNorm2d[0m[1m([0m[1;36m128[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
        [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;92mTrue[0m[1m)[0m
      [1m)[0m
      [1m([0mout_


[1m[[0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m8[0m, [1;36m8[0m[1m][0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m
[1m][0m

# nbdev

In [None]:
!nbdev_export