In [1]:
# | default_exp nets/swinv2_3d

# Imports

In [2]:
# | export

import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from huggingface_hub import PyTorchModelHubMixin
from munch import munchify
from pydantic import BaseModel, model_validator
from torch import nn

# Config

In [3]:
# | export


class SwinV23DStageConfig(BaseModel):
    depth: int
    window_size: tuple[int, int, int]

    num_heads: int = 4
    mlp_ratio: int = 4
    layer_norm_eps: float = 1e-6
    use_relative_position_bias: bool = False
    patch_merging: dict | None = None
    patch_splitting: dict | None = None

    attn_drop_prob: float = 0.0
    proj_drop_prob: float = 0.0
    mlp_drop_prob: float = 0.0

    _in_dim: int
    _in_patch_size: tuple[int, int, int]
    _attention_dim: int = None
    _out_dim: int
    _out_patch_size: tuple[int, int, int]

    @model_validator(mode="after")
    def validate_after(self):
        if isinstance(self.patch_merging, dict):
            assert {"merge_window_size", "out_dim_ratio"}.issubset(self.patch_merging), "Missing keys"
        if isinstance(self.patch_splitting, dict):
            assert {"final_window_size", "out_dim_ratio"}.issubset(self.patch_splitting), "Missing keys"
        return self


class SwinV23DConfig(BaseModel):
    in_channels: int
    dim: int
    patch_size: tuple[int, int, int]
    stages: list[SwinV23DStageConfig]

    image_size: tuple[int, int, int] | None = None  # required for learnable absolute position embeddings
    drop_prob: float = 0.0
    embed_spacing_info: bool = False
    use_absolute_position_embeddings: bool = True
    learnable_absolute_position_embeddings: bool = False

    def populate(self):
        dim = self.dim
        patch_size = self.patch_size

        # Prepare config based on provided values
        for i in range(len(self.stages)):
            stage = self.stages[i]
            stage._in_dim = dim
            stage._in_patch_size = patch_size
            if stage.patch_merging is not None:
                dim *= stage.patch_merging["out_dim_ratio"]
                stage._attention_dim = dim  # attention will happen after merging
                patch_size = tuple(
                    [patch * window for patch, window in zip(patch_size, stage.patch_merging["merge_window_size"])]
                )
            if stage.patch_splitting is not None:
                stage._attention_dim = dim  # attention will happen before splitting
                dim //= stage.patch_splitting["out_dim_ratio"]
                patch_size = tuple(
                    [patch * window for patch, window in zip(patch_size, stage.patch_splitting["final_window_size"])]
                )
            if stage._attention_dim is None:
                stage._attention_dim = dim  # In case it is not yet set
            stage._out_dim = dim
            stage._out_patch_size = patch_size

    @model_validator(mode="after")
    def validate_after(self):
        self.populate()

        # test divisibility of dim with number of attention heads
        for stage in self.stages:
            assert (
                stage._out_dim % stage.num_heads == 0
            ), f"stage._out_dim {stage._out_dim} is not divisible by stage.num_heads {stage.num_heads}"

        # test population of image_size field iff the absolute position embeddings are relative
        if self.learnable_absolute_position_embeddings:
            assert (
                self.image_size is not None
            ), "Please provide image_size if absolute position embeddings are learnable"

        return self

In [4]:
test_config = SwinV23DConfig.model_validate(
    {
        "patch_size": (1, 8, 8),
        "in_channels": 1,
        "use_absolute_position_embeddings": True,
        "learnable_absolute_position_embeddings": False,
        "embed_spacing_info": False,
        "dim": 36,
        "stages": [
            {
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": False,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 3,
                },
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 3,
                },
                "patch_splitting": {
                    "final_window_size": (2, 2, 2),
                    "out_dim_ratio": 3,
                },
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
        ],
    }
)

# Architecture

### Basic Layers

In [5]:
# | export


def get_coords_grid(grid_size):
    d, h, w = grid_size

    grid_d = torch.arange(d, dtype=torch.int32)
    grid_h = torch.arange(h, dtype=torch.int32)
    grid_w = torch.arange(w, dtype=torch.int32)

    grid = torch.meshgrid(grid_d, grid_h, grid_w, indexing="ij")
    grid = torch.stack(grid, axis=0)
    # (3, d, h, w)

    return grid

In [6]:
# | export


class SwinV23DMHSA(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        window_size,
        use_relative_position_bias,
        attn_drop_prob=0.0,
        proj_drop_prob=0.0,
    ):
        super().__init__()

        assert dim % num_heads == 0, "dimension must be divisible by number of heads"

        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size

        self.per_head_dim = int(dim // num_heads)

        self.W_qkv = nn.Linear(dim, 3 * dim)
        self.attn_drop_prob = attn_drop_prob
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop_prob)

        self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))

        # TODO: Add embed_spacing_info functionality
        self.use_relative_position_bias = use_relative_position_bias
        if use_relative_position_bias:
            self.cpb_mlp = nn.Sequential(
                nn.Linear(3, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
            )

            relative_limits = (2 * window_size[0] - 1, 2 * window_size[1] - 1, 2 * window_size[2] - 1)

            # Relative coordinates table
            relative_coords_table = get_coords_grid(relative_limits).float()
            for i in range(3):
                relative_coords_table[i] = (relative_coords_table[i] - (window_size[0] - 1)) / (
                    window_size[0] - 1 + 1e-8  # small value added to ensure there is no NaN when window size is 1
                )
            relative_coords_table = rearrange(
                relative_coords_table,
                "three window_size_z window_size_y window_size_x -> window_size_z window_size_y window_size_x three",
            ).contiguous()
            relative_coords_table *= 8  # Normalize to -8, 8
            relative_coords_table = (
                torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / np.log2(8)
            )
            # (window_size_z, window_size_y, window_size_x, 3)
            # Allow moving this to and from cuda whenever required but don't save to state_dict
            self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)

            # Pair-wise relative position index for each token inside the window
            coords = get_coords_grid(window_size)
            coords_flatten = rearrange(
                coords, "three_dimensional d h w -> three_dimensional (d h w)", three_dimensional=3
            )
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()
            relative_coords[:, :, 0] += window_size[0] - 1
            relative_coords[:, :, 1] += window_size[1] - 1
            relative_coords[:, :, 2] += window_size[2] - 1
            relative_position_index: torch.Tensor = (
                relative_coords[:, :, 0] * relative_limits[1] * relative_limits[2]
                + relative_coords[:, :, 1] * relative_limits[2]
                + relative_coords[:, :, 2]
            )
            self.relative_position_index = relative_position_index.flatten()
            # (num_patches, num_patches)

    def calculate_relative_position_bias(self):
        # (window_size_z, window_size_y, window_size_x, 3)
        relative_position_bias_table = self.cpb_mlp(self.relative_coords_table)
        # (window_size_z, window_size_y, window_size_x, num_heads)
        relative_position_bias_table = relative_position_bias_table.reshape(-1, self.num_heads)
        # (num_patches, num_heads)
        relative_position_bias = relative_position_bias_table[self.relative_position_index]
        # (num_patches * num_patches, num_heads)
        relative_position_bias = rearrange(
            relative_position_bias,
            "(num_patches1 num_patches2) num_heads -> num_heads num_patches1 num_patches2",
            num_patches1=np.prod(self.window_size),
            num_patches2=np.prod(self.window_size),
            num_heads=self.num_heads,
        ).contiguous()
        # (num_heads, num_patches, num_patches)
        relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
        # (num_heads, num_patches, num_patches)
        return relative_position_bias

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (windowed_b, window_size_z window_size_y window_size_x, dim)
        _, num_patches_z, num_patches_y, num_patches_x, _ = hidden_states.shape

        query, key, value = rearrange(
            self.W_qkv(hidden_states),
            "b nz ny nx (n num_heads d) -> n b num_heads (nz ny nx) d",
            n=3,
            num_heads=self.num_heads,
        )
        # num_patches = window_size_z * window_size_y * window_size_x
        # Each is (windowed_b, num_heads, num_patches, per_head_dim)

        logit_scale = torch.clamp(self.logit_scale, max=np.log(1.0 / 0.01)).exp()

        query_normalized = F.normalize(query, dim=-1)
        key_normalized = F.normalize(key, dim=-1)

        query_normalized_and_scaled = query_normalized * logit_scale  # Scale the query beforehand

        relative_position_bias = None
        if self.use_relative_position_bias:
            relative_position_bias = self.calculate_relative_position_bias()

        context = F.scaled_dot_product_attention(
            query_normalized_and_scaled,
            key_normalized,
            value,
            attn_mask=relative_position_bias,  # Use this as a way to introduce relative position bias
            dropout_p=self.attn_drop_prob,
            is_causal=False,
            scale=1.0,  # Already scaled the vectors
        )
        # (windowed_b, num_heads, num_patches, per_head_dim)

        context = rearrange(
            context,
            "b num_heads (num_patches_z num_patches_y num_patches_x) d -> "
            "b num_patches_z num_patches_y num_patches_x (num_heads d)",
            num_patches_z=num_patches_z,
            num_patches_y=num_patches_y,
            num_patches_x=num_patches_x,
        )
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        context = self.proj(context)
        context = self.proj_drop(context)
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        return context

