# GroundThink v6 Hybrid Architecture - v5
## TRUE Delta Rule + Retrieval Training

**Critical Fix in v5:**
- FLA's `chunk_gated_delta_rule` processes chunks in parallel, breaking Delta correction
- This version uses **pure PyTorch token-by-token** Delta Rule for correctness
- State stays bounded (no more -55 or 66,958 signal strength)

**Architecture:**
1. GDN: True Delta Rule with error correction `S += β * (v - S·k) ⊗ k`
2. SWA: Sparse retrieval with dedicated query projection
3. Training: Curriculum learning (retrieval warmup → mixed)

---

In [1]:
# =============================================================================
# CELL 0: Configuration & Imports
# =============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from dataclasses import dataclass
from typing import Optional, List, Dict, Tuple, Any
import math
import time

@dataclass
class HybridConfig:
    d_model: int = 256
    n_heads: int = 8
    head_dim: int = 32
    expand_v: float = 2.0
    vocab_size: int = 50257
    layer_pattern: str = "GS"
    window_size: int = 64
    init_std: float = 0.02
    state_accumulation: str = 'replace'
    marker_token: int = 50251
    cue_token: int = 50250
    
    def __post_init__(self):
        self.head_dim = self.d_model // self.n_heads
        self.value_dim = int(self.head_dim * self.expand_v)

def count_params(model):
    return sum(p.numel() for p in model.parameters())

print("Configuration loaded.")

Configuration loaded.


In [2]:
# =============================================================================
# CELL 1: Basic Components
# =============================================================================

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x / rms * self.weight

class FFN(nn.Module):
    def __init__(self, cfg: HybridConfig):
        super().__init__()
        hidden = int(cfg.d_model * 4)
        self.norm = RMSNorm(cfg.d_model)
        self.w1 = nn.Linear(cfg.d_model, hidden, bias=False)
        self.w2 = nn.Linear(hidden, cfg.d_model, bias=False)
        self.w3 = nn.Linear(cfg.d_model, hidden, bias=False)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.norm(x)
        return x + self.w2(F.silu(self.w1(h)) * self.w3(h))

print("Basic components loaded.")

Basic components loaded.


In [3]:
# =============================================================================
# CELL 2: GatedDeltaNetLayer - TRUE DELTA RULE (PURE PYTORCH)
# =============================================================================
#
# WHY NOT USE FLA's chunk_gated_delta_rule?
# - It processes chunks in parallel for speed
# - This breaks the error correction: each token's update MUST depend on
#   the state AFTER all previous tokens, not the initial state
#
# TRUE DELTA RULE: S_t = g*S_{t-1} + β * (v - S_{t-1}·k) ⊗ k
# - If we've already stored v at address k, error ≈ 0, no redundant write
# - This keeps state bounded and memory efficient
#
# =============================================================================

class GatedDeltaNetLayer(nn.Module):
    """True Delta Rule GDN - token-by-token for correctness."""
    
    def __init__(self, cfg: HybridConfig, layer_idx: int):
        super().__init__()
        self.cfg = cfg
        self.layer_idx = layer_idx
        H, K, V = cfg.n_heads, cfg.head_dim, cfg.value_dim
        
        self.q_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.k_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.v_proj = nn.Linear(cfg.d_model, H * V, bias=False)
        self.o_proj = nn.Linear(H * V, cfg.d_model, bias=False)
        
        # Beta gate (write strength) - gatekeeper init
        self.beta_proj = nn.Linear(cfg.d_model, H, bias=True)
        nn.init.constant_(self.beta_proj.bias, -2.0)  # sigmoid(-2) ≈ 0.12
        
        # Forget gate (retention) - high default
        self.g_proj = nn.Linear(cfg.d_model, H, bias=True)
        nn.init.constant_(self.g_proj.bias, 2.0)  # sigmoid(2) ≈ 0.88
        
        self.norm = RMSNorm(cfg.d_model)
        
    def forward(self, x, initial_state=None, output_state=True):
        B, T, D = x.shape
        H, K, V = self.cfg.n_heads, self.cfg.head_dim, self.cfg.value_dim
        
        x_norm = self.norm(x)
        
        q = self.q_proj(x_norm).view(B, T, H, K)
        k = self.k_proj(x_norm).view(B, T, H, K)
        v = self.v_proj(x_norm).view(B, T, H, V)
        
        # CRITICAL: Normalize keys for stability
        k = F.normalize(k.float(), p=2, dim=-1).to(x.dtype)
        
        beta = torch.sigmoid(self.beta_proj(x_norm))  # [B, T, H]
        g = torch.sigmoid(self.g_proj(x_norm))        # [B, T, H]
        
        # Initialize state
        if initial_state is None:
            state = torch.zeros(B, H, K, V, device=x.device, dtype=x.dtype)
        else:
            state = initial_state.to(x.dtype)
        
        # Token-by-token TRUE Delta Rule
        outputs = []
        for t in range(T):
            k_t = k[:, t]      # [B, H, K]
            v_t = v[:, t]      # [B, H, V]
            beta_t = beta[:, t]  # [B, H]
            g_t = g[:, t]        # [B, H]
            
            # 1. Prediction from current state
            prediction = torch.einsum('bhkv,bhk->bhv', state, k_t)
            
            # 2. Error (what we want - what we'd retrieve)
            error = v_t - prediction
            
            # 3. Outer product update
            update = torch.einsum('bhv,bhk->bhkv', error, k_t)
            update = beta_t.unsqueeze(-1).unsqueeze(-1) * update
            
            # 4. Apply forget gate and update
            state = g_t.unsqueeze(-1).unsqueeze(-1) * state + update
            
            # 5. Output: retrieve using q
            q_t = q[:, t]  # [B, H, K]
            out_t = torch.einsum('bhkv,bhk->bhv', state, q_t)
            outputs.append(out_t)
        
        output = torch.stack(outputs, dim=1)  # [B, T, H, V]
        output = output.reshape(B, T, H * V)
        output = self.o_proj(output)
        output = x + output
        
        diag = {
            'beta_mean': beta.mean().item(),
            'beta_max': beta.max().item(),
            'g_mean': g.mean().item(),
            'state_norm': state.norm().item(),
            'state_max': state.abs().max().item(),
        }
        
        return output, state, diag

print("GatedDeltaNetLayer loaded (TRUE DELTA RULE).")
print("  - Pure PyTorch, token-by-token")
print("  - Error correction: v - S·k")
print("  - Key normalization for stability")

