In [None]:
# | default_exp blocks/se

# Imports

In [None]:
# | export


from functools import wraps

import torch
from torch import nn

from vision_architectures.blocks.cnn import CNNBlock3D, CNNBlockConfig
from vision_architectures.docstrings import populate_docstring
from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.custom_base_model import Field
from vision_architectures.utils.rearrange import rearrange_channels

# Config

In [None]:
# | export


class SEBlock3DConfig(CNNBlockConfig):
    dim: int = Field(..., description="Number of input channels.")
    r: float = Field(..., description="Reduction ratio for the number of channels in the SE block.")

    kernel_size: int = Field(1, description="Kernel size for the convolutional layers in the SE block.")
    normalization: str = Field("batchnorm3d", description="Normalization layer to use in the SE block.")
    activation: str = Field("silu", description="Activation function to use in the SE block.")

    in_channels: None = Field(None, description="determined by dim and r")
    out_channels: None = Field(None, description="determined by dim and r")

# Architecture

In [None]:
# | export


class SEBlock3D(nn.Module):
    @populate_docstring
    def __init__(self, config: SEBlock3DConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize an SEBlock3D block. Activation checkpointing level 2.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = SEBlock3DConfig.model_validate(config | kwargs)

        dim = self.config.dim
        r = self.config.r

        excitation_dim = int(dim // r)

        self.squeeze = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.excite = nn.Sequential(
            CNNBlock3D(
                self.config.model_dump()
                | {
                    "in_channels": dim,
                    "out_channels": excitation_dim,
                    "kernel_size": 1,
                    "stride": 1,
                    "padding": 0,
                },
                checkpointing_level,
            ),
            CNNBlock3D(
                self.config.model_dump()
                | {
                    "in_channels": excitation_dim,
                    "out_channels": dim,
                    "kernel_size": 1,
                    "stride": 1,
                    "padding": 0,
                    "activation": "sigmoid",
                },
                checkpointing_level,
            ),
        )

        self.checkpointing_level2 = ActivationCheckpointing(2, checkpointing_level)

    @populate_docstring
    def _forward(self, x: torch.Tensor, channels_first: bool = True) -> torch.Tensor:
        """Forward pass of the SEBlock3D block.

        Args:
            x: {INPUT_3D_DOC}
            channels_first: {CHANNELS_FIRST_DOC}

        Returns:
            {OUTPUT_3D_DOC}
        """
        # x: (b, [dim], z, y, x, [dim])

        x = rearrange_channels(x, channels_first, True)
        # Now x is (b, dim, z, y, x)

        p = self.squeeze(x)
        # (b, dim, 1, 1, 1)
        p = self.excite(p)
        # (b, dim, 1, 1, 1)
        x = x * p
        # (b, dim, z, y, x)

        x = rearrange_channels(x, True, channels_first)
        # (b, [dim], z, y, x, [dim])

        return x

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level2(self._forward, *args, **kwargs)

In [None]:
test = SEBlock3D(dim=12, r=3, kernel_size=3)
display(test)

sample_input = torch.randn(2, 12, 4, 4, 4)
test(sample_input).shape


[1;35mSEBlock3D[0m[1m([0m
  [1m([0msqueeze[1m)[0m: [1;35mAdaptiveAvgPool3d[0m[1m([0m[33moutput_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m[1m)[0m
  [1m([0mexcite[1m)[0m: [1;35mSequential[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
      [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m12[0m, [1;36m4[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
      [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m4[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mact[1m)[0m: [1;35mSiLU[0m[1m([0m[1m)[0m
      [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m12[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m

# nbdev

In [None]:
!nbdev_export