In [None]:
# | default_exp nets/swinv2_3d

# Imports

In [None]:
# | export

import numpy as np
import torch
from huggingface_hub import PyTorchModelHubMixin  # TODO
from torch import nn

from vision_architectures.blocks.transformer import Attention3DWithMLP, Attention3DWithMLPConfig
from vision_architectures.docstrings import populate_docstring
from vision_architectures.layers.embeddings import (
    AbsolutePositionEmbeddings3D,
    PatchEmbeddings3D,
    RelativePositionEmbeddings3DConfig,
    RelativePositionEmbeddings3DMetaNetwork,
)
from vision_architectures.nets.swin_3d import (
    Swin3DBlock,
    Swin3DBlockConfig,
    Swin3DEncoderDecoderBase,
    Swin3DEncoderDecoderConfig,
    Swin3DEncoderWithPatchEmbeddings,
    Swin3DEncoderWithPatchEmbeddingsConfig,
    Swin3DLayer,
    Swin3DPatchMerging,
    Swin3DPatchMergingConfig,
    Swin3DPatchSplitting,
    Swin3DPatchSplittingConfig,
    Swin3DStage,
    Swin3DStageConfig,
)
from vision_architectures.utils.custom_base_model import Field

# Config

In [None]:
# | export


class SwinV23DPatchMergingConfig(Swin3DPatchMergingConfig):
    pass


class SwinV23DPatchSplittingConfig(Swin3DPatchSplittingConfig):
    pass


class SwinV23DBlockConfig(Swin3DBlockConfig):
    patch_merging: SwinV23DPatchMergingConfig | None = Field(
        None, description="Patch merging config if desired. Patch merging is applied before attention."
    )
    patch_splitting: SwinV23DPatchSplittingConfig | None = Field(
        None, description="Patch splitting config if desired. Patch splitting is applied after attention."
    )


class SwinV23DStageConfig(SwinV23DBlockConfig, Swin3DStageConfig):
    pass


class SwinV23DEncoderDecoderConfig(Swin3DEncoderDecoderConfig):
    stages: list[SwinV23DStageConfig]


class SwinV23DEncoderWithPatchEmbeddingsConfig(SwinV23DEncoderDecoderConfig, Swin3DEncoderWithPatchEmbeddingsConfig):
    pass

In [None]:
test_config = SwinV23DEncoderWithPatchEmbeddingsConfig.model_validate(
    {
        "patch_size": (1, 8, 8),
        "in_channels": 1,
        "use_absolute_position_embeddings": True,
        "learnable_absolute_position_embeddings": False,
        "stages": [
            {
                "dim": 36,
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": False,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "in_dim": 36,
                    "out_dim": 72,
                },
                "dim": 72,
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "in_dim": 72,
                    "out_dim": 144,
                },
                "dim": 144,
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
                "patch_splitting": {
                    "final_window_size": (2, 2, 2),
                    "in_dim": 144,
                    "out_dim": 72,
                },
            },
        ],
    }
)

test_config


