In [None]:
# | default_exp nets/maxvit_3d

# Imports

In [None]:
# | export

from functools import wraps

import torch
from torch import nn

from vision_architectures.blocks.cnn import CNNBlock3D, CNNBlockConfig
from vision_architectures.blocks.mbconv_3d import MBConv3D, MBConv3DConfig
from vision_architectures.blocks.transformer import Attention3DWithMLPConfig
from vision_architectures.docstrings import populate_docstring
from vision_architectures.nets.swinv2_3d import SwinV23DLayer
from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.custom_base_model import CustomBaseModel, Field, model_validator
from vision_architectures.utils.rearrange import rearrange_channels

# Config

In [None]:
# | export


class MaxViT3DStem0Config(CNNBlockConfig):
    in_channels: int = Field(..., description="Number of input channels")
    kernel_size: int = Field(3, description="Kernel size for the convolutional layers in the stem")
    dim: int = Field(..., description="Hidden dimension of the stem")
    depth: int = Field(2, description="Number of convolutional layers in the stem", ge=1)

    out_channels: None = Field(None, description="This is defined by dim")


class MaxViT3DBlockConfig(MBConv3DConfig, Attention3DWithMLPConfig):
    window_size: tuple[int, int, int] = Field(..., description="Size of the window to apply attention over")
    out_dim_ratio: int = Field(
        2, description="Ratio of the output dimension to the input dimension. Used only in the last block of stems"
    )


class MaxViT3DStemConfig(MaxViT3DBlockConfig):
    depth: int = Field(..., description="Number of blocks in the stem")


class MaxViT3DEncoderConfig(CustomBaseModel):
    stem0: MaxViT3DStem0Config = Field(..., description="Configuration for the stem0")
    stems: list[MaxViT3DStemConfig] = Field(..., description="Configurations for the remaining stems")

    @model_validator(mode="after")
    def validate(self):
        super().validate()
        assert self.stem0.dim == self.stems[0].dim, "Stem0 dim should be equal to the first stem dim"
        assert len(self.stems) >= 1
        for i in range(1, len(self.stems)):
            assert (
                self.stems[i - 1].dim * self.stems[i - 1].out_dim_ratio == self.stems[i].dim
            ), "Stem dims should match"
        return self

# Architecture

In [None]:
# | export


