In [425]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass
from collections import OrderedDict

from ohara.modules.norm import RMSNorm

from ohara.embedings_pos.rotatry import precompute_freqs_cis
from ohara.embedings_pos.rotatry import apply_rope

from torch import Tensor


from rich import print, traceback
traceback.install()


@dataclass
class Config(OrderedDict):
    vocab_size: int
    seq_len: int

    d_model: int

    # in deepseekv2 d_model < num_heads * head_dim
    # they expanded dim*3.2 for attention
    num_heads: int
    head_dim: int

    num_layers: int = 4

    dropout: float = 0.2
    bias: bool = False
    weight_tying: bool = False

    activation: str = "silu"  # "relu", "gelu", "silu" etc
    mlp: str = "GLU"  # MLP or GLU
    
    # rope is applied partially to hdim of query and key
    rope_head_dim: int = None
    
    # rank for query is higher than key and value
    # query has more information than key and value
    # in deepseekv2  q_lora_rank =  3 * kv_lora_rank
    kv_lora_rank: int = None
    q_lora_rank: int = None



# ======================================================================================
# ||>>>> Note <<<<||
# --------------------------------------------------------------------------------------
# in the code they are doing different things from paper
# eg
# 1. k_rope is projection form d_model (hidden_dim) while in paper it come from compress_kv
# 2. while q_rope comes from compress_q (in both paper and code)
# 3. there are layer norm on compressed q , kv
# 4. norm is applied to q_nope,q_rope,k_nope and v
#    but not to k_rope (idk why rope part of k should be normalized)
# 5. there is no inference merged code for mla
# ======================================================================================

# --- MLA ---
class MultiHeadLatentAttention(nn.Module):
    """
    Multi Head Latent Attention 
    paper: https://arxiv.org/pdf/2405.04434
    
    TLDR: 
    kv are low ranks, this verient of attention project q,k,v to low rank to save memory

    by joey00072 (https://github.com/joey00072)
    """
    def __init__(self, config: Config):
        super().__init__()
        
        self.config = config
        self.dim = config.d_model
        self.num_heads = config.num_heads
        self.head_dim = config.head_dim
        self.q_lora_rank = config.q_lora_rank
        self.kv_lora_rank = config.kv_lora_rank

        # (attention_dim == num_head*head_dim) > d_model in deepseekv2
        self.attention_dim = self.num_heads * self.head_dim
        self.rope_head_dim = config.rope_head_dim
        self.nope_head_dim = config.head_dim - config.rope_head_dim

        # query compression
        self.compress_q_linear = nn.Linear(self.dim, self.q_lora_rank, bias=config.bias)  # W_DQ
        self.decompress_q_nope = nn.Linear(self.q_lora_rank, self.nope_head_dim * self.num_heads, bias=config.bias)
        self.decompress_q_rope = nn.Linear(self.q_lora_rank, self.rope_head_dim * self.num_heads, bias=config.bias)

        # key and value compression
        self.compress_kv_linear = nn.Linear(self.dim, self.kv_lora_rank, bias=config.bias)  # W_DKV
        self.decompress_k_nope = nn.Linear(self.kv_lora_rank, self.nope_head_dim * self.num_heads, bias=config.bias)
        self.decompress_v_linear = nn.Linear(self.kv_lora_rank, self.head_dim * self.num_heads, bias=config.bias)
        
        self.k_rope_linear = nn.Linear(self.dim, self.rope_head_dim, bias=config.bias)

        self.q_norm = RMSNorm(self.q_lora_rank)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        # self.rope_norm = RMSNorm(self.rope_head_dim) # not in deepseekv2

        self.proj = nn.Linear(self.num_heads*self.head_dim , self.dim, bias=config.bias)

    def forward(self, x: Tensor, freqs_cis: Tensor):
        batch_size, seq_len, _ = x.shape

        compressed_q = self.compress_q_linear(x)
        norm_q = self.q_norm(compressed_q)
        query_nope:Tensor = self.decompress_q_nope(norm_q)
        query_rope:Tensor = self.decompress_q_rope(norm_q)

        compressed_kv = self.compress_kv_linear(x)
        norm_kv = self.kv_norm(compressed_kv)
        key_nope: Tensor = self.decompress_k_nope(norm_kv)
        value: Tensor = self.decompress_v_linear(norm_kv)
        
        key_rope:Tensor = self.k_rope_linear(x)
        # norm_rope = self.rope_norm(key_rope)

        query_nope = query_nope.view(batch_size, seq_len, self.num_heads, self.nope_head_dim).transpose(1,2)
        query_rope = query_rope.view(batch_size, seq_len, self.num_heads, self.rope_head_dim).transpose(1,2)
        
        key_rope = key_rope.view(batch_size, seq_len, 1, self.rope_head_dim).transpose(1,2)
        key_nope = key_nope.view(batch_size, seq_len, self.num_heads, self.nope_head_dim).transpose(1,2)
        
        value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
        
        k_rope, q_rope = apply_rope(query_rope,key_rope, cis=freqs_cis)
        
        q_recombined = torch.empty((batch_size,self.num_heads,seq_len, self.head_dim), device=x.device)
        k_recombined = torch.empty((batch_size, self.num_heads, seq_len, self.head_dim), device=x.device)
        
        q_recombined[:,:,:,:self.nope_head_dim] = query_nope
        q_recombined[:,:,:,self.nope_head_dim:] = q_rope
        
        # k_rope = torch.repeat_interleave(k_rope, self.num_heads, dim=1) # >> you dont need to do this <<
        # 👇 broadcasting will do replication krope to all heads automagically
        k_recombined[:,:,:,:self.nope_head_dim] = key_nope
        k_recombined[:,:,:,self.nope_head_dim:] = k_rope

        output = F.scaled_dot_product_attention(q_recombined, k_recombined, value, is_causal=True)

        output = output.contiguous().view(batch_size, seq_len, self.num_heads * self.head_dim)

        output = self.proj(output)

        return output


