In [None]:
# | default_exp blocks/cnn

# Imports

In [None]:
# | export


from functools import cache
from itertools import chain, permutations
from typing import Any, Literal

import torch
from torch import nn

from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.activations import get_act_layer
from vision_architectures.utils.custom_base_model import CustomBaseModel, Field, field_validator, model_validator
from vision_architectures.utils.normalizations import get_norm_layer
from vision_architectures.utils.rearrange import rearrange_channels
from vision_architectures.utils.residuals import Residual
from vision_architectures.utils.splitter_merger import Splitter

# Config

In [None]:
# | export


possible_sequences = ["".join(p) for p in chain.from_iterable(permutations("ACDN", r) for r in range(5)) if "C" in p]


class CNNBlockConfig(CustomBaseModel):
    in_channels: int
    out_channels: int
    kernel_size: int | tuple[int, ...]
    padding: int | tuple[int, ...] | str = "same"
    stride: int = 1
    conv_kwargs: dict[str, Any] = {}
    transposed: bool = Field(False, description="Whether to perform ConvTranspose instead of Conv")

    normalization: str | None = "batchnorm3d"
    normalization_pre_args: list = []
    normalization_post_args: list = []
    normalization_kwargs: dict = {}
    activation: str | None = "relu"
    activation_kwargs: dict = {}

    sequence: Literal[tuple(possible_sequences)] = "CNA"

    drop_prob: float = 0.0

    @model_validator(mode="after")
    def validate(self):
        super().validate()
        if self.normalization is None and "N" in self.sequence:
            self.sequence = self.sequence.replace("N", "")
        if self.normalization is not None and "N" not in self.sequence:
            raise ValueError("Add N to the sequence or set normalization=None.")
        if self.activation is None and "A" in self.sequence:
            self.sequence = self.sequence.replace("A", "")
        if self.activation is not None and "A" not in self.sequence:
            raise ValueError("Add A to the sequence or set activation=None.")
        if self.drop_prob == 0.0 and "D" in self.sequence:
            self.sequence = self.sequence.replace("D", "")
        if self.drop_prob > 0.0 and "D" not in self.sequence:
            raise ValueError("Add D to the sequence or set drop_prob=0.")
        return self


class MultiResCNNBlockConfig(CNNBlockConfig):
    kernel_sizes: tuple[int | tuple[int, ...], ...] = (3, 5, 7)
    filter_ratios: tuple[float, ...] = Field(
        (1, 2, 3), description="Ratio of filters to out_channels for each conv layer. Will be scaled to sum to 1."
    )
    padding: Literal["same"] = "same"

    kernel_size: int = 3

    @field_validator("filter_ratios", mode="after")
    @classmethod
    def scale_filter_ratios(cls, filter_ratios):
        filter_ratios = tuple(ratio / sum(filter_ratios) for ratio in filter_ratios)
        return filter_ratios

    @model_validator(mode="after")
    def validate(self):
        super().validate()
        assert self.kernel_sizes == (3, 5, 7), "Only kernel sizes of (3, 5, 7) are supported for MultiResCNNBlock"
        assert self.kernel_size == 3, "only kernel_size = 3 is supported for MultiResCNNBlock"
        assert len(self.kernel_sizes) == len(
            self.filter_ratios
        ), "kernel_sizes and filter_ratios must have the same length"
        return self

# Architecture

### Simple blocks

In [None]:
# | export