GatedDeltaNetLayer loaded (TRUE DELTA RULE).
  - Pure PyTorch, token-by-token
  - Error correction: v - S·k
  - Key normalization for stability


In [4]:
# =============================================================================
# CELL 2b: DELTA RULE VALIDATION
# =============================================================================

def validate_delta_rule():
    """Verify Delta Rule suppresses redundant updates."""
    B, H, K, V = 1, 4, 16, 32
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    print("=" * 60)
    print("DELTA RULE VALIDATION")
    print("=" * 60)
    
    state = torch.zeros(B, H, K, V, device=device)
    
    # Normalized key, random value
    k = torch.randn(B, H, K, device=device)
    k = F.normalize(k, p=2, dim=-1)
    v = torch.randn(B, H, V, device=device)
    beta = torch.ones(B, H, device=device)
    g = torch.ones(B, H, device=device)  # No decay
    
    # First update
    pred_1 = torch.einsum('bhkv,bhk->bhv', state, k)
    error_1 = v - pred_1
    update_1 = torch.einsum('bhv,bhk->bhkv', error_1, k)
    state = state + beta.unsqueeze(-1).unsqueeze(-1) * update_1
    norm_1 = state.norm().item()
    
    print(f"\nFirst token:")
    print(f"  Error norm: {error_1.norm().item():.4f}")
    print(f"  State norm: {norm_1:.4f}")
    
    # Second update (SAME k, v)
    pred_2 = torch.einsum('bhkv,bhk->bhv', state, k)
    error_2 = v - pred_2
    update_2 = torch.einsum('bhv,bhk->bhkv', error_2, k)
    state = state + beta.unsqueeze(-1).unsqueeze(-1) * update_2
    norm_2 = state.norm().item()
    
    print(f"\nSecond token (SAME k, v):")
    print(f"  Error norm: {error_2.norm().item():.6f}  ← Should be ~0")
    print(f"  State norm: {norm_2:.4f}")
    print(f"  State growth: {norm_2/norm_1:.4f}x  ← Should be ~1.0")
    
    if error_2.norm().item() < 0.01 and norm_2 / norm_1 < 1.1:
        print(f"\n✓ [PASS] TRUE DELTA RULE: Redundant info suppressed!")
        return True
    else:
        print(f"\n✗ [FAIL] NOT Delta Rule")
        return False

# Run validation
validate_delta_rule()

DELTA RULE VALIDATION

First token:
  Error norm: 10.1741
  State norm: 10.1741

Second token (SAME k, v):
  Error norm: 0.000001  ← Should be ~0
  State norm: 10.1741
  State growth: 1.0000x  ← Should be ~1.0

✓ [PASS] TRUE DELTA RULE: Redundant info suppressed!


True

In [5]:
# =============================================================================
# CELL 3: SlidingWindowAttention with State Retrieval
# =============================================================================

class SlidingWindowAttention(nn.Module):
    def __init__(self, cfg: HybridConfig, layer_idx: int):
        super().__init__()
        self.cfg = cfg
        H, K, V = cfg.n_heads, cfg.head_dim, cfg.value_dim
        
        # Local attention
        self.q_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.k_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.v_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.o_proj = nn.Linear(H * K, cfg.d_model, bias=False)
        
        # State retrieval (dedicated projection)
        self.global_q_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        nn.init.normal_(self.global_q_proj.weight, std=0.02)
        self.retrieval_o_proj = nn.Linear(H * V, cfg.d_model, bias=False)
        
        # Gate (starts open for recall)
        self.gate_proj = nn.Linear(cfg.d_model, H, bias=True)
        nn.init.constant_(self.gate_proj.bias, 1.0)
        
        self.norm = RMSNorm(cfg.d_model)
        self.scale = K ** -0.5
        
    def forward(self, x, gdn_state=None):
        B, T, D = x.shape
        H, K, V, W = self.cfg.n_heads, self.cfg.head_dim, self.cfg.value_dim, self.cfg.window_size
        x_norm = self.norm(x)
        
        # Local attention
        q = self.q_proj(x_norm).view(B, T, H, K).transpose(1, 2)
        k = self.k_proj(x_norm).view(B, T, H, K).transpose(1, 2)
        v = self.v_proj(x_norm).view(B, T, H, K).transpose(1, 2)
        
        mask = torch.ones(T, T, device=x.device, dtype=torch.bool).triu(1) | \
               torch.ones(T, T, device=x.device, dtype=torch.bool).tril(-W - 1)
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        attn_w = F.softmax(attn, dim=-1)
        local_out = (attn_w @ v).transpose(1, 2).reshape(B, T, H * K)
        local_out = self.o_proj(local_out)
        
        # State retrieval
        retrieval_out = torch.zeros_like(x)
        gate_mean = 0.0
        
        if gdn_state is not None:
            q_g = self.global_q_proj(x_norm).view(B, T, H, K).transpose(1, 2)
            q_g = F.relu(q_g)  # Sparse queries
            
            retrieved = torch.einsum('bhkv,bhtk->bhtv', gdn_state.to(x.dtype), q_g)
            retrieved = retrieved.transpose(1, 2).reshape(B, T, H * V)
            retrieval_out = self.retrieval_o_proj(retrieved)
            
            gate = torch.sigmoid(self.gate_proj(x_norm))
            gate_mean = gate.mean().item()
            retrieval_out = gate.mean(dim=-1, keepdim=True) * retrieval_out
        
        out = x + local_out + retrieval_out
        diag = {'gate_mean': gate_mean, 'local_norm': local_out.norm().item()}
        return out, diag

print("SlidingWindowAttention loaded.")

SlidingWindowAttention loaded.


In [6]:
# =============================================================================
# CELL 4: TransparentHybrid Model
# =============================================================================

