In [None]:
# | default_exp nets/vit_3d

# Imports

In [None]:
# | export

from functools import wraps
from typing import Literal

import torch
from einops import rearrange, repeat
from huggingface_hub import PyTorchModelHubMixin
from torch import nn

from vision_architectures.blocks.transformer import TransformerDecoderBlock1D, TransformerEncoderBlock1D
from vision_architectures.docstrings import populate_docstring
from vision_architectures.layers.embeddings import AbsolutePositionEmbeddings3D, PatchEmbeddings3D
from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.custom_base_model import CustomBaseModel, Field
from vision_architectures.utils.rearrange import rearrange_channels

# Config

In [None]:
# | export


class ViTEncoderDecoderConfig(CustomBaseModel):
    dim: int = Field(..., description="Dimension of the patches.")
    num_heads: int = Field(..., description="Number of attention heads.")
    mlp_ratio: int = Field(..., description="Ratio of mlp hidden dim to embedding dim.")
    layer_norm_eps: float = Field(..., description="Epsilon value for the layer normalization.")
    attn_drop_prob: float = Field(0.0, description="Dropout probability for attention weights.")
    proj_drop_prob: float = Field(0.0, description="Dropout probability for the projection layer.")
    mlp_drop_prob: float = Field(0.0, description="Dropout probability for the mlp layer.")
    norm_location: Literal["pre", "post"] = Field(
        "pre", description="Location of the normalization layer with respect to the attention."
    )


class ViTEncoderConfig(ViTEncoderDecoderConfig):
    encoder_depth: int = Field(..., description="Number of encoder blocks.")


class ViT3DEncoderWithPatchEmbeddingsConfig(ViTEncoderConfig):
    patch_size: tuple[int, int, int] = Field(..., description="Size of the patches to be embedded.")
    in_channels: int = Field(..., description="Number of input channels in the input datapoints.")
    num_class_tokens: int = Field(..., description="Number of class tokens to be added.")

    drop_prob: float = Field(0.0, description="Dropout probability for the embedding layer.")


class ViTDecoderConfig(ViTEncoderDecoderConfig):
    decoder_depth: int

In [None]:
test_config = ViT3DEncoderWithPatchEmbeddingsConfig.model_validate(
    {
        "patch_size": (2, 2, 2),
        "in_channels": 3,
        "dim": 64,
        "num_heads": 8,
        "mlp_ratio": 4,
        "layer_norm_eps": 1e-6,
        "encoder_depth": 3,
        "decoder_depth": 3,
        "num_class_tokens": 0,
    }
)
test_config


