In [1]:
# | default_exp swin_3d

# Imports

In [2]:
# | export

import torch
import numpy as np
from torch import nn
from einops import rearrange, repeat

# Config

In [3]:
# | export


def populate_and_validate_config(config: dict) -> dict:
    assert config["stages"][0]["patch_merging"] is None

    # Prepare config based on provided values
    dim = config["dim"]
    patch_size = config["patch_size"]
    # image_size = config["image_size"]  # This may not be fixed while fine-tuning.
    for i in range(len(config["stages"])):
        stage = config["stages"][i]
        stage["_in_dim"] = dim
        stage["_in_patch_size"] = patch_size
        # stage['_in"grid_size'] = tuple([image // patch for image, patch in zip(image_size, patch_size)])
        if stage["patch_merging"] is not None:
            dim *= stage["patch_merging"]["out_dim_ratio"]
            patch_size = tuple(
                [patch * window for patch, window in zip(patch_size, stage["patch_merging"]["merge_window_size"])]
            )
        stage["_out_dim"] = dim
        stage["_out_patch_size"] = patch_size
        # stage["_out_grid_size"] = tuple([image // patch for image, patch in zip(image_size, patch_size)])

    for stage in config["stages"]:
        assert stage["_out_dim"] % stage["num_heads"] == 0, stage

    return config

In [4]:
test_config = {
    "patch_size": (1, 8, 8),
    "in_channels": 1,
    "use_absolute_position_embeddings": True,
    "learnable_absolute_position_embeddings": False,
    "embed_spacing_info": False,
    "image_size": (32, 512, 512),
    "dim": 36,
    "stages": [
        {
            "patch_merging": None,
            "depth": 1,
            "num_heads": 4,
            "intermediate_ratio": 4,
            "layer_norm_eps": 1e-6,
            "window_size": (4, 4, 4),
            "use_relative_position_bias": False,
        },
        {
            "patch_merging": {
                "merge_window_size": (2, 2, 2),
                "out_dim_ratio": 3,
            },
            "depth": 3,
            "num_heads": 4,
            "intermediate_ratio": 4,
            "layer_norm_eps": 1e-6,
            "window_size": (4, 4, 4),
            "use_relative_position_bias": True,
        },
        {
            "patch_merging": {
                "merge_window_size": (2, 2, 2),
                "out_dim_ratio": 3,
            },
            "depth": 1,
            "num_heads": 4,
            "intermediate_ratio": 4,
            "layer_norm_eps": 1e-6,
            "window_size": (4, 4, 4),
            "use_relative_position_bias": True,
        },
    ],
}

populate_and_validate_config(test_config)


