In [None]:
import os
import torch 
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from math import sqrt

In [None]:
@dataclass
class config:
    # hyperparameters
    batch_size : int # how many independent sequences will we process in parallel?
    block_size : int  # what is the maximum context length for predictions?
    vocab_size : int # OPTIM 4 (along with grad clipping) brought dt from 95 to 90

    max_iters : int
    eval_interval : int
    learning_rate : float
    warmup_steps : int
    max_decay_steps : int

    device : str
    eval_iters : int
    compile : bool #= False if os.name != 'posix' else True
    save_model : bool

    latent_dim : int
    n_embd : int
    n_head : int
    n_layer : int
    n_kv_heads : int # Set to 6 for MHA, 1 for MQA, or another divisor of n_head for GQA
    dropout : float
    total_batch_size : int

MLAconfig = config(
    # hyperparameters
    batch_size = 4, # how many independent sequences will we process in parallel?
    block_size = 1024, # what is the maximum context length for predictions?
    vocab_size = 50304, # OPTIM 4 (along with grad clipping) brought dt from 95 to 90

    max_iters = 500,
    eval_interval = 50,
    learning_rate = 3e-4,
    warmup_steps = 25,
    max_decay_steps = 75,

    device = 'cuda' if torch.cuda.is_available() else 'cpu',
    eval_iters = 200,
    compile = False if os.name != 'posix' else True,
    save_model = True,

    n_embd = 768,
    n_head = 8,
    latent_dim = 32,
    n_layer = 6,
    n_kv_heads = 2, # Set to 6 for MHA, 1 for MQA, or another divisor of n_head for GQA
    dropout = 0.2,
    total_batch_size = 2**16)

In [None]:
class MHLA(nn.Module):
    def __init__(self, config:config):
        super().__init__()
        self.config = config
        assert config.n_embd % config.n_head == 0, "num of heads must be a divisor of n_embd"
        self.head_size = config.n_embd // config.n_head
        # Projection layers
        self.W_dq  = nn.Linear(config.n_embd,     config.latent_dim, bias=False)  # Query down projection
        self.W_uq  = nn.Linear(config.latent_dim, config.n_embd,     bias=False)  # Query up projection
        self.W_dkv = nn.Linear(config.n_embd,     config.latent_dim, bias=False)  # Compress into latent KV space
        self.W_uk  = nn.Linear(config.latent_dim, config.n_embd,     bias=False)  # Decompress K
        self.W_uv  = nn.Linear(config.latent_dim, config.n_embd,     bias=False)  # Decompress V
        self.W_o   = nn.Linear(config.n_embd,     config.n_embd,     bias=False)  # Final output projection
        # self.ln  = nn.LayerNorm(config.latent_dim)
        self.dropout = nn.Dropout(config.dropout)
        self.register_buffer('k_abs', None)
        self.register_buffer('v_abs', None)
        self.register_buffer('tril', torch.tril(torch.ones(config.block_size, config.block_size)).unsqueeze(0).unsqueeze(0))

    def forward(self, x:torch.Tensor, kv_cache=None) -> torch.Tensor:

        B, T, C = x.size()
        nh , nl, hs = self.config.n_head, self.config.latent_dim, self.config.n_embd//self.config.n_head
        if self.k_abs is None:
            k_absorbed = self.W_dq.weight.T @ self.W_uk.weight.T @ self.W_uk.weight # (C,nl) x (nl,C) x (C,nl) = (C,nl)
            self.k_abs = k_absorbed.view(nh, hs, nl).unsqueeze(0) # (1, nh, hs, nl)

        if self.v_abs is None:
            v_absorbed = self.W_uv.weight.T @ self.W_o.weight.T   # (nl, C) x (C, C) = (nl, C)
            self.v_abs = v_absorbed.view(nl, nh, hs).transpose(0,1).unsqueeze(0) # (1, nh, nl, hs)
        
        new_c_kv = self.W_dkv(x)  # down projection : (B,T,C) -> (B,T,nl)
        if kv_cache is None:
            c_kv = new_c_kv # (B,T,nl) ; initiate cache
        else:
            c_kv = torch.cat([kv_cache, new_c_kv], dim=1) # append cache
        
        # Q*K^T = x * k_abs * c_kv^T   ### for variables, let q replace x, as q appears no where
        # x -> (B,T,C)
        q = x.view(B,T,nh,hs).transpose(1, 2) # (B,T,C) -> (B,T,nh,hs) -> (B, nh, T, hs)

        # now we have everything to compute attention scores, attn = q * k_abs * c_kv^T 
        # (B, nh, T, hs) * (1, nh, hs, nl) * (B, 1, nl, T) = (B, nh, T, T)
        # the following three steps can be made significatnly faster by avioding VRAM bottlenecks, perhaps by FlashMLA
        attn = (q @ self.k_abs @ c_kv.transpose(1,2).unsqueeze(1)) / sqrt(hs) # (B, nh, T, T) # significatnly faster than **-0.5
        attn = attn.masked_fill(self.tril[...,:T,:T] == 0, float('-inf'))
        attn = self.dropout(F.softmax(attn, dim=-1))  # (B, nh, T, T)

        # final output : attn @ C_kv @ v_abs 
        # (B, nh, T, T) * (B, 1, T, nl) * (1, nh, nl, hs) = (B, nh, T, hs)
        y:torch.Tensor = attn @ c_kv.unsqueeze(1) @ self.v_abs  #(B, nh, T, hs)
        y = self.dropout(y.transpose(1,2).contiguous().view(B,T,C))

        return y, c_kv

In [26]:
mla = MHLA(MLAconfig)
x = torch.randn(64, 19, 768)
y = mla(x)

In [24]:
y.shape

torch.Size([64, 19, 768])