# Demonstration of `Attention` mechanism in `Transformer` model

This implementation is based on only `Torch`, and the reference is [Mistral-7B-v0.1](https://github.com/mistralai/mistral-src/blob/main/mistral/model.py).
The attention mechanism in this notebook is **Grouped-Query Attention**.
Thus, when `n_kv_heads` is 1, it is equivalent to the **Multi-Query Attention**; when `n_kv_heads` is equal to `n_heads`, it is equivalent to the **Multi-Head Attention**.

In [1]:
from dataclasses import dataclass
from simple_parsing.helpers import Serializable
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

In [2]:
# utility functions for Rotary Positional Embedding in the Attention layer
def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor:
    freqs = 1.0 / (
        theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
    )
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    return torch.polar(torch.ones_like(freqs), freqs)

def apply_rotary_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    q_ = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2))
    k_ = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2))
    freqs_cis = freqs_cis[:, None, :]
    q_out = torch.view_as_real(q_ * freqs_cis).flatten(-2)
    k_out = torch.view_as_real(k_ * freqs_cis).flatten(-2)
    return q_out.to(q.dtype), k_out.to(k.dtype)

In [3]:
# dummy q, k, and freqs_cis to test `apply_rotary_emb`
_q = torch.randn(2, 4, 3, 8) # batch, seq_len, n_heads, emb_dim
_k = torch.randn(2, 4, 3, 8)
_freqs_cis = precompute_freqs_cis(8, 4, 10) # emb_dim, seq_len, theta
_freqs_cis.shape
q_out, k_out = apply_rotary_emb(_q, _k, _freqs_cis)
q_out.shape, k_out.shape

(torch.Size([2, 4, 3, 8]), torch.Size([2, 4, 3, 8]))

In [4]:
@dataclass
class AttnArgs(Serializable):
    block_size: int
    emb_dim: int
    n_layers: int
    head_dim: int
    hidden_dim: int
    n_heads: int
    n_kv_heads: int
    norm_eps: float
    p_drop: float

    rope_theta: Optional[float] = None

In [5]:
# dummy config to test `AttnArgs`
config = AttnArgs(
    block_size=4,
    emb_dim=8,
    n_layers=3,
    head_dim=6,
    hidden_dim=16,
    n_heads=9,
    n_kv_heads=3,
    norm_eps=1e-6,
    p_drop=0.1,
    rope_theta=10000,
)

Below we define an `Attention` layer with grouped-query attention mechanism.
The `n_heads` is the number of query heads, and the `n_kv_heads` is the number of key-value heads.
Thus, the number of group is `n_kv_heads // n_heads`.

In [6]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, config: AttnArgs) -> None:
        super().__init__()
        self.config = config

        self.n_heads: int = config.n_heads
        self.n_kv_heads: int = config.n_kv_heads
        self.head_dim: int = config.head_dim

        self.repeats = self.n_heads // self.n_kv_heads
        self.scale = config.head_dim ** -0.5

        self.q_proj = nn.Linear(
            config.emb_dim, config.n_heads * config.head_dim, bias=False
        )
        self.k_proj = nn.Linear(
            config.emb_dim, config.n_kv_heads * config.head_dim, bias=False
        )
        self.v_proj = nn.Linear(
            config.emb_dim, config.n_kv_heads * config.head_dim, bias=False
        )
        self.o_proj = nn.Linear(
            config.n_heads * config.head_dim, config.emb_dim, bias=False
        )

        self.p_drop = config.p_drop

        self.register_buffer(
            "attn_bias",
            torch.tril(
                torch.ones(config.block_size, config.block_size)
            ).view(1, 1, config.block_size, config.block_size)
        )

    def forward(self, x: torch.Tensor, freq_cis: torch.Tensor) -> torch.Tensor:
        B, T, _ = x.size()

        q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x)

        q = q.view(B, T, self.n_heads, self.head_dim)
        k = k.view(B, T, self.n_kv_heads, self.head_dim)
        v = v.view(B, T, self.n_kv_heads, self.head_dim)

        q, k = apply_rotary_emb(q, k, freq_cis)
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        k = torch.repeat_interleave(k, repeats=self.repeats, dim=1)
        v = torch.repeat_interleave(v, repeats=self.repeats, dim=1)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.masked_fill(
            self.attn_bias[:, :, :T, :T] == 0, float("-inf")
        )
        attn = attn.softmax(dim=-1)
        attn = F.dropout(attn, self.p_drop, training=self.training)
        y = (attn @ v).transpose(1, 2).contiguous().reshape(B, T, -1)
        y = F.dropout(y, self.p_drop, training=self.training)
        return self.o_proj(y)

In [7]:
model = GroupedQueryAttention(config)

In [8]:
x = torch.randn(2, 4, 8)
freqs_cis = precompute_freqs_cis(config.head_dim, 4, config.rope_theta)
y = model(x, freqs_cis)
y.shape

torch.Size([2, 4, 8])