<a href="https://colab.research.google.com/github/Tanya-Verma/Apline_dashboard/blob/main/MLA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# MLA vs Standard Attention - PyTorch Implementation
# For Netflix-style long-context recommendations (128K tokens!)
# Based on your diagram: 32K â†’ 512 latent cache (93% memory reduction)

import torch
import torch.nn as nn

class StandardAttention(nn.Module):
    """Standard Multi-Head Attention - BLOATED KV CACHE (5GB @ 128K)"""
    def __init__(self, d_model=1024, num_heads=32, head_dim=1024):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.w_q = nn.Linear(d_model, num_heads * head_dim)
        self.w_k = nn.Linear(d_model, num_heads * head_dim)  # 32K values!
        self.w_v = nn.Linear(d_model, num_heads * head_dim)  # 32K values!
        self.w_o = nn.Linear(num_heads * head_dim, d_model)

    def forward(self, x):
        # x: (batch, seq_len=128K, d_model)
        Q = self.w_q(x).view(x.size(0), x.size(1), self.num_heads, self.head_dim)
        K = self.w_k(x).view(x.size(0), x.size(1), self.num_heads, self.head_dim)  # CACHE THIS = 5GB!
        V = self.w_v(x).view(x.size(0), x.size(1), self.num_heads, self.head_dim)  # CACHE THIS = 5GB!
        return self.w_o(Q @ K.transpose(-2,-1) @ V)  # Simplified

class MLADownProject(nn.Module):
    """MLA Core: W_down compression (YOUR GREEN BOX!)"""
    def __init__(self, d_model=1024, latent_dim=512):  # 512 values = CACHE THIS!
        super().__init__()
        self.w_down = nn.Linear(d_model, latent_dim)  # Compress to tiny latent

class MLAAttention(nn.Module):
    """Multi-Head Latent Attention - 1GB @ 128K (93% reduction!)"""
    def __init__(self, d_model=1024, num_heads=32, latent_dim=512):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # W_down (diagram green box)
        self.down_proj = MLADownProject(d_model, latent_dim)

        # Per-head up-projectors (W_uK, W_uV from diagram)
        self.w_uk = nn.Linear(latent_dim, num_heads * self.head_dim)  # Rebuild K
        self.w_uv = nn.Linear(latent_dim, num_heads * self.head_dim)  # Rebuild V
        self.w_q = nn.Linear(d_model, num_heads * self.head_dim)
        self.w_o = nn.Linear(num_heads * self.head_dim, d_model)

    def forward(self, x):
        batch, seq_len, _ = x.shape

        # Step 1: Down-project to LATENT CACHE (512 values per token!)
        latent_cache = self.down_proj(x.mean(-1))  # (batch, seq_len, 512)
        # SAVE THIS TO KV-CACHE = 1GB vs 5GB! [Your diagram âœ…]

        # Step 2: Up-project ON-DEMAND per head
        K = self.w_uk(latent_cache).view(batch, seq_len, self.num_heads, self.head_dim)
        V = self.w_uv(latent_cache).view(batch, seq_len, self.num_heads, self.head_dim)
        Q = self.w_q(x).view(batch, seq_len, self.num_heads, self.head_dim)

        # Step 3: Attention (same math, tiny cache!)
        return self.w_o((Q @ K.transpose(-2,-1)) @ V)

# Netflix Demo Usage
if __name__ == "__main__":
    batch_size, seq_len, d_model = 1, 128000, 1024  # Your 128K context!

    x = torch.randn(batch_size, seq_len, d_model)

    # Standard = MEMORY EXPLOSION ðŸ’¥
    standard = StandardAttention()
    # standard(x)  # 5GB KV cache -> GPU crash!

    # MLA = Netflix speed! âš¡
    mla = MLAAttention()
    output = mla(x)  # 1GB cache -> endless recs!
    print(f"MLA Output: {output.shape} | Cache: ~1GB vs 5GB (93% reduction!)")