## Why Multi-Head Latent Attention?

So the main reason for using multi-head latent attention is to reduce the size required to cache the key and value vectors
or we say the KV cache size

### Why is KV cache size important to reduce ?

The KV cache size is important to reduce because it directly affects the memory usage of the model. A larger KV cache size means more memory is required to store the key and value vectors, which can lead to higher memory costs and potential out-of-memory errors. By reducing the KV cache size, we can optimize memory usage and improve the overall efficiency of the model.

#### Why to even store them ?

The key and value vectors are stored in the cache to avoid redundant calculations during the attention process. When the same input is processed multiple times, the key and value vectors can be reused, which can improve performance and reduce memory usage.

#### How does mutli-head latent attention help in reducing the KV cache size ?

In mutli-head latent attention, we use latent variables to represent the key and value vectors, which can be computed once and reused for multiple tokens. This reduces the size of the KV cache required to store the key and value vectors for each token.


How does it plays out ?

Let's say our input matrix is of shape [2,768] i.e. 2 tokens and 768 dimensions (embeddings)

Now generally what we would have done is convert it to K,Q,V vectors of shape [2,768] each
And for cache we would have stored the key and value vectors 

But in case of latent attention what is done is :
1. query is calculated the same way 
2. but now input is first converted to a latent key value space (which signifies compressed key value data)
    Wdkv -> Weight for down projection key and value [768,512]
    i.e. [2.768] * Wdkv -> [2,512]
    and then this latent key value space is used to calculate key and value
    by doing up projection
    Wuv -> Weight for up projection value [512,768]
    i.e. [2,512] * Wuv -> [2,768]
    Wuk -> Weight for up projection key [512,768]
    i.e. [2,512] * Wuk -> [2,768]
    So we obtained both key and value vectors
    Now the number of steps muight seems to be increased (which is not the case as we can combine some calulcations) but we don't 
    need to store key and value vectors in cache seperately we can just store this latent key value space
3. Let's see how the calculation of attention scores would be different
    X -> input -> [2,768] 
    Q -> X * Wq -> [2,768]
    Ckv -> latent key value space -> [2,512] -> X * Wdkv -> [2,512]
    K -> Ckv * Wuk -> [2,768]
    V -> Ckv * Wuv -> [2,768]
    Attention scores -> Q * K.T -> [2,2] 
    Q * K.T -> X * Wdkv * (Ckv * Wuk).T 
    Q * K.T -> X * Wdkv * Wuk.T * Ckv.T (we can say the multiplication of weight i.e. Wdkv * Wuk.T) is something which is always fixed 
    Q * K.T -> X * (Wdkv * Wuk.T) * (X * Wdkv).T
    X * (Wdkv * Wuk.T) -> Absorbed query 
    (X * Wdkv).T -> this is cached every time basically Ckv
    
    
 


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MLA(nn.Module):
    def __init__(self, d_model, n_heads, kv_latent_dim):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.dh = d_model // n_heads # dimension per head

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_dkv = nn.Linear(d_model, kv_latent_dim, bias=False)
        self.W_uk = nn.Linear(kv_latent_dim, d_model, bias=False)
        self.W_uv = nn.Linear(kv_latent_dim, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False) # final output projection

        self.ln1 = nn.LayerNorm(kv_latent_dim)
        self.register_buffer('absorbed_k', None) # as we figured out this stored (W_dkv @ W_uk)

    def forward(self, x, kv_cache=None, past_length=0):
        B,S,D = x.size() # batch, context length, dimension

        if self.absorbed_k is None:
            absorbed = torch.matmul(self.W_q.weight,self.W_uk.weight) # the transpose is done automatically in matmul (d_model, kv_latent_dim)
            self.absorbed_k = absorbed.view(self.n_heads,self.dh, -1) # (n_heads, dh, kv_latent_dim)
        
        new_ckv = self.ln1(self.W_dkv(x)) # new latent kv cache for current input x 
        if kv_cache is None:
            kv_cache = new_ckv
        else:
            kv_cache = torch.cat([kv_cache, new_ckv], dim=1) # (B, S, D)
        
        s_full = kv_cache.size(1)

        v_full = self.W_uv(kv_cache) # (B, S, D)
        v = v_full.view(B,s_full,self.n_heads,self.dh).transpose(1,2) # (B, n_heads, S_full, dh)

        q = x.view(B,S,self.n_heads,self.dh)

        attn_scores = torch.zeros(B,self.n_heads,S,s_full,device=x.device)

        # why s and s_full?
        # s is the current context length
        # s_full is the total context length including the past context and for current prediction we need scores b/w current words with all previpous 
        # words including the current words i.e. [s, s_full] s_full = s + prev_context_length
        # why not s_full,s_full because for next token prediction we just need the logits of last current word

        for h in range(self.n_heads):
            tmp = torch.matmul(q[:,:,h],self.absorbed_k[h])
            attn_scores[:,h] = torch.bmm(tmp,kv_cache.transpose(1,2))

        attn_scores = attn_scores / (self.dh ** 0.5)

        mask = torch.tril(torch.ones(S,s_full,device=x.device),diagonal=past_length) # diagonal -> shifting the diagonal by past_length or visualise it like 
        # extending the diagoal to right by past_length why ? because we start masking from past_length + 1
        # Creating causal attention mask:
        # mask = torch.tril(torch.ones(S,S_full,device=x.device),diagonal=past_length)
        #
        # This creates a mask matrix of shape [S x S_full] where:
        # - S is current sequence length
        # - S_full is total sequence length (current + past context)
        # - past_length = S_full - S (length of past context)
        #
        # Example with S=4, S_full=7, past_length=3:
        # 1 1 1 1 0 0 0  # First row: can attend to past context (first 3) and itself
        # 1 1 1 1 1 0 0  # Second row: can attend to all previous tokens and itself
        # 1 1 1 1 1 1 0  # Third row: same pattern
        # 1 1 1 1 1 1 1  # Fourth row: can attend to everything up to itself
        #
        # The diagonal parameter shifts the main diagonal up by past_length,
        # allowing each position to attend to all past context tokens plus
        # the regular causal attention pattern for current sequence.
        attn_scores = attn_scores.masked_fill(mask.view(1,1,S,s_full) == 0, float('-inf'))

        attention_weights = F.softmax(attn_scores, dim=-1)   # (B, n_heads, S, S_full)

        out_heads = []

        for h in range(self.n_heads):
            context_h = torch.matmul(attention_weights[:,h], v[:,h]) # (B,S, dh)
            out_heads.append(context_h)
        
        out = torch.cat(out_heads, dim=-1) # (B,S,D)

        return self.W_o(out), kv_cache
                
        
                
        
                
    

In [14]:
def demo_mla():

    torch.manual_seed(0)
    
    model = MLA(d_model=8,n_heads=2,kv_latent_dim=4)

    x1 = torch.randn(1,5,8) # (batch_size, seq_len, d_model)
    out1, kv1 = model(x1)
    
    print("Step 1:")
    print(f"Output shape: {out1.shape}")
    print(f"KV shape: {kv1.shape}")
    
    x2 = torch.randn(1,1,8) # (batch_size, seq_len, d_model)
    out2, kv2 = model(x2,kv_cache=kv1,past_length=5)
    
    print("Step 2:")
    print(f"Output shape: {out2.shape}")
    print(f"KV shape: {kv2.shape}")
    
demo_mla()
    
    
    
    

Step 1:
Output shape: torch.Size([1, 5, 8])
KV shape: torch.Size([1, 5, 4])
Step 2:
Output shape: torch.Size([1, 1, 8])
KV shape: torch.Size([1, 6, 4])
