In [None]:
# | default_exp nets/swin_3d

# Imports

In [None]:
# | export

from functools import wraps

import numpy as np
import torch
from einops import rearrange
from huggingface_hub import PyTorchModelHubMixin
from loguru import logger
from torch import nn

from vision_architectures.blocks.transformer import Attention3DWithMLP, Attention3DWithMLPConfig
from vision_architectures.docstrings import populate_docstring
from vision_architectures.layers.embeddings import (
    AbsolutePositionEmbeddings3D,
    PatchEmbeddings3D,
    RelativePositionEmbeddings3D,
    RelativePositionEmbeddings3DConfig,
)
from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.custom_base_model import CustomBaseModel, Field, computed_field, model_validator
from vision_architectures.utils.rearrange import rearrange_channels

# Config

In [None]:
# | export


class Swin3DPatchMergingConfig(CustomBaseModel):
    in_dim: int = Field(..., description="Input dimension before merging")
    out_dim: int = Field(..., description="Output dimension after merging")
    merge_window_size: tuple[int, int, int] = Field(..., description="Size of the window for merging patches")

    @computed_field(description="Factor by which the dimension is increased after merging")
    @property
    def out_dim_ratio(self) -> float:
        return self.out_dim / self.in_dim

    @model_validator(mode="before")
    @classmethod
    def validate_before(cls, data):
        super().validate_before(data)
        merge_window_size = data.get("merge_window_size")
        if isinstance(merge_window_size, int):
            data["merge_window_size"] = (
                merge_window_size,
                merge_window_size,
                merge_window_size,
            )
        return data


class Swin3DPatchSplittingConfig(CustomBaseModel):
    in_dim: int = Field(..., description="Input dimension before splitting")
    out_dim: int = Field(..., description="Output dimension after splitting")
    final_window_size: tuple[int, int, int] = Field(..., description="Size of the window to split patches into")

    @computed_field(description="Factor by which the dimension is decreased after splitting")
    @property
    def in_dim_ratio(self) -> float:
        return self.in_dim / self.out_dim

    @model_validator(mode="before")
    @classmethod
    def validate_before(cls, data):
        super().validate_before(data)
        final_window_size = data.get("final_window_size")
        if isinstance(final_window_size, int):
            data["final_window_size"] = (
                final_window_size,
                final_window_size,
                final_window_size,
            )
        return data


class Swin3DBlockConfig(Attention3DWithMLPConfig):
    window_size: tuple[int, int, int] = Field(..., description="Size of the window to apply attention over")

    use_relative_position_bias: bool = Field(False, description="Whether to use relative position bias")
    patch_merging: Swin3DPatchMergingConfig | None = Field(
        None, description="Patch merging config if desired. Patch merging is applied before attention."
    )
    patch_splitting: Swin3DPatchSplittingConfig | None = Field(
        None, description="Patch splitting config if desired. Patch splitting is applied after attention."
    )

    in_dim: int | None = Field(None, description="Input dimension of the stage. Useful if ``patch_merging`` is used.")
    dim: int = Field(..., description="Dim at which attention is performed")
    out_dim: int | None = Field(
        None, description="Output dimension of the stage. Useful if ``patch_splitting`` is used."
    )

    @property
    def spatial_compression_ratio(self):
        compression_ratio = (1.0, 1.0, 1.0)
        if self.patch_merging is not None:
            compression_ratio = tuple(compression_ratio[i] * self.patch_merging.merge_window_size[i] for i in range(3))
        if self.patch_splitting is not None:
            compression_ratio = tuple(
                compression_ratio[i] / self.patch_splitting.final_window_size[i] for i in range(3)
            )
        return compression_ratio

    def get_out_patch_size(self, in_patch_size: tuple[int, int, int]):
        patch_size = tuple(int(in_patch_size[i] * self.spatial_compression_ratio[i]) for i in range(3))
        return patch_size

    def get_in_patch_size(self, out_patch_size: tuple[int, int, int]):
        patch_size = tuple(int(out_patch_size[i] / self.spatial_compression_ratio[i]) for i in range(3))
        return patch_size

    def get_in_dim(self) -> int:
        if self.in_dim is None:
            return self.dim
        return self.in_dim

    def get_out_dim(self) -> int:
        if self.out_dim is None:
            return self.dim
        return self.out_dim

    @property
    def out_dim_ratio(self) -> float:
        return self.get_out_dim() / self.get_in_dim()

    def populate(self):
        """Populate the in_dim and out_dim of patch_splitting and patch_merging based on the stage's in_dim, dim,
        out_dim."""
        if self.patch_merging is not None:
            if self.in_dim != self.patch_merging.in_dim:
                if self.in_dim is not None:
                    logger.warning(
                        f"Overwriting in_dim ({self.in_dim}) for this stage as it does not match patch_merging config "
                        f"in_dim ({self.patch_merging.in_dim})."
                    )
                self.in_dim = self.patch_merging.in_dim

            if self.dim != self.patch_merging.out_dim:
                if self.dim is not None:
                    logger.warning(
                        f"Overwriting dim ({self.dim}) for this stage as it does not match patch_merging config out_dim "
                        f"({self.patch_merging.out_dim})."
                    )
                self.dim = self.patch_merging.out_dim

        if self.patch_splitting is not None:
            if self.dim != self.patch_splitting.in_dim:
                if self.dim is not None:
                    logger.warning(
                        f"Overwriting dim ({self.dim}) for this stage as it does not match patch_splitting config "
                        f"in_dim ({self.patch_splitting.in_dim})."
                    )
                self.dim = self.patch_splitting.in_dim

            if self.out_dim != self.patch_splitting.out_dim:
                if self.out_dim is not None:
                    logger.warning(
                        f"Overwriting out_dim ({self.out_dim}) for this stage as it does not match patch_splitting config "
                        f"out_dim ({self.patch_splitting.out_dim})."
                    )
                self.out_dim = self.patch_splitting.out_dim

    @model_validator(mode="after")
    def validate(self):
        super().validate()
        self.populate()
        return self


