In [None]:
# | default_exp blocks/cnn

# Imports

In [None]:
# | export


from typing import Any, Literal

import torch
from torch import nn

from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.activations import get_act_layer
from vision_architectures.utils.custom_base_model import CustomBaseModel
from vision_architectures.utils.normalizations import get_norm_layer
from vision_architectures.utils.rearrange import rearrange_channels

# Config

In [None]:
# | export


class CNNBlock3DConfig(CustomBaseModel):
    in_channels: int
    out_channels: int
    kernel_size: int
    padding: int | tuple[int, ...] | str = "same"
    stride: int = 1
    conv_kwargs: dict[str, Any] = {}

    sequence: Literal["ADN", "AND", "DAN", "DNA", "NAD", "NDA"] = "NDA"

    normalization: str | None = None
    drop_prob: float = 0.0
    activation: str | None = None

# Architecture

In [None]:
# | export


class CNNBlock3D(nn.Module):
    def __init__(self, config: CNNBlock3DConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__()

        self.config = CNNBlock3DConfig.model_validate(config | kwargs)

        normalization = self.config.normalization
        activation = self.config.activation
        drop_prob = self.config.drop_prob
        sequence = self.config.sequence

        bias = True
        if normalization is not None and normalization.startswith("batchnorm") and sequence.startswith("N"):
            bias = False
        self.cnn = nn.Conv3d(
            in_channels=self.config.in_channels,
            out_channels=self.config.out_channels,
            kernel_size=self.config.kernel_size,
            padding=self.config.padding,
            stride=self.config.stride,
            bias=bias,
            **self.config.conv_kwargs,
        )

        self.norm_layer = get_norm_layer(normalization, self.config.out_channels)
        self.act_layer = get_act_layer(activation)
        self.dropout = nn.Dropout(drop_prob)

        self.checkpointing_level1 = ActivationCheckpointing(1, checkpointing_level)

    def _forward(self, x: torch.Tensor, channels_first: bool = True):
        # x: (b, [in_channels], z, y, x, [in_channels])

        x = rearrange_channels(x, channels_first, True)
        # Now x is (b, in_channels, z, y, x)

        x = self.cnn(x)
        # Now x is (b, out_channels, z, y, x)

        for layer in self.config.sequence:
            if layer == "A":
                x = self.act_layer(x)
            elif layer == "D":
                x = self.dropout(x)
            elif layer == "N":
                x = self.norm_layer(x)
            # (b, out_channels, z, y, x)

        x = rearrange_channels(x, True, channels_first)
        # (b, [out_channels], z, y, x, [out_channels])

        return x

    def forward(self, *args, **kwargs):
        return self.checkpointing_level1(self._forward, *args, **kwargs)

In [None]:
test = CNNBlock3D(
    in_channels=4,
    out_channels=8,
    kernel_size=3,
    normalization="batchnorm3d",
    activation="silu",
    drop_prob=0.5,
    padding=1,
    conv_kwargs={"groups": 2},
    sequence="NDA",
)
display(test)

sample_input = torch.randn(2, 4, 16, 16, 16)
test(sample_input).shape


[1;35mCNNBlock3D[0m[1m([0m
  [1m([0mcnn[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m4[0m, [1;36m8[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mgroups[0m=[1;36m2[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mnorm_layer[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m8[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mact_layer[1m)[0m: [1;35mSiLU[0m[1m([0m[1m)[0m
  [1m([0mdropout[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.5[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m, [33mcheckpo

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m

# nbdev

In [None]:
!nbdev_export