<a href="https://colab.research.google.com/github/TesterSim2/The-Kepler-Architecture/blob/main/Untitled32.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ============================================================
# Research Script: Differential Multi-Head Latent Attention (Diff-MLA)
# ============================================================
# Hypothesis: We can achieve the "Noise Cancellation" of DiffAttn
# while maintaining the "KV Cache Compression" of MLA.
#
# Mechanism:
# 1. Store Compressed Latent KV (cKV).
# 2. Project Latent Query (cQ) into 2x Heads (Signal Group, Noise Group).
# 3. Project cKV into 2x "Virtual" Keys via absorbed matrix multiplication.
# 4. Compute DiffScore = Softmax(Signal) - lambda * Softmax(Noise).
# 5. Apply Headwise Norm (Critical for DiffAttn stability).
# ============================================================

import os, math, time, random
from dataclasses import dataclass
from typing import Optional, Dict, Any, List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

# ----------------------------
# 0) Setup & Device
# ----------------------------
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

DTYPE = torch.bfloat16 if (device == "cuda" and torch.cuda.is_bf16_supported()) else torch.float16

# ----------------------------
# 1) Data (TinyShakespeare)
# ----------------------------
# We use character-level for speed/simplicity to focus on architectural mechanics
if not os.path.exists('input.txt'):
    os.system('wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
data = torch.tensor(encode(text), dtype=torch.long, device=device)

# ----------------------------
# 2) Config
# ----------------------------
@dataclass
class ModelConfig:
    vocab_size: int = vocab_size
    dim: int = 512
    n_layers: int = 6
    n_heads: int = 8        # Output heads (DiffAttn uses 2x internally)
    max_seq_len: int = 1024

    # Diff-MLA Specifics
    kv_rank: int = 128      # Compression for KV
    q_rank: int = 256       # Compression for Query
    rope_head_dim: int = 32 # Decoupled RoPE dimension
    diff_lambda_init: float = 0.8

# ----------------------------
# 3) Utilities
# ----------------------------
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        var = x.pow(2).mean(dim=-1, keepdim=True)
        return x * torch.rsqrt(var + self.eps) * self.weight

class SwiGLU(nn.Module):
    def __init__(self, dim: int, hidden_dim: int):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(end, device=device)
    freqs = torch.outer(t, freqs).float()
    return torch.polar(torch.ones_like(freqs), freqs)

def apply_rotary_emb(xq, xk, freqs_cis):
    # Standard RoPE application
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs = freqs_cis[:xq.shape[1]].view(1, xq.shape[1], 1, -1)
    xq_out = torch.view_as_real(xq_ * freqs).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

# ----------------------------
# 4) The Hybrid Architecture: Diff-MLA
# ----------------------------
class DiffMLA(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.nh = cfg.n_heads
        self.hd = cfg.dim // cfg.n_heads
        self.kv_rank = cfg.kv_rank
        self.q_rank = cfg.q_rank
        self.rhd = cfg.rope_head_dim

        # 1. Query Compression (Latent Q)
        self.W_DQ = nn.Linear(cfg.dim, self.q_rank, bias=False)

        # 2. Up-Projections (Signal vs Noise Pairs)
        # We generate 2x heads (Group 1 and Group 2) from the same latent vector
        self.W_UQ1 = nn.Linear(self.q_rank, self.nh * self.hd, bias=False) # Signal Q
        self.W_UQ2 = nn.Linear(self.q_rank, self.nh * self.hd, bias=False) # Noise Q

        self.W_QR = nn.Linear(self.q_rank, self.nh * self.rhd, bias=False) # RoPE Q (Shared base?)

        # 3. Key-Value Compression (The Latent Vector)
        self.W_DKV = nn.Linear(cfg.dim, self.kv_rank, bias=False)

        # 4. Virtual Key Up-Projections (The Absorption Trick)
        # Instead of 1 Up-Projection, we have 2.
        # But we never output [B, T, H, D]. We use these weights to transform Q.
        self.W_UK1 = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False) # Signal K
        self.W_UK2 = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False) # Noise K

        # 5. Value Projection (Shared)
        # DiffAttn usually shares V between the diff pair
        self.W_UV = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False)

        # 6. RoPE Key (Shared)
        self.W_KR = nn.Linear(cfg.dim, self.rhd, bias=False)

        # 7. Differential Components
        self.lambda_init = cfg.diff_lambda_init
        self.lambda_q1 = nn.Parameter(torch.randn(self.hd))
        self.lambda_k1 = nn.Parameter(torch.randn(self.hd))
        self.lambda_q2 = nn.Parameter(torch.randn(self.hd))
        self.lambda_k2 = nn.Parameter(torch.randn(self.hd))

        # Headwise Normalization (Apply to each head output individually)
        self.diff_norm = RMSNorm(self.hd)

        self.c_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)

    def _get_lambda(self):
        # Calculate learnable lambda
        l1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1))
        l2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2))
        return l1 - l2 + self.lambda_init

    def forward(self, x, freqs_cis):
        B, T, C = x.size()

        # --- Compression Phase ---
        cQ = self.W_DQ(x) # [B, T, q_rank]
        cKV = self.W_DKV(x) # [B, T, kv_rank]

        # --- Generate Heads (Signal & Noise) ---
        # Content Queries
        q1 = self.W_UQ1(cQ).view(B, T, self.nh, self.hd)
        q2 = self.W_UQ2(cQ).view(B, T, self.nh, self.hd)

        # RoPE Queries & Keys
        q_rope = self.W_QR(cQ).view(B, T, self.nh, self.rhd)
        k_rope = self.W_KR(x).view(B, T, 1, self.rhd).expand(B, T, self.nh, self.rhd)
        q_rope, k_rope = apply_rotary_emb(q_rope, k_rope, freqs_cis)

        # --- The Absorption Trick (Training Mode) ---
        # For training, we can just expand the Keys efficiently using SDPA
        # In inference, we would do the Q @ W_UK trick. Here we stick to PyTorch SDPA logic.

        # Expand Virtual Keys
        k1 = self.W_UK1(cKV).view(B, T, self.nh, self.hd)
        k2 = self.W_UK2(cKV).view(B, T, self.nh, self.hd)
        v = self.W_UV(cKV).view(B, T, self.nh, self.hd)

        # Combine Content + RoPE
        # Note: We duplicate q_rope/k_rope for both Signal and Noise branches
        # to ensure position awareness is identical
        q1_full = torch.cat([q1, q_rope], dim=-1)
        q2_full = torch.cat([q2, q_rope], dim=-1)
        k1_full = torch.cat([k1, k_rope], dim=-1)
        k2_full = torch.cat([k2, k_rope], dim=-1)

        # --- Differential Attention ---
        # Score 1 (Signal)
        # Scaled Dot Product Attention (Manual Softmax to allow subtraction)
        scale = 1.0 / math.sqrt(self.hd + self.rhd)

        # [B, H, T, T]
        scores1 = torch.matmul(q1_full.transpose(1, 2), k1_full.transpose(1, 2).transpose(-2, -1)) * scale
        scores2 = torch.matmul(q2_full.transpose(1, 2), k2_full.transpose(1, 2).transpose(-2, -1)) * scale

        # Causal Mask
        mask = torch.triu(torch.ones(T, T, device=device) * float('-inf'), diagonal=1)
        scores1 = scores1 + mask
        scores2 = scores2 + mask

        attn1 = torch.softmax(scores1, dim=-1)
        attn2 = torch.softmax(scores2, dim=-1)

        # Subtraction
        lam = self._get_lambda()
        diff_attn = attn1 - lam * attn2

        # Aggregate Values
        # [B, H, T, T] @ [B, H, T, D] -> [B, H, T, D]
        y = torch.matmul(diff_attn, v.transpose(1, 2))

        # --- Headwise Normalization ---
        # DiffAttn requires normalizing EACH head output before concatenation
        y = y.transpose(1, 2) # [B, T, H, D]
        y = self.diff_norm(y) # Norm over D dimension
        y = y * (1 - self.lambda_init) # Scaling for gradient stability

        y = y.reshape(B, T, C)
        return self.c_proj(y)

    @torch.no_grad()
    def prefill_cache(self, x, freqs_cis):
        # Returns standard MLA-style compressed cache
        # cKV: [B, T, kv_rank]
        # k_rope: [B, T, rhd]
        cKV = self.W_DKV(x)
        k_rope = self.W_KR(x)

        # We compute output just like forward
        out = self.forward(x, freqs_cis)

        return out, {"cKV": cKV, "k_rope": k_rope}

    @torch.no_grad()
    def decode_step(self, x, freqs_cis, cache):
        # Here is where the Memory Efficiency shines.
        # We use the TINY cKV cache to generate BOTH Signal and Noise keys on the fly.

        # 1. Update Cache
        cKV_new = self.W_DKV(x)
        k_rope_new = self.W_KR(x)

        cache['cKV'] = torch.cat([cache['cKV'], cKV_new], dim=1)
        cache['k_rope'] = torch.cat([cache['k_rope'], k_rope_new], dim=1)

        cKV = cache['cKV'] # [B, Seq, kv_rank]

        # 2. Generate Current Query
        cQ = self.W_DQ(x)
        q1 = self.W_UQ1(cQ).view(1, 1, self.nh, self.hd)
        q2 = self.W_UQ2(cQ).view(1, 1, self.nh, self.hd)
        q_rope = self.W_QR(cQ).view(1, 1, self.nh, self.rhd)

        # 3. RoPE Rotation (Current Query + All Cached RoPE Keys)
        # Note: In real impl, we'd rotate carefully. Here simplified for benchmark.
        q_rope, _ = apply_rotary_emb(q_rope, q_rope, freqs_cis)

        # 4. ABSORPTION TRICK: Generate Keys from Latent Cache
        # Instead of storing K1/K2, we project cKV on the fly
        # K1 = cKV @ W_UK1.T
        k1 = self.W_UK1(cKV).view(1, -1, self.nh, self.hd)
        k2 = self.W_UK2(cKV).view(1, -1, self.nh, self.hd)
        v = self.W_UV(cKV).view(1, -1, self.nh, self.hd)

        # 5. Compute Diff Attention
        # (Simplified decode logic matching forward pass for benchmarking)
        # In a real kernel, we would fuse the projection and dot product.
        # Score = (Q @ W_UK.T) @ cKV.T

        # ... (Standard logic follows, omitted for brevity in benchmark script logic)
        # Returning dummy output to measure throughput structure
        out = self.c_proj(torch.zeros_like(x))
        return out, cache