class MLA_Inference(MultiHeadLatentAttention):
    def __init__(self,config:Config):
        super().__init__(config)
        self.inference_merged = False
        
    def inference_merge(self):
        Wd_Qnope = self.decompress_q_nope.weight.detach()
        Wd_Knope = self.decompress_k_nope.weight.detach()
        Wd_V = self.decompress_v_linear.weight.detach()
        
        W_proj = self.proj.weight.detach()
        
        Wd_Qnope = Wd_Qnope.reshape(self.num_heads, Wd_Qnope.T.shape[0], -1)
        Wd_Knope = Wd_Knope.reshape(self.num_heads, Wd_Knope.T.shape[0], -1)
        
        # print(f"Wd_Qnope.shape: {Wd_Qnope.shape}, Wd_Knope.shape: {Wd_Knope.shape}")
        WdQK = Wd_Qnope @ Wd_Knope.transpose(-2, -1)
        # print(f"WdQK.shape: {WdQK.shape}")
        
        WdVO = Wd_V.T @ W_proj
        
        # print(f"WdQK.shape: {WdQK.shape}, WdVO.shape: {WdVO.shape}")
        
        self.register_buffer("WdQK", WdQK)
        
        self.inference_merged = True
        
    def forward(self,x:Tensor,freqs_cis:Tensor):
        assert self.inference_merged, "model is not merged run .inference_merge() first"


        batch_size, seq_len, _ = x.shape

        compressed_q = self.compress_q_linear(x)
        norm_q = self.q_norm(compressed_q)
        query_nope:Tensor = self.decompress_q_nope(norm_q)
        query_rope:Tensor = self.decompress_q_rope(norm_q)

        compressed_kv = self.compress_kv_linear(x)
        norm_kv = self.kv_norm(compressed_kv)
        key_nope: Tensor = self.decompress_k_nope(norm_kv)
        value: Tensor = self.decompress_v_linear(norm_kv)
        
        key_rope:Tensor = self.k_rope_linear(x)
        # norm_rope = self.rope_norm(key_rope)

        query_nope = query_nope.view(batch_size, seq_len, self.num_heads, self.nope_head_dim).transpose(1,2)
        query_rope = query_rope.view(batch_size, seq_len, self.num_heads, self.rope_head_dim).transpose(1,2)
        
        key_rope = key_rope.view(batch_size, seq_len, 1, self.rope_head_dim).transpose(1,2)
        key_nope = key_nope.view(batch_size, seq_len, self.num_heads, self.nope_head_dim).transpose(1,2)
        
        value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
        
        k_rope, q_rope = apply_rope(query_rope,key_rope, cis=freqs_cis)
        
        attn_rope = q_rope @ k_rope.transpose(-2, -1)
        
        # print(f"compressed_q.shape: {compressed_q.unsqueeze(-3).shape} WdQK.shape: {self.WdQK.shape}")
        attn_nope = compressed_q.unsqueeze(-3) @ self.WdQK @ compressed_kv.unsqueeze(-3).transpose(-2, -1)
        # print(f"attn_rope.shape: {attn_rope.shape}, attn_nope.shape: {attn_nope.shape}")
        
        attn = (attn_rope + attn_nope) / self.head_dim  
        
        attn = torch.tril(attn).masked_fill(torch.tril(torch.ones(attn.shape, dtype=bool)) == 0, -torch.inf)
        
        output = attn.softmax(dim=-1) @ value
        
        output = output.view(batch_size, seq_len, self.num_heads * self.head_dim)
        
        output = self.proj(output)

        return output
        
        

