In [None]:
# | default_exp blocks/transformer

# Imports

In [None]:
# | export


from typing import Literal

import torch
from einops import rearrange
from torch import nn

from vision_architectures.layers.attention import Attention1D, Attention1DConfig, Attention3D, Attention3DConfig
from vision_architectures.layers.embeddings import RelativePositionEmbeddings
from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.activations import get_act_layer
from vision_architectures.utils.custom_base_model import CustomBaseModel
from vision_architectures.utils.rearrange import make_channels_first, make_channels_last
from vision_architectures.utils.residuals import Residual

# Configs

In [None]:
# | export


class Attention1DMLPConfig(CustomBaseModel):
    dim: int
    mlp_ratio: int = 4
    activation: str = "gelu"
    mlp_drop_prob: float = 0.0


class Attention3DMLPConfig(Attention1DMLPConfig):
    pass


class Attention1DWithMLPConfig(Attention1DConfig, Attention1DMLPConfig):
    dim: int | tuple[int, int]
    norm_location: Literal["pre", "post"] = "post"
    layer_norm_eps: float = 1e-6


class Attention3DWithMLPConfig(Attention3DConfig, Attention3DMLPConfig):
    dim: int | tuple[int, int]
    norm_location: Literal["pre", "post"] = "post"
    layer_norm_eps: float = 1e-6

# Architecture

In [None]:
# | export


class Attention1DMLP(nn.Module):
    def __init__(self, config: Attention1DMLPConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__()

        self.config = Attention1DMLPConfig.model_validate(config | kwargs)

        dim = self.config.dim
        mlp_ratio = self.config.mlp_ratio
        activation = self.config.activation
        mlp_drop_prob = self.config.mlp_drop_prob

        self.dense1 = nn.Linear(dim, dim * mlp_ratio)

        if isinstance(activation, nn.Module):
            self.act = activation
        else:
            self.act = get_act_layer(activation)

        self.dense2 = nn.Linear(dim * mlp_ratio, dim)
        self.dropout = nn.Dropout(mlp_drop_prob)

        self.checkpointing_level1 = ActivationCheckpointing(1, checkpointing_level)
        self.checkpointing_level2 = ActivationCheckpointing(2, checkpointing_level)

    def _forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, T, dim)
        def first_half(hidden_states):
            hidden_states = self.dense1(hidden_states)
            hidden_states = self.act(hidden_states)
            return hidden_states

        def second_half(hidden_states):
            hidden_states = self.dense2(hidden_states)
            hidden_states = self.dropout(hidden_states)
            return hidden_states

        hidden_states = self.checkpointing_level1(first_half, hidden_states)
        hidden_states = self.checkpointing_level1(second_half, hidden_states)
        return hidden_states

    def forward(self, *args, **kwargs):
        return self.checkpointing_level2(self._forward, *args, **kwargs)

In [None]:
test = Attention1DMLP(dim=64, activation="relu", mlp_drop_prob=0.2)
display(test)
display(test(torch.randn(2, 28, 64)).shape)