# ----------------------------
# 5) Training Loop
# ----------------------------
def get_batch(split_data):
    ix = torch.randint(len(split_data) - 1024, (8,)) # batch 8, seq 1024
    x = torch.stack([split_data[i:i+1024] for i in ix])
    y = torch.stack([split_data[i+1:i+1024+1] for i in ix])
    return x, y

def train_and_benchmark():
    print("\nInitializing Diff-MLA Model...")
    cfg = ModelConfig()
    model = nn.Sequential(
        nn.Embedding(cfg.vocab_size, cfg.dim),
        DiffMLA(cfg), # The Star of the Show
        DiffMLA(cfg), # Stack a few to feel the memory weight
        DiffMLA(cfg),
        RMSNorm(cfg.dim),
        nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
    ).to(device)

    # RoPE Cache
    freqs = precompute_freqs_cis(cfg.rope_head_dim, 2048)

    print(f"Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)

    # 1. Training Speed Test
    print("\n--- Training Speed Test (100 Steps) ---")
    model.train()
    torch.cuda.synchronize()
    t0 = time.time()

    for i in range(100):
        x, y = get_batch(data)

        # Manual Forward for layers (to pass freqs)
        h = model[0](x)
        for layer in model[1:4]: # The 3 DiffMLA layers
            h = layer(h, freqs[:x.shape[1]])
        h = model[4](h)
        logits = model[5](h)

        loss = F.cross_entropy(logits.view(-1, cfg.vocab_size), y.view(-1))
        opt.zero_grad()
        loss.backward()
        opt.step()

        if i % 20 == 0: print(f"Step {i}: Loss {loss.item():.4f}")

    torch.cuda.synchronize()
    dt = time.time() - t0
    print(f"Training Throughput: {100 * 8 * 1024 / dt:.0f} tokens/sec")
    print(f"Peak VRAM: {torch.cuda.max_memory_allocated()/1024**2:.0f} MB")

    # 2. Inference Memory Test
    print("\n--- Inference Memory Test (KV Cache) ---")
    model.eval()
    torch.cuda.reset_peak_memory_stats()

    # Simulate a long context generation (e.g., 4096 tokens)
    # We will measure the size of the cache objects

    # Create a dummy cache
    # In Diff-MLA, we store [B, T, kv_rank] + [B, T, rope_dim]
    # In MHA, we would store [B, T, n_heads, head_dim] * 2

    batch = 1
    seq_len = 8192 # Push it

    # MLA Cache Size Calculation
    # cKV: 1 * 8192 * 128 * 2 bytes (bf16)
    # k_rope: 1 * 8192 * 32 * 2 bytes
    mla_bytes = (batch * seq_len * cfg.kv_rank * 2) + (batch * seq_len * cfg.rope_head_dim * 2)

    # MHA Equivalent
    # K: 1 * 8192 * 8 * 64 * 2
    # V: 1 * 8192 * 8 * 64 * 2
    mha_bytes = (batch * seq_len * cfg.n_heads * (cfg.dim//cfg.n_heads) * 2) * 2

    print(f"Simulated Context: {seq_len} tokens")
    print(f"Standard MHA Cache: {mha_bytes / 1024**2:.2f} MB")
    print(f"Diff-MLA Cache:     {mla_bytes / 1024**2:.2f} MB")
    print(f"Compression Ratio:  {mha_bytes / mla_bytes:.2f}x")

if __name__ == "__main__":
    train_and_benchmark()

Running on: cuda
GPU: NVIDIA A100-SXM4-40GB


  self.setter(val)



Initializing Diff-MLA Model...
Parameters: 3.07M

--- Training Speed Test (100 Steps) ---
Step 0: Loss 4.4937
Step 20: Loss 3.3118
Step 40: Loss 3.2983
Step 60: Loss 3.3686
Step 80: Loss 3.3584
Training Throughput: 169805 tokens/sec
Peak VRAM: 3842 MB

--- Inference Memory Test (KV Cache) ---
Simulated Context: 8192 tokens
Standard MHA Cache: 16.00 MB
Diff-MLA Cache:     2.50 MB
Compression Ratio:  6.40x


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import os
import random
from dataclasses import dataclass

# ==========================================
# 1. Configuration & Reproducibility
# ==========================================
@dataclass
class ModelConfig:
    vocab_size: int = 65  # Will be set dynamically based on dataset
    dim: int = 384        # Hidden dimension
    n_layers: int = 6     # Depth
    n_heads: int = 6      # Number of heads
    max_seq_len: int = 1024
    dropout: float = 0.0  # Zero dropout for pure architectural comparison

    # Diff-MLA Specifics
    # We tune these to keep parameter count roughly similar to MHA
    kv_rank: int = 64     # High compression
    q_rank: int = 128
    rope_head_dim: int = 32
    diff_lambda_init: float = 0.8

def set_seed(seed=1337):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")

# ==========================================
# 2. Shared Components (RoPE, Norms)
# ==========================================
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        var = x.pow(2).mean(dim=-1, keepdim=True)
        return x * torch.rsqrt(var + self.eps) * self.weight

class SwiGLU(nn.Module):
    def __init__(self, dim: int, hidden_dim: int):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(end, device=device)
    freqs = torch.outer(t, freqs).float()
    return torch.polar(torch.ones_like(freqs), freqs)

def apply_rotary_emb(xq, xk, freqs_cis):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs = freqs_cis[:xq.shape[1]].view(1, xq.shape[1], 1, -1)
    xq_out = torch.view_as_real(xq_ * freqs).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

# ==========================================
# 3. Model 1: Baseline MHA
# ==========================================
class CausalMHA(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.n_head = config.n_heads
        self.head_dim = config.dim // config.n_heads
        # Standard QKV projection: 3x dim -> 3x dim
        self.c_attn = nn.Linear(config.dim, 3 * config.dim, bias=False)
        self.c_proj = nn.Linear(config.dim, config.dim, bias=False)

    def forward(self, x, freqs_cis):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(C, dim=2)

        q = q.view(B, T, self.n_head, self.head_dim)
        k = k.view(B, T, self.n_head, self.head_dim)
        v = v.view(B, T, self.n_head, self.head_dim)

        q, k = apply_rotary_emb(q, k, freqs_cis)

        y = F.scaled_dot_product_attention(
            q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True
        )
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)

# ==========================================
# 4. Model 2: The Innovation (Diff-MLA)
# ==========================================
class DiffMLA(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.nh = cfg.n_heads
        self.hd = cfg.dim // cfg.n_heads
        self.kv_rank = cfg.kv_rank
        self.q_rank = cfg.q_rank
        self.rhd = cfg.rope_head_dim

        # 1. Compression
        self.W_DQ = nn.Linear(cfg.dim, self.q_rank, bias=False)
        self.W_DKV = nn.Linear(cfg.dim, self.kv_rank, bias=False)

        # 2. Expansion (Signal vs Noise)
        self.W_UQ1 = nn.Linear(self.q_rank, self.nh * self.hd, bias=False) # Signal Q
        self.W_UQ2 = nn.Linear(self.q_rank, self.nh * self.hd, bias=False) # Noise Q
        self.W_UK1 = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False) # Signal K
        self.W_UK2 = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False) # Noise K
        self.W_UV  = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False) # Shared V

        # 3. RoPE (Decoupled Side-Channel)
        self.W_QR = nn.Linear(self.q_rank, self.nh * self.rhd, bias=False)
        self.W_KR = nn.Linear(cfg.dim, self.rhd, bias=False)

        # 4. Differential Components
        self.lambda_init = cfg.diff_lambda_init
        self.lambda_q1 = nn.Parameter(torch.randn(self.hd))
        self.lambda_k1 = nn.Parameter(torch.randn(self.hd))
        self.lambda_q2 = nn.Parameter(torch.randn(self.hd))
        self.lambda_k2 = nn.Parameter(torch.randn(self.hd))

        self.diff_norm = RMSNorm(self.hd)
        self.c_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)

    def _get_lambda(self):
        l1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1))
        l2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2))
        return l1 - l2 + self.lambda_init

    def forward(self, x, freqs_cis):
        B, T, C = x.size()

        # Compression
        cQ = self.W_DQ(x)
        cKV = self.W_DKV(x)

        # Expansion
        q1 = self.W_UQ1(cQ).view(B, T, self.nh, self.hd)
        q2 = self.W_UQ2(cQ).view(B, T, self.nh, self.hd)
        k1 = self.W_UK1(cKV).view(B, T, self.nh, self.hd)
        k2 = self.W_UK2(cKV).view(B, T, self.nh, self.hd)
        v  = self.W_UV(cKV).view(B, T, self.nh, self.hd)

        # RoPE Side-Channel
        q_rope = self.W_QR(cQ).view(B, T, self.nh, self.rhd)
        k_rope = self.W_KR(x).view(B, T, 1, self.rhd).expand(B, T, self.nh, self.rhd)
        q_rope, k_rope = apply_rotary_emb(q_rope, k_rope, freqs_cis)

        # Combine Content + RoPE
        q1_full = torch.cat([q1, q_rope], dim=-1)
        q2_full = torch.cat([q2, q_rope], dim=-1)
        k1_full = torch.cat([k1, k_rope], dim=-1)
        k2_full = torch.cat([k2, k_rope], dim=-1)

        # Differential Attention
        scale = 1.0 / math.sqrt(self.hd + self.rhd)

        # Compute raw scores
        s1 = torch.matmul(q1_full.transpose(1, 2), k1_full.transpose(1, 2).transpose(-2, -1)) * scale
        s2 = torch.matmul(q2_full.transpose(1, 2), k2_full.transpose(1, 2).transpose(-2, -1)) * scale

        # Causal Masking
        mask = torch.triu(torch.ones(T, T, device=device) * float('-inf'), diagonal=1)
        s1 = s1 + mask
        s2 = s2 + mask

        # Subtraction
        attn = torch.softmax(s1, dim=-1) - self._get_lambda() * torch.softmax(s2, dim=-1)

        # Output
        y = torch.matmul(attn, v.transpose(1, 2)) # [B, H, T, D]
        y = self.diff_norm(y.transpose(1, 2))     # Headwise Norm
        y = y * (1 - self.lambda_init)            # Gradient Scale

        return self.c_proj(y.contiguous().view(B, T, C))

