In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F

In [3]:
class MLA(nn.Module):
    def __init__(self,d_model,n_heads,kv_latent_dim):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.kv_dim_model = kv_latent_dim
        self.dh = d_model // n_heads ## dim per head
        
        ## Projection layers
        self.W_q   = nn.Linear(d_model,d_model,bias=False)
        self.W_dkv = nn.Linear(d_model,kv_latent_dim,bias=False)
        self.W_uk  = nn.Linear(kv_latent_dim,d_model,bias=False)
        self.W_uv  = nn.Linear(kv_latent_dim,d_model,bias=False)
        self.W_o   = nn.Linear(d_model,d_model,bias=False)
        
        self.ln = nn.LayerNorm(kv_latent_dim)
        self.register_buffer("absorbed_k",None) ## holds W_q @ W_uk
        
    def forward(self,x,kv_cache=None,past_length=0):
        b,s,d = x.size()
        ## compute absorbed_k once : W_q @ W_uk 
        if self.absorbed_k is None : 
            absorbed = torch.matmul(self.W_q.weight,self.W_uk.weight) # (d,latent_dim)
            self.absorbed_k = absorbed.view(self.n_heads,self.dh,-1)
        ## compress x into latent dimension
        new_C_kv = self.ln(self.W_dkv(x)) ## (b,s,latent_dim)
        if kv_cache is None : 
            C_kv = new_C_kv
        else : 
            C_kv = torch.cat([kv_cache,new_C_kv],dim=1) ## (b,s_total,latent_dim)
        
        s_total = C_kv.size(1)
        ## compress V into d_model, split into heads
        V = self.W_uv(C_kv)
        V.view(b,s_total,self.n_heads,self.dh).transpose(1,2)
        ## split the input token vector into n_heads
        q = x.view(b,s,self.n_heads,self.dh)
        
        ## Attention scores : absorbed query * updated C_kv
        attn_scores = torch.zeros(b,self.n_heads,s,s_total,device=x.device)
        for  h in range(self.n_heads):
            tmp = torch.matmul(q[:,:,h],self.absorbed_k[h])   ## absorbed query
            attn_scores[:,h] = torch.bmm(tmp,C_kv.transpose(1,2))
        ## Scaling / Causal mask
        attn_scores = attn_scores / (self.dh**0.5)
        mask = torch.tril(torch.ones((s,s_total),device=x.device),diagonal=past_length)
        attn_scores = attn_scores.masked_fill(mask.view(1,1,s,s_total)==0,float("-inf"))
        
        ## Softmax to get the weights
        attn_weights = F.softmax(attn_scores,dim=-1)
        
        
        out_heads = []
        for i in range(self.n_heads):
            context_h = torch.matmul(attn_weights[:,h], V[:,h])
            out_heads.append(context_h)
        
        
        out = torch.cat(out_heads,dim=-1)
        return self.W_o(out), C_kv