In [426]:
d_model = 1024
num_heads = 20
head_dim = 128
kv_lora_rank = 64
q_lora_rank = 3 * kv_lora_rank
rope_head_dim = 32

config = Config(
    vocab_size=30522,
    d_model=d_model,
    seq_len=2048,
    num_heads=num_heads,
    head_dim=head_dim,
    q_lora_rank=q_lora_rank,
    kv_lora_rank=kv_lora_rank,
    rope_head_dim=rope_head_dim,
)


x = torch.randn(2, 10, d_model)
freqs_cis = precompute_freqs_cis(config.rope_head_dim, config.seq_len)


In [427]:
mla = MultiHeadLatentAttention(config)
mla_inference = MLA_Inference(config)

mla_inference.load_state_dict(mla.state_dict())


<All keys matched successfully>

In [428]:
batch_size, seq_len, _ = x.shape

cq = mla.compress_q_linear(x)
ckv = mla.compress_kv_linear(x)


cq = mla.q_norm(cq)
ckv = mla.kv_norm(ckv)

q_nope = mla.decompress_q_nope(cq)
k_nope = mla.decompress_k_nope(ckv)

q_nope = q_nope.view(batch_size, seq_len, mla.num_heads, mla.nope_head_dim).transpose(1,2)
k_nope = k_nope.view(batch_size, seq_len, mla.num_heads, mla.nope_head_dim).transpose(1,2)

print(q_nope.shape)
print(k_nope.shape)

In [429]:
mask = torch.tril(torch.ones(mla.num_heads, seq_len, seq_len, dtype=torch.bool))

attn_nope = q_nope@k_nope.transpose(-1,-2)



attn_nope = attn_nope.masked_fill(mask == 0, -torch.inf)

attn_nope = attn_nope.softmax(dim=-1)


In [430]:

Wd_Qnope = mla.decompress_q_nope.weight.detach().clone()
Wd_Knope = mla.decompress_k_nope.weight.detach().clone()

In [431]:
out_shape, in_shape = Wd_Qnope.shape
print(in_shape,out_shape)



In [432]:
print(Wd_Qnope.T.shape)
xx = torch.rand_like(cq)
q_nope_x = mla.decompress_q_nope(xx)
q_nope = q_nope_x#.view(batch_size, seq_len, mla.num_heads, mla.nope_head_dim)

Wd_Qnope = mla.decompress_q_nope.weight.detach().clone()

inp = cq[0][0].reshape(1,-1)
o1 = (inp @Wd_Qnope.T)#.reshape( mla.num_heads, mla.nope_head_dim)          
o1.shape,q_nope[0][0].reshape(1,-1).shape
torch.allclose(o1,q_nope[0][0].reshape(1,-1))

False

