In [1]:
# | default_exp upernet_3d

# Imports

In [2]:
# | export

import torch
from torch import nn
from torch.nn import functional as F
from huggingface_hub import PyTorchModelHubMixin

from vision_architectures.fpn_3d import FPN3D
from vision_architectures.activation_checkpointing import ActivationCheckpointing

# Architecture

### Basic block

In [3]:
# | export


class UPerNet3DFusion(nn.Module):
    def __init__(self, dim, num_layers, fusion_shape=None, checkpointing_level=0):
        super().__init__()

        self.fusion_shape = fusion_shape
        # (d, h, w) | None

        self.conv = nn.Sequential(
            nn.Conv3d(dim * num_layers, dim, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(dim),
            nn.ReLU(inplace=True),
        )

        self.checkpointing_level1 = ActivationCheckpointing(1, checkpointing_level)
        self.checkpointing_level2 = ActivationCheckpointing(2, checkpointing_level)

    def concat_features(self, features: list[torch.Tensor]):
        # features: List of [(b, dim, d1, h1, w1), (b, dim, d2, h2, ...]

        if self.fusion_shape is None:
            self.fusion_shape = features[0].shape[-3:]
            # (d, h, w)

        for i in range(len(features)):
            features[i] = F.interpolate(features[i], size=self.fusion_shape, mode="trilinear", align_corners=False)
            # Each is (b, dim, d, h, w)

        concatenated_features = torch.cat(features, dim=1)
        # (b, dim * num_layers, d, h, w)

        return concatenated_features

    def fuse_features(self, concatenated_features: torch.Tensor):
        # (b, dim * num_layers, d, h, w)
        fused_features = self.conv(concatenated_features)
        # (b, dim, d, h, w)

        return fused_features

    def forward(self, features: list[torch.Tensor]):
        # features: List of [(b, dim, d1, h1, w1), (b, dim, d2, h2, w2), ...]
        concatenated_features = self.checkpointing_level1(self.concat_features, features)
        # (b, dim * num_layers, d, h, w)
        fused_features = self.checkpointing_level2(self.fuse_features, concatenated_features)
        # (b, dim, d, h, w)

        return fused_features

In [4]:
test_input = [
    torch.randn(2, 128, 8, 16, 16),
    torch.randn(2, 128, 4, 8, 8),
    torch.randn(2, 128, 2, 4, 4),
    torch.randn(2, 128, 1, 2, 2),
]
test = UPerNet3DFusion(128, 4)

display(test)
display(test(test_input).shape)


[1;35mUPerNet3DFusion[0m[1m([0m
  [1m([0mconv[1m)[0m: [1;35mSequential[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m512[0m, [1;36m128[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0m[1;36m1[0m[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m128[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0m[1;36m2[0m[1m)[0m: [1;35mReLU[0m[1m([0m[33minplace[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m128[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m

### Complete architecture

In [5]:
# | export


class UPerNet3D(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()

        self.fpn = FPN3D(config)

        dim = config["fpn_dim"]
        num_layers = len(config["in_dims"])
        num_objects = config["num_objects"]
        checkpointing_level = config["checkpointing_level"]
        enabled_outputs = config["enabled_outputs"]
        fusion_shape = config.get("fusion_shape", None)

        self.output_shape = config["output_shape"]
        # (d, h, w)

        self.fusion = None
        self.object_head = None
        self.scene_head = None
        self.part_head = None
        self.material_head = None
        self.texture_head = None

        # TODO: Implement scene, part, material, texture
        if {"object", "part"} & set(enabled_outputs):
            self.fusion = UPerNet3DFusion(dim, num_layers, fusion_shape, checkpointing_level=checkpointing_level)

            if "object" in enabled_outputs:
                self.object_head = nn.Sequential(
                    nn.Conv3d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm3d(dim),
                    nn.ReLU(inplace=True),
                    nn.Conv3d(dim, num_objects, kernel_size=1, stride=1),
                )

            if "part" in enabled_outputs:
                raise NotImplementedError("Part output not implemented yet")

        if "scene" in enabled_outputs:
            raise NotImplementedError("Scene output not implemented yet")

        if "material" in enabled_outputs:
            raise NotImplementedError("Material output not implemented yet")

        if "texture" in enabled_outputs:
            raise NotImplementedError("Texture output not implemented yet")

    def forward(self, features: list[torch.Tensor]):
        # features: [
        #   (b, in_dim1, d1, h1, w1),
        #   (b, in_dim2, d2, h2, w2),
        #   ...
        # ]

        features = self.fpn(features)
        # features: [
        #   (b, fpn_dim, d1, h1, w1),
        #   (b, fpn_dim, d2, h2, w2),
        #   ...
        # ]

        output = {}

        if self.fusion is not None:
            fused_features = self.fusion(features)
            # (b, fpn_dim, d1, h1, w1)

            object_logits = self.object_head(fused_features)
            # (b, num_objects, d1, h1, w1)

            object_logits = F.interpolate(object_logits, size=self.output_shape, mode="trilinear", align_corners=False)
            # (b, num_objects, d, h, w)

            output["object"] = object_logits

        return output

In [6]:
test_config = {
    "fpn_dim": 128,
    "in_dims": [64, 128, 256, 512],
    "checkpointing_level": 1,
    "num_objects": 3,
    "output_shape": (16, 32, 32),
    "enabled_outputs": {"object"},
}
test_input = [
    torch.randn(2, 64, 8, 16, 16),
    torch.randn(2, 128, 4, 8, 8),
    torch.randn(2, 256, 2, 4, 4),
    torch.randn(2, 512, 1, 2, 2),
]
test = UPerNet3D(test_config)

display(test)
display({key: value.shape for key, value in test(test_input).items()})


[1;35mUPerNet3D[0m[1m([0m
  [1m([0mfpn[1m)[0m: [1;35mFPN3D[0m[1m([0m
    [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m[1m)[0m: [1;35mFPN3DBlock[0m[1m([0m
        [1m([0min_conv[1m)[0m: [1;35mSequential[0m[1m([0m
          [1m([0m[1;36m0[0m[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m64[0m, [1;36m128[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
          [1m([0m[1;36m1[0m[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m128[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
          [1m([0m[1;36m2[0m[1m)[0m: [1;35mReLU[0m[1m([0m[33minplace[0m=[3;92mTrue[0m[1m)[0m
        [1m)[0m
        [1m([0mout_conv[1m)[0m: 

[1m{[0m[32m'object'[0m: [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m16[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m[1m}[0m

In [7]:
from neuro_utils.describe import describe_model

describe_model(test)

Total Parameters: 3,664,515
+--------------------------------+------------+
|             Module             | Parameters |
+--------------------------------+------------+
| fpn.blocks.0.in_conv.0.weight  |   8,192    |
| fpn.blocks.0.in_conv.1.weight  |    128     |
|  fpn.blocks.0.in_conv.1.bias   |    128     |
| fpn.blocks.0.out_conv.0.weight |  442,368   |
| fpn.blocks.0.out_conv.1.weight |    128     |
|  fpn.blocks.0.out_conv.1.bias  |    128     |
| fpn.blocks.1.in_conv.0.weight  |   16,384   |
| fpn.blocks.1.in_conv.1.weight  |    128     |
|  fpn.blocks.1.in_conv.1.bias   |    128     |
| fpn.blocks.1.out_conv.0.weight |  442,368   |
| fpn.blocks.1.out_conv.1.weight |    128     |
|  fpn.blocks.1.out_conv.1.bias  |    128     |
| fpn.blocks.2.in_conv.0.weight  |   32,768   |
| fpn.blocks.2.in_conv.1.weight  |    128     |
|  fpn.blocks.2.in_conv.1.bias   |    128     |
| fpn.blocks.2.out_conv.0.weight |  442,368   |
| fpn.blocks.2.out_conv.1.weight |    128     |
|  fpn.block

# nbdev

In [8]:
!nbdev_export