<a href="https://colab.research.google.com/github/Arpit-Baranwal/Transformer_Implementation/blob/main/MLA_impleemntation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import math

In [15]:
class RopelessMLA(nn.Module):
    def __init__(self, d_model, n_heads, kv_latent_dim):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.dh = d_model // n_heads   # dimension per head
        self.kv_latent_dim = kv_latent_dim

        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.W_q = nn.Linear(d_model, d_model, bias=False)        # Query Projection
        self.W_dkv = nn.Linear(d_model, kv_latent_dim, bias=False)  # Compress into latent KV space
        self.W_uk = nn.Linear(kv_latent_dim, d_model, bias=False)  # decompress K
        self.W_vk = nn.Linear(kv_latent_dim, d_model, bias=False)  # decompress V
        self.W_o = nn.Linear(d_model, d_model, bias=False)        # Output Projection

        self.ln = nn.LayerNorm(kv_latent_dim)
        self.register_buffer('absorbed_k', None)

    def forward(self, x, kv_cache=None, past_length=0):
        B, S, D = x.size()

        # computing the absorption matrix once: W_q @ W_uk, shape(D,latent_dim)
        if self.absorbed_k is None:
            absorbed = torch.matmul(self.W_q.weight, self.W_uk.weight)    # (D, latent_dim)
            self.absorbed_k = absorbed.view(self.n_heads, self.dh, -1)    # (n_heads, dh, latent_dim)

        # compress x into KV space
        new_c_kv = self.ln(self.W_dkv(x))     # (B, S, latent_dim)

        if kv_cache is None:
            c_kv = new_c_kv
            S_full = S
        else:
            c_kv = torch.cat([kv_cache, new_c_kv], dim=1)   # (B, S_total, latent_dim)
            S_full = c_kv.size(1)

        # decompress V to full d_model and split it into heads
        v_full = self.W_vk(c_kv)  # (B, S_full, D)
        v = v_full.view(B, S_full, self.n_heads, self.dh).transpose(1, 2)  # (B, n_heads, S_full, dh)

        # Compute queries
        q_full = self.W_q(x)  # (B, S, D)
        q = q_full.view(B, S, self.n_heads, self.dh).transpose(1, 2)  # (B, n_heads, S, dh)

        # compute attention scores
        attn_score = torch.zeros(B, self.n_heads, S, S_full, device=x.device)

        # Pre-compute c_kv transpose once
        c_kv_t = c_kv.transpose(1, 2)  # (B, latent_dim, S_full)

        for i in range(self.n_heads):
            # q_i: (B, S, dh), absorbed_k[i]: (dh, latent_dim)
            # tmp: (B, S, latent_dim)
            tmp = torch.matmul(q[:, i], self.absorbed_k[i])  # (B, S, dh) @ (dh, latent_dim)

            # attn_score[:, i]: (B, S, S_full) = (B, S, latent_dim) @ (B, latent_dim, S_full)
            attn_score[:, i] = torch.bmm(tmp, c_kv_t)

        # scale and apply causal mask
        attn_score = attn_score / (self.dh ** 0.5)

        # Create causal mask
        mask = torch.tril(torch.ones(S, S_full, device=x.device), diagonal=past_length)
        attn_score = attn_score.masked_fill(mask.view(1,1,S,S_full)==0,float('-inf'))

        # Apply softmax to get attention weights
        attn_weights = F.softmax(attn_score, dim=-1)  # (B, n_heads, S, S_full)

        # Apply attention weights to each head's V separately
        out_heads = []
        for i in range(self.n_heads):
            # attn_weights[:, i]: (B, S, S_full), v[:, i]: (B, S_full, dh)
            context_head = torch.bmm(attn_weights[:, i], v[:, i])  # (B, S, dh)
            out_heads.append(context_head)

        # Concatenate all head outputs along the feature dimension
        # First, transpose back to (B, S, n_heads, dh)
        out = torch.stack(out_heads, dim=2)  # (B, S, n_heads, dh)
        out = out.reshape(B, S, D)  # Concatenate heads

        return self.W_o(out), c_kv

### Speed Test

In [18]:
def demo():
    model = RopelessMLA(d_model=512, n_heads=8, kv_latent_dim=256)

    x = torch.rand(1, 5, 512)  # batch=1, seq_len=5, d_model=512
    out, cache = model(x)
    print(f'Output shape: {out.shape}')
    print(f'Cache shape: {cache.shape}')

    # Memory comparison
    seq_len = 1000  # Example sequence length
    batch_size = 1

    # Standard KV cache (K and V separately)
    std_mem = 2 * batch_size * seq_len * 512 * 4 / 1024  # 2 for K+V, 4 bytes for float32

    # Latent KV cache (compressed)
    latent_mem = batch_size * seq_len * 256 * 4 / 1024  # Only latent representation

    print(f'\nMemory comparison for seq_len={seq_len}:')
    print(f'Standard KV cache: {std_mem:.1f} KB')
    print(f'Latent KV cache: {latent_mem:.1f} KB')
    print(f'Compression: {std_mem/latent_mem:.1f}x smaller')


if __name__ == "__main__":
    demo()

Output shape: torch.Size([1, 5, 512])
Cache shape: torch.Size([1, 5, 256])

Memory comparison for seq_len=1000:
Standard KV cache: 4000.0 KB
Latent KV cache: 1000.0 KB
Compression: 4.0x smaller
