# Multi Query Attention

In [57]:
import math
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

In [58]:
n_head = 6
n_embed = 48
block_size = 8
batch_size = 4
bias = False
dropout = 0.2

x = torch.tensor(np.random.rand(batch_size, block_size, n_embed), dtype=torch.float32)

## LL1

In [56]:
class MultiHeadAttention_LL1(nn.Module):
    def __init__(
        self,
        block_size: int,
        n_embed: int,
        n_head: int,
        n_query_groups: int,
        dropout: float,
        bias: bool,
    ):
        super().__init__()
        assert n_embed % n_head == 0
        self.block_size = block_size
        self.n_embed = n_embed
        self.n_head = n_head
        self.n_query_groups = n_query_groups
        self.hs = n_embed // n_head
        self.dropout = dropout
        self.bias = bias

        self.c_attn = nn.Linear(n_embed, (n_head + 2 * n_query_groups) * self.hs, bias)
        self.c_proj = nn.Linear(n_embed, n_embed, bias)
        self.dropout_attn = nn.Dropout(dropout)
        self.dropout_residual = nn.Dropout(dropout)

        ltm = torch.tril(torch.ones(block_size, block_size)).view(
            1, 1, block_size, block_size
        )
        self.register_buffer("causal_mask", ltm)

    def forward(self, x) -> torch.Tensor:
        B, T, C = x.size()

        qkv = self.c_attn(x)

        q_per_kv = self.n_head // self.n_query_groups
        total_qkv = q_per_kv + 2
        qkv = qkv.view(B, T, self.n_query_groups, total_qkv, self.hs)
        qkv = qkv.permute(0, 2, 3, 1, 4)  # B, n_query_groups, total_qkv, T, hs
        q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)

        q = q.contiguous().view(B, -1, T, self.hs)
        k = k.contiguous().view(B, -1, T, self.hs)
        v = v.contiguous().view(B, -1, T, self.hs)

        attn = q @ k.transpose(-2, -1)
        attn = attn.masked_fill(self.causal_mask == 0, float("-inf"))
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout_attn(attn)

        y = attn @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.dropout_residual(self.c_proj(y))
        return y


MultiHeadAttention_LL1(
    block_size=block_size,
    n_embed=n_embed,
    n_head=n_head,
    n_query_groups=1,  # Multi-Query-Attention
    dropout=0.2,
    bias=False,
)(x).shape

torch.Size([4, 8, 48])

## LL2 

In [66]:
class CausalAttention_LL2(nn.Module):
    def __init__(
        self,
        block_size: int,
        n_embed: int,
        n_head: int,
        dropout: float,
        bias: bool,
        n_query_groups: int = 1,
    ):
        super().__init__()
        assert n_embed % n_head == 0

        self.block_size = block_size
        self.n_embed = n_embed
        self.dropout = dropout
        self.bias = bias
        self.n_head = n_head
        self.hs = n_embed // n_head
        self.n_query_groups = n_query_groups

        shape = (n_head + 2 * n_query_groups) * self.hs

        self.c_attn = nn.Linear(n_embed, shape, bias=bias)
        self.c_proj = nn.Linear(n_embed, n_embed, bias=bias)
        self.dropout_attn = nn.Dropout(dropout)
        self.dropout_residual = nn.Dropout(dropout)

        ltm = torch.tril(torch.ones(block_size, block_size)).view(
            1, 1, block_size, block_size
        )
        self.register_buffer("causal_mask", ltm)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.size()
        assert C == self.n_embed

        qkv = self.c_attn(x)
        q_per_kv = self.n_head // self.n_query_groups
        total_qkv = q_per_kv + 2
        qkv = qkv.view(B, T, self.n_query_groups, total_qkv, self.hs)
        qkv = qkv.permute(0, 2, 3, 1, 4)  # B, n_query_groups, total_qkv, T, hs
        q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)

        q = q.contiguous().view(B, -1, T, self.hs)
        k = k.contiguous().view(B, -1, T, self.hs)
        v = v.contiguous().view(B, -1, T, self.hs)

        # attn = (q @ k.transpose(-2, -1)) * (1 / math.sqrt(k.size(-1)))
        # attn = attn.masked_fill(self.causal_mask == 0, float("-inf"))
        # attn = F.softmax(attn, dim=-1)
        # attn = self.dropout_attn(attn)

        attn = F.scaled_dot_product_attention(
            q,
            k,
            v,
            scale=1 / math.sqrt(k.size(-1)),
            is_causal=True,
            dropout_p=self.dropout if self.training else 0,
        )

        y = attn @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        y = self.dropout_residual(self.c_proj(y))

        return y


CausalAttention_LL2(
    block_size=block_size, n_embed=n_embed, n_head=n_head, dropout=0.2, bias=False
)(x).shape

torch.Size([4, 8, 48])

## LL2.2

