In [1]:
# | default_exp swinv2_3d_with_sdpa

# Imports

In [2]:
# | export

import numpy as np
import torch
import torch.nn.functional as F

from einops import rearrange, repeat
from torch import nn
from torch.nn.attention import SDPBackend, sdpa_kernel
from vision_architectures.swinv2_3d import (
    populate_and_validate_config,
    get_coords_grid,
    SwinV23DMHSA as SwinV23DMHSAWithoutSDPA,
    SwinV23DLayerMLP,
    SwinV23DLayer as SwinV23DLayerWithoutSDPA,
    SwinV23DBlock as SwinV23DBlockWithoutSDPA,
    SwinV23DPatchMerging,
    SwinV23DStage as SwinV23DStageWithoutSDPA,
    SwinV23DEncoder as SwinV23DEncoderWithoutSDPA,
    SwinV23DPatchEmbeddings,
    get_3d_position_embeddings,
    embed_spacings_in_position_embeddings,
    SwinV23DEmbeddings,
    SwinV23DModel as SwinV23DModelWithoutSDPA,
    SwinV23DMIMDecoder,
    SwinV23DMIM as SwinV23DMIMWithoutSDPA,
)

# Modify MHSA

In [3]:
# | export


class SwinV23DMHSA(SwinV23DMHSAWithoutSDPA):
    def __init__(
        self,
        dim,
        num_heads,
        window_size,
        use_relative_position_bias,
        attn_drop_prob=0.0,
        proj_drop_prob=0.0,
    ):
        super().__init__(dim, num_heads, window_size, use_relative_position_bias, attn_drop_prob, proj_drop_prob)

        # Remove attention dropout layer as that is handled automatically, but store the dropout for later
        del self.attn_drop
        self.attn_drop_prob = attn_drop_prob

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (windowed_b, window_size_z window_size_y window_size_x, dim)
        _, num_patches_z, num_patches_y, num_patches_x, _ = hidden_states.shape

        query, key, value = rearrange(
            self.W_qkv(hidden_states),
            "b nz ny nx (n num_heads d) -> n b num_heads (nz ny nx) d",
            n=3,
            num_heads=self.num_heads,
        )
        # num_patches = window_size_z * window_size_y * window_size_x
        # Each is (windowed_b, num_heads, num_patches, per_head_dim)

        logit_scale = torch.clamp(self.logit_scale, max=np.log(1.0 / 0.01)).exp()

        query_normalized = F.normalize(query, dim=-1)
        key_normalized = F.normalize(key, dim=-1)

        query_normalized_and_scaled = query_normalized * logit_scale  # Scale the query beforehand

        relative_position_bias = None
        if self.use_relative_position_bias:
            relative_position_bias = self.calculate_relative_position_bias()

        context = F.scaled_dot_product_attention(
            query_normalized_and_scaled,
            key_normalized,
            value,
            attn_mask=relative_position_bias,  # Use this as a way to introduce relative position bias
            dropout_p=self.attn_drop_prob,
            is_causal=False,
            scale=1.0,  # Already scaled the vectors
        )
        # (windowed_b, num_heads, num_patches, per_head_dim)
        
        context = rearrange(
            context,
            "b num_heads (num_patches_z num_patches_y num_patches_x) d -> "
            "b num_patches_z num_patches_y num_patches_x (num_heads d)",
            num_patches_z=num_patches_z,
            num_patches_y=num_patches_y,
            num_patches_x=num_patches_x,
        )
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        context = self.proj(context)
        context = self.proj_drop(context)
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        return context

In [4]:
test = SwinV23DMHSA(54, 6, (4, 4, 4), True)
display(test)
display(test(torch.randn(2, 4, 4, 4, 54)).shape)


