In [None]:
# | default_exp nets/vit_3d

# Imports

In [None]:
# | export

from typing import Literal

import numpy as np
import torch
from einops import rearrange, repeat
from huggingface_hub import PyTorchModelHubMixin
from torch import nn

from vision_architectures.layers.attention import (
    Attention1D,
    Attention1DMLP,
    Attention1DWithMLP,
)
from vision_architectures.layers.embeddings import (
    AbsolutePositionEmbeddings3D,
    PatchEmbeddings3D,
)
from vision_architectures.utils.custom_base_model import CustomBaseModel

# Config

In [None]:
# | export


class ViT3DEncoderConfig(CustomBaseModel):
    dim: int
    num_heads: int
    mlp_ratio: int
    layer_norm_eps: float
    attn_drop_prob: float = 0.0
    proj_drop_prob: float = 0.0
    mlp_drop_prob: float = 0.0
    proj_drop_prob: float = 0.0
    norm_location: Literal["pre", "post"] = "pre"

    encoder_depth: int


class ViT3DConfig(ViT3DEncoderConfig):
    patch_size: tuple[int, int, int]
    in_channels: int
    num_class_tokens: int

    drop_prob: float = 0.0

    # For MIM
    image_size: tuple[int, int, int] | None = None
    mask_ratio: float | None = None


class ViT3DDecoderConfig(CustomBaseModel):
    dim: int
    num_heads: int
    mlp_ratio: int
    layer_norm_eps: float
    attn_drop_prob: float = 0.0
    proj_drop_prob: float = 0.0
    mlp_drop_prob: float = 0.0
    proj_drop_prob: float = 0.0
    norm_location: Literal["pre", "post"] = "pre"

    decoder_depth: int

In [None]:
test_config = ViT3DConfig.model_validate(
    {
        "patch_size": (2, 2, 2),
        "in_channels": 3,
        "dim": 64,
        "num_heads": 8,
        "mlp_ratio": 4,
        "layer_norm_eps": 1e-6,
        "encoder_depth": 3,
        "decoder_depth": 3,
        "num_class_tokens": 0,
    }
)
test_config


