In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [14]:
class RopeLessMLA(nn.Module):
    def __init__(self, d_model, n_heads, kv_latent_dim):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.d_model = d_model
        self.n_heads = n_heads
        self.kv_latent_dim = kv_latent_dim
        self.head_dim = d_model // n_heads
        self.dh = self.head_dim

        # Projections
        self.W_q   = nn.Linear(d_model, d_model, bias=False)
        self.W_dkv = nn.Linear(d_model, kv_latent_dim, bias=False)   # compress to latent
        self.W_uk  = nn.Linear(kv_latent_dim, d_model, bias=False)   # decompress keys to model dim
        self.W_uv  = nn.Linear(kv_latent_dim, d_model, bias=False)   # <-- FIXED: latent -> model for values
        self.W_o   = nn.Linear(d_model, d_model, bias=False)

        self.ln = nn.LayerNorm(kv_latent_dim)

        # will hold (H, dh, kv_latent_dim) after first forward
        self.register_buffer("absorbed_k", None, persistent=False)

    def _shape_heads(self, x):  # (B,S,D) -> (B,H,S,dh)
        B, S, D = x.shape
        return x.view(B, S, self.n_heads, self.dh).permute(0, 2, 1, 3).contiguous()

    def forward(self, x, kv_cache=None, past_length=0):
        """
        x: (B, S, D)
        kv_cache: (B, S_cached, kv_latent_dim) or None
        """
        B, S, D = x.size()

        # Lazily build absorbed_k = W_q * W_uk  -> shape (D, kv_latent_dim)
        # Then split into heads: (H, dh, kv_latent_dim)
        if self.absorbed_k is None:
            absorbed_k = torch.matmul(self.W_q.weight, self.W_uk.weight)  # (D, kv_latent_dim)
            self.absorbed_k = absorbed_k.view(self.n_heads, self.dh, self.kv_latent_dim)

        # Build latent KV (and cache)
        new_c_kv = self.ln(self.W_dkv(x))  # (B, S, kv)
        if kv_cache is None:
            c_kv = new_c_kv
        else:
            c_kv = torch.cat([kv_cache, new_c_kv], dim=1)  # (B, S_full, kv)
        S_full = c_kv.size(1)

        # Project queries and split to heads
        q_full = self.W_q(x)                           # (B, S, D)
        q = self._shape_heads(q_full)                  # (B, H, S, dh)

        # Precompute (q_h @ absorbed_k_h): for each head h,
        # q_h: (B, S, dh), absorbed_k_h: (dh, kv) -> (B, S, kv)
        # Then attn_scores_h = (q_h @ absorbed_k_h) @ c_kv^T -> (B, S, S_full)
        attn_scores = torch.zeros(B, self.n_heads, S, S_full, device=x.device, dtype=x.dtype)
        c_kv_T = c_kv.transpose(1, 2)  # (B, kv, S_full)

        for h in range(self.n_heads):
            q_h = q[:, h]                                      # (B, S, dh)
            absorbed_k_h = self.absorbed_k[h]                  # (dh, kv)
            tmp = torch.matmul(q_h, absorbed_k_h)              # (B, S, kv)
            attn_scores[:, h] = torch.bmm(tmp, c_kv_T)         # (B, S, S_full)

        # Scale + causal mask
        attn_scores = attn_scores / (self.dh ** 0.5)
        causal = torch.tril(torch.ones(S, S_full, device=x.device), diagonal=past_length)
        attn_scores = attn_scores.masked_fill(causal.view(1, 1, S, S_full) == 0, float("-inf"))

        attn_weights = F.softmax(attn_scores, dim=-1)  # (B, H, S, S_full)

        # Build V: decompress latent to full, then split heads
        v_full = self.W_uv(c_kv)                       # (B, S_full, D)  <-- now valid
        v = v_full.view(B, S_full, self.n_heads, self.dh).permute(0, 2, 1, 3)  # (B, H, S_full, dh)

        # Weighted sum per head -> concat -> output proj
        out_heads = []
        for h in range(self.n_heads):
            context_h = torch.matmul(attn_weights[:, h], v[:, h])  # (B, S, dh)
            out_heads.append(context_h)
        out = torch.cat(out_heads, dim=-1)  # (B, S, D)

        return self.W_o(out), c_kv

### DOing mem testing

In [15]:
def demo():
  model= RopeLessMLA(d_model=512,n_heads=8,kv_latent_dim=256)
  x=torch.randn(1,5,512)
  out,cache=model(x)

  print(f"output : {out.shape}, cache : {cache.shape}")
  std_size=2*2*10*512*4/1024 # standard KV  : B*2(K,V) * T *D * float32
  latent_size=2*10*256*4/1024 # KB (Latent Cache : B* T * latent_dim * float32)
  print(f"Memory: Standard={std_size:.2f} KB, Latent={latent_size:.2f} KB")
if __name__=="__main__":
  demo()

output : torch.Size([1, 5, 512]), cache : torch.Size([1, 5, 256])
Memory: Standard=80.00 KB, Latent=20.00 KB
