# Multi-Head Latent Attention (MLA)

In [1]:
# Building the MLA Module from scratch

import torch
import torch.nn as nn 

class MultiHeadLatentAttention(nn.Module):
    """
    Implementation of Multi-Head Latent Attention (MLA) as described
    in the DeepSeek architecture. This version focuses on the core
    "compress for storage, decompress for use" mechanism for the
    Key and Value matrices.
    """
     
    def  __init__(self , d_model , num_heads , d_latent , dropout=0.0):
        super().__init__()
        assert d_model % num_heads == 0 , "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.d_latent = d_latent # The dimension of the compressed latent space

        # The Query projection remains standard, projecting to the full model dimension.
        self.W_q = nn.Linear(d_model , d_model)

        # The new KV Down-Projector. This is "compress" step.
        # It projects the input down to a small , shared latent space.
        self.W_dkv = nn.Linear(d_model , d_latent)

        # The new Key and Value Up-Projectors. This is the "decompress" step.
        # They reconstruct the full-sized K and V from latent space.
        # Note: These are multi-headed to preserve head diversity.
        self.W_uk = nn.Linear(d_latent , d_model)
        self.W_uv = nn.Linear(d_latent , d_model)

        # The final output projection , standard for multi-head attention.
        self.W_o = nn.Linear(d_model , d_model)

        self.dropout = nn.Dropout(dropout)
        # Causal mask to prevent attending to future tokens. Using a fixed size for demo.
        self.register_buffer('mask' , torch.triu(torch.ones(1,1,1024,1024) , diagonal=1).bool())

    
    def forward(self , x):
        batch_size , seq_len , _ = x.shape

        # 1. Query Path (Unchanged)
        # Project and reshape the query as in standard MHA.
        q = self.W_q(x).view(batch_size , seq_len , self.num_heads , self.d_head).transpose(1,2)

        # 2. Key/Value Path (The MLA Innovotion)
        # Step 2a: Down-Project to the latent space.
        # This is the ONLY value that would be cached during inference.
        c_kv = self.W_dkv(x) # Shape: (batch , seq_len , d_latent)

        # Step 2b: Up-Project from the latent space to get full K and V.
        # These are compute on the fly and are not cached.
        k = self.W_uk(c_kv).view(batch_size , seq_len , self.num_heads , self.d_head).transpose(1,2)
        v = self.W_uv(c_kv).view(batch_size , seq_len , self.num_heads , self.d_head).transpose(1,2)

        # 3. Standard Attention Calculation
        # The rest of process is identical to standard MHA.
        attn_scores = (q @ k.transpose(-2 , -1)) / (self.d_head ** 0.5)

        # Apply causal mask
        attn_scores = attn_scores.masked_fill(
            self.mask[: , : , :seq_len , :seq_len] , float('-inf')
        )

        attn_weights = torch.softmax(attn_scores , dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vector = (attn_weights @ v).transpose(1,2).contiguous().view(batch_size , seq_len , self.d_model)

        # 4. Final Output Projection
        output = self.W_o(context_vector)
        return output
    
# --- Usage Example ---
d_model = 512
num_heads = 8
d_latent = 128  # Latent dimension must be smaller than d_model
batch_size = 4
seq_len = 64

# Instantiate the layer
mla_layer = MultiHeadLatentAttention(d_model, num_heads, d_latent)

# Create a dummy input tensor
dummy_input = torch.randn(batch_size, seq_len, d_model)

# Pass the input through the layer
output = mla_layer(dummy_input)

print("MLA Layer successful!")
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")

MLA Layer successful!
Input shape: torch.Size([4, 64, 512])
Output shape: torch.Size([4, 64, 512])


# Fused MLA with Decoupled RoPE

In [3]:
# Building the Fused MLA and Decoupled RoPE Module

import torch
import torch.nn as nn
import math

class RotaryPositionalEncoding(nn.Module):
    """
    Helper module to apply Rotary Positional Encoding (RoPE).
    This is not added to the embeddings but applied directly to
    the Query and Key vectors.
    """
    def __init__(self , d_head , max_seq_len=2048):
        super().__init__()
        # Precompute the theta values for rotational matrix
        theta = 1.0 / (10000 ** (torch.arange(0 , d_head , 2).float() / d_head))
        self.register_buffer('theta' , theta)

        # Precompute the frequency terms (m * theta) for all positions
        positions = torch.arange(max_seq_len).unsqueeze(1)
        freqs = positions * self.theta.unsqueeze(0)

        # Create the complex number representation for rotation
        # the real part is cos(freqs) and imaginary part is sin(freqs)
        self.register_buffer("freqs_cis" , torch.polar(torch.ones_like(freqs) , freqs))

    def forward(self , x):
        # x.shape: (batch , num_heads , seq_len , d_head)
        seq_len = x.shape[2]

        # Reshape x to treat pairs of dimensions as complex numbers
        x_complex = x.float().view(*x.shape[:-1] , -1 , 2) # shape: (batch , num_heads , seq_len , d_head/2 , 2)
        # Convert to PyTorch complex type
        x_complex = torch.view_as_complex(x_complex)

        # Get the precomputed frequencies for the current sequence length
        freqs_cis = self.freqs_cis[:seq_len , :].unsqueeze(0).unsqueeze(0) # shape: (1 , 1 , seq_len , d_head/2)

        # Apply rotation by multiplying in the complex domain
        # This rotates each pair of dimensions by the angle m * theta_i
        x_rotated = x_complex * freqs_cis

        # Convert back to real number representation
        x_rotated = torch.view_as_real(x_rotated)
        # Reshape back to the original d_head dimension
        x_rotated = x_rotated.flatten(3)

        return x_rotated.type_as(x)
    

class DeepSeekAttention(nn.Module):
    """
    The full, state-of-the-art attention mechanism from DeepSeek, combining
    Multi-Head Latent Attention (MLA) with Decoupled Rotary Positional
    Encoding (RoPE).
    """
    def __init__(self , d_model , num_heads , d_latent , d_rope , dropout=0.0 , max_seq_len=2048):
        super().__init__()
        assert d_model % num_heads == 0 , "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.d_latent = d_latent
        self.d_rope = d_rope # Dimension for positional vectors

        # --- A: Content Path (Pure MLA) ---
        self.W_q_content = nn.Linear(d_model , d_model)
        self.W_dkv_content = nn.Linear(d_model , d_latent)
        self.W_uk_content = nn.Linear(d_latent , d_model)
        self.W_uv_content = nn.Linear(d_latent , d_model)

        # --- B: Position Path (RoPE Applied) ---
        self.W_k_pos = nn.Linear(d_model , d_rope * num_heads)
        self.W_q_pos = nn.Linear(d_model , d_rope * num_heads)

        # RoPE module to apply the rotations
        self.rope = RotaryPositionalEncoding(d_rope , max_seq_len)

        # --- C: Final Output Projection ---
        self.W_o = nn.Linear(d_model , d_model)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask' , torch.triu(torch.ones(1,1,max_seq_len , max_seq_len) , diagonal=1).bool())

    
    def forward(self , x):
        batch_size , seq_len , _ = x.shape

        # --- A: Content Path Calculation ---
        # This path is cache-friendly and position-agnostic.
        q_c = self.W_q_content(x).view(batch_size , seq_len , self.num_heads , self.d_head).transpose(1,2)
        c_kv = self.W_dkv_content(x) # This is what gets cached for content path.
        k_c = self.W_uk_content(c_kv).view(batch_size , seq_len , self.num_heads , self.d_head).transpose(1,2)
        v_c = self.W_uv_content(c_kv).view(batch_size , seq_len , self.num_heads , self.d_head).transpose(1,2)

        # --- B: Position Path Calculation ---
        # This path handles the positional information.
        q_r_unrotated = self.W_q_pos(x).view(batch_size , seq_len , self.num_heads , self.d_head).transpose(1,2)
        k_r_unrotated = self.W_k_pos(x).view(batch_size , seq_len , self.num_heads , self.d_head).transpose(1,2)

        # Apply RoPE to the positional Query and Key vectors
        q_r = self.rope(q_r_unrotated)
        k_r = self.rope(k_r_unrotated) # This is what gets cached for the position path.

        # --- C: Combining Paths for Final Attention Score ---
        # The final score is the sum of content and position scores.
        content_scores = (q_c @ k_c.transpose(-2 , -1)) / (self.d_head ** 0.5)
        position_scores = (q_r @ k_r.transpose(-2 , -1)) / (self.d_head ** 0.5)

        attn_scores = content_scores + position_scores

        # --- D: Final Steps (Masking , SoftMax , Output) ---
        attn_scores = attn_scores.masked_fill(
            self.mask[: , : , :seq_len , :seq_len] , float('-inf')
        )

        attn_weights = torch.softmax(attn_scores , dim=-1)
        attn_weights = self.dropout(attn_weights)

        # The final context vector is computed using only the content value matrix (v_c)
        context_vector = (attn_weights @ v_c).transpose(1,2).contiguous().view(batch_size , seq_len , self.d_model)

        output = self.W_o(context_vector)

        return output
    
# --- Usage Example ---
d_model = 512
num_heads = 8
d_latent = 128
d_rope = 64 # Dimension for RoPE, typically d_head or smaller
batch_size = 4
seq_len = 64

# Instantiate the full attention layer
deepseek_attn_layer = DeepSeekAttention(d_model, num_heads, d_latent, d_rope)

# Create a dummy input tensor
dummy_input = torch.randn(batch_size, seq_len, d_model)

# Pass the input through the layer
output = deepseek_attn_layer(dummy_input)

print("DeepSeekAttention Layer successful!")
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")


DeepSeekAttention Layer successful!
Input shape: torch.Size([4, 64, 512])
Output shape: torch.Size([4, 64, 512])