class TransparentHybrid(nn.Module):
    def __init__(self, cfg: HybridConfig):
        super().__init__()
        self.cfg = cfg
        
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
        nn.init.normal_(self.embed.weight, std=cfg.init_std)
        
        self.layers = nn.ModuleList()
        self.ffns = nn.ModuleList()
        
        for i, lt in enumerate(cfg.layer_pattern):
            if lt == 'G':
                self.layers.append(GatedDeltaNetLayer(cfg, i))
            elif lt == 'S':
                self.layers.append(SlidingWindowAttention(cfg, i))
            self.ffns.append(FFN(cfg))
        
        self.norm_f = RMSNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        self.lm_head.weight = self.embed.weight
        
    def forward(self, input_ids, targets=None, return_diagnostics=False):
        x = self.embed(input_ids)
        state = None
        all_diag = []
        
        for i, (layer, ffn) in enumerate(zip(self.layers, self.ffns)):
            lt = self.cfg.layer_pattern[i]
            if lt == 'G':
                x, state, diag = layer(x, initial_state=state)
            else:
                x, diag = layer(x, gdn_state=state)
            x = ffn(x)
            diag['layer'] = lt
            all_diag.append(diag)
        
        x = self.norm_f(x)
        logits = self.lm_head(x)
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
        
        return logits, loss, all_diag, state

print("TransparentHybrid loaded.")

TransparentHybrid loaded.


In [7]:
# =============================================================================
# CELL 5: Data Loading
# =============================================================================

from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Dataset

class TextDataset(Dataset):
    def __init__(self, tokens, seq_len=128):
        self.tokens = tokens
        self.seq_len = seq_len
    def __len__(self):
        return (len(self.tokens) - 1) // self.seq_len
    def __getitem__(self, idx):
        start = idx * self.seq_len
        return torch.tensor(self.tokens[start:start + self.seq_len + 1], dtype=torch.long)

def load_data(n_tokens=500_000, seq_len=128, batch_size=16):
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    dataset = load_dataset('wikitext', 'wikitext-103-raw-v1', split='train')
    all_tokens = []
    for item in dataset:
        if item['text'].strip():
            all_tokens.extend(tokenizer.encode(item['text']))
            if len(all_tokens) >= n_tokens:
                break
    all_tokens = all_tokens[:n_tokens]
    print(f"Loaded {len(all_tokens):,} tokens")
    ds = TextDataset(all_tokens, seq_len)
    return DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)

data_loader = load_data(n_tokens=500_000)

  from .autonotebook import tqdm as notebook_tqdm


Loaded 500,000 tokens


In [8]:
# =============================================================================
# CELL 6: Retrieval Testing
# =============================================================================

def proper_niah_test(model, seq_len=128, needle_pos=32, n_trials=30):
    """NIAH test with MARKER + CUE tokens."""
    model.eval()
    device = next(model.parameters()).device
    cfg = model.cfg
    
    correct = 0
    for _ in range(n_trials):
        needle_id = cfg.vocab_size - 3
        seq = torch.randint(0, cfg.vocab_size - 100, (1, seq_len), device=device)
        seq[0, needle_pos] = cfg.marker_token
        seq[0, needle_pos + 1] = needle_id
        seq[0, -1] = cfg.cue_token
        
        with torch.no_grad():
            logits, _, _, _ = model(seq)
        
        pred = logits[0, -1].argmax().item()
        if pred == needle_id:
            correct += 1
    
    acc = correct / n_trials
    print(f"  Accuracy: {acc*100:.1f}% ({correct}/{n_trials})")
    return {'accuracy': acc}

def test_niah_by_distance(model, distances=[5, 10, 20, 40, 60, 95], n_trials=20):
    print(f"Testing retrieval across distances: {distances}")
    for dist in distances:
        needle_pos = max(2, 128 - dist - 2)
        print(f"  Distance {dist}:  ", end="")
        proper_niah_test(model, needle_pos=needle_pos, n_trials=n_trials)

def run_full_diagnostic(model, seq_len=128, needle_pos=32):
    """Check SNR and state health."""
    model.eval()
    device = next(model.parameters()).device
    cfg = model.cfg
    
    needle_id = cfg.vocab_size - 3
    seq = torch.randint(0, cfg.vocab_size - 100, (1, seq_len), device=device)
    seq[0, needle_pos] = cfg.marker_token
    seq[0, needle_pos + 1] = needle_id
    seq[0, -1] = cfg.cue_token
    
    with torch.no_grad():
        logits, _, diags, state = model(seq, return_diagnostics=True)
    
    print(f"\n--- Diagnostic ---")
    print(f"  State norm: {state.norm().item():.2f}")
    print(f"  State max:  {state.abs().max().item():.2f}")
    for d in diags:
        if d['layer'] == 'G':
            print(f"  GDN β={d['beta_mean']:.3f}, g={d['g_mean']:.3f}")
        else:
            print(f"  SWA gate={d['gate_mean']:.3f}")
    
    # Check if state is bounded (good Delta Rule)
    if state.abs().max().item() < 10:
        print(f"  ✓ State bounded - Delta Rule working!")
    else:
        print(f"  ⚠ State large - check Delta Rule")
    
    return state

print("Retrieval testing loaded.")

Retrieval testing loaded.


In [9]:
# =============================================================================
# CELL 7: Training Infrastructure
# =============================================================================

def compute_retrieval_loss(model, seq_len=128):
    """Synthetic retrieval task for gradient signal."""
    device = next(model.parameters()).device
    cfg = model.cfg
    batch_size = 4
    
    needle_id = cfg.vocab_size - 3
    tokens = torch.randint(0, cfg.vocab_size - 100, (batch_size, seq_len), device=device)
    
    # Random needle positions
    for i in range(batch_size):
        pos = torch.randint(5, seq_len - 10, (1,)).item()
        tokens[i, pos] = cfg.marker_token
        tokens[i, pos + 1] = needle_id
    tokens[:, -1] = cfg.cue_token
    
    targets = torch.full((batch_size, seq_len), -100, device=device)
    targets[:, -1] = needle_id
    
    _, loss, _, _ = model(tokens, targets=targets)
    return loss

def train_curriculum(model, data_loader, steps=1000, warmup_steps=200):
    """Curriculum: retrieval warmup → mixed training."""
    device = next(model.parameters()).device
    optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
    
    lm_iter = iter(data_loader)
    history = {'lm': [], 'ret': []}
    
    print(f"Training {steps} steps ({warmup_steps} warmup)...")
    
    for step in range(steps):
        optimizer.zero_grad()
        
        # Phase 1: Retrieval only
        if step < warmup_steps:
            ret_loss = compute_retrieval_loss(model)
            ret_loss.backward()
            history['ret'].append(ret_loss.item())
            history['lm'].append(0)
        # Phase 2: Mixed
        else:
            try:
                batch = next(lm_iter)
            except StopIteration:
                lm_iter = iter(data_loader)
                batch = next(lm_iter)
            
            input_ids = batch[:, :-1].to(device)
            targets = batch[:, 1:].to(device)
            _, lm_loss, _, _ = model(input_ids, targets)
            
            ret_loss = compute_retrieval_loss(model)
            
            total = lm_loss + 2.0 * ret_loss
            total.backward()
            
            history['lm'].append(lm_loss.item())
            history['ret'].append(ret_loss.item())
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        if step % 100 == 0:
            phase = "WARMUP" if step < warmup_steps else "MIXED"
            lm = history['lm'][-1]
            ret = history['ret'][-1]
            print(f"[{phase}] {step}: LM={lm:.3f} RET={ret:.3f}")
    
    return history