class Swin3DStageConfig(Swin3DBlockConfig):
    depth: int = Field(..., description="Number of transformer blocks in this stage")


class Swin3DEncoderDecoderConfig(CustomBaseModel):
    stages: list[Swin3DStageConfig]

    def populate(self):
        """Populate the in_dim, dim, out_dim of each stage."""
        for i in range(len(self.stages)):
            self.stages[i].populate()

    @model_validator(mode="after")
    def validate(self):
        super().validate()
        self.populate()

        # Ensure there is at least one stage
        assert len(self.stages) > 0, "Must have at least one stage"

        # Test divisibility of dim with number of attention heads
        for stage in self.stages:
            assert (
                stage.dim % stage.num_heads == 0
            ), f"stage.dim {stage.dim} is not divisible by stage.num_heads {stage.num_heads}"

        # Ensure dimensionality matches across stages
        for i in range(len(self.stages) - 1):
            prev_out_dim = self.stages[i].get_out_dim()
            succ_in_dim = self.stages[i + 1].get_in_dim()
            assert prev_out_dim == succ_in_dim, (
                f"Dimensionality mismatch between stages. Preceding stage has out_dim "
                f"{prev_out_dim} and succeeding stage has in_dim {succ_in_dim}."
            )

        return self

    def get_out_dim_ratios(self):
        return [stage.out_dim_ratio for stage in self.stages]


class Swin3DEncoderWithPatchEmbeddingsConfig(Swin3DEncoderDecoderConfig):
    in_channels: int = Field(..., description="Number of input channels in the input image/video")
    patch_size: tuple[int, int, int] = Field(
        ..., description="Size of the patches to be extracted from the input image/video"
    )
    image_size: tuple[int, int, int] | None = Field(
        None, description="Size of the input image/video. Required if absolute position embeddings are learnable."
    )

    use_absolute_position_embeddings: bool = Field(True, description="Whether to use absolute position embeddings.")
    learnable_absolute_position_embeddings: bool = Field(
        False, description="Whether to use learnable absolute position embeddings."
    )

    @model_validator(mode="after")
    def validate(self):
        super().validate()
        # Test population of image_size field iff the absolute position embeddings are relative
        if self.learnable_absolute_position_embeddings:
            assert (
                self.image_size is not None
            ), "Please provide image_size if absolute position embeddings are learnable"
        return self

