In [None]:
# | default_exp blocks/mbconv_3d

# Imports

In [None]:
# | export


from functools import wraps

import torch
from torch import nn

from vision_architectures.blocks.cnn import CNNBlock3D, CNNBlockConfig
from vision_architectures.blocks.se import SEBlock3D
from vision_architectures.docstrings import populate_docstring
from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.custom_base_model import Field, model_validator
from vision_architectures.utils.rearrange import rearrange_channels
from vision_architectures.utils.residuals import Residual

# Config

In [None]:
# | export


class MBConv3DConfig(CNNBlockConfig):
    dim: int = Field(..., description="Input channel dimension of the block.")
    out_dim: int | None = Field(
        None, description="Output channel dimension of the block. If None, it will be set to `dim`."
    )
    expansion_ratio: float = Field(6.0, description="Expansion ratio for the block.")
    se_reduction_ratio: float = Field(4.0, description="Squeeze-and-excitation reduction ratio.")

    kernel_size: int = Field(3, description="Kernel size for the convolutional layers.")
    activation: str = Field("relu", description="Activation function to use in the block.")
    normalization: str = Field("batchnorm3d", description="Normalization layer to use in the block.")

    in_channels: None = Field(None, description="Use dim instead")
    out_channels: None = Field(None, description="Use expansion_ratio instead")

    @property
    def hidden_dim(self):
        return int(self.expansion_ratio * self.dim)

    @model_validator(mode="before")
    @classmethod
    def validate_before(cls, data: dict):
        data.setdefault("dim", data.pop("in_channels", None))
        data.setdefault("out_dim", data.pop("out_channels", None))
        return data

    @model_validator(mode="after")
    def validate(self):
        super().validate()
        min_expansion_ratio = (self.dim + 1) / self.dim
        assert self.expansion_ratio > min_expansion_ratio, f"expansion_ratio must be greater than {min_expansion_ratio}"
        if self.out_dim is None:
            self.out_dim = self.dim
        return self

In [None]:
MBConv3DConfig(dim=10, expansion_ratio=1.2).hidden_dim

[1;36m12[0m

# Architecture

In [None]:
# | export


@populate_docstring
class MBConv3D(nn.Module):
    """Mobile Inverted Residual Bottleneck Block. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: MBConv3DConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the MBConv3D block. Activation checkpointing level 2.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = MBConv3DConfig.model_validate(config | kwargs)

        dim = self.config.dim
        hidden_dim = self.config.hidden_dim
        out_dim = self.config.out_dim
        se_reduction_ratio = self.config.se_reduction_ratio

        se_config = self.config.model_dump() | {"dim": hidden_dim, "r": se_reduction_ratio}

        self.expand = CNNBlock3D(
            self.config.model_dump()
            | dict(
                in_channels=dim,
                out_channels=hidden_dim,
                kernel_size=1,
                stride=1,
                padding=0,
            ),
            checkpointing_level,
        )
        self.depthwise_conv = CNNBlock3D(
            self.config.model_dump()
            | dict(
                in_channels=hidden_dim,
                out_channels=hidden_dim,
                conv_kwargs=self.config.conv_kwargs | dict(groups=hidden_dim),
            ),
            checkpointing_level,
        )
        self.se = SEBlock3D(se_config, checkpointing_level)
        self.pointwise_conv = CNNBlock3D(
            self.config.model_dump()
            | dict(
                in_channels=hidden_dim,
                out_channels=out_dim,
                kernel_size=1,
                stride=1,
                padding=0,
                activation=None,
            ),
            checkpointing_level,
        )

        self.residual = Residual()

        self.checkpointing_level2 = ActivationCheckpointing(2, checkpointing_level)

    @populate_docstring
    def _forward(self, x: torch.Tensor, channels_first: bool = True) -> torch.Tensor:
        """Forward pass of the MBConv3D block.

        Args:
            x: {INPUT_3D_DOC}
            channels_first: {CHANNELS_FIRST_DOC}

        Returns:
            {OUTPUT_3D_DOC}
        """
        # x: (b, [dim], z, y, x, [dim])

        x = rearrange_channels(x, channels_first, True)
        # Now x is (b, dim, z, y, x)

        res_connection = x

        # Expand
        x = self.expand(x, channels_first=True)
        # (b, hidden_dim, z, y, x)

        # Depthwise Conv
        x = self.depthwise_conv(x, channels_first=True)
        # (b, hidden_dim, z, y, x)

        # SE
        x = self.se(x, channels_first=True)
        # (b, hidden_dim, z, y, x)

        # Pointwise Conv
        x = self.pointwise_conv(x, channels_first=True)
        # (b, dim, z, y, x)

        # Residual
        if x.shape == res_connection.shape:
            x = self.residual(x, res_connection)

        x = rearrange_channels(x, True, channels_first)
        # (b, [dim], z, y, x, [dim])

        return x

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level2(self._forward, *args, **kwargs)

In [None]:
test = MBConv3D(dim=10)
display(test)

sample_input = torch.randn(2, 10, 32, 32, 32)
test(sample_input).shape


[1;35mMBConv3D[0m[1m([0m
  [1m([0mexpand[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m10[0m, [1;36m60[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m60[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mdepthwise_conv[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m60[0m, [1;36m60[0m, [33mkernel_size[0m=[1m([

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m

# nbdev

In [None]:
!nbdev_export