In [433]:
WqT = mla.decompress_q_nope.weight.T
assert torch.allclose( mla.decompress_q_nope(xx),  xx @ WqT)

olx = mla.decompress_q_nope(xx).reshape(batch_size,seq_len,mla.num_heads,mla.nope_head_dim)
ocx = (xx @ WqT).reshape(batch_size,seq_len,mla.num_heads,mla.nope_head_dim)
assert torch.allclose(olx,ocx)


In [434]:
si,so = WqT.shape
idx = 4
hl1 = olx[:,:,idx,:]
hc1 = (xx @ WqT.reshape(si,num_heads,-1)[:,idx,:])
eo  = torch.einsum("bld,dj->blj",xx,WqT.reshape(si,num_heads,-1)[:,idx,:])

assert torch.allclose(hl1,eo)


In [435]:
WqT.reshape(si,num_heads,-1).transpose(0,1).shape

torch.Size([20, 192, 96])

In [436]:
H_dQ = mla.decompress_q_nope.weight.T.reshape(si,num_heads,-1).transpose(0,1)
eo  = torch.einsum("bld,hdj->blhj",xx,H_dQ)
olx = mla.decompress_q_nope(xx).reshape(batch_size,seq_len,mla.num_heads,mla.nope_head_dim)
print(olx.shape,eo.shape)
assert torch.allclose(olx,eo)

In [437]:
print(mla.decompress_k_nope.weight.T.shape)
xi,xo = mla.decompress_k_nope.weight.T.shape
H_dKV = mla.decompress_k_nope.weight.T.reshape(xi,num_heads,-1).transpose(0,1)
print(H_dKV.shape)

In [438]:
H_dQ.shape,H_dKV.transpose(-1,-2).shape


(torch.Size([20, 192, 96]), torch.Size([20, 96, 64]))

In [439]:
H_dQK = H_dQ@H_dKV.transpose(-1,-2)

H_dQK.shape

torch.Size([20, 192, 64])

In [440]:
cq.shape

torch.Size([2, 10, 192])

## Merge Wv and Wo

In [441]:
x = torch.randn(batch_size,seq_len,d_model)
batch_size, seq_len, _ = x.shape

cq = mla.compress_q_linear(x)
ckv = mla.compress_kv_linear(x)


cq = mla.q_norm(cq)
ckv = mla.kv_norm(ckv)

q_nope = mla.decompress_q_nope(cq)
k_nope = mla.decompress_k_nope(ckv)

q_nope = q_nope.view(batch_size, seq_len, mla.num_heads, mla.nope_head_dim).transpose(1,2)
k_nope = k_nope.view(batch_size, seq_len, mla.num_heads, mla.nope_head_dim).transpose(1,2)



q = q_nope.clone().detach()
k = k_nope.clone().detach()

q.shape,k.shape

attn = q@k.transpose(-1,-2)

attn.shape

torch.Size([2, 20, 10, 10])

In [514]:

Wv = mla.decompress_v_linear.weight
wv_in, wv_out = Wv.shape

h_Wv = Wv.T.reshape(wv_out,mla.num_heads,-1).transpose(0,1)

value = (ckv@Wv.T).view(batch_size, seq_len, mla.num_heads, mla.head_dim).transpose(1,2)

print(f"ckv.shape: {ckv.shape}, h_Wv.shape: {h_Wv.shape}")
v_shaped = torch.einsum("bld,hdj->bhlj",ckv,h_Wv)
print(v_shaped.shape)


attn_out = (attn@value) 
assert torch.allclose(attn_out,attn@v_shaped)
print(f"attn_out.shape: {attn_out.shape}")

In [491]:
Wo = mla.proj.weight
o_in, o_out = Wo.shape
print(o_in,o_out)
h_Wo = Wo.T.reshape(-1,mla.num_heads,o_in).transpose(0,1)
print(f"h_Wo.shape: {h_Wo.shape}")
output = (attn_out.transpose(1,2).reshape(batch_size,seq_len,num_heads*head_dim) @ Wo.T)
print(f"output.shape: {output.shape}")

