In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [11]:
class ropeless(nn.Module):
    def __init__(self,d_model,kv_latent_dim,num_head):
        super().__init__()
        self.d_model=d_model
        self.kv_latent_dim=kv_latent_dim
        self.num_heads=num_head

        self.dim_each_head=d_model//num_head

        self.W_q=nn.Linear(d_model,d_model,bias=False)
        self.W_dkv=nn.Linear(d_model,kv_latent_dim,bias=False)
        self.W_uk=nn.Linear(kv_latent_dim,d_model,bias=False)
        self.W_uv=nn.Linear(kv_latent_dim,d_model,bias=False)
        self.W_o=nn.Linear(d_model,d_model,bias=False)

        self.layernom=nn.LayerNorm(kv_latent_dim)
        self.register_buffer('absorbed',None)

    
    def forward(self,x,kv_cache=None,past_length=0):
        batch,size1,dim=x.size()

        if self.absorbed is None:
            absorbed=torch.matmul(self.W_q.weight,self.W_uk.weight)
            self.absorbed=absorbed.view(self.num_heads,self.dim_each_head,-1) # basically jitne head hai utne me split krna hoga , if nhead=2 divide absorbed in 2 parts 
        
        new_cache_kv=self.layernom(self.W_dkv(x))

        if kv_cache is None:
            cache_kv=new_cache_kv
        else:
            cache_kv=torch.cat([kv_cache,new_cache_kv],dim=1)

        
        size_full=cache_kv.size(1)
        value_full=self.W_uv(cache_kv)
        v=value_full.view(batch,size_full,self.num_heads,self.dim_each_head).transpose(1,2)

        q=x.view(batch,size1,self.num_heads,self.dim_each_head)
        attn_score=torch.zeros(batch,self.num_heads,size1,size_full,device=x.device)

        for heads in range(self.num_heads):
            scores=torch.matmul(q[:,:,heads],self.absorbed[heads])
            attn_score[:,heads]=torch.bmm(scores,cache_kv.transpose(1,2))

        attn_score = attn_score / (self.dim_each_head ** 0.5)
        mask = torch.tril(torch.ones((size1, size_full), device=x.device), diagonal=past_length)
        attn_score = attn_score.masked_fill(mask.view(1, 1, size1, size_full) == 0, float('-inf'))

        attn_weights = F.softmax(attn_score, dim=-1)  

        out_heads = []
        for h in range(self.num_heads):
            context_h = torch.matmul(attn_weights[:, h], v[:, h])  
            out_heads.append(context_h)

        out = torch.cat(out_heads, dim=-1)  
        return self.W_o(out), cache_kv
        



In [12]:
d_model = 512
n_heads = 8
seq_len = 16
batch_size = 2
kv_latent_dim = 128 


x = torch.randn(batch_size, seq_len, d_model)


model = ropeless(d_model=d_model, kv_latent_dim=kv_latent_dim, num_head=n_heads)

out, new_cache = model(x)

print(f"Output shape: {out.shape}")        # Expected: (batch_size, seq_len, d_model)
print(f"Cache shape: {new_cache.shape}")   # Expected: (batch_size, seq_len, kv_latent_dim)

Output shape: torch.Size([2, 16, 512])
Cache shape: torch.Size([2, 16, 128])