print("Training infrastructure loaded.")

Training infrastructure loaded.


In [10]:
# =============================================================================
# CELL 8: Create Model & Validate
# =============================================================================

cfg = HybridConfig(d_model=256, n_heads=8, layer_pattern="GS")
model = TransparentHybrid(cfg).cuda().bfloat16()

print(f"Model: {cfg.layer_pattern}")
print(f"Parameters: {count_params(model):,}")

# Quick validation
print("\n--- Pre-training check ---")
x = torch.randint(0, 1000, (1, 128), device='cuda')
with torch.no_grad():
    _, _, diags, state = model(x, return_diagnostics=True)
print(f"Initial state norm: {state.norm().item():.4f}")
print(f"GDN β={diags[0]['beta_mean']:.3f}, g={diags[0]['g_mean']:.3f}")

Model: GS
Parameters: 15,298,072

--- Pre-training check ---
Initial state norm: 4.3750
GDN β=0.130, g=0.867


In [11]:
# =============================================================================
# CELL 9: Train
# =============================================================================

history = train_curriculum(model, data_loader, steps=1000, warmup_steps=200)

Training 1000 steps (200 warmup)...
[WARMUP] 0: LM=0.000 RET=10.812
[WARMUP] 100: LM=0.000 RET=0.013
[MIXED] 200: LM=15.438 RET=0.007
[MIXED] 300: LM=6.781 RET=0.003
[MIXED] 400: LM=6.250 RET=0.002
[MIXED] 500: LM=5.938 RET=0.002
[MIXED] 600: LM=5.938 RET=0.002
[MIXED] 700: LM=5.562 RET=0.002
[MIXED] 800: LM=5.688 RET=0.002
[MIXED] 900: LM=5.438 RET=0.002


In [12]:
# =============================================================================
# CELL 10: Evaluate
# =============================================================================

print("=" * 60)
print("POST-TRAINING EVALUATION")
print("=" * 60)

print("\n1. NIAH Accuracy:")
proper_niah_test(model, n_trials=30)

print("\n2. NIAH by Distance:")
test_niah_by_distance(model)

print("\n3. State Health:")
run_full_diagnostic(model)

print("\n4. Delta Rule Validation:")
validate_delta_rule()

POST-TRAINING EVALUATION

1. NIAH Accuracy:
  Accuracy: 100.0% (30/30)

2. NIAH by Distance:
Testing retrieval across distances: [5, 10, 20, 40, 60, 95]
  Distance 5:    Accuracy: 100.0% (20/20)
  Distance 10:    Accuracy: 100.0% (20/20)
  Distance 20:    Accuracy: 100.0% (20/20)
  Distance 40:    Accuracy: 100.0% (20/20)
  Distance 60:    Accuracy: 100.0% (20/20)
  Distance 95:    Accuracy: 100.0% (20/20)

3. State Health:

--- Diagnostic ---
  State norm: 5.66
  State max:  0.29
  GDN β=0.214, g=0.742
  SWA gate=0.676
  ✓ State bounded - Delta Rule working!

4. Delta Rule Validation:
DELTA RULE VALIDATION

First token:
  Error norm: 11.3580
  State norm: 11.3580

Second token (SAME k, v):
  Error norm: 0.000001  ← Should be ~0
  State norm: 11.3580
  State growth: 1.0000x  ← Should be ~1.0

✓ [PASS] TRUE DELTA RULE: Redundant info suppressed!


True

In [13]:
# =============================================================================
# CELL: PROFILING & VALIDATION (Run this to understand what we built)
# =============================================================================
#
# BEFORE scaling, answer these questions:
# 1. How slow is sequential vs chunked?
# 2. What edge cases does Delta Rule handle?
# 3. Where does it break?
#
# =============================================================================

import time

# -----------------------------------------------------------------------------
# SPEED TEST: Sequential Delta Rule
# -----------------------------------------------------------------------------

def profile_sequential(T_values=[64, 128, 256, 512], B=8, H=8, K=32, V=64, n_runs=5):
    """Profile the sequential loop we're using."""
    print("=" * 60)
    print("SPEED: Sequential Delta Rule (what we're using)")
    print("=" * 60)
    
    device = "cuda"
    
    for T in T_values:
        k = F.normalize(torch.randn(B, T, H, K, device=device), dim=-1)
        v = torch.randn(B, T, H, V, device=device)
        beta = torch.sigmoid(torch.randn(B, T, H, device=device))
        g = torch.sigmoid(torch.randn(B, T, H, device=device))
        
        # Warmup
        for _ in range(2):
            state = torch.zeros(B, H, K, V, device=device)
            for t in range(T):
                pred = torch.einsum('bhkv,bhk->bhv', state, k[:, t])
                error = v[:, t] - pred
                update = torch.einsum('bhv,bhk->bhkv', error, k[:, t])
                state = g[:, t].unsqueeze(-1).unsqueeze(-1) * state + \
                        beta[:, t].unsqueeze(-1).unsqueeze(-1) * update
            torch.cuda.synchronize()
        
        # Timed
        torch.cuda.synchronize()
        start = time.perf_counter()
        for _ in range(n_runs):
            state = torch.zeros(B, H, K, V, device=device)
            for t in range(T):
                pred = torch.einsum('bhkv,bhk->bhv', state, k[:, t])
                error = v[:, t] - pred
                update = torch.einsum('bhv,bhk->bhkv', error, k[:, t])
                state = g[:, t].unsqueeze(-1).unsqueeze(-1) * state + \
                        beta[:, t].unsqueeze(-1).unsqueeze(-1) * update
        torch.cuda.synchronize()
        elapsed = (time.perf_counter() - start) / n_runs
        
        toks_per_sec = (B * T) / elapsed
        print(f"  T={T:4d}: {elapsed*1000:7.2f} ms | {toks_per_sec:>10,.0f} tok/s")