[1;35mAttention1DMLP[0m[1m([0m
  [1m([0mdense1[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m256[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
  [1m([0mdense2[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m256[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mdropout[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.2[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m, [33mcheckpointing_level[0m=[1;36m1[0m[1m)[0m
  [1m([0mcheckpointing_level2[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m, [33mcheckpointing_level[0m=[1;36m2[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m28[0m, [1;36m64[0m[1m][0m[1m)[0m

In [None]:
# | export


class Attention3DMLP(Attention1DMLP):
    def __init__(self, config: Attention3DMLPConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__(config, checkpointing_level, **kwargs)

    def _forward(self, hidden_states: torch.Tensor, channels_first: bool = True):
        # hidden_states: (b, dim, z, y, x) or (b, z, y, x, dim)

        if channels_first:
            hidden_states = make_channels_last(hidden_states)

        hidden_states = super()._forward(hidden_states)

        if channels_first:
            hidden_states = make_channels_first(hidden_states)

        return hidden_states

    def forward(self, *args, **kwargs):
        return self.checkpointing_level1(self._forward, *args, **kwargs)

In [None]:
test = Attention3DMLP(dim=64)
display(test)
display(test(torch.randn(2, 64, 4, 4, 4)).shape)


[1;35mAttention3DMLP[0m[1m([0m
  [1m([0mdense1[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m256[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mact[1m)[0m: [1;35mGELU[0m[1m([0m[33mapproximate[0m=[32m'none'[0m[1m)[0m
  [1m([0mdense2[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m256[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mdropout[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m, [33mcheckpointing_level[0m=[1;36m1[0m[1m)[0m
  [1m([0mcheckpointing_level2[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m, [33mcheckpointing_level[0m=[1;36m2[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m

In [None]:
# | export


class Attention1DWithMLP(nn.Module):
    def __init__(
        self,
        config: Attention1DWithMLPConfig = {},
        relative_position_bias: RelativePositionEmbeddings | None = None,
        logit_scale: float | None = None,
        checkpointing_level: int = 0,
        **kwargs
    ):
        super().__init__()

        self.config = Attention1DWithMLPConfig.model_validate(config | kwargs)

        dim_qk = self.config.dim_qk
        layer_norm_eps = self.config.layer_norm_eps

        self.attn = Attention1D(
            self.config,
            relative_position_bias=relative_position_bias,
            logit_scale=logit_scale,
            checkpointing_level=checkpointing_level,
        )
        self.layernorm1 = nn.LayerNorm(dim_qk, eps=layer_norm_eps)
        self.mlp = Attention1DMLP(self.config, checkpointing_level=checkpointing_level)
        self.layernorm2 = nn.LayerNorm(dim_qk, eps=layer_norm_eps)

        self.residual = Residual()

        self.checkpointing_level3 = ActivationCheckpointing(3, checkpointing_level)

    def _forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
        # Each is (b, T, dim)
        res_connection1 = query
        # (b, T, dim)

        if self.config.norm_location == "pre":
            query = self.layernorm1(query)
            key = self.layernorm1(key)
            value = self.layernorm1(value)
            # (b, T, dim)

        hidden_states = self.attn(query, key, value)
        # (b, T, dim)

        if self.config.norm_location == "post":
            hidden_states = self.layernorm1(hidden_states)
            # (b, T, dim)

        hidden_states = self.residual(res_connection1, hidden_states)
        res_connection2 = hidden_states
        # (b, T, dim)

        if self.config.norm_location == "pre":
            hidden_states = self.layernorm2(hidden_states)
            # (b, T, dim)

        hidden_states = self.mlp(hidden_states)
        # (b, T, dim)

        if self.config.norm_location == "post":
            hidden_states = self.layernorm2(hidden_states)
            # (b, T, dim)

        hidden_states = self.residual(res_connection2, hidden_states)
        # (b, T, dim)

        return hidden_states

    def forward(self, *args, **kwargs):
        return self.checkpointing_level3(self._forward, *args, **kwargs)

In [None]:
test = Attention1DWithMLP(dim=54, num_heads=3)
output = torch.randn(2, 64, 54)

display(test)
display(test(output, output, output).shape)


[1;35mAttention1DWithMLP[0m[1m([0m
  [1m([0mattn[1m)[0m: [1;35mAttention1D[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mlayernorm1[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m54[0m,[1m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m

In [None]:
# | export


class Attention3DWithMLP(nn.Module):
    def __init__(
        self,
        config: Attention3DWithMLPConfig = {},
        relative_position_bias: RelativePositionEmbeddings | None = None,
        logit_scale: float | None = None,
        checkpointing_level: int = 0,
        **kwargs
    ):
        super().__init__()

        self.config = Attention3DWithMLPConfig.model_validate(config | kwargs)

        dim_qk = self.config.dim_qk
        layer_norm_eps = self.config.layer_norm_eps

        self.attn = Attention3D(
            self.config,
            relative_position_bias=relative_position_bias,
            logit_scale=logit_scale,
            checkpointing_level=checkpointing_level,
        )
        self.layernorm1 = nn.LayerNorm(dim_qk, eps=layer_norm_eps)
        self.mlp = Attention3DMLP(self.config, checkpointing_level=checkpointing_level)
        self.layernorm2 = nn.LayerNorm(dim_qk, eps=layer_norm_eps)

        self.residual = Residual()

        self.checkpointing_level3 = ActivationCheckpointing(3, checkpointing_level)

    def _forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        channels_first: bool = True,
    ):
        # Each is (b, [dim], tokens_z, tokens_y, tokens_x, [dim])

        if channels_first:
            query = rearrange(query, "b d z y x -> b z y x d").contiguous()
            key = rearrange(key, "b d z y x -> b z y x d").contiguous()
            value = rearrange(value, "b d z y x -> b z y x d").contiguous()
            # (b, tokens_z, tokens_y, tokens_x, dim)

        res_connection1 = query
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if self.config.norm_location == "pre":
            query = self.layernorm1(query)
            key = self.layernorm1(key)
            value = self.layernorm1(value)
            # (b, tokens_z, tokens_y, tokens_x, dim)

        hidden_states = self.attn(query, key, value, channels_first=False)
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if self.config.norm_location == "post":
            hidden_states = self.layernorm1(hidden_states)
            # (b, tokens_z, tokens_y, tokens_x, dim)

        hidden_states = self.residual(res_connection1, hidden_states)
        res_connection2 = hidden_states
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if self.config.norm_location == "pre":
            hidden_states = self.layernorm2(hidden_states)
            # (b, tokens_z, tokens_y, tokens_x, dim)

        hidden_states = self.mlp(hidden_states, channels_first=False)
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if self.config.norm_location == "post":
            hidden_states = self.layernorm2(hidden_states)
            # (b, tokens_z, tokens_y, tokens_x, dim)

        hidden_states = self.residual(res_connection2, hidden_states)
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if channels_first:
            hidden_states = rearrange(hidden_states, "b z y x d -> b d z y x").contiguous()
            # (b, dim, tokens_z, tokens_y, tokens_x)

        return hidden_states

    def forward(self, *args, **kwargs):
        return self.checkpointing_level3(self._forward, *args, **kwargs)

In [None]:
test = Attention3DWithMLP(dim=54, num_heads=3)
output = torch.randn(2, 54, 4, 4, 4)

display(test)
display(test(output, output, output, channels_first=True).shape)


[1;35mAttention3DWithMLP[0m[1m([0m
  [1m([0mattn[1m)[0m: [1;35mAttention3D[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mlayernorm1[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m54[0m,[1m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m54[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m

In [None]:
# | export


class TransformerEncoderBlock1D(Attention1DWithMLP):
    def forward(self, qkv: torch.Tensor):
        # qkv: (b, num_tokens, dim)
        return super().forward(qkv, qkv, qkv)

In [None]:
test = TransformerEncoderBlock1D(dim=54, num_heads=6, mlp_ratio=2)
display(test)
o = test(torch.randn(2, 64, 54))
display(o.shape)


[1;35mTransformerEncoderBlock1D[0m[1m([0m
  [1m([0mattn[1m)[0m: [1;35mAttention1D[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mlayernorm1[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m54[

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m

In [None]:
# | export


class TransformerEncoderBlock3D(Attention3DWithMLP):
    def forward(self, qkv: torch.Tensor, channels_first: bool = True):
        # qkv: (b, num_tokens, dim)
        return super().forward(qkv, qkv, qkv, channels_first)

In [None]:
test = TransformerEncoderBlock3D(dim=54, num_heads=6, mlp_ratio=2)
display(test)
o = test(torch.randn(2, 54, 4, 4, 4))
display(o.shape)


[1;35mTransformerEncoderBlock3D[0m[1m([0m
  [1m([0mattn[1m)[0m: [1;35mAttention3D[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mlayernorm1[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m54[

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m54[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m

In [None]:
# | export


class TransformerDecoderBlock1D(nn.Module):
    def __init__(self, config: Attention1DWithMLPConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__()

        self.config = Attention1DWithMLPConfig.model_validate(config | kwargs)

        dim = self.config.dim
        num_heads = self.config.num_heads
        mlp_ratio = self.config.mlp_ratio
        layer_norm_eps = self.config.layer_norm_eps
        attn_drop_prob = self.config.attn_drop_prob
        proj_drop_prob = self.config.proj_drop_prob
        mlp_drop_prob = self.config.mlp_drop_prob

        self.attn1 = Attention1D(
            dim=dim,
            num_heads=num_heads,
            attn_drop_prob=attn_drop_prob,
            proj_drop_prob=proj_drop_prob,
        )
        self.layernorm1 = nn.LayerNorm(dim, eps=layer_norm_eps)
        self.attn2 = Attention1D(
            dim=dim,
            num_heads=num_heads,
            attn_drop_prob=attn_drop_prob,
            proj_drop_prob=proj_drop_prob,
        )
        self.layernorm2 = nn.LayerNorm(dim, eps=layer_norm_eps)
        self.mlp = Attention1DMLP(dim=dim, mlp_ratio=mlp_ratio, mlp_drop_prob=mlp_drop_prob)
        self.layernorm3 = nn.LayerNorm(dim, eps=layer_norm_eps)

        self.residual = Residual()

        self.checkpointing_level3 = ActivationCheckpointing(3, checkpointing_level)

    def _forward(self, q: torch.Tensor, kv: torch.Tensor):
        # q: (b, num_tokens_in_q, dim)
        # kv: (b, num_tokens_in_kv, dim)

        res_connection1 = q
        # (b, num_tokens_in_q, dim)

        if self.config.norm_location == "pre":
            q = self.layernorm1(q)
            # (b, num_tokens_in_q, dim)
            kv = self.layernorm1(kv)
            # (b, num_tokens_in_kv, dim)

        hidden_states = self.attn1(q, q, q)
        # (b, num_tokens_in_q, dim)

        if self.config.norm_location == "post":
            hidden_states = self.layernorm1(hidden_states)
            # (b, num_tokens_in_q, dim)

        hidden_states = self.residual(res_connection1, hidden_states)
        res_connection2 = hidden_states
        # (b, num_tokens_in_q, dim)

        if self.config.norm_location == "pre":
            hidden_states = self.layernorm2(hidden_states)
            # (b, num_tokens_in_q, dim)

        hidden_states = self.attn2(hidden_states, kv, kv)
        # (b, num_tokens_in_q, dim)

        if self.config.norm_location == "post":
            hidden_states = self.layernorm2(hidden_states)
            # (b, num_tokens_in_q, dim)

        hidden_states = self.residual(res_connection2, hidden_states)
        res_connection3 = hidden_states
        # (b, num_tokens_in_q, dim)

        if self.config.norm_location == "pre":
            hidden_states = self.layernorm3(hidden_states)
            # (b, num_tokens_in_q, dim)

        hidden_states = self.mlp(hidden_states)
        # (b, num_tokens_in_q, dim)

        if self.config.norm_location == "post":
            hidden_states = self.layernorm3(hidden_states)
            # (b, num_tokens_in_q, dim)

        hidden_states = self.residual(res_connection3, hidden_states)
        # (b, num_tokens_in_q, dim)

        return hidden_states

    def forward(self, *args, **kwargs):
        return self.checkpointing_level3(self._forward, *args, **kwargs)

In [None]:
test = TransformerDecoderBlock1D(dim=52, num_heads=4, mlp_ratio=2)
display(test)
display(test(torch.randn(2, 64, 52), torch.randn(2, 64, 52)).shape)


[1;35mTransformerDecoderBlock1D[0m[1m([0m
  [1m([0mattn1[1m)[0m: [1;35mAttention1D[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mlayernorm1[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m52

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m52[0m[1m][0m[1m)[0m

In [None]:
# | export


class TransformerDecoderBlock3D(nn.Module):
    def __init__(self, config: Attention3DWithMLPConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__()

        self.config = Attention3DWithMLPConfig.model_validate(config | kwargs)

        dim = self.config.dim
        num_heads = self.config.num_heads
        mlp_ratio = self.config.mlp_ratio
        layer_norm_eps = self.config.layer_norm_eps
        attn_drop_prob = self.config.attn_drop_prob
        proj_drop_prob = self.config.proj_drop_prob
        mlp_drop_prob = self.config.mlp_drop_prob

        self.attn1 = Attention3D(
            dim=dim,
            num_heads=num_heads,
            attn_drop_prob=attn_drop_prob,
            proj_drop_prob=proj_drop_prob,
        )
        self.layernorm1 = nn.LayerNorm(dim, eps=layer_norm_eps)
        self.attn2 = Attention3D(
            dim=dim,
            num_heads=num_heads,
            attn_drop_prob=attn_drop_prob,
            proj_drop_prob=proj_drop_prob,
        )
        self.layernorm2 = nn.LayerNorm(dim, eps=layer_norm_eps)
        self.mlp = Attention3DMLP(dim=dim, mlp_ratio=mlp_ratio, mlp_drop_prob=mlp_drop_prob)
        self.layernorm3 = nn.LayerNorm(dim, eps=layer_norm_eps)

        self.residual = Residual()

        self.checkpointing_level3 = ActivationCheckpointing(3, checkpointing_level)

    def _forward(self, q: torch.Tensor, kv: torch.Tensor, channels_first: bool = True):
        # Each is (b, [dim], tokens_z, tokens_y, tokens_x, [dim])

        if channels_first:
            q = rearrange(q, "b d z y x -> b z y x d").contiguous()
            kv = rearrange(kv, "b d z y x -> b z y x d").contiguous()
            # (b, tokens_z, tokens_y, tokens_x, dim)

        res_connection1 = q
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if self.config.norm_location == "pre":
            q = self.layernorm1(q)
            kv = self.layernorm1(kv)
            # (b, tokens_z, tokens_y, tokens_x, dim)

        hidden_states: torch.Tensor = self.attn1(q, q, q, channels_first=False)
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if self.config.norm_location == "post":
            hidden_states = self.layernorm1(hidden_states)
            # (b, tokens_z, tokens_y, tokens_x, dim)

        hidden_states = self.residual(res_connection1, hidden_states)
        res_connection2 = hidden_states
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if self.config.norm_location == "pre":
            hidden_states = self.layernorm2(hidden_states)
            # (b, tokens_z, tokens_y, tokens_x, dim)

        hidden_states = self.attn2(hidden_states, kv, kv, channels_first=False)
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if self.config.norm_location == "post":
            hidden_states = self.layernorm2(hidden_states)
            # (b, tokens_z, tokens_y, tokens_x, dim)

        hidden_states = self.residual(res_connection2, hidden_states)
        res_connection3 = hidden_states
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if self.config.norm_location == "pre":
            hidden_states = self.layernorm3(hidden_states)
            # (b, tokens_z, tokens_y, tokens_x, dim)

        hidden_states = self.mlp(hidden_states, channels_first=False)
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if self.config.norm_location == "post":
            hidden_states = self.layernorm3(hidden_states)
            # (b, tokens_z, tokens_y, tokens_x, dim)

        hidden_states = self.residual(res_connection3, hidden_states)
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if channels_first:
            hidden_states = rearrange(hidden_states, "b z y x d -> b d z y x").contiguous()
            # (b, dim, tokens_z, tokens_y, tokens_x)

        return hidden_states

    def forward(self, *args, **kwargs):
        return self.checkpointing_level3(self._forward, *args, **kwargs)

In [None]:
test = TransformerDecoderBlock3D(dim=52, num_heads=4, mlp_ratio=2)
display(test)
display(test(torch.randn(2, 52, 4, 4, 4), torch.randn(2, 52, 4, 4, 4)).shape)


[1;35mTransformerDecoderBlock3D[0m[1m([0m
  [1m([0mattn1[1m)[0m: [1;35mAttention3D[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mlayernorm1[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m52

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m52[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m

# nbdev

In [None]:
!nbdev_export