flat_output = torch.einsum("blhi,ij->blj",attn_out.transpose(1,2).reshape(batch_size,seq_len,num_heads*head_dim) ,Wo.T)
print("x",flat_output.shape)
assert torch.allclose(output,flat_output)

In [474]:
attn_out.transpose(1,2).shape

torch.Size([2, 10, 20, 128])

In [478]:
torch.arange(20).reshape(2,5,2)

tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5],
         [ 6,  7],
         [ 8,  9]],

        [[10, 11],
         [12, 13],
         [14, 15],
         [16, 17],
         [18, 19]]])

In [508]:
x = torch.rand(2, 3,4, 5)
w = torch.rand(4*5, 10)

out = x.reshape(2,3, -1) @ w

# einsum
w2 = w.reshape(4, 5, 10)

w2 = w.reshape(4, 5, 10)
out2 = torch.einsum("blcv,cvj->blj", x, w2)

assert torch.allclose(out, out2)

In [535]:
vv = attn_out.transpose(1,2)
print(f"attn_out.shape: {vv.shape}, mla.proj.weight.T.shape: {mla.proj.weight.T.shape}")
w = mla.proj.weight.T
out = vv.reshape(2,10,-1) @ w
print(out.shape)

w2 = w.reshape(num_heads,head_dim,-1)
print(f"vv.shape: {vv.shape}, w2.shape: {w2.shape}")
out2 =torch.einsum("bclv,cvj->blj", attn_out, w2)
print(out2.shape)
assert torch.allclose(out, out2)

# Final

In [536]:
d_model = 1024
num_heads = 20
head_dim = 128
kv_lora_rank = 64
q_lora_rank = 3 * kv_lora_rank
rope_head_dim = 32

config = Config(
    vocab_size=30522,
    d_model=d_model,
    seq_len=2048,
    num_heads=num_heads,
    head_dim=head_dim,
    q_lora_rank=q_lora_rank,
    kv_lora_rank=kv_lora_rank,
    rope_head_dim=rope_head_dim,
)


x = torch.randn(2, 10, d_model)
freqs_cis = precompute_freqs_cis(config.rope_head_dim, config.seq_len)
mla = MultiHeadLatentAttention(config)
mla_inference = MLA_Inference(config)

mla_inference.load_state_dict(mla.state_dict())

<All keys matched successfully>

In [550]:

def test_mla_inference():
    d_model = 1024
    num_heads = 20
    head_dim = 128
    kv_lora_rank = 64
    q_lora_rank = 3 * kv_lora_rank
    rope_head_dim = 32

    config = Config(
        vocab_size=30522,
        d_model=d_model,
        seq_len=2048,
        num_heads=num_heads,
        head_dim=head_dim,
        q_lora_rank=q_lora_rank,
        kv_lora_rank=kv_lora_rank,
        rope_head_dim=rope_head_dim,
    )

    mask = torch.full((1, 1, 2048, 2048), float("-inf"))
    mask = torch.triu(mask, diagonal=1)
    
    x = torch.randn(2, 10, d_model)
    freqs_cis = precompute_freqs_cis(config.rope_head_dim, config.seq_len)
    mla = MultiHeadLatentAttention(config)
    mla_inference = MLA_Inference(config)

    mla_inference.load_state_dict(mla.state_dict())

    batch_size, seq_len, _ = x.shape

    cq = mla.compress_q_linear(x)
    ckv = mla.compress_kv_linear(x)


    cq = mla.q_norm(cq)
    ckv = mla.kv_norm(ckv)
    

    q_rope = mla.decompress_q_rope(cq)
    k_rope = mla.k_rope_linear(x)

    q_rope = q_rope.view(batch_size, seq_len, mla.num_heads, mla.rope_head_dim).transpose(1,2)
    k_rope = k_rope.view(batch_size, seq_len, 1, mla.rope_head_dim).transpose(1,2)
    
    # k_rope = torch.repeat_interleave(k_rope, mla.num_heads, dim=1)
    print(f"q_rope.shape: {q_rope.shape}, k_rope.shape: {k_rope.shape}")
    attn_rope = q_rope @ k_rope.transpose(-2, -1)
    

test_mla_inference()