[1;35mSwinV23DMHSA[0m[1m([0m
  [1m([0mW_qkv[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m162[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mcpb_mlp[1m)[0m: [1;35mSequential[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m3[0m, [33mout_features[0m=[1;36m512[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0m[1;36m1[0m[1m)[0m: [1;35mReLU[0m[1m([0m[33minplace[0m=[3;92mTrue[0m[1m)[0m
    [1m([0m[1;36m2[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m512[0m, [33mout_features[0m=[1;36m6[0m, [33mbias[0m=[3;91mFalse

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m54[0m[1m][0m[1m)[0m

In [5]:
for use_relative_position_bias in [True, False]:
    m1 = SwinV23DMHSA(54, 6, (4, 4, 4), use_relative_position_bias)
    m2 = SwinV23DMHSAWithoutSDPA(54, 6, (4, 4, 4), use_relative_position_bias)

    m1.load_state_dict(m2.state_dict())
    m1.eval(), m2.eval()

    example_input = torch.randn(2, 4, 4, 4, 54)
    o1 = m1(example_input)
    o2 = m2(example_input)

    assert torch.allclose(o1, o2, atol=1e-6), (o1 - o2).abs().max()

# Modify other classes accordingly

In [6]:
# | export


class SwinV23DLayer(SwinV23DLayerWithoutSDPA):
    def __init__(
        self,
        dim,
        num_heads,
        intermediate_ratio,
        layer_norm_eps,
        window_size,
        use_relative_position_bias,
        attn_drop_prob=0.0,
        proj_drop_prob=0.0,
        mlp_drop_prob=0.0,
    ):
        super().__init__(
            dim,
            num_heads,
            intermediate_ratio,
            layer_norm_eps,
            window_size,
            use_relative_position_bias,
            attn_drop_prob,
            proj_drop_prob,
            mlp_drop_prob,
        )

        self.mhsa = SwinV23DMHSA(
            dim, num_heads, window_size, use_relative_position_bias, attn_drop_prob, proj_drop_prob
        )

In [7]:
# | export


class SwinV23DBlock(SwinV23DBlockWithoutSDPA):
    def __init__(self, stage_config):
        super().__init__(stage_config)

        self.stage_config = stage_config
        self.w_layer = SwinV23DLayer(
            stage_config["_out_dim"],
            stage_config["num_heads"],
            stage_config["intermediate_ratio"],
            stage_config["layer_norm_eps"],
            stage_config["window_size"],
            stage_config["use_relative_position_bias"],
            stage_config.get("attn_drop_prob", 0.0),
            stage_config.get("proj_drop_prob", 0.0),
            stage_config.get("mlp_drop_prob", 0.0),
        )
        self.sw_layer = SwinV23DLayer(
            stage_config["_out_dim"],
            stage_config["num_heads"],
            stage_config["intermediate_ratio"],
            stage_config["layer_norm_eps"],
            stage_config["window_size"],
            stage_config["use_relative_position_bias"],
            stage_config.get("attn_drop_prob", 0.0),
            stage_config.get("proj_drop_prob", 0.0),
            stage_config.get("mlp_drop_prob", 0.0),
        )

In [8]:
# | export


class SwinV23DStage(SwinV23DStageWithoutSDPA):
    def __init__(self, stage_config):
        super().__init__(stage_config)

        self.blocks = nn.ModuleList(
            [SwinV23DBlock(stage_config) for _ in range(stage_config["depth"])],
        )

In [9]:
# | export


class SwinV23DEncoder(SwinV23DEncoderWithoutSDPA):
    def __init__(self, config):
        super().__init__(config)

        self.stages = nn.ModuleList([SwinV23DStage(stage_config) for stage_config in config["stages"]])

In [10]:
# | export


class SwinV23DModel(SwinV23DModelWithoutSDPA):
    def __init__(self, config):
        super().__init__(config)

        self.encoder = SwinV23DEncoder(config)

In [11]:
# | export


class SwinV23DMIM(SwinV23DMIMWithoutSDPA):
    def __init__(self, config):
        super().__init__(config)

        self.swin = SwinV23DModel(config)

# Some more tests

### Overfitting tests

In [12]:
from tqdm.auto import tqdm

sample_spacings = torch.tensor([[1, 0.1, 0.1], [2, 0.2, 0.2], [3, 0.3, 0.3], [4, 0.4, 0.4], [5, 0.5, 0.5]])
sample_batch = torch.rand(3, 1, 16, 128, 128)
sample_config = populate_and_validate_config(
    {
        "patch_size": (1, 4, 4),
        "dim": 12,
        "in_channels": 1,
        "use_absolute_position_embeddings": True,
        "learnable_absolute_position_embeddings": False,
        "embed_spacing_info": False,
        "image_size": (16, 128, 128),
        "drop_prob": 0.2,
        "stages": [
            {
                "patch_merging": None,
                "depth": 1,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
                "attn_drop_prob": 0.2,
                "proj_drop_prob": 0.2,
                "mlp_drop_prob": 0.2,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 4,
                },
                "depth": 3,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 4,
                },
                "depth": 1,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
        ],
        "mim": {
            "mask_ratio": 0.7,
            "mask_grid_size": (8, 8, 8),
        },
    }
)

model = SwinV23DMIM(sample_config)

sum(x.numel() for x in model.swin.parameters()), sum(x.numel() for x in model.decoder.parameters())

[1m([0m[1;36m1183892[0m, [1;36m197632[0m[1m)[0m

In [13]:
from neuro_utils.describe import describe_model

describe_model(model)

Total Parameters: 1,381,536
+---------------------------------------------------------------+------------+
|                             Module                            | Parameters |
+---------------------------------------------------------------+------------+
|                           mask_token                          |     12     |
|    swin.embeddings.patch_embeddings.patch_embeddings.weight   |    192     |
|     swin.embeddings.patch_embeddings.patch_embeddings.bias    |     12     |
|               swin.embeddings.layer_norm.weight               |     12     |
|                swin.embeddings.layer_norm.bias                |     12     |
|    swin.encoder.stages.0.blocks.0.w_layer.mhsa.logit_scale    |     4      |
|    swin.encoder.stages.0.blocks.0.w_layer.mhsa.W_qkv.weight   |    432     |
|     swin.encoder.stages.0.blocks.0.w_layer.mhsa.W_qkv.bias    |     36     |
|    swin.encoder.stages.0.blocks.0.w_layer.mhsa.proj.weight    |    144     |
|     swin.encoder.stage

In [14]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

In [15]:
sample_batch = sample_batch.cuda()
sample_spacings = sample_spacings.cuda()
model = model.cuda()

In [16]:
for i in tqdm(range(200)):
    optimizer.zero_grad()
    output = model(sample_batch, sample_spacings)
    print(f"Loss: {output[1]:f}\tLR: {scheduler.get_last_lr()[0]:f}")
    output[1].backward()
    optimizer.step()
    scheduler.step()

  0%|          | 0/200 [00:00<?, ?it/s]

Loss: 3.286826	LR: 0.500000
Loss: 3.551189	LR: 0.500000
Loss: 5.618439	LR: 0.500000
Loss: 3.360113	LR: 0.500000
Loss: 3.476582	LR: 0.500000
Loss: 2.366527	LR: 0.450000
Loss: 1.915956	LR: 0.450000
Loss: 1.969440	LR: 0.450000
Loss: 1.796858	LR: 0.450000
Loss: 1.840450	LR: 0.450000
Loss: 1.803444	LR: 0.405000
Loss: 1.675246	LR: 0.405000
Loss: 1.577672	LR: 0.405000
Loss: 1.529282	LR: 0.405000
Loss: 1.477986	LR: 0.405000
Loss: 1.581834	LR: 0.364500
Loss: 1.315247	LR: 0.364500
Loss: 1.345085	LR: 0.364500
Loss: 1.300409	LR: 0.364500
Loss: 1.340105	LR: 0.364500
Loss: 1.328799	LR: 0.328050
Loss: 1.138287	LR: 0.328050
Loss: 1.161520	LR: 0.328050
Loss: 1.210390	LR: 0.328050
Loss: 1.176165	LR: 0.328050
Loss: 1.152816	LR: 0.295245
Loss: 1.039268	LR: 0.295245
Loss: 1.030901	LR: 0.295245
Loss: 1.058415	LR: 0.295245
Loss: 1.032585	LR: 0.295245
Loss: 1.005308	LR: 0.265721
Loss: 0.917888	LR: 0.265721
Loss: 0.892791	LR: 0.265721
Loss: 0.905334	LR: 0.265721
Loss: 0.891857	LR: 0.265721
Loss: 0.905171	LR: 0

In [17]:
for name, param in model.named_parameters():
    if param.grad is None:
        print(name)

# nbdev

In [18]:
!nbdev_export

# Rough work