[1;35mViT3DConfig[0m[1m([0m
    [33mdim[0m=[1;36m64[0m,
    [33mnum_heads[0m=[1;36m8[0m,
    [33mmlp_ratio[0m=[1;36m4[0m,
    [33mlayer_norm_eps[0m=[1;36m1e[0m[1;36m-06[0m,
    [33mattn_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
    [33mproj_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
    [33mmlp_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
    [33mnorm_location[0m=[32m'pre'[0m,
    [33mencoder_depth[0m=[1;36m3[0m,
    [33mpatch_size[0m=[1m([0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m)[0m,
    [33min_channels[0m=[1;36m3[0m,
    [33mnum_class_tokens[0m=[1;36m0[0m,
    [33mdrop_prob[0m=[1;36m0[0m[1;36m.0[0m,
    [33mimage_size[0m=[3;35mNone[0m,
    [33mmask_ratio[0m=[3;35mNone[0m
[1m)[0m

# Architecture

### Encoder

In [None]:
# | export


class ViT3DEncoderLayer(Attention1DWithMLP):
    def __init__(
        self,
        dim,
        num_heads,
        *args,
        **kwargs,
    ):
        super().__init__(
            dim=dim,
            num_heads=num_heads,
            *args,
            **kwargs,
        )

    def forward(self, qkv: torch.Tensor):
        # qkv: (b, num_tokens, dim)
        return super().forward(qkv, qkv, qkv)

In [None]:
test = ViT3DEncoderLayer(
    dim=54,
    num_heads=6,
    mlp_ratio=2,
    layer_norm_eps=1e-6,
    attn_drop_prob=0.0,
    proj_drop_prob=0.0,
    mlp_drop_prob=0.0,
)
display(test)
o = test(torch.randn(2, 64, 54))
display(o.shape)


[1;35mViT3DEncoderLayer[0m[1m([0m
  [1m([0mattn[1m)[0m: [1;35mAttention1D[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mlayernorm1[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m54[0m,[1m)

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m

In [None]:
# | export


class ViT3DEncoder(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config: ViT3DEncoderConfig):
        super().__init__()

        self.layers = nn.ModuleList(
            [
                ViT3DEncoderLayer(
                    config.dim,
                    config.num_heads,
                    mlp_ratio=config.mlp_ratio,
                    layer_norm_eps=config.layer_norm_eps,
                    attn_drop_prob=config.attn_drop_prob,
                    proj_drop_prob=config.proj_drop_prob,
                    mlp_drop_prob=config.mlp_drop_prob,
                    norm_location=config.norm_location,
                )
                for _ in range(config.encoder_depth)
            ]
        )

    def forward(self, embeddings: torch.Tensor, return_all: bool = False):
        # hidden_states: (b, num_tokens, dim)

        layer_outputs = []
        for encoder_layer in self.layers:
            embeddings = encoder_layer(embeddings)
            # (b, num_tokens, dim)

            layer_outputs.append(embeddings)

        return_value = embeddings
        if return_all:
            return_value = {
                "embeddings": embeddings,
                "layer_outputs": layer_outputs,
            }

        return return_value

In [None]:
test_config = ViT3DEncoderConfig.model_validate(
    {
        "dim": 54,
        "num_heads": 6,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "attn_drop_prob": 0.0,
        "proj_drop_prob": 0.0,
        "mlp_drop_prob": 0.0,
        "encoder_depth": 3,
    }
)

test = ViT3DEncoder(test_config)
display(test)
o = test(torch.randn(2, 64, 54), return_all=True)
display((o["embeddings"].shape, [x.shape for x in o["layer_outputs"]]))


[1;35mViT3DEncoder[0m[1m([0m
  [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m2[0m[1m)[0m: [1;36m3[0m x [1;35mViT3DEncoderLayer[0m[1m([0m
      [1m([0mattn[1m)[0m: [1;35mAttention1D[0m[1m([0m
        [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[

[1m([0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m, [1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m[1m][0m[1m)[0m

### Decoder

In [None]:
# | export


class ViT3DDecoderLayer(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio,
        layer_norm_eps,
        attn_drop_prob=0.0,
        proj_drop_prob=0.0,
        mlp_drop_prob=0.0,
        use_post_norm=False,
    ):
        super().__init__()

        self.use_post_norm = use_post_norm

        self.mhsa = Attention1D(
            dim=dim,
            num_heads=num_heads,
            attn_drop_prob=attn_drop_prob,
            proj_drop_prob=proj_drop_prob,
        )
        self.layernorm1 = nn.LayerNorm(dim, eps=layer_norm_eps)
        self.mhca = Attention1D(
            dim=dim,
            num_heads=num_heads,
            attn_drop_prob=attn_drop_prob,
            proj_drop_prob=proj_drop_prob,
        )
        self.layernorm2 = nn.LayerNorm(dim, eps=layer_norm_eps)
        self.mlp = Attention1DMLP(dim, mlp_ratio=mlp_ratio, activation="gelu", mlp_drop_prob=mlp_drop_prob)
        self.layernorm3 = nn.LayerNorm(dim, eps=layer_norm_eps)

    def forward(self, q: torch.Tensor, kv: torch.Tensor):
        # q: (b, num_tokens_in_q, dim)
        # kv: (b, num_tokens_in_kv, dim)

        res_connection1 = q
        # (b, num_tokens_in_q, dim)

        if not self.use_post_norm:
            q = self.layernorm1(q)
            # (b, num_tokens_in_q, dim)
            kv = self.layernorm1(kv)
            # (b, num_tokens_in_kv, dim)

        hidden_states = self.mhsa(q, q, q)
        # (b, num_tokens_in_q, dim)

        if self.use_post_norm:
            hidden_states = self.layernorm1(hidden_states)
            # (b, num_tokens_in_q, dim)

        hidden_states = hidden_states + res_connection1
        res_connection2 = hidden_states
        # (b, num_tokens_in_q, dim)

        if not self.use_post_norm:
            hidden_states = self.layernorm1(hidden_states)
            # (b, num_tokens_in_q, dim)

        hidden_states = self.mhca(hidden_states, kv, kv)
        # (b, num_tokens_in_q, dim)

        if self.use_post_norm:
            hidden_states = self.layernorm2(hidden_states)
            # (b, num_tokens_in_q, dim)

        hidden_states = hidden_states + res_connection2
        res_connection3 = hidden_states
        # (b, num_tokens_in_q, dim)

        if not self.use_post_norm:
            hidden_states = self.layernorm3(hidden_states)
            # (b, num_tokens_in_q, dim)

        hidden_states = self.mlp(hidden_states)
        # (b, num_tokens_in_q, dim)

        if self.use_post_norm:
            hidden_states = self.layernorm3(hidden_states)
            # (b, num_tokens_in_q, dim)

        hidden_states = hidden_states + res_connection3
        # (b, num_tokens_in_q, dim)

        return hidden_states

In [None]:
test = ViT3DDecoderLayer(52, 4, 2, 1e-6)
display(test)
display(test(torch.randn(2, 64, 52), torch.randn(2, 64, 52)).shape)


[1;35mViT3DDecoderLayer[0m[1m([0m
  [1m([0mmhsa[1m)[0m: [1;35mAttention1D[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mlayernorm1[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m52[0m,[1m)

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m52[0m[1m][0m[1m)[0m

In [None]:
test = ViT3DDecoderLayer(52, 4, 2, 1e-6)
display(test)
display(test(torch.randn(2, 3, 52), torch.randn(2, 64, 52)).shape)


[1;35mViT3DDecoderLayer[0m[1m([0m
  [1m([0mmhsa[1m)[0m: [1;35mAttention1D[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mlayernorm1[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m52[0m,[1m)

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m52[0m[1m][0m[1m)[0m

In [None]:
# | export


class ViT3DDecoder(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config: ViT3DDecoderConfig):
        super().__init__()

        self.layers = nn.ModuleList(
            [
                ViT3DDecoderLayer(
                    config.dim,
                    config.num_heads,
                    config.mlp_ratio,
                    config.layer_norm_eps,
                    config.attn_drop_prob,
                    config.proj_drop_prob,
                    config.mlp_drop_prob,
                )
                for _ in range(config.decoder_depth)
            ]
        )

    def forward(self, q: torch.Tensor, kv: torch.Tensor, return_all: bool = False):
        # q: (b, num_q_tokens, dim)
        # kv: (b, num_kv_tokens, dim)

        embeddings = q

        layer_outputs = []
        for decoder_layer in self.layers:
            embeddings = decoder_layer(embeddings, kv)
            # (b, num_q_tokens, dim)

            layer_outputs.append(embeddings)

        return_value = embeddings
        if return_all:
            return_value = {
                "embeddings": embeddings,
                "layer_outputs": layer_outputs,
            }

        return return_value

In [None]:
test_config = ViT3DDecoderConfig.model_validate(
    {
        "dim": 54,
        "num_heads": 6,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "attn_drop_prob": 0.0,
        "proj_drop_prob": 0.0,
        "mlp_drop_prob": 0.0,
        "decoder_depth": 5,
    }
)

test = ViT3DDecoder(test_config)
display(test)
o = test(torch.randn(2, 64, 54), torch.randn(2, 128, 54), return_all=True)
display((o["embeddings"].shape, [x.shape for x in o["layer_outputs"]]))


[1;35mViT3DDecoder[0m[1m([0m
  [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m4[0m[1m)[0m: [1;36m5[0m x [1;35mViT3DDecoderLayer[0m[1m([0m
      [1m([0mmhsa[1m)[0m: [1;35mAttention1D[0m[1m([0m
        [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

# Models

In [None]:
# | export


class ViT3DModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config: ViT3DConfig):
        super().__init__()

        self.patchify = PatchEmbeddings3D(config.patch_size, config.in_channels, config.dim)
        self.absolute_position_embeddings = AbsolutePositionEmbeddings3D(config.dim, learnable=False)
        self.pos_drop = nn.Dropout(config.drop_prob)
        self.num_class_tokens = config.num_class_tokens
        if self.num_class_tokens > 0:
            self.class_tokens = nn.Parameter(torch.randn(1, config.num_class_tokens, config.dim))
        self.encoder = ViT3DEncoder(config)

    def forward(
        self,
        pixel_values: torch.Tensor,
        spacings: torch.Tensor,
        mask_patches: torch.Tensor = None,
        mask_token: torch.Tensor = None,
        return_all: bool = False,
    ):
        # pixel_values: (b, c, z, y, x)
        # spacings: (b, 3)
        # mask_patches: (b, num_patches_z, num_patches_y, num_patches_x)
        # mask_token: (1, dim, 1, 1, 1)

        embeddings = self.patchify(pixel_values)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        if mask_patches is not None:
            # mask_patches (binary mask): (b, num_patches_z, num_patches_y, num_patches_x)
            # mask_token: (1, dim, 1, 1, 1)
            mask_patches = repeat(mask_patches, "b z y x -> b d z y x", d=embeddings.shape[1])
            embeddings = (embeddings * (1 - mask_patches)) + (mask_patches * mask_token)

        absolute_position_embeddings = self.absolute_position_embeddings(
            batch_size=embeddings.shape[0],
            grid_size=embeddings.shape[2:],
            spacings=spacings,
            device=pixel_values.device,
        )
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)
        embeddings = embeddings + absolute_position_embeddings
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        embeddings = rearrange(embeddings, "b e nz ny nx -> b (nz ny nx) e")
        # (b, num_tokens, dim)

        embeddings = self.pos_drop(embeddings)
        # (b, num_tokens, dim)

        class_tokens = None
        if self.num_class_tokens > 0:
            class_tokens = repeat(self.class_tokens, "1 n d -> b n d", b=embeddings.shape[0])
            embeddings = torch.cat([class_tokens, embeddings], dim=1)
            # (b, num_tokens + num_class_tokens, dim)

        encoder_output = self.encoder(embeddings, return_all=True)
        encoded, layer_outputs = (
            encoder_output["embeddings"],
            encoder_output["layer_outputs"],
        )
        # encoded: (b, num_tokens (+ num_class_tokens), dim)
        # layer_outputs: list of (b, num_tokens (+ 1), dim)

        if self.num_class_tokens > 0:
            class_tokens = encoded[:, : self.num_class_tokens]
            encoded = encoded[:, self.num_class_tokens :]

        return_value = class_tokens, encoded
        if return_all:
            return_value = {
                "class_tokens": class_tokens,
                "encoded": encoded,
                "layer_outputs": layer_outputs,
            }

        return return_value

In [None]:
test_config = ViT3DConfig.model_validate(
    {
        "num_class_tokens": 2,
        "attn_drop_prob": 0.2,
        "dim": 768,
        "drop_prob": 0.2,
        "embed_spacing_info": False,
        "encoder_depth": 4,
        "image_size": (32, 512, 512),
        "in_channels": 1,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "mlp_drop_prob": 0.2,
        "num_heads": 4,
        "patch_size": (8, 16, 16),
        "proj_drop_prob": 0.2,
    }
)

test = ViT3DModel(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
    return_all=True,
)
display((o["class_tokens"].shape, o["encoded"].shape, [x.shape for x in o["layer_outputs"]]))


[1;35mViT3DModel[0m[1m([0m
  [1m([0mpatchify[1m)[0m: [1;35mPatchEmbeddings3D[0m[1m([0m
    [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m768[0m, [33mkernel_size[0m=[1m([0m[1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m)[0m[1m)[0m
    [1m([0mnormalization[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m768[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
  [1m([0mabsolute_position_embeddings[1m)[0m: [1;35mAbsolutePositionEmbeddings3D[0m[1m([0m[1m)[0m
  [1m([0mpos_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.2[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mencoder[1m)[0m: [1;35mViT3DEncoder[0m[1m([0m
    [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m-[1;36m3[0m[1m)[0m: [1;36m


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m2[0m, [1;36m768[0m[1m][0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4096[0m, [1;36m768[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4098[0m, [1;36m768[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4098[0m, [1;36m768[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4098[0m, [1;36m768[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4098[0m, [1;36m768[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

# Masked Image Modeling

In [None]:
# | export


class ViT3DMIMDecoder(nn.Module):
    def __init__(self, dim, image_size, in_channels, patch_size):
        super().__init__()

        self.image_size = image_size
        self.in_channels = in_channels
        self.patch_size = patch_size

        out_dim = np.prod(self.patch_size) * self.in_channels

        self.decoder = nn.Linear(dim, out_dim)

    def forward(self, encodings: torch.Tensor):
        # encodings: (b, num_tokens, dim)

        decoded = self.decoder(encodings)
        # (b, num_tokens, new_dim)

        decoded = rearrange(
            decoded,
            "b (nz ny nx) (c pz py px) -> b c (nz pz) (ny py) (nx px)",
            c=self.in_channels,
            pz=self.patch_size[0],
            py=self.patch_size[1],
            px=self.patch_size[2],
            nz=self.image_size[0] // self.patch_size[0],
            ny=self.image_size[1] // self.patch_size[1],
            nx=self.image_size[2] // self.patch_size[2],
        )
        # (b, c, z, y, x)

        return decoded

In [None]:
test = ViT3DMIMDecoder(768, (32, 512, 512), 1, (8, 16, 16))
display(test)
display(test(torch.randn(2, 4096, 768)).shape)


[1;35mViT3DMIMDecoder[0m[1m([0m
  [1m([0mdecoder[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m2048[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m32[0m, [1;36m512[0m, [1;36m512[0m[1m][0m[1m)[0m

In [None]:
# | export


class ViT3DMIM(nn.Module):
    def __init__(self, config: ViT3DConfig):
        super().__init__()

        assert config.num_class_tokens == 0, "MIM does not support class tokens"

        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.in_channels = config.in_channels
        self.mask_ratio = config.mask_ratio

        self.vit = ViT3DModel(config)
        self.decoder = ViT3DMIMDecoder(config.dim, config.image_size, config.in_channels, config.patch_size)

        self.mask_token = nn.Parameter(torch.randn(1, config.dim, 1, 1, 1))

    def mask_image(self, pixel_values: torch.Tensor):
        b = pixel_values.shape[0]

        mask_ratio = self.mask_ratio
        grid_size = tuple([size // patch for size, patch in zip(self.image_size, self.patch_size)])
        num_tokens = np.prod(grid_size)
        mask_patches = []
        for _ in range(b):
            _mask_patches = torch.zeros(num_tokens, dtype=torch.int8, device=pixel_values.device)
            _mask_patches[: int(mask_ratio * num_tokens)] = 1
            _mask_patches = _mask_patches[torch.randperm(num_tokens)]
            _mask_patches = rearrange(
                _mask_patches,
                "(z y x) -> z y x",
                z=grid_size[0],
                y=grid_size[1],
                x=grid_size[2],
            )
            mask_patches.append(_mask_patches)
        mask_patches: torch.Tensor = torch.stack(mask_patches, dim=0)

        return mask_patches

In [None]:
# | export


class ViT3DSimMIM(ViT3DMIM, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__(config)

    @staticmethod
    def loss_fn(pred: torch.Tensor, target: torch.Tensor, reduction="mean"):
        return nn.functional.l1_loss(pred, target, reduction=reduction)

    def forward(self, pixel_values: torch.Tensor, spacings: torch.Tensor):
        mask_patches = self.mask_image(pixel_values)

        _, encodings = self.vit(pixel_values, spacings, mask_patches, self.mask_token)
        decoded = self.decoder(encodings)

        loss = self.loss_fn(decoded, pixel_values, reduction="none")
        mask = repeat(
            mask_patches,
            "b z y x -> b (z pz) (y py) (x px)",
            pz=self.patch_size[0],
            py=self.patch_size[1],
            px=self.patch_size[2],
        )
        loss = (loss * mask).sum() / ((mask.sum() + 1e-5) * self.in_channels)

        return decoded, loss, mask

In [None]:
test_config = ViT3DConfig.model_validate(
    {
        "num_class_tokens": 0,
        "attn_drop_prob": 0.2,
        "dim": 768,
        "drop_prob": 0.2,
        "encoder_depth": 4,
        "image_size": (32, 128, 128),
        "in_channels": 1,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "mlp_drop_prob": 0.2,
        "num_heads": 4,
        "patch_size": (8, 16, 16),
        "proj_drop_prob": 0.2,
        "mask_ratio": 0.8,
    }
)

test = ViT3DSimMIM(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 128, 128),
    torch.randn(2, 3),
)
display((o[0].shape, o[1], o[2].shape))


[1;35mViT3DSimMIM[0m[1m([0m
  [1m([0mvit[1m)[0m: [1;35mViT3DModel[0m[1m([0m
    [1m([0mpatchify[1m)[0m: [1;35mPatchEmbeddings3D[0m[1m([0m
      [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m768[0m, [33mkernel_size[0m=[1m([0m[1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m)[0m[1m)[0m
      [1m([0mnormalization[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m768[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
    [1m)[0m
    [1m([0mabsolute_position_embeddings[1m)[0m: [1;35mAbsolutePositionEmbeddings3D[0m[1m([0m[1m)[0m
    [1m([0mpos_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.2[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mencoder[1m)[0m: [1;35mViT3DEncoder[0m[1m([0m
      [1m([0mlayers[1m)[0m: [1;35mModuleLis


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m32[0m, [1;36m128[0m, [1;36m128[0m[1m][0m[1m)[0m,
    [1;35mtensor[0m[1m([0m[1;36m2.2192[0m, [33mgrad_fn[0m=[1m<[0m[1;95mDivBackward0[0m[1m>[0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m128[0m, [1;36m128[0m[1m][0m[1m)[0m
[1m)[0m

# Some more tests

### Overfitting tests

In [None]:
from tqdm.auto import tqdm

sample_spacings = torch.tensor([[1, 0.1, 0.1], [2, 0.2, 0.2], [3, 0.3, 0.3]])
sample_batch = torch.rand(3, 1, 16, 128, 128)
sample_config = ViT3DConfig.model_validate(
    {
        "num_class_tokens": 0,
        "attn_drop_prob": 0.2,
        "dim": 384,
        "drop_prob": 0.2,
        "embed_spacing_info": False,
        "encoder_depth": 4,
        "image_size": (16, 128, 128),
        "in_channels": 1,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "mlp_drop_prob": 0.2,
        "num_heads": 4,
        "patch_size": (8, 16, 16),
        "proj_drop_prob": 0.2,
        "mask_ratio": 0.8,
    }
)

model = ViT3DSimMIM(sample_config)

sum(x.numel() for x in model.vit.parameters()), sum(x.numel() for x in model.decoder.parameters())

[1m([0m[1;36m5523076[0m, [1;36m788480[0m[1m)[0m

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

In [None]:
sample_batch = sample_batch.cuda()
sample_spacings = sample_spacings.cuda()
model = model.cuda()

In [None]:
for i in tqdm(range(200)):
    optimizer.zero_grad()
    output = model(sample_batch, sample_spacings)
    print(f"Loss: {output[1]:f}\tLR: {scheduler.get_last_lr()[0]:f}")
    output[1].backward()
    optimizer.step()
    scheduler.step()

  0%|          | 0/200 [00:00<?, ?it/s]

Loss: 2.724077	LR: 0.500000
Loss: 2.063609	LR: 0.500000
Loss: 1.932758	LR: 0.500000
Loss: 1.833690	LR: 0.500000
Loss: 1.747474	LR: 0.500000
Loss: 1.622086	LR: 0.450000
Loss: 1.551356	LR: 0.450000
Loss: 1.501874	LR: 0.450000
Loss: 1.431908	LR: 0.450000
Loss: 1.400595	LR: 0.450000
Loss: 1.409469	LR: 0.405000
Loss: 1.366366	LR: 0.405000
Loss: 1.320952	LR: 0.405000
Loss: 1.337093	LR: 0.405000
Loss: 1.329062	LR: 0.405000
Loss: 1.311299	LR: 0.364500
Loss: 1.276818	LR: 0.364500
Loss: 1.262741	LR: 0.364500
Loss: 1.262444	LR: 0.364500
Loss: 1.255624	LR: 0.364500
Loss: 1.252090	LR: 0.328050
Loss: 1.235026	LR: 0.328050
Loss: 1.215526	LR: 0.328050
Loss: 1.215673	LR: 0.328050
Loss: 1.201721	LR: 0.328050
Loss: 1.198486	LR: 0.295245
Loss: 1.184551	LR: 0.295245
Loss: 1.178360	LR: 0.295245
Loss: 1.177834	LR: 0.295245
Loss: 1.177073	LR: 0.295245
Loss: 1.171974	LR: 0.265721
Loss: 1.164572	LR: 0.265721
Loss: 1.160800	LR: 0.265721
Loss: 1.155823	LR: 0.265721
Loss: 1.151032	LR: 0.265721
Loss: 1.154409	LR: 0

In [None]:
for name, param in model.named_parameters():
    if param.grad is None:
        print(name)

vit.encoder.layers.0.attn.logit_scale
vit.encoder.layers.1.attn.logit_scale
vit.encoder.layers.2.attn.logit_scale
vit.encoder.layers.3.attn.logit_scale


# nbdev

In [None]:
!nbdev_export