profile_sequential()

# -----------------------------------------------------------------------------
# SPEED TEST: FLA Chunked Kernel (for comparison)
# -----------------------------------------------------------------------------

try:
    from fla.ops.gated_delta_rule import chunk_gated_delta_rule
    
    def profile_fla(T_values=[64, 128, 256, 512], B=8, H=8, K=32, V=64, n_runs=5):
        print("\n" + "=" * 60)
        print("SPEED: FLA Chunked Kernel (what we had before)")
        print("=" * 60)
        
        device = "cuda"
        
        for T in T_values:
            q = torch.randn(B, T, H, K, device=device)
            k = F.normalize(torch.randn(B, T, H, K, device=device), dim=-1)
            v = torch.randn(B, T, H, V, device=device)
            beta = torch.sigmoid(torch.randn(B, T, H, device=device))
            g = torch.sigmoid(torch.randn(B, T, H, device=device))
            
            # Warmup
            for _ in range(2):
                _, _ = chunk_gated_delta_rule(q, k, v, g, beta, output_final_state=True)
                torch.cuda.synchronize()
            
            # Timed
            torch.cuda.synchronize()
            start = time.perf_counter()
            for _ in range(n_runs):
                _, _ = chunk_gated_delta_rule(q, k, v, g, beta, output_final_state=True)
            torch.cuda.synchronize()
            elapsed = (time.perf_counter() - start) / n_runs
            
            toks_per_sec = (B * T) / elapsed
            print(f"  T={T:4d}: {elapsed*1000:7.2f} ms | {toks_per_sec:>10,.0f} tok/s")
    
    profile_fla()
    
except ImportError:
    print("\nFLA not available for comparison")

# -----------------------------------------------------------------------------
# VALIDATION TESTS
# -----------------------------------------------------------------------------

print("\n" + "=" * 60)
print("VALIDATION: Delta Rule Correctness Tests")
print("=" * 60)

# Test 1: Identical tokens
print("\n--- Test 1: Identical Tokens ---")
B, H, K, V = 1, 4, 32, 64
state = torch.zeros(B, H, K, V, device='cuda')
k = F.normalize(torch.randn(B, H, K, device='cuda'), dim=-1)
v = torch.randn(B, H, V, device='cuda')

pred1 = torch.einsum('bhkv,bhk->bhv', state, k)
error1 = v - pred1
state = state + torch.einsum('bhv,bhk->bhkv', error1, k)
norm1 = state.norm().item()

pred2 = torch.einsum('bhkv,bhk->bhv', state, k)
error2 = v - pred2
state = state + torch.einsum('bhv,bhk->bhkv', error2, k)
norm2 = state.norm().item()

print(f"  Error1: {error1.norm().item():.4f}")
print(f"  Error2: {error2.norm().item():.6f} (should be ~0)")
print(f"  Growth: {norm2/norm1:.4f}x (should be ~1.0)")
print(f"  → {'✓ PASS' if error2.norm().item() < 0.001 else '✗ FAIL'}")

# Test 2: Orthogonal keys
print("\n--- Test 2: Orthogonal Keys ---")
state = torch.zeros(1, 1, 32, 64, device='cuda')
k1 = torch.zeros(1, 1, 32, device='cuda'); k1[0,0,0] = 1.0
k2 = torch.zeros(1, 1, 32, device='cuda'); k2[0,0,1] = 1.0
v1 = torch.randn(1, 1, 64, device='cuda')
v2 = torch.randn(1, 1, 64, device='cuda')

# Write v1 at k1
state = state + torch.einsum('bhv,bhk->bhkv', v1, k1)
# Write v2 at k2
state = state + torch.einsum('bhv,bhk->bhkv', v2, k2)

ret1 = torch.einsum('bhkv,bhk->bhv', state, k1)
ret2 = torch.einsum('bhkv,bhk->bhv', state, k2)
err1 = (ret1 - v1).norm().item() / v1.norm().item()
err2 = (ret2 - v2).norm().item() / v2.norm().item()

print(f"  v1 retrieval error: {err1:.6f}")
print(f"  v2 retrieval error: {err2:.6f}")
print(f"  → {'✓ PASS' if err1 < 0.001 and err2 < 0.001 else '✗ FAIL'}")

# Test 3: Capacity / interference
print("\n--- Test 3: Capacity (100 random writes) ---")
state = torch.zeros(1, 1, 32, 64, device='cuda')
keys, values = [], []
for i in range(100):
    k = F.normalize(torch.randn(1, 1, 32, device='cuda'), dim=-1)
    v = torch.randn(1, 1, 64, device='cuda')
    keys.append(k); values.append(v)
    pred = torch.einsum('bhkv,bhk->bhv', state, k)
    error = v - pred
    state = state + torch.einsum('bhv,bhk->bhkv', error, k)

ret_first = torch.einsum('bhkv,bhk->bhv', state, keys[0])
ret_last = torch.einsum('bhkv,bhk->bhv', state, keys[-1])
err_first = (ret_first - values[0]).norm().item() / values[0].norm().item()
err_last = (ret_last - values[-1]).norm().item() / values[-1].norm().item()

print(f"  State norm after 100 writes: {state.norm().item():.2f}")
print(f"  First item error: {err_first:.4f}")
print(f"  Last item error:  {err_last:.4f}")
print(f"  → First item degraded (expected with random keys)")

print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print("""
Speed: Sequential loop is ~10-50x slower than chunked kernel
       This may be acceptable for small T, but scales poorly

Correctness:
  ✓ Redundant writes suppressed (true Delta Rule)
  ✓ Orthogonal keys stored independently  
  ⚠ Random keys interfere (inherent to associative memory)
  ⚠ Capacity limited by K*V state size

Next steps:
  1. Decide if speed is acceptable for your use case
  2. Consider hybrid: chunked for training, sequential for inference
  3. Consider custom CUDA kernel that does Delta Rule correctly in chunks
""")

