In [None]:
# | default_exp nets/cait_3d

# Imports

In [None]:
# | export

from functools import wraps

import torch
from einops import rearrange, repeat
from huggingface_hub import PyTorchModelHubMixin
from torch import nn

from vision_architectures.blocks.transformer import Attention1DMLP, Attention1DMLPConfig
from vision_architectures.docstrings import populate_docstring
from vision_architectures.layers.attention import Attention1D, Attention1DConfig
from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.custom_base_model import Field, model_validator
from vision_architectures.utils.rearrange import rearrange_channels

# Config

In [None]:
# | export


class CaiTAttentionWithMLPConfig(Attention1DConfig, Attention1DMLPConfig):
    layer_norm_eps: float = Field(1e-6, description="Epsilon value for the layer normalization.")


class CaiTStage1Config(CaiTAttentionWithMLPConfig):
    stage1_depth: int = Field(..., description="Number of layers in stage 1.", ge=0)


class CaiTStage2Config(CaiTAttentionWithMLPConfig):
    num_class_tokens: int = Field(1, description="Number of class tokens to be added in stage 2.", ge=0)
    stage2_depth: int = Field(..., description="Number of layers in stage 2.", ge=0)


class CaiTConfig(CaiTStage1Config, CaiTStage2Config):
    @model_validator(mode="after")
    def validate(self):
        super().validate()
        assert self.stage1_depth + self.stage2_depth > 0, "There should be atleast one layer in the model."
        return self

# Architecture

### Basic Layers

In [None]:
# | export


