In [1]:
# | default_exp layers/attention

# Imports

In [2]:
# | export

from functools import partial
from typing import Literal

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn

from vision_architectures.layers.embeddings import RelativePositionEmbeddings
from vision_architectures.utils.activations import get_act_layer

# Architecture

In [3]:
# | export


class Attention1D(nn.Module):
    """
    Performs attention (MHA, GQA, and MQA) on 1D sequences
    Parameters:
        - dim: Input dimensions of k, q, and v. Attention happens at dim=dim_qk.
            int: dimension of q, k, and v
            tuple[int, int]: first int is dimension of q and k, second int is dimension of v
        - num_q_heads: number of query heads
        - ratio_q_to_kv_heads: number of query heads per key and value head
        - relative_position_bias: RelativePositionEmbeddings object
        - logit_scale: scale of the logits. Defaults to 1/sqrt(d)
        - logit_scale_learnable: whether the logit scale is learnable
        - attn_drop_prob: dropout probability for attention weights
        - proj_drop_prob: dropout probability for projection layer
    """

    def __init__(
        self,
        dim: int | tuple[int, int],
        num_q_heads: int,
        ratio_q_to_kv_heads: int = 1,  # num q heads per kv head
        relative_position_bias: RelativePositionEmbeddings | None = None,
        logit_scale=None,
        logit_scale_learnable: bool = False,
        attn_drop_prob=0.0,
        proj_drop_prob=0.0,
    ):
        super().__init__()

        gqa_mqa_enabled = ratio_q_to_kv_heads != 1
        if gqa_mqa_enabled:
            assert torch.__version__ >= "2.5", "Need PyTorch version >= 2.5 for GQA and MQA"

        if isinstance(dim, int):
            dim_qk, dim_v = dim, dim
        else:
            dim_qk, dim_v = dim

        num_kv_heads = num_q_heads // ratio_q_to_kv_heads

        assert dim_qk % num_q_heads == 0, "dimension must be divisible by number of heads"
        assert (
            num_q_heads % num_kv_heads == 0
        ), "number of query heads must be divisible by number of key and value heads"

        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.gqa_mqa_enabled = gqa_mqa_enabled

        self.per_head_dim = int(dim_qk // num_q_heads)

        self.W_q = nn.Linear(dim_qk, dim_qk)
        self.W_k = nn.Linear(dim_qk, dim_qk // ratio_q_to_kv_heads)
        self.W_v = nn.Linear(dim_v, dim_qk // ratio_q_to_kv_heads)
        self.attn_drop_prob = attn_drop_prob
        self.proj = nn.Linear(dim_qk, dim_qk)
        self.proj_drop = nn.Dropout(proj_drop_prob)

        if logit_scale is None:
            self.logit_scale = nn.Parameter(
                torch.tensor([self.per_head_dim**-0.5]), requires_grad=logit_scale_learnable
            )
        else:
            self.logit_scale = logit_scale

        self.relative_position_bias = relative_position_bias

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
        """
        Parameters: T => number of tokens, b => batch size
            - query: (b, T_q, dim_qk)
            - key: (b, T_kv, dim_qk)
            - value: (b, T_kv, dim_v)
        """

        query = self.W_q(query)
        key = self.W_k(key)
        value = self.W_v(value)

        rearrange_partial = partial(rearrange, pattern="b T (num_heads d) -> b num_heads T d")
        query = rearrange_partial(query, num_heads=self.num_q_heads)
        key = rearrange_partial(key, num_heads=self.num_kv_heads)
        value = rearrange_partial(value, num_heads=self.num_kv_heads)
        # query: (b, num_q_heads, T, per_head_dim)
        # key: (b, num_kv_heads, T, per_head_dim)
        # value: (b, num_kv_heads, T, per_head_dim)

        if isinstance(self.logit_scale, nn.Module):
            logit_scale = self.logit_scale()
        else:
            logit_scale = self.logit_scale

        query_normalized = F.normalize(query, dim=-1)
        key_normalized = F.normalize(key, dim=-1)

        query_normalized_and_scaled = query_normalized * logit_scale  # Scale the query beforehand

        relative_position_bias = None
        if self.relative_position_bias is not None:
            relative_position_bias = self.relative_position_bias()

        output = F.scaled_dot_product_attention(
            query_normalized_and_scaled,
            key_normalized,
            value,
            attn_mask=relative_position_bias,  # Use this as a way to introduce relative position bias
            dropout_p=self.attn_drop_prob,
            is_causal=False,
            scale=1.0,  # Already scaled the vectors
            enable_gqa=self.gqa_mqa_enabled,
        )
        # (b, num_q_heads, T, per_head_dim)

        output = rearrange(output, "b num_heads T d -> b T (num_heads d)")
        # (b, T, dim_qk)

        output = self.proj(output)
        output = self.proj_drop(output)
        # (b, T, dim_qk)

        return output

In [4]:
test = Attention1D((30, 60), 6, ratio_q_to_kv_heads=2, logit_scale=4.0)
q = torch.randn(2, 64, 30)
k = torch.randn(2, 32, 30)
v = torch.randn(2, 32, 60)

display(test)
display(test(q, k, v).shape)


[1;35mAttention1D[0m[1m([0m
  [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m30[0m, [33mout_features[0m=[1;36m30[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m30[0m, [33mout_features[0m=[1;36m15[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m60[0m, [33mout_features[0m=[1;36m15[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m30[0m, [33mout_features[0m=[1;36m30[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m30[0m[1m][0m[1m)[0m

In [5]:
# | export


class Attention3D(Attention1D):
    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, channels_first: bool = True):
        """
        Parameters: z => depth, y => height, x => width, b => batch size
            - query: (b, [dim_qk], z_q, y_q, x_q, [dim_qk])
            - key: (b, [dim_qk], z_k, y_k, x_k, [dim_qk])
            - value: (b, [dim_v], z_k, y_k, x_k, [dim_v])
            - channels_first: if True, BCDHW expected, else BDHWC

        Constraints:
            - d_q * h_q * w_q = d_k * h_k * w_k
        """

        if channels_first:
            z_q, y_q, x_q = query.shape[2:5]
            forward_pattern = "b d z y x -> b (z y x) d"
            reverse_pattern = "b (z y x) d -> b d z y x"
        else:
            z_q, y_q, x_q = query.shape[1:4]
            forward_pattern = "b z y x d -> b (z y x) d"
            reverse_pattern = "b (z y x) d -> b z y x d"

        query = rearrange(query, forward_pattern)
        key = rearrange(key, forward_pattern)
        value = rearrange(value, forward_pattern)

        output = super().forward(query, key, value)

        output = rearrange(output, reverse_pattern, z=z_q, y=y_q, x=x_q)

        return output

In [6]:
test = Attention3D((30, 60), 6, ratio_q_to_kv_heads=2, logit_scale=4.0)
q = torch.randn(2, 30, 4, 4, 4)
k = torch.randn(2, 30, 2, 4, 4)
v = torch.randn(2, 60, 2, 4, 4)

display(test)
display(test(q, k, v, channels_first=True).shape)


[1;35mAttention3D[0m[1m([0m
  [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m30[0m, [33mout_features[0m=[1;36m30[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m30[0m, [33mout_features[0m=[1;36m15[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m60[0m, [33mout_features[0m=[1;36m15[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m30[0m, [33mout_features[0m=[1;36m30[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m30[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m

In [7]:
# | export


class Attention1DMLP(nn.Module):
    def __init__(self, dim, mlp_ratio, activation='gelu', mlp_drop_prob=0.0):
        super().__init__()
        self.dense1 = nn.Linear(dim, dim * mlp_ratio)

        if isinstance(activation, nn.Module):
            self.act = activation
        else:
            self.act = get_act_layer(activation)

        self.dense2 = nn.Linear(dim * mlp_ratio, dim)
        self.dropout = nn.Dropout(mlp_drop_prob)

    def forward(self, hidden_states: torch.Tensor):
        # hidden_states: (b, T, dim)
        hidden_states = self.dense1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.dense2(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states

In [8]:
test = Attention1DMLP(64, 4, "relu", 0.2)
display(test)
display(test(torch.randn(2, 28, 64)).shape)


[1;35mAttention1DMLP[0m[1m([0m
  [1m([0mdense1[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m256[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
  [1m([0mdense2[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m256[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mdropout[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.2[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m28[0m, [1;36m64[0m[1m][0m[1m)[0m

In [9]:
# | export


class Attention3DMLP(Attention1DMLP):
    def forward(self, hidden_states: torch.Tensor, channels_first: bool = True):
        # hidden_states: (b, dim, z, y, x) or (b, z, y, x, dim)

        if channels_first:
            hidden_states = rearrange(hidden_states, "b d z y x -> b z y x d")

        hidden_states = super().forward(hidden_states)

        if channels_first:
            hidden_states = rearrange(hidden_states, "b z y x d -> b d z y x")

        return hidden_states

In [10]:
test = Attention3DMLP(64, 4, "relu", 0.2)
display(test)
display(test(torch.randn(2, 64, 4, 4, 4)).shape)


[1;35mAttention3DMLP[0m[1m([0m
  [1m([0mdense1[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m64[0m, [33mout_features[0m=[1;36m256[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
  [1m([0mdense2[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m256[0m, [33mout_features[0m=[1;36m64[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mdropout[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.2[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m

In [11]:
# | export


class Attention1DWithMLP(nn.Module):
    def __init__(
        self,
        dim: int | tuple[int, int],
        num_q_heads: int,
        ratio_q_to_kv_heads: int = 1,
        mlp_ratio: int = 4,
        qkv_relative_position_bias=None,
        qk_scale: float = None,
        qk_scale_learnable: bool = False,
        activation="gelu",
        norm_location: Literal["pre", "post"] = "post",
        layer_norm_eps: float = 1e-6,
        attn_drop_prob: float = 0.0,
        proj_drop_prob: float = 0.0,
        mlp_drop_prob: float = 0.0,
    ):
        super().__init__()

        self.norm_location = norm_location

        if isinstance(dim, int):
            dim_qk = dim
        else:
            dim_qk = dim[0]

        self.attn = Attention1D(
            dim=dim,
            num_q_heads=num_q_heads,
            ratio_q_to_kv_heads=ratio_q_to_kv_heads,
            relative_position_bias=qkv_relative_position_bias,
            logit_scale=qk_scale,
            logit_scale_learnable=qk_scale_learnable,
            attn_drop_prob=attn_drop_prob,
            proj_drop_prob=proj_drop_prob,
        )
        self.layernorm1 = nn.LayerNorm(dim_qk, eps=layer_norm_eps)
        self.mlp = Attention1DMLP(
            dim_qk, mlp_ratio=mlp_ratio, activation=activation, mlp_drop_prob=mlp_drop_prob
        )
        self.layernorm2 = nn.LayerNorm(dim_qk, eps=layer_norm_eps)

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
        # Each is (b, T, dim)
        res_connection1 = query
        # (b, T, dim)

        if self.norm_location == "pre":
            query = self.layernorm1(query)
            key = self.layernorm1(key)
            value = self.layernorm1(value)
            # (b, T, dim)

        hidden_states = self.attn(query, key, value)
        # (b, T, dim)

        if self.norm_location == "post":
            hidden_states = self.layernorm1(hidden_states)
            # (b, T, dim)

        hidden_states = hidden_states + res_connection1
        res_connection2 = hidden_states
        # (b, T, dim)

        if self.norm_location == "pre":
            hidden_states = self.layernorm2(hidden_states)
            # (b, T, dim)

        hidden_states = self.mlp(hidden_states)
        # (b, T, dim)

        if self.norm_location == "post":
            hidden_states = self.layernorm2(hidden_states)
            # (b, T, dim)

        hidden_states = hidden_states + res_connection2
        # (b, T, dim)

        return hidden_states

In [12]:
test = Attention1DWithMLP(54, 3)
output = torch.randn(2, 64, 54)

display(test)
display(test(output, output, output).shape)


[1;35mAttention1DWithMLP[0m[1m([0m
  [1m([0mattn[1m)[0m: [1;35mAttention1D[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mlayernorm1[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m54[0m,[1m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m

In [13]:
# | export


class Attention3DWithMLP(nn.Module):
    def __init__(
        self,
        dim: int | tuple[int, int],
        num_q_heads: int,
        ratio_q_to_kv_heads: int = 1,
        mlp_ratio: int = 4,
        qkv_relative_position_bias=None,
        qk_scale: float = None,
        qk_scale_learnable: bool = False,
        activation="gelu",
        norm_location: Literal["pre", "post"] = "post",
        layer_norm_eps: float = 1e-6,
        attn_drop_prob: float = 0.0,
        proj_drop_prob: float = 0.0,
        mlp_drop_prob: float = 0.0,
    ):
        super().__init__()

        self.norm_location = norm_location

        if isinstance(dim, int):
            dim_qk = dim
        else:
            dim_qk = dim[0]

        self.attn = Attention3D(
            dim=dim,
            num_q_heads=num_q_heads,
            ratio_q_to_kv_heads=ratio_q_to_kv_heads,
            relative_position_bias=qkv_relative_position_bias,
            logit_scale=qk_scale,
            logit_scale_learnable=qk_scale_learnable,
            attn_drop_prob=attn_drop_prob,
            proj_drop_prob=proj_drop_prob,
        )
        self.layernorm1 = nn.LayerNorm(dim_qk, eps=layer_norm_eps)
        self.mlp = Attention3DMLP(
            dim_qk, mlp_ratio=mlp_ratio, activation=activation, mlp_drop_prob=mlp_drop_prob
        )
        self.layernorm2 = nn.LayerNorm(dim_qk, eps=layer_norm_eps)

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, channels_first: bool = True):
        # Each is (b, [dim], tokens_z, tokens_y, tokens_x, [dim])

        if channels_first:
            query = rearrange(query, "b d z y x -> b z y x d")
            key = rearrange(key, "b d z y x -> b z y x d")
            value = rearrange(value, "b d z y x -> b z y x d")
            # (b, tokens_z, tokens_y, tokens_x, dim)

        res_connection1 = query
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if self.norm_location == "pre":
            query = self.layernorm1(query)
            key = self.layernorm1(key)
            value = self.layernorm1(value)
            # (b, tokens_z, tokens_y, tokens_x, dim)

        hidden_states = self.attn(query, key, value, channels_first=False)
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if self.norm_location == "post":
            hidden_states = self.layernorm1(hidden_states)
            # (b, tokens_z, tokens_y, tokens_x, dim)

        hidden_states = hidden_states + res_connection1
        res_connection2 = hidden_states
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if self.norm_location == "pre":
            hidden_states = self.layernorm2(hidden_states)
            # (b, tokens_z, tokens_y, tokens_x, dim)

        hidden_states = self.mlp(hidden_states, channels_first=False)
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if self.norm_location == "post":
            hidden_states = self.layernorm2(hidden_states)
            # (b, tokens_z, tokens_y, tokens_x, dim)

        hidden_states = hidden_states + res_connection2
        # (b, tokens_z, tokens_y, tokens_x, dim)

        if channels_first:
            hidden_states = rearrange(hidden_states, "b z y x d -> b d z y x")
            # (b, dim, tokens_z, tokens_y, tokens_x)

        return hidden_states

In [14]:
test = Attention3DWithMLP(54, 3)
output = torch.randn(2, 54, 4, 4, 4)

display(test)
display(test(output, output, output, channels_first=True).shape)


[1;35mAttention3DWithMLP[0m[1m([0m
  [1m([0mattn[1m)[0m: [1;35mAttention3D[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mlayernorm1[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m54[0m,[1m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m54[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m

# nbdev

In [15]:
!nbdev_export