In [91]:
class CausalAttention_LL22(nn.Module):
    def __init__(
        self,
        block_size: int,
        n_embed: int,
        n_head: int,
        dropout: float,
        bias: bool,
        n_query_groups: int = 1,
    ):
        super().__init__()
        assert n_embed % n_head == 0

        self.block_size = block_size
        self.n_embed = n_embed
        self.n_head = n_head
        self.dropout = dropout
        self.hs = n_embed // n_head
        self.n_query_groups = n_query_groups

        shape = (self.n_head + 2 * self.n_query_groups) * self.hs
        self.c_attn = nn.Linear(n_embed, shape, bias=bias)
        self.c_proj = nn.Linear(n_embed, n_embed, bias=bias)
        self.dropout_attn = nn.Dropout(dropout)
        self.dropout_residual = nn.Dropout(dropout)

        ltm = torch.tril(torch.ones(block_size, block_size)).view(
            1, 1, block_size, block_size
        )
        self.register_buffer("causal_mask", ltm)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.size()
        assert C == self.n_embed

        qkv = self.c_attn(x)
        q_per_kv = self.n_head // self.n_query_groups
        total_qkv = q_per_kv + 2
        qkv = qkv.view(B, T, self.n_query_groups, total_qkv, self.hs)
        qkv = qkv.permute(0, 2, 3, 1, 4)  # B, n_query_groups, total_qkv, T, hs

        q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
        q = q.contiguous().view(B, -1, T, self.hs)
        k = k.contiguous().view(B, -1, T, self.hs)
        v = v.contiguous().view(B, -1, T, self.hs)

        # attn = F.scaled_dot_product_attention(
        #    q,
        #    k,
        #    v,
        #    is_causal=True,
        #    dropout_p=self.dropout if self.training else 0,
        #    scale=1.0 / math.sqrt(k.size(-1)),
        # )

        attn = (q @ k.transpose(-2, -1)) * (1 / math.sqrt(k.size(-1)))
        attn = attn.masked_fill(self.causal_mask == 0, float("-inf"))
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout_attn(attn)

        y = attn @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.dropout_residual(self.c_proj(y))

        return y


CausalAttention_LL22(
    block_size=block_size, n_embed=n_embed, n_head=n_head, dropout=0.2, bias=False
)(x).shape

torch.Size([4, 8, 48])

## LL29

In [107]:
class CausalAttention_LL29(nn.Module):
    def __init__(
        self,
        block_size: int,
        n_embed: int,
        n_head: int,
        dropout: float,
        bias: bool,
        n_query_groups: int = 1,
    ):
        super().__init__()
        assert n_embed % n_head == 0

        self.block_size = block_size
        self.hs = n_embed // n_head
        self.n_embed = n_embed
        self.dropout = dropout
        self.n_query_groups = n_query_groups
        self.n_head = n_head

        shape = (self.n_head + 2 * self.n_query_groups) * self.hs
        self.c_attn = nn.Linear(n_embed, shape, bias=bias)
        self.c_proj = nn.Linear(n_embed, n_embed, bias=bias)
        self.dropout_attn = nn.Dropout(dropout)
        self.dropout_residual = nn.Dropout(dropout)

        ltm = torch.tril(torch.ones(block_size, block_size)).view(
            1, 1, block_size, block_size
        )
        self.register_buffer("causal_mask", ltm)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.size()
        assert C == self.n_embed

        qkv = self.c_attn(x)
        q_per_kv = self.n_head // self.n_query_groups
        total_qkv = q_per_kv + 2
        qkv = qkv.view(B, T, self.n_query_groups, total_qkv, self.hs)
        qkv = qkv.permute(0, 2, 3, 1, 4)
        q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)

        q = q.contiguous().view(
            B, -1, T, self.hs
        )  # B, nh_q(q_per_kv * n_query_groups), T, hs
        k = k.contiguous().view(
            B, -1, T, self.hs
        )  # B, nh_k(q_per_kv * n_query_groups), T, hs
        v = v.contiguous().view(
            B, -1, T, self.hs
        )  # B, nh_v(q_per_kv * n_query_groups), T, hs
        print(f"{q.size()=} {k.size()=} {v.size()=}")

        # attn = (q @ k.transpose(-2, -1)) * (1 / math.sqrt(self.k.size(-1)))
        # attn = attn.masked_fill(self.causal_mask == 0, float("-inf"))
        # attn = F.softmax(attn, dim=-1)
        # attn = self.dropout_attn(attn)
        attn = F.scaled_dot_product_attention(
            q,
            k,
            v,
            is_causal=True,
            scale=1 / math.sqrt(k.size(-1)),
            dropout_p=self.dropout if self.training else 0,
        )

        y = attn @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.dropout_residual(self.c_proj(y))

        return y


CausalAttention_LL29(
    block_size=block_size,
    n_embed=n_embed,
    n_head=n_head,
    dropout=0.2,
    bias=False,
    n_query_groups=1,
)(x).shape

q.size()=torch.Size([4, 6, 8, 8]) k.size()=torch.Size([4, 1, 8, 8]) v.size()=torch.Size([4, 1, 8, 8])


torch.Size([4, 8, 48])