SPEED: Sequential Delta Rule (what we're using)
  T=  64:   18.50 ms |     27,680 tok/s
  T= 128:   35.69 ms |     28,689 tok/s
  T= 256:   51.50 ms |     39,765 tok/s
  T= 512:  108.66 ms |     37,696 tok/s

SPEED: FLA Chunked Kernel (what we had before)
  T=  64:    0.60 ms |    856,513 tok/s
  T= 128:    1.08 ms |    949,515 tok/s
  T= 256:    1.78 ms |  1,147,598 tok/s
  T= 512:    3.36 ms |  1,219,835 tok/s

VALIDATION: Delta Rule Correctness Tests

--- Test 1: Identical Tokens ---
  Error1: 15.2943
  Error2: 0.000001 (should be ~0)
  Growth: 1.0000x (should be ~1.0)
  → ✓ PASS

--- Test 2: Orthogonal Keys ---
  v1 retrieval error: 0.000000
  v2 retrieval error: 0.000000
  → ✓ PASS

--- Test 3: Capacity (100 random writes) ---
  State norm after 100 writes: 44.62
  First item error: 1.2811
  Last item error:  0.0000
  → First item degraded (expected with random keys)

SUMMARY

Speed: Sequential loop is ~10-50x slower than chunked kernel
       This may be acceptable for small T, b

In [14]:
# =============================================================================
# TRUE DELTA RULE: PROFILING & VALIDATION SUITE
# =============================================================================
#
# Before scaling, we need to understand:
# 1. HOW SLOW is the sequential loop vs chunked kernel?
# 2. WHAT CASES does the Delta Rule handle correctly?
# 3. WHERE does it break down?
# 4. CAN we optimize without losing correctness?
#
# =============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from typing import Dict, List, Tuple

# Try to import FLA for comparison
try:
    from fla.ops.gated_delta_rule import chunk_gated_delta_rule
    HAS_FLA = True
except ImportError:
    HAS_FLA = False
    print("FLA not available - skipping kernel comparison")


# =============================================================================
# SECTION 1: SPEED PROFILING
# =============================================================================

def profile_sequential_delta(
    batch_sizes: List[int] = [1, 4, 8, 16],
    seq_lens: List[int] = [64, 128, 256, 512],
    n_heads: int = 8,
    head_dim: int = 32,
    value_dim: int = 64,
    n_warmup: int = 3,
    n_runs: int = 10,
    device: str = "cuda",
):
    """
    Profile the sequential Delta Rule implementation.
    
    This is the CRITICAL question: Is token-by-token viable?
    """
    print("=" * 70)
    print("SPEED PROFILING: Sequential Delta Rule")
    print("=" * 70)
    print(f"Config: H={n_heads}, K={head_dim}, V={value_dim}")
    print(f"Warmup: {n_warmup}, Runs: {n_runs}")
    print()
    
    results = {}
    
    for B in batch_sizes:
        for T in seq_lens:
            # Create inputs
            k = torch.randn(B, T, n_heads, head_dim, device=device)
            k = F.normalize(k, p=2, dim=-1)
            v = torch.randn(B, T, n_heads, value_dim, device=device)
            beta = torch.sigmoid(torch.randn(B, T, n_heads, device=device))
            g = torch.sigmoid(torch.randn(B, T, n_heads, device=device))
            
            # Warmup
            for _ in range(n_warmup):
                state = torch.zeros(B, n_heads, head_dim, value_dim, device=device)
                for t in range(T):
                    pred = torch.einsum('bhkv,bhk->bhv', state, k[:, t])
                    error = v[:, t] - pred
                    update = torch.einsum('bhv,bhk->bhkv', error, k[:, t])
                    state = g[:, t].unsqueeze(-1).unsqueeze(-1) * state + \
                            beta[:, t].unsqueeze(-1).unsqueeze(-1) * update
                torch.cuda.synchronize()
            
            # Timed runs
            times = []
            for _ in range(n_runs):
                torch.cuda.synchronize()
                start = time.perf_counter()
                
                state = torch.zeros(B, n_heads, head_dim, value_dim, device=device)
                for t in range(T):
                    pred = torch.einsum('bhkv,bhk->bhv', state, k[:, t])
                    error = v[:, t] - pred
                    update = torch.einsum('bhv,bhk->bhkv', error, k[:, t])
                    state = g[:, t].unsqueeze(-1).unsqueeze(-1) * state + \
                            beta[:, t].unsqueeze(-1).unsqueeze(-1) * update
                
                torch.cuda.synchronize()
                elapsed = time.perf_counter() - start
                times.append(elapsed)
            
            avg_ms = sum(times) / len(times) * 1000
            tokens_per_sec = (B * T) / (sum(times) / len(times))
            
            results[(B, T)] = {
                'avg_ms': avg_ms,
                'tokens_per_sec': tokens_per_sec,
            }
            
            print(f"B={B:2d}, T={T:4d}: {avg_ms:8.2f} ms | {tokens_per_sec:,.0f} tok/s")
    
    return results


def profile_fla_kernel(
    batch_sizes: List[int] = [1, 4, 8, 16],
    seq_lens: List[int] = [64, 128, 256, 512],
    n_heads: int = 8,
    head_dim: int = 32,
    value_dim: int = 64,
    n_warmup: int = 3,
    n_runs: int = 10,
    device: str = "cuda",
):
    """Profile FLA's chunked kernel for comparison."""
    if not HAS_FLA:
        print("FLA not available")
        return {}
    
    print("=" * 70)
    print("SPEED PROFILING: FLA Chunked Kernel")
    print("=" * 70)
    
    results = {}
    
    for B in batch_sizes:
        for T in seq_lens:
            q = torch.randn(B, T, n_heads, head_dim, device=device)
            k = torch.randn(B, T, n_heads, head_dim, device=device)
            k = F.normalize(k, p=2, dim=-1)
            v = torch.randn(B, T, n_heads, value_dim, device=device)
            beta = torch.sigmoid(torch.randn(B, T, n_heads, device=device))
            g = torch.sigmoid(torch.randn(B, T, n_heads, device=device))
            
            # Warmup
            for _ in range(n_warmup):
                _, _ = chunk_gated_delta_rule(q, k, v, g, beta, output_final_state=True)
                torch.cuda.synchronize()
            
            # Timed runs
            times = []
            for _ in range(n_runs):
                torch.cuda.synchronize()
                start = time.perf_counter()
                _, _ = chunk_gated_delta_rule(q, k, v, g, beta, output_final_state=True)
                torch.cuda.synchronize()
                elapsed = time.perf_counter() - start
                times.append(elapsed)
            
            avg_ms = sum(times) / len(times) * 1000
            tokens_per_sec = (B * T) / (sum(times) / len(times))
            
            results[(B, T)] = {
                'avg_ms': avg_ms,
                'tokens_per_sec': tokens_per_sec,
            }
            
            print(f"B={B:2d}, T={T:4d}: {avg_ms:8.2f} ms | {tokens_per_sec:,.0f} tok/s")
    
    return results


def compare_speeds():
    """Side-by-side comparison."""
    print("\n" + "=" * 70)
    print("SPEED COMPARISON: Sequential vs Chunked")
    print("=" * 70)
    
    seq_results = profile_sequential_delta(
        batch_sizes=[1, 8],
        seq_lens=[64, 128, 256, 512],
    )
    
    if HAS_FLA:
        print()
        fla_results = profile_fla_kernel(
            batch_sizes=[1, 8],
            seq_lens=[64, 128, 256, 512],
        )
        
        print("\n" + "-" * 70)
        print("SLOWDOWN FACTOR (Sequential / Chunked)")
        print("-" * 70)
        for key in seq_results:
            if key in fla_results:
                slowdown = seq_results[key]['avg_ms'] / fla_results[key]['avg_ms']
                print(f"B={key[0]:2d}, T={key[1]:4d}: {slowdown:6.1f}x slower")
    
    return seq_results


# =============================================================================
# SECTION 2: CORRECTNESS VALIDATION
# =============================================================================

def test_identical_tokens():
    """Test 1: Identical tokens should produce zero error on second write."""
    print("\n" + "=" * 60)
    print("TEST 1: Identical Tokens (Redundancy Suppression)")
    print("=" * 60)
    
    B, H, K, V = 1, 4, 32, 64
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    state = torch.zeros(B, H, K, V, device=device)
    k = F.normalize(torch.randn(B, H, K, device=device), dim=-1)
    v = torch.randn(B, H, V, device=device)
    
    # First write
    pred1 = torch.einsum('bhkv,bhk->bhv', state, k)
    error1 = v - pred1
    update1 = torch.einsum('bhv,bhk->bhkv', error1, k)
    state = state + update1
    norm1 = state.norm().item()
    
    # Second write (same k, v)
    pred2 = torch.einsum('bhkv,bhk->bhv', state, k)
    error2 = v - pred2
    update2 = torch.einsum('bhv,bhk->bhkv', error2, k)
    state = state + update2
    norm2 = state.norm().item()
    
    error_ratio = error2.norm().item() / (error1.norm().item() + 1e-8)
    growth = norm2 / norm1
    
    print(f"  Error1 norm: {error1.norm().item():.4f}")
    print(f"  Error2 norm: {error2.norm().item():.6f}")
    print(f"  Error ratio: {error_ratio:.6f} (should be ~0)")
    print(f"  State growth: {growth:.4f}x (should be ~1.0)")
    
    passed = error_ratio < 0.001 and growth < 1.01
    print(f"  → {'✓ PASS' if passed else '✗ FAIL'}")
    return passed


def test_orthogonal_keys():
    """Test 2: Orthogonal keys should store independently without interference."""
    print("\n" + "=" * 60)
    print("TEST 2: Orthogonal Keys (Independent Storage)")
    print("=" * 60)
    
    B, H, K, V = 1, 1, 32, 64
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    state = torch.zeros(B, H, K, V, device=device)
    
    # Create two orthogonal keys
    k1 = torch.zeros(B, H, K, device=device)
    k1[0, 0, 0] = 1.0  # Unit vector along dim 0
    
    k2 = torch.zeros(B, H, K, device=device)
    k2[0, 0, 1] = 1.0  # Unit vector along dim 1
    
    v1 = torch.randn(B, H, V, device=device)
    v2 = torch.randn(B, H, V, device=device)
    
    # Write v1 at k1
    pred = torch.einsum('bhkv,bhk->bhv', state, k1)
    error = v1 - pred
    update = torch.einsum('bhv,bhk->bhkv', error, k1)
    state = state + update
    
    # Write v2 at k2
    pred = torch.einsum('bhkv,bhk->bhv', state, k2)
    error = v2 - pred
    update = torch.einsum('bhv,bhk->bhkv', error, k2)
    state = state + update
    
    # Retrieve v1 using k1
    retrieved_v1 = torch.einsum('bhkv,bhk->bhv', state, k1)
    # Retrieve v2 using k2
    retrieved_v2 = torch.einsum('bhkv,bhk->bhv', state, k2)
    
    error_v1 = (retrieved_v1 - v1).norm().item() / v1.norm().item()
    error_v2 = (retrieved_v2 - v2).norm().item() / v2.norm().item()
    
    print(f"  v1 retrieval error: {error_v1:.6f} (should be ~0)")
    print(f"  v2 retrieval error: {error_v2:.6f} (should be ~0)")
    
    passed = error_v1 < 0.001 and error_v2 < 0.001
    print(f"  → {'✓ PASS' if passed else '✗ FAIL'}")
    return passed


def test_interference():
    """Test 3: Similar (non-orthogonal) keys cause interference."""
    print("\n" + "=" * 60)
    print("TEST 3: Key Interference (Expected Behavior)")
    print("=" * 60)
    
    B, H, K, V = 1, 1, 32, 64
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    state = torch.zeros(B, H, K, V, device=device)
    
    # Two similar keys (high dot product)
    k1 = F.normalize(torch.randn(B, H, K, device=device), dim=-1)
    k2 = k1 + 0.1 * F.normalize(torch.randn(B, H, K, device=device), dim=-1)
    k2 = F.normalize(k2, dim=-1)
    
    dot_product = (k1 * k2).sum().item()
    
    v1 = torch.randn(B, H, V, device=device)
    v2 = torch.randn(B, H, V, device=device)
    
    # Write v1 at k1
    pred = torch.einsum('bhkv,bhk->bhv', state, k1)
    error = v1 - pred
    update = torch.einsum('bhv,bhk->bhkv', error, k1)
    state = state + update
    
    # Retrieve v1 BEFORE writing v2
    retrieved_v1_before = torch.einsum('bhkv,bhk->bhv', state, k1)
    error_before = (retrieved_v1_before - v1).norm().item() / v1.norm().item()
    
    # Write v2 at k2 (similar key)
    pred = torch.einsum('bhkv,bhk->bhv', state, k2)
    error = v2 - pred
    update = torch.einsum('bhv,bhk->bhkv', error, k2)
    state = state + update
    
    # Retrieve v1 AFTER writing v2
    retrieved_v1_after = torch.einsum('bhkv,bhk->bhv', state, k1)
    error_after = (retrieved_v1_after - v1).norm().item() / v1.norm().item()
    
    print(f"  Key similarity (dot product): {dot_product:.4f}")
    print(f"  v1 error BEFORE v2 write: {error_before:.6f}")
    print(f"  v1 error AFTER v2 write:  {error_after:.4f}")
    print(f"  Interference occurred: {error_after > error_before}")
    
    # This test documents expected behavior, not a pass/fail
    print(f"  → This is EXPECTED: Similar keys interfere")
    return True


def test_capacity_limit():
    """Test 4: What happens when we write more items than state can hold?"""
    print("\n" + "=" * 60)
    print("TEST 4: Capacity Limit (State Saturation)")
    print("=" * 60)
    
    B, H, K, V = 1, 1, 32, 64
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    state = torch.zeros(B, H, K, V, device=device)
    
    # Theoretical capacity: K * V = 32 * 64 = 2048 floats
    # But with random keys, effective capacity is much less
    
    n_writes = [10, 50, 100, 200]
    results = []
    
    for n in n_writes:
        state = torch.zeros(B, H, K, V, device=device)
        keys = []
        values = []
        
        # Write n random (k, v) pairs
        for i in range(n):
            k = F.normalize(torch.randn(B, H, K, device=device), dim=-1)
            v = torch.randn(B, H, V, device=device)
            keys.append(k)
            values.append(v)
            
            pred = torch.einsum('bhkv,bhk->bhv', state, k)
            error = v - pred
            update = torch.einsum('bhv,bhk->bhkv', error, k)
            state = state + update
        
        # Test retrieval of first item
        retrieved_first = torch.einsum('bhkv,bhk->bhv', state, keys[0])
        error_first = (retrieved_first - values[0]).norm().item() / values[0].norm().item()
        
        # Test retrieval of last item
        retrieved_last = torch.einsum('bhkv,bhk->bhv', state, keys[-1])
        error_last = (retrieved_last - values[-1]).norm().item() / values[-1].norm().item()
        
        results.append({
            'n': n,
            'state_norm': state.norm().item(),
            'first_error': error_first,
            'last_error': error_last,
        })
        
        print(f"  n={n:3d}: state_norm={state.norm().item():.2f}, "
              f"first_err={error_first:.4f}, last_err={error_last:.4f}")
    
    print(f"\n  → State grows with writes, early items degrade")
    return results


def test_forget_gate():
    """Test 5: Forget gate controls information decay."""
    print("\n" + "=" * 60)
    print("TEST 5: Forget Gate Effect")
    print("=" * 60)
    
    B, H, K, V = 1, 1, 32, 64
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    k = F.normalize(torch.randn(B, H, K, device=device), dim=-1)
    v = torch.randn(B, H, V, device=device)
    
    for g_val in [1.0, 0.9, 0.5, 0.1]:
        state = torch.zeros(B, H, K, V, device=device)
        g = torch.full((B, H), g_val, device=device)
        
        # Write once
        pred = torch.einsum('bhkv,bhk->bhv', state, k)
        error = v - pred
        update = torch.einsum('bhv,bhk->bhkv', error, k)
        state = g.unsqueeze(-1).unsqueeze(-1) * state + update
        
        # Apply 10 more "empty" steps (just decay)
        k_noise = F.normalize(torch.randn(B, H, K, device=device), dim=-1)
        v_zero = torch.zeros(B, H, V, device=device)
        
        for _ in range(10):
            pred = torch.einsum('bhkv,bhk->bhv', state, k_noise)
            error = v_zero - pred
            update = torch.einsum('bhv,bhk->bhkv', error, k_noise)
            state = g.unsqueeze(-1).unsqueeze(-1) * state + update
        
        # Try to retrieve original
        retrieved = torch.einsum('bhkv,bhk->bhv', state, k)
        retention = (retrieved * v).sum().item() / (v.norm().item() ** 2)
        
        print(f"  g={g_val:.1f}: retention after 10 steps = {retention:.4f}")
    
    print(f"\n  → Lower g = faster decay")
    return True


# =============================================================================
# SECTION 3: RUN ALL TESTS
# =============================================================================

def run_all_validations():
    """Run complete validation suite."""
    print("\n" + "#" * 70)
    print("# COMPLETE VALIDATION SUITE")
    print("#" * 70)
    
    results = {
        'identical_tokens': test_identical_tokens(),
        'orthogonal_keys': test_orthogonal_keys(),
        'interference': test_interference(),
        'capacity': test_capacity_limit(),
        'forget_gate': test_forget_gate(),
    }
    
    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)
    for name, passed in results.items():
        status = "✓ PASS" if passed else "✗ FAIL/INFO"
        print(f"  {name}: {status}")
    
    return results


