In [None]:
# | default_exp nets/vit_3d

# Imports

In [None]:
# | export

from functools import wraps

import torch
from einops import rearrange, repeat
from huggingface_hub import PyTorchModelHubMixin
from torch import nn

from vision_architectures.blocks.transformer import (
    TransformerDecoderBlock3D,
    TransformerDecoderBlock3DConfig,
    TransformerEncoderBlock3D,
    TransformerEncoderBlock3DConfig,
)
from vision_architectures.docstrings import populate_docstring
from vision_architectures.layers.embeddings import (
    AbsolutePositionEmbeddings3D,
    AbsolutePositionEmbeddings3DConfig,
    PatchEmbeddings3D,
    PatchEmbeddings3DConfig,
)
from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.custom_base_model import Field
from vision_architectures.utils.rearrange import rearrange_channels

# Config

In [None]:
# | export


class ViT3DEncoderConfig(TransformerEncoderBlock3DConfig):
    encoder_depth: int = Field(..., description="Number of encoder blocks.")


class ViT3DEncoderWithPatchEmbeddingsConfig(ViT3DEncoderConfig, PatchEmbeddings3DConfig):
    absolute_position_embeddings_config: AbsolutePositionEmbeddings3DConfig | None = {}
    num_class_tokens: int = Field(..., description="Number of class tokens to be added.")


class ViT3DDecoderConfig(TransformerDecoderBlock3DConfig):
    decoder_depth: int = Field(..., description="Number of decoder blocks.")


class ViT3DDecoderWithPatchEmbeddingsConfig(ViT3DDecoderConfig, PatchEmbeddings3DConfig):
    absolute_position_embeddings_config: AbsolutePositionEmbeddings3DConfig | None = {}
    num_class_tokens: int = Field(..., description="Number of class tokens to be added.")

In [None]:
test_config = ViT3DEncoderWithPatchEmbeddingsConfig.model_validate(
    {
        "patch_size": (2, 2, 2),
        "in_channels": 3,
        "dim": 64,
        "num_heads": 8,
        "mlp_ratio": 4,
        "layer_norm_eps": 1e-6,
        "encoder_depth": 3,
        "decoder_depth": 3,
        "num_class_tokens": 2,
    }
)
test_config

# Architecture

### Encoder

In [None]:
# | export


