In [1]:
# | default_exp nets/fpn_3d

# Imports

In [2]:
# | export

import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from torch.nn import functional as F

from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing

# Architecture

### Basic block

In [3]:
# | export


class FPN3DBlock(nn.Module):
    def __init__(self, shallow_dim, fpn_dim, is_deepest=False, checkpointing_level=0):
        super().__init__()

        self.is_deepest = is_deepest
        self.checkpointing_level = checkpointing_level

        self.checkpointing_level1 = ActivationCheckpointing(1, checkpointing_level)

        self.in_conv = nn.Sequential(
            nn.Conv3d(shallow_dim, fpn_dim, kernel_size=1, bias=False),
            nn.BatchNorm3d(fpn_dim),
            nn.ReLU(inplace=True),
        )

        if not is_deepest:
            self.out_conv = nn.Sequential(
                nn.Conv3d(fpn_dim, fpn_dim, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm3d(fpn_dim),
                nn.ReLU(inplace=True),
            )

    def merge_features(self, shallow_features, deep_features):
        deep_features = F.interpolate(
            deep_features, size=shallow_features.shape[2:], mode="trilinear", align_corners=False
        )
        # (b, fpn_dim, d1, h1, w1)

        merged_features = shallow_features + deep_features
        # (b, fpn_dim, d1, h1, w1)

        merged_features = self.out_conv(merged_features)
        # (b, fpn_dim, d1, h1, w1)

        return merged_features

    def forward(self, shallow_features: torch.Tensor, deep_features: torch.Tensor):
        # shallow_features: (b, in_dim, d1, h1, w1)
        # deep_features: (b, fpn_dim, d2, h2, w2)

        shallow_features = self.in_conv(shallow_features)
        # (b, fpn_dim, d1, h1, w1)

        if self.is_deepest:
            merged_features = shallow_features
        else:
            merged_features = self.checkpointing_level1(self.merge_features, shallow_features, deep_features)
            # (b, fpn_dim, d1, h1, w1)

        return merged_features

In [4]:
test = FPN3DBlock(256, 128, is_deepest=True, checkpointing_level=0)
display(test)
display(test(torch.randn(2, 256, 4, 8, 8), None).shape)


[1;35mFPN3DBlock[0m[1m([0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[1m)[0m
  [1m([0min_conv[1m)[0m: [1;35mSequential[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m256[0m, [1;36m128[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0m[1;36m1[0m[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m128[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0m[1;36m2[0m[1m)[0m: [1;35mReLU[0m[1m([0m[33minplace[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m4[0m, [1;36m8[0m, [1;36m8[0m[1m][0m[1m)[0m

In [5]:
test = FPN3DBlock(256, 128, is_deepest=False, checkpointing_level=1)
display(test)
display(
    test(
        torch.randn(2, 256, 4, 8, 8, requires_grad=True),
        torch.randn(2, 128, 8, 16, 16, requires_grad=True),
    ).shape
)


[1;35mFPN3DBlock[0m[1m([0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[1m)[0m
  [1m([0min_conv[1m)[0m: [1;35mSequential[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m256[0m, [1;36m128[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0m[1;36m1[0m[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m128[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0m[1;36m2[0m[1m)[0m: [1;35mReLU[0m[1m([0m[33minplace[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
  [1m([0mout_conv[1m)[0m: [1;35mSequential[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m128[0m, [1;36m128[



[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m4[0m, [1;36m8[0m, [1;36m8[0m[1m][0m[1m)[0m

### Complete architecture

In [6]:
# | export


class FPN3D(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()

        fpn_dim = config["fpn_dim"]
        in_dims = config["in_dims"]

        self.blocks = nn.ModuleList()
        for i in range(len(in_dims)):
            is_deepest = False
            if i == len(in_dims) - 1:
                is_deepest = True

            self.blocks.append(
                FPN3DBlock(
                    in_dims[i],
                    fpn_dim,
                    is_deepest=is_deepest,
                    checkpointing_level=config["checkpointing_level"],
                )
            )

    def forward(self, features: list[torch.Tensor]):
        # features: [
        #   (b, in_dim1, d1, h1, w1),
        #   (b, in_dim2, d2, h2, w2),
        #   ...
        # ]

        features_None = features + [None]
        for i in range(len(features), 0, -1):
            features_None[i - 1] = self.blocks[i - 1](features_None[i - 1], features_None[i])
        features = features_None[:-1]

        return features

In [7]:
test_config = {
    "fpn_dim": 128,
    "in_dims": [64, 128, 256, 512],
    "checkpointing_level": 1,
}
test_input = [
    torch.randn(2, 64, 16, 32, 32),
    torch.randn(2, 128, 8, 16, 16),
    torch.randn(2, 256, 4, 8, 8),
    torch.randn(2, 512, 2, 4, 4),
]
test = FPN3D(test_config)

display(test)
display([output.shape for output in test(test_input)])


[1;35mFPN3D[0m[1m([0m
  [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mFPN3DBlock[0m[1m([0m
      [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[1m)[0m
      [1m([0min_conv[1m)[0m: [1;35mSequential[0m[1m([0m
        [1m([0m[1;36m0[0m[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m64[0m, [1;36m128[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
        [1m([0m[1;36m1[0m[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m128[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
        [1m([0m[1;36m2[0m[1m)[0m: [1;35mReLU[0m[1m([0m[33minplace[0m=[3;92mTrue[0m[1m)[0m
      [1m)[0m
      


[1m[[0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m16[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m4[0m, [1;36m8[0m, [1;36m8[0m[1m][0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m2[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m
[1m][0m

# nbdev

In [9]:
!nbdev_export