[1m{[0m
    [32m'patch_size'[0m: [1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m,
    [32m'in_channels'[0m: [1;36m1[0m,
    [32m'use_absolute_position_embeddings'[0m: [3;92mTrue[0m,
    [32m'learnable_absolute_position_embeddings'[0m: [3;91mFalse[0m,
    [32m'embed_spacing_info'[0m: [3;91mFalse[0m,
    [32m'image_size'[0m: [1m([0m[1;36m32[0m, [1;36m512[0m, [1;36m512[0m[1m)[0m,
    [32m'dim'[0m: [1;36m36[0m,
    [32m'stages'[0m: [1m[[0m
        [1m{[0m
            [32m'patch_merging'[0m: [3;35mNone[0m,
            [32m'depth'[0m: [1;36m1[0m,
            [32m'num_heads'[0m: [1;36m4[0m,
            [32m'intermediate_ratio'[0m: [1;36m4[0m,
            [32m'layer_norm_eps'[0m: [1;36m1e-06[0m,
            [32m'window_size'[0m: [1m([0m[1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m)[0m,
            [32m'use_relative_position_bias'[0m: [3;91mFalse[0m,
            [32m'_in_dim'[0m: [1;36m36[0m,
            [32m

# Architecture

### Basic Layers

In [5]:
# | export


def get_coords_grid(grid_size):
    d, h, w = grid_size

    grid_d = torch.arange(d, dtype=torch.int32)
    grid_h = torch.arange(h, dtype=torch.int32)
    grid_w = torch.arange(w, dtype=torch.int32)

    grid = torch.meshgrid(grid_w, grid_h, grid_d, indexing="ij")
    grid = torch.stack(grid, axis=0)
    # (3, d, h, w)

    return grid

In [6]:
# | export


class Swin3DMHSA(nn.Module):
    def __init__(self, dim, num_heads, window_size, use_relative_position_bias):
        super().__init__()

        assert dim % num_heads == 0, "dimension must be divisible by number of heads"

        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size

        self.per_head_dim = int(dim // num_heads)

        self.W_qkv = nn.Linear(dim, 3 * dim)
        self.proj = nn.Linear(dim, dim)

        # TODO: Add embed_spacing_info functionality
        # TODO: Add dropout everywhere
        self.use_relative_position_bias = use_relative_position_bias
        if use_relative_position_bias:
            relative_limits = (2 * window_size[0] - 1, 2 * window_size[1] - 1, 2 * window_size[2] - 1)

            self.relative_position_bias_table = nn.Parameter(torch.randn(num_heads, np.prod(relative_limits)))

            coords = get_coords_grid(window_size)
            coords_flatten = rearrange(
                coords, "three_dimensional d h w -> three_dimensional (d h w)", three_dimensional=3
            )
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()
            relative_coords[:, :, 0] += window_size[0] - 1
            relative_coords[:, :, 1] += window_size[1] - 1
            relative_coords[:, :, 2] += window_size[2] - 1
            relative_position_index: torch.Tensor = (
                relative_coords[:, :, 0] * relative_limits[1] * relative_limits[2]
                + relative_coords[:, :, 1] * relative_limits[2]
                + relative_coords[:, :, 2]
            )
            self.relative_position_index = relative_position_index.flatten()
        
        self.init()

    def init(self):
        nn.init.xavier_normal_(self.W_qkv.weight)
        nn.init.xavier_normal_(self.proj.weight)

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (windowed_b, window_size_z window_size_y window_size_x, dim)
        _, num_patches_z, num_patches_y, num_patches_x, _ = hidden_states.shape

        query, key, value = rearrange(
            self.W_qkv(hidden_states),
            "b nz ny nx (n num_heads d) -> n b num_heads (nz ny nx) d",
            n=3,
            num_heads=self.num_heads,
        )
        # num_patches = window_size_z window_size_y window_size_x
        # Each is (windowed_b, num_heads, num_patches, per_head_dim)

        attention_scores = query @ rearrange(key, "b num_heads n d -> b num_heads d n")
        attention_scores = attention_scores / (self.per_head_dim**0.5)
        # (windowed_b, num_heads, num_patches, num_patches)

        if self.use_relative_position_bias:
            relative_position_bias = self.relative_position_bias_table[:, self.relative_position_index]
            relative_position_bias = relative_position_bias.reshape(
                1, np.prod(self.window_size), np.prod(self.window_size), -1
            )
            relative_position_bias = relative_position_bias.permute(0, 3, 1, 2).contiguous()
            attention_scores = attention_scores + relative_position_bias

        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        # (windowed_b, num_heads, num_patches, num_patches)

        context = attention_probs @ value
        # (windowed_b, num_heads, num_patches, per_head_dim)
        context = rearrange(
            context,
            "b num_heads (num_patches_z num_patches_y num_patches_x) d -> "
            "b num_patches_z num_patches_y num_patches_x (num_heads d)",
            num_patches_z=num_patches_z,
            num_patches_y=num_patches_y,
            num_patches_x=num_patches_x,
        )
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        context = self.proj(context)
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        return context

In [7]:
test = Swin3DMHSA(54, 6, (4, 4, 4), True)
display(test)
display(test(torch.randn(2, 4, 4, 4, 54)).shape)


[1;35mSwin3DMHSA[0m[1m([0m
  [1m([0mW_qkv[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m162[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m54[0m[1m][0m[1m)[0m

In [8]:
# | export


class Swin3DLayerMLP(nn.Module):
    def __init__(self, dim, intermediate_ratio):
        super().__init__()
        self.dense1 = nn.Linear(dim, dim * intermediate_ratio)
        self.act = nn.GELU()
        self.dense2 = nn.Linear(dim * intermediate_ratio, dim)
        self.init()

    def init(self):
        nn.init.xavier_normal_(self.dense1.weight)
        nn.init.xavier_normal_(self.dense2.weight)

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (windowed_b, window_size_z window_size_y window_size_x, dim)
        hidden_states = self.dense1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.dense2(hidden_states)
        return hidden_states

In [9]:
test = Swin3DLayerMLP(64, 256)
display(test)
display(test(torch.randn(2, 4, 4, 4, 64)).shape)


[1;35mSwin3DLayerMLP[0m[1m([0m
  [1m([0mdense1[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m16384[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mact[1m)[0m: [1;35mGELU[0m[1m([0m[33mapproximate[0m=[32m'none'[0m[1m)[0m
  [1m([0mdense2[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m16384[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m

In [10]:
# | export


class Swin3DLayer(nn.Module):
    def __init__(self, dim, num_heads, intermediate_ratio, layer_norm_eps, window_size, use_relative_position_bias):
        super().__init__()

        self.window_size = window_size

        self.layernorm_before = nn.LayerNorm(dim, eps=layer_norm_eps)
        self.mhsa = Swin3DMHSA(dim, num_heads, window_size, use_relative_position_bias)
        self.layernorm_after = nn.LayerNorm(dim, eps=layer_norm_eps)
        self.mlp = Swin3DLayerMLP(dim, intermediate_ratio)
        
        self.init()

    def init(self):
        nn.init.constant_(self.layernorm_before.weight, 1.0)
        nn.init.constant_(self.layernorm_before.bias, 0.0)
        nn.init.constant_(self.layernorm_after.weight, 1.0)
        nn.init.constant_(self.layernorm_after.bias, 0.0)

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)
        _, num_patches_z, num_patches_y, num_patches_x, _ = hidden_states.shape

        # Perform windowing
        window_size_z, window_size_y, window_size_x = self.window_size
        num_windows_z, num_windows_y, num_windows_x = (
            num_patches_z // window_size_z,
            num_patches_y // window_size_y,
            num_patches_x // window_size_x,
        )
        hidden_states = rearrange(
            hidden_states,
            "b (num_windows_z window_size_z) (num_windows_y window_size_y) (num_windows_x window_size_x) dim -> "
            "(b num_windows_z num_windows_y num_windows_x) window_size_z window_size_y window_size_x dim ",
            num_windows_z=num_windows_z,
            num_windows_y=num_windows_y,
            num_windows_x=num_windows_x,
            window_size_z=window_size_z,
            window_size_y=window_size_y,
            window_size_x=window_size_x,
        )

        res_connection1 = hidden_states
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        hidden_states = self.layernorm_before(hidden_states)
        hidden_states = self.mhsa(hidden_states)
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        res_connection2 = hidden_states + res_connection1
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        hidden_states = self.layernorm_after(res_connection2)
        hidden_states = self.mlp(hidden_states)
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        hidden_states = hidden_states + res_connection2
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        # Undo windowing
        output = rearrange(
            hidden_states,
            "(b num_windows_z num_windows_y num_windows_x) window_size_z window_size_y window_size_x dim -> "
            "b (num_windows_z window_size_z) (num_windows_y window_size_y) (num_windows_x window_size_x) dim",
            num_windows_z=num_windows_z,
            num_windows_y=num_windows_y,
            num_windows_x=num_windows_x,
            window_size_z=window_size_z,
            window_size_y=window_size_y,
            window_size_x=window_size_x,
        )

        return output

In [11]:
test = Swin3DLayer(64, 4, 256, 1e-6, (4, 4, 4), True)
display(test)
display(test(torch.randn(2, 4, 4, 4, 64)).shape)


[1;35mSwin3DLayer[0m[1m([0m
  [1m([0mlayernorm_before[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m64[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-06[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mmhsa[1m)[0m: [1;35mSwin3DMHSA[0m[1m([0m
    [1m([0mW_qkv[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m192[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
  [1m([0mlayernorm_after[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m64[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-06[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mmlp[1m)[0m: [1;35mSwin3DLayerMLP[0m[1m([0m
    [1m([0mdense1[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m

### Stage layers

In [12]:
# | export


class Swin3DBlock(nn.Module):
    def __init__(self, stage_config):
        super().__init__()

        self.stage_config = stage_config
        self.w_layer = Swin3DLayer(
            stage_config["_out_dim"],
            stage_config["num_heads"],
            stage_config["intermediate_ratio"],
            stage_config["layer_norm_eps"],
            stage_config["window_size"],
            stage_config["use_relative_position_bias"],
        )
        self.sw_layer = Swin3DLayer(
            stage_config["_out_dim"],
            stage_config["num_heads"],
            stage_config["intermediate_ratio"],
            stage_config["layer_norm_eps"],
            stage_config["window_size"],
            stage_config["use_relative_position_bias"],
        )

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)

        layer_outputs = []

        # First layer
        hidden_states = self.w_layer(hidden_states)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        layer_outputs.append(hidden_states)

        # Shift windows
        window_size_z, window_size_y, window_size_x = self.stage_config["window_size"]
        shifts = (window_size_z // 2, window_size_y // 2, window_size_x // 2)
        hidden_states = torch.roll(hidden_states, shifts=shifts, dims=(1, 2, 3))
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        # Second layer
        hidden_states = self.sw_layer(hidden_states)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        # Reverse window shift
        shifts = tuple(-shift for shift in shifts)
        hidden_states = torch.roll(hidden_states, shifts=shifts, dims=(1, 2, 3))
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        layer_outputs.append(hidden_states)

        return hidden_states, layer_outputs

In [13]:
test_stage_config = {
    "depth": 4,
    "_out_dim": 64,
    "num_heads": 4,
    "intermediate_ratio": 4,
    "layer_norm_eps": 1e-6,
    "window_size": (4, 4, 4),
    "use_relative_position_bias": True,
}

test = Swin3DBlock(test_stage_config)
display(test)
o = test(torch.randn(2, 4, 4, 4, 64))
display((o[0].shape, (o[1][0].shape, o[1][1].shape)))


[1;35mSwin3DBlock[0m[1m([0m
  [1m([0mw_layer[1m)[0m: [1;35mSwin3DLayer[0m[1m([0m
    [1m([0mlayernorm_before[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m64[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-06[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mmhsa[1m)[0m: [1;35mSwin3DMHSA[0m[1m([0m
      [1m([0mW_qkv[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m192[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m)[0m
    [1m([0mlayernorm_after[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m64[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-06[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mmlp[1m)[0m: [1;35mSwin3DLayerMLP[0m[1m([0m
      [1m([0mdense1[1m)[0m: [1;35

[1m([0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m, [1m([0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m[1m)[0m[1m)[0m

In [14]:
# | export


class Swin3DPatchMerging(nn.Module):
    def __init__(self, merge_window_size, in_dim, out_dim):
        super().__init__()

        self.merge_window_size = merge_window_size

        in_dim = in_dim * np.prod(merge_window_size)
        self.layer_norm = nn.LayerNorm(in_dim)
        self.proj = nn.Linear(in_dim, out_dim)

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)

        window_size_z, window_size_y, window_size_x = self.merge_window_size

        hidden_states = rearrange(
            hidden_states,
            "b (new_num_patches_z window_size_z) (new_num_patches_y window_size_y) (new_num_patches_x window_size_x) dim -> "
            "b new_num_patches_z new_num_patches_y new_num_patches_x (window_size_z window_size_y window_size_x dim)",
            window_size_z=window_size_z,
            window_size_y=window_size_y,
            window_size_x=window_size_x,
        )

        hidden_states = self.layer_norm(hidden_states)
        hidden_states = self.proj(hidden_states)
        return hidden_states

In [15]:
test_stage_config = {
    "_in_dim": 64,
    "_out_dim": 64 * 3,
    "patch_merging": {
        "merge_window_size": (2, 2, 2),
        "out_dim_ratio": 3,
    },
    "depth": 4,
    "num_heads": 4,
    "intermediate_size": 256,
    "layer_norm_eps": 1e-6,
    "window_size": (4, 4, 4),
    "use_relative_position_bias": True,
}

test = Swin3DPatchMerging(
    test_stage_config["patch_merging"]["merge_window_size"],
    test_stage_config["_in_dim"],
    test_stage_config["_out_dim"],
)
display(test)
display(test(torch.randn(2, 4, 4, 4, 64)).shape)


[1;35mSwin3DPatchMerging[0m[1m([0m
  [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m512[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m512[0m, [33mout_features[0m=[1;36m192[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m, [1;36m2[0m, [1;36m192[0m[1m][0m[1m)[0m

In [16]:
# | export


class Swin3DStage(nn.Module):
    def __init__(self, stage_config):
        super().__init__()

        self.config = stage_config

        self.patch_merging = None
        if stage_config["patch_merging"] is not None:
            self.patch_merging = Swin3DPatchMerging(
                stage_config["patch_merging"]["merge_window_size"],
                stage_config["_in_dim"],
                stage_config["_out_dim"],
            )

        self.blocks = nn.ModuleList(
            [Swin3DBlock(stage_config) for _ in range(stage_config["depth"])],
        )

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)

        if self.patch_merging:
            hidden_states = self.patch_merging(hidden_states)
            # (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, new_dim)

        layer_outputs = []
        for layer_module in self.blocks:
            hidden_states, _layer_outputs = layer_module(hidden_states)
            # (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, new_dim)
            layer_outputs.extend(_layer_outputs)

        return hidden_states, layer_outputs

In [17]:
test_stage_config = {
    "patch_merging": {
        "merge_window_size": (2, 2, 2),
        "out_dim_ratio": 3,
    },
    "depth": 2,
    "_in_dim": 48,
    "_out_dim": 48 * 3,
    "num_heads": 4,
    "intermediate_ratio": 4,
    "layer_norm_eps": 1e-6,
    "window_size": (4, 4, 4),
    "use_relative_position_bias": True,
}

test = Swin3DStage(test_stage_config)
display(test)
o = test(torch.randn(2, 8, 8, 8, 48))
display((o[0].shape, [x.shape for x in o[1]]))


[1;35mSwin3DStage[0m[1m([0m
  [1m([0mpatch_merging[1m)[0m: [1;35mSwin3DPatchMerging[0m[1m([0m
    [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m384[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m384[0m, [33mout_features[0m=[1;36m144[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
  [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m1[0m[1m)[0m: [1;36m2[0m x [1;35mSwin3DBlock[0m[1m([0m
      [1m([0mw_layer[1m)[0m: [1;35mSwin3DLayer[0m[1m([0m
        [1m([0mlayernorm_before[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m144[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-06[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mmhsa[1m)[0m: [1;35mSwin3DMHSA[0m[1m([0m
          [1m([0mW_qkv[1m)[0m:


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m144[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m144[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m144[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m144[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m144[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

### Encoder

In [18]:
# | export


class Swin3DEncoder(nn.Module):
    def __init__(self, config, default_layer_norm_eps=1e-6):
        super().__init__()

        self.stages = nn.ModuleList([Swin3DStage(stage_config) for stage_config in config["stages"]])

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)

        stage_outputs, layer_outputs = [], []
        for stage_module in self.stages:
            hidden_states, _layer_outputs = stage_module(hidden_states)
            # (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, dim)

            stage_outputs.append(hidden_states)
            layer_outputs.extend(_layer_outputs)

        return hidden_states, stage_outputs, layer_outputs

In [19]:
test_config = {
    "stages": [
        {
            "patch_merging": None,
            "_in_dim": 32,
            "_out_dim": 32,
            "depth": 1,
            "num_heads": 4,
            "intermediate_ratio": 4,
            "layer_norm_eps": 1e-6,
            "window_size": (4, 4, 4),
            "use_relative_position_bias": False,
        },
        {
            "patch_merging": {
                "merge_window_size": (2, 2, 2),
                "out_dim_ratio": 3,
            },
            "_in_dim": 32,
            "_out_dim": 32 * 3,
            "depth": 3,
            "num_heads": 4,
            "intermediate_ratio": 4,
            "layer_norm_eps": 1e-6,
            "window_size": (4, 4, 4),
            "use_relative_position_bias": True,
        },
    ],
}

test = Swin3DEncoder(test_config)
display(test)
o = test(torch.randn(2, 16, 16, 16, 32))
display((o[0].shape, [x.shape for x in o[1]], [x.shape for x in o[2]]))


[1;35mSwin3DEncoder[0m[1m([0m
  [1m([0mstages[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwin3DStage[0m[1m([0m
      [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
        [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwin3DBlock[0m[1m([0m
          [1m([0mw_layer[1m)[0m: [1;35mSwin3DLayer[0m[1m([0m
            [1m([0mlayernorm_before[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m32[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-06[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
            [1m([0mmhsa[1m)[0m: [1;35mSwin3DMHSA[0m[1m([0m
              [1m([0mW_qkv[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m32[0m, [33mout_features[0m=[1;36m96[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
              [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m32[0m, [33mout_features[0m=[1;36m32[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m96[0m[1m][0m[1m)[0m,
    [1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m96[0m[1m][0m[1m)[0m[1m][0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m96[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m96[0m[1m][0m[1m)[0m,
        [1;

# Embeddings

### Patch embeddings

In [20]:
# | export


class Swin3DPatchEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()

        patch_size = config["patch_size"]
        num_channels = config["in_channels"]
        dim = config["dim"]

        self.patch_embeddings = nn.Conv3d(
            in_channels=num_channels,
            out_channels=dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

        self.init()

    def init(self):
        nn.init.xavier_normal_(self.patch_embeddings.weight)
        nn.init.constant_(self.patch_embeddings.bias, 0)

    def forward(self, pixel_values: torch.Tensor):
        # pixel_values: (b, c, z, y, x)

        embeddings = self.patch_embeddings(pixel_values)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        return embeddings

In [21]:
test_config = {
    "patch_size": (1, 8, 8),
    "in_channels": 1,
    "dim": 12,
}

test = Swin3DPatchEmbeddings(test_config)
display(test)
o = test(torch.randn(2, 1, 32, 512, 512))
display(o.shape)


[1;35mSwin3DPatchEmbeddings[0m[1m([0m
  [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m12[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m12[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m

### Position embeddings

In [22]:
# | export


def get_3d_position_embeddings(embedding_size, grid_size, patch_size=(1, 1, 1)):
    if embedding_size % 6 != 0:
        raise ValueError("embed_dim must be divisible by 6")

    grid = get_coords_grid(grid_size)
    # (3, d, h, w)

    grid = rearrange(grid, "x d h w -> x 1 d h w")
    # (3, 1, d, h, w)

    omega = torch.arange(embedding_size // 6, dtype=torch.float32)
    omega /= embedding_size / 6.0
    omega = 1.0 / 10000**omega
    # (d // 6)

    patch_multiplier = torch.Tensor(patch_size) / min(patch_size)

    position_embeddings = []
    for i, grid_subset in enumerate(grid):
        grid_subset = grid_subset.reshape(-1)
        out = torch.einsum("m,d->md", grid_subset, omega)

        emb_sin = torch.sin(out)
        emb_cos = torch.cos(out)

        emb = torch.cat([emb_sin, emb_cos], axis=1) * patch_multiplier[i]
        position_embeddings.append(emb)

    position_embeddings = torch.cat(position_embeddings, axis=1)
    # (embedding_size, d * h * w)
    d, h, w = grid_size
    position_embeddings = rearrange(position_embeddings, "(d h w) e -> 1 e d h w", d=d, h=h, w=w)
    # (1, embedding_size, d, h, w)

    return position_embeddings

In [23]:
# | export


def embed_spacings_in_position_embeddings(embeddings: torch.Tensor, spacings: torch.Tensor):
    assert spacings.ndim == 2, "Please provide spacing information for each batch element"
    _, embedding_size, _, _, _ = embeddings.shape
    assert embedding_size % 3 == 0, "To embed spacing info, the embedding size must be divisible by 3"
    embeddings = embeddings * repeat(spacings, f"B S -> B (S {int(embedding_size / 3)}) 1 1 1", S=3)

    return embeddings

In [24]:
# | export


class Swin3DEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config

        dim = config["dim"]

        self.patch_embeddings = Swin3DPatchEmbeddings(config)
        self.layer_norm = nn.LayerNorm(dim)

        self.absolute_position_embeddings = None
        if config["use_absolute_position_embeddings"]:
            grid_size = (
                config["image_size"][0] // config["patch_size"][0],
                config["image_size"][1] // config["patch_size"][1],
                config["image_size"][2] // config["patch_size"][2],
            )
            if config["learnable_absolute_position_embeddings"]:
                self.absolute_position_embeddings = nn.Parameter(
                    torch.randn(1, dim, grid_size[0], grid_size[1], grid_size[2])
                )
            else:
                self.absolute_position_embeddings = get_3d_position_embeddings(dim, grid_size, config["patch_size"])

        self.init()

    def init(self):
        nn.init.constant_(self.layer_norm.weight, 1.0)
        nn.init.constant_(self.layer_norm.bias, 0.0)
        if self.absolute_position_embeddings is not None:
            nn.init.xavier_normal_(self.absolute_position_embeddings)

    def forward(
        self,
        pixel_values: torch.Tensor,
        spacings: torch.Tensor,
        mask_patches: torch.Tensor = None,
        mask_token: torch.Tensor = None,
    ):
        # pixel_values: (b, c, z, y, x)

        embeddings = self.patch_embeddings(pixel_values)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)
        embeddings = rearrange(embeddings, "b d nz ny nx -> b nz ny nx d")
        embeddings = self.layer_norm(embeddings)
        embeddings = rearrange(embeddings, "b nz ny nx d -> b d nz ny nx")
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        if mask_patches is not None:
            # mask_patches (binary mask): (b, num_patches_z, num_patches_y, num_patches_x)
            # mask_token: (1, dim, 1, 1, 1)
            mask_patches = repeat(mask_patches, "b z y x -> b d z y x", d=embeddings.shape[1])
            embeddings = (embeddings * (1 - mask_patches)) + (mask_patches * mask_token)

        if self.absolute_position_embeddings is not None:
            absolute_position_embeddings = self.absolute_position_embeddings.to(embeddings.device)
            # (1, dim, num_patches_z, num_patches_y, num_patches_x)
            if self.config["embed_spacing_info"]:
                absolute_position_embeddings = embed_spacings_in_position_embeddings(
                    absolute_position_embeddings, spacings
                )
                # (b, dim, num_patches_z, num_patches_y, num_patches_x)

            embeddings = embeddings + absolute_position_embeddings
            # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        return embeddings

In [25]:
test_config = {
    "patch_size": (1, 8, 8),
    "in_channels": 1,
    "dim": 36,
    "use_absolute_position_embeddings": True,
    "learnable_absolute_position_embeddings": False,
    "embed_spacing_info": False,
    "image_size": (32, 512, 512),
}

test = Swin3DEmbeddings(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
)
display(o.shape)


[1;35mSwin3DEmbeddings[0m[1m([0m
  [1m([0mpatch_embeddings[1m)[0m: [1;35mSwin3DPatchEmbeddings[0m[1m([0m
    [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m36[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m
  [1m)[0m
  [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m36[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m36[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m

In [26]:
e = get_3d_position_embeddings(6, (2, 2, 2), (1, 1, 1))
display(e.shape)

print("Distance of (0, 0, 0) with other coords")
for i in range(2):
    for j in range(2):
        for k in range(2):
            print((i, j, k), np.linalg.norm(e[0, :, 0, 0, 0] - e[0, :, i, j, k]))

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m1[0m, [1;36m6[0m, [1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m][0m[1m)[0m

Distance of (0, 0, 0) with other coords
(0, 0, 0) 0.0
(0, 0, 1) 0.95885104
(0, 1, 0) 0.95885104
(0, 1, 1) 1.3560202
(1, 0, 0) 0.95885104
(1, 0, 1) 1.3560202
(1, 1, 0) 1.3560202
(1, 1, 1) 1.6607788


# Models

In [27]:
# | export


class Swin3DModel(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.embeddings = Swin3DEmbeddings(config)
        self.encoder = Swin3DEncoder(config)

    def forward(
        self,
        pixel_values: torch.Tensor,
        spacings: torch.Tensor,
        mask_patches: torch.Tensor = None,
        mask_token: torch.Tensor = None,
    ):
        # pixel_values: (b, c, z, y, x)
        # spacings: (b, 3)
        # mask_patches: (num_patches_z, num_patches_y, num_patches_x)

        embeddings = self.embeddings(pixel_values, spacings, mask_patches, mask_token)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        embeddings = rearrange(embeddings, "b e nz ny nx -> b nz ny nx e")
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        encoded, stage_outputs, layer_outputs = self.encoder(embeddings)
        # encoded: (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, dim)
        # stage_outputs, layer_outputs: list of (b, some_num_patches_z, some_num_patches_y, some_num_patches_x, dim)

        encoded = rearrange(encoded, "b nz ny nx d -> b d nz ny nx")
        # (b, dim, new_num_patches_z, new_num_patches_y, new_num_patches_x)

        for i in range(len(stage_outputs)):
            stage_outputs[i] = rearrange(stage_outputs[i], "b nz ny nx d -> b d nz ny nx")
            # (b, dim, some_num_patches_z, some_num_patches_y, some_num_patches_x)

        for i in range(len(layer_outputs)):
            layer_outputs[i] = rearrange(layer_outputs[i], "b nz ny nx d -> b d nz ny nx")
            # (b, dim, some_num_patches_z, some_num_patches_y, some_num_patches_x)

        return encoded, stage_outputs, layer_outputs

In [28]:
test_config = populate_and_validate_config(
    {
        "patch_size": (1, 8, 8),
        "in_channels": 1,
        "use_absolute_position_embeddings": True,
        "learnable_absolute_position_embeddings": False,
        "embed_spacing_info": False,
        "image_size": (32, 512, 512),
        "dim": 36,
        "stages": [
            {
                "patch_merging": None,
                "depth": 1,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": False,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 3,
                },
                "depth": 3,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 3,
                },
                "depth": 1,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
        ],
    }
)

test = Swin3DModel(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
)
display((o[0].shape, [x.shape for x in o[1]], [x.shape for x in o[2]]))


[1;35mSwin3DModel[0m[1m([0m
  [1m([0membeddings[1m)[0m: [1;35mSwin3DEmbeddings[0m[1m([0m
    [1m([0mpatch_embeddings[1m)[0m: [1;35mSwin3DPatchEmbeddings[0m[1m([0m
      [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m36[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m
    [1m)[0m
    [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m36[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
  [1m([0mencoder[1m)[0m: [1;35mSwin3DEncoder[0m[1m([0m
    [1m([0mstages[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwin3DStage[0m[1m([0m
        [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
          [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwin3DBlock[0m[1m([0m
   


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m324[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m36[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m108[0m, [1;36m16[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m324[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m
    [1m][0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m36[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m36[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m108[0m, [1;36m16[0m, [1;36m32[0m, [1;36m32

# For MIM

In [29]:
# | export


class Swin3DMIMDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.image_size = config["image_size"]
        self.in_channels = config["in_channels"]

        dim = config["stages"][-1]["_out_dim"]
        patch_size = config["stages"][-1]["_out_patch_size"]

        out_dim = np.prod(patch_size) * self.in_channels
        self.final_patch_size = patch_size

        self.decoder = nn.Conv3d(dim, out_dim, kernel_size=1)

        self.init()

    def init(self):
        nn.init.xavier_normal_(self.decoder.weight)
        nn.init.constant_(self.decoder.bias, 0)

    def forward(self, encodings: torch.Tensor):
        # encodings: (b, dim, num_patches_z, num_patches_y, num_patches_x)

        decoded = self.decoder(encodings)
        # (b, new_dim, num_patches_z, num_patches_y, num_patches_x)

        decoded = rearrange(
            decoded,
            "b (c pz py px) nz ny nx -> b c (nz pz) (ny py) (nx px)",
            c=self.in_channels,
            pz=self.final_patch_size[0],
            py=self.final_patch_size[1],
            px=self.final_patch_size[2],
        )
        # (b, c, z, y, x)

        return decoded

In [30]:
test_config = populate_and_validate_config(
    {
        "dim": 36,
        "patch_size": (1, 8, 8),
        "in_channels": 1,
        "use_absolute_position_embeddings": True,
        "learnable_absolute_position_embeddings": False,
        "embed_spacing_info": False,
        "image_size": (32, 512, 512),
        "stages": [
            {
                "patch_merging": None,
                "depth": 2,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": False,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 3,
                },
                "depth": 6,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
        ],
    }
)

test = Swin3DMIMDecoder(test_config)
display(test)
display(test(torch.randn(2, 108, 16, 32, 32)).shape)


[1;35mSwin3DMIMDecoder[0m[1m([0m
  [1m([0mdecoder[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m108[0m, [1;36m512[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m32[0m, [1;36m512[0m, [1;36m512[0m[1m][0m[1m)[0m

In [31]:
# | export


class Swin3DMIM(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config

        self.swin = Swin3DModel(config)
        self.decoder = Swin3DMIMDecoder(config)

        self.mask_token = nn.Parameter(torch.randn(1, config["dim"], 1, 1, 1))

        self.init()

    def init(self):
        nn.init.xavier_normal_(self.mask_token)

    def forward(self, pixel_values: torch.Tensor, spacings: torch.Tensor):
        b = pixel_values.shape[0]

        mask_ratio = self.config["mim"]["mask_ratio"]
        mask_grid_size = self.config["mim"]["mask_grid_size"]
        num_patches = np.prod(mask_grid_size)
        mask_patches = []
        for _ in range(b):
            _mask_patches = torch.zeros(num_patches, dtype=torch.int8, device=pixel_values.device)
            _mask_patches[: int(mask_ratio * num_patches)] = 1
            _mask_patches = _mask_patches[torch.randperm(num_patches)]
            _mask_patches = rearrange(
                _mask_patches, "(z y x) -> z y x", z=mask_grid_size[0], y=mask_grid_size[1], x=mask_grid_size[2]
            )
            mask_patches.append(_mask_patches)
        mask_patches: torch.Tensor = torch.stack(mask_patches, dim=0)

        grid_size = tuple([size // patch for size, patch in zip(self.config["image_size"], self.config["patch_size"])])
        assert all(
            [x % y == 0 for x, y in zip(grid_size, mask_grid_size)]
        ), "Mask grid size must divide image grid size"
        mask_patches = repeat(
            mask_patches,
            "b z y x -> b (z gz) (y gy) (x gx)",
            gz=grid_size[0] // mask_grid_size[0],
            gy=grid_size[1] // mask_grid_size[1],
            gx=grid_size[2] // mask_grid_size[2],
        )

        encodings, _, _ = self.swin(pixel_values, spacings, mask_patches, self.mask_token)

        decoded = self.decoder(encodings)

        loss = nn.functional.l1_loss(decoded, pixel_values, reduction="none")
        mask = repeat(
            mask_patches,
            "b z y x -> b (z pz) (y py) (x px)",
            pz=self.config["patch_size"][0],
            py=self.config["patch_size"][1],
            px=self.config["patch_size"][2],
        )
        loss = (loss * mask).sum() / ((mask.sum() + 1e-5) * self.config["in_channels"])

        return decoded, loss, mask

In [32]:
test_config = populate_and_validate_config(
    {
        "dim": 36,
        "patch_size": (1, 8, 8),
        "in_channels": 1,
        "use_absolute_position_embeddings": True,
        "learnable_absolute_position_embeddings": False,
        "embed_spacing_info": False,
        "image_size": (32, 512, 512),
        "stages": [
            {
                "patch_merging": None,
                "depth": 2,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": False,
            },
            {
                "depth": 6,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 3,
                },
            },
        ],
        "mim": {
            "mask_ratio": 0.8,
            "mask_grid_size": (8, 8, 8),
        },
    }
)

test = Swin3DMIM(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
)
display((o[0].shape, o[1], o[2].shape))


[1;35mSwin3DMIM[0m[1m([0m
  [1m([0mswin[1m)[0m: [1;35mSwin3DModel[0m[1m([0m
    [1m([0membeddings[1m)[0m: [1;35mSwin3DEmbeddings[0m[1m([0m
      [1m([0mpatch_embeddings[1m)[0m: [1;35mSwin3DPatchEmbeddings[0m[1m([0m
        [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m36[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m
      [1m)[0m
      [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m36[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
    [1m)[0m
    [1m([0mencoder[1m)[0m: [1;35mSwin3DEncoder[0m[1m([0m
      [1m([0mstages[1m)[0m: [1;35mModuleList[0m[1m([0m
        [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwin3DStage[0m[1m([0m
          [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
 


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m32[0m, [1;36m512[0m, [1;36m512[0m[1m][0m[1m)[0m,
    [1;35mtensor[0m[1m([0m[1;36m2.9981[0m, [33mgrad_fn[0m=[1m<[0m[1;95mDivBackward0[0m[1m>[0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m512[0m, [1;36m512[0m[1m][0m[1m)[0m
[1m)[0m

In [33]:
from neuro_utils.visualize import plot_scans

plot_scans([o[0][0, 0].detach(), o[0][0, 0].detach() * (1 - o[2][0])])

interactive(children=(IntSlider(value=0, description='z', max=31), Output()), _dom_classes=('widget-interact',…

# Some more tests

### Overfitting tests

In [34]:
from tqdm.auto import tqdm

sample_spacings = torch.tensor([[1, 0.1, 0.1], [2, 0.2, 0.2], [3, 0.3, 0.3], [4, 0.4, 0.4], [5, 0.5, 0.5]])
sample_batch = torch.rand(3, 1, 16, 128, 128)
sample_config = populate_and_validate_config(
    {
        "patch_size": (1, 4, 4),
        "dim": 12,
        "in_channels": 1,
        "use_absolute_position_embeddings": True,
        "learnable_absolute_position_embeddings": False,
        "embed_spacing_info": False,
        "image_size": (16, 128, 128),
        "stages": [
            {
                "patch_merging": None,
                "depth": 1,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 4,
                },
                "depth": 3,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 4,
                },
                "depth": 1,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
        ],
        "mim": {
            "mask_ratio": 0.7,
            "mask_grid_size": (8, 8, 8),
        },
    }
)

model = Swin3DMIM(sample_config)

sum(x.numel() for x in model.swin.parameters()), sum(x.numel() for x in model.decoder.parameters())

[1m([0m[1;36m1156612[0m, [1;36m197632[0m[1m)[0m

In [35]:
from neuro_utils.describe import describe_model

describe_model(model)

Total Parameters: 1,354,256
+---------------------------------------------------------------------------+------------+
|                                   Module                                  | Parameters |
+---------------------------------------------------------------------------+------------+
|                                 mask_token                                |     12     |
|          swin.embeddings.patch_embeddings.patch_embeddings.weight         |    192     |
|           swin.embeddings.patch_embeddings.patch_embeddings.bias          |     12     |
|                     swin.embeddings.layer_norm.weight                     |     12     |
|                      swin.embeddings.layer_norm.bias                      |     12     |
|       swin.encoder.stages.0.blocks.0.w_layer.layernorm_before.weight      |     12     |
|        swin.encoder.stages.0.blocks.0.w_layer.layernorm_before.bias       |     12     |
|  swin.encoder.stages.0.blocks.0.w_layer.mhsa.relative_positi

In [36]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

In [37]:
sample_batch = sample_batch.cuda()
sample_spacings = sample_spacings.cuda()
model = model.cuda()

In [38]:
for i in tqdm(range(200)):
    optimizer.zero_grad()
    output = model(sample_batch, sample_spacings)
    print(f"Loss: {output[1]:f}\tLR: {scheduler.get_last_lr()[0]:f}")
    output[1].backward()
    optimizer.step()
    scheduler.step()

  0%|          | 0/200 [00:00<?, ?it/s]

Loss: 2.223207	LR: 0.500000
Loss: 2.255578	LR: 0.500000
Loss: 3.108296	LR: 0.500000
Loss: 3.395616	LR: 0.500000
Loss: 4.185133	LR: 0.500000
Loss: 3.663680	LR: 0.450000
Loss: 3.548665	LR: 0.450000
Loss: 3.401049	LR: 0.450000
Loss: 3.566421	LR: 0.450000
Loss: 3.479365	LR: 0.450000
Loss: 3.367604	LR: 0.405000
Loss: 3.601780	LR: 0.405000
Loss: 3.210312	LR: 0.405000
Loss: 2.804216	LR: 0.405000
Loss: 3.156502	LR: 0.405000
Loss: 3.162444	LR: 0.364500
Loss: 3.256133	LR: 0.364500
Loss: 3.092483	LR: 0.364500
Loss: 2.965374	LR: 0.364500
Loss: 2.843133	LR: 0.364500
Loss: 2.761711	LR: 0.328050
Loss: 2.883082	LR: 0.328050
Loss: 2.789261	LR: 0.328050
Loss: 2.474174	LR: 0.328050
Loss: 2.722669	LR: 0.328050
Loss: 2.551768	LR: 0.295245
Loss: 2.528475	LR: 0.295245
Loss: 2.450451	LR: 0.295245
Loss: 2.684007	LR: 0.295245
Loss: 2.483982	LR: 0.295245
Loss: 2.675984	LR: 0.265721
Loss: 2.218823	LR: 0.265721
Loss: 2.600932	LR: 0.265721
Loss: 1.994672	LR: 0.265721
Loss: 2.093859	LR: 0.265721
Loss: 2.016809	LR: 0

# nbdev

In [39]:
!nbdev_export

# Rough work