@populate_docstring
class ViT3DEncoder(nn.Module, PyTorchModelHubMixin):
    """Vision Transformer encoder. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: ViT3DEncoderConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the ViT3DEncoder.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = ViT3DEncoderConfig.model_validate(config | kwargs)

        self.layers = nn.ModuleList(
            [
                TransformerEncoderBlock3D(self.config.model_dump(), checkpointing_level=checkpointing_level)
                for _ in range(self.config.encoder_depth)
            ]
        )

        self.checkpointing_level5 = ActivationCheckpointing(5, checkpointing_level)

    @populate_docstring
    def _forward(
        self,
        embeddings: torch.Tensor,
        return_intermediates: bool = False,
        channels_first: bool = True,
        query_grid_shape: tuple[int, int, int] | None = None,
        key_grid_shape: tuple[int, int, int] | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
        """Pass the input embeddings through the ViT encoder (self attention).

        Args:
            embeddings: {INPUT_3D_OR_1D_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}
            channels_first: {CHANNELS_FIRST_DOC}
            query_grid_shape: {ROTARY_POSITION_EMBEDDINGS_GRID_SHAPE_DOC}
            key_grid_shape: {ROTARY_POSITION_EMBEDDINGS_GRID_SHAPE_DOC}

        Returns:
            {OUTPUT_3D_OR_1D_DOC} If `return_intermediates` is True, returns a tuple of the output embeddings and a list of
            intermediate embeddings in channels_last format.
        """
        # embeddings: (b, T, dim) or (b, [dim], z, y, x, [dim])

        if embeddings.ndim == 5:
            embeddings = rearrange_channels(embeddings, channels_first, False)
            # (b, z, y, x, dim)

        layer_outputs = []
        for encoder_layer in self.layers:
            embeddings = encoder_layer(
                qkv=embeddings, channels_first=False, query_grid_shape=query_grid_shape, key_grid_shape=key_grid_shape
            )
            # (b, T, dim) or (b, z, y, x, dim)

            layer_outputs.append(embeddings)

        if embeddings.ndim == 5:
            embeddings = rearrange_channels(embeddings, False, channels_first)
            # (b, [dim], z, y, x, [dim])

        if return_intermediates:
            return embeddings, layer_outputs
        return embeddings

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level5(self._forward, *args, **kwargs)

In [None]:
test_config = ViT3DEncoderConfig.model_validate(
    {
        "dim": 54,
        "num_heads": 6,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "attn_drop_prob": 0.0,
        "proj_drop_prob": 0.0,
        "mlp_drop_prob": 0.0,
        "encoder_depth": 3,
        "rotary_position_embeddings_config": {},
    }
)

test = ViT3DEncoder(test_config)
display(test)
o = test(torch.randn(2, 54, 4, 4, 4), return_intermediates=True)
display((o[0].shape, [x.shape for x in o[1]]))

In [None]:
test_config = ViT3DEncoderConfig.model_validate(
    {
        "dim": 54,
        "num_heads": 6,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "attn_drop_prob": 0.0,
        "proj_drop_prob": 0.0,
        "mlp_drop_prob": 0.0,
        "encoder_depth": 3,
        "rotary_position_embeddings_config": {},
    }
)

test = ViT3DEncoder(test_config)
display(test)
o = test(torch.randn(2, 65, 54), return_intermediates=True, query_grid_shape=(4, 4, 4))
display((o[0].shape, [x.shape for x in o[1]]))

### Decoder

In [None]:
# | export


@populate_docstring
class ViT3DDecoder(nn.Module, PyTorchModelHubMixin):
    """Vision Transformer decoder. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: ViT3DDecoderConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the ViT3DDecoder.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = ViT3DDecoderConfig.model_validate(config | kwargs)

        self.layers = nn.ModuleList(
            [
                TransformerDecoderBlock3D(self.config.model_dump(), checkpointing_level=checkpointing_level)
                for _ in range(self.config.decoder_depth)
            ]
        )

        self.checkpointing_level5 = ActivationCheckpointing(5, checkpointing_level)

    @populate_docstring
    def _forward(
        self,
        q: torch.Tensor,
        kv: torch.Tensor,
        return_intermediates: bool = False,
        channels_first: bool = True,
        q_grid_shape: tuple[int, int, int] | None = None,
        kv_grid_shape: tuple[int, int, int] | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
        """Pass the input embeddings through the ViT decoder (self attention + cross attention).

        Args:
            q: Input to the query matrix. {INPUT_3D_DOC}
            kv: Input to the key and value matrices. {INPUT_3D_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}
            channels_first: {CHANNELS_FIRST_DOC}
            q_grid_shape: {ROTARY_POSITION_EMBEDDINGS_GRID_SHAPE_DOC}
            kv_grid_shape: {ROTARY_POSITION_EMBEDDINGS_GRID_SHAPE_DOC}

        Returns:
            {OUTPUT_3D_DOC} If `return_intermediates` is True, returns a tuple of the output embeddings and a list of
            intermediate embeddings in channels_last format.
        """
        # q: (b, T, dim) or (b, [dim], q_z, q_y, q_x, [dim])
        # kv: (b, T, dim) or (b, [dim], kv_z, kv_y, kv_x, [dim])

        if q.ndim == 5:
            q = rearrange_channels(q, channels_first, False)
            # (b, q_z, q_y, q_x, dim)
            kv = rearrange_channels(kv, channels_first, False)
            # (b, kv_z, kv_y, kv_x, dim)

        embeddings = q

        layer_outputs = []
        for decoder_layer in self.layers:
            embeddings = decoder_layer(
                q=embeddings,
                kv=kv,
                channels_first=False,
                q_grid_shape=q_grid_shape,
                k2_grid_shape=kv_grid_shape,
            )
            # (b, T, dim) or (b, q_z, q_y, q_x, dim)

            layer_outputs.append(embeddings)

        if embeddings.ndim == 5:
            embeddings = rearrange_channels(embeddings, False, channels_first)
            # (b, [dim], q_z, q_y, q_x, [dim])

        if return_intermediates:
            return embeddings, layer_outputs
        return embeddings

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level5(self._forward, *args, **kwargs)

In [None]:
test_config = ViT3DDecoderConfig.model_validate(
    {
        "dim": 54,
        "num_heads": 6,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "attn_drop_prob": 0.0,
        "proj_drop_prob": 0.0,
        "mlp_drop_prob": 0.0,
        "decoder_depth": 5,
        "rotary_position_embeddings_config": {},
    }
)

test = ViT3DDecoder(test_config)
display(test)
o = test(torch.randn(2, 54, 4, 4, 4), torch.randn(2, 54, 6, 6, 6), return_intermediates=True)
display((o[0].shape, [x.shape for x in o[1]]))

In [None]:
test_config = ViT3DDecoderConfig.model_validate(
    {
        "dim": 54,
        "num_heads": 6,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "attn_drop_prob": 0.0,
        "proj_drop_prob": 0.0,
        "mlp_drop_prob": 0.0,
        "decoder_depth": 5,
        "rotary_position_embeddings_config": {},
    }
)

test = ViT3DDecoder(test_config)
display(test)
o = test(
    torch.randn(2, 64, 54),
    torch.randn(2, 216, 54),
    return_intermediates=True,
    q_grid_shape=(4, 4, 4),
    kv_grid_shape=(6, 6, 6),
)
display((o[0].shape, [x.shape for x in o[1]]))

# Models

In [None]:
# | export


@populate_docstring
class ViT3DEncoderWithPatchEmbeddings(nn.Module, PyTorchModelHubMixin):
    """Patchification of input array followed by a ViT encoder. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: ViT3DEncoderWithPatchEmbeddingsConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the ViT3DEncoderWithPatchEmbeddings.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = ViT3DEncoderWithPatchEmbeddingsConfig.model_validate(config | kwargs)

        self.patchify = PatchEmbeddings3D(self.config, checkpointing_level=checkpointing_level)

        self.absolute_position_embeddings = None
        if self.config.absolute_position_embeddings_config is not None:
            self.absolute_position_embeddings = AbsolutePositionEmbeddings3D(
                self.config.absolute_position_embeddings_config
            )

        self.num_class_tokens = self.config.num_class_tokens
        if self.num_class_tokens > 0:
            self.class_tokens = nn.Parameter(torch.randn(1, self.config.num_class_tokens, self.config.dim))
        self.encoder = ViT3DEncoder(self.config, checkpointing_level=checkpointing_level)

    @populate_docstring
    def forward(
        self,
        pixel_values: torch.Tensor,
        spacings: torch.Tensor | None = None,
        channels_first: bool = True,
        return_intermediates: bool = False,
    ) -> tuple[torch.Tensor, list[torch.Tensor]] | tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
        """Patchify the input datapoint and then pass through the ViT encoder (self attention).

        Args:
            pixel_values: {INPUT_3D_DOC}
            spacings: {SPACINGS_DOC}
            channels_first: {CHANNELS_FIRST_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}

        Returns:
            {OUTPUT_3D_DOC} If `return_intermediates` is True, returns a tuple of the output embeddings and a list of
            intermediate embeddings.
        """
        # pixel_values: (b, [c], z, y, x, [c])
        # spacings: (b, 3)

        pixel_values = rearrange_channels(pixel_values, channels_first, True)
        # (b, c, z, y, x)

        embeddings = self.patchify(pixel_values, channels_first=True)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        if self.absolute_position_embeddings is not None:
            embeddings = self.absolute_position_embeddings(embeddings, spacings=spacings, channels_first=True)
            # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        query_grid_shape = (embeddings.shape[2], embeddings.shape[3], embeddings.shape[4])
        embeddings = rearrange(embeddings, "b e nz ny nx -> b (nz ny nx) e").contiguous()
        # (b, num_tokens, dim)

        class_tokens = None
        if self.num_class_tokens > 0:
            class_tokens = repeat(self.class_tokens, "1 n d -> b n d", b=embeddings.shape[0])
            embeddings = torch.cat([class_tokens, embeddings], dim=1)
            # (b, num_class_tokens + num_tokens, dim)

        encoded, layer_outputs = self.encoder(embeddings, return_intermediates=True, query_grid_shape=query_grid_shape)
        # encoded: (b, (num_class_tokens +) num_tokens, dim)
        # layer_outputs: list of (b, (num_class_tokens +) num_tokens, dim)

        if self.num_class_tokens > 0:
            class_tokens = encoded[:, : self.num_class_tokens]
            encoded = encoded[:, self.num_class_tokens :]

        if return_intermediates:
            return encoded, class_tokens, layer_outputs
        return encoded, class_tokens

In [None]:
test_config = ViT3DEncoderWithPatchEmbeddingsConfig.model_validate(
    {
        "num_class_tokens": 2,
        "attn_drop_prob": 0.2,
        "dim": 768,
        "encoder_depth": 4,
        "in_channels": 1,
        "num_heads": 4,
        "patch_size": (8, 16, 16),
        "proj_drop_prob": 0.2,
        "rotary_position_embeddings_config": {},
    }
)

test = ViT3DEncoderWithPatchEmbeddings(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
    return_intermediates=True,
)
display((o[0].shape, o[1].shape, [x.shape for x in o[2]]))

In [None]:
# | export


@populate_docstring
class ViT3DDecoderWithPatchEmbeddings(nn.Module, PyTorchModelHubMixin):
    """Patchification of input array followed by a ViT Decoder. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: ViT3DDecoderWithPatchEmbeddingsConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the ViT3DDecoderWithPatchEmbeddings.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = ViT3DDecoderWithPatchEmbeddingsConfig.model_validate(config | kwargs)

        self.patchify = PatchEmbeddings3D(self.config, checkpointing_level=checkpointing_level)

        self.absolute_position_embeddings = None
        if self.config.absolute_position_embeddings_config is not None:
            self.absolute_position_embeddings = AbsolutePositionEmbeddings3D(
                self.config.absolute_position_embeddings_config
            )

        self.num_class_tokens = self.config.num_class_tokens
        if self.num_class_tokens > 0:
            self.class_tokens = nn.Parameter(torch.randn(1, self.config.num_class_tokens, self.config.dim))
        self.decoder = ViT3DDecoder(self.config, checkpointing_level=checkpointing_level)

    @populate_docstring
    def forward(
        self,
        pixel_values: torch.Tensor,
        kv: torch.Tensor | None = None,
        spacings: torch.Tensor | None = None,
        channels_first: bool = True,
        return_intermediates: bool = False,
    ) -> tuple[torch.Tensor, list[torch.Tensor]] | tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
        """Patchify the input datapoint and then pass through the ViT encoder (self attention).

        Args:
            pixel_values: {INPUT_3D_DOC}
            kv: {INPUT_3D_DOC} This represents the cache from the encoder.
            spacings: {SPACINGS_DOC}
            channels_first: {CHANNELS_FIRST_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}

        Returns:
            {OUTPUT_3D_DOC} If `return_intermediates` is True, returns a tuple of the output embeddings and a list of
            intermediate embeddings.
        """
        # pixel_values: (b, [c], z, y, x, [c])
        # kv: (b, [dim], kv_z, kv_y, kv_x, [dim])
        # spacings: (b, 3)

        pixel_values = rearrange_channels(pixel_values, channels_first, True)
        # (b, c, z, y, x)
        kv = rearrange_channels(kv, channels_first, True)
        # (b, dim, kv_z, kv_y, kv_x)

        embeddings = self.patchify(pixel_values, channels_first=True)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        if self.absolute_position_embeddings is not None:
            embeddings = self.absolute_position_embeddings(embeddings, spacings=spacings, channels_first=True)
            # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        q_grid_shape = (embeddings.shape[2], embeddings.shape[3], embeddings.shape[4])
        embeddings = rearrange(embeddings, "b e nz ny nx -> b (nz ny nx) e").contiguous()
        # (b, num_q_tokens, dim)
        kv_grid_shape = (kv.shape[2], kv.shape[3], kv.shape[4])
        kv = rearrange(kv, "b e nz ny nx -> b (nz ny nx) e").contiguous()
        # (b, num_kv_tokens, dim)

        class_tokens = None
        if self.num_class_tokens > 0:
            class_tokens = repeat(self.class_tokens, "1 n d -> b n d", b=embeddings.shape[0])
            embeddings = torch.cat([class_tokens, embeddings], dim=1)
            # (b, num_class_tokens + num_q_tokens, dim)

        encoded, layer_outputs = self.decoder(
            q=embeddings, kv=kv, return_intermediates=True, q_grid_shape=q_grid_shape, kv_grid_shape=kv_grid_shape
        )
        # encoded: (b, (num_class_tokens +) num_q_tokens, dim)
        # layer_outputs: list of (b, (num_class_tokens +) num_q_tokens, dim)

        if self.num_class_tokens > 0:
            class_tokens = encoded[:, : self.num_class_tokens]
            encoded = encoded[:, self.num_class_tokens :]

        if return_intermediates:
            return encoded, class_tokens, layer_outputs
        return encoded, class_tokens

In [None]:
test_config = ViT3DDecoderWithPatchEmbeddingsConfig.model_validate(
    {
        "num_class_tokens": 2,
        "attn_drop_prob": 0.2,
        "dim": 768,
        "decoder_depth": 4,
        "in_channels": 1,
        "num_heads": 4,
        "patch_size": (8, 16, 16),
        "proj_drop_prob": 0.2,
        "rotary_position_embeddings_config": {},
        "absolute_position_embeddings_config": None,
    }
)

test = ViT3DDecoderWithPatchEmbeddings(test_config)
display(test)
o = test(
    torch.randn(2, 1, 16, 128, 128),
    torch.randn(2, 768, 8, 8, 8),
    return_intermediates=True,
)
display((o[0].shape, o[1].shape, [x.shape for x in o[2]]))

# nbdev

In [None]:
!nbdev_export