# ==========================================
# 5. Generic Transformer Skeleton
# ==========================================
class Transformer(nn.Module):
    def __init__(self, config: ModelConfig, model_type="mha"):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.dim)

        self.layers = nn.ModuleList()
        for _ in range(config.n_layers):
            self.layers.append(nn.ModuleDict({
                'norm1': RMSNorm(config.dim),
                'attn': DiffMLA(config) if model_type == "diff_mla" else CausalMHA(config),
                'norm2': RMSNorm(config.dim),
                'mlp': SwiGLU(config.dim, 4 * config.dim)
            }))

        self.final_norm = RMSNorm(config.dim)
        self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
        self.lm_head.weight = self.token_emb.weight # Weight tying

        # Precompute RoPE table
        # MHA uses head_dim, Diff-MLA uses rope_head_dim
        dim_rope = config.rope_head_dim if model_type == "diff_mla" else config.dim // config.n_heads
        self.freqs_cis = precompute_freqs_cis(dim_rope, config.max_seq_len * 2)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        x = self.token_emb(idx)
        freqs = self.freqs_cis[:T]

        for layer in self.layers:
            # Attention Block
            x = x + layer['attn'](layer['norm1'](x), freqs)
            # MLP Block
            x = x + layer['mlp'](layer['norm2'](x))

        x = self.final_norm(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

# ==========================================
# 6. Experimental Harness
# ==========================================
def train_experiment(model_type: str, steps: int = 1000):
    print(f"\n--- Starting Experiment: {model_type.upper()} ---")

    # 1. Load Data
    if not os.path.exists('input.txt'):
        os.system('wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
    with open('input.txt', 'r', encoding='utf-8') as f: text = f.read()
    chars = sorted(list(set(text)))
    stoi = { ch:i for i,ch in enumerate(chars) }
    encode = lambda s: [stoi[c] for c in s]
    data = torch.tensor(encode(text), dtype=torch.long, device=device)

    # 2. Config & Init
    config = ModelConfig(vocab_size=len(chars))
    model = Transformer(config, model_type=model_type).to(device)

    param_count = sum(p.numel() for p in model.parameters())
    print(f"Parameters: {param_count/1e6:.2f}M")

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=steps)

    # 3. Training Loop
    model.train()
    start_time = time.time()
    losses = []

    for step in range(steps):
        # Batching
        ix = torch.randint(len(data) - config.max_seq_len, (16,)) # Batch size 16
        x = torch.stack([data[i:i+config.max_seq_len] for i in ix])
        y = torch.stack([data[i+1:i+config.max_seq_len+1] for i in ix])

        # Forward
        logits, loss = model(x, y)

        # Backward
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        losses.append(loss.item())

        if step % 200 == 0:
            print(f"Step {step}: Loss {loss.item():.4f}")

    # 4. Final Metrics
    torch.cuda.synchronize()
    total_time = time.time() - start_time
    peak_mem = torch.cuda.max_memory_allocated() / 1024**2
    final_loss = sum(losses[-10:]) / 10

    print(f"Final Loss: {final_loss:.4f}")
    print(f"Throughput: {steps * 16 * config.max_seq_len / total_time:.0f} tok/s")
    print(f"Peak VRAM:  {peak_mem:.0f} MB")

    # Clean up
    del model, optimizer, x, y
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    return final_loss, param_count, peak_mem

# ==========================================
# 7. Run Comparison
# ==========================================
print("Initializing Research Environment...")
set_seed(42)

# Run Baseline
loss_mha, params_mha, mem_mha = train_experiment("mha", steps=1000)

# Run Experimental
loss_diff, params_diff, mem_diff = train_experiment("diff_mla", steps=1000)

print("\n\n========================================================")
print(f"{'Metric':<15} | {'Baseline (MHA)':<15} | {'Diff-MLA':<15}")
print("--------------------------------------------------------")
print(f"{'Parameters':<15} | {params_mha/1e6:.2f}M            | {params_diff/1e6:.2f}M")
print(f"{'Final Loss':<15} | {loss_mha:.4f}             | {loss_diff:.4f}")
print(f"{'Peak VRAM':<15} | {mem_mha:.0f} MB             | {mem_diff:.0f} MB")
print("========================================================")
print("Interpretation:")
if loss_diff <= loss_mha:
    print("SUCCESS: Diff-MLA matches or beats Baseline quality.")
else:
    print("NOTE: Diff-MLA shows slight degradation (expected due to compression).")
print("Check VRAM usage. Training VRAM might be higher for Diff-MLA due to")
print("intermediate activations, but Inference Cache (KV) will be 6x smaller.")

Running on: cuda
Initializing Research Environment...

--- Starting Experiment: MHA ---
Parameters: 14.19M
Step 0: Loss 355.9028
Step 200: Loss 2.3265
Step 400: Loss 1.8679
Step 600: Loss 1.6239
Step 800: Loss 1.4711
Final Loss: 1.4163
Throughput: 137949 tok/s
Peak VRAM:  4448 MB

--- Starting Experiment: DIFF_MLA ---
Parameters: 13.23M
Step 0: Loss 347.7994
Step 200: Loss 2.7611
Step 400: Loss 2.1762
Step 600: Loss 1.8915
Step 800: Loss 1.6110
Final Loss: 1.5788
Throughput: 63190 tok/s
Peak VRAM:  12721 MB


Metric          | Baseline (MHA)  | Diff-MLA       
--------------------------------------------------------
Parameters      | 14.19M            | 13.23M
Final Loss      | 1.4163             | 1.5788
Peak VRAM       | 4448 MB             | 12721 MB
Interpretation:
NOTE: Diff-MLA shows slight degradation (expected due to compression).
Check VRAM usage. Training VRAM might be higher for Diff-MLA due to
intermediate activations, but Inference Cache (KV) will be 6x smaller.


In [1]:
import os, math, time, random
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F

# ----------------------------
# 0) Setup
# ----------------------------
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# ----------------------------
# 1) Data (TinyShakespeare)
# ----------------------------
if not os.path.exists('input.txt'):
    os.system('wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
with open('input.txt', 'r', encoding='utf-8') as f: text = f.read()
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
data = torch.tensor(encode(text), dtype=torch.long, device=device)

# ----------------------------
# 2) Config
# ----------------------------
@dataclass
class ModelConfig:
    vocab_size: int = vocab_size
    dim: int = 512
    n_layers: int = 6
    n_heads: int = 8
    max_seq_len: int = 1024

    # MLA Specifics
    kv_rank: int = 128
    q_rank: int = 256
    rope_head_dim: int = 32

    # CMD-MLA Specifics
    # We only use ONE noise head shared across all signal heads
    diff_lambda_init: float = 0.8

# ----------------------------
# 3) Utilities
# ----------------------------
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        var = x.pow(2).mean(dim=-1, keepdim=True)
        return x * torch.rsqrt(var + self.eps) * self.weight

class SwiGLU(nn.Module):
    def __init__(self, dim: int, hidden_dim: int):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(end, device=device)
    freqs = torch.outer(t, freqs).float()
    return torch.polar(torch.ones_like(freqs), freqs)

def apply_rotary_emb(xq, xk, freqs_cis):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs = freqs_cis[:xq.shape[1]].view(1, xq.shape[1], 1, -1)
    xq_out = torch.view_as_real(xq_ * freqs).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

# ----------------------------
# 4) The Innovation: CMD-MLA
# ----------------------------
class CMD_MLA(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.nh = cfg.n_heads
        self.hd = cfg.dim // cfg.n_heads
        self.kv_rank = cfg.kv_rank
        self.q_rank = cfg.q_rank
        self.rhd = cfg.rope_head_dim

        # 1. Compression
        self.W_DQ = nn.Linear(cfg.dim, self.q_rank, bias=False)
        self.W_DKV = nn.Linear(cfg.dim, self.kv_rank, bias=False)

        # 2. Signal Heads (Standard MLA)
        self.W_UQ = nn.Linear(self.q_rank, self.nh * self.hd, bias=False)
        self.W_UK = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False)
        self.W_UV = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False)

        # 3. The "Common Mode" Noise Head (SINGLE Head)
        # It has its own projections from the SAME latent space
        self.W_UQ_Noise = nn.Linear(self.q_rank, self.hd, bias=False)
        self.W_UK_Noise = nn.Linear(self.kv_rank, self.hd, bias=False)

        # 4. RoPE (Shared)
        self.W_QR = nn.Linear(self.q_rank, self.nh * self.rhd, bias=False)
        self.W_KR = nn.Linear(cfg.dim, self.rhd, bias=False)

        # 5. Differential Components
        self.lambda_init = cfg.diff_lambda_init
        # Lambda is per-head: Each signal head decides how much noise to subtract
        self.lambda_q = nn.Parameter(torch.randn(self.nh, self.hd))
        self.lambda_k = nn.Parameter(torch.randn(self.nh, self.hd))

        self.head_norm = RMSNorm(self.hd)
        self.c_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)

    def forward(self, x, freqs_cis):
        B, T, C = x.size()

        # Latents
        cQ = self.W_DQ(x)
        cKV = self.W_DKV(x)

        # --- Signal Generation ---
        q_sig = self.W_UQ(cQ).view(B, T, self.nh, self.hd) # [B, T, H, D]
        k_sig = self.W_UK(cKV).view(B, T, self.nh, self.hd)
        v_sig = self.W_UV(cKV).view(B, T, self.nh, self.hd)

        # --- Noise Generation (Single Head) ---
        q_noise = self.W_UQ_Noise(cQ).view(B, T, 1, self.hd) # [B, T, 1, D]
        k_noise = self.W_UK_Noise(cKV).view(B, T, 1, self.hd)

        # --- RoPE ---
        q_rope = self.W_QR(cQ).view(B, T, self.nh, self.rhd)
        k_rope = self.W_KR(x).view(B, T, 1, self.rhd).expand(B, T, self.nh, self.rhd)
        q_rope, k_rope = apply_rotary_emb(q_rope, k_rope, freqs_cis)

        # We attach RoPE to Signal only?
        # Actually, DiffAttn paper suggests RoPE on both.
        # But Noise should be positional too. Let's add RoPE to signal.
        # For simplicity in this hypothesis, we assume Noise is content-heavy (stopwords).
        # We attach RoPE to Signal.

        q_sig_full = torch.cat([q_sig, q_rope], dim=-1)
        k_sig_full = torch.cat([k_sig, k_rope], dim=-1)

        # --- Attention Scores ---
        scale = 1.0 / math.sqrt(self.hd + self.rhd)

        # 1. Signal Scores [B, H, T, T]
        scores_sig = torch.matmul(q_sig_full.transpose(1, 2), k_sig_full.transpose(1, 2).transpose(-2, -1)) * scale

        # 2. Noise Scores [B, 1, T, T]
        # Note: Noise head has no RoPE here to keep it "Global/Background" focused
        scale_noise = 1.0 / math.sqrt(self.hd)
        scores_noise = torch.matmul(q_noise.transpose(1, 2), k_noise.transpose(1, 2).transpose(-2, -1)) * scale_noise

        # Causal Mask
        mask = torch.triu(torch.ones(T, T, device=device) * float('-inf'), diagonal=1)
        scores_sig = scores_sig + mask
        scores_noise = scores_noise + mask

        attn_sig = torch.softmax(scores_sig, dim=-1)
        attn_noise = torch.softmax(scores_noise, dim=-1) # [B, 1, T, T]

        # --- Differential Subtraction ---
        # Lambda calculation per head
        lam = torch.exp(torch.sum(self.lambda_q * self.lambda_k, dim=-1)) + self.lambda_init
        lam = lam.view(1, self.nh, 1, 1) # Broadcast to batch and seq

        # The Innovation: Subtract Global Noise from All Heads
        # attn_sig: [B, H, T, T]
        # attn_noise: [B, 1, T, T] -> Broadcasts to H
        diff_attn = attn_sig - lam * attn_noise

        # Value Projection
        y = torch.matmul(diff_attn, v_sig.transpose(1, 2)) # [B, H, T, D]

        # Norm & Output
        y = self.head_norm(y.transpose(1, 2))
        y = y * (1 - self.lambda_init)

        return self.c_proj(y.contiguous().view(B, T, C))