In [None]:
test_config = Swin3DEncoderWithPatchEmbeddingsConfig.model_validate(
    {
        "patch_size": (1, 8, 8),
        "in_channels": 1,
        "use_absolute_position_embeddings": True,
        "learnable_absolute_position_embeddings": False,
        "image_size": (32, 512, 512),
        "stages": [
            {
                "patch_merging": None,
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": False,
                "dim": 36,
            },
            {
                "patch_merging": {
                    "in_dim": 36,
                    "out_dim": 108,
                    "merge_window_size": (2, 2, 2),
                },
                "dim": 108,
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
            {
                "patch_merging": {
                    "in_dim": 108,
                    "out_dim": 216,
                    "merge_window_size": (2, 2, 2),
                },
                "dim": 1000,  # Purposefully set incorrecly
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
        ],
    }
)

test_config




[1;35mSwin3DEncoderWithPatchEmbeddingsConfig[0m[1m([0m
    [33mstages[0m=[1m[[0m
        [1;35mSwin3DStageConfig[0m[1m([0m
            [33mdim[0m=[1;36m36[0m,
            [33mnum_heads[0m=[1;36m4[0m,
            [33mratio_q_to_kv_heads[0m=[1;36m1[0m,
            [33mlogit_scale_learnable[0m=[3;91mFalse[0m,
            [33mattn_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
            [33mproj_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
            [33mmax_attention_batch_size[0m=[1;36m-1[0m,
            [33mmlp_ratio[0m=[1;36m4[0m,
            [33mactivation[0m=[32m'gelu'[0m,
            [33mmlp_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
            [33mnorm_location[0m=[32m'post'[0m,
            [33mlayer_norm_eps[0m=[1;36m1e[0m[1;36m-06[0m,
            [33mwindow_size[0m=[1m([0m[1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m)[0m,
            [33muse_relative_position_bias[0m=[3;91mFalse[0m,
            [33mpatch_merging[0m=[3;35mNo

# Architecture

### Basic Layers

In [None]:
# | export


@populate_docstring
class Swin3DLayer(nn.Module):
    """Swin 3D Layer applying windowed attention with optional relative position embeddings.
    {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(
        self,
        config: RelativePositionEmbeddings3DConfig | Attention3DWithMLPConfig = {},
        checkpointing_level: int = 0,
        **kwargs
    ):
        """Initializes the Swin3DLayer.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        if isinstance(config, CustomBaseModel):
            config = config.model_dump()
        self._all_kwargs = config | kwargs
        self._window_size = self._all_kwargs.get("window_size")
        self._use_relative_position_bias = self._all_kwargs.get("use_relative_position_bias")

        self.embeddings_config = RelativePositionEmbeddings3DConfig.model_validate(
            self._all_kwargs | {"grid_size": self._window_size}
        )
        self.transformer_config = Attention3DWithMLPConfig.model_validate(self._all_kwargs)

        relative_position_bias = None
        if self._use_relative_position_bias:
            relative_position_bias = RelativePositionEmbeddings3D(self.embeddings_config)

        self.transformer = Attention3DWithMLP(
            self.transformer_config,
            relative_position_bias=relative_position_bias,
            checkpointing_level=checkpointing_level,
        )

        self.checkpointing_level3 = ActivationCheckpointing(3, checkpointing_level)

    @staticmethod
    def _get_rearrange_patterns() -> tuple[str, str]:
        """Note that the patterns will be applied on tensors that are in channels_last format"""
        forward_pattern = (
            "b (num_windows_z window_size_z) (num_windows_y window_size_y) (num_windows_x window_size_x) dim -> "
            "(b num_windows_z num_windows_y num_windows_x) window_size_z window_size_y window_size_x dim "
        )
        reverse_pattern = (
            "(b num_windows_z num_windows_y num_windows_x) window_size_z window_size_y window_size_x dim -> "
            "b (num_windows_z window_size_z) (num_windows_y window_size_y) (num_windows_x window_size_x) dim"
        )
        return forward_pattern, reverse_pattern

    def _forward(self, hidden_states: torch.Tensor, channels_first: bool = True) -> torch.Tensor:
        """Window the input features and apply self attention on each window.

        Args:
            hidden_states: {INPUT_3D_DOC}
            channels_first: {CHANNELS_FIRST_DOC}

        Returns:
            {OUTPUT_3D_DOC}
        """
        # hidden_states: (b, [dim], num_patches_z, num_patches_y, num_patches_x, [dim])

        hidden_states = rearrange_channels(hidden_states, channels_first, False)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        _, num_patches_z, num_patches_y, num_patches_x, _ = hidden_states.shape

        forward_pattern, reverse_pattern = self._get_rearrange_patterns()

        # Perform windowing
        window_size_z, window_size_y, window_size_x = self._window_size
        num_windows_z, num_windows_y, num_windows_x = (
            num_patches_z // window_size_z,
            num_patches_y // window_size_y,
            num_patches_x // window_size_x,
        )
        hidden_states = rearrange(
            hidden_states,
            forward_pattern,
            num_windows_z=num_windows_z,
            num_windows_y=num_windows_y,
            num_windows_x=num_windows_x,
            window_size_z=window_size_z,
            window_size_y=window_size_y,
            window_size_x=window_size_x,
        ).contiguous()

        hidden_states = self.transformer(hidden_states, hidden_states, hidden_states, channels_first=False)

        # Undo windowing
        output = rearrange(
            hidden_states,
            reverse_pattern,
            num_windows_z=num_windows_z,
            num_windows_y=num_windows_y,
            num_windows_x=num_windows_x,
            window_size_z=window_size_z,
            window_size_y=window_size_y,
            window_size_x=window_size_x,
        ).contiguous()

        output = rearrange_channels(output, False, channels_first)
        # (b, [dim], num_patches_z, num_patches_y, num_patches_x, [dim])

        return output

    @wraps(_forward)
    def forward(self, *args, **kwargs) -> torch.Tensor:
        return self.checkpointing_level3(self._forward, *args, **kwargs)

In [None]:
test = Swin3DLayer(
    dim=64,
    num_heads=4,
    mlp_ratio=4,
    layer_norm_eps=1e-6,
    window_size=(4, 4, 4),
    use_relative_position_bias=True,
)
display(test)
display(test(torch.randn(2, 4, 4, 4, 64), channels_first=False).shape)


[1;35mSwin3DLayer[0m[1m([0m
  [1m([0mtransformer[1m)[0m: [1;35mAttention3DWithMLP[0m[1m([0m
    [1m([0mattn[1m)[0m: [1;35mAttention3D[0m[1m([0m
      [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
      [1m([0mrelative

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m

### Stage layers

In [None]:
# | export


@populate_docstring
class Swin3DBlock(nn.Module):
    """Swin 3D Block consisting of two Swin3DLayers: one with regular windows and one with shifted windows.
    {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: Swin3DBlockConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initializes the Swin3DBlock.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = Swin3DBlockConfig.model_validate(config | kwargs)

        self.w_layer = Swin3DLayer(self.config.model_dump(), checkpointing_level=checkpointing_level)
        self.sw_layer = Swin3DLayer(self.config.model_dump(), checkpointing_level=checkpointing_level)

    @populate_docstring
    def forward(
        self, hidden_states: torch.Tensor, channels_first: bool = True, return_intermediates: bool = False
    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
        """Apply window attention and shifted window attention on the input features.

        Args:
            hidden_states: {INPUT_3D_DOC}
            channels_first: {CHANNELS_FIRST_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}

        Returns:
            {OUTPUT_3D_DOC}. If return_intermediates is True, also returns a list of intermediate layer outputs. Note
            that the intermediate layer outputs returned will always be in ``channels_last`` format.
        """
        # hidden_states: (b, [dim], num_patches_z, num_patches_y, num_patches_x, [dim])

        hidden_states = rearrange_channels(hidden_states, channels_first, False)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        layer_outputs = []

        # First layer
        hidden_states = self.w_layer(hidden_states, channels_first=False)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        layer_outputs.append(hidden_states)

        # Shift windows
        window_size_z, window_size_y, window_size_x = self.config.window_size
        shifts = (window_size_z // 2, window_size_y // 2, window_size_x // 2)
        hidden_states = torch.roll(hidden_states, shifts=shifts, dims=(1, 2, 3))
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        # Second layer
        hidden_states = self.sw_layer(hidden_states, channels_first=False)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        # Reverse window shift
        shifts = tuple(-shift for shift in shifts)
        hidden_states = torch.roll(hidden_states, shifts=shifts, dims=(1, 2, 3))
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        layer_outputs.append(hidden_states)

        hidden_states = rearrange_channels(hidden_states, False, channels_first)
        # (b, [dim], num_patches_z, num_patches_y, num_patches_x, [dim])

        if return_intermediates:
            return hidden_states, layer_outputs
        return hidden_states

In [None]:
test_stage_config = Swin3DBlockConfig.model_validate(
    {
        "dim": 64,
        "num_heads": 4,
        "mlp_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
    }
)

test = Swin3DBlock(test_stage_config)
display(test)
o = test(torch.randn(2, 64, 4, 4, 4), return_intermediates=True)
display((o[0].shape, (o[1][0].shape, o[1][1].shape)))


[1;35mSwin3DBlock[0m[1m([0m
  [1m([0mw_layer[1m)[0m: [1;35mSwin3DLayer[0m[1m([0m
    [1m([0mtransformer[1m)[0m: [1;35mAttention3DWithMLP[0m[1m([0m
      [1m([0mattn[1m)[0m: [1;35mAttention3D[0m[1m([0m
        [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36

[1m([0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m, [1m([0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m[1m)[0m[1m)[0m

In [None]:
# | export


@populate_docstring
class Swin3DPatchMerging(nn.Module):
    """Patch merging layer for Swin3D. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: Swin3DPatchMergingConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the Swin3DPatchMerging layer.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = Swin3DPatchMergingConfig.model_validate(config | kwargs)

        in_dim = self.config.in_dim * np.prod(self.config.merge_window_size)
        self.layer_norm = nn.LayerNorm(in_dim)
        self.proj = nn.Linear(in_dim, self.config.out_dim)

        self.checkpointing_level1 = ActivationCheckpointing(1, checkpointing_level)

    @populate_docstring
    def _forward(self, hidden_states: torch.Tensor, channels_first: bool = True) -> torch.Tensor:
        """Merge multiple patches into a single patch.

        Args:
            hidden_states: {INPUT_3D_DOC}
            channels_first: {CHANNELS_FIRST_DOC}

        Returns:
            {OUTPUT_3D_DOC}
        """
        # hidden_states: (b, [dim], num_patches_z, num_patches_y, num_patches_x, [dim])

        hidden_states = rearrange_channels(hidden_states, channels_first, False)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        window_size_z, window_size_y, window_size_x = self.config.merge_window_size

        hidden_states = rearrange(
            hidden_states,
            "b (new_num_patches_z window_size_z) (new_num_patches_y window_size_y) (new_num_patches_x window_size_x) dim -> "
            "b new_num_patches_z new_num_patches_y new_num_patches_x (window_size_z window_size_y window_size_x dim)",
            window_size_z=window_size_z,
            window_size_y=window_size_y,
            window_size_x=window_size_x,
        ).contiguous()

        hidden_states = self.layer_norm(hidden_states)
        hidden_states = self.proj(hidden_states)

        hidden_states = rearrange_channels(hidden_states, False, channels_first)
        # (b, [dim], new_num_patches_z, new_num_patches_y, new_num_patches_x, [dim])

        return hidden_states

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level1(self._forward, *args, **kwargs)

In [None]:
test_stage_config = Swin3DPatchMergingConfig.model_validate(
    {
        "merge_window_size": (2, 2, 2),
        "in_dim": 64,
        "out_dim": 108,
    }
)

test = Swin3DPatchMerging(test_stage_config)
display(test)
display(test(torch.randn(2, 4, 4, 4, 64), channels_first=False).shape)


[1;35mSwin3DPatchMerging[0m[1m([0m
  [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;35mnp.int64[0m[1m([0m[1;36m512[0m[1m)[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m512[0m, [33mout_features[0m=[1;36m108[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m, [1;36m2[0m, [1;36m108[0m[1m][0m[1m)[0m

In [None]:
# | export


@populate_docstring
class Swin3DPatchSplitting(nn.Module):
    """Patch splitting layer for Swin3D. {CLASS_DESCRIPTION_3D_DOC}

    This is a self-implemented class and is not part of the paper."""

    @populate_docstring
    def __init__(self, config: Swin3DPatchSplittingConfig, checkpointing_level: int = 0, **kwargs):
        """Initialize the Swin3DPatchSplitting layer.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = Swin3DPatchSplittingConfig.model_validate(config | kwargs)

        self.layer_norm = nn.LayerNorm(self.config.in_dim)
        self.proj = nn.Linear(self.config.in_dim, self.config.out_dim * np.prod(self.config.final_window_size))

        self.checkpointing_level1 = ActivationCheckpointing(1, checkpointing_level)

    @populate_docstring
    def _forward(self, hidden_states: torch.Tensor, channels_first: bool = True) -> torch.Tensor:
        """Split patches into multiple patches.

        Args:
            hidden_states: {INPUT_3D_DOC}
            channels_first: {CHANNELS_FIRST_DOC}

        Returns:
            {OUTPUT_3D_DOC}
        """
        # hidden_states: (b, [dim], num_patches_z, num_patches_y, num_patches_x, [dim])

        hidden_states = rearrange_channels(hidden_states, channels_first, False)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        hidden_states = self.layer_norm(hidden_states)
        hidden_states = self.proj(hidden_states)

        window_size_z, window_size_y, window_size_x = self.config.final_window_size

        hidden_states = rearrange(
            hidden_states,
            "b num_patches_z num_patches_y num_patches_x (window_size_z window_size_y window_size_x dim) -> "
            "b (num_patches_z window_size_z) (num_patches_y window_size_y) (num_patches_x window_size_x) dim",
            window_size_z=window_size_z,
            window_size_y=window_size_y,
            window_size_x=window_size_x,
        ).contiguous()

        hidden_states = rearrange_channels(hidden_states, False, channels_first)
        # (b, [dim], num_patches_z, num_patches_y, num_patches_x, [dim])

        return hidden_states

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level1(self._forward, *args, **kwargs)

In [None]:
test_stage_config = Swin3DPatchSplittingConfig.model_validate(
    {
        "final_window_size": (2, 2, 2),
        "in_dim": 108,
        "out_dim": 64,
    },
)

test = Swin3DPatchSplitting(test_stage_config)
display(test)
display(test(torch.randn(2, 108, 4, 4, 4)).shape)


[1;35mSwin3DPatchSplitting[0m[1m([0m
  [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m108[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m108[0m, [33mout_features[0m=[1;36m512[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m[1m][0m[1m)[0m

In [None]:
# | export


@populate_docstring
class Swin3DStage(nn.Module):
    """Swin3D stage for Swin3D. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: Swin3DStageConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the Swin3DStage.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = Swin3DStageConfig.model_validate(config | kwargs)

        self.patch_merging = None
        if self.config.patch_merging is not None:
            self.patch_merging = Swin3DPatchMerging(self.config.patch_merging)

        self.blocks = nn.ModuleList(
            [Swin3DBlock(self.config) for _ in range(self.config.depth)],
        )

        self.patch_splitting = None
        if self.config.patch_splitting is not None:
            # This has been implemented to create a Swin-based decoder
            self.patch_splitting = Swin3DPatchSplitting(self.config.patch_splitting)

        self.checkpointing_level4 = ActivationCheckpointing(4, checkpointing_level)

    @populate_docstring
    def _forward(
        self, hidden_states: torch.Tensor, channels_first: bool = True, return_intermediates: bool = False
    ) -> torch.Tensor:
        """Merge patches if applicable (used by the encoder), perform a series of window and shifted window attention,
        and then split patches if applicable (used by the decoder).

        Args:
            hidden_states: {INPUT_3D_DOC}
            channels_first: {CHANNELS_FIRST_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}

        Returns:
            {OUTPUT_3D_DOC}. If return_intermediates is True, also returns a list of intermediate layer outputs. Note
            that the intermediate layer outputs returned will always be in ``channels_last`` format.
        """
        # hidden_states: (b, [dim], num_patches_z, num_patches_y, num_patches_x, [dim])

        hidden_states = rearrange_channels(hidden_states, channels_first, False)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        if self.patch_merging:
            hidden_states = self.patch_merging(hidden_states, channels_first=False)
            # (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, new_dim)

        layer_outputs = []
        for layer_module in self.blocks:
            hidden_states, _layer_outputs = layer_module(hidden_states, channels_first=False, return_intermediates=True)
            # (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, new_dim)
            layer_outputs.extend(_layer_outputs)

        if self.patch_splitting:
            hidden_states = self.patch_splitting(hidden_states, channels_first=False)
            # (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, new_dim)

        hidden_states = rearrange_channels(hidden_states, False, channels_first)
        # (b, [dim], num_patches_z, num_patches_y, num_patches_x, [dim])

        if return_intermediates:
            return hidden_states, layer_outputs
        return hidden_states

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level4(self._forward, *args, **kwargs)

In [None]:
test_stage_config = Swin3DStageConfig.model_validate(
    {
        "patch_merging": {"in_dim": 48, "merge_window_size": (2, 2, 2), "out_dim": 100},
        "depth": 2,
        "num_heads": 4,
        "mlp_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
        "dim": 100,
    }
)

test = Swin3DStage(test_stage_config)
display(test)
o = test(torch.randn(2, 48, 8, 8, 8), return_intermediates=True)
display((o[0].shape, [x.shape for x in o[1]]))


[1;35mSwin3DStage[0m[1m([0m
  [1m([0mpatch_merging[1m)[0m: [1;35mSwin3DPatchMerging[0m[1m([0m
    [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;35mnp.int64[0m[1m([0m[1;36m384[0m[1m)[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m384[0m, [33mout_features[0m=[1;36m100[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m1[0m[1m)[0m: [1;36m2[0m x [1;35mSwin3DBlock[0m[1m([0m
      [1m([0mw_layer[1m)[0m: [1;35mSwin3DLayer[0m[1m([0m
        [1m([0mtransformer[1m)[0m: [1;35mAttention3DWithMLP[0m[1m([0m
          [1m([0mattn[1m)[0m: [1;35mAttention3D[0


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m100[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m100[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m100[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m100[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m100[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

### Encoder/Decoder Base

In [None]:
# | export


class Swin3DEncoderDecoderBase(nn.Module, PyTorchModelHubMixin):
    @populate_docstring
    def __init__(self, config: Swin3DEncoderDecoderConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initializes the Swin3DEncoder/Swin3DDecoder.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = Swin3DEncoderDecoderConfig.model_validate(config | kwargs)

        self.stages = nn.ModuleList(
            [Swin3DStage(stage_config, checkpointing_level) for stage_config in self.config.stages]
        )

        self.checkpointing_level5 = ActivationCheckpointing(5, checkpointing_level)

    @populate_docstring
    def _forward(
        self, hidden_states: torch.Tensor, channels_first: bool = True, return_intermediates: bool = False
    ) -> torch.Tensor:
        """Encodes the input features using the Swin Transformer hierarchy.

        Args:
            hidden_states: {INPUT_3D_DOC}
            channels_first: {CHANNELS_FIRST_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}

        Returns:
            {OUTPUT_3D_DOC}. If return_intermediates is True, also returns a list of intermediate layer outputs. Note
            that the intermediate layer outputs returned will always be in ``channels_last`` format.
        """
        # hidden_states: (b, [dim], num_patches_z, num_patches_y, num_patches_x, [dim])

        hidden_states = rearrange_channels(hidden_states, channels_first, False)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        stage_outputs, layer_outputs = [], []
        for stage_module in self.stages:
            hidden_states, _layer_outputs = stage_module(hidden_states, channels_first=False, return_intermediates=True)
            # (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, dim)

            stage_outputs.append(hidden_states)
            layer_outputs.extend(_layer_outputs)

        hidden_states = rearrange_channels(hidden_states, False, channels_first)
        # (b, [dim], num_patches_z, num_patches_y, num_patches_x, [dim])

        if return_intermediates:
            return hidden_states, stage_outputs, layer_outputs
        return hidden_states

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level5(self._forward, *args, **kwargs)

### Encoder

In [None]:
# | export


@populate_docstring
class Swin3DEncoder(Swin3DEncoderDecoderBase):
    """3D Swin Transformer encoder. Assumes input has already been patchified/tokenized. {CLASS_DESCRIPTION_3D_DOC}"""

    def __init__(self, config: Swin3DEncoderDecoderConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__(config, checkpointing_level, **kwargs)

        for stage_config in self.config.stages:
            if stage_config.patch_splitting is not None:
                assert stage_config.patch_merging is not None, "Swin3DEncoder is not for decoding (mid blocks are ok)."

In [None]:
test_config = Swin3DEncoderDecoderConfig.model_validate(
    {
        "in_channels": 32,
        "patch_size": (1, 1, 1),
        "stages": [
            {
                "dim": 32,
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": False,
            },
            {
                "patch_merging": {"merge_window_size": (2, 2, 2), "in_dim": 32, "out_dim": 96},
                "dim": 96,
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
        ],
    }
)

test = Swin3DEncoder(test_config)
display(test)
o = test(torch.randn(2, 16, 16, 16, 32), channels_first=False, return_intermediates=True)
display((o[0].shape, [x.shape for x in o[1]], [x.shape for x in o[2]]))


[1;35mSwin3DEncoder[0m[1m([0m
  [1m([0mstages[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwin3DStage[0m[1m([0m
      [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
        [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwin3DBlock[0m[1m([0m
          [1m([0mw_layer[1m)[0m: [1;35mSwin3DLayer[0m[1m([0m
            [1m([0mtransformer[1m)[0m: [1;35mAttention3DWithMLP[0m[1m([0m
              [1m([0mattn[1m)[0m: [1;35mAttention3D[0m[1m([0m
                [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m32[0m, [33mout_features[0m=[1;36m32[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
                [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m32[0m, [33mout_features[0m=[1;36m32[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
                [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m32[0m, [33mout_features[0m=[1;3


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m96[0m[1m][0m[1m)[0m,
    [1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m96[0m[1m][0m[1m)[0m[1m][0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m96[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m96[0m[1m][0m[1m)[0m,
        [1;

### Decoder

In [None]:
# | export


@populate_docstring
class Swin3DDecoder(Swin3DEncoderDecoderBase):
    """3D Swin Transformer decoder. Assumes input has already been patchified/tokenized. {CLASS_DESCRIPTION_3D_DOC}"""

    def __init__(self, config: Swin3DEncoderDecoderConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__(config, checkpointing_level, **kwargs)

        for stage_config in config.stages:
            if stage_config.patch_merging is not None:
                assert (
                    stage_config.patch_splitting is not None
                ), "Swin3DDecoder is not for encoding (mid blocks are ok)."

In [None]:
test_config = Swin3DEncoderDecoderConfig.model_validate(
    {
        "depth": 1,
        "num_heads": 4,
        "mlp_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": False,
        "stages": [
            {
                "patch_merging": {
                    "in_dim": 96,
                    "out_dim": 32,
                    "merge_window_size": (2, 2, 2),
                },
                "dim": 32,
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
                "patch_splitting": {
                    "final_window_size": (2, 2, 2),
                    "in_dim": 32,
                    "out_dim": 96,
                },
            },
            {
                "dim": 96,
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
                "patch_splitting": {
                    "final_window_size": (2, 2, 2),
                    "in_dim": 96,
                    "out_dim": 32,
                },
            },
        ],
    }
)

test = Swin3DDecoder(test_config)
display(test)
o = test(torch.randn(2, 96, 16, 16, 16), return_intermediates=True)
display((o[0].shape, [x.shape for x in o[1]], [x.shape for x in o[2]]))


[1;35mSwin3DDecoder[0m[1m([0m
  [1m([0mstages[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwin3DStage[0m[1m([0m
      [1m([0mpatch_merging[1m)[0m: [1;35mSwin3DPatchMerging[0m[1m([0m
        [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;35mnp.int64[0m[1m([0m[1;36m768[0m[1m)[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m32[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
      [1m)[0m
      [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
        [1m([0m[1;36m0[0m-[1;36m2[0m[1m)[0m: [1;36m3[0m x [1;35mSwin3DBlock[0m[1m([0m
          [1m([0mw_layer[1m)[0m: [1;3


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m,
    [1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m96[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m[1m][0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;

# Models

In [None]:
# | export


@populate_docstring
class Swin3DEncoderWithPatchEmbeddings(nn.Module, PyTorchModelHubMixin):
    """3D Swin transformer with 3D patch embeddings. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: Swin3DEncoderWithPatchEmbeddingsConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initializes the Swin3DEncoderWithPatchEmbeddings.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = Swin3DEncoderWithPatchEmbeddingsConfig.model_validate(config | kwargs)

        self.patchify = PatchEmbeddings3D(
            patch_size=self.config.patch_size,
            in_channels=self.config.in_channels,
            dim=self.config.stages[0].get_in_dim(),
            checkpointing_level=checkpointing_level,
        )
        self.absolute_position_embeddings = AbsolutePositionEmbeddings3D(
            dim=self.config.stages[0].get_in_dim(), learnable=False
        )
        self.encoder = Swin3DEncoder(self.config, checkpointing_level=checkpointing_level)

    @populate_docstring
    def forward(
        self,
        pixel_values: torch.Tensor,
        spacings: torch.Tensor = None,
        crop_offsets: torch.Tensor = None,
        channels_first: bool = True,
        return_intermediates: bool = False,
    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
        """Patchify the input pixel values and then pass it through the Swin transformer.

        Args:
            pixel_values: {INPUT_3D_DOC}
            spacings: {SPACINGS_DOC}
            crop_offsets: Used if the embeddings required are of a crop of a larger image. If provided, the grid
                coordinates will be offset accordingly.
            channels_first: {CHANNELS_FIRST_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}

        Returns:
            {OUTPUT_3D_DOC}. If `return_intermediates` is True, also returns the intermediate stage outputs and layer
            outputs.
        """
        # pixel_values: (b, [c], z, y, x, [c])
        # spacings: (b, 3)

        pixel_values = rearrange_channels(pixel_values, channels_first, False)
        # (b, z, y, x, c)

        embeddings = self.patchify(pixel_values, channels_first=False)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        embeddings = self.absolute_position_embeddings(
            embeddings, spacings=spacings, crop_offsets=crop_offsets, channels_first=False
        )
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        encoded, stage_outputs, layer_outputs = self.encoder(
            embeddings, channels_first=False, return_intermediates=True
        )
        # encoded: (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, dim)
        # stage_outputs, layer_outputs: list of (b, some_num_patches_z, some_num_patches_y, some_num_patches_x, dim)

        encoded = rearrange_channels(encoded, False, channels_first)
        # (b [dim], new_num_patches_z, new_num_patches_y, new_num_patches_x, [dim])

        if return_intermediates:
            return encoded, stage_outputs, layer_outputs
        return encoded

In [None]:
test_config = Swin3DEncoderWithPatchEmbeddingsConfig.model_validate(
    {
        "patch_size": (1, 8, 8),
        "in_channels": 1,
        "use_absolute_position_embeddings": True,
        "learnable_absolute_position_embeddings": False,
        "image_size": (32, 512, 512),
        "stages": [
            {
                "dim": 36,
                "patch_merging": None,
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": False,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "in_dim": 36,
                    "out_dim": 96,
                },
                "dim": 96,
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "in_dim": 96,
                    "out_dim": 192,
                },
                "dim": 192,
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
        ],
    }
)

test = Swin3DEncoderWithPatchEmbeddings(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 128, 128),
    # torch.randn(2, 3),
    return_intermediates=True,
)
display((o[0].shape, [x.shape for x in o[1]], [x.shape for x in o[2]]))


[1;35mSwin3DEncoderWithPatchEmbeddings[0m[1m([0m
  [1m([0mpatchify[1m)[0m: [1;35mPatchEmbeddings3D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m36[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m36[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mabsolute_position_embeddings[1m)[0m: [1;35mAbsolutePositionEmbeddings3D[0m[1m([0m[1m)[0m
  [1m([0mencoder[1m)[0m: [1;35mSwi


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m192[0m, [1;36m8[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m,
    [1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m16[0m, [1;36m16[0m, [1;36m36[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m8[0m, [1;36m8[0m, [1;36m96[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m4[0m, [1;36m4[0m, [1;36m192[0m[1m][0m[1m)[0m[1m][0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m16[0m, [1;36m16[0m, [1;36m36[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m16[0m, [1;36m16[0m, [1;36m36[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m8[0m, [1;36m8[0m, [1;36m96[0m[1m][0m[1m)[0m,
        [1;35mt

# nbdev

In [None]:
!nbdev_export