if __name__ == "__main__":
    # Run speed comparison
    compare_speeds()
    
    # Run all validations
    run_all_validations()


SPEED COMPARISON: Sequential vs Chunked
SPEED PROFILING: Sequential Delta Rule
Config: H=8, K=32, V=64
Warmup: 3, Runs: 10

B= 1, T=  64:    10.94 ms | 5,850 tok/s
B= 1, T= 128:    23.21 ms | 5,516 tok/s
B= 1, T= 256:    46.63 ms | 5,491 tok/s
B= 1, T= 512:    78.86 ms | 6,492 tok/s
B= 8, T=  64:    13.97 ms | 36,663 tok/s
B= 8, T= 128:    26.90 ms | 38,067 tok/s
B= 8, T= 256:    60.47 ms | 33,871 tok/s
B= 8, T= 512:   113.94 ms | 35,950 tok/s

SPEED PROFILING: FLA Chunked Kernel
B= 1, T=  64:     1.21 ms | 52,874 tok/s
B= 1, T= 128:     1.06 ms | 120,433 tok/s
B= 1, T= 256:     1.06 ms | 242,520 tok/s
B= 1, T= 512:     0.81 ms | 630,671 tok/s
B= 8, T=  64:     0.65 ms | 793,513 tok/s
B= 8, T= 128:     0.95 ms | 1,077,741 tok/s
B= 8, T= 256:     1.34 ms | 1,526,117 tok/s
B= 8, T= 512:     1.93 ms | 2,119,349 tok/s

----------------------------------------------------------------------
SLOWDOWN FACTOR (Sequential / Chunked)
--------------------------------------------------------------