@populate_docstring
class CaiTAttentionWithMLP(nn.Module):
    """Attention layer used in the CaiT 3D model. Introduces learnable gamma scaling of hidden states after the self
    attention and MLP layers. {CLASS_DESCRIPTION_1D_DOC}"""

    @populate_docstring
    def __init__(self, config: CaiTAttentionWithMLPConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initializes the CaiT 3D attention layer.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = CaiTAttentionWithMLPConfig.model_validate(config | kwargs)

        self.mhsa = Attention1D(self.config, checkpointing_level=checkpointing_level)
        self.gamma1 = nn.Parameter(torch.empty(1, 1, self.config.dim))
        self.layernorm1 = nn.LayerNorm(self.config.dim, eps=self.config.layer_norm_eps)
        self.mlp = Attention1DMLP(self.config, checkpointing_level=checkpointing_level)
        self.gamma2 = nn.Parameter(torch.empty(1, 1, self.config.dim))
        self.layernorm2 = nn.LayerNorm(self.config.dim, eps=self.config.layer_norm_eps)

        nn.init.uniform_(self.gamma1, a=-1e-4, b=1e-4)
        nn.init.uniform_(self.gamma2, a=-1e-4, b=1e-4)

        self.checkpointing_level3 = ActivationCheckpointing(3, checkpointing_level)

    @populate_docstring
    def _forward(self, q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
        """Pass the input q and kv tensors through the q, k, and v matrices and then pass them through the CaiT
        attention layer.

        Args:
            q: {INPUT_1D_DOC}
            kv: {INPUT_1D_DOC}

        Returns:
            {OUTPUT_1D_DOC}
        """
        # q: (b, num_tokens_in_q, dim)
        # kv: (b, num_tokens_in_kv, dim)

        res_connection1 = q
        # (b, num_tokens, dim)

        hidden_states = self.layernorm1(q)
        hidden_states = self.mhsa(hidden_states, kv, kv)
        hidden_states = self.gamma1 * hidden_states
        # (b, num_tokens, dim)

        res_connection2 = hidden_states + res_connection1
        # (b, num_tokens, dim)

        hidden_states = self.layernorm2(hidden_states)
        hidden_states = self.mlp(res_connection2)
        hidden_states = self.gamma2 * hidden_states
        # (b, num_tokens, dim)

        hidden_states = hidden_states + res_connection2
        # (b, num_tokens, dim)

        return hidden_states

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level3(self._forward, *args, **kwargs)

In [None]:
test = CaiTAttentionWithMLP(dim=52, num_heads=4, mlp_ratio=2)
display(test)
display(test(torch.randn(2, 64, 52), torch.randn(2, 64, 52)).shape)


[1;35mCaiTAttentionWithMLP[0m[1m([0m
  [1m([0mmhsa[1m)[0m: [1;35mAttention1D[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menable

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m52[0m[1m][0m[1m)[0m

### Stages

In [None]:
# | export


@populate_docstring
class CaiTStage1(nn.Module, PyTorchModelHubMixin):
    """CaiT stage 1. Performs self attention without class tokens focusing on learning features among tokens.
    {CLASS_DESCRIPTION_1D_DOC}"""

    @populate_docstring
    def __init__(self, config: CaiTStage1Config = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the CaiTStage1.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = CaiTStage1Config.model_validate(config | kwargs)

        self.layers = nn.ModuleList(
            [
                CaiTAttentionWithMLP(self.config, checkpointing_level=checkpointing_level)
                for _ in range(self.config.stage1_depth)
            ]
        )

        self.checkpointing_level4 = ActivationCheckpointing(4, checkpointing_level)

    @populate_docstring
    def _forward(self, embeddings: torch.Tensor, return_intermediates: bool = False) -> torch.Tensor:
        """Pass the input embeddings through the CaiT stage 1 layers.

        Args:
            embeddings: {INPUT_1D_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}

        Returns:
            {OUTPUT_1D_DOC} If `return_intermediates` is True, returns a tuple of the output embeddings and a list of
            intermediate layer outputs."""
        # embeddings: (b, num_tokens, dim)

        layer_outputs = []
        for encoder_layer in self.layers:
            embeddings = encoder_layer(embeddings, embeddings)
            # (b, num_tokens, dim)

            layer_outputs.append(embeddings)

        if return_intermediates:
            return embeddings, layer_outputs
        return embeddings

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level4(self._forward, *args, **kwargs)

In [None]:
test = CaiTStage1(dim=52, num_heads=4, stage1_depth=3)
display(test)
display(test(torch.randn(2, 64, 52)).shape)


[1;35mCaiTStage1[0m[1m([0m
  [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m2[0m[1m)[0m: [1;36m3[0m x [1;35mCaiTAttentionWithMLP[0m[1m([0m
      [1m([0mmhsa[1m)[0m: [1;35mAttention1D[0m[1m([0m
        [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m52[0m[1m][0m[1m)[0m

In [None]:
# | export


@populate_docstring
class CaiTStage2(nn.Module, PyTorchModelHubMixin):
    """CaiT stage 2. Performs cross attention between class tokens and learned features from stage 1.
    {CLASS_DESCRIPTION_1D_DOC}"""

    @populate_docstring
    def __init__(self, config: CaiTStage2Config = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the CaiTStage2.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = CaiTStage2Config.model_validate(config | kwargs)

        self.layers = nn.ModuleList(
            [
                CaiTAttentionWithMLP(self.config, checkpointing_level=checkpointing_level)
                for _ in range(self.config.stage2_depth)
            ]
        )

        self.checkpointing_level4 = ActivationCheckpointing(4, checkpointing_level)

    @populate_docstring
    def _forward(
        self, class_tokens: torch.Tensor, embeddings: torch.Tensor, return_intermediates: bool = False
    ) -> torch.Tensor:
        """Pass the input embeddings through the CaiT stage 2 layers.

        Args:

            class_tokens: {INPUT_1D_DOC}
            embeddings: {INPUT_1D_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}

        Returns:
            {OUTPUT_1D_DOC} If `return_intermediates` is True, returns a tuple of the output embeddings and a list of
            intermediate layer outputs."""
        # embeddings: (b, num_tokens, dim)

        class_embeddings = class_tokens

        layer_outputs = []
        for encoder_layer in self.layers:
            class_embeddings = encoder_layer(class_embeddings, embeddings)
            # (b, num_tokens, dim)

            layer_outputs.append(class_embeddings)

        if return_intermediates:
            return class_embeddings, layer_outputs
        return class_embeddings

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level4(self._forward, *args, **kwargs)

In [None]:
test = CaiTStage2(dim=52, num_heads=4, stage2_depth=3, num_class_tokens=2)
display(test)
display(test(torch.randn(2, 1, 52), torch.randn(2, 64, 52)).shape)


[1;35mCaiTStage2[0m[1m([0m
  [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m2[0m[1m)[0m: [1;36m3[0m x [1;35mCaiTAttentionWithMLP[0m[1m([0m
      [1m([0mmhsa[1m)[0m: [1;35mAttention1D[0m[1m([0m
        [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m52[0m[1m][0m[1m)[0m

# Models

In [None]:
# | export


@populate_docstring
class CaiT1D(nn.Module, PyTorchModelHubMixin):
    """End-to-end CaiT model for classification. {CLASS_DESCRIPTION_1D_DOC}"""

    @populate_docstring
    def __init__(self, config: CaiTConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the CaiT Model.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = CaiTConfig.model_validate(config | kwargs)

        self.class_tokens = nn.Parameter(torch.randn(1, self.config.num_class_tokens, self.config.dim))

        self.self_attention = CaiTStage1(self.config)
        self.class_attention = CaiTStage2(self.config)
        self.classifiers = nn.ModuleList([nn.Linear(config.dim, 1) for i in range(self.config.num_class_tokens)])

        self.checkpointing_level5 = ActivationCheckpointing(5, checkpointing_level)

    @populate_docstring
    def _forward(self, tokens: torch.Tensor, return_intermediates: bool = False) -> torch.Tensor | tuple:
        """Pass the input embeddings through the CaiT layers. Expects flattened input.

        Args:
            tokens: {INPUT_1D_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}

        Returns:
            {OUTPUT_1D_DOC}
        """
        # tokens: (b, num_embedding_tokens, dim)

        embeddings, layer_outputs1 = self.self_attention(tokens, return_intermediates=True)

        class_tokens = repeat(self.class_tokens, "1 n d -> b n d", b=embeddings.shape[0])
        # (b, num_class_tokens, dim)

        class_embeddings, layer_outputs2 = self.class_attention(class_tokens, embeddings, return_intermediates=True)
        # class_embeddings: (b, num_class_tokens, dim)
        # layer_outputs: list of (b, num_embedding_tokens, dim)

        class_logits = torch.cat(
            [self.classifiers[i](class_embeddings[:, i]) for i in range(len(self.classifiers))],
            dim=1,
        )
        # list of (b, num_classes) for each class token

        if return_intermediates:
            return class_logits, class_embeddings, [layer_outputs1, layer_outputs2]
        return class_logits

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level5(self._forward, *args, **kwargs)

In [None]:
test_config = CaiTConfig.model_validate(
    {
        "num_class_tokens": 3,
        "attn_drop_prob": 0.2,
        "dim": 768,
        "drop_prob": 0.2,
        "stage1_depth": 2,
        "stage2_depth": 2,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "mlp_drop_prob": 0.2,
        "num_heads": 4,
        "proj_drop_prob": 0.2,
    }
)

test = CaiT1D(test_config)
display(test)
o = test(
    torch.randn(2, 4096, 768),
    return_intermediates=True,
)
display((o[0].shape, o[1].shape, [[x.shape for x in o[2][0]], [x.shape for x in o[2][1]]]))


[1;35mCaiT1D[0m[1m([0m
  [1m([0mself_attention[1m)[0m: [1;35mCaiTStage1[0m[1m([0m
    [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m-[1;36m1[0m[1m)[0m: [1;36m2[0m x [1;35mCaiTAttentionWithMLP[0m[1m([0m
        [1m([0mmhsa[1m)[0m: [1;35mAttention1D[0m[1m([0m
          [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m768[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m768[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m768[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m768[0m, [33mbias[0m=[3;9


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m[1m][0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m768[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4096[0m, [1;36m768[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4096[0m, [1;36m768[0m[1m][0m[1m)[0m[1m][0m,
        [1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m768[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m768[0m[1m][0m[1m)[0m[1m][0m
    [1m][0m
[1m)[0m

In [None]:
# | export


@populate_docstring
class CaiT3D(CaiT1D):
    """End-to-end CaiT model for classification. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def _forward(
        self, tokens: torch.Tensor, channels_first: bool = True, return_intermediates: bool = False
    ) -> torch.Tensor | tuple:
        """Pass the input embeddings through the CaiT layers. Expects flattened input.

        Args:
            tokens: {INPUT_3D_DOC}
            channels_first: {CHANNELS_FIRST_DOC}
            return_intermediates: {RETURN_INTERMEDIATES_DOC}

        Returns:
            {OUTPUT_1D_DOC}
        """
        # tokens: (b, [dim], z, y, x, [dim])

        tokens = rearrange_channels(tokens, channels_first, False)
        # tokens: (b, z, y, x, dim)
        tokens = rearrange(tokens, "b z y x dim -> b (z y x) dim").contiguous()
        # (b, T, dim)

        return super()._forward(tokens, return_intermediates)

In [None]:
test_config = CaiTConfig.model_validate(
    {
        "num_class_tokens": 3,
        "attn_drop_prob": 0.2,
        "dim": 768,
        "drop_prob": 0.2,
        "stage1_depth": 2,
        "stage2_depth": 2,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "mlp_drop_prob": 0.2,
        "num_heads": 4,
        "proj_drop_prob": 0.2,
    }
)

test = CaiT3D(test_config)
display(test)
o = test(
    torch.randn(2, 768, 16, 16, 16),
    return_intermediates=True,
)
display((o[0].shape, o[1].shape, [[x.shape for x in o[2][0]], [x.shape for x in o[2][1]]]))


[1;35mCaiT3D[0m[1m([0m
  [1m([0mself_attention[1m)[0m: [1;35mCaiTStage1[0m[1m([0m
    [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m-[1;36m1[0m[1m)[0m: [1;36m2[0m x [1;35mCaiTAttentionWithMLP[0m[1m([0m
        [1m([0mmhsa[1m)[0m: [1;35mAttention1D[0m[1m([0m
          [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m768[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m768[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m768[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m768[0m, [33mbias[0m=[3;9


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m[1m][0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m768[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4096[0m, [1;36m768[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4096[0m, [1;36m768[0m[1m][0m[1m)[0m[1m][0m,
        [1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m768[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m768[0m[1m][0m[1m)[0m[1m][0m
    [1m][0m
[1m)[0m

# nbdev

In [None]:
!nbdev_export