# ==========================================
# 5. Standard MLA (Control Group)
# ==========================================
class MLA(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.nh = cfg.n_heads
        self.hd = cfg.dim // cfg.n_heads
        self.kv_rank = cfg.kv_rank
        self.q_rank = cfg.q_rank
        self.rhd = cfg.rope_head_dim

        self.W_DQ = nn.Linear(cfg.dim, self.q_rank, bias=False)
        self.W_UQ = nn.Linear(self.q_rank, self.nh * self.hd, bias=False)
        self.W_QR = nn.Linear(self.q_rank, self.nh * self.rhd, bias=False)

        self.W_DKV = nn.Linear(cfg.dim, self.kv_rank, bias=False)
        self.W_UK = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False)
        self.W_UV = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False)
        self.W_KR = nn.Linear(cfg.dim, self.rhd, bias=False)

        self.c_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)

    def forward(self, x, freqs_cis):
        B, T, C = x.size()
        cQ = self.W_DQ(x)
        q = self.W_UQ(cQ).view(B, T, self.nh, self.hd)
        q_rope = self.W_QR(cQ).view(B, T, self.nh, self.rhd)

        cKV = self.W_DKV(x)
        k = self.W_UK(cKV).view(B, T, self.nh, self.hd)
        v = self.W_UV(cKV).view(B, T, self.nh, self.hd)
        k_rope = self.W_KR(x).view(B, T, 1, self.rhd).expand(B, T, self.nh, self.rhd)

        q_rope, k_rope = apply_rotary_emb(q_rope, k_rope, freqs_cis)

        q_full = torch.cat([q, q_rope], dim=-1)
        k_full = torch.cat([k, k_rope], dim=-1)

        y = F.scaled_dot_product_attention(
            q_full.transpose(1, 2), k_full.transpose(1, 2), v.transpose(1, 2), is_causal=True
        )
        return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C))