In [7]:
test = SwinV23DMHSA(54, 6, (4, 4, 4), True)
display(test)
display(test(torch.randn(2, 4, 4, 4, 54)).shape)


[1;35mSwinV23DMHSA[0m[1m([0m
  [1m([0mW_qkv[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m162[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mcpb_mlp[1m)[0m: [1;35mSequential[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m3[0m, [33mout_features[0m=[1;36m512[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0m[1;36m1[0m[1m)[0m: [1;35mReLU[0m[1m([0m[33minplace[0m=[3;92mTrue[0m[1m)[0m
    [1m([0m[1;36m2[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m512[0m, [33mout_features[0m=[1;36m6[0m, [33mbias[0m=[3;91mFalse

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m54[0m[1m][0m[1m)[0m

In [8]:
# | export


class SwinV23DLayerMLP(nn.Module):
    def __init__(self, dim, mlp_ratio, dropout_prob=0.0):
        super().__init__()
        self.dense1 = nn.Linear(dim, dim * mlp_ratio)
        self.act = nn.GELU()
        self.dense2 = nn.Linear(dim * mlp_ratio, dim)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (windowed_b, window_size_z window_size_y window_size_x, dim)
        hidden_states = self.dense1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.dense2(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states

In [9]:
test = SwinV23DLayerMLP(64, 256)
display(test)
display(test(torch.randn(2, 4, 4, 4, 64)).shape)


[1;35mSwinV23DLayerMLP[0m[1m([0m
  [1m([0mdense1[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m16384[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mact[1m)[0m: [1;35mGELU[0m[1m([0m[33mapproximate[0m=[32m'none'[0m[1m)[0m
  [1m([0mdense2[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m16384[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mdropout[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m

In [10]:
# | export


class SwinV23DLayer(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio,
        layer_norm_eps,
        window_size,
        use_relative_position_bias,
        attn_drop_prob=0.0,
        proj_drop_prob=0.0,
        mlp_drop_prob=0.0,
    ):
        super().__init__()

        self.window_size = window_size

        self.mhsa = SwinV23DMHSA(
            dim, num_heads, window_size, use_relative_position_bias, attn_drop_prob, proj_drop_prob
        )
        self.layernorm1 = nn.LayerNorm(dim, eps=layer_norm_eps)
        self.mlp = SwinV23DLayerMLP(dim, mlp_ratio, mlp_drop_prob)
        self.layernorm2 = nn.LayerNorm(dim, eps=layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)
        _, num_patches_z, num_patches_y, num_patches_x, _ = hidden_states.shape

        # Perform windowing
        window_size_z, window_size_y, window_size_x = self.window_size
        num_windows_z, num_windows_y, num_windows_x = (
            num_patches_z // window_size_z,
            num_patches_y // window_size_y,
            num_patches_x // window_size_x,
        )
        hidden_states = rearrange(
            hidden_states,
            "b (num_windows_z window_size_z) (num_windows_y window_size_y) (num_windows_x window_size_x) dim -> "
            "(b num_windows_z num_windows_y num_windows_x) window_size_z window_size_y window_size_x dim ",
            num_windows_z=num_windows_z,
            num_windows_y=num_windows_y,
            num_windows_x=num_windows_x,
            window_size_z=window_size_z,
            window_size_y=window_size_y,
            window_size_x=window_size_x,
        )

        res_connection1 = hidden_states
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        hidden_states = self.mhsa(hidden_states)
        hidden_states = self.layernorm1(hidden_states)
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        res_connection2 = hidden_states + res_connection1
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        hidden_states = self.mlp(res_connection2)
        hidden_states = self.layernorm2(hidden_states)
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        hidden_states = hidden_states + res_connection2
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        # Undo windowing
        output = rearrange(
            hidden_states,
            "(b num_windows_z num_windows_y num_windows_x) window_size_z window_size_y window_size_x dim -> "
            "b (num_windows_z window_size_z) (num_windows_y window_size_y) (num_windows_x window_size_x) dim",
            num_windows_z=num_windows_z,
            num_windows_y=num_windows_y,
            num_windows_x=num_windows_x,
            window_size_z=window_size_z,
            window_size_y=window_size_y,
            window_size_x=window_size_x,
        )

        return output

In [11]:
test = SwinV23DLayer(64, 4, 256, 1e-6, (2, 2, 2), True)
display(test)
display(test(torch.randn(2, 4, 4, 4, 64)).shape)


[1;35mSwinV23DLayer[0m[1m([0m
  [1m([0mmhsa[1m)[0m: [1;35mSwinV23DMHSA[0m[1m([0m
    [1m([0mW_qkv[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m192[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mcpb_mlp[1m)[0m: [1;35mSequential[0m[1m([0m
      [1m([0m[1;36m0[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m3[0m, [33mout_features[0m=[1;36m512[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0m[1;36m1[0m[1m)[0m: [1;35mReLU[0m[1m([0m[33minplace[0m=[3;92mTrue[0m[1m)[0m
      [1m([0m[1;36m2[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m

### Stage layers

In [12]:
# | export


class SwinV23DBlock(nn.Module):
    def __init__(self, stage_config):
        super().__init__()

        self.stage_config = stage_config
        self.w_layer = SwinV23DLayer(
            stage_config._attention_dim,
            stage_config.num_heads,
            stage_config.mlp_ratio,
            stage_config.layer_norm_eps,
            stage_config.window_size,
            stage_config.use_relative_position_bias,
            stage_config.attn_drop_prob,
            stage_config.proj_drop_prob,
            stage_config.mlp_drop_prob,
        )
        self.sw_layer = SwinV23DLayer(
            stage_config._attention_dim,
            stage_config.num_heads,
            stage_config.mlp_ratio,
            stage_config.layer_norm_eps,
            stage_config.window_size,
            stage_config.use_relative_position_bias,
            stage_config.attn_drop_prob,
            stage_config.proj_drop_prob,
            stage_config.mlp_drop_prob,
        )

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)

        layer_outputs = []

        # First layer
        hidden_states = self.w_layer(hidden_states)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        layer_outputs.append(hidden_states)

        # Shift windows
        window_size_z, window_size_y, window_size_x = self.stage_config.window_size
        shifts = (window_size_z // 2, window_size_y // 2, window_size_x // 2)
        hidden_states = torch.roll(hidden_states, shifts=shifts, dims=(1, 2, 3))
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        # Second layer
        hidden_states = self.sw_layer(hidden_states)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        # Reverse window shift
        shifts = tuple(-shift for shift in shifts)
        hidden_states = torch.roll(hidden_states, shifts=shifts, dims=(1, 2, 3))
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        layer_outputs.append(hidden_states)

        return hidden_states, layer_outputs

In [13]:
test_stage_config = SwinV23DStageConfig.model_validate(
    {
        "depth": 4,
        "num_heads": 4,
        "mlp_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
    }
)
test_stage_config._attention_dim = 64

test = SwinV23DBlock(test_stage_config)
display(test)
o = test(torch.randn(2, 4, 4, 4, 64))
display((o[0].shape, (o[1][0].shape, o[1][1].shape)))


[1;35mSwinV23DBlock[0m[1m([0m
  [1m([0mw_layer[1m)[0m: [1;35mSwinV23DLayer[0m[1m([0m
    [1m([0mmhsa[1m)[0m: [1;35mSwinV23DMHSA[0m[1m([0m
      [1m([0mW_qkv[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m192[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
      [1m([0mcpb_mlp[1m)[0m: [1;35mSequential[0m[1m([0m
        [1m([0m[1;36m0[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m3[0m, [33mout_features[0m=[1;36m512[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0m[1;36m1[0m[1m)[0m: [1;35mReLU[0m[1m([0m[33minplace[0m=[3;92mTrue[0m[1m)[0m
        [1m

[1m([0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m, [1m([0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m[1m)[0m[1m)[0m

In [14]:
# | export


class SwinV23DPatchMerging(nn.Module):
    def __init__(self, merge_window_size, in_dim, out_dim):
        super().__init__()

        self.merge_window_size = merge_window_size

        in_dim = in_dim * np.prod(merge_window_size)
        self.layer_norm = nn.LayerNorm(in_dim)
        self.proj = nn.Linear(in_dim, out_dim)

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)

        window_size_z, window_size_y, window_size_x = self.merge_window_size

        hidden_states = rearrange(
            hidden_states,
            "b (new_num_patches_z window_size_z) (new_num_patches_y window_size_y) (new_num_patches_x window_size_x) dim -> "
            "b new_num_patches_z new_num_patches_y new_num_patches_x (window_size_z window_size_y window_size_x dim)",
            window_size_z=window_size_z,
            window_size_y=window_size_y,
            window_size_x=window_size_x,
        )

        hidden_states = self.layer_norm(hidden_states)
        hidden_states = self.proj(hidden_states)
        return hidden_states

In [15]:
test_stage_config = SwinV23DStageConfig.model_validate(
    {
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "out_dim_ratio": 3,
        },
        "depth": 4,
        "num_heads": 4,
        "intermediate_size": 256,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
    }
)
test_stage_config._in_dim = 64
test_stage_config._out_dim = 64 * 3

test = SwinV23DPatchMerging(
    test_stage_config.patch_merging["merge_window_size"],
    test_stage_config._in_dim,
    test_stage_config._out_dim,
)
display(test)
display(test(torch.randn(2, 4, 4, 4, 64)).shape)


[1;35mSwinV23DPatchMerging[0m[1m([0m
  [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m512[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m512[0m, [33mout_features[0m=[1;36m192[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m, [1;36m2[0m, [1;36m192[0m[1m][0m[1m)[0m

In [16]:
# | export


class SwinV23DPatchSplitting(nn.Module):  # This is a self-implemented class and is not part of the paper.
    def __init__(self, final_window_size, in_dim, out_dim):
        super().__init__()

        self.final_window_size = final_window_size

        out_dim = out_dim * np.prod(final_window_size)
        self.layer_norm = nn.LayerNorm(in_dim)
        self.proj = nn.Linear(in_dim, out_dim)

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)

        hidden_states = self.layer_norm(hidden_states)
        hidden_states = self.proj(hidden_states)

        window_size_z, window_size_y, window_size_x = self.final_window_size

        hidden_states = rearrange(
            hidden_states,
            "b num_patches_z num_patches_y num_patches_x (window_size_z window_size_y window_size_x dim) -> "
            "b (num_patches_z window_size_z) (num_patches_y window_size_y) (num_patches_x window_size_x) dim",
            window_size_z=window_size_z,
            window_size_y=window_size_y,
            window_size_x=window_size_x,
        )

        return hidden_states

In [17]:
test_stage_config = SwinV23DStageConfig.model_validate(
    {
        "patch_splitting": {
            "final_window_size": (2, 2, 2),
            "out_dim_ratio": 3,
        },
        "depth": 4,
        "num_heads": 4,
        "intermediate_size": 256,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
    }
)
test_stage_config._in_dim = 64 * 3
test_stage_config._out_dim = 64 * 3

test = SwinV23DPatchSplitting(
    test_stage_config.patch_splitting["final_window_size"],
    test_stage_config._in_dim,
    test_stage_config._out_dim,
)
display(test)
display(test(torch.randn(2, 4, 4, 4, 64 * 3)).shape)


[1;35mSwinV23DPatchSplitting[0m[1m([0m
  [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m192[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m192[0m, [33mout_features[0m=[1;36m1536[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m192[0m[1m][0m[1m)[0m

In [18]:
# | export


class SwinV23DStage(nn.Module):
    def __init__(self, stage_config):
        super().__init__()

        self.config = stage_config

        self.patch_merging = None
        if stage_config.patch_merging is not None:
            self.patch_merging = SwinV23DPatchMerging(
                stage_config.patch_merging["merge_window_size"],
                stage_config._in_dim,
                stage_config._attention_dim,
            )

        self.blocks = nn.ModuleList(
            [SwinV23DBlock(stage_config) for _ in range(stage_config.depth)],
        )

        self.patch_splitting = None
        if stage_config.patch_splitting is not None:  # This has been implemented to create a Swin-based decoder
            self.patch_splitting = SwinV23DPatchSplitting(
                stage_config.patch_splitting["final_window_size"],
                stage_config._attention_dim,
                stage_config._out_dim,
            )

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)

        if self.patch_merging:
            hidden_states = self.patch_merging(hidden_states)
            # (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, new_dim)

        layer_outputs = []
        for layer_module in self.blocks:
            hidden_states, _layer_outputs = layer_module(hidden_states)
            # (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, new_dim)
            layer_outputs.extend(_layer_outputs)

        if self.patch_splitting:
            hidden_states = self.patch_splitting(hidden_states)
            # (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, new_dim)

        return hidden_states, layer_outputs

In [19]:
test_stage_config = SwinV23DStageConfig.model_validate(
    {
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "out_dim_ratio": 3,
        },
        "patch_splitting": None,
        "depth": 2,
        "num_heads": 4,
        "mlp_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
    }
)
test_stage_config._in_dim = 48
test_stage_config._attention_dim = 48 * 3
test_stage_config._out_dim = 48 * 3

test = SwinV23DStage(test_stage_config)
display(test)
o = test(torch.randn(2, 8, 8, 8, 48))
display((o[0].shape, [x.shape for x in o[1]]))


[1;35mSwinV23DStage[0m[1m([0m
  [1m([0mpatch_merging[1m)[0m: [1;35mSwinV23DPatchMerging[0m[1m([0m
    [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m384[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m384[0m, [33mout_features[0m=[1;36m144[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
  [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m1[0m[1m)[0m: [1;36m2[0m x [1;35mSwinV23DBlock[0m[1m([0m
      [1m([0mw_layer[1m)[0m: [1;35mSwinV23DLayer[0m[1m([0m
        [1m([0mmhsa[1m)[0m: [1;35mSwinV23DMHSA[0m[1m([0m
          [1m([0mW_qkv[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m144[0m, [33mout_features[0m=[1;36m432[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mproj[1m)[0m: [1;35mLinear[0m


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m144[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m144[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m144[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m144[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m144[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

In [20]:
test_stage_config = SwinV23DStageConfig.model_validate(
    {
        "patch_merging": None,
        "patch_splitting": {
            "final_window_size": (2, 2, 2),
            "out_dim_ratio": 3,
        },
        "depth": 2,
        "num_heads": 4,
        "mlp_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
    }
)
test_stage_config._in_dim = 48
test_stage_config._attention_dim = 48
test_stage_config._out_dim = 16

test = SwinV23DStage(test_stage_config)
display(test)
o = test(torch.randn(2, 8, 8, 8, 48))
display((o[0].shape, [x.shape for x in o[1]]))
o = test(torch.randn(2, 8, 8, 8, 48))
display((o[0].shape, [x.shape for x in o[1]]))


[1;35mSwinV23DStage[0m[1m([0m
  [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m1[0m[1m)[0m: [1;36m2[0m x [1;35mSwinV23DBlock[0m[1m([0m
      [1m([0mw_layer[1m)[0m: [1;35mSwinV23DLayer[0m[1m([0m
        [1m([0mmhsa[1m)[0m: [1;35mSwinV23DMHSA[0m[1m([0m
          [1m([0mW_qkv[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m48[0m, [33mout_features[0m=[1;36m144[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m48[0m, [33mout_features[0m=[1;36m48[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
          [1m([0mcpb_mlp[1m)[0m: [1;35mSequential[0m[1m([0m
            [1m([0m[1;36m0[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m3[0m, [33mout_f


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m48[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

### Encoder

In [21]:
# | export


class SwinV23DEncoder(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()

        for stage_config in config.stages:
            if stage_config.patch_splitting is not None:
                assert (
                    stage_config.patch_merging is not None
                ), "SwinV23DEncoder is not for decoding (mid blocks are ok)."

        self.stages = nn.ModuleList([SwinV23DStage(stage_config) for stage_config in config.stages])

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)

        stage_outputs, layer_outputs = [], []
        for stage_module in self.stages:
            hidden_states, _layer_outputs = stage_module(hidden_states)
            # (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, dim)

            stage_outputs.append(hidden_states)
            layer_outputs.extend(_layer_outputs)

        return hidden_states, stage_outputs, layer_outputs

In [22]:
test_config = SwinV23DConfig.model_validate(
    {
        "dim": 32,
        "patch_size": (2, 2, 2),
        "in_channels": 32,
        "stages": [
            {
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": False,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 3,
                },
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 3,
                },
                "patch_splitting": {
                    "final_window_size": (2, 2, 2),
                    "out_dim_ratio": 3,
                },
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
        ],
    }
)

test = SwinV23DEncoder(test_config)
display(test)
o = test(torch.randn(2, 16, 16, 16, 32))
display((o[0].shape, [x.shape for x in o[1]], [x.shape for x in o[2]]))


[1;35mSwinV23DEncoder[0m[1m([0m
  [1m([0mstages[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwinV23DStage[0m[1m([0m
      [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
        [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwinV23DBlock[0m[1m([0m
          [1m([0mw_layer[1m)[0m: [1;35mSwinV23DLayer[0m[1m([0m
            [1m([0mmhsa[1m)[0m: [1;35mSwinV23DMHSA[0m[1m([0m
              [1m([0mW_qkv[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m32[0m, [33mout_features[0m=[1;36m96[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
              [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m32[0m, [33mout_features[0m=[1;36m32[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
              [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
            [1m)[0m
            [1m([0mlayernorm1


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m96[0m[1m][0m[1m)[0m,
    [1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m96[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m96[0m[1m][0m[1m)[0m[1m][0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m96[0m[1m][0m[1m)[0m,
        [1;35mtorch

### Decoder

In [23]:
# | export


class SwinV23DDecoder(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()

        for stage_config in config.stages:
            if stage_config.patch_merging is not None:
                assert (
                    stage_config.patch_splitting is not None
                ), "SwinV23DDecoder is not for encoding (mid blocks are ok)."

        self.stages = nn.ModuleList([SwinV23DStage(stage_config) for stage_config in config.stages])

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)

        stage_outputs, layer_outputs = [], []
        for stage_module in self.stages:
            hidden_states, _layer_outputs = stage_module(hidden_states)
            # (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, dim)

            stage_outputs.append(hidden_states)
            layer_outputs.extend(_layer_outputs)

        return hidden_states, stage_outputs, layer_outputs

In [24]:
test_config = SwinV23DConfig.model_validate(
    {
        "dim": 96,
        "patch_size": (2, 2, 2),
        "in_channels": 32,
        "stages": [
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 3,
                },
                "patch_splitting": {
                    "final_window_size": (2, 2, 2),
                    "out_dim_ratio": 3,
                },
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
            {
                "patch_splitting": {
                    "final_window_size": (2, 2, 2),
                    "out_dim_ratio": 3,
                },
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
            {
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": False,
            },
        ],
    }
)

test = SwinV23DDecoder(test_config)
display(test)
o = test(torch.randn(2, 16, 16, 16, 96))
display((o[0].shape, [x.shape for x in o[1]], [x.shape for x in o[2]]))


[1;35mSwinV23DDecoder[0m[1m([0m
  [1m([0mstages[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwinV23DStage[0m[1m([0m
      [1m([0mpatch_merging[1m)[0m: [1;35mSwinV23DPatchMerging[0m[1m([0m
        [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m768[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m288[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m)[0m
      [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
        [1m([0m[1;36m0[0m-[1;36m2[0m[1m)[0m: [1;36m3[0m x [1;35mSwinV23DBlock[0m[1m([0m
          [1m([0mw_layer[1m)[0m: [1;35mSwinV23DLayer[0m[1m([0m
            [1m([0mmhsa[1m)[0m: [1;35mSwinV23DMHSA[0m[1m([0m
              [1m([0mW_qkv[1m)[0m: [1;35mLinear[0m[1m([0


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m96[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m
    [1m][0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m288[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m288[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m288[0m[1m

# Embeddings

### Patch embeddings

In [25]:
# | export


class SwinV23DPatchEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()

        patch_size = config.patch_size
        num_channels = config.in_channels
        dim = config.dim

        self.patch_embeddings = nn.Conv3d(
            in_channels=num_channels,
            out_channels=dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

    def forward(self, pixel_values: torch.Tensor):
        # pixel_values: (b, c, z, y, x)

        embeddings = self.patch_embeddings(pixel_values)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        return embeddings

In [26]:
test_config = SwinV23DConfig.model_validate(
    {
        "patch_size": (1, 8, 8),
        "in_channels": 1,
        "dim": 12,
        "stages": [],
    }
)

test = SwinV23DPatchEmbeddings(test_config)
display(test)
o = test(torch.randn(2, 1, 32, 512, 512))
display(o.shape)


[1;35mSwinV23DPatchEmbeddings[0m[1m([0m
  [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m12[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m12[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m

### Position embeddings

In [27]:
# | export


def get_3d_position_embeddings(embedding_size, grid_size, patch_size=(1, 1, 1)):
    if embedding_size % 6 != 0:
        raise ValueError("embed_dim must be divisible by 6")

    grid = get_coords_grid(grid_size)
    # (3, d, h, w)

    grid = rearrange(grid, "x d h w -> x 1 d h w")
    # (3, 1, d, h, w)

    omega = torch.arange(embedding_size // 6, dtype=torch.float32)
    omega /= embedding_size / 6.0
    omega = 1.0 / 10000**omega
    # (d // 6)

    patch_multiplier = torch.Tensor(patch_size) / min(patch_size)

    position_embeddings = []
    for i, grid_subset in enumerate(grid):
        grid_subset = grid_subset.reshape(-1)
        out = torch.einsum("m,d->md", grid_subset, omega)

        emb_sin = torch.sin(out)
        emb_cos = torch.cos(out)

        emb = torch.cat([emb_sin, emb_cos], axis=1) * patch_multiplier[i]
        position_embeddings.append(emb)

    position_embeddings = torch.cat(position_embeddings, axis=1)
    # (embedding_size, d * h * w)
    d, h, w = grid_size
    position_embeddings = rearrange(position_embeddings, "(d h w) e -> 1 e d h w", d=d, h=h, w=w)
    # (1, embedding_size, d, h, w)

    return position_embeddings

In [28]:
# | export


def embed_spacings_in_position_embeddings(embeddings: torch.Tensor, spacings: torch.Tensor):
    assert spacings is not None, "spacing information cannot be None"
    assert spacings.ndim == 2, "Please provide spacing information for each batch element"
    _, embedding_size, _, _, _ = embeddings.shape
    assert embedding_size % 3 == 0, "To embed spacing info, the embedding size must be divisible by 3"
    embeddings = embeddings * repeat(spacings, f"B S -> B (S {int(embedding_size / 3)}) 1 1 1", S=3)

    return embeddings

In [29]:
# | export


class SwinV23DEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config

        dim = config.dim

        self.patch_embeddings = SwinV23DPatchEmbeddings(config)
        self.layer_norm = nn.LayerNorm(dim)

        self.absolute_position_embeddings = None
        if config.use_absolute_position_embeddings:
            self.absolute_position_embeddings = {}  # embeddings will be cached in this for every input image size
            if config.learnable_absolute_position_embeddings:
                grid_size = self._get_grid_size(config.image_size)
                self.absolute_position_embeddings[config.image_size] = nn.Parameter(
                    torch.randn(1, dim, grid_size[0], grid_size[1], grid_size[2])
                )

    def _get_grid_size(self, image_size):
        grid_size = (
            image_size[0] // self.config.patch_size[0],
            image_size[1] // self.config.patch_size[1],
            image_size[2] // self.config.patch_size[2],
        )
        return grid_size

    def forward(
        self,
        pixel_values: torch.Tensor,
        spacings: torch.Tensor = None,
        mask_patches: torch.Tensor = None,
        mask_token: torch.Tensor = None,
    ):
        # pixel_values: (b, c, z, y, x)

        embeddings = self.patch_embeddings(pixel_values)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)
        embeddings = rearrange(embeddings, "b d nz ny nx -> b nz ny nx d")
        embeddings = self.layer_norm(embeddings)
        embeddings = rearrange(embeddings, "b nz ny nx d -> b d nz ny nx")
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        if mask_patches is not None:
            # mask_patches (binary mask): (b, num_patches_z, num_patches_y, num_patches_x)
            # mask_token: (1, dim, 1, 1, 1)
            mask_patches = repeat(mask_patches, "b z y x -> b d z y x", d=embeddings.shape[1])
            embeddings = (embeddings * (1 - mask_patches)) + (mask_patches * mask_token)

        if self.absolute_position_embeddings is not None:
            image_size = tuple(pixel_values.shape[-3:])

            if image_size not in self.absolute_position_embeddings:
                grid_size = self._get_grid_size(image_size)
                self.absolute_position_embeddings[image_size] = get_3d_position_embeddings(
                    self.config.dim, grid_size, self.config.patch_size
                )

            absolute_position_embeddings = self.absolute_position_embeddings[image_size].to(embeddings.device)
            # (1, dim, num_patches_z, num_patches_y, num_patches_x)
            if self.config.embed_spacing_info:
                absolute_position_embeddings = embed_spacings_in_position_embeddings(
                    absolute_position_embeddings, spacings
                )
                # (b, dim, num_patches_z, num_patches_y, num_patches_x)

            embeddings = embeddings + absolute_position_embeddings
            # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        return embeddings

In [30]:
test_config = SwinV23DConfig.model_validate(
    {
        "patch_size": (1, 8, 8),
        "in_channels": 1,
        "dim": 36,
        "use_absolute_position_embeddings": True,
        "learnable_absolute_position_embeddings": False,
        "embed_spacing_info": False,
        "stages": [],
    }
)

test = SwinV23DEmbeddings(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
)
display(o.shape)


[1;35mSwinV23DEmbeddings[0m[1m([0m
  [1m([0mpatch_embeddings[1m)[0m: [1;35mSwinV23DPatchEmbeddings[0m[1m([0m
    [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m36[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m
  [1m)[0m
  [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m36[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m36[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m

In [31]:
e = get_3d_position_embeddings(6, (2, 2, 2), (1, 1, 1))
display(e.shape)

print("Distance of (0, 0, 0) with other coords")
for i in range(2):
    for j in range(2):
        for k in range(2):
            print((i, j, k), np.linalg.norm(e[0, :, 0, 0, 0] - e[0, :, i, j, k]))

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m1[0m, [1;36m6[0m, [1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m][0m[1m)[0m

Distance of (0, 0, 0) with other coords
(0, 0, 0) 0.0
(0, 0, 1) 0.95885104
(0, 1, 0) 0.95885104
(0, 1, 1) 1.3560202
(1, 0, 0) 0.95885104
(1, 0, 1) 1.3560202
(1, 1, 0) 1.3560202
(1, 1, 1) 1.6607786


# Models

In [32]:
# | export


class SwinV23DModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()

        self.embeddings = SwinV23DEmbeddings(config)
        self.pos_drop = nn.Dropout(config.drop_prob)
        self.encoder = SwinV23DEncoder(config)

    def forward(
        self,
        pixel_values: torch.Tensor,
        spacings: torch.Tensor = None,
        mask_patches: torch.Tensor = None,
        mask_token: torch.Tensor = None,
    ):
        # pixel_values: (b, c, z, y, x)
        # spacings: (b, 3)
        # mask_patches: (num_patches_z, num_patches_y, num_patches_x)

        embeddings = self.embeddings(pixel_values, spacings, mask_patches, mask_token)
        embeddings = self.pos_drop(embeddings)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        embeddings = rearrange(embeddings, "b e nz ny nx -> b nz ny nx e")
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        encoded, stage_outputs, layer_outputs = self.encoder(embeddings)
        # encoded: (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, dim)
        # stage_outputs, layer_outputs: list of (b, some_num_patches_z, some_num_patches_y, some_num_patches_x, dim)

        encoded = rearrange(encoded, "b nz ny nx d -> b d nz ny nx")
        # (b, dim, new_num_patches_z, new_num_patches_y, new_num_patches_x)

        for i in range(len(stage_outputs)):
            stage_outputs[i] = rearrange(stage_outputs[i], "b nz ny nx d -> b d nz ny nx")
            # (b, dim, some_num_patches_z, some_num_patches_y, some_num_patches_x)

        for i in range(len(layer_outputs)):
            layer_outputs[i] = rearrange(layer_outputs[i], "b nz ny nx d -> b d nz ny nx")
            # (b, dim, some_num_patches_z, some_num_patches_y, some_num_patches_x)

        return encoded, stage_outputs, layer_outputs

In [33]:
test_config = SwinV23DConfig.model_validate(
    {
        "patch_size": (1, 8, 8),
        "in_channels": 1,
        "use_absolute_position_embeddings": True,
        "learnable_absolute_position_embeddings": False,
        "embed_spacing_info": False,
        "dim": 36,
        "drop_prob": 0.2,
        "stages": [
            {
                "patch_merging": None,
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": False,
                "attn_drop_prob": 0.2,
                "proj_drop_prob": 0.2,
                "mlp_drop_prob": 0.2,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 3,
                },
                "depth": 3,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
                "attn_drop_prob": 0.2,
                "proj_drop_prob": 0.2,
                "mlp_drop_prob": 0.2,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 3,
                },
                "depth": 1,
                "num_heads": 4,
                "mlp_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
                "attn_drop_prob": 0.2,
                "proj_drop_prob": 0.2,
            },
        ],
    }
)

test = SwinV23DModel(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
)
display((o[0].shape, [x.shape for x in o[1]], [x.shape for x in o[2]]))


[1;35mSwinV23DModel[0m[1m([0m
  [1m([0membeddings[1m)[0m: [1;35mSwinV23DEmbeddings[0m[1m([0m
    [1m([0mpatch_embeddings[1m)[0m: [1;35mSwinV23DPatchEmbeddings[0m[1m([0m
      [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m36[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m
    [1m)[0m
    [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m36[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
  [1m([0mpos_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.2[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mencoder[1m)[0m: [1;35mSwinV23DEncoder[0m[1m([0m
    [1m([0mstages[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwinV23DStage[0m[1m


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m324[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m36[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m108[0m, [1;36m16[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m324[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m
    [1m][0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m36[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m36[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m108[0m, [1;36m16[0m, [1;36m32[0m, [1;36m32

# Masked Image Modeling

In [34]:
# | export


class SwinV23DReconstructionDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.in_channels = config.in_channels

        dim = config.dim
        patch_size = config.patch_size

        out_dim = np.prod(patch_size) * self.in_channels
        self.final_patch_size = patch_size

        self.decoder = nn.Conv3d(dim, out_dim, kernel_size=1)

    def forward(self, encodings: torch.Tensor):
        # encodings: (b, dim, num_patches_z, num_patches_y, num_patches_x)

        decoded = self.decoder(encodings)
        # (b, new_dim, num_patches_z, num_patches_y, num_patches_x)

        decoded = rearrange(
            decoded,
            "b (c pz py px) nz ny nx -> b c (nz pz) (ny py) (nx px)",
            c=self.in_channels,
            pz=self.final_patch_size[0],
            py=self.final_patch_size[1],
            px=self.final_patch_size[2],
        )
        # (b, c, z, y, x)

        return decoded

In [35]:
test_config = SwinV23DConfig.model_validate(
    {
        "dim": 108,
        "patch_size": (2, 16, 16),
        "in_channels": 1,
        "stages": [],
    }
)

test = SwinV23DReconstructionDecoder(test_config)
display(test)
display(test(torch.randn(2, 108, 16, 32, 32)).shape)


[1;35mSwinV23DReconstructionDecoder[0m[1m([0m
  [1m([0mdecoder[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m108[0m, [1;36m512[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m32[0m, [1;36m512[0m, [1;36m512[0m[1m][0m[1m)[0m

In [36]:
# | export


class SwinV23DMIM(nn.Module):
    def __init__(self, swin_config, decoder_config, mim_config):
        super().__init__()

        self.swin_config = swin_config
        self.decoder_config = decoder_config
        self.mim_config = mim_config

        self.swin = SwinV23DModel(swin_config)
        self.decoder = SwinV23DReconstructionDecoder(decoder_config)

        self.mask_token = nn.Parameter(torch.randn(1, swin_config.dim, 1, 1, 1))

    def _get_grid_size(self, image_size):
        grid_size = (
            image_size[0] // self.swin_config.patch_size[0],
            image_size[1] // self.swin_config.patch_size[1],
            image_size[2] // self.swin_config.patch_size[2],
        )
        return grid_size

    def mask_image(self, pixel_values: torch.Tensor):
        b = pixel_values.shape[0]

        mask_ratio = self.mim_config["mask_ratio"]
        mask_grid_size = self.mim_config["mask_grid_size"]
        num_patches = np.prod(mask_grid_size)
        mask_patches = []
        for _ in range(b):
            _mask_patches = torch.zeros(num_patches, dtype=torch.int8, device=pixel_values.device)
            _mask_patches[: int(mask_ratio * num_patches)] = 1
            _mask_patches = _mask_patches[torch.randperm(num_patches)]
            _mask_patches = rearrange(
                _mask_patches, "(z y x) -> z y x", z=mask_grid_size[0], y=mask_grid_size[1], x=mask_grid_size[2]
            )
            mask_patches.append(_mask_patches)
        mask_patches: torch.Tensor = torch.stack(mask_patches, dim=0)

        grid_size = self._get_grid_size(self.swin_config.image_size)
        assert all(
            [x % y == 0 for x, y in zip(grid_size, mask_grid_size)]
        ), "Mask grid size must divide image grid size"
        mask_patches = repeat(
            mask_patches,
            "b z y x -> b (z gz) (y gy) (x gx)",
            gz=grid_size[0] // mask_grid_size[0],
            gy=grid_size[1] // mask_grid_size[1],
            gx=grid_size[2] // mask_grid_size[2],
        )

        return mask_patches

In [37]:
# | export


class SwinV23DSimMIM(SwinV23DMIM, PyTorchModelHubMixin):
    def __init__(self, swin_config, decoder_config, mim_config):
        super().__init__(swin_config, decoder_config, mim_config)

    @staticmethod
    def loss_fn(pred: torch.Tensor, target: torch.Tensor, reduction="mean"):
        return nn.functional.l1_loss(pred, target, reduction=reduction)

    def forward(self, pixel_values: torch.Tensor, spacings: torch.Tensor = None):
        mask_patches = self.mask_image(pixel_values)

        encodings, _, _ = self.swin(pixel_values, spacings, mask_patches, self.mask_token)

        decoded = self.decoder(encodings)

        loss = self.loss_fn(decoded, pixel_values, reduction="none")
        mask = repeat(
            mask_patches,
            "b z y x -> b (z pz) (y py) (x px)",
            pz=self.swin_config.patch_size[0],
            py=self.swin_config.patch_size[1],
            px=self.swin_config.patch_size[2],
        )
        loss = (loss * mask).sum() / ((mask.sum() + 1e-5) * self.swin_config.in_channels)

        return decoded, loss, mask

In [38]:
test_config = {
    "swin": SwinV23DConfig.model_validate(
        {
            "dim": 36,
            "patch_size": (1, 8, 8),
            "image_size": (32, 512, 512),
            "in_channels": 1,
            "use_absolute_position_embeddings": True,
            "learnable_absolute_position_embeddings": False,
            "embed_spacing_info": False,
            "stages": [
                {
                    "patch_merging": None,
                    "depth": 2,
                    "num_heads": 4,
                    "mlp_ratio": 4,
                    "layer_norm_eps": 1e-6,
                    "window_size": (4, 4, 4),
                    "use_relative_position_bias": False,
                },
                {
                    "patch_merging": {
                        "merge_window_size": (2, 2, 2),
                        "out_dim_ratio": 3,
                    },
                    "depth": 6,
                    "num_heads": 4,
                    "mlp_ratio": 4,
                    "layer_norm_eps": 1e-6,
                    "window_size": (4, 4, 4),
                    "use_relative_position_bias": True,
                },
            ],
        }
    ),
    "mim": {
        "mask_ratio": 0.8,
        "mask_grid_size": (8, 8, 8),
    },
}
test_config["decoder"] = SwinV23DConfig.model_validate(
    {
        "dim": test_config["swin"].stages[-1]._out_dim,
        "patch_size": test_config["swin"].stages[-1]._out_patch_size,
        "in_channels": test_config["swin"].in_channels,
        "stages": [],
    }
)

test = SwinV23DSimMIM(test_config["swin"], test_config["decoder"], test_config["mim"])
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
)
display((o[0].shape, o[1], o[2].shape))


[1;35mSwinV23DSimMIM[0m[1m([0m
  [1m([0mswin[1m)[0m: [1;35mSwinV23DModel[0m[1m([0m
    [1m([0membeddings[1m)[0m: [1;35mSwinV23DEmbeddings[0m[1m([0m
      [1m([0mpatch_embeddings[1m)[0m: [1;35mSwinV23DPatchEmbeddings[0m[1m([0m
        [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m36[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m
      [1m)[0m
      [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m36[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
    [1m)[0m
    [1m([0mpos_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mencoder[1m)[0m: [1;35mSwinV23DEncoder[0m[1m([0m
      [1m([0mstages[1m)[0m: [1;35mModuleList


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m32[0m, [1;36m512[0m, [1;36m512[0m[1m][0m[1m)[0m,
    [1;35mtensor[0m[1m([0m[1;36m4.7475[0m, [33mgrad_fn[0m=[1m<[0m[1;95mDivBackward0[0m[1m>[0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m512[0m, [1;36m512[0m[1m][0m[1m)[0m
[1m)[0m

In [39]:
# | export


class SwinV23DVAEMIM(SwinV23DMIM, PyTorchModelHubMixin):
    def __init__(self, swin_config, decoder_config, mim_config):
        super().__init__(swin_config, decoder_config, mim_config)

        assert (decoder_config["beta"] is None) is not (
            decoder_config["beta_schedule"] is None
        ), "Only one of beta or beta_schedule should be provided"

        if decoder_config["beta_schedule"] is not None:
            self.beta_schedule = decoder_config["beta_schedule"]
            self.beta_increment = (self.beta_schedule[2] - self.beta_schedule[1]) / self.beta_schedule[0]
            self.beta = None
        else:
            self.beta = decoder_config["beta"]
            self.beta_schedule = None
            self.beta_increment = None

        self.mu_layer = nn.Conv3d(swin_config.stages[-1]._out_dim, decoder_config.dim, kernel_size=1)
        self.logvar_layer = nn.Conv3d(swin_config.stages[-1]._out_dim, decoder_config.dim, kernel_size=1)

    def get_beta(self):
        # If fixed beta
        if self.beta_schedule is None:
            return self.beta

        # Else there is a beta schedule
        if self.beta is None:
            # If first iteration
            self.beta = self.beta_schedule[1]
        else:
            # Calculate new beta and return
            self.beta = min(self.beta + self.beta_increment, self.beta_schedule[2])
        return self.beta

    def reparameterize(self, mu, logvar):
        return mu + torch.randn_like(logvar) * torch.exp(0.5 * logvar)

    @staticmethod
    def reconstruction_loss_fn(pred: torch.Tensor, target: torch.Tensor, loss_type: str = "l2", reduction="mean"):
        loss = ...
        if loss_type == "l2":
            loss = nn.functional.mse_loss(pred, target, reduction=reduction)
        elif loss_type == "l1":
            loss = nn.functional.l1_loss(pred, target, reduction=reduction)
        else:
            raise NotImplementedError(f"Loss type {loss_type} not implemented")
        return loss

    @staticmethod
    def kl_divergence_loss_fn(mu: torch.Tensor, logvar: torch.Tensor):
        return torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1))

    def forward(
        self,
        pixel_values: torch.Tensor,
        spacings: torch.Tensor = None,
        reconstruction_loss_type: str = "l2",
    ):
        mask_patches = self.mask_image(pixel_values)

        encodings, _, _ = self.swin(pixel_values, spacings, mask_patches, self.mask_token)

        mu = self.mu_layer(encodings)
        logvar = self.logvar_layer(encodings)
        kl_loss = self.kl_divergence_loss_fn(mu, logvar)

        sampled = self.reparameterize(mu, logvar)
        decoded = self.decoder(sampled)

        reconstruction_loss = self.reconstruction_loss_fn(decoded, pixel_values, reconstruction_loss_type)

        mask = repeat(
            mask_patches,
            "b z y x -> b (z pz) (y py) (x px)",
            pz=self.swin_config.patch_size[0],
            py=self.swin_config.patch_size[1],
            px=self.swin_config.patch_size[2],
        )

        beta = self.get_beta()
        loss = reconstruction_loss + beta * kl_loss

        return decoded, loss, mask, [reconstruction_loss, kl_loss, beta]

In [40]:
test_config = {
    "swin": SwinV23DConfig.model_validate(
        {
            "dim": 36,
            "patch_size": (1, 8, 8),
            "image_size": (32, 512, 512),
            "in_channels": 1,
            "use_absolute_position_embeddings": True,
            "learnable_absolute_position_embeddings": False,
            "embed_spacing_info": False,
            "stages": [
                {
                    "patch_merging": None,
                    "depth": 2,
                    "num_heads": 4,
                    "mlp_ratio": 4,
                    "layer_norm_eps": 1e-6,
                    "window_size": (4, 4, 4),
                    "use_relative_position_bias": False,
                },
                {
                    "patch_merging": {
                        "merge_window_size": (2, 2, 2),
                        "out_dim_ratio": 3,
                    },
                    "depth": 6,
                    "num_heads": 4,
                    "mlp_ratio": 4,
                    "layer_norm_eps": 1e-6,
                    "window_size": (4, 4, 4),
                    "use_relative_position_bias": True,
                },
            ],
        }
    ),
    "mim": {
        "mask_ratio": 0.8,
        "mask_grid_size": (8, 8, 8),
    },
}
test_config["decoder"] = munchify(
    {
        "dim": test_config["swin"].stages[-1]._out_dim,
        "patch_size": test_config["swin"].stages[-1]._out_patch_size,
        "in_channels": test_config["swin"].in_channels,
        "beta": 1,
        "beta_schedule": None,
    }
)

test = SwinV23DVAEMIM(test_config["swin"], test_config["decoder"], test_config["mim"])
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
)
display((o[0].shape, o[1], o[2].shape, o[3]))


[1;35mSwinV23DVAEMIM[0m[1m([0m
  [1m([0mswin[1m)[0m: [1;35mSwinV23DModel[0m[1m([0m
    [1m([0membeddings[1m)[0m: [1;35mSwinV23DEmbeddings[0m[1m([0m
      [1m([0mpatch_embeddings[1m)[0m: [1;35mSwinV23DPatchEmbeddings[0m[1m([0m
        [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m36[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m
      [1m)[0m
      [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m36[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
    [1m)[0m
    [1m([0mpos_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mencoder[1m)[0m: [1;35mSwinV23DEncoder[0m[1m([0m
      [1m([0mstages[1m)[0m: [1;35mModuleList


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m32[0m, [1;36m512[0m, [1;36m512[0m[1m][0m[1m)[0m,
    [1;35mtensor[0m[1m([0m[1;36m2845.8687[0m, [33mgrad_fn[0m=[1m<[0m[1;95mAddBackward0[0m[39m>[0m[1;39m)[0m[39m,[0m
[39m    [0m[1;35mtorch.Size[0m[1;39m([0m[1;39m[[0m[1;36m2[0m[39m, [0m[1;36m32[0m[39m, [0m[1;36m512[0m[39m, [0m[1;36m512[0m[1;39m][0m[1;39m)[0m[39m,[0m
[39m    [0m[1;39m[[0m[1;35mtensor[0m[1;39m([0m[1;36m18.8984[0m[39m, [0m[33mgrad_fn[0m[39m=<MseLossBackward0>[0m[1;39m)[0m[39m, [0m[1;35mtensor[0m[1;39m([0m[1;36m2826.9702[0m[39m, [0m[33mgrad_fn[0m[39m=<MeanBackward0[0m[1m>[0m[1m)[0m, [1;36m1[0m[1m][0m
[1m)[0m

# Some more tests

### Overfitting tests

In [41]:
from tqdm.auto import tqdm

sample_spacings = torch.tensor([[1, 0.1, 0.1], [2, 0.2, 0.2], [3, 0.3, 0.3], [4, 0.4, 0.4], [5, 0.5, 0.5]])
sample_batch = torch.rand(3, 1, 16, 128, 128)
sample_config = {
    "swin": SwinV23DConfig.model_validate(
        {
            "patch_size": (1, 4, 4),
            "image_size": (16, 128, 128),
            "dim": 12,
            "in_channels": 1,
            "use_absolute_position_embeddings": True,
            "learnable_absolute_position_embeddings": False,
            "embed_spacing_info": False,
            "drop_prob": 0.2,
            "stages": [
                {
                    "patch_merging": None,
                    "depth": 1,
                    "num_heads": 4,
                    "mlp_ratio": 4,
                    "layer_norm_eps": 1e-6,
                    "window_size": (4, 4, 4),
                    "use_relative_position_bias": True,
                    "attn_drop_prob": 0.2,
                    "proj_drop_prob": 0.2,
                    "mlp_drop_prob": 0.2,
                },
                {
                    "patch_merging": {
                        "merge_window_size": (2, 2, 2),
                        "out_dim_ratio": 4,
                    },
                    "depth": 3,
                    "num_heads": 4,
                    "mlp_ratio": 4,
                    "layer_norm_eps": 1e-6,
                    "window_size": (4, 4, 4),
                    "use_relative_position_bias": True,
                },
                {
                    "patch_merging": {
                        "merge_window_size": (2, 2, 2),
                        "out_dim_ratio": 4,
                    },
                    "depth": 1,
                    "num_heads": 4,
                    "mlp_ratio": 4,
                    "layer_norm_eps": 1e-6,
                    "window_size": (4, 4, 4),
                    "use_relative_position_bias": True,
                },
            ],
        }
    ),
    "mim": {
        "mask_ratio": 0.7,
        "mask_grid_size": (8, 8, 8),
    },
}
sample_config["decoder"] = munchify(
    {
        "dim": sample_config["swin"].stages[-1]._out_dim,
        "patch_size": sample_config["swin"].stages[-1]._out_patch_size,
        "in_channels": sample_config["swin"].in_channels,
    }
)

model = SwinV23DSimMIM(sample_config["swin"], sample_config["decoder"], sample_config["mim"])

sum(x.numel() for x in model.swin.parameters()), sum(x.numel() for x in model.decoder.parameters())

[1m([0m[1;36m1183892[0m, [1;36m197632[0m[1m)[0m

In [42]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

In [43]:
sample_batch = sample_batch.cuda()
sample_spacings = sample_spacings.cuda()
model = model.cuda()

In [44]:
for i in tqdm(range(200)):
    optimizer.zero_grad()
    output = model(sample_batch, sample_spacings)
    print(f"Loss: {output[1]:f}\tLR: {scheduler.get_last_lr()[0]:f}")
    output[1].backward()
    optimizer.step()
    scheduler.step()

  0%|          | 0/200 [00:00<?, ?it/s]

In [45]:
for name, param in model.named_parameters():
    if param.grad is None:
        print(name)

In [None]:
sample_config["decoder"]["beta"] = None
sample_config["decoder"]["beta_schedule"] = (100, 0.0, 2.0)

model = SwinV23DVAEMIM(sample_config["swin"], sample_config["decoder"], sample_config["mim"])

encoder_params = sum(x.numel() for x in model.swin.parameters())
decoder_params = sum(x.numel() for x in model.decoder.parameters())
sampling_params = sum(x.numel() for x in model.mu_layer.parameters()) + sum(
    x.numel() for x in model.logvar_layer.parameters()
)
encoder_params, decoder_params, sampling_params

In [47]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

In [48]:
sample_batch = sample_batch.cuda()
sample_spacings = sample_spacings.cuda()
model = model.cuda()

In [None]:
for i in tqdm(range(200)):
    optimizer.zero_grad()
    output = model(sample_batch, sample_spacings)
    print(f"Loss: {output[1]:f}\tLR: {scheduler.get_last_lr()[0]:f}\tBeta: {output[3][2]:f}")
    # print(output[-1])
    output[1].backward()
    optimizer.step()
    scheduler.step()

In [50]:
for name, param in model.named_parameters():
    if param.grad is None:
        print(name)

# nbdev

In [51]:
!nbdev_export

# Rough work