In [None]:
# | default_exp blocks/se

# Imports

In [None]:
# | export


import torch
from einops import rearrange
from torch import nn

from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.activations import get_act_layer
from vision_architectures.utils.custom_base_model import CustomBaseModel
from vision_architectures.utils.normalizations import get_norm_layer

# Config

In [None]:
# | export


class SEBlock3DConfig(CustomBaseModel):
    dim: int
    r: float
    normalization: str = "batchnorm3d"
    activation: str = "silu"

# Architecture

In [None]:
# | export


class SEBlock3D(nn.Module):
    def __init__(self, config: SEBlock3DConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__()

        self.config = SEBlock3DConfig.model_validate(config | kwargs)

        dim = self.config.dim
        r = self.config.r
        activation = self.config.activation
        normalization = self.config.normalization

        excitation_dim = int(dim // r)

        self.squeeze = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.excitation = nn.Sequential(
            nn.Conv3d(
                dim, excitation_dim, kernel_size=1, bias=False if normalization.startswith("batchnorm") else True
            ),
            get_norm_layer(normalization, excitation_dim),
            get_act_layer(activation),
            nn.Conv3d(
                excitation_dim, dim, kernel_size=1, bias=False if normalization.startswith("batchnorm") else True
            ),
            get_norm_layer(normalization, dim),
            nn.Sigmoid(),
        )

        self.checkpointing_level2 = ActivationCheckpointing(2, checkpointing_level)

    def _forward(self, x: torch.Tensor, channels_first: bool = True):
        # x: (b, [dim], z, y, x, [dim])

        if not channels_first:
            x = rearrange(x, "b z y x d -> b d z y x").contiguous()
        # Now x is (b, dim, z, y, x)

        p = self.squeeze(x)
        # (b, dim, 1, 1, 1)
        p = self.excitation(p)
        # (b, dim, 1, 1, 1)
        x = x * p
        # (b, dim, z, y, x)

        if not channels_first:
            x = rearrange(x, "b d z y x d -> b z y x d").contiguous()
            # (b, z, y, x, dim)

        return x

    def forward(self, *args, **kwargs):
        return self.checkpointing_level2(self._forward, *args, **kwargs)

In [None]:
test = SEBlock3D(dim=12, r=3)
display(test)

sample_input = torch.randn(2, 12, 4, 4, 4)
test(sample_input).shape


[1;35mSEBlock3D[0m[1m([0m
  [1m([0msqueeze[1m)[0m: [1;35mAdaptiveAvgPool3d[0m[1m([0m[33moutput_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m[1m)[0m
  [1m([0mexcitation[1m)[0m: [1;35mSequential[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m12[0m, [1;36m4[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0m[1;36m1[0m[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m4[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0m[1;36m2[0m[1m)[0m: [1;35mSiLU[0m[1m([0m[1m)[0m
    [1m([0m[1;36m3[0m[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m4[0m, [1;36m12[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m12[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m

# nbdev

In [None]:
!nbdev_export