[1;35mSwinV23DEncoderWithPatchEmbeddingsConfig[0m[1m([0m
    [33mstages[0m=[1m[[0m
        [1;35mSwinV23DStageConfig[0m[1m([0m
            [33mdim[0m=[1;36m36[0m,
            [33mnum_heads[0m=[1;36m4[0m,
            [33mratio_q_to_kv_heads[0m=[1;36m1[0m,
            [33mlogit_scale_learnable[0m=[3;91mFalse[0m,
            [33mattn_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
            [33mproj_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
            [33mmax_attention_batch_size[0m=[1;36m-1[0m,
            [33mmlp_ratio[0m=[1;36m4[0m,
            [33mactivation[0m=[32m'gelu'[0m,
            [33mmlp_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
            [33mnorm_location[0m=[32m'post'[0m,
            [33mlayer_norm_eps[0m=[1;36m1e[0m[1;36m-06[0m,
            [33mwindow_size[0m=[1m([0m[1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m)[0m,
            [33muse_relative_position_bias[0m=[3;91mFalse[0m,
            [33mpatch_merging[0m=[3;3

# Architecture

### Basic Layers

In [None]:
# | export


class SwinV23DLayerLogitScale(nn.Module):
    def __init__(self, num_heads):
        super().__init__()

        self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)

    def forward(self):
        logit_scale = torch.clamp(self.logit_scale, max=np.log(1.0 / 0.01)).exp()
        return logit_scale

In [None]:
# | export


@populate_docstring
class SwinV23DLayer(Swin3DLayer):
    """SwinV2 3D Layer applying windowed attention with optional relative position embeddings.
    {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(
        self,
        config: RelativePositionEmbeddings3DConfig | Attention3DWithMLPConfig = {},
        checkpointing_level: int = 0,
        **kwargs
    ):
        """Initializes the SwinV23DLayer.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__(config, checkpointing_level, **kwargs)

        # Update relative position bias to use the meta network
        relative_position_bias = None
        if self._use_relative_position_bias:
            relative_position_bias = RelativePositionEmbeddings3DMetaNetwork(
                self.embeddings_config, checkpointing_level=checkpointing_level
            )

        # Use SwinV2 logit scale
        logit_scale = SwinV23DLayerLogitScale(self.transformer_config.num_heads)

        # Re-initialize the transformer with the new relative position bias and logit scale
        self.transformer = Attention3DWithMLP(
            self.transformer_config,
            relative_position_bias=relative_position_bias,
            logit_scale=logit_scale,
            checkpointing_level=checkpointing_level,
        )

In [None]:
test = SwinV23DLayer(
    dim=64,
    num_heads=4,
    mlp_ratio=4,
    layer_norm_eps=1e-6,
    window_size=(2, 2, 2),
    use_relative_position_bias=True,
)
display(test)
display(test(torch.randn(2, 64, 4, 4, 4)).shape)


[1;35mSwinV23DLayer[0m[1m([0m
  [1m([0mtransformer[1m)[0m: [1;35mAttention3DWithMLP[0m[1m([0m
    [1m([0mattn[1m)[0m: [1;35mAttention3D[0m[1m([0m
      [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
      [1m([0mlogit_

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m

### Stage layers

In [None]:
# | export


@populate_docstring
class SwinV23DBlock(Swin3DBlock):
    """SwinV2 3D Block consisting of two SwinV23DLayers: one with regular windows and one with shifted windows.
    {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: SwinV23DBlockConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initializes the SwinV23DBlock.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__(config, checkpointing_level, **kwargs)

        self.config = SwinV23DBlockConfig.model_validate(config | kwargs)

        self.w_layer = SwinV23DLayer(self.config.model_dump(), checkpointing_level=checkpointing_level)
        self.sw_layer = SwinV23DLayer(self.config.model_dump(), checkpointing_level=checkpointing_level)

In [None]:
test_stage_config = SwinV23DBlockConfig.model_validate(
    {
        "dim": 64,
        "depth": 4,
        "num_heads": 4,
        "mlp_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
    }
)

test = SwinV23DBlock(test_stage_config)
display(test)
o = test(torch.randn(2, 64, 4, 4, 4), return_intermediates=True)
display((o[0].shape, (o[1][0].shape, o[1][1].shape)))


[1;35mSwinV23DBlock[0m[1m([0m
  [1m([0mw_layer[1m)[0m: [1;35mSwinV23DLayer[0m[1m([0m
    [1m([0mtransformer[1m)[0m: [1;35mAttention3DWithMLP[0m[1m([0m
      [1m([0mattn[1m)[0m: [1;35mAttention3D[0m[1m([0m
        [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[

[1m([0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m, [1m([0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m[1m)[0m[1m)[0m

In [None]:
# | export


@populate_docstring
class SwinV23DPatchMerging(Swin3DPatchMerging):
    """Patch merging layer for SwinV23D. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: SwinV23DPatchMergingConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the SwinV23DPatchMerging layer.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__(config, checkpointing_level, **kwargs)

        self.config = SwinV23DPatchMergingConfig.model_validate(config | kwargs)

        in_dim = self.config.in_dim * np.prod(self.config.merge_window_size)
        self.layer_norm = nn.LayerNorm(in_dim)
        self.proj = nn.Linear(in_dim, self.config.out_dim)

In [None]:
test_stage_config = SwinV23DPatchMergingConfig.model_validate(
    {
        "merge_window_size": (2, 2, 2),
        "in_dim": 64,
        "out_dim": 108,
    }
)

test = SwinV23DPatchMerging(test_stage_config)
display(test)
display(test(torch.randn(2, 64, 4, 4, 4)).shape)


[1;35mSwinV23DPatchMerging[0m[1m([0m
  [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;35mnp.int64[0m[1m([0m[1;36m512[0m[1m)[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m512[0m, [33mout_features[0m=[1;36m108[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m108[0m, [1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m][0m[1m)[0m

In [None]:
# | export


@populate_docstring
class SwinV23DPatchSplitting(Swin3DPatchSplitting):
    """Patch splitting layer for SwinV23D. {CLASS_DESCRIPTION_3D_DOC}

    This is a self-implemented class and is not part of the paper."""

    @populate_docstring
    def __init__(self, config: SwinV23DPatchSplittingConfig, checkpointing_level: int = 0, **kwargs):
        """Initialize the SwinV23DPatchSplitting layer.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__(config, checkpointing_level, **kwargs)

        self.config = SwinV23DPatchSplittingConfig.model_validate(config | kwargs)

        self.layer_norm = nn.LayerNorm(self.config.in_dim)
        self.proj = nn.Linear(self.config.in_dim, self.config.out_dim * np.prod(self.config.final_window_size))

In [None]:
test_stage_config = SwinV23DPatchSplittingConfig.model_validate(
    {
        "final_window_size": (2, 2, 2),
        "in_dim": 72,
        "out_dim": 64,
    }
)

test = SwinV23DPatchSplitting(test_stage_config)
display(test)
display(test(torch.randn(2, 72, 4, 4, 4)).shape)


[1;35mSwinV23DPatchSplitting[0m[1m([0m
  [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m72[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m72[0m, [33mout_features[0m=[1;36m512[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m[1m][0m[1m)[0m

In [None]:
# | export


@populate_docstring
class SwinV23DStage(Swin3DStage):
    """SwinV23D stage for SwinV23D. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: SwinV23DStageConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the SwinV23DStage.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__(config, checkpointing_level, **kwargs)

        self.config = SwinV23DStageConfig.model_validate(config | kwargs)

        self.patch_merging = None
        if self.config.patch_merging is not None:
            self.patch_merging = SwinV23DPatchMerging(self.config.patch_merging)

        self.blocks = nn.ModuleList(
            [SwinV23DBlock(self.config) for _ in range(self.config.depth)],
        )

        self.patch_splitting = None
        if self.config.patch_splitting is not None:
            # This has been implemented to create a Swin-based decoder
            self.patch_splitting = SwinV23DPatchSplitting(self.config.patch_splitting)

In [None]:
test_stage_config = SwinV23DStageConfig.model_validate(
    {
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "in_dim": 48,
            "out_dim": 108,
        },
        "dim": 108,
        "depth": 2,
        "num_heads": 4,
        "mlp_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
    }
)
display(test_stage_config)

test = SwinV23DStage(test_stage_config)
display(test)
o = test(torch.randn(2, 48, 8, 8, 8), return_intermediates=True)
display((o[0].shape, [x.shape for x in o[1]]))


[1;35mSwinV23DStageConfig[0m[1m([0m
    [33mdim[0m=[1;36m108[0m,
    [33mnum_heads[0m=[1;36m4[0m,
    [33mratio_q_to_kv_heads[0m=[1;36m1[0m,
    [33mlogit_scale_learnable[0m=[3;91mFalse[0m,
    [33mattn_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
    [33mproj_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
    [33mmax_attention_batch_size[0m=[1;36m-1[0m,
    [33mmlp_ratio[0m=[1;36m4[0m,
    [33mactivation[0m=[32m'gelu'[0m,
    [33mmlp_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
    [33mnorm_location[0m=[32m'post'[0m,
    [33mlayer_norm_eps[0m=[1;36m1e[0m[1;36m-06[0m,
    [33mwindow_size[0m=[1m([0m[1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m)[0m,
    [33muse_relative_position_bias[0m=[3;92mTrue[0m,
    [33mpatch_merging[0m=[1;35mSwinV23DPatchMergingConfig[0m[1m([0m
        [33min_dim[0m=[1;36m48[0m,
        [33mout_dim[0m=[1;36m108[0m,
        [33mmerge_window_size[0m=[1m([0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m)[0m,
  


[1;35mSwinV23DStage[0m[1m([0m
  [1m([0mpatch_merging[1m)[0m: [1;35mSwinV23DPatchMerging[0m[1m([0m
    [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;35mnp.int64[0m[1m([0m[1;36m384[0m[1m)[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m384[0m, [33mout_features[0m=[1;36m108[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m1[0m[1m)[0m: [1;36m2[0m x [1;35mSwinV23DBlock[0m[1m([0m
      [1m([0mw_layer[1m)[0m: [1;35mSwinV23DLayer[0m[1m([0m
        [1m([0mtransformer[1m)[0m: [1;35mAttention3DWithMLP[0m[1m([0m
          [1m([0mattn[1m)[0m: [1;35mAttent


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m108[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m108[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m108[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m108[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m108[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

In [None]:
test_stage_config = SwinV23DStageConfig.model_validate(
    {
        "patch_merging": None,
        "dim": 48,
        "depth": 2,
        "num_heads": 4,
        "mlp_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
        "patch_splitting": {
            "final_window_size": (2, 2, 2),
            "in_dim": 48,
            "out_dim": 18,
        },
    }
)

test = SwinV23DStage(test_stage_config)
display(test)
o = test(torch.randn(2, 48, 8, 8, 8), return_intermediates=True)
display((o[0].shape, [x.shape for x in o[1]]))
o = test(torch.randn(2, 48, 8, 8, 8), return_intermediates=True)
display((o[0].shape, [x.shape for x in o[1]]))


[1;35mSwinV23DStage[0m[1m([0m
  [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m1[0m[1m)[0m: [1;36m2[0m x [1;35mSwinV23DBlock[0m[1m([0m
      [1m([0mw_layer[1m)[0m: [1;35mSwinV23DLayer[0m[1m([0m
        [1m([0mtransformer[1m)[0m: [1;35mAttention3DWithMLP[0m[1m([0m
          [1m([0mattn[1m)[0m: [1;35mAttention3D[0m[1m([0m
            [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m48[0m, [33mout_features[0m=[1;36m48[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
            [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m48[0m, [33mout_features[0m=[1;36m48[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
            [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m48[0m, [33mout_features[0m=[1;36m48[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
            [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m18[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m18[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

### Encoder/Decoder Base

In [None]:
# | export


class SwinV23DEncoderDecoderBase(Swin3DEncoderDecoderBase, PyTorchModelHubMixin):
    @populate_docstring
    def __init__(self, config: SwinV23DEncoderDecoderConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initializes the SwinV23DEncoder/SwinV23DDecoder.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__(config, checkpointing_level, **kwargs)

        self.config = SwinV23DEncoderDecoderConfig.model_validate(config | kwargs)

        self.stages = nn.ModuleList(
            [SwinV23DStage(stage_config, checkpointing_level) for stage_config in self.config.stages]
        )

### Encoder

In [None]:
# | export


@populate_docstring
class SwinV23DEncoder(SwinV23DEncoderDecoderBase):
    """3D Swin Transformer encoder. Assumes input has already been patchified/tokenized. {CLASS_DESCRIPTION_3D_DOC}"""

    def __init__(self, config: SwinV23DEncoderDecoderConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__(config, checkpointing_level, **kwargs)

        for stage_config in self.config.stages:
            if stage_config.patch_splitting is not None:
                assert (
                    stage_config.patch_merging is not None
                ), "SwinV23DEncoder is not for decoding (mid blocks are ok)."

In [None]:
test_config = SwinV23DEncoderDecoderConfig.model_validate(
    {
        "patch_size": (2, 2, 2),
        "in_channels": 32,
        "stages": [
            {
                "dim": 32,
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": False,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "in_dim": 32,
                    "out_dim": 64,
                },
                "dim": 64,
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "in_dim": 64,
                    "out_dim": 128,
                },
                "dim": 128,
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
                "patch_splitting": {
                    "final_window_size": (2, 2, 2),
                    "in_dim": 128,
                    "out_dim": 100,
                },
            },
        ],
    }
)

test = SwinV23DEncoder(test_config)
display(test)
o = test(torch.randn(2, 32, 16, 16, 16), return_intermediates=True)
display((o[0].shape, [x.shape for x in o[1]], [x.shape for x in o[2]]))


[1;35mSwinV23DEncoder[0m[1m([0m
  [1m([0mstages[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwinV23DStage[0m[1m([0m
      [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
        [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwinV23DBlock[0m[1m([0m
          [1m([0mw_layer[1m)[0m: [1;35mSwinV23DLayer[0m[1m([0m
            [1m([0mtransformer[1m)[0m: [1;35mAttention3DWithMLP[0m[1m([0m
              [1m([0mattn[1m)[0m: [1;35mAttention3D[0m[1m([0m
                [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m32[0m, [33mout_features[0m=[1;36m32[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
                [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m32[0m, [33mout_features[0m=[1;36m32[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
                [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m32[0m, [33mout_features[


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m100[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m[1m][0m[1m)[0m,
    [1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m64[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m100[0m[1m][0m[1m)[0m[1m][0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtor

### Decoder

In [None]:
# | export


@populate_docstring
class SwinV23DDecoder(SwinV23DEncoderDecoderBase):
    """3D Swin Transformer decoder. Assumes input has already been patchified/tokenized. {CLASS_DESCRIPTION_3D_DOC}"""

    def __init__(self, config: SwinV23DEncoderDecoderConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__(config, checkpointing_level, **kwargs)

        for stage_config in config.stages:
            if stage_config.patch_merging is not None:
                assert (
                    stage_config.patch_splitting is not None
                ), "SwinV23DDecoder is not for encoding (mid blocks are ok)."

In [None]:
test_config = SwinV23DEncoderDecoderConfig.model_validate(
    {
        "stages": [
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "in_dim": 96,
                    "out_dim": 288,
                },
                "dim": 288,
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
                "patch_splitting": {
                    "final_window_size": (2, 2, 2),
                    "in_dim": 288,
                    "out_dim": 96,
                },
            },
            {
                "dim": 96,
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
                "patch_splitting": {
                    "final_window_size": (2, 2, 2),
                    "in_dim": 96,
                    "out_dim": 32,
                },
            },
            {
                "dim": 32,
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": False,
            },
        ],
    }
)

test = SwinV23DDecoder(test_config)
display(test)
o = test(torch.randn(2, 96, 16, 16, 16), return_intermediates=True)
display((o[0].shape, [x.shape for x in o[1]], [x.shape for x in o[2]]))


[1;35mSwinV23DDecoder[0m[1m([0m
  [1m([0mstages[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwinV23DStage[0m[1m([0m
      [1m([0mpatch_merging[1m)[0m: [1;35mSwinV23DPatchMerging[0m[1m([0m
        [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;35mnp.int64[0m[1m([0m[1;36m768[0m[1m)[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m288[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
      [1m)[0m
      [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
        [1m([0m[1;36m0[0m-[1;36m2[0m[1m)[0m: [1;36m3[0m x [1;35mSwinV23DBlock[0m[1m([0m
          [1m([0mw_layer[1m)[


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m96[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m
    [1m][0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m288[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m288[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m288[0m[1m

# Models

In [None]:
# | export


@populate_docstring
class SwinV23DEncoderWithPatchEmbeddings(Swin3DEncoderWithPatchEmbeddings, PyTorchModelHubMixin):
    """3D SwinV2 transformer with 3D patch embeddings. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: SwinV23DEncoderWithPatchEmbeddingsConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initializes the SwinV23DEncoderWithPatchEmbeddings.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__(config, checkpointing_level, **kwargs)

        self.config = SwinV23DEncoderWithPatchEmbeddingsConfig.model_validate(config | kwargs)

        self.patchify = PatchEmbeddings3D(
            patch_size=self.config.patch_size,
            in_channels=self.config.in_channels,
            dim=self.config.stages[0].get_in_dim(),
            checkpointing_level=checkpointing_level,
        )
        self.absolute_position_embeddings = AbsolutePositionEmbeddings3D(
            dim=self.config.stages[0].get_in_dim(), learnable=False
        )
        self.encoder = SwinV23DEncoder(self.config, checkpointing_level=checkpointing_level)

In [None]:
test_config = SwinV23DEncoderWithPatchEmbeddingsConfig.model_validate(
    {
        "patch_size": (1, 8, 8),
        "in_channels": 1,
        "use_absolute_position_embeddings": True,
        "learnable_absolute_position_embeddings": False,
        "drop_prob": 0.2,
        "stages": [
            {
                "dim": 36,
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": False,
                "attn_drop_prob": 0.2,
                "proj_drop_prob": 0.2,
                "mlp_drop_prob": 0.2,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "in_dim": 36,
                    "out_dim": 72,
                },
                "dim": 72,
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
                "attn_drop_prob": 0.2,
                "proj_drop_prob": 0.2,
                "mlp_drop_prob": 0.2,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "in_dim": 72,
                    "out_dim": 144,
                },
                "dim": 144,
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
                "attn_drop_prob": 0.2,
                "proj_drop_prob": 0.2,
            },
        ],
    }
)

test = SwinV23DEncoderWithPatchEmbeddings(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
    crop_offsets=torch.Tensor((10, 10, 10)),
    return_intermediates=True,
)
display((o[0].shape, [x.shape for x in o[1]], [x.shape for x in o[2]]))


[1;35mSwinV23DEncoderWithPatchEmbeddings[0m[1m([0m
  [1m([0mpatchify[1m)[0m: [1;35mPatchEmbeddings3D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m36[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[35msame[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m36[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mabsolute_position_embeddings[1m)[0m: [1;35mAbsolutePositionEmbeddings3D[0m[1m([0m[1m)[0m
  [1

  return F.conv3d(



[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m144[0m, [1;36m8[0m, [1;36m128[0m, [1;36m128[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m512[0m, [1;36m512[0m, [1;36m36[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m256[0m, [1;36m256[0m, [1;36m72[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m128[0m, [1;36m128[0m, [1;36m144[0m[1m][0m[1m)[0m
    [1m][0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m512[0m, [1;36m512[0m, [1;36m36[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m512[0m, [1;36m512[0m, [1;36m36[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m256[0m, [1;36m256[0

# nbdev

In [None]:
!nbdev_export

# Rough work