# ==========================================
# 6. Experimental Harness
# ==========================================
class Transformer(nn.Module):
    def __init__(self, config: ModelConfig, model_type="mla"):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.dim)
        self.layers = nn.ModuleList()
        for _ in range(config.n_layers):
            self.layers.append(nn.ModuleDict({
                'norm1': RMSNorm(config.dim),
                'attn': CMD_MLA(config) if model_type == "cmd_mla" else MLA(config),
                'norm2': RMSNorm(config.dim),
                'mlp': SwiGLU(config.dim, 4 * config.dim)
            }))
        self.final_norm = RMSNorm(config.dim)
        self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
        self.lm_head.weight = self.token_emb.weight

        dim_rope = config.rope_head_dim # Same for both
        self.freqs_cis = precompute_freqs_cis(dim_rope, config.max_seq_len * 2)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        x = self.token_emb(idx)
        freqs = self.freqs_cis[:T]
        for layer in self.layers:
            x = x + layer['attn'](layer['norm1'](x), freqs)
            x = x + layer['mlp'](layer['norm2'](x))
        x = self.final_norm(x)
        logits = self.lm_head(x)
        loss = None if targets is None else F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

def run_experiment(name, model_type, steps=1000):
    print(f"\n--- Experiment: {name} ---")
    cfg = ModelConfig(vocab_size=vocab_size)
    model = Transformer(cfg, model_type=model_type).to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=steps)

    model.train()
    losses = []

    t0 = time.time()
    for step in range(steps):
        ix = torch.randint(len(data) - cfg.max_seq_len, (16,))
        x = torch.stack([data[i:i+cfg.max_seq_len] for i in ix])
        y = torch.stack([data[i+1:i+cfg.max_seq_len+1] for i in ix])

        _, loss = model(x, y)
        optim.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optim.step()
        sched.step()
        losses.append(loss.item())

        if step % 200 == 0: print(f"Step {step}: {loss.item():.4f}")

    torch.cuda.synchronize()
    dt = time.time() - t0
    peak_mem = torch.cuda.max_memory_allocated() / 1024**2
    final_loss = sum(losses[-20:]) / 20

    print(f"Result: Loss {final_loss:.4f} | VRAM {peak_mem:.0f}MB | Time {dt:.1f}s")

    # Cleanup
    del model, optim, x, y
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    return final_loss, peak_mem, dt

# ==========================================
# 7. Execution
# ==========================================
loss_mla, mem_mla, time_mla = run_experiment("Baseline MLA", "mla")
loss_cmd, mem_cmd, time_cmd = run_experiment("CMD-MLA (Ours)", "cmd_mla")

print("\n=== Final Report ===")
print(f"{'Metric':<10} | {'MLA':<10} | {'CMD-MLA':<10} | {'Delta'}")
print("-" * 45)
print(f"{'Loss':<10} | {loss_mla:.4f}     | {loss_cmd:.4f}         | {loss_cmd - loss_mla:+.4f}")
print(f"{'VRAM':<10} | {mem_mla:.0f}MB      | {mem_cmd:.0f}MB         | {mem_cmd - mem_mla:+.0f}MB")
print(f"{'Time':<10} | {time_mla:.1f}s       | {time_cmd:.1f}s         | {time_cmd - time_mla:+.1f}s")

Device: cuda

--- Experiment: Baseline MLA ---
Step 0: 484.7122
Step 200: 3.0524
Step 400: 2.4129
Step 600: 2.0626
Step 800: 1.9731
Result: Loss 1.9352 | VRAM 5925MB | Time 196.7s

--- Experiment: CMD-MLA (Ours) ---
Step 0: 467.8169
Step 200: 2.8713
Step 400: 2.4459
Step 600: 2.1647
Step 800: 1.8582
Result: Loss 1.7933 | VRAM 13471MB | Time 306.9s

=== Final Report ===
Metric     | MLA        | CMD-MLA    | Delta
---------------------------------------------
Loss       | 1.9352     | 1.7933         | -0.1420
VRAM       | 5925MB      | 13471MB         | +7546MB
Time       | 196.7s       | 306.9s         | +110.1s


In [1]:
# ==============================================================================
# Research Experiment: Scaled CMD-MLA (Common-Mode Differential Latent Attention)
# Platform: Single A100 80GB (or 40GB)
# Optimization: FlashAttention via (AV - BV) decomposition
# ==============================================================================

import os
import time
import math
import random
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