class _CNNBlock(nn.Module):
    def __init__(
        self, spatial_dims: Literal[2, 3], config: CNNBlockConfig = {}, checkpointing_level: int = 0, **kwargs
    ):
        super().__init__()

        self.config = CNNBlockConfig.model_validate(config | kwargs)

        normalization = self.config.normalization
        activation = self.config.activation
        drop_prob = self.config.drop_prob
        sequence = self.config.sequence

        bias = True
        if normalization is not None and normalization.startswith("batchnorm") and "CN" in sequence:
            bias = False

        match spatial_dims, self.config.transposed:
            case 2, False:
                conv_module = nn.Conv2d
            case 2, True:
                conv_module = nn.ConvTranspose2d
            case 3, False:
                conv_module = nn.Conv3d
            case 3, True:
                conv_module = nn.ConvTranspose3d
            case _:
                raise ValueError(f"Unsupported spatial dimensions: {spatial_dims}")

        self.conv = conv_module(
            in_channels=self.config.in_channels,
            out_channels=self.config.out_channels,
            kernel_size=self.config.kernel_size,
            padding=self.config.padding,
            stride=self.config.stride,
            bias=bias,
            **self.config.conv_kwargs,
        )

        self.norm = None
        self.act = None
        self.dropout = None

        norm_channels = self.config.out_channels
        if "N" in sequence.split("C")[0]:
            norm_channels = self.config.in_channels

        if "N" in sequence:
            self.norm = get_norm_layer(
                normalization,
                *self.config.normalization_pre_args,
                norm_channels,
                *self.config.normalization_post_args,
                **self.config.normalization_kwargs,
            )
        if "A" in sequence:
            self.act = get_act_layer(activation, **self.config.activation_kwargs)
        if "D" in sequence:
            self.dropout = nn.Dropout(drop_prob)

        self.checkpointing_level1 = ActivationCheckpointing(1, checkpointing_level)

    def _forward(self, x: torch.Tensor, channels_first: bool = True):
        # x: (b, [in_channels], [z], y, x, [in_channels])

        x = rearrange_channels(x, channels_first, True)
        # Now x is (b, in_channels, [z], y, x)

        for layer in self.config.sequence:
            if layer == "C":
                x = self.conv(x)
            if layer == "A":
                x = self.act(x)
            elif layer == "D":
                x = self.dropout(x)
            elif layer == "N":
                x = self.norm(x)
        # (b, out_channels, [z], y, x)

        x = rearrange_channels(x, True, channels_first)
        # (b, [out_channels], [z], y, x, [out_channels])

        return x

    def forward(self, *args, **kwargs):
        return self.checkpointing_level1(self._forward, *args, **kwargs)

In [None]:
# | export


