In [1]:
# | default_exp swin3d_old

# Imports

In [2]:
# | export

import torch
import numpy as np
from torch import nn
from einops import rearrange, repeat

from vision_architectures.swin3d import Swin3DMIM as Swin3DNewMIM
from vision_architectures.swin3d import populate_and_validate_config

# Architecture

### Basic Layers

In [3]:
# | export


def get_coords_grid(grid_size):
    d, h, w = grid_size

    grid_d = torch.arange(d, dtype=torch.int32)
    grid_h = torch.arange(h, dtype=torch.int32)
    grid_w = torch.arange(w, dtype=torch.int32)

    grid = torch.meshgrid(grid_w, grid_h, grid_d, indexing="ij")
    grid = torch.stack(grid, axis=0)
    # (3, d, h, w)

    return grid

In [4]:
# | export


class Swin3DMHSA(nn.Module):
    def __init__(self, dim, num_heads, window_size, use_relative_position_bias):
        super().__init__()

        assert dim % num_heads == 0, "dimension must be divisible by number of heads"

        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size

        self.per_head_dim = int(dim // num_heads)

        self.W_q = nn.Linear(dim, dim)
        self.W_k = nn.Linear(dim, dim)
        self.W_v = nn.Linear(dim, dim)

        self.proj = nn.Linear(dim, dim)

        # TODO: Add embed_spacing_info functionality
        # TODO: Add dropout everywhere
        # TODO: Combine qkv into one linear layer
        self.use_relative_position_bias = use_relative_position_bias
        if use_relative_position_bias:
            relative_limits = (2 * window_size[0] - 1, 2 * window_size[1] - 1, 2 * window_size[2] - 1)

            self.relative_position_bias_table = nn.Parameter(torch.randn(num_heads, np.prod(relative_limits)))

            coords = get_coords_grid(window_size)
            coords_flatten = rearrange(
                coords, "three_dimensional d h w -> three_dimensional (d h w)", three_dimensional=3
            )
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()
            relative_coords[:, :, 0] += window_size[0] - 1
            relative_coords[:, :, 1] += window_size[1] - 1
            relative_coords[:, :, 2] += window_size[2] - 1
            relative_position_index: torch.Tensor = (
                relative_coords[:, :, 0] * relative_limits[1] * relative_limits[2]
                + relative_coords[:, :, 1] * relative_limits[2]
                + relative_coords[:, :, 2]
            )
            self.relative_position_index = relative_position_index.flatten()

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (windowed_b, window_size_z window_size_y window_size_x, dim)
        _, num_patches_z, num_patches_y, num_patches_x, _ = hidden_states.shape

        query = rearrange(
            self.W_q(hidden_states),
            "b nz ny nx (num_heads d) -> b num_heads (nz ny nx) d",
            num_heads=self.num_heads,
        )
        key = rearrange(
            self.W_k(hidden_states),
            "b nz ny nx (num_heads d) -> b num_heads (nz ny nx) d",
            num_heads=self.num_heads,
        )
        value = rearrange(
            self.W_v(hidden_states),
            "b nz ny nx (num_heads d) -> b num_heads (nz ny nx) d",
            num_heads=self.num_heads,
        )
        # num_patches = window_size_z window_size_y window_size_x
        # Each is (windowed_b, num_heads, num_patches, per_head_dim)

        attention_scores = query @ rearrange(key, "b num_heads n d -> b num_heads d n")
        attention_scores = attention_scores / (self.per_head_dim**0.5)
        # (windowed_b, num_heads, num_patches, num_patches)

        if self.use_relative_position_bias:
            relative_position_bias = self.relative_position_bias_table[:, self.relative_position_index]
            relative_position_bias = relative_position_bias.reshape(
                1, np.prod(self.window_size), np.prod(self.window_size), -1
            )
            relative_position_bias = relative_position_bias.permute(0, 3, 1, 2).contiguous()
            attention_scores = attention_scores + relative_position_bias

        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        # (windowed_b, num_heads, num_patches, num_patches)

        context = attention_probs @ value
        # (windowed_b, num_heads, num_patches, per_head_dim)
        context = rearrange(
            context,
            "b num_heads (num_patches_z num_patches_y num_patches_x) d -> "
            "b num_patches_z num_patches_y num_patches_x (num_heads d)",
            num_patches_z=num_patches_z,
            num_patches_y=num_patches_y,
            num_patches_x=num_patches_x,
        )
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        context = self.proj(context)
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        return context

In [5]:
test = Swin3DMHSA(54, 6, (4, 4, 4), True)
display(test)
display(test(torch.randn(2, 4, 4, 4, 54)).shape)


[1;35mSwin3DMHSA[0m[1m([0m
  [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m54[0m[1m][0m[1m)[0m

In [6]:
# | export


class Swin3DLayerMLP(nn.Module):
    def __init__(self, dim, intermediate_ratio):
        super().__init__()
        self.dense1 = nn.Linear(dim, dim * intermediate_ratio)
        self.act = nn.GELU()
        self.dense2 = nn.Linear(dim * intermediate_ratio, dim)

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (windowed_b, window_size_z window_size_y window_size_x, dim)
        hidden_states = self.dense1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.dense2(hidden_states)
        return hidden_states

In [7]:
test = Swin3DLayerMLP(64, 256)
display(test)
display(test(torch.randn(2, 4, 4, 4, 64)).shape)


[1;35mSwin3DLayerMLP[0m[1m([0m
  [1m([0mdense1[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m16384[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mact[1m)[0m: [1;35mGELU[0m[1m([0m[33mapproximate[0m=[32m'none'[0m[1m)[0m
  [1m([0mdense2[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m16384[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m

In [8]:
# | export


class Swin3DLayer(nn.Module):
    def __init__(self, dim, num_heads, intermediate_ratio, layer_norm_eps, window_size, use_relative_position_bias):
        super().__init__()

        self.window_size = window_size

        self.layernorm_before = nn.LayerNorm(dim, eps=layer_norm_eps)
        self.mhsa = Swin3DMHSA(dim, num_heads, window_size, use_relative_position_bias)
        self.layernorm_after = nn.LayerNorm(dim, eps=layer_norm_eps)
        self.mlp = Swin3DLayerMLP(dim, intermediate_ratio)

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)
        _, num_patches_z, num_patches_y, num_patches_x, _ = hidden_states.shape

        # Perform windowing
        window_size_z, window_size_y, window_size_x = self.window_size
        num_windows_z, num_windows_y, num_windows_x = (
            num_patches_z // window_size_z,
            num_patches_y // window_size_y,
            num_patches_x // window_size_x,
        )
        hidden_states = rearrange(
            hidden_states,
            "b (num_windows_z window_size_z) (num_windows_y window_size_y) (num_windows_x window_size_x) dim -> "
            "(b num_windows_z num_windows_y num_windows_x) window_size_z window_size_y window_size_x dim ",
            num_windows_z=num_windows_z,
            num_windows_y=num_windows_y,
            num_windows_x=num_windows_x,
            window_size_z=window_size_z,
            window_size_y=window_size_y,
            window_size_x=window_size_x,
        )

        res_connection1 = hidden_states
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        hidden_states = self.layernorm_before(hidden_states)
        hidden_states = self.mhsa(hidden_states)
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        res_connection2 = hidden_states + res_connection1
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        hidden_states = self.layernorm_after(res_connection2)
        hidden_states = self.mlp(hidden_states)
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        hidden_states = hidden_states + res_connection2
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        # Undo windowing
        output = rearrange(
            hidden_states,
            "(b num_windows_z num_windows_y num_windows_x) window_size_z window_size_y window_size_x dim -> "
            "b (num_windows_z window_size_z) (num_windows_y window_size_y) (num_windows_x window_size_x) dim",
            num_windows_z=num_windows_z,
            num_windows_y=num_windows_y,
            num_windows_x=num_windows_x,
            window_size_z=window_size_z,
            window_size_y=window_size_y,
            window_size_x=window_size_x,
        )

        return output

In [9]:
test = Swin3DLayer(64, 4, 256, 1e-6, (4, 4, 4), True)
display(test)
display(test(torch.randn(2, 4, 4, 4, 64)).shape)


[1;35mSwin3DLayer[0m[1m([0m
  [1m([0mlayernorm_before[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m64[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-06[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mmhsa[1m)[0m: [1;35mSwin3DMHSA[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
  [1m([0mlayernorm_after[1m)[0m: [1;

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m

### Stage layers

In [10]:
# | export


class Swin3DBlock(nn.Module):
    def __init__(self, stage_config):
        super().__init__()

        self.stage_config = stage_config
        self.w_layer = Swin3DLayer(
            stage_config["dim"],
            stage_config["num_heads"],
            stage_config["intermediate_ratio"],
            stage_config["layer_norm_eps"],
            stage_config["window_size"],
            stage_config["use_relative_position_bias"],
        )
        self.sw_layer = Swin3DLayer(
            stage_config["dim"],
            stage_config["num_heads"],
            stage_config["intermediate_ratio"],
            stage_config["layer_norm_eps"],
            stage_config["window_size"],
            stage_config["use_relative_position_bias"],
        )

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)

        layer_outputs = []

        # First layer
        hidden_states = self.w_layer(hidden_states)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        layer_outputs.append(hidden_states)

        # Shift windows
        window_size_z, window_size_y, window_size_x = self.stage_config["window_size"]
        shifts = (window_size_z // 2, window_size_y // 2, window_size_x // 2)
        hidden_states = torch.roll(hidden_states, shifts=shifts, dims=(1, 2, 3))
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        # Second layer
        hidden_states = self.sw_layer(hidden_states)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        # Reverse window shift
        shifts = tuple(-shift for shift in shifts)
        hidden_states = torch.roll(hidden_states, shifts=shifts, dims=(1, 2, 3))
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        layer_outputs.append(hidden_states)

        return hidden_states, layer_outputs

In [11]:
test_stage_config = {
    "depth": 4,
    "dim": 64,
    "num_heads": 4,
    "intermediate_ratio": 4,
    "layer_norm_eps": 1e-6,
    "window_size": (4, 4, 4),
    "use_relative_position_bias": True,
}

test = Swin3DBlock(test_stage_config)
display(test)
o = test(torch.randn(2, 4, 4, 4, 64))
display((o[0].shape, (o[1][0].shape, o[1][1].shape)))


[1;35mSwin3DBlock[0m[1m([0m
  [1m([0mw_layer[1m)[0m: [1;35mSwin3DLayer[0m[1m([0m
    [1m([0mlayernorm_before[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m64[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-06[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mmhsa[1m)[0m: [1;35mSwin3DMHSA[0m[1m([0m
      [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;9

[1m([0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m, [1m([0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m[1m)[0m[1m)[0m

In [12]:
# | export


class Swin3DPatchMerging(nn.Module):
    def __init__(self, merge_window_size, dim, out_dim_ratio):
        super().__init__()

        self.merge_window_size = merge_window_size

        in_dim = dim * np.prod(merge_window_size)
        self.layer_norm = nn.LayerNorm(in_dim)
        self.proj = nn.Linear(in_dim, dim * out_dim_ratio)

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)

        window_size_z, window_size_y, window_size_x = self.merge_window_size

        hidden_states = rearrange(
            hidden_states,
            "b (new_num_patches_z window_size_z) (new_num_patches_y window_size_y) (new_num_patches_x window_size_x) dim -> "
            "b new_num_patches_z new_num_patches_y new_num_patches_x (window_size_z window_size_y window_size_x dim)",
            window_size_z=window_size_z,
            window_size_y=window_size_y,
            window_size_x=window_size_x,
        )

        hidden_states = self.layer_norm(hidden_states)
        hidden_states = self.proj(hidden_states)
        return hidden_states

In [13]:
test_stage_config = {
    "depth": 4,
    "dim": 64,
    "num_heads": 4,
    "intermediate_size": 256,
    "layer_norm_eps": 1e-6,
    "window_size": (4, 4, 4),
    "use_relative_position_bias": True,
    "patch_merging": {
        "merge_window_size": (2, 2, 2),
        "out_dim_ratio": 3,
    },
}

test = Swin3DPatchMerging(
    test_stage_config["patch_merging"]["merge_window_size"],
    test_stage_config["dim"],
    test_stage_config["patch_merging"]["out_dim_ratio"],
)
display(test)
display(test(torch.randn(2, 4, 4, 4, 64)).shape)


[1;35mSwin3DPatchMerging[0m[1m([0m
  [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m512[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m512[0m, [33mout_features[0m=[1;36m192[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m, [1;36m2[0m, [1;36m192[0m[1m][0m[1m)[0m

In [14]:
# | export


class Swin3DStage(nn.Module):
    def __init__(self, stage_config):
        super().__init__()

        self.config = stage_config

        self.blocks = nn.ModuleList(
            [Swin3DBlock(stage_config) for _ in range(stage_config["depth"])],
        )

        self.patch_merging = None
        if stage_config["patch_merging"] is not None:
            self.patch_merging = Swin3DPatchMerging(
                stage_config["patch_merging"]["merge_window_size"],
                stage_config["dim"],
                stage_config["patch_merging"]["out_dim_ratio"],
            )

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)

        layer_outputs = []
        for layer_module in self.blocks:
            hidden_states, _layer_outputs = layer_module(hidden_states)
            # (b, num_patches_z, num_patches_y, num_patches_x, dim)
            layer_outputs.extend(_layer_outputs)

        if self.patch_merging:
            hidden_states = self.patch_merging(hidden_states)
            # (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, new_dim)

        return hidden_states, layer_outputs

In [15]:
test_stage_config = {
    "depth": 4,
    "dim": 64,
    "num_heads": 4,
    "intermediate_ratio": 4,
    "layer_norm_eps": 1e-6,
    "window_size": (4, 4, 4),
    "use_relative_position_bias": True,
    "patch_merging": {
        "merge_window_size": (2, 2, 2),
        "out_dim_ratio": 3,
    },
}

test = Swin3DStage(test_stage_config)
display(test)
o = test(torch.randn(2, 4, 4, 4, 64))
display((o[0].shape, [x.shape for x in o[1]]))


[1;35mSwin3DStage[0m[1m([0m
  [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m3[0m[1m)[0m: [1;36m4[0m x [1;35mSwin3DBlock[0m[1m([0m
      [1m([0mw_layer[1m)[0m: [1;35mSwin3DLayer[0m[1m([0m
        [1m([0mlayernorm_before[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m64[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-06[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mmhsa[1m)[0m: [1;35mSwin3DMHSA[0m[1m([0m
          [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m64[0m, [33mbia


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m, [1;36m2[0m, [1;36m192[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size

### Encoder

In [16]:
# | export


class Swin3DEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.stages = nn.ModuleList([Swin3DStage(stage_config) for stage_config in config["stages"]])

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)

        stage_outputs, layer_outputs = [], []
        for stage_module in self.stages:
            hidden_states, _layer_outputs = stage_module(hidden_states)
            # (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, dim)

            stage_outputs.append(hidden_states)
            layer_outputs.extend(_layer_outputs)

        return hidden_states, stage_outputs, layer_outputs

In [17]:
test_config = {"stages": []}
test_config["stages"].append(
    {
        "depth": 2,
        "dim": 32,
        "num_heads": 4,
        "intermediate_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": False,
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "out_dim_ratio": 3,
        },
    }
)
test_config["stages"].append(
    {
        "depth": 6,
        "dim": test_config["stages"][-1]["dim"] * test_config["stages"][-1]["patch_merging"]["out_dim_ratio"],
        "num_heads": 4,
        "intermediate_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "out_dim_ratio": 3,
        },
    }
)

test = Swin3DEncoder(test_config)
display(test)
o = test(torch.randn(2, 16, 16, 16, 32))
display((o[0].shape, [x.shape for x in o[1]], [x.shape for x in o[2]]))


[1;35mSwin3DEncoder[0m[1m([0m
  [1m([0mstages[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwin3DStage[0m[1m([0m
      [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
        [1m([0m[1;36m0[0m-[1;36m1[0m[1m)[0m: [1;36m2[0m x [1;35mSwin3DBlock[0m[1m([0m
          [1m([0mw_layer[1m)[0m: [1;35mSwin3DLayer[0m[1m([0m
            [1m([0mlayernorm_before[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m32[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-06[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
            [1m([0mmhsa[1m)[0m: [1;35mSwin3DMHSA[0m[1m([0m
              [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m32[0m, [33mout_features[0m=[1;36m32[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
              [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m32[0m, [33mout_features[0m=[1;36m32[0m, [33mbias[0m=[3;92mT


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m288[0m[1m][0m[1m)[0m,
    [1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m, [1;36m96[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m288[0m[1m][0m[1m)[0m[1m][0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m16[0m, [1;36m16[0m, [1;36m16[0m, [1;36m32[0m[1m][0m[1m)[0m,
       

# Embeddings

### Patch embeddings

In [18]:
# | export


class Swin3DPatchEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()

        patch_size = config["patch_size"]
        num_channels = config["in_channels"]
        dim = config["stages"][0]["dim"]

        self.patch_embeddings = nn.Conv3d(
            in_channels=num_channels,
            out_channels=dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

    def forward(self, pixel_values: torch.Tensor):
        # pixel_values: (b, c, z, y, x)

        embeddings = self.patch_embeddings(pixel_values)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        return embeddings

In [19]:
test_config = {
    "patch_size": (1, 8, 8),
    "in_channels": 1,
    "stages": [{"dim": 12}],
}

test = Swin3DPatchEmbeddings(test_config)
display(test)
o = test(torch.randn(2, 1, 32, 512, 512))
display(o.shape)


[1;35mSwin3DPatchEmbeddings[0m[1m([0m
  [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m12[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m12[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m

### Position embeddings

In [20]:
# | export


def get_3d_position_embeddings(embedding_size, grid_size, patch_size=(1, 1, 1)):
    if embedding_size % 6 != 0:
        raise ValueError("embed_dim must be divisible by 6")

    grid = get_coords_grid(grid_size)
    # (3, d, h, w)

    grid = rearrange(grid, "x d h w -> x 1 d h w")
    # (3, 1, d, h, w)

    omega = torch.arange(embedding_size // 6, dtype=torch.float32)
    omega /= embedding_size / 6.0
    omega = 1.0 / 10000**omega
    # (d // 6)

    patch_multiplier = torch.Tensor(patch_size) / min(patch_size)

    position_embeddings = []
    for i, grid_subset in enumerate(grid):
        grid_subset = grid_subset.reshape(-1)
        out = torch.einsum("m,d->md", grid_subset, omega)

        emb_sin = torch.sin(out)
        emb_cos = torch.cos(out)

        emb = torch.cat([emb_sin, emb_cos], axis=1) * patch_multiplier[i]
        position_embeddings.append(emb)

    position_embeddings = torch.cat(position_embeddings, axis=1)
    # (embedding_size, d * h * w)
    d, h, w = grid_size
    position_embeddings = rearrange(position_embeddings, "(d h w) e -> 1 e d h w", d=d, h=h, w=w)
    # (1, embedding_size, d, h, w)

    return position_embeddings

In [21]:
# | export


def embed_spacings_in_position_embeddings(embeddings: torch.Tensor, spacings: torch.Tensor):
    assert spacings.ndim == 2, "Please provide spacing information for each batch element"
    _, embedding_size, _, _, _ = embeddings.shape
    assert embedding_size % 3 == 0, "To embed spacing info, the embedding size must be divisible by 3"
    embeddings = embeddings * repeat(spacings, f"B S -> B (S {int(embedding_size / 3)}) 1 1 1", S=3)

    return embeddings

In [22]:
# | export


class Swin3DEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config

        dim = config["stages"][0]["dim"]

        self.patch_embeddings = Swin3DPatchEmbeddings(config)
        self.layer_norm = nn.LayerNorm(dim)

        self.absolute_position_embeddings = None
        if config["use_absolute_position_embeddings"]:
            grid_size = (
                config["image_size"][0] // config["patch_size"][0],
                config["image_size"][1] // config["patch_size"][1],
                config["image_size"][2] // config["patch_size"][2],
            )
            if config["learnable_absolute_position_embeddings"]:
                self.absolute_position_embeddings = nn.Parameter(
                    torch.randn(1, dim, grid_size[0], grid_size[1], grid_size[2])
                )
            else:
                self.absolute_position_embeddings = get_3d_position_embeddings(dim, grid_size, config["patch_size"])

    def forward(
        self,
        pixel_values: torch.Tensor,
        spacings: torch.Tensor,
        mask_patches: torch.Tensor = None,
        mask_token: torch.Tensor = None,
    ):
        # pixel_values: (b, c, z, y, x)

        embeddings = self.patch_embeddings(pixel_values)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)
        embeddings = rearrange(embeddings, "b d nz ny nx -> b nz ny nx d")
        embeddings = self.layer_norm(embeddings)
        embeddings = rearrange(embeddings, "b nz ny nx d -> b d nz ny nx")
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        if mask_patches is not None:
            # mask_patches (binary mask): (b, num_patches_z, num_patches_y, num_patches_x)
            # mask_token: (1, dim, 1, 1, 1)
            mask_patches = repeat(mask_patches, "b z y x -> b d z y x", d=embeddings.shape[1])
            embeddings = (embeddings * (1 - mask_patches)) + (mask_patches * mask_token)

        if self.absolute_position_embeddings is not None:
            absolute_position_embeddings = self.absolute_position_embeddings.to(embeddings.device)
            # (1, dim, num_patches_z, num_patches_y, num_patches_x)
            if self.config["embed_spacing_info"]:
                absolute_position_embeddings = embed_spacings_in_position_embeddings(
                    absolute_position_embeddings, spacings
                )
                # (b, dim, num_patches_z, num_patches_y, num_patches_x)

            embeddings = embeddings + absolute_position_embeddings
            # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        return embeddings

In [23]:
test_config = {
    "patch_size": (1, 8, 8),
    "in_channels": 1,
    "stages": [{"dim": 36}],
    "use_absolute_position_embeddings": True,
    "learnable_absolute_position_embeddings": False,
    "embed_spacing_info": False,
    "image_size": (32, 512, 512),
}

test = Swin3DEmbeddings(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
)
display(o.shape)


[1;35mSwin3DEmbeddings[0m[1m([0m
  [1m([0mpatch_embeddings[1m)[0m: [1;35mSwin3DPatchEmbeddings[0m[1m([0m
    [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m36[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m
  [1m)[0m
  [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m36[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m36[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m

In [24]:
e = get_3d_position_embeddings(6, (2, 2, 2), (1, 1, 1))
display(e.shape)

print("Distance of (0, 0, 0) with other coords")
for i in range(2):
    for j in range(2):
        for k in range(2):
            print((i, j, k), np.linalg.norm(e[0, :, 0, 0, 0] - e[0, :, i, j, k]))

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m1[0m, [1;36m6[0m, [1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m][0m[1m)[0m

Distance of (0, 0, 0) with other coords
(0, 0, 0) 0.0
(0, 0, 1) 0.95885104
(0, 1, 0) 0.95885104
(0, 1, 1) 1.3560202
(1, 0, 0) 0.95885104
(1, 0, 1) 1.3560202
(1, 1, 0) 1.3560202
(1, 1, 1) 1.6607786


# Models

In [39]:
# | export


class Swin3DModel(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.embeddings = Swin3DEmbeddings(config)
        self.encoder = Swin3DEncoder(config)

    def forward(
        self,
        pixel_values: torch.Tensor,
        spacings: torch.Tensor,
        mask_patches: torch.Tensor = None,
        mask_token: torch.Tensor = None,
    ):
        # pixel_values: (b, c, z, y, x)
        # spacings: (b, 3)
        # mask_patches: (num_patches_z, num_patches_y, num_patches_x)

        embeddings = self.embeddings(pixel_values, spacings, mask_patches, mask_token)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        embeddings = rearrange(embeddings, "b e nz ny nx -> b nz ny nx e")
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)

        stage_outputs = [embeddings]

        encoded, stage_outputs_, layer_outputs = self.encoder(embeddings)
        stage_outputs.extend(stage_outputs_)
        # encoded: (b, new_num_patches_z, new_num_patches_y, new_num_patches_x, dim)
        # stage_outputs, layer_outputs: list of (b, some_num_patches_z, some_num_patches_y, some_num_patches_x, dim)

        encoded = rearrange(encoded, "b nz ny nx d -> b d nz ny nx")
        # (b, dim, new_num_patches_z, new_num_patches_y, new_num_patches_x)

        for i in range(len(stage_outputs)):
            stage_outputs[i] = rearrange(stage_outputs[i], "b nz ny nx d -> b d nz ny nx")
            # (b, dim, some_num_patches_z, some_num_patches_y, some_num_patches_x)

        for i in range(len(layer_outputs)):
            layer_outputs[i] = rearrange(layer_outputs[i], "b nz ny nx d -> b d nz ny nx")
            # (b, dim, some_num_patches_z, some_num_patches_y, some_num_patches_x)

        return encoded, stage_outputs, layer_outputs

In [26]:
test_config = {
    "patch_size": (1, 8, 8),
    "in_channels": 1,
    "use_absolute_position_embeddings": True,
    "learnable_absolute_position_embeddings": False,
    "embed_spacing_info": False,
    "image_size": (32, 512, 512),
    "stages": [],
}
test_config["stages"].append(
    {
        "depth": 2,
        "dim": 36,
        "num_heads": 4,
        "intermediate_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": False,
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "out_dim_ratio": 3,
        },
    }
)
test_config["stages"].append(
    {
        "depth": 6,
        "dim": test_config["stages"][-1]["dim"] * test_config["stages"][-1]["patch_merging"]["out_dim_ratio"],
        "num_heads": 4,
        "intermediate_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "out_dim_ratio": 3,
        },
    }
)

test = Swin3DModel(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
)
display((o[0].shape, [x.shape for x in o[1]], [x.shape for x in o[2]]))


[1;35mSwin3DModel[0m[1m([0m
  [1m([0membeddings[1m)[0m: [1;35mSwin3DEmbeddings[0m[1m([0m
    [1m([0mpatch_embeddings[1m)[0m: [1;35mSwin3DPatchEmbeddings[0m[1m([0m
      [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m36[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m
    [1m)[0m
    [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m36[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
  [1m([0mencoder[1m)[0m: [1;35mSwin3DEncoder[0m[1m([0m
    [1m([0mstages[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwin3DStage[0m[1m([0m
        [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
          [1m([0m[1;36m0[0m-[1;36m1[0m[1m)[0m: [1;36m2[0m x [1;35m


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m324[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m36[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m108[0m, [1;36m16[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m324[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m
    [1m][0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m36[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m36[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m36[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64

# For MIM

In [27]:
# | export


class Swin3DMIMDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.image_size = config["image_size"]
        self.in_channels = config["in_channels"]

        dim = config["stages"][0]["dim"]
        patch_size = config["patch_size"]
        for stage in config["stages"]:
            if stage["patch_merging"] is not None:
                dim *= stage["patch_merging"]["out_dim_ratio"]
                patch_size = tuple(
                    [patch * window for patch, window in zip(patch_size, stage["patch_merging"]["merge_window_size"])]
                )

        out_dim = np.prod(patch_size) * self.in_channels
        self.final_patch_size = patch_size

        self.decoder = nn.Conv3d(dim, out_dim, kernel_size=1)

    def forward(self, encodings: torch.Tensor):
        # encodings: (b, dim, num_patches_z, num_patches_y, num_patches_x)

        decoded = self.decoder(encodings)
        # (b, new_dim, num_patches_z, num_patches_y, num_patches_x)

        decoded = rearrange(
            decoded,
            "b (c pz py px) nz ny nx -> b c (nz pz) (ny py) (nx px)",
            c=self.in_channels,
            pz=self.final_patch_size[0],
            py=self.final_patch_size[1],
            px=self.final_patch_size[2],
        )
        # (b, c, z, y, x)

        return decoded

In [28]:
test_config = {
    "patch_size": (1, 8, 8),
    "in_channels": 1,
    "use_absolute_position_embeddings": True,
    "learnable_absolute_position_embeddings": False,
    "embed_spacing_info": False,
    "image_size": (32, 512, 512),
    "stages": [],
}
test_config["stages"].append(
    {
        "depth": 2,
        "dim": 36,
        "num_heads": 4,
        "intermediate_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": False,
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "out_dim_ratio": 3,
        },
    }
)
test_config["stages"].append(
    {
        "depth": 6,
        "dim": test_config["stages"][-1]["dim"] * test_config["stages"][-1]["patch_merging"]["out_dim_ratio"],
        "num_heads": 4,
        "intermediate_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "out_dim_ratio": 3,
        },
    }
)

test = Swin3DMIMDecoder(test_config)
display(test)
display(test(torch.randn(2, 324, 8, 16, 16)).shape)


[1;35mSwin3DMIMDecoder[0m[1m([0m
  [1m([0mdecoder[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m324[0m, [1;36m4096[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m32[0m, [1;36m512[0m, [1;36m512[0m[1m][0m[1m)[0m

In [29]:
# | export


class Swin3DMIM(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config

        self.swin = Swin3DModel(config)
        self.decoder = Swin3DMIMDecoder(config)

        self.mask_token = nn.Parameter(torch.randn(1, config["stages"][0]["dim"], 1, 1, 1))

    def forward(self, pixel_values: torch.Tensor, spacings: torch.Tensor):
        b = pixel_values.shape[0]

        mask_ratio = self.config["mim"]["mask_ratio"]
        mask_grid_size = self.config["mim"]["mask_grid_size"]
        num_patches = np.prod(mask_grid_size)
        mask_patches = []
        for _ in range(b):
            _mask_patches = torch.zeros(num_patches, dtype=torch.int8, device=pixel_values.device)
            _mask_patches[: int(mask_ratio * num_patches)] = 1
            _mask_patches = _mask_patches[torch.randperm(num_patches)]
            _mask_patches = rearrange(
                _mask_patches, "(z y x) -> z y x", z=mask_grid_size[0], y=mask_grid_size[1], x=mask_grid_size[2]
            )
            mask_patches.append(_mask_patches)
        mask_patches: torch.Tensor = torch.stack(mask_patches, dim=0)

        grid_size = tuple([size // patch for size, patch in zip(self.config["image_size"], self.config["patch_size"])])
        assert all(
            [x % y == 0 for x, y in zip(grid_size, mask_grid_size)]
        ), "Mask grid size must divide image grid size"
        mask_patches = repeat(
            mask_patches,
            "b z y x -> b (z gz) (y gy) (x gx)",
            gz=grid_size[0] // mask_grid_size[0],
            gy=grid_size[1] // mask_grid_size[1],
            gx=grid_size[2] // mask_grid_size[2],
        )

        encodings, _, _ = self.swin(pixel_values, spacings, mask_patches, self.mask_token)

        decoded = self.decoder(encodings)

        loss = nn.functional.l1_loss(decoded, pixel_values, reduction="none")
        mask = repeat(
            mask_patches,
            "b z y x -> b (z pz) (y py) (x px)",
            pz=self.config["patch_size"][0],
            py=self.config["patch_size"][1],
            px=self.config["patch_size"][2],
        )
        loss = (loss * mask).sum() / ((mask.sum() + 1e-5) * self.config["in_channels"])

        return decoded, loss, mask

In [30]:
test_config = {
    "patch_size": (1, 8, 8),
    "in_channels": 1,
    "use_absolute_position_embeddings": True,
    "learnable_absolute_position_embeddings": False,
    "embed_spacing_info": False,
    "image_size": (32, 512, 512),
    "stages": [],
    "mim": {
        "mask_ratio": 0.8,
        "mask_grid_size": (8, 8, 8),
    },
}
test_config["stages"].append(
    {
        "depth": 2,
        "dim": 36,
        "num_heads": 4,
        "intermediate_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": False,
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "out_dim_ratio": 3,
        },
    }
)
test_config["stages"].append(
    {
        "depth": 6,
        "dim": test_config["stages"][-1]["dim"] * test_config["stages"][-1]["patch_merging"]["out_dim_ratio"],
        "num_heads": 4,
        "intermediate_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "out_dim_ratio": 3,
        },
    }
)

test = Swin3DMIM(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
)
display((o[0].shape, o[1], o[2].shape))


[1;35mSwin3DMIM[0m[1m([0m
  [1m([0mswin[1m)[0m: [1;35mSwin3DModel[0m[1m([0m
    [1m([0membeddings[1m)[0m: [1;35mSwin3DEmbeddings[0m[1m([0m
      [1m([0mpatch_embeddings[1m)[0m: [1;35mSwin3DPatchEmbeddings[0m[1m([0m
        [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m36[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m
      [1m)[0m
      [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m36[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
    [1m)[0m
    [1m([0mencoder[1m)[0m: [1;35mSwin3DEncoder[0m[1m([0m
      [1m([0mstages[1m)[0m: [1;35mModuleList[0m[1m([0m
        [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwin3DStage[0m[1m([0m
          [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
 


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m32[0m, [1;36m512[0m, [1;36m512[0m[1m][0m[1m)[0m,
    [1;35mtensor[0m[1m([0m[1;36m1.6864[0m, [33mgrad_fn[0m=[1m<[0m[1;95mDivBackward0[0m[1m>[0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m512[0m, [1;36m512[0m[1m][0m[1m)[0m
[1m)[0m

In [31]:
from neuro_utils.visualize import plot_scans

plot_scans([o[0][0, 0].detach(), o[0][0, 0].detach() * (1 - o[2][0])])

interactive(children=(IntSlider(value=0, description='z', max=31), Output()), _dom_classes=('widget-interact',…

# Some more tests

In [32]:
test_config = {
    "patch_size": (1, 8, 8),
    "in_channels": 1,
    "use_absolute_position_embeddings": True,
    "learnable_absolute_position_embeddings": False,
    "embed_spacing_info": False,
    "image_size": (32, 512, 512),
    "stages": [],
    "mask_ratio": 0.8,
}
test_config["stages"].append(
    {
        "depth": 1,
        "dim": 48,
        "num_heads": 4,
        "intermediate_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": False,
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "out_dim_ratio": 2,
        },
    }
)
test_config["stages"].append(
    {
        "depth": 1,
        "dim": test_config["stages"][-1]["dim"] * test_config["stages"][-1]["patch_merging"]["out_dim_ratio"],
        "num_heads": 4,
        "intermediate_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": False,
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "out_dim_ratio": 2,
        },
    }
)
test_config["stages"].append(
    {
        "depth": 3,
        "dim": test_config["stages"][-1]["dim"] * test_config["stages"][-1]["patch_merging"]["out_dim_ratio"],
        "num_heads": 4,
        "intermediate_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": False,
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "out_dim_ratio": 2,
        },
    }
)
test_config["stages"].append(
    {
        "depth": 1,
        "dim": test_config["stages"][-1]["dim"] * test_config["stages"][-1]["patch_merging"]["out_dim_ratio"],
        "num_heads": 4,
        "intermediate_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
        "patch_merging": None,
    }
)

test = Swin3DModel(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
)
display((o[0].shape, [x.shape for x in o[1]], [x.shape for x in o[2]]))


[1;35mSwin3DModel[0m[1m([0m
  [1m([0membeddings[1m)[0m: [1;35mSwin3DEmbeddings[0m[1m([0m
    [1m([0mpatch_embeddings[1m)[0m: [1;35mSwin3DPatchEmbeddings[0m[1m([0m
      [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m48[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m
    [1m)[0m
    [1m([0mlayer_norm[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m48[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
  [1m([0mencoder[1m)[0m: [1;35mSwin3DEncoder[0m[1m([0m
    [1m([0mstages[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwin3DStage[0m[1m([0m
        [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
          [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwin3DBlock[0m[1m([0m
   


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m384[0m, [1;36m4[0m, [1;36m8[0m, [1;36m8[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m48[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m96[0m, [1;36m16[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m192[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m384[0m, [1;36m4[0m, [1;36m8[0m, [1;36m8[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m384[0m, [1;36m4[0m, [1;36m8[0m, [1;36m8[0m[1m][0m[1m)[0m
    [1m][0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m48[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m

In [33]:
from neuro_utils.describe import describe_model

describe_model(test)

* 'smart_union' has been removed


Total Parameters: 7,284,488
+----------------------------------------------------------------------+------------+
|                                Module                                | Parameters |
+----------------------------------------------------------------------+------------+
|         embeddings.patch_embeddings.patch_embeddings.weight          |   3,072    |
|          embeddings.patch_embeddings.patch_embeddings.bias           |     48     |
|                     embeddings.layer_norm.weight                     |     48     |
|                      embeddings.layer_norm.bias                      |     48     |
|      encoder.stages.0.blocks.0.w_layer.layernorm_before.weight       |     48     |
|       encoder.stages.0.blocks.0.w_layer.layernorm_before.bias        |     48     |
|          encoder.stages.0.blocks.0.w_layer.mhsa.W_q.weight           |   2,304    |
|           encoder.stages.0.blocks.0.w_layer.mhsa.W_q.bias            |     48     |
|          encoder.stages.

### Overfitting tests

In [34]:
from tqdm.auto import tqdm

sample_spacings = torch.tensor([[1, 0.1, 0.1], [2, 0.2, 0.2], [3, 0.3, 0.3], [4, 0.4, 0.4], [5, 0.5, 0.5]])
sample_batch = torch.rand(3, 1, 16, 128, 128)
sample_config = {
    "patch_size": (1, 4, 4),
    "in_channels": 1,
    "use_absolute_position_embeddings": True,
    "learnable_absolute_position_embeddings": False,
    "embed_spacing_info": False,
    "image_size": (16, 128, 128),
    "stages": [],
    "mim": {
        "mask_ratio": 0.7,
        "mask_grid_size": (8, 8, 8),
    },
}
sample_config["stages"].append(
    {
        "depth": 1,
        "dim": 12,
        "num_heads": 4,
        "intermediate_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "out_dim_ratio": 4,
        },
    }
)
sample_config["stages"].append(
    {
        "depth": 3,
        "dim": sample_config["stages"][-1]["dim"] * sample_config["stages"][-1]["patch_merging"]["out_dim_ratio"],
        "num_heads": 4,
        "intermediate_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
        "patch_merging": {
            "merge_window_size": (2, 2, 2),
            "out_dim_ratio": 4,
        },
    }
)
sample_config["stages"].append(
    {
        "depth": 1,
        "dim": sample_config["stages"][-1]["dim"] * sample_config["stages"][-1]["patch_merging"]["out_dim_ratio"],
        "num_heads": 4,
        "intermediate_ratio": 4,
        "layer_norm_eps": 1e-6,
        "window_size": (4, 4, 4),
        "use_relative_position_bias": True,
        "patch_merging": None,
    }
)

model = Swin3DMIM(sample_config)

sum(x.numel() for x in model.swin.parameters()), sum(x.numel() for x in model.decoder.parameters())

[1m([0m[1;36m1156612[0m, [1;36m197632[0m[1m)[0m

In [35]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

In [36]:
sample_batch = sample_batch.cuda()
sample_spacings = sample_spacings.cuda()
model = model.cuda()

In [37]:
for i in tqdm(range(200)):
    optimizer.zero_grad()
    output = model(sample_batch, sample_spacings)
    print(f"Loss: {output[1]:f}\tLR: {scheduler.get_last_lr()[0]:f}")
    output[1].backward()
    optimizer.step()
    scheduler.step()

  0%|          | 0/200 [00:00<?, ?it/s]

Loss: 1.820429	LR: 0.500000
Loss: 1.987945	LR: 0.500000
Loss: 1.967435	LR: 0.500000
Loss: 2.308640	LR: 0.500000
Loss: 2.168925	LR: 0.500000
Loss: 2.207896	LR: 0.450000
Loss: 2.152745	LR: 0.450000
Loss: 2.326722	LR: 0.450000
Loss: 2.213367	LR: 0.450000
Loss: 2.072788	LR: 0.450000
Loss: 2.212431	LR: 0.405000
Loss: 2.075707	LR: 0.405000
Loss: 2.134526	LR: 0.405000
Loss: 2.043120	LR: 0.405000
Loss: 1.940446	LR: 0.405000
Loss: 1.967878	LR: 0.364500
Loss: 1.656061	LR: 0.364500
Loss: 1.712492	LR: 0.364500
Loss: 1.583969	LR: 0.364500
Loss: 1.320167	LR: 0.364500
Loss: 1.395069	LR: 0.328050
Loss: 1.224852	LR: 0.328050
Loss: 1.678762	LR: 0.328050
Loss: 1.253203	LR: 0.328050
Loss: 1.520050	LR: 0.328050
Loss: 1.030859	LR: 0.295245
Loss: 1.008602	LR: 0.295245
Loss: 1.087134	LR: 0.295245
Loss: 1.035728	LR: 0.295245
Loss: 1.025624	LR: 0.295245
Loss: 1.029193	LR: 0.265721
Loss: 0.838549	LR: 0.265721
Loss: 0.813711	LR: 0.265721
Loss: 0.810683	LR: 0.265721
Loss: 0.822165	LR: 0.265721
Loss: 0.841314	LR: 0

# Migration from old to new

In [119]:
# | export


def migrate_checkpoint_to_new(checkpoint, prefix="model"):
    model_config = checkpoint["hyper_parameters"]["model_config"]

    model_config["dim"] = model_config["stages"][0]["dim"]
    for stage in model_config["stages"]:
        del stage["dim"]

    for i in range(len(model_config["stages"]) - 1, 0, -1):
        model_config["stages"][i]["patch_merging"] = model_config["stages"][i - 1]["patch_merging"]
    model_config["stages"][0]["patch_merging"] = None

    checkpoint["hyper_parameters"]["model_config"] = populate_and_validate_config(
        checkpoint["hyper_parameters"]["model_config"]
    )

    state_dict = checkpoint["state_dict"]

    for stage_i in range(len(model_config["stages"])):
        for block_i in range(model_config["stages"][stage_i]["depth"]):
            for layer_name in {"w_layer", "sw_layer"}:
                for state_type in {"weight", "bias"}:
                    q = state_dict.pop(
                        f"{prefix}.swin.encoder.stages.{stage_i}.blocks.{block_i}.{layer_name}.mhsa.W_q.{state_type}"
                    )
                    k = state_dict.pop(
                        f"{prefix}.swin.encoder.stages.{stage_i}.blocks.{block_i}.{layer_name}.mhsa.W_k.{state_type}"
                    )
                    v = state_dict.pop(
                        f"{prefix}.swin.encoder.stages.{stage_i}.blocks.{block_i}.{layer_name}.mhsa.W_v.{state_type}"
                    )
                    qkv = torch.concat([q, k, v], dim=0)
                    state_dict[
                        f"{prefix}.swin.encoder.stages.{stage_i}.blocks.{block_i}.{layer_name}.mhsa.W_qkv.{state_type}"
                    ] = qkv

    for i in range(len(model_config["stages"]) - 1, 0, -1):
        state_dict[f"{prefix}.swin.encoder.stages.{i}.patch_merging.layer_norm.weight"] = state_dict.pop(
            f"{prefix}.swin.encoder.stages.{i-1}.patch_merging.layer_norm.weight"
        )
        state_dict[f"{prefix}.swin.encoder.stages.{i}.patch_merging.layer_norm.bias"] = state_dict.pop(
            f"{prefix}.swin.encoder.stages.{i-1}.patch_merging.layer_norm.bias"
        )
        state_dict[f"{prefix}.swin.encoder.stages.{i}.patch_merging.proj.weight"] = state_dict.pop(
            f"{prefix}.swin.encoder.stages.{i-1}.patch_merging.proj.weight"
        )
        state_dict[f"{prefix}.swin.encoder.stages.{i}.patch_merging.proj.bias"] = state_dict.pop(
            f"{prefix}.swin.encoder.stages.{i-1}.patch_merging.proj.bias"
        )

    return checkpoint

In [120]:
old = torch.load(
    r"/cache/expdata1/arjun/checkpoints/ct_pretraining/v11__2024_04_19/version_0/checkpoints/epoch=16.ckpt",
    map_location="cpu",
)

new = migrate_checkpoint_to_new(old)

display(new["hyper_parameters"]["model_config"].toDict())
model = Swin3DNewMIM(new["hyper_parameters"]["model_config"])

state_dict = new["state_dict"]
for k in state_dict.copy():
    state_dict[k.removeprefix("model.")] = state_dict.pop(k)

model.load_state_dict(state_dict)
describe_model(model)


[1m{[0m
    [32m'patch_size'[0m: [1m([0m[1;36m1[0m, [1;36m4[0m, [1;36m4[0m[1m)[0m,
    [32m'in_channels'[0m: [1;36m1[0m,
    [32m'use_absolute_position_embeddings'[0m: [3;92mTrue[0m,
    [32m'learnable_absolute_position_embeddings'[0m: [3;91mFalse[0m,
    [32m'embed_spacing_info'[0m: [3;91mFalse[0m,
    [32m'image_size'[0m: [1m([0m[1;36m32[0m, [1;36m256[0m, [1;36m256[0m[1m)[0m,
    [32m'stages'[0m: [1m[[0m
        [1m{[0m
            [32m'depth'[0m: [1;36m1[0m,
            [32m'num_heads'[0m: [1;36m4[0m,
            [32m'intermediate_ratio'[0m: [1;36m4[0m,
            [32m'layer_norm_eps'[0m: [1;36m1e-06[0m,
            [32m'window_size'[0m: [1m([0m[1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m)[0m,
            [32m'use_relative_position_bias'[0m: [3;92mTrue[0m,
            [32m'patch_merging'[0m: [3;35mNone[0m,
            [32m'_in_dim'[0m: [1;36m12[0m,
            [32m'_in_patch_size'[0m: [1m([0m[1;3

Total Parameters: 24,484,616
+---------------------------------------------------------------------------+------------+
|                                   Module                                  | Parameters |
+---------------------------------------------------------------------------+------------+
|                                 mask_token                                |     12     |
|          swin.embeddings.patch_embeddings.patch_embeddings.weight         |    192     |
|           swin.embeddings.patch_embeddings.patch_embeddings.bias          |     12     |
|                     swin.embeddings.layer_norm.weight                     |     12     |
|                      swin.embeddings.layer_norm.bias                      |     12     |
|       swin.encoder.stages.0.blocks.0.w_layer.layernorm_before.weight      |     12     |
|        swin.encoder.stages.0.blocks.0.w_layer.layernorm_before.bias       |     12     |
|  swin.encoder.stages.0.blocks.0.w_layer.mhsa.relative_posit

# nbdev

In [1]:
!nbdev_export

# Rough work