# 1. Setup & Dependencies
# ------------------------------------------------------------------------------
def setup_environment():
    import subprocess, sys
    def install(package):
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

    try: import datasets
    except ImportError: install("datasets")
    try: import tiktoken
    except ImportError: install("tiktoken")

setup_environment()
from datasets import load_dataset
import tiktoken

# Reproducibility
def set_seed(seed=1337):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(1337)

device = "cuda" if torch.cuda.is_available() else "cpu"
# Enable TF32 for A100 Tensor Cores
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
print(f"Compute Device: {device} ({torch.cuda.get_device_name(0) if device=='cuda' else 'CPU'})")

# 2. Data Pipeline (WikiText-103) - RAM Optimized
# ------------------------------------------------------------------------------
print("Loading WikiText-103...")
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
tokenizer = tiktoken.get_encoding("gpt2")

def tokenize_batch(batch):
    return [tokenizer.encode(text) + [tokenizer.eot_token] for text in batch["text"]]

# We tokenize and flatten manually in batches to avoid RAM explosion
import numpy as np
from tqdm import tqdm

all_tokens = []
batch_size = 1000  # Process 1k rows at a time
print("Tokenizing and Flattening in streams...")

# Iterate through dataset without loading everything at once
for i in tqdm(range(0, len(dataset), batch_size)):
    batch_text = dataset[i : i + batch_size]
    # Tokenize
    batch_ids = tokenize_batch(batch_text)
    # Flatten list of lists efficiently
    for seq in batch_ids:
        all_tokens.extend(seq)

print("Converting to Tensor...")
# Convert to tensor (uint16 is sufficient for vocab < 65535, saves 4x RAM)
# We use int32/int64 for safety during training casting
train_data = torch.tensor(all_tokens, dtype=torch.long)
vocab_size = tokenizer.n_vocab

print(f"Dataset Loaded: {len(train_data):,} tokens. Vocab Size: {vocab_size}")

# Clean up CPU RAM
del all_tokens, dataset
import gc
gc.collect()

# 3. Model Configuration (Scaled Up)
# ------------------------------------------------------------------------------
@dataclass
class ModelConfig:
    vocab_size: int = vocab_size
    dim: int = 1024           # GPT-2 Medium width
    n_layers: int = 12        # Reasonable depth
    n_heads: int = 16         # 64 dim per head
    max_seq_len: int = 2048   # Longer context for A100
    dropout: float = 0.0

    # CMD-MLA Specifics
    kv_rank: int = 512        # Higher rank for larger model
    q_rank: int = 1024        # Keep Q rank high for expressivity
    rope_head_dim: int = 64
    diff_lambda_init: float = 0.8

# 4. Utilities
# ------------------------------------------------------------------------------
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        var = x.pow(2).mean(dim=-1, keepdim=True)
        return x * torch.rsqrt(var + self.eps) * self.weight

class SwiGLU(nn.Module):
    def __init__(self, dim: int, hidden_dim: int):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(end, device=device)
    freqs = torch.outer(t, freqs).float()
    return torch.polar(torch.ones_like(freqs), freqs)

def apply_rotary_emb(xq, xk, freqs_cis):
    # xq: [B, T, H, D]
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs = freqs_cis[:xq.shape[1]].view(1, xq.shape[1], 1, -1)
    xq_out = torch.view_as_real(xq_ * freqs).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

# 5. Architecture: Optimized CMD-MLA
# ------------------------------------------------------------------------------
class Optimized_CMD_MLA(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.nh = cfg.n_heads
        self.hd = cfg.dim // cfg.n_heads
        self.kv_rank = cfg.kv_rank
        self.q_rank = cfg.q_rank
        self.rhd = cfg.rope_head_dim

        # Projections
        self.W_DQ = nn.Linear(cfg.dim, self.q_rank, bias=False)
        self.W_DKV = nn.Linear(cfg.dim, self.kv_rank, bias=False)

        # Signal Heads
        self.W_UQ = nn.Linear(self.q_rank, self.nh * self.hd, bias=False)
        self.W_UK = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False)
        self.W_UV = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False)

        # Common Mode Noise Head (1 Head)
        self.W_UQ_Noise = nn.Linear(self.q_rank, self.hd, bias=False)
        self.W_UK_Noise = nn.Linear(self.kv_rank, self.hd, bias=False)

        # RoPE
        self.W_QR = nn.Linear(self.q_rank, self.nh * self.rhd, bias=False)
        self.W_KR = nn.Linear(cfg.dim, self.rhd, bias=False)

        # Differential Components
        self.lambda_init = cfg.diff_lambda_init
        self.lambda_q = nn.Parameter(torch.zeros(self.nh, self.hd)) # Init zero for stability
        self.lambda_k = nn.Parameter(torch.zeros(self.nh, self.hd))

        self.head_norm = RMSNorm(self.hd)
        self.c_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)

    def forward(self, x, freqs_cis):
        B, T, C = x.size()

        # 1. Latents
        cQ = self.W_DQ(x)
        cKV = self.W_DKV(x)

        # 2. Signal Generation
        q_sig = self.W_UQ(cQ).view(B, T, self.nh, self.hd)
        k_sig = self.W_UK(cKV).view(B, T, self.nh, self.hd)
        v_sig = self.W_UV(cKV).view(B, T, self.nh, self.hd)

        # 3. Noise Generation (1 Head)
        # We perform a trick here: We create 1 noise head, but we need to subtract it
        # from H signal heads. To use FlashAttention, we expand Q/K noise to H heads.
        # This effectively copies the noise computation H times, but allows parallel CUDA kernels.
        q_noise = self.W_UQ_Noise(cQ).view(B, T, 1, self.hd).expand(B, T, self.nh, self.hd)
        k_noise = self.W_UK_Noise(cKV).view(B, T, 1, self.hd).expand(B, T, self.nh, self.hd)

        # 4. RoPE
        q_rope = self.W_QR(cQ).view(B, T, self.nh, self.rhd)
        k_rope = self.W_KR(x).view(B, T, 1, self.rhd).expand(B, T, self.nh, self.rhd)
        q_rope, k_rope = apply_rotary_emb(q_rope, k_rope, freqs_cis)

        # Combine
        q_sig_full = torch.cat([q_sig, q_rope], dim=-1)
        k_sig_full = torch.cat([k_sig, k_rope], dim=-1)

        # NOISE HEAD: No RoPE (Bag-of-words assumption for noise)
        # We need to pad the Noise Q/K to match the dimension of Signal Q/K for cleaner code,
        # OR we just run separate kernels. Separate kernels is better.

        # 5. FlashAttention Calls
        # Signal Output: SDPA(Q_sig, K_sig, V_sig)
        # Transpose for SDPA: [B, H, T, D]
        sig_out = F.scaled_dot_product_attention(
            q_sig_full.transpose(1, 2),
            k_sig_full.transpose(1, 2),
            v_sig.transpose(1, 2),
            is_causal=True
        )

        # Noise Output: SDPA(Q_noise, K_noise, V_sig)
        # Note: We use V_sig for both. This means the noise head decides "which parts of Value" are noise.
        noise_out = F.scaled_dot_product_attention(
            q_noise.transpose(1, 2),
            k_noise.transpose(1, 2),
            v_sig.transpose(1, 2),
            is_causal=True
        )

        # 6. Differential Subtraction (AV - BV)
        # Calculate Lambda
        lam = torch.exp(torch.sum(self.lambda_q * self.lambda_k, dim=-1)) + self.lambda_init
        lam = lam.view(1, self.nh, 1, 1) # [1, H, 1, 1]

        diff_out = sig_out - lam * noise_out

        # 7. Norm & Proj
        # [B, H, T, D] -> [B, T, H, D]
        diff_out = diff_out.transpose(1, 2)
        diff_out = self.head_norm(diff_out)
        diff_out = diff_out * (1 - self.lambda_init)

        return self.c_proj(diff_out.flatten(2))

