#### What is Multi-head Latent Attention

I will not rant about what is deepseek, which you can find at their website https://deepseek.ai/. What caught my attention was the Multi-Head latent attention used in the model instead 
of the traditional attention mechanism.

Multi-Head Latent Attention (MHLA) is a key innovation in DeepSeek V2, enhancing the model’s ability to capture complex dependencies while maintaining efficiency. Unlike traditional self-attention, which directly operates on token interactions, MHLA introduces a latent space where multiple attention heads process compressed representations. This reduces computational overhead while preserving expressivity, making it particularly effective for scaling large models. By leveraging this approach, DeepSeek V2 achieves better efficiency-memory trade-offs, enabling faster inference and stronger long-range reasoning—a crucial factor for tasks requiring deep contextual understanding.  

In this post i would take a stab at implementing the MHLA using pytorch. It took me some time to decipher the maths and the tensor shapes but i think i have good grasp of it now. I would refer the paper [Deepseek v2 paper](https://arxiv.org/pdf/2405.04434) from Deepseek AI.

The following images are from the paper which illustrates the MLHA and the math behind it:

![Multi-head latent attention](../../images/mla.JPG)

![maths behind MHLA](../../images/mla-formulas.JPG)


#### Notes on symbols and variables used in the deepseek v2 paper

All the dimensions are only for one token where t in subscript represents token t. The batch and sequence length are omitted for the sake of simplicity.

| symbol | dimensions| remarks|
|--------|----------|-----------|
| $h_t$  |  d       | same as model dim or dmodel used in transformers. These are the embedding dimensions = $d_h$ x $n_h$|
| $d_h$  |          | number of dimensions per attention head |
| $n_h$  |          | number of attention heads|
| $d_c$  |          | KV compression dimension|
| $d_c^{'}$ |       | dimensions for compressed query where  $d_c^{'}$  (<<$d_h$ x $n_h$) |
| $C_t^{KV}$| $d_c$ | compressed latent vector for keys and values where $d_c$ (<<$d_h$ x $n_h$)|
| $W^{DKV}$ | $d_c$ x d | down projection matrix |
| $W^{UK}$  | d x $d_c$ | up projection matrix for keys|
| $W^{UV}$  | d x $d_c$ | up projection matrix for values|
| $W^{KR}$  | $d_h^R$ x d | matrix to produce the decouples keys|
| $K_t^R$   | $d_h^R$     | shared key to carry RoPE |
| $W^{DQ}$  | $d_c^{'}$ x d| down projection matrix for queries |
| $C_t^Q$   | $d_c^{'}$    | compressed laten vector for queries| 
| $W^{UQ}$  | d x $d_c^{'}$ | up projection matrix for queries |
| $q_t^C$   |   d           | queries after up projection|
| $W^{QR}$  | $d_h^R$ x  $d_c^{'}$ | matrix to produce the decouples queries|
| $q_t{R}$   | $d_h^R$         | multihead queries to carry RoPE|


After concatenating rope encodings with queries and keys, the new dimensions become $d_h$ + $d_h^R$. 

I have implemented both traditional multi head self attention and MLHA in pytorch below. 

In [15]:
import torch
from torch.nn import functional as F
from torch import nn

In [16]:
# original attention mechanism from vaswani et al.
batch_size = 4
seq_len = 10
n_embed = 1024
n_heads = 16
d_heads = n_embed // n_heads
latent = torch.randn(batch_size, seq_len, n_embed)

lat_proj = nn.Linear(n_embed, 3 * n_embed)
k, q, v = lat_proj(latent).chunk(3, dim=-1)

q_w = nn.Linear(n_embed, n_embed)
k_w = nn.Linear(n_embed, n_embed)
v_w = nn.Linear(n_embed, n_embed)

q = q_w(q)
k = k_w(k)
v = v_w(v)

q = q.view(batch_size, seq_len, n_heads, d_heads).transpose(1, 2)
k = k.view(batch_size, seq_len, n_heads, d_heads).transpose(1, 2)
v = v.view(batch_size, seq_len, n_heads, d_heads).transpose(1, 2)

weights = torch.matmul(q, k.transpose(-2, -1)) / (n_embed ** 0.5)
weights = F.softmax(weights, dim=-1)
output = torch.matmul(weights, v)

output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, n_embed)

print(" # number of KV cache elements per token per layer", n_heads * d_heads * 2)
print(output.shape)

 # number of KV cache elements per token per layer 2048
torch.Size([4, 10, 1024])


In [4]:
# RoPE encoding
import torch
from torch.nn import functional as F
from torch import nn
class RoPE(nn.Module):
    def __init__(self, seq, n_embed=1024, theta=10000.0):
        super(RoPE, self).__init__()
        self.theta = 1/(theta ** (torch.arange(0, n_embed, 2)[: n_embed //2].float() / n_embed))
        self.seq = seq
        self.build_rope_cache(self.seq)

    def build_rope_cache(self, seq):

        pos = torch.arange(0, seq, dtype=self.theta.dtype)
        thetas = torch.einsum("i, j -> ij", pos, self.theta).float()
        rope = torch.stack([torch.cos(thetas), torch.sin(thetas)], dim=-1)
        self.register_buffer("rope", rope, persistent=False)


    def forward(self, x):
        b, s, nhead, hdim = x.shape

        x = x.view(b, s, nhead, hdim//2 , 2)
        rope = self.rope.view(1, s, nhead, hdim//2 , 2)
        x_out = torch.stack(
            [
                x[..., 0] * rope[..., 0]
                - x[..., 1] * rope[..., 1],
                x[..., 1] * rope[..., 0]
                + x[..., 0] * rope[..., 1],
            ],
            -1,
        )
        x_out = x_out.flatten(3)
        return x_out.type_as(x)

In [17]:
class MultiHeadLatentAttn(nn.Module):
    def __init__(self, d_model, num_heads, d_c, d_r_h, d_hat_c):

        super(MultiHeadLatentAttn, self).__init__()
        assert d_model % num_heads == 0 , "d_model should be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.d_r_h = d_r_h

        # define weights 
        self.w_DQ = nn.Linear(d_model, d_hat_c) # project to smaller dimension d_hat query compression
        self.w_UQ = nn.Linear(d_hat_c, d_model) # d_hat_c, d_model , bring the dimension back to d_model
        self.w_QR = nn.Linear(d_hat_c, d_r_h * self.num_heads) # d_hat_c, d_model  , weights for applying RoPE encoding for queries project d_hat_c --> d_r
        self.w_KR = nn.Linear(d_model, d_r_h) # weights for applying RoPE encoding for keys project d_model --> d_r

        self.w_DKV = nn.Linear(d_model, d_c) # weight to compress keys and values into one latent
        self.w_UK = nn.Linear(d_c, d_model)  # bring the key dimension back to d_model
        self.w_UV = nn.Linear(d_c, d_model) # bring the value dimension back to d_model
        self.rope_query = RoPE(seq_len, d_r_h * self.num_heads)
        self.rope_key = RoPE(seq_len, d_r_h)


    def forward(self, latent):
        ## input latent shape  (batch_size, seq_len, d_model) 
        batch_size, seq_len, d = latent.shape 

        ct_Q = self.w_DQ(latent)    # project to smaller dimension d_hat_c query compression --> batch_size, seq_len, d_hat_c
        qt_C = self.w_UQ(ct_Q)  # bring the dimension back to d_model -->  batch_size, seq_len, d_model
        qt_C = qt_C.view(batch_size, seq_len, self.num_heads, self.head_dim)  # reshape to num_heads

        qt_R = self.w_QR(ct_Q)  # weights for applying RoPE encoding for queries project d_hat_c --> d_r
        qt_R = qt_R.view(batch_size, seq_len, self.num_heads, self.d_r_h)  # reshape to num_heads
        qt_R = self.rope_query(qt_R)

        qt_cat = torch.cat([qt_C, qt_R], dim=-1)  # concatenate the query with the RoPE encoding (batch_size, seq_len, num_heads, d_model//num_heads * 2)
        qt_cat = qt_cat.transpose(1,2) # reshape to b, num_heads, seq_len, dim
        #print("qt_cat", qt_cat.shape)

        ct_KV = self.w_DKV(latent)  # project to smaller dimension d_c --> batch_size, seq_len, d_c
#
        kt_C = self.w_UK(ct_KV)  # bring the key dimension back to d_model --> batch_size, seq_len, d_model
        kt_C = kt_C.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)  # (batch_size, seq_len, num_heads, d_model//num_heads)
        vt_C = self.w_UV(ct_KV)  # bring the value dimension back to d_model --> batch_size, seq_len, d_model
        vt_C = vt_C.view(batch_size, seq_len, self.num_heads, self.head_dim)  # (batch_size, seq_len, num_heads, d_model//num_heads)
        vt_C = vt_C.transpose(1, 2)  # reshape to b, num_heads, seq_len, dim
        #print("vt_C", vt_C.shape)
#
        kt_R = self.w_KR(latent)  # weights for applying RoPE encoding for keys project d_model --> d_r
        
        kt_R = kt_R.unsqueeze(1)   # add head dimension
        kt_R = self.rope_key(kt_R)
        kt_R = kt_R.expand(-1, self.num_heads, -1, -1)  # expand to num_heads

        kt_cat = torch.cat([kt_C, kt_R], dim=-1)  # concatenate the key with the RoPE encoding (batch_size, seq_len, num_heads, d_model//num_heads * 2)
        #print("kt_cat", kt_cat.shape)
        
        weights = F.softmax(torch.matmul(qt_cat, kt_cat.transpose(-1,-2)) / ((self.d_r_h + self.head_dim)** 0.5), dim=-1)
        #print("weights", weights.shape)
        
        attn = torch.matmul(weights, vt_C)  # (batch_size, num_heads, seq_len, d_model//num_heads)


        return attn

In [19]:
num_heads = 16
d_model = 1024  
batch_size = 4
seq_len = 10 

# following defaults from deepseek v2 paper except d_hat_c which i set to 8
head_dim = d_model // num_heads
d_r_h = head_dim // 2
d_c =  4 * head_dim
d_hat_c = 8

latent = torch.randn(batch_size, seq_len, n_embed)
print("latent shape", latent.shape)


multi_head_attn = MultiHeadLatentAttn(
    d_model=d_model,
    num_heads=num_heads,
    d_c=d_c,
    d_r_h=d_r_h,
    d_hat_c=d_hat_c
)
out = multi_head_attn(latent)
#print(out.shape)
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, n_embed)
print(out.shape)
print(" # number of KV cache elements per token per layer  ~ 9/2 * head_dim)", d_r_h + d_c)

latent shape torch.Size([4, 10, 1024])
torch.Size([4, 10, 1024])
 # number of KV cache elements per token per layer  ~ 9/2 * head_dim) 288


#### Final Remarks

This is an excerpt from the paper "For DeepSeek-V2, we design an innovative attention mechanism called Multi-head Latent Attention (MLA). Equipped with low-rank key-value joint compression, MLA achieves better performance than MHA, but requires a significantly smaller amount of KV cache. We introduce its architecture in the following, and also provide a comparison between MLA and MHA in Appendix D.2."

During inference MLA caches only $C_t^{KV}$ and $K_t^R$, so the number of elements in KV cache per token for L layers comes out to be ($d_c$ + $d_h^R$) x L

![KV Cache](../../images/kv_cache_per_token.JPG)

MLA achieves better performance than MHA, but requires a significantly smaller amount of KV cache. This reduces memory footprint, achieves faster inference and better hardware utilization.