@populate_docstring
class MaxViT3DStem0(nn.Module):
    """Stem0 for MaxViT3D. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: MaxViT3DStem0Config = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the MaxViT3DStem0 block.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = MaxViT3DStem0Config.model_validate(config | kwargs)

        self.layers = nn.ModuleList()
        self.layers.append(
            CNNBlock3D(
                self.config.model_dump() | dict(out_channels=self.config.dim),
                checkpointing_level,
            )
        )
        for i in range(self.config.depth - 1):
            self.layers.append(
                CNNBlock3D(
                    self.config.model_dump()
                    | dict(
                        in_channels=self.config.dim,
                        out_channels=self.config.dim,
                        normalization=self.config.normalization if i < self.config.depth - 1 else None,
                        activation=self.config.activation if i < self.config.depth - 1 else None,
                    ),
                    checkpointing_level,
                )
            )

        self.checkpointing_level2 = ActivationCheckpointing(2, checkpointing_level)

    @populate_docstring
    def _forward(self, x: torch.Tensor, channels_first: bool = True) -> torch.Tensor:
        """Pass the input through the stem0. Downsamples the input 2x along each dimension

        Args:
            x: {INPUT_3D_DOC}
            channels_first: {CHANNELS_FIRST_DOC}

        Returns:
            {OUTPUT_3D_DOC}
        """
        # x: (b, [in_channels], z, y, x, [in_channels])

        x = rearrange_channels(x, channels_first, True)
        # Now x is (b, in_channels, z, y, x)

        for layer in self.layers:
            x = layer(x)
            # (b, dim, z1, y1, x1)

        x = rearrange_channels(x, True, channels_first)
        # (b, [dim], z1, y1, x1, [dim])

        return x

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level2(self._forward, *args, **kwargs)

In [None]:
test = MaxViT3DStem0(in_channels=4, dim=8)
display(test)

sample_input = torch.randn(1, 4, 32, 32, 32)
test(sample_input).shape


[1;35mMaxViT3DStem0[0m[1m([0m
  [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
      [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m4[0m, [1;36m8[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[35msame[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
      [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m8[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
      [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
    [1m)[0m
    [1m([0m[1;36m1[0m[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
 

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m1[0m, [1;36m8[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m

In [None]:
# | export


@populate_docstring
class MaxViT3DBlockAttention(SwinV23DLayer):
    """Perform windowed attention on the input tensor. {CLASS_DESCRIPTION_3D_DOC}"""

In [None]:
test = MaxViT3DBlockAttention(
    dim=12,
    num_heads=2,
    window_size=(4, 4, 4),
)
display(test)

sample_input = torch.randn(1, 12, 20, 20, 20)
test(sample_input).shape


[1;35mMaxViT3DBlockAttention[0m[1m([0m
  [1m([0mtransformer[1m)[0m: [1;35mAttention3DWithMLP[0m[1m([0m
    [1m([0mattn[1m)[0m: [1;35mAttention3D[0m[1m([0m
      [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m12[0m, [33mout_features[0m=[1;36m12[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m12[0m, [33mout_features[0m=[1;36m12[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m12[0m, [33mout_features[0m=[1;36m12[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m12[0m, [33mout_features[0m=[1;36m12[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
      [1m(

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m1[0m, [1;36m12[0m, [1;36m20[0m, [1;36m20[0m, [1;36m20[0m[1m][0m[1m)[0m

In [None]:
# | export


class MaxViT3DGridAttention(SwinV23DLayer):
    """Perform grid attention on the input tensor.

    Note that the grid attention implementation differs from the paper where the image is being partitioned based on the
    window size and not based on the number of windows. For example:

    Let us say the input is

    .. code-block:: text

        A1 A2 A3 A4 A5 A6
        B1 B2 B3 B4 B5 B6
        C1 C2 C3 C4 C5 C6
        D1 D2 D3 D4 D5 D6

    Let us say the window size is 2x2. The grid attention will be performed on the following 6 windows:

    .. code-block:: text

        A1 A4  A2 A5  A3 A6
        C1 C4  C2 C5  C3 C6

        B1 B4  B2 B5  B3 B6
        D1 D4  D2 D5  D3 D6

    According to the paper, my understanding is that attention should have been applied on the following 4 windows:

    .. code-block:: text

        A1 A3 A5  A2 A4 A6
        B1 B3 B5  B2 B4 B6

        C1 C3 C5  C2 C4 C6
        D1 D3 D5  D2 D4 D6

    i.e. the first token of all 2x2 windows in block attention, the second token of all 2x2 windows in block attention
    and so on.

    This has been implemented different so as to limit the number of tokens to be attended to in a window, as if
    utilized as per the paper, since 3D inputs are usually very large, the number of total windows in block attention
    would be very large, leading to a very large number of tokens to attend to in each window in grid attention.

    It would also cause problems when estimating the position embeddings as the grid size of the position embeddings
    would vary very with every input size.
    """

    @staticmethod
    def _get_rearrange_patterns() -> tuple[str, str]:
        forward_pattern = (
            "b (window_size_z num_windows_z) (window_size_y num_windows_y) (window_size_x num_windows_x) dim -> "
            "(b num_windows_z num_windows_y num_windows_x) window_size_z window_size_y window_size_x dim "
        )
        reverse_pattern = (
            "(b num_windows_z num_windows_y num_windows_x) window_size_z window_size_y window_size_x dim -> "
            "b (window_size_z num_windows_z) (window_size_y num_windows_y) (window_size_x num_windows_x) dim"
        )
        return forward_pattern, reverse_pattern

In [None]:
test = MaxViT3DGridAttention(
    dim=12,
    num_heads=2,
    window_size=(4, 4, 4),
)
display(test)

sample_input = torch.randn(1, 12, 20, 20, 20)
test(sample_input).shape


[1;35mMaxViT3DGridAttention[0m[1m([0m
  [1m([0mtransformer[1m)[0m: [1;35mAttention3DWithMLP[0m[1m([0m
    [1m([0mattn[1m)[0m: [1;35mAttention3D[0m[1m([0m
      [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m12[0m, [33mout_features[0m=[1;36m12[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m12[0m, [33mout_features[0m=[1;36m12[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m12[0m, [33mout_features[0m=[1;36m12[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m12[0m, [33mout_features[0m=[1;36m12[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
      [1m([

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m1[0m, [1;36m12[0m, [1;36m20[0m, [1;36m20[0m, [1;36m20[0m[1m][0m[1m)[0m

In [None]:
# | export


@populate_docstring
class MaxViT3DBlock(nn.Module):
    """MaxViT3D block."""

    @populate_docstring
    def __init__(
        self, config: MaxViT3DBlockConfig = {}, modify_dims: bool = False, checkpointing_level: int = 0, **kwargs
    ):
        """Initialize MaxViT3D block.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = MaxViT3DBlockConfig.model_validate(config | kwargs)

        mbconv_kwargs = {}
        out_dim = self.config.dim
        if modify_dims:
            out_dim = self.config.dim * self.config.out_dim_ratio
            mbconv_kwargs["stride"] = 2
            mbconv_kwargs["padding"] = 1
            mbconv_kwargs["out_dim"] = out_dim

        self.mbconv = MBConv3D(self.config.model_dump(), checkpointing_level=checkpointing_level, **mbconv_kwargs)

        # modify dim in the config
        self.config = MaxViT3DBlockConfig.model_validate(self.config | {"dim": out_dim})

        self.block_attention = MaxViT3DBlockAttention(self.config.model_dump(), checkpointing_level=checkpointing_level)
        self.grid_attention = MaxViT3DGridAttention(self.config.model_dump(), checkpointing_level=checkpointing_level)

        self.checkpointing_level3 = ActivationCheckpointing(3, checkpointing_level)

    @populate_docstring
    def _forward(self, x: torch.Tensor, channels_first: bool = True) -> torch.Tensor:
        """Pass the input through the block.

        Args:
            x: {INPUT_3D_DOC}
            channels_first: {CHANNELS_FIRST_DOC}

        Returns:
            {OUTPUT_3D_DOC}
        """
        # x: (b, [dim], z, y, x, [dim])

        x = self.mbconv(x, channels_first)  # this runs in channels_first format internally
        # (b, [dim], z1, y1, x1, [dim])

        x = rearrange_channels(x, channels_first, False)
        # (b, z1, y1, x1, dim)

        x = self.block_attention(x, channels_first=False)
        # (b, z1, y1, x1, dim)
        x = self.grid_attention(x, channels_first=False)
        # (b, z1, y1, x1, dim)

        x = rearrange_channels(x, False, channels_first)
        # (b, [dim], z1, y1, x1, [dim])

        return x

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level3(self._forward, *args, **kwargs)

