In [None]:
# | default_exp nets/upernet_3d

# Imports

In [None]:
# | export

from typing import Literal

import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from torch.nn import functional as F

from vision_architectures.blocks.cnn import CNNBlock3D, CNNBlockConfig
from vision_architectures.nets.fpn_3d import FPN3D, FPN3DConfig
from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.custom_base_model import Field, model_validator

# Config

In [None]:
# | export


class UPerNet3DFusionConfig(CNNBlockConfig):
    dim: int
    num_features: int
    kernel_size: int = 3
    fused_shape: tuple[int, int, int] | None = Field(
        None,
        description=(
            "Shape of the fused feature map. It can also be provided during runtime. "
            "If None, highest input resolution is used."
        ),
    )
    interpolation_mode: str = "trilinear"

    in_channels: None = Field(None, description="Calculated based on other parameters")
    out_channels: None = Field(None, description="Calculated based on other parameters")


class UPerNet3DConfig(FPN3DConfig):
    fusion: UPerNet3DFusionConfig

    enabled_outputs: set[Literal["object", "part", "scene", "material", "texture"]] = {"object"}
    num_objects: int | None = None

    @model_validator(mode="before")
    @classmethod
    def validate_before(cls, data: dict):
        data = FPN3DConfig.validate_before(data)
        data = UPerNet3DFusionConfig.validate_before(data)
        data.setdefault(
            "fusion", data | {"dim": data.get("blocks")[0].get("dim"), "num_features": len(data.get("blocks", []))}
        )
        return data

    @model_validator(mode="after")
    def validate(self):
        super().validate()
        assert self.dim == self.fusion.dim, "Fusion dim must match the FPN output dim"
        if "object" in self.enabled_outputs:
            assert self.num_objects is not None, "num_objects must be set when 'object' output is enabled"
        return self

In [None]:
UPerNet3DConfig(dim=128, skip_conn_dims=[12, 24, 36, 48], num_objects=3)


