In [1]:
import torch 
import torch.nn as nn

In [2]:
class MultiHeadLatentAttention(nn.Module):
    def __init__(self, d_model, num_heads,d_latent,dropout = 0.0):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head= d_model // num_heads
        self.d_latent = d_latent

        self.W_q = nn.Linear(d_model, d_model)
        self.W_dkv = nn.Linear(d_model, d_latent)

        self.W_uk = nn.Linear(d_latent, d_model)
        self.W_uv = nn.Linear(d_latent, d_model)

        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

        self.register_buffer('mask',torch.triu( torch.ones(1,1,1024,1024),diagonal=1).bool())

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        c_kv = self.W_dkv(x) # shape : (batch,seq_len,d_latent)

        k = self.W_uk(c_kv).view(batch_size,seq_len,self.num_heads,self.d_head).transpose(1,2)
        v = self.W_uv(c_kv).view(batch_size,seq_len,self.num_heads,self.d_head).transpose(1,2)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_head ** 0.5)

        #apply causal mask
        attn_scores = attn_scores.masked_fill(self.mask[:,:,:seq_len,:seq_len],float('-inf'))

        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vector = (attn_weights @ v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

        output = self.W_o(context_vector)
        return output
        



In [3]:
d_model = 512
num_heads = 8
d_latent = 128
batch_size = 4
seq_len = 64



In [4]:
mla_layer = MultiHeadLatentAttention(d_model, num_heads,d_latent,dropout=0.1)

In [5]:
dummy_input = torch.randn(batch_size, seq_len, d_model)

In [6]:
output = mla_layer(dummy_input)


In [7]:
print("MLA Layer successful!")
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
     

MLA Layer successful!
Input shape: torch.Size([4, 64, 512])
Output shape: torch.Size([4, 64, 512])