# Baseline MLA for Comparison
class Baseline_MLA(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.nh = cfg.n_heads
        self.hd = cfg.dim // cfg.n_heads
        self.kv_rank = cfg.kv_rank
        self.q_rank = cfg.q_rank
        self.rhd = cfg.rope_head_dim

        self.W_DQ = nn.Linear(cfg.dim, self.q_rank, bias=False)
        self.W_UQ = nn.Linear(self.q_rank, self.nh * self.hd, bias=False)
        self.W_QR = nn.Linear(self.q_rank, self.nh * self.rhd, bias=False)
        self.W_DKV = nn.Linear(cfg.dim, self.kv_rank, bias=False)
        self.W_UK = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False)
        self.W_UV = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False)
        self.W_KR = nn.Linear(cfg.dim, self.rhd, bias=False)
        self.c_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)

    def forward(self, x, freqs_cis):
        B, T, C = x.size()
        cQ = self.W_DQ(x)
        q = self.W_UQ(cQ).view(B, T, self.nh, self.hd)
        q_rope = self.W_QR(cQ).view(B, T, self.nh, self.rhd)
        cKV = self.W_DKV(x)
        k = self.W_UK(cKV).view(B, T, self.nh, self.hd)
        v = self.W_UV(cKV).view(B, T, self.nh, self.hd)
        k_rope = self.W_KR(x).view(B, T, 1, self.rhd).expand(B, T, self.nh, self.rhd)
        q_rope, k_rope = apply_rotary_emb(q_rope, k_rope, freqs_cis)
        q_full = torch.cat([q, q_rope], dim=-1)
        k_full = torch.cat([k, k_rope], dim=-1)

        y = F.scaled_dot_product_attention(
            q_full.transpose(1, 2), k_full.transpose(1, 2), v.transpose(1, 2), is_causal=True
        )
        return self.c_proj(y.transpose(1, 2).flatten(2))

# 6. Transformer Backbone
# ------------------------------------------------------------------------------
class LargeTransformer(nn.Module):
    def __init__(self, config: ModelConfig, model_type="mla"):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.dim)
        self.layers = nn.ModuleList()
        for _ in range(config.n_layers):
            self.layers.append(nn.ModuleDict({
                'norm1': RMSNorm(config.dim),
                'attn': Optimized_CMD_MLA(config) if model_type == "cmd_mla" else Baseline_MLA(config),
                'norm2': RMSNorm(config.dim),
                'mlp': SwiGLU(config.dim, 4 * config.dim)
            }))
        self.final_norm = RMSNorm(config.dim)
        self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
        self.lm_head.weight = self.token_emb.weight
        self.freqs_cis = precompute_freqs_cis(config.rope_head_dim, config.max_seq_len * 2)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        x = self.token_emb(idx)
        freqs = self.freqs_cis[:T]
        for layer in self.layers:
            x = x + layer['attn'](layer['norm1'](x), freqs)
            x = x + layer['mlp'](layer['norm2'](x))
        x = self.final_norm(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

# 7. Training Harness (Scaled)
# ------------------------------------------------------------------------------
def run_scaled_experiment(name, model_type, train_steps=500, batch_size=8):
    print(f"\n========================================================")
    print(f"Starting Scaled Run: {name} (A100 Optimized)")
    print(f"========================================================")

    cfg = ModelConfig()
    model = LargeTransformer(cfg, model_type=model_type).to(device)

    # Compile model for speed (PyTorch 2.0)
    # model = torch.compile(model)

    params = sum(p.numel() for p in model.parameters())
    print(f"Model Parameters: {params/1e6:.2f}M")

    optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4, weight_decay=0.1)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=train_steps)

    # AMP Scaler
    scaler = torch.cuda.amp.GradScaler()

    model.train()
    torch.cuda.reset_peak_memory_stats()
    start_time = time.time()
    losses = []

    # Data Iterator
    def get_batch():
        ix = torch.randint(len(train_data) - cfg.max_seq_len, (batch_size,))
        x = torch.stack([train_data[i:i+cfg.max_seq_len] for i in ix])
        y = torch.stack([train_data[i+1:i+cfg.max_seq_len+1] for i in ix])
        return x.to(device), y.to(device)

    for step in range(train_steps):
        x, y = get_batch()

        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            logits, loss = model(x, y)

        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        losses.append(loss.item())

        if step % 50 == 0:
            curr_loss = sum(losses[-10:]) / (len(losses[-10:]) + 1e-9)
            print(f"Step {step:04d} | Loss: {curr_loss:.4f} | Tokens: {step*batch_size*cfg.max_seq_len:,}")

    torch.cuda.synchronize()
    total_time = time.time() - start_time
    peak_mem = torch.cuda.max_memory_allocated() / 1024**3 # GB
    final_loss = sum(losses[-50:]) / 50
    throughput = (train_steps * batch_size * cfg.max_seq_len) / total_time

    print(f"--- Results: {name} ---")
    print(f"Final Loss: {final_loss:.4f}")
    print(f"Peak VRAM:  {peak_mem:.2f} GB")
    print(f"Throughput: {throughput:.0f} tokens/sec")

    del model, optimizer, scaler
    torch.cuda.empty_cache()
    return final_loss, peak_mem, throughput

# 8. Execution
# ------------------------------------------------------------------------------
# We run shorter steps to fit 1 hour limit, but batch size * seq_len is large
# 2048 seq len * 8 batch = 16k tokens per step.
# 1000 steps = 16M tokens trained.
STEPS = 1000

print("Warming up CUDA...")
loss_base, mem_base, speed_base = run_scaled_experiment("Baseline MLA", "mla", train_steps=STEPS)
loss_cmd, mem_cmd, speed_cmd = run_scaled_experiment("CMD-MLA (Ours)", "cmd_mla", train_steps=STEPS)

print("\n\n========================================================")
print(f"FINAL HEAD-TO-HEAD REPORT (WikiText-103, ~200M Params)")
print("========================================================")
print(f"{'Metric':<15} | {'MLA (Base)':<15} | {'CMD-MLA':<15} | {'Delta'}")
print("--------------------------------------------------------")
print(f"{'Final Loss':<15} | {loss_base:.4f}          | {loss_cmd:.4f}          | {loss_cmd - loss_base:+.4f}")
print(f"{'VRAM (GB)':<15} | {mem_base:.2f} GB          | {mem_cmd:.2f} GB          | {mem_cmd - mem_base:+.2f} GB")
print(f"{'Speed':<15} | {speed_base:.0f} tok/s     | {speed_cmd:.0f} tok/s     | {(speed_cmd-speed_base)/speed_base*100:.1f}%")

  self.setter(val)


Compute Device: cuda (NVIDIA A100-SXM4-80GB)
Loading WikiText-103...


README.md: 0.00B [00:00, ?B/s]

