In [1]:
# | default_exp swinv2_3d_with_flash_attention

# Imports

In [2]:
# | export

import torch
import numpy as np
from einops import rearrange, repeat
from torch import nn
import torch.nn.functional as F

from vision_architectures.swinv2_3d import (
    populate_and_validate_config,
    get_coords_grid,
    SwinV23DMHSA as SwinV23DMHSAWithoutFlashAttention,
    SwinV23DLayerMLP,
    SwinV23DLayer as SwinV23DLayerWithoutFlashAttention,
    SwinV23DBlock as SwinV23DBlockWithoutFlashAttention,
    SwinV23DPatchMerging,
    SwinV23DStage as SwinV23DStageWithoutFlashAttention,
    SwinV23DEncoder as SwinV23DEncoderWithoutFlashAttention,
    SwinV23DPatchEmbeddings,
    get_3d_position_embeddings,
    embed_spacings_in_position_embeddings,
    SwinV23DEmbeddings,
    SwinV23DModel as SwinV23DModelWithoutFlashAttention,
    SwinV23DMIMDecoder,
    SwinV23DMIM as SwinV23DMIMWithoutFlashAttention,
)

# Modify MHSA

In [5]:
# | export


class SwinV23DMHSA(SwinV23DMHSAWithoutFlashAttention):
    def __init__(
        self,
        dim,
        num_heads,
        window_size,
        use_relative_position_bias,
        attn_drop_prob=0.0,
        proj_drop_prob=0.0,
    ):
        super().__init__(dim, num_heads, window_size, use_relative_position_bias, attn_drop_prob, proj_drop_prob)

        # Remove attention dropout layer as that is handled automatically, but store the dropout for later
        del self.attn_drop
        self.attn_drop_prob = attn_drop_prob

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (windowed_b, window_size_z window_size_y window_size_x, dim)
        _, num_patches_z, num_patches_y, num_patches_x, _ = hidden_states.shape

        query, key, value = rearrange(
            self.W_qkv(hidden_states),
            "b nz ny nx (n num_heads d) -> n b num_heads (nz ny nx) d",
            n=3,
            num_heads=self.num_heads,
        )
        # num_patches = window_size_z * window_size_y * window_size_x
        # Each is (windowed_b, num_heads, num_patches, per_head_dim)

        query_normalized = F.normalize(query, dim=-1)
        key_normalized = F.normalize(key, dim=-1)

        relative_position_bias = None
        if self.use_relative_position_bias:
            relative_position_bias = self.calculate_relative_position_bias()

        logit_scale = torch.clamp(self.logit_scale, max=np.log(1.0 / 0.01)).exp()

        context = F.scaled_dot_product_attention(
            query_normalized,
            key_normalized,
            value,
            attn_mask=relative_position_bias,  # Use this as a way to introduce relative position bias
            dropout_p=self.attn_drop_prob,
            is_causal=False,
            # scale=logit_scale, # TODO: Allow learnable scaling per head, otherwise it won't work
        )
        # (windowed_b, num_heads, num_patches, per_head_dim)
        context = rearrange(
            context,
            "b num_heads (num_patches_z num_patches_y num_patches_x) d -> "
            "b num_patches_z num_patches_y num_patches_x (num_heads d)",
            num_patches_z=num_patches_z,
            num_patches_y=num_patches_y,
            num_patches_x=num_patches_x,
        )
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        context = self.proj(context)
        context = self.proj_drop(context)
        # (windowed_b, window_size_z window_size_y window_size_x, dim)

        return context

In [6]:
test = SwinV23DMHSA(54, 6, (4, 4, 4), True)
display(test)
display(test(torch.randn(2, 4, 4, 4, 54)).shape)