[1;35mViT3DEncoderWithPatchEmbeddingsConfig[0m[1m([0m
    [33mdim[0m=[1;36m64[0m,
    [33mnum_heads[0m=[1;36m8[0m,
    [33mmlp_ratio[0m=[1;36m4[0m,
    [33mlayer_norm_eps[0m=[1;36m1e[0m[1;36m-06[0m,
    [33mattn_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
    [33mproj_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
    [33mmlp_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
    [33mnorm_location[0m=[32m'pre'[0m,
    [33mencoder_depth[0m=[1;36m3[0m,
    [33mpatch_size[0m=[1m([0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m)[0m,
    [33min_channels[0m=[1;36m3[0m,
    [33mnum_class_tokens[0m=[1;36m0[0m,
    [33mdrop_prob[0m=[1;36m0[0m[1;36m.0[0m
[1m)[0m

# Architecture

### Encoder

In [None]:
# | export


@populate_docstring
class ViTEncoder(nn.Module, PyTorchModelHubMixin):
    """Vision Transformer encoder. {CLASS_DESCRIPTION_1D_DOC}"""

    @populate_docstring
    def __init__(self, config: ViTEncoderConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the ViTEncoder.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = ViTEncoderConfig.model_validate(config | kwargs)

        self.layers = nn.ModuleList(
            [
                TransformerEncoderBlock1D(self.config.model_dump(), checkpointing_level=checkpointing_level)
                for _ in range(self.config.encoder_depth)
            ]
        )

        self.checkpointing_level5 = ActivationCheckpointing(5, checkpointing_level)

    @populate_docstring
    def _forward(
        self, embeddings: torch.Tensor, return_intermediates: bool = False
    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
        """Pass the input embeddings through the ViT encoder (self attention).

        Args:
            embeddings: {INPUT_1D_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}

        Returns:
            {OUTPUT_1D_DOC} If `return_intermediates` is True, returns a tuple of the output embeddings and a list of
            intermediate embeddings.
        """
        # hidden_states: (b, num_tokens, dim)

        layer_outputs = []
        for encoder_layer in self.layers:
            embeddings = encoder_layer(embeddings)
            # (b, num_tokens, dim)

            layer_outputs.append(embeddings)

        if return_intermediates:
            return embeddings, layer_outputs
        return embeddings

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level5(self._forward, *args, **kwargs)

In [None]:
test_config = ViTEncoderConfig.model_validate(
    {
        "dim": 54,
        "num_heads": 6,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "attn_drop_prob": 0.0,
        "proj_drop_prob": 0.0,
        "mlp_drop_prob": 0.0,
        "encoder_depth": 3,
    }
)

test = ViTEncoder(test_config)
display(test)
o = test(torch.randn(2, 64, 54), return_intermediates=True)
display((o[0].shape, [x.shape for x in o[1]]))


[1;35mViTEncoder[0m[1m([0m
  [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m2[0m[1m)[0m: [1;36m3[0m x [1;35mTransformerEncoderBlock1D[0m[1m([0m
      [1m([0mattn[1m)[0m: [1;35mAttention1D[0m[1m([0m
        [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m(

[1m([0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m, [1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m[1m][0m[1m)[0m

### Decoder

In [None]:
# | export


@populate_docstring
class ViTDecoder(nn.Module, PyTorchModelHubMixin):
    """Vision Transformer decoder. {CLASS_DESCRIPTION_1D_DOC}"""

    @populate_docstring
    def __init__(self, config: ViTDecoderConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the ViTDecoder.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = ViTDecoderConfig.model_validate(config | kwargs)

        self.layers = nn.ModuleList(
            [
                TransformerDecoderBlock1D(self.config.model_dump(), checkpointing_level=checkpointing_level)
                for _ in range(self.config.decoder_depth)
            ]
        )

        self.checkpointing_level5 = ActivationCheckpointing(5, checkpointing_level)

    @populate_docstring
    def _forward(
        self, q: torch.Tensor, kv: torch.Tensor, return_intermediates: bool = False
    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
        """Pass the input embeddings through the ViT decoder (self attention + cross attention).

        Args:
            q: Input to the query matrix. {INPUT_1D_DOC}
            kv: Input tot he key and value matrices. {INPUT_1D_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}

        Returns:
            {OUTPUT_1D_DOC} If `return_intermediates` is True, returns a tuple of the output embeddings and a list of
            intermediate embeddings.
        """
        # q: (b, num_q_tokens, dim)
        # kv: (b, num_kv_tokens, dim)

        embeddings = q

        layer_outputs = []
        for decoder_layer in self.layers:
            embeddings = decoder_layer(embeddings, kv)
            # (b, num_q_tokens, dim)

            layer_outputs.append(embeddings)

        if return_intermediates:
            return embeddings, layer_outputs
        return embeddings

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level5(self._forward, *args, **kwargs)

In [None]:
test_config = ViTDecoderConfig.model_validate(
    {
        "dim": 54,
        "num_heads": 6,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "attn_drop_prob": 0.0,
        "proj_drop_prob": 0.0,
        "mlp_drop_prob": 0.0,
        "decoder_depth": 5,
    }
)

test = ViTDecoder(test_config)
display(test)
o = test(torch.randn(2, 64, 54), torch.randn(2, 128, 54), return_intermediates=True)
display((o[0].shape, [x.shape for x in o[1]]))


[1;35mViTDecoder[0m[1m([0m
  [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m4[0m[1m)[0m: [1;36m5[0m x [1;35mTransformerDecoderBlock1D[0m[1m([0m
      [1m([0mattn1[1m)[0m: [1;35mAttention1D[0m[1m([0m
        [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

# Models

In [None]:
# | export


@populate_docstring
class ViT3DEncoderWithPatchEmbeddings(nn.Module, PyTorchModelHubMixin):
    """Patchification of input array followd by a ViT encoder. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: ViT3DEncoderWithPatchEmbeddingsConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the ViT3DEncoderWithPatchEmbeddings.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = ViT3DEncoderWithPatchEmbeddingsConfig.model_validate(config | kwargs)

        self.patchify = PatchEmbeddings3D(
            patch_size=self.config.patch_size,
            in_channels=self.config.in_channels,
            dim=self.config.dim,
            checkpointing_level=checkpointing_level,
        )
        self.absolute_position_embeddings = AbsolutePositionEmbeddings3D(dim=self.config.dim, learnable=False)
        self.pos_drop = nn.Dropout(self.config.drop_prob)
        self.num_class_tokens = self.config.num_class_tokens
        if self.num_class_tokens > 0:
            self.class_tokens = nn.Parameter(torch.randn(1, self.config.num_class_tokens, self.config.dim))
        self.encoder = ViTEncoder(self.config, checkpointing_level=checkpointing_level)

    @populate_docstring
    def forward(
        self,
        pixel_values: torch.Tensor,
        spacings: torch.Tensor,
        channels_first: bool = True,
        return_intermediates: bool = False,
    ) -> tuple[torch.Tensor, list[torch.Tensor]] | tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
        """Patchify the input datapoint and then pass through the ViT encoder (self attention).

        Args:
            pixel_values: {INPUT_3D_DOC}
            spacings: {SPACINGS_DOC}
            channels_first: {CHANNELS_FIRST_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}

        Returns:
            {OUTPUT_1D_DOC} If `return_intermediates` is True, returns a tuple of the output embeddings and a list of
            intermediate embeddings.
        """
        # pixel_values: (b, [c], z, y, x, [c])
        # spacings: (b, 3)

        pixel_values = rearrange_channels(pixel_values, channels_first, True)
        # (b, c, z, y, x)

        embeddings = self.patchify(pixel_values, channels_first=True)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        embeddings = self.absolute_position_embeddings(embeddings, spacings=spacings, channels_first=True)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        embeddings = rearrange(embeddings, "b e nz ny nx -> b (nz ny nx) e").contiguous()
        # (b, num_tokens, dim)

        embeddings = self.pos_drop(embeddings)
        # (b, num_tokens, dim)

        class_tokens = None
        if self.num_class_tokens > 0:
            class_tokens = repeat(self.class_tokens, "1 n d -> b n d", b=embeddings.shape[0])
            embeddings = torch.cat([class_tokens, embeddings], dim=1)
            # (b, num_tokens + num_class_tokens, dim)

        encoded, layer_outputs = self.encoder(embeddings, return_intermediates=True)
        # encoded: (b, num_tokens (+ num_class_tokens), dim)
        # layer_outputs: list of (b, num_tokens (+ 1), dim)

        if self.num_class_tokens > 0:
            class_tokens = encoded[:, : self.num_class_tokens]
            encoded = encoded[:, self.num_class_tokens :]

        if return_intermediates:
            return encoded, class_tokens, layer_outputs
        return encoded, class_tokens

In [None]:
test_config = ViT3DEncoderWithPatchEmbeddingsConfig.model_validate(
    {
        "num_class_tokens": 2,
        "attn_drop_prob": 0.2,
        "dim": 768,
        "drop_prob": 0.2,
        "embed_spacing_info": False,
        "encoder_depth": 4,
        "image_size": (32, 512, 512),
        "in_channels": 1,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "mlp_drop_prob": 0.2,
        "num_heads": 4,
        "patch_size": (8, 16, 16),
        "proj_drop_prob": 0.2,
    }
)

test = ViT3DEncoderWithPatchEmbeddings(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
    return_intermediates=True,
)
display((o[0].shape, o[1].shape, [x.shape for x in o[2]]))


[1;35mViT3DEncoderWithPatchEmbeddings[0m[1m([0m
  [1m([0mpatchify[1m)[0m: [1;35mPatchEmbeddings3D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m768[0m, [33mkernel_size[0m=[1m([0m[1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m768[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mabsolute_position_embeddings[1m)[0m: [1;35mAbsolutePositionEmbeddings3D[0m[1m([0m[1m)[0m
  [1m([0mpos_drop[1m)[0m: [1;


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4096[0m, [1;36m768[0m[1m][0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m2[0m, [1;36m768[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4098[0m, [1;36m768[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4098[0m, [1;36m768[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4098[0m, [1;36m768[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4098[0m, [1;36m768[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

# nbdev

In [None]:
!nbdev_export