In [None]:
test = MaxViT3DBlock(dim=12, num_heads=3, window_size=(4, 4, 4))
display(test)

sample_input = torch.randn(2, 12, 20, 20, 20)
test(sample_input).shape


[1;35mMaxViT3DBlock[0m[1m([0m
  [1m([0mmbconv[1m)[0m: [1;35mMBConv3D[0m[1m([0m
    [1m([0mexpand[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
      [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m12[0m, [1;36m72[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
      [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m72[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
      [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
    [1m)[0m
    [1m([0mdepthwise_conv[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
      [1m([0mconv[1m)[0m: [1;35

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m12[0m, [1;36m20[0m, [1;36m20[0m, [1;36m20[0m[1m][0m[1m)[0m

In [None]:
test = MaxViT3DBlock(dim=12, num_heads=3, window_size=(2, 2, 2), modify_dims=True)
display(test)

sample_input = torch.randn(2, 12, 20, 20, 20)
test(sample_input).shape


[1;35mMaxViT3DBlock[0m[1m([0m
  [1m([0mmbconv[1m)[0m: [1;35mMBConv3D[0m[1m([0m
    [1m([0mexpand[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
      [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m12[0m, [1;36m72[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
      [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m72[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
      [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
    [1m)[0m
    [1m([0mdepthwise_conv[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
      [1m([0mconv[1m)[0m: [1;35

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m24[0m, [1;36m10[0m, [1;36m10[0m, [1;36m10[0m[1m][0m[1m)[0m

In [None]:
# | export


@populate_docstring
class MaxViT3DStem(nn.Module):
    """Implementation of a group of MaxViT blocks forming a stem. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(
        self, config: MaxViT3DStemConfig = {}, checkpointing_level: int = 0, dont_downsample: bool = False, **kwargs
    ):
        """Initialize the stem

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            dont_downsample: Whether or not to downsample the input at the end of the stem.
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = MaxViT3DStemConfig.model_validate(config | kwargs)

        self.blocks = nn.ModuleList(
            MaxViT3DBlock(
                self.config.model_dump(),
                checkpointing_level=checkpointing_level,
                modify_dims=True if i == self.config.depth - 1 and not dont_downsample else False,
            )
            for i in range(self.config.depth)
        )

        self.checkpointing_level4 = ActivationCheckpointing(4, checkpointing_level)

    @populate_docstring
    def _forward(self, x: torch.Tensor, channels_first: bool = True) -> torch.Tensor:
        """Pass the input through the stem.

        Args:
            x: {INPUT_3D_DOC}
            channels_first: {CHANNELS_FIRST_DOC}

        Returns:
            {OUTPUT_3D_DOC}
        """
        # x: (b, [dim], z, y, x, [dim])

        x = rearrange_channels(x, channels_first, False)
        # (b, z, y, x, dim)

        for layer in self.blocks:
            x = layer(x, channels_first=False)

        x = rearrange_channels(x, False, channels_first)
        # (b, [dim], z1, y1, x1, [dim])

        return x

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level4(self._forward, *args, **kwargs)

In [None]:
test = MaxViT3DStem(
    dim=12,
    num_heads=3,
    depth=3,
    window_size=(2, 2, 2),
)
display(test)

sample_input = torch.randn(2, 12, 20, 20, 20)
test(sample_input).shape


[1;35mMaxViT3DStem[0m[1m([0m
  [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m1[0m[1m)[0m: [1;36m2[0m x [1;35mMaxViT3DBlock[0m[1m([0m
      [1m([0mmbconv[1m)[0m: [1;35mMBConv3D[0m[1m([0m
        [1m([0mexpand[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
          [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m12[0m, [1;36m72[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
          [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m72[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
          [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpoi

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m24[0m, [1;36m10[0m, [1;36m10[0m, [1;36m10[0m[1m][0m[1m)[0m

In [None]:
test = MaxViT3DStem(
    dim=12,
    num_heads=3,
    depth=3,
    window_size=(2, 2, 2),
    dont_downsample=True,
)
display(test)

sample_input = torch.randn(2, 12, 20, 20, 20)
test(sample_input).shape


[1;35mMaxViT3DStem[0m[1m([0m
  [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m2[0m[1m)[0m: [1;36m3[0m x [1;35mMaxViT3DBlock[0m[1m([0m
      [1m([0mmbconv[1m)[0m: [1;35mMBConv3D[0m[1m([0m
        [1m([0mexpand[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
          [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m12[0m, [1;36m72[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
          [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m72[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
          [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpoi

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m12[0m, [1;36m20[0m, [1;36m20[0m, [1;36m20[0m[1m][0m[1m)[0m

In [None]:
# | export


@populate_docstring
class MaxViT3DEncoder(nn.Module):
    """3D MaxViT encoder. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: MaxViT3DEncoderConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the 3D MaxViT encoder.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = MaxViT3DEncoderConfig.model_validate(config | kwargs)

        self.stems = nn.ModuleList([])
        self.stems.append(MaxViT3DStem0(self.config.stem0, checkpointing_level))
        for i, stem_config in enumerate(self.config.stems):
            self.stems.append(
                MaxViT3DStem(stem_config, checkpointing_level, dont_downsample=i == len(self.config.stems) - 1)
            )

        self.checkpointing_level5 = ActivationCheckpointing(5, checkpointing_level)

    @populate_docstring
    def _forward(self, x: torch.Tensor, return_intermediates: bool = False, channels_first: bool = True):
        """Pass the input through the 3D MaxViT encoder.

        Args:
            x: {INPUT_3D_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}
            channels_first: {CHANNELS_FIRST_DOC}

        Returns:
            {OUTPUT_3D_DOC}. If return_intermediates is True, returns a tuple of the output and a list of intermediate
            stem outputs. Note that the stem outputs are always in channels_last format.
        """
        # x: (b, [in_channels], z, y, x, [in_channels])

        x = rearrange_channels(x, channels_first, False)
        # (b, z, y, x, in_channels)

        features = []
        for stem in self.stems:
            x = stem(x, channels_first=False)
            features.append(x)

        x = rearrange_channels(x, False, channels_first)
        # (b, [in_channels], z1, y1, x1, [in_channels])

        if return_intermediates:
            return x, features
        return x

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level5(self._forward, *args, **kwargs)

In [None]:
test = MaxViT3DEncoder(
    stem0={
        "in_channels": 1,
        "dim": 12,
    },
    stems=[
        {
            "dim": 12,
            "num_heads": 3,
            "window_size": (2, 2, 2),
            "depth": 3,
        },
        {
            "dim": 24,
            "num_heads": 3,
            "window_size": (2, 2, 2),
            "depth": 3,
        },
    ],
)
display(test)

sample_input = torch.randn(2, 1, 32, 32, 32)
test(sample_input).shape


[1;35mMaxViT3DEncoder[0m[1m([0m
  [1m([0mstems[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mMaxViT3DStem0[0m[1m([0m
      [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
        [1m([0m[1;36m0[0m[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
          [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m12[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[35msame[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
          [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m12[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
          [1m([0mcheckpointing_level1[1m)[0m: [1;35mActiv

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m24[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m

# nbdev

In [None]:
!nbdev_export