In [None]:
# | default_exp blocks/mbconv_3d

# Imports

In [None]:
# | export


import torch
from einops import rearrange
from torch import nn

from vision_architectures.blocks.cnn import CNNBlock3D, CNNBlock3DConfig
from vision_architectures.blocks.se import SEBlock3D
from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.custom_base_model import model_validator
from vision_architectures.utils.residuals import Residual

# Config

In [None]:
# | export


class MBConv3DConfig(CNNBlock3DConfig):
    dim: int
    out_dim: int | None = None
    expansion_ratio: float = 6.0
    se_reduction_ratio: float = 4.0

    kernel_size: int = 3
    activation: str = "relu"
    normalization: str = "batchnorm3d"

    in_channels: None = None  # use dim instead
    out_channels: None = None  # use expansion_ratio instead

    @property
    def hidden_dim(self):
        return int(self.expansion_ratio * self.dim)

    @model_validator(mode="after")
    def validate(self):
        super().validate()
        min_expansion_ratio = (self.dim + 1) / self.dim
        assert self.expansion_ratio > min_expansion_ratio, f"expansion_ratio must be greater than {min_expansion_ratio}"
        if self.out_dim is None:
            self.out_dim = self.dim
        return self

In [None]:
MBConv3DConfig(dim=10, expansion_ratio=1.2).hidden_dim

[1;36m12[0m

# Architecture

In [None]:
# | export


class MBConv3D(nn.Module):
    def __init__(self, config: MBConv3DConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__()

        self.config = MBConv3DConfig.model_validate(config | kwargs)

        dim = self.config.dim
        hidden_dim = self.config.hidden_dim
        out_dim = self.config.out_dim
        se_reduction_ratio = self.config.se_reduction_ratio

        self.expand = CNNBlock3D(
            self.config,
            checkpointing_level,
            in_channels=dim,
            out_channels=hidden_dim,
            kernel_size=1,
            stride=1,
            padding=0,
        )
        self.depthwise_conv = CNNBlock3D(
            self.config,
            checkpointing_level,
            in_channels=hidden_dim,
            out_channels=hidden_dim,
            conv_kwargs=self.config.conv_kwargs | dict(groups=hidden_dim),
        )
        self.se = SEBlock3D(dim=hidden_dim, r=se_reduction_ratio)
        self.pointwise_conv = CNNBlock3D(
            self.config,
            checkpointing_level,
            in_channels=hidden_dim,
            out_channels=out_dim,
            kernel_size=1,
            stride=1,
            padding=0,
            activation=None,
        )

        self.residual = Residual()

        self.checkpointing_level2 = ActivationCheckpointing(2, checkpointing_level)

    def _forward(self, x: torch.Tensor, channels_first: bool = True):
        # x: (b, [dim], z, y, x, [dim])

        if not channels_first:
            x = rearrange(x, "b z y x d -> b d z y x").contiguous()

        # Now x is (b, dim, z, y, x)

        res_connection = x

        # Expand
        x = self.expand(x, channels_first=True)
        # (b, hidden_dim, z, y, x)

        # Depthwise Conv
        x = self.depthwise_conv(x, channels_first=True)
        # (b, hidden_dim, z, y, x)

        # SE
        x = self.se(x, channels_first=True)
        # (b, hidden_dim, z, y, x)

        # Pointwise Conv
        x = self.pointwise_conv(x, channels_first=True)
        # (b, dim, z, y, x)

        # Residual
        if x.shape == res_connection.shape:
            x = self.residual(x, res_connection)

        if not channels_first:
            x = rearrange(x, "b d z y x -> b z y x d").contiguous()
            # (b, z, y, x, dim)

        return x

    def forward(self, x: torch.Tensor, channels_first: bool = True):
        return self.checkpointing_level2(self._forward, x, channels_first)

In [None]:
test = MBConv3D(dim=10)
display(test)

sample_input = torch.randn(2, 10, 32, 32, 32)
test(sample_input).shape


[1;35mMBConv3D[0m[1m([0m
  [1m([0mexpand[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
    [1m([0mcnn[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m10[0m, [1;36m60[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mnorm_layer[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m60[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mact_layer[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
    [1m([0mdropout[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mdepthwise_conv[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
    [1m([0mcnn[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m60[0m, [1;36m60[0m, [33mkern

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m

# nbdev

In [None]:
!nbdev_export