In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class config:
    # hyperparameters
    batch_size : int # how many independent sequences will we process in parallel?
    block_size : int  # what is the maximum context length for predictions?
    vocab_size : int # OPTIM 4 (along with grad clipping) brought dt from 95 to 90

    max_iters : int
    eval_interval : int
    learning_rate : float
    warmup_steps : int
    max_decay_steps : int

    device : str
    eval_iters : int
    compile : bool #= False if os.name != 'posix' else True
    save_model : bool

    kv_latent_dim : int
    q_latent_dim : int
    n_embd : int
    n_head : int
    n_layer : int
    n_kv_heads : int # Set to 6 for MHA, 1 for MQA, or another divisor of n_head for GQA
    dropout : float
    total_batch_size : int

In [None]:
class MHLA(nn.Module):
    def __init__(self, config:config):
        super().__init__()

        self.config = config

        # Projection layers
        self.W_dq  = nn.Linear(config.n_embd,        config.q_latent_dim,  bias=False)  # Query down projection
        self.W_uq  = nn.Linear(config.q_latent_dim,  config.n_embd,        bias=False)  # Query up projection
        self.W_dkv = nn.Linear(config.n_embd,        config.kv_latent_dim, bias=False)  # Compress into latent KV space
        self.W_uk  = nn.Linear(config.kv_latent_dim, config.n_embd,        bias=False)  # Decompress K
        self.W_uv  = nn.Linear(config.kv_latent_dim, config.n_embd,        bias=False)  # Decompress V
        self.W_o   = nn.Linear(config.n_embd,        config.n_embd,        bias=False)  # Final output projection

        self.ln = nn.LayerNorm(config.kv_latent_dim)
        self.register_buffer('absorbed_k', None)  # Holds W_q @ W_uk

    def forward(self, x, kv_cache=None, past_length=0):
        B, S, D = x.size()

        # Compute absorbed_k once: W_q @ W_uk, shape: (D, latent_dim)
        if self.absorbed_k is None:
            absorbed = torch.matmul(self.W_q.weight, self.W_uk.weight)
            self.absorbed_k = absorbed.view(self.n_heads, self.dh, -1)