wikitext-103-raw-v1/test-00000-of-00001.(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-103-raw-v1/train-00000-of-00002(…):   0%|          | 0.00/157M [00:00<?, ?B/s]

wikitext-103-raw-v1/train-00001-of-00002(…):   0%|          | 0.00/157M [00:00<?, ?B/s]

wikitext-103-raw-v1/validation-00000-of-(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/1801350 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Tokenizing and Flattening in streams...


100%|██████████| 1802/1802 [01:08<00:00, 26.37it/s]


Converting to Tensor...
Dataset Loaded: 119,721,490 tokens. Vocab Size: 50257
Warming up CUDA...

Starting Scaled Run: Baseline MLA (A100 Optimized)
Model Parameters: 272.48M


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


Step 0000 | Loss: 947.1138 | Tokens: 0
Step 0050 | Loss: 51.2598 | Tokens: 819,200
Step 0100 | Loss: 29.2583 | Tokens: 1,638,400
Step 0150 | Loss: 18.4819 | Tokens: 2,457,600
Step 0200 | Loss: 13.0218 | Tokens: 3,276,800
Step 0250 | Loss: 10.8060 | Tokens: 4,096,000
Step 0300 | Loss: 10.0314 | Tokens: 4,915,200
Step 0350 | Loss: 9.1086 | Tokens: 5,734,400
Step 0400 | Loss: 8.3287 | Tokens: 6,553,600
Step 0450 | Loss: 7.8629 | Tokens: 7,372,800
Step 0500 | Loss: 7.7676 | Tokens: 8,192,000
Step 0550 | Loss: 7.5067 | Tokens: 9,011,200
Step 0600 | Loss: 7.3287 | Tokens: 9,830,400
Step 0650 | Loss: 7.1310 | Tokens: 10,649,600
Step 0700 | Loss: 6.9860 | Tokens: 11,468,800
Step 0750 | Loss: 6.9921 | Tokens: 12,288,000
Step 0800 | Loss: 6.9082 | Tokens: 13,107,200
Step 0850 | Loss: 6.8473 | Tokens: 13,926,400
Step 0900 | Loss: 6.8074 | Tokens: 14,745,600
Step 0950 | Loss: 6.8375 | Tokens: 15,564,800
--- Results: Baseline MLA ---
Final Loss: 6.8146
Peak VRAM:  26.63 GB
Throughput: 45526 tokens/

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Tuple

@dataclass
class CMDMLAConfig:
    dim: int = 1024           # Hidden dimension
    n_heads: int = 16         # Number of Signal Heads
    max_seq_len: int = 2048

    # Latent Compression Dimensions
    kv_rank: int = 512        # Size of the compressed KV latent vector
    q_rank: int = 1024        # Size of the compressed Query latent vector

    # RoPE
    rope_head_dim: int = 64   # Dimension used for Rotary Embeddings

    # Differential Setup
    diff_lambda_init: float = 0.8  # Initial noise subtraction strength

class CMD_MLA(nn.Module):
    """
    Common-Mode Differential Multi-Head Latent Attention (CMD-MLA).

    A hybrid architecture that uses Low-Rank Compression (MLA) for memory efficiency
    and a 1-to-N Differential Noise Head for signal fidelity.
    """
    def __init__(self, cfg: CMDMLAConfig):
        super().__init__()
        self.nh = cfg.n_heads
        self.hd = cfg.dim // cfg.n_heads
        self.kv_rank = cfg.kv_rank
        self.q_rank = cfg.q_rank
        self.rhd = cfg.rope_head_dim

        # 1. Compression (Down-Projections)
        self.W_DQ = nn.Linear(cfg.dim, self.q_rank, bias=False)
        self.W_DKV = nn.Linear(cfg.dim, self.kv_rank, bias=False)

        # 2. Signal Generation (Up-Projections)
        # Projects Latent Q/KV -> N Heads
        self.W_UQ = nn.Linear(self.q_rank, self.nh * self.hd, bias=False)
        self.W_UK = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False)
        self.W_UV = nn.Linear(self.kv_rank, self.nh * self.hd, bias=False)

        # 3. Noise Generation (Single Common-Mode Head)
        # Projects Latent Q/KV -> 1 Head
        self.W_UQ_Noise = nn.Linear(self.q_rank, self.hd, bias=False)
        self.W_UK_Noise = nn.Linear(self.kv_rank, self.hd, bias=False)

        # 4. RoPE Side-Channel (Decoupled Strategy)
        # We apply RoPE only to the Signal Heads to preserve position awareness there.
        self.W_QR = nn.Linear(self.q_rank, self.nh * self.rhd, bias=False)
        self.W_KR = nn.Linear(cfg.dim, self.rhd, bias=False)

        # 5. Differential Gating (Lambda)
        # Learnable parameter per head: "How much noise should this head ignore?"
        # Initialized to allow gradient flow immediately.
        self.lambda_init = cfg.diff_lambda_init
        self.lambda_q = nn.Parameter(torch.zeros(self.nh, self.hd))
        self.lambda_k = nn.Parameter(torch.zeros(self.nh, self.hd))

        # 6. Normalization & Output
        # Headwise Norm is critical for Differential Attention stability
        self.head_norm = RMSNorm(self.hd)
        self.c_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)

    def _apply_rope(self, q, k, freqs_cis):
        # q, k: [B, T, H, D]
        q_ = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2))
        k_ = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2))
        freqs = freqs_cis[:q.shape[1]].view(1, q.shape[1], 1, -1)
        q_out = torch.view_as_real(q_ * freqs).flatten(3)
        k_out = torch.view_as_real(k_ * freqs).flatten(3)
        return q_out.type_as(q), k_out.type_as(k)

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor] = None):
        """
        Args:
            x: Input tensor [Batch, SeqLen, Dim]
            freqs_cis: Precomputed RoPE frequencies [SeqLen, Dim/2]
        """
        B, T, C = x.size()

        # A. Latent Projection
        cQ = self.W_DQ(x)   # [B, T, q_rank]
        cKV = self.W_DKV(x) # [B, T, kv_rank]

        # B. Signal Head Generation
        q_sig = self.W_UQ(cQ).view(B, T, self.nh, self.hd)
        k_sig = self.W_UK(cKV).view(B, T, self.nh, self.hd)
        v_sig = self.W_UV(cKV).view(B, T, self.nh, self.hd)

        # C. Common-Mode Noise Generation
        # We generate 1 head, then expand to N heads to allow parallel FlashAttention
        # This is VRAM efficient because we don't materialize the N x N grid, just the Q/K vectors.
        q_noise = self.W_UQ_Noise(cQ).view(B, T, 1, self.hd).expand(B, T, self.nh, self.hd)
        k_noise = self.W_UK_Noise(cKV).view(B, T, 1, self.hd).expand(B, T, self.nh, self.hd)

        # D. RoPE Injection (Signal Only)
        q_rope = self.W_QR(cQ).view(B, T, self.nh, self.rhd)
        k_rope = self.W_KR(x).view(B, T, 1, self.rhd).expand(B, T, self.nh, self.rhd)
        q_rope, k_rope = self._apply_rope(q_rope, k_rope, freqs_cis)

        # Concat content + position for Signal
        q_sig_full = torch.cat([q_sig, q_rope], dim=-1)
        k_sig_full = torch.cat([k_sig, k_rope], dim=-1)

        # E. FlashAttention Decomposition: (A - lambda*B)V = AV - lambda*BV

        # 1. Signal Context (AV)
        # Transpose to [B, H, T, D] for Torch SDPA
        ctx_sig = F.scaled_dot_product_attention(
            q_sig_full.transpose(1, 2),
            k_sig_full.transpose(1, 2),
            v_sig.transpose(1, 2),
            is_causal=True
        )

        # 2. Noise Context (BV)
        # Note: Noise heads use the SAME Value vectors (v_sig) but different Q/K attention patterns
        ctx_noise = F.scaled_dot_product_attention(
            q_noise.transpose(1, 2),
            k_noise.transpose(1, 2),
            v_sig.transpose(1, 2),
            is_causal=True
        )

        # F. Differential Subtraction
        # Calculate Lambda: \exp(\lambda_q \cdot \lambda_k) + init
        lam = torch.exp(torch.sum(self.lambda_q * self.lambda_k, dim=-1)) + self.lambda_init
        lam = lam.view(1, self.nh, 1, 1) # Broadcast

        # Subtract Common-Mode Noise
        ctx_diff = ctx_sig - (lam * ctx_noise)

        # G. Normalization & Output
        # [B, H, T, D] -> [B, T, H, D]
        ctx_diff = ctx_diff.transpose(1, 2)
        ctx_diff = self.head_norm(ctx_diff)
        ctx_diff = ctx_diff * (1 - self.lambda_init) # Gradient scaling

        return self.c_proj(ctx_diff.flatten(2))

# Helper: RMSNorm
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        var = x.pow(2).mean(dim=-1, keepdim=True)
        return x * torch.rsqrt(var + self.eps) * self.weight