[1;35mSwinV23DMHSA[0m[1m([0m
  [1m([0mW_qkv[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m162[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mcpb_mlp[1m)[0m: [1;35mSequential[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m3[0m, [33mout_features[0m=[1;36m512[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0m[1;36m1[0m[1m)[0m: [1;35mReLU[0m[1m([0m[33minplace[0m=[3;92mTrue[0m[1m)[0m
    [1m([0m[1;36m2[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m512[0m, [33mout_features[0m=[1;36m6[0m, [33mbias[0m=[3;91mFalse

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m, [1;36m54[0m[1m][0m[1m)[0m

# Modify other classes accordingly

In [7]:
# | export


class SwinV23DLayer(SwinV23DLayerWithoutFlashAttention):
    def __init__(
        self,
        dim,
        num_heads,
        intermediate_ratio,
        layer_norm_eps,
        window_size,
        use_relative_position_bias,
        attn_drop_prob=0.0,
        proj_drop_prob=0.0,
        mlp_drop_prob=0.0,
    ):
        super().__init__(
            dim,
            num_heads,
            intermediate_ratio,
            layer_norm_eps,
            window_size,
            use_relative_position_bias,
            attn_drop_prob,
            proj_drop_prob,
            mlp_drop_prob,
        )

        self.mhsa = SwinV23DMHSA(
            dim, num_heads, window_size, use_relative_position_bias, attn_drop_prob, proj_drop_prob
        )

In [8]:
# | export


class SwinV23DBlock(SwinV23DBlockWithoutFlashAttention):
    def __init__(self, stage_config):
        super().__init__(stage_config)

        self.stage_config = stage_config
        self.w_layer = SwinV23DLayer(
            stage_config["_out_dim"],
            stage_config["num_heads"],
            stage_config["intermediate_ratio"],
            stage_config["layer_norm_eps"],
            stage_config["window_size"],
            stage_config["use_relative_position_bias"],
            stage_config.get("attn_drop_prob", 0.0),
            stage_config.get("proj_drop_prob", 0.0),
            stage_config.get("mlp_drop_prob", 0.0),
        )
        self.sw_layer = SwinV23DLayer(
            stage_config["_out_dim"],
            stage_config["num_heads"],
            stage_config["intermediate_ratio"],
            stage_config["layer_norm_eps"],
            stage_config["window_size"],
            stage_config["use_relative_position_bias"],
            stage_config.get("attn_drop_prob", 0.0),
            stage_config.get("proj_drop_prob", 0.0),
            stage_config.get("mlp_drop_prob", 0.0),
        )

In [9]:
# | export


class SwinV23DStage(SwinV23DStageWithoutFlashAttention):
    def __init__(self, stage_config):
        super().__init__(stage_config)

        self.blocks = nn.ModuleList(
            [SwinV23DBlock(stage_config) for _ in range(stage_config["depth"])],
        )

In [10]:
# | export


class SwinV23DEncoder(SwinV23DEncoderWithoutFlashAttention):
    def __init__(self, config):
        super().__init__(config)

        self.stages = nn.ModuleList([SwinV23DStage(stage_config) for stage_config in config["stages"]])

In [11]:
# | export


class SwinV23DModel(SwinV23DModelWithoutFlashAttention):
    def __init__(self, config):
        super().__init__(config)

        self.encoder = SwinV23DEncoder(config)

In [12]:
class SwinV23DMIM(SwinV23DMIMWithoutFlashAttention):
    def __init__(self, config):
        super().__init__(config)

        self.swin = SwinV23DModel(config)

# Some more tests

### Overfitting tests

In [13]:
from tqdm.auto import tqdm

sample_spacings = torch.tensor([[1, 0.1, 0.1], [2, 0.2, 0.2], [3, 0.3, 0.3], [4, 0.4, 0.4], [5, 0.5, 0.5]])
sample_batch = torch.rand(3, 1, 16, 128, 128)
sample_config = populate_and_validate_config(
    {
        "patch_size": (1, 4, 4),
        "dim": 12,
        "in_channels": 1,
        "use_absolute_position_embeddings": True,
        "learnable_absolute_position_embeddings": False,
        "embed_spacing_info": False,
        "image_size": (16, 128, 128),
        "drop_prob": 0.2,
        "stages": [
            {
                "patch_merging": None,
                "depth": 1,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
                "attn_drop_prob": 0.2,
                "proj_drop_prob": 0.2,
                "mlp_drop_prob": 0.2,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 4,
                },
                "depth": 3,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
            {
                "patch_merging": {
                    "merge_window_size": (2, 2, 2),
                    "out_dim_ratio": 4,
                },
                "depth": 1,
                "num_heads": 4,
                "intermediate_ratio": 4,
                "layer_norm_eps": 1e-6,
                "window_size": (4, 4, 4),
                "use_relative_position_bias": True,
            },
        ],
        "mim": {
            "mask_ratio": 0.7,
            "mask_grid_size": (8, 8, 8),
        },
    }
)

model = SwinV23DMIM(sample_config)

sum(x.numel() for x in model.swin.parameters()), sum(x.numel() for x in model.decoder.parameters())

[1m([0m[1;36m1183892[0m, [1;36m197632[0m[1m)[0m

In [14]:
from neuro_utils.describe import describe_model

describe_model(model)

Total Parameters: 1,381,536
+---------------------------------------------------------------+------------+
|                             Module                            | Parameters |
+---------------------------------------------------------------+------------+
|                           mask_token                          |     12     |
|    swin.embeddings.patch_embeddings.patch_embeddings.weight   |    192     |
|     swin.embeddings.patch_embeddings.patch_embeddings.bias    |     12     |
|               swin.embeddings.layer_norm.weight               |     12     |
|                swin.embeddings.layer_norm.bias                |     12     |
|    swin.encoder.stages.0.blocks.0.w_layer.mhsa.logit_scale    |     4      |
|    swin.encoder.stages.0.blocks.0.w_layer.mhsa.W_qkv.weight   |    432     |
|     swin.encoder.stages.0.blocks.0.w_layer.mhsa.W_qkv.bias    |     36     |
|    swin.encoder.stages.0.blocks.0.w_layer.mhsa.proj.weight    |    144     |
|     swin.encoder.stage

In [15]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

In [16]:
sample_batch = sample_batch.cuda()
sample_spacings = sample_spacings.cuda()
model = model.cuda()

In [17]:
for i in tqdm(range(200)):
    optimizer.zero_grad()
    output = model(sample_batch, sample_spacings)
    print(f"Loss: {output[1]:f}\tLR: {scheduler.get_last_lr()[0]:f}")
    output[1].backward()
    optimizer.step()
    scheduler.step()

  0%|          | 0/200 [00:00<?, ?it/s]

Loss: 3.210407	LR: 0.500000
Loss: 3.960272	LR: 0.500000
Loss: 5.666294	LR: 0.500000
Loss: 3.388160	LR: 0.500000
Loss: 3.186559	LR: 0.500000
Loss: 2.869438	LR: 0.450000
Loss: 2.782217	LR: 0.450000
Loss: 2.372935	LR: 0.450000
Loss: 1.711966	LR: 0.450000
Loss: 1.976822	LR: 0.450000
Loss: 1.917054	LR: 0.405000
Loss: 1.574714	LR: 0.405000
Loss: 1.528946	LR: 0.405000
Loss: 1.693159	LR: 0.405000
Loss: 1.569976	LR: 0.405000
Loss: 1.565363	LR: 0.364500
Loss: 1.387463	LR: 0.364500
Loss: 1.574689	LR: 0.364500
Loss: 1.272892	LR: 0.364500
Loss: 1.371980	LR: 0.364500
Loss: 1.410923	LR: 0.328050
Loss: 1.224110	LR: 0.328050
Loss: 1.263726	LR: 0.328050
Loss: 1.214210	LR: 0.328050
Loss: 1.185062	LR: 0.328050
Loss: 1.291223	LR: 0.295245
Loss: 1.057200	LR: 0.295245
Loss: 1.117633	LR: 0.295245
Loss: 1.109142	LR: 0.295245
Loss: 1.088373	LR: 0.295245
Loss: 1.038932	LR: 0.265721
Loss: 0.964423	LR: 0.265721
Loss: 1.012183	LR: 0.265721
Loss: 1.008810	LR: 0.265721
Loss: 1.001049	LR: 0.265721
Loss: 0.943001	LR: 0

In [19]:
for name, param in model.named_parameters():
    if param.grad is None:
        print(name)

swin.encoder.stages.0.blocks.0.w_layer.mhsa.logit_scale
swin.encoder.stages.0.blocks.0.sw_layer.mhsa.logit_scale
swin.encoder.stages.1.blocks.0.w_layer.mhsa.logit_scale
swin.encoder.stages.1.blocks.0.sw_layer.mhsa.logit_scale
swin.encoder.stages.1.blocks.1.w_layer.mhsa.logit_scale
swin.encoder.stages.1.blocks.1.sw_layer.mhsa.logit_scale
swin.encoder.stages.1.blocks.2.w_layer.mhsa.logit_scale
swin.encoder.stages.1.blocks.2.sw_layer.mhsa.logit_scale
swin.encoder.stages.2.blocks.0.w_layer.mhsa.logit_scale
swin.encoder.stages.2.blocks.0.sw_layer.mhsa.logit_scale


# nbdev

In [20]:
!nbdev_export

# Rough work