[1;35mUPerNet3DConfig[0m[1m([0m
    [33mblocks[0m=[1m[[0m
        [1;35mFPN3DBlockConfig[0m[1m([0m
            [33min_channels[0m=[3;35mNone[0m,
            [33mout_channels[0m=[3;35mNone[0m,
            [33mkernel_size[0m=[1;36m3[0m,
            [33mpadding[0m=[32m'same'[0m,
            [33mstride[0m=[1;36m1[0m,
            [33mconv_kwargs[0m=[1m{[0m[1m}[0m,
            [33mtransposed[0m=[3;91mFalse[0m,
            [33mnormalization[0m=[32m'batchnorm3d'[0m,
            [33mnormalization_pre_args[0m=[1m[[0m[1m][0m,
            [33mnormalization_post_args[0m=[1m[[0m[1m][0m,
            [33mnormalization_kwargs[0m=[1m{[0m[1m}[0m,
            [33mactivation[0m=[32m'relu'[0m,
            [33mactivation_kwargs[0m=[1m{[0m[1m}[0m,
            [33msequence[0m=[32m'CNA'[0m,
            [33mdrop_prob[0m=[1;36m0[0m[1;36m.0[0m,
            [33mdim[0m=[1;36m128[0m,
            [33mskip_conn_dim[0m=[1;36m12[

# Architecture

### Basic block

In [None]:
# | export


class UPerNet3DFusion(nn.Module):
    def __init__(self, config: UPerNet3DFusionConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__()

        self.config = UPerNet3DFusionConfig.model_validate(config | kwargs)

        self.conv = CNNBlock3D(
            self.config,
            in_channels=self.config.dim * self.config.num_features,
            out_channels=self.config.dim,
            checkpointing_level=checkpointing_level,
        )

        self.checkpointing_level1 = ActivationCheckpointing(1, checkpointing_level)
        self.checkpointing_level2 = ActivationCheckpointing(2, checkpointing_level)

    def concat_features(self, features: list[torch.Tensor], fused_shape: tuple[int, int, int] | None = None):
        """Concatenate features from different resolutions and interpolate them to the same size.

        Args:
            features (list[torch.Tensor]): List of feature maps to be concatenated.
                Each feature map should have shape (b, dim, d, h, w).
            fused_shape (tuple[int, int, int] | None): Shape to which all feature maps will be interpolated.
                If None, value entered in the config is used. If that is None too, the shape of the largest feature map
                is used.
        """
        # features: List of [(b, dim, d1, h1, w1), (b, dim, d2, h2, ...]

        if fused_shape is None:
            fused_shape = self.config.fused_shape
        if fused_shape is None:
            fused_shape = features[0].shape[2:]
            for feature in features:
                resolution = feature.shape[2:].numel()
                if resolution > fused_shape.numel():
                    fused_shape = feature.shape[2:]
        # (d, h, w)

        for i in range(len(features)):
            features[i] = F.interpolate(
                features[i],
                size=fused_shape,
                mode=self.config.interpolation_mode,
                align_corners=False,
            )
            # Each is (b, dim, d, h, w)

        concatenated_features = torch.cat(features, dim=1)
        # (b, dim * num_features, d, h, w)

        return concatenated_features

    def fuse_features(self, concatenated_features: torch.Tensor):
        # (b, dim * num_features, d, h, w)
        fused_features = self.conv(concatenated_features)
        # (b, dim, d, h, w)

        return fused_features

    def _forward(self, features: list[torch.Tensor], fused_shape: tuple[int, int, int] | None = None):
        # features: List of [(b, dim, d1, h1, w1), (b, dim, d2, h2, w2), ...] where d1 > d2 > ...
        concatenated_features = self.checkpointing_level1(self.concat_features, features, fused_shape)
        # (b, dim * num_features, d, h, w)
        fused_features = self.checkpointing_level1(self.fuse_features, concatenated_features)
        # (b, dim, d, h, w)

        return fused_features

    def forward(self, *args, **kwargs):
        return self.checkpointing_level2(self._forward, *args, **kwargs)

In [None]:
test_input = [
    torch.randn(2, 128, 1, 2, 2),
    torch.randn(2, 128, 2, 4, 4),
    torch.randn(2, 128, 4, 8, 8),
    torch.randn(2, 128, 8, 16, 16),
]
test = UPerNet3DFusion(dim=128, num_features=4)

display(test)
display(test(test_input).shape)


[1;35mUPerNet3DFusion[0m[1m([0m
  [1m([0mconv[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m512[0m, [1;36m128[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[35msame[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m128[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
  [1m([0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m

In [None]:
test_input = [
    torch.randn(2, 128, 1, 2, 2),
    torch.randn(2, 128, 2, 4, 4),
    torch.randn(2, 128, 4, 8, 8),
    torch.randn(2, 128, 8, 16, 16),
]
test = UPerNet3DFusion(dim=128, num_features=4, fused_shape=(6, 12, 12))

display(test)
display(test(test_input).shape)


[1;35mUPerNet3DFusion[0m[1m([0m
  [1m([0mconv[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m512[0m, [1;36m128[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[35msame[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m128[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
  [1m([0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m6[0m, [1;36m12[0m, [1;36m12[0m[1m][0m[1m)[0m

In [None]:
test_input = [
    torch.randn(2, 128, 1, 2, 2),
    torch.randn(2, 128, 2, 4, 4),
    torch.randn(2, 128, 4, 8, 8),
    torch.randn(2, 128, 8, 16, 16),
]
test = UPerNet3DFusion(dim=128, num_features=4, fused_shape=(3, 6, 6))

display(test)
display(test(test_input, fused_shape=(6, 12, 12)).shape)


[1;35mUPerNet3DFusion[0m[1m([0m
  [1m([0mconv[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m512[0m, [1;36m128[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[35msame[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m128[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
  [1m([0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m6[0m, [1;36m12[0m, [1;36m12[0m[1m][0m[1m)[0m

### Complete architecture

In [None]:
# | export


class UPerNet3D(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config: UPerNet3DConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__()

        self.config = UPerNet3DConfig.model_validate(config | kwargs)

        self.fpn = FPN3D(config, checkpointing_level=checkpointing_level)

        enabled_outputs = self.config.enabled_outputs

        self.fusion = None
        self.object_head = None
        self.scene_head = None
        self.part_head = None
        self.material_head = None
        self.texture_head = None

        # TODO: Implement scene, part, material, texture
        if {"object", "part"} & set(enabled_outputs):
            self.fusion = UPerNet3DFusion(self.config.fusion, checkpointing_level=checkpointing_level)

            if "object" in enabled_outputs:
                self.object_head = nn.Sequential(
                    CNNBlock3D(
                        in_channels=self.config.dim,
                        out_channels=self.config.dim,
                        kernel_size=3,
                        checkpointing_level=checkpointing_level,
                    ),
                    CNNBlock3D(
                        in_channels=self.config.dim,
                        out_channels=self.config.num_objects,
                        kernel_size=1,
                        activation=None,
                        normalization=None,
                        checkpointing_level=checkpointing_level,
                    ),
                )

            if "part" in enabled_outputs:
                raise NotImplementedError("Part output not implemented yet")

        if "scene" in enabled_outputs:
            raise NotImplementedError("Scene output not implemented yet")

        if "material" in enabled_outputs:
            raise NotImplementedError("Material output not implemented yet")

        if "texture" in enabled_outputs:
            raise NotImplementedError("Texture output not implemented yet")

    def forward(self, features: list[torch.Tensor], output_shape: tuple[int, int, int] = None):
        """Forward pass of the UPerNet3D model.

        Args:
            features (list[torch.Tensor]): List of feature maps from the FPN.
                Each feature map should have shape (b, dim, d, h, w).
            output_shape (tuple[int, int, int], optional): Desired output shape for the object head. If None, the shape
                of the highest resolution feature map is used.
        """
        # features: [
        #   (b, in_dim1, d1, h1, w1),
        #   (b, in_dim2, d2, h2, w2),
        #   ...
        # ] where d1 > d2 > ...

        features = self.fpn(features)
        # features: [
        #   (b, fpn_dim, d1, h1, w1),
        #   (b, fpn_dim, d2, h2, w2),
        #   ...
        # ] where d1 < d2 < ...

        output = {}

        if self.fusion is not None:
            fused_features = self.fusion(features, output_shape)
            # (b, fpn_dim, d1, h1, w1)

            object_logits = self.object_head(fused_features)
            # (b, num_objects, d1, h1, w1)

            output["object"] = object_logits

        return output

In [None]:
test_config = UPerNet3DConfig.model_validate(
    {
        "dim": 128,
        "skip_conn_dims": [64, 128, 256, 512],
        "num_objects": 3,
        "fused_shape": (16, 32, 32),
        "enabled_outputs": {"object"},
    }
)
test_input = [
    torch.randn(2, 64, 8, 16, 16),
    torch.randn(2, 128, 4, 8, 8),
    torch.randn(2, 256, 2, 4, 4),
    torch.randn(2, 512, 1, 2, 2),
]
test = UPerNet3D(test_config)

display(test)
display({key: value.shape for key, value in test(test_input).items()})


[1;35mUPerNet3D[0m[1m([0m
  [1m([0mfpn[1m)[0m: [1;35mFPN3D[0m[1m([0m
    [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m[1m)[0m: [1;35mFPN3DBlock[0m[1m([0m
        [1m([0mskip_conn_conv[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
          [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m64[0m, [1;36m128[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[35msame[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
          [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m128[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
          [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckp

[1m{[0m[32m'object'[0m: [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m16[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m[1m}[0m

In [None]:
test_config = UPerNet3DConfig.model_validate(
    {
        "dim": 128,
        "skip_conn_dims": [64, 128, 256, 512],
        "num_objects": 3,
        "fused_shape": (16, 32, 32),
        "enabled_outputs": {"object"},
    }
)
test_input = [
    torch.randn(2, 64, 8, 16, 16),
    torch.randn(2, 128, 4, 8, 8),
    torch.randn(2, 256, 2, 4, 4),
    torch.randn(2, 512, 1, 2, 2),
]
test = UPerNet3D(test_config)

display(test)
display({key: value.shape for key, value in test(test_input, (8, 12, 12)).items()})


[1;35mUPerNet3D[0m[1m([0m
  [1m([0mfpn[1m)[0m: [1;35mFPN3D[0m[1m([0m
    [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m[1m)[0m: [1;35mFPN3DBlock[0m[1m([0m
        [1m([0mskip_conn_conv[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
          [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m64[0m, [1;36m128[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[35msame[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
          [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m128[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
          [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckp

[1m{[0m[32m'object'[0m: [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m8[0m, [1;36m12[0m, [1;36m12[0m[1m][0m[1m)[0m[1m}[0m

# nbdev

In [None]:
!nbdev_export