## Step 1: Writting the code for the Multi Head Latent Attention

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm

A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "c:\Users\ayush\AppData\Local\Programs\Python\Python39\lib\runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "c:\Users\ayush\AppData\Local\Programs\Python\Python39\lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "c:\Users\ayush\AppData\Local\Programs\Python\Python39\lib\site-packages\ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "c:\Users\ayush\AppData\Local\Pr

In [15]:
class RopelessMLA(nn.Module):
    def __init__(self,d_model,n_heads,kv_latent_dim):
        super().__init__()
        self.d_model=d_model
        self.n_heads=n_heads
        self.dh=d_model//n_heads ## Dimension per head

        ## Projection layers
        self.W_q=nn.Linear(d_model,d_model,bias=False) ## Query Projection
        self.W_dkv=nn.Linear(d_model,kv_latent_dim,bias=False) ## Compress into latent kv space
        self.W_uk=nn.Linear(kv_latent_dim,d_model,bias=False) ## Decompress K
        self.W_uv=nn.Linear(kv_latent_dim,d_model,bias=False) ## Decompress V
        self.W_o=nn.Linear(d_model,d_model,bias=False)  ## Final Output projection
        

        self.ln=nn.LayerNorm(kv_latent_dim)
        self.register_buffer("absorbed_k",None) ### Holds w_q @ W_uk

    def forward(self,x,kv_cache=None,past_length=0):
            B,S,D=x.size()
            ## Compute absorbed_k once : W_q @ W_uk ,shape: (D,latent_dim)
            if self.absorbed_k is None :
                absorbed=torch.matmul(self.W_q.weight,self.W_uk.weight) ## (D,latent_dim)
                self.absorbed_k=absorbed.view(self.n_heads,self.dh,-1) ## (n_heads,dh,latent_dim)
            
            ## Compress x into latent KV Space
            new_c_kv=self.ln(self.W_dkv(x)) ## (B,S,latent_dim)
            if kv_cache is None:
                c_kv=new_c_kv
            else:
                c_kv=torch.cat((kv_cache,new_c_kv),dim=1) ## (B_total,latent_dim)
            S_full=c_kv.size(1)
            
            ## Decompress V to full d_model and split into Heads
            v_full=self.W_uv(c_kv)  ## (B,S_full,D)
            v=v_full.view(B,S_full,self.n_heads,self.dh).transpose(1,2) ## (B,S,n_heads,dh)

            ## Use input  x directly (Since W_q is absorbed)
            q=x.view(B,S,self.n_heads,self.dh) ## (B,S,n_heads,dh)
            
            ## Compute attention scores
            attn_scores=torch.zeros(B,self.n_heads,S,S_full,device=x.device)
            for h in range(self.n_heads):
                tmp=torch.matmul(q[:,  :,h],self.absorbed_k[h])
                attn_scores[:,h]=torch.bmm(tmp,c_kv.transpose(1,2))

            ## Scale and Apply Causal Mask

            attn_scores=attn_scores/(self.dh*0.5)
            mask=torch.tril(torch.ones((S,S_full),device=x.device),diagonal=past_length)
            attn_scores=attn_scores.masked_fill(mask.view(1,1,S,S_full)==0,float('-inf'))

            ## Apply Softmax to get attention weights
            attn_weights=torch.softmax(attn_scores,dim=-1) ## (B,n_heads,S,S_full)

            ## Apply attention weights to each heads V separately
            out_heads=[]
            for h in range(self.n_heads):
                context_h=torch.bmm(attn_weights[:,h],v[:,h])  ## (B,S,dh)
                out_heads.append(context_h)
            ## Concentenate all head outputs along the feature dimension
            out=torch.cat(out_heads,dim=-1)  ## (B,S,D)
            return self.W_o(out),c_kv ## Final output projection updated latent cache

In [16]:
model=RopelessMLA(d_model=512,n_heads=8,kv_latent_dim=256)
x=torch.randn(1,5,512) ## Batch=2,context_length=10,d_model=512
out,cache=model(x)

## Step 2 Memory Testing


In [17]:
def demo ():
    model=RopelessMLA(d_model=512,n_heads=8,kv_latent_dim=256)
    x=torch.randn(1,5,512) ## Batch=2,context_length=10,d_model=512
    out,cache=model(x)
    print(f"Output: {out.shape}, Cache: {cache.shape}")

    ## Memory Consumption
    std_size=2*2*10*512*4/1024
    latent_size=2*10*256*4/1024
    print(f"Memory: Standard={std_size:.1f}kb ,Latent={latent_size:.1f}KB,Reduction={std_size/latent_size}")


if __name__ =="__main__":
    demo()

Output: torch.Size([1, 5, 512]), Cache: torch.Size([1, 5, 256])
Memory: Standard=80.0kb ,Latent=20.0KB,Reduction=4.0


## STEP 3 : Cache testing - Single new infernece

In [18]:
def demo_cache_usage():
    torch.manual_seed(0)
    model=RopelessMLA(d_model=8,n_heads=2,kv_latent_dim=4)

    ## -----Step 1 Initial input (Sequence of 5 Tokens)-------
    x_1=torch.randn(1,5,8) ## (Batch=1,tokens(S)=5,D=8)
    out1,cache1=model(x_1)
    print("Step  : Inital input")
    print(f"Output shape: {out1.shape}")
    print(f"cache shape:  {cache1.shape}")  ## Expect: (1,5,4)

    ## Step 2: Append 1 Toekn ---
    x_2=torch.randn(1,1,8) ## (Batch=1,tokens(S)=1,D=8)
    out2,cache2=model(x_2,kv_cache=cache1,past_length=5)

    print("Step  : Appended input")
    print(f"Output shape: {out2.shape}")
    print(f"cache shape:  {cache2.shape}")  ## Expect: (1,6,4)
demo_cache_usage()

Step  : Inital input
Output shape: torch.Size([1, 5, 8])
cache shape:  torch.Size([1, 5, 4])
Step  : Appended input
Output shape: torch.Size([1, 1, 8])
cache shape:  torch.Size([1, 6, 4])


## STEP 4 Cache Testing - Multiple new Inferences

In [22]:
def demo_kv_cache_growth(num_initial_tokens=5,num_new_tokens=3):
    torch.manual_seed(0)

    model=RopelessMLA(d_model=8,n_heads=2,kv_latent_dim=4)

    ## Step 1 Start with initial token batch
    x=torch.randn(1,num_initial_tokens,8)
    out,cache=model(x)
    print(f"Initial input of {num_initial_tokens} tokens -> cache shape:", cache.shape)

    ## Step Incrementally append new tokens one at a time
    for step in range(1,num_new_tokens+1):
        x_new = torch.randn(1, 1, 8)  # New token with the same feature dimension
        out, cache = model(x_new,kv_cache=cache,past_length=cache.shape[1])
        print(f"Step {step}: Added  1 Token -> cache shape", cache.shape)
demo_kv_cache_growth(num_initial_tokens=50,num_new_tokens=4)

Initial input of 50 tokens -> cache shape: torch.Size([1, 50, 4])
Step 1: Added  1 Token -> cache shape torch.Size([1, 51, 4])
Step 2: Added  1 Token -> cache shape torch.Size([1, 52, 4])
Step 3: Added  1 Token -> cache shape torch.Size([1, 53, 4])
Step 4: Added  1 Token -> cache shape torch.Size([1, 54, 4])