class CNNBlock3D(_CNNBlock):
    def __init__(self, config: CNNBlockConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__(3, config, checkpointing_level, **kwargs)

In [None]:
test = CNNBlock3D(
    in_channels=4,
    out_channels=8,
    kernel_size=3,
    normalization="groupnorm",
    normalization_pre_args=[2],
    activation="silu",
    drop_prob=0.5,
    padding=1,
    conv_kwargs={"groups": 2},
    sequence="NCDA",
)
display(test)

sample_input = torch.randn(2, 4, 16, 16, 16)
test(sample_input).shape


[1;35mCNNBlock3D[0m[1m([0m
  [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m4[0m, [1;36m8[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mgroups[0m=[1;36m2[0m[1m)[0m
  [1m([0mnorm[1m)[0m: [1;35mGroupNorm[0m[1m([0m[1;36m2[0m, [1;36m4[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33maffine[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mact[1m)[0m: [1;35mSiLU[0m[1m([0m[1m)[0m
  [1m([0mdropout[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.5[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m

In [None]:
test = CNNBlock3D(
    in_channels=4,
    out_channels=8,
    kernel_size=4,
    normalization="batchnorm3d",
    activation="prelu",
    drop_prob=0.5,
    padding=1,
    stride=2,
    conv_kwargs={"groups": 2},
    sequence="NDAC",
    transposed=True,
)
display(test)

sample_input = torch.randn(2, 4, 16, 16, 16)
test(sample_input).shape


[1;35mCNNBlock3D[0m[1m([0m
  [1m([0mconv[1m)[0m: [1;35mConvTranspose3d[0m[1m([0m[1;36m4[0m, [1;36m8[0m, [33mkernel_size[0m=[1m([0m[1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mgroups[0m=[1;36m2[0m[1m)[0m
  [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m4[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mact[1m)[0m: [1;35mPReLU[0m[1m([0m[33mnum_parameters[0m=[1;36m1[0m[1m)[0m
  [1m([0mdropout[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.5[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m


[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m

In [None]:
# | export


class CNNBlock2D(_CNNBlock):
    def __init__(self, config: CNNBlockConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__(2, config, checkpointing_level, **kwargs)

In [None]:
test = CNNBlock2D(
    in_channels=4,
    out_channels=8,
    kernel_size=3,
    normalization="groupnorm",
    normalization_pre_args=[2],
    activation="silu",
    drop_prob=0.5,
    padding=1,
    conv_kwargs={"groups": 2},
    sequence="NCDA",
)
display(test)

sample_input = torch.randn(2, 4, 16, 16)
test(sample_input).shape


[1;35mCNNBlock2D[0m[1m([0m
  [1m([0mconv[1m)[0m: [1;35mConv2d[0m[1m([0m[1;36m4[0m, [1;36m8[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m[1m)[0m, [33mgroups[0m=[1;36m2[0m[1m)[0m
  [1m([0mnorm[1m)[0m: [1;35mGroupNorm[0m[1m([0m[1;36m2[0m, [1;36m4[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33maffine[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mact[1m)[0m: [1;35mSiLU[0m[1m([0m[1m)[0m
  [1m([0mdropout[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.5[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m

In [None]:
test = CNNBlock2D(
    in_channels=4,
    out_channels=8,
    kernel_size=4,
    normalization="batchnorm2d",
    activation="prelu",
    drop_prob=0.5,
    padding=1,
    stride=2,
    conv_kwargs={"groups": 2},
    sequence="NDAC",
    transposed=True,
)
display(test)

sample_input = torch.randn(2, 4, 16, 16)
test(sample_input).shape


[1;35mCNNBlock2D[0m[1m([0m
  [1m([0mconv[1m)[0m: [1;35mConvTranspose2d[0m[1m([0m[1;36m4[0m, [1;36m8[0m, [33mkernel_size[0m=[1m([0m[1;36m4[0m, [1;36m4[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m2[0m, [1;36m2[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m[1m)[0m, [33mgroups[0m=[1;36m2[0m[1m)[0m
  [1m([0mnorm[1m)[0m: [1;35mBatchNorm2d[0m[1m([0m[1;36m4[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mact[1m)[0m: [1;35mPReLU[0m[1m([0m[33mnum_parameters[0m=[1;36m1[0m[1m)[0m
  [1m([0mdropout[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.5[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m

### Multi resolution

In [None]:
# | export


class _MultiResCNNBlock(nn.Module):
    def __init__(
        self, spatial_dims: Literal[2, 3], config: MultiResCNNBlockConfig = {}, checkpointing_level: int = 0, **kwargs
    ):
        super().__init__()

        self.config = MultiResCNNBlockConfig.model_validate(config | kwargs)

        assert self.config.kernel_sizes == (3, 5, 7), "Only kernel sizes of (3, 5, 7) are supported for now"

        all_out_channels = [max(1, int(self.config.out_channels * ratio)) for ratio in self.config.filter_ratios[:-1]]
        last_out_channels = self.config.out_channels - sum(all_out_channels)
        all_out_channels.append(last_out_channels)
        if last_out_channels <= 0:
            raise ValueError(
                f"These filter values ({self.config.filter_ratios}) won't work with the given out_channels. Please "
                f"adjust them. The out_channels of each conv layer is coming out to be {all_out_channels}."
            )
        all_in_channels = [self.config.in_channels] + all_out_channels[:-1]

        self.convs = nn.ModuleList(
            [
                _CNNBlock(
                    spatial_dims,
                    self.config.model_dump(),
                    checkpointing_level,
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=3,
                )
                for in_channels, out_channels in zip(all_in_channels, all_out_channels)
            ]
        )

        self.residual_conv = _CNNBlock(
            spatial_dims,
            self.config.model_dump(),
            checkpointing_level,
            in_channels=self.config.in_channels,
            out_channels=self.config.out_channels,
            kernel_size=1,
        )

        self.residual = Residual()

        self.checkpointing_level2 = ActivationCheckpointing(2, checkpointing_level)

    def _forward(self, x: torch.Tensor, channels_first: bool = True):
        # x: (b, [in_channels], [z], y, x, [in_channels])

        x = rearrange_channels(x, channels_first, True)
        # (b, in_channels, [z], y, x)

        residual = self.residual_conv(x)
        # (b, out_channels, [z], y, x)

        conv_outputs = []
        for conv in self.convs:
            conv_input = conv_outputs[-1] if conv_outputs else x
            conv_output = conv(conv_input)
            conv_outputs.append(conv_output)
            # (b, one_of_all_out_channels, [z], y, x)

        x = torch.cat(conv_outputs, dim=1)
        # (b, out_channels, [z], y, x)

        x = self.residual(x, residual)
        # (b, out_channels, [z], y, x)

        x = rearrange_channels(x, True, channels_first)
        # (b, [out_channels], [z], y, x, [out_channels])

        return x

    def forward(self, *args, **kwargs):
        return self.checkpointing_level2(self._forward, *args, **kwargs)

In [None]:
# | export


class MultiResCNNBlock3D(_MultiResCNNBlock):
    def __init__(self, config: MultiResCNNBlockConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__(3, config, checkpointing_level, **kwargs)

In [None]:
test = MultiResCNNBlock3D(
    in_channels=4,
    out_channels=8,
    filter_ratios=(3, 2, 1),
    activation="gelu",
)
display(test)

sample_input = torch.randn(2, 4, 16, 16, 16)
test(sample_input).shape


[1;35mMultiResCNNBlock3D[0m[1m([0m
  [1m([0mconvs[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35m_CNNBlock[0m[1m([0m
      [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m4[0m, [1;36m4[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[35msame[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
      [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m4[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mact[1m)[0m: [1;35mGELU[0m[1m([0m[33mapproximate[0m=[32m'none'[0m[1m)[0m
      [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
    [1m)[0m
    [1m([0m[1;36m1[0m[1m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m

In [None]:
# | export


class MultiResCNNBlock2D(_MultiResCNNBlock):
    def __init__(self, config: MultiResCNNBlockConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__(2, config, checkpointing_level, **kwargs)

In [None]:
test = MultiResCNNBlock2D(
    in_channels=4,
    out_channels=8,
    filter_ratios=(3, 2, 1),
    activation="gelu",
    normalization="batchnorm2d",
)
display(test)

sample_input = torch.randn(2, 4, 16, 16)
test(sample_input).shape


[1;35mMultiResCNNBlock2D[0m[1m([0m
  [1m([0mconvs[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35m_CNNBlock[0m[1m([0m
      [1m([0mconv[1m)[0m: [1;35mConv2d[0m[1m([0m[1;36m4[0m, [1;36m4[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[35msame[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
      [1m([0mnorm[1m)[0m: [1;35mBatchNorm2d[0m[1m([0m[1;36m4[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mact[1m)[0m: [1;35mGELU[0m[1m([0m[33mapproximate[0m=[32m'none'[0m[1m)[0m
      [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
    [1m)[0m
    [1m([0m[1;36m1[0m[1m)[0m: [1;35m_CNNBlock[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m

# Tensor splitting inference

In [None]:
# | export


class TensorSplittingConv(nn.Module):
    """Convolution layer that operates on splits of a tensor on desired device and concatenates the results to give a
    lossless output. This is useful for large tensors that cannot fit in memory."""

    def __init__(self, conv: nn.Module, num_splits: int | tuple[int, ...]):
        super().__init__()

        if isinstance(conv, nn.Conv2d):
            self.spatial_dims = 2
        elif isinstance(conv, nn.Conv3d):
            self.spatial_dims = 3
        else:
            raise ValueError("Unsupported convolution type. Only Conv2d and Conv3d are supported.")

        assert conv.stride == (1,) * self.spatial_dims, "Stride must be 1 for tensor splitting convolution."
        assert conv.padding == "same", "Padding must be 'same' for tensor splitting convolution."

        if isinstance(num_splits, int):
            num_splits = (num_splits,) * self.spatial_dims
        assert len(num_splits) == self.spatial_dims, "num_splits must be a tuple of length equal to spatial_dims"

        self.conv = conv
        self.num_splits = num_splits

    @cache
    def get_receptive_field(self) -> tuple[int, ...]:
        """Calculate the receptive field of the convolution layer."""
        kernel_size = torch.tensor(self.conv.kernel_size)
        dilation = torch.tensor(self.conv.dilation)
        receptive_field = dilation * (kernel_size - 1) + 1
        return tuple(receptive_field.tolist())

    @cache
    def get_edge_context(self):
        """Calculate the context size required to eliminate edge effects when merging the conv outputs into one."""
        receptive_field = self.get_receptive_field()
        context = torch.tensor(receptive_field) // 2
        return tuple(context.tolist())

    def get_split_size(self, input_shape: tuple[int, ...] | torch.Tensor) -> tuple[int, ...]:
        """Calculate the split size for each dimension based on the input shape and number of splits.

        Args:
            input_shape: Shape of the input tensor. If a tensor is provided, its shape will be used.

        Returns:
            Tuple of split sizes for each dimension.
        """
        if isinstance(input_shape, torch.Tensor):
            input_shape = input_shape.shape
        input_shape = input_shape[-self.spatial_dims :]

        context = self.get_edge_context()

        split_size = []
        for i in range(self.spatial_dims):
            dim = input_shape[i]
            num_splits = self.num_splits[i]
            if dim % num_splits != 0:
                raise ValueError(f"Input dimension {dim} is not divisible by number of splits {num_splits}.")
            split_size.append(dim // num_splits + 2 * context[i])
        split_size = tuple(split_size)
        return split_size

    def get_split_stride(self, input_shape: tuple[int, ...] | torch.Tensor) -> tuple[int, ...]:
        """Calculate the split stride for each dimension based on the input shape and context size."""
        context = self.get_edge_context()
        split_size = self.get_split_size(input_shape)
        split_stride = [split_size[i] - 2 * context[i] for i in range(self.spatial_dims)]
        assert all(
            split_stride[i] > 0 for i in range(self.spatial_dims)
        ), "Split stride must be greater than 0 for all dimensions."
        return split_stride

    def pad_input(self, x: torch.Tensor) -> torch.Tensor:
        """Pad the input with the context size for consistent merging."""
        context = self.get_edge_context()
        padding = [0, 0] * (x.ndim - self.spatial_dims)
        for i in range(self.spatial_dims):
            padding.extend([context[i], context[i]])
        x = nn.functional.pad(x, list(reversed(padding)))
        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the convolution layer with tensor splitting parallelism. Main convolution occurs on it's
             device, but the output is built on the input tensor's device.

        Args:
            x: Input tensor of shape (batch_size, in_channels, [z], y, x).

        Returns:
            Output tensor of shape (batch_size, out_channels, [z], y, x).
        """
        input_device = x.device
        B, DIMS = x.shape[0], x.shape[2:]  # (batch_size, in_channels, [z], y, x)

        # Calculate split size
        split_size = self.get_split_size(x)

        # Identify the stride required to split the input tensor such that overlapping regions can be counted only once
        split_stride = self.get_split_stride(x)

        # Pad the input
        x = self.pad_input(x)

        # Split the input tensor
        splitter = Splitter(
            split_dims=self.spatial_dims,
            split_size=split_size,
            stride=split_stride,
        )
        positions = splitter.get_positions(x)
        x = splitter(x)
        # (num_splits, batch_size, in_channels, [z1], y1, x1)

        # Run the convolution on each split
        outputs = []
        for x_split in x:
            x_split = x_split.to(self.conv.weight.device)
            x_split = self.conv(x_split)
            x_split = x_split.to(input_device)
            outputs.append(x_split)
        outputs = torch.stack(outputs, dim=0)
        # (num_splits, batch_size, out_channels, [z1], y1, x1)

        # Merge the outputs
        context = self.get_edge_context()
        merged = torch.zeros((B, outputs.shape[2], *DIMS), device=input_device)
        for output, position in zip(outputs, positions):
            output_slices = [slice(None), slice(None)]
            for i in range(self.spatial_dims):
                output_slices.append(slice(context[i], -context[i] if context[i] != 0 else None))
            output = output[tuple(output_slices)]

            merged_slices = [slice(None), slice(None)]
            for i in range(self.spatial_dims):
                merged_slices.append(slice(position[i], position[i] + split_stride[i]))
            merged[tuple(merged_slices)] = output

        return merged

In [None]:
conv = nn.Conv3d(in_channels=2, out_channels=4, kernel_size=3, padding="same").eval()
sample_input = torch.randn(2, 2, 16, 18, 16)

with torch.no_grad():
    test_output1 = conv(sample_input)

test = TensorSplittingConv(conv, num_splits=(2, 3, 4))
test.to("cuda:0")
display(test)

display(test.get_receptive_field())
display(test.get_edge_context())
display(test.get_split_size(sample_input))
display(test.get_split_stride(sample_input))

with torch.no_grad():
    test_output2 = test(sample_input)

test_output2.shape, torch.allclose(test_output1, test_output2)


[1;35mTensorSplittingConv[0m[1m([0m
  [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m2[0m, [1;36m4[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[35msame[0m[1m)[0m
[1m)[0m

[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m

[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m

[1m([0m[1;36m10[0m, [1;36m8[0m, [1;36m6[0m[1m)[0m

[1m[[0m[1;36m8[0m, [1;36m6[0m, [1;36m4[0m[1m][0m

[1m([0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m16[0m, [1;36m18[0m, [1;36m16[0m[1m][0m[1m)[0m, [3;91mFalse[0m[1m)[0m

# nbdev

In [None]:
!nbdev_export