# GroundThink v7

## GDN + SWA Hybrid with Chunk-Recurrent Delta Rule


**v7 Changes from v6:**
- Chunk-recurrent backward pass (numerically stable)
- Modular code, but now all scripts are in a single folder for direct import
- No inline Triton kernels in notebook

**Current Script Structure:**
```
config.py      # HybridConfig
core.py        # Triton kernels + chunk_delta_rule
model.py       # GDN, SWA, TransparentHybrid
analysis.py    # NIAH tests, training utils
```

**Notebook Import Mode:**
- All imports are now direct from scripts (not as a package)
- This enables compatibility with flat-folder workflows and dynamic imports in notebooks

**Notebook/Progress Bar Fixes:**
- tqdm notebook progress bars require `jupyter` and `ipywidgets` to be installed
- Both are now included in requirements.txt


In [1]:
# =============================================================================
# SETUP
# =============================================================================

import sys
import os
sys.path.insert(0, os.path.abspath(os.getcwd()))  # Ensure current folder is importable

import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

print("sys.path:", sys.path)
print("os.getcwd():", os.getcwd())

PyTorch: 2.11.0.dev20260128+cu128
CUDA: True
GPU: NVIDIA GeForce RTX 4050 Laptop GPU
sys.path: ['/home/m_tes/groundthink/gt-v6/v7-design/groundthink_v7', '/usr/lib/python312.zip', '/usr/lib/python3.12', '/usr/lib/python3.12/lib-dynload', '', '/home/m_tes/groundthink/gt-v6/.venv/lib/python3.12/site-packages']
os.getcwd(): /home/m_tes/groundthink/gt-v6/v7-design/groundthink_v7


In [2]:
# =============================================================================
# IMPORTS
# =============================================================================

# Patch sys.modules to allow relative imports in scripts when running as a notebook
import sys
import importlib.util
import types

def import_module_from_file(module_name, file_path):
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module

config = import_module_from_file('config', './config.py')
core = import_module_from_file('core', './core.py')
model = import_module_from_file('model', './model.py')
analysis = import_module_from_file('analysis', './analysis.py')

HybridConfig = config.HybridConfig
TransparentHybrid = model.TransparentHybrid
proper_niah_test = analysis.proper_niah_test
test_niah_by_distance = analysis.test_niah_by_distance
run_full_diagnostic = analysis.run_full_diagnostic
validate_delta_rule = analysis.validate_delta_rule
train_curriculum = analysis.train_curriculum
analyze_gradients = analysis.analyze_gradients
load_wikitext = analysis.load_wikitext

print("✓ v7 scripts imported (dynamic import mode)")

✓ v7 scripts imported (dynamic import mode)


In [3]:
# =============================================================================
# SYNC CONFIGURATION: Set Sequence Length
# =============================================================================
# This cell synchronizes the sequence length (NEW_T) across config, loader, and all reporting.

NEW_T = 2048  # Set your new sequence length here

cfg = HybridConfig(
    d_model=512,        # UP from 256 - more capacity for multi-needle
    n_heads=8,
    head_dim=64,        # UP from 32 (d_model / n_heads)
    value_dim=128,      # UP from 64 (2x head_dim is common)
    layer_pattern="GS",
    window_size=64,
    chunk_size=64,
    beta_bias=-2.0,
    g_bias=2.0,
)

print(cfg)
print(f"\nState capacity: {cfg.n_heads} heads × {cfg.head_dim} × {cfg.value_dim} = {cfg.n_heads * cfg.head_dim * cfg.value_dim:,} floats")

HybridConfig(GS, d=512, h=8, K=64, V=128)

State capacity: 8 heads × 64 × 128 = 65,536 floats


In [4]:
# =============================================================================
# CREATE MODEL
# =============================================================================

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = TransparentHybrid(cfg).to(DEVICE).bfloat16()

print(f"Parameters: {model.count_params():,}")
print(f"Device: {DEVICE}")

Parameters: 33,217,560
Device: cuda


In [5]:
# =============================================================================
# VALIDATION
# =============================================================================

print("="*60)
print("FORWARD/BACKWARD TEST")
print("="*60)

# Forward
x = torch.randint(0, 1000, (2, 64), device=DEVICE)
with torch.no_grad():
    logits, _, diags, state = model(x)
print(f"Forward: output={logits.shape}, state_norm={state.norm().item():.4f} ✓")

# Backward
model.train()
y = torch.randint(0, 1000, (2, 64), device=DEVICE)
_, loss, _, _ = model(x, y)
loss.backward()
print(f"Backward: loss={loss.item():.4f} ✓")

# Delta Rule validation
validate_delta_rule(DEVICE)

FORWARD/BACKWARD TEST
Forward: output=torch.Size([2, 64, 50257]), state_norm=7.5000 ✓
Backward: loss=10.9375 ✓

DELTA RULE VALIDATION

1. Identical Tokens (Redundancy Suppression):
  Error2: 0.000001 (should be ~0)
  Growth: 1.0000x
  → ✓ PASS

2. Orthogonal Keys (Independent Storage):
  v1 error: 0.000000
  v2 error: 0.000000
  → ✓ PASS

3. Capacity (100 writes):
  State norm: 44.82
  First error: 1.4286
  Last error: 0.0000
  → First degrades (expected)

OVERALL: ✓ ALL PASS


{'identical_tokens': True, 'orthogonal_keys': True, 'capacity': True}

In [6]:
# =============================================================================
# NIAH (Untrained Baseline)
# =============================================================================

print("\n" + "="*60)
print("NIAH TEST (Untrained)")
print("="*60)

model.eval()
proper_niah_test(model, seq_len=64, n_trials=20)


NIAH TEST (Untrained)
  Accuracy: 0.0% (0/20)


{'accuracy': 0.0, 'correct': 0, 'total': 20}

In [7]:
# =============================================================================
# LOAD DATA
# =============================================================================
# Math is: Batch Size * Sequence Length = Constant Tokens per Batch
# Epochs = Total Tokens / (Batch Size * Sequence Length)
# Logic is: Adjust batch size inversely with sequence length to maintain constant token count per batch.
# Science is: Longer sequence = more context, fewer updates per epoch; 
#             Shorter sequence = less context, more updates per epoch.
#          Keeping total tokens per batch constant balances context and update frequency.
# Practicality is: Larger batch sizes improve training stability and throughput,
# but require more memory. Smaller batch sizes fit in memory but may lead to noisier updates
# and lower throughput. Find a balance based on your hardware capabilities.
# =============================================================================
# Note: Monitor batch size and sequence length to avoid OOM errors.  
# These values are for a GPU with 6GB VRAM (e.g., RTX 4050).
# 4096 sequence length with batch size 2 is stable and efficient.
# 2048 sequence length with batch size 4-2 is safe. Batch 2 learns slower with higher tok/s.
# 1024 sequence length with batch size 8 is a good starting point.
# 512 or less sequence length with batch size 16 (32 untested) also works well.
#
# COMPETENCE MATH:
#   - Tokens per step = batch_size * seq_len
#   - Steps per epoch = n_tokens / tokens_per_step  
#   - Target: ~2 epochs for learning, not memorizing
#   - At batch=2, seq=4096: 8192 tokens/step, 2M tokens = 244 batches

data_loader = load_wikitext(n_tokens=2_000_000, seq_len=NEW_T, batch_size=4)

Loading 2,000,000 tokens from wikitext-103...
Loaded 2,000,000 tokens


In [8]:
# =============================================================================
# MEMORY PROFILING
# =============================================================================
# Track exact VRAM usage for different batch/seq combinations

def profile_memory(model, seq_lens=[512, 1024, 2048, 4096], batch_sizes=[2, 4, 8]):
    """Profile peak memory usage for training forward+backward."""
    import gc
    device = next(model.parameters()).device
    
    print(f"{'Seq Len':>8} | {'Batch':>5} | {'Tokens':>8} | {'Peak MB':>8} | {'% of 6GB':>8}")
    print("-" * 50)
    
    results = []
    for T in seq_lens:
        for B in batch_sizes:
            # Clear cache
            gc.collect()
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            
            try:
                model.train()
                x = torch.randint(0, 1000, (B, T), device=device)
                
                # Forward + backward (this is what training uses)
                _, loss, _, _ = model(x, x)
                loss.backward()
                
                peak_mb = torch.cuda.max_memory_allocated() / 1024**2
                pct_6gb = (peak_mb / 6144) * 100
                tokens = B * T
                
                print(f"{T:>8} | {B:>5} | {tokens:>8} | {peak_mb:>8.1f} | {pct_6gb:>7.1f}%")
                results.append({'seq_len': T, 'batch': B, 'tokens': tokens, 'peak_mb': peak_mb})
                
            except RuntimeError as e:
                if "out of memory" in str(e):
                    print(f"{T:>8} | {B:>5} | {'OOM':>8} | {'---':>8} | {'---':>8}")
                    gc.collect()
                    torch.cuda.empty_cache()
                else:
                    raise
            
            model.zero_grad()
    
    return results

print("Memory Profile for RTX 4050 (6GB)(DO NOT EXCEED 45%):")
print("="*50)
memory_results = profile_memory(model, seq_lens=[512, 1024, 2048, 4096], batch_sizes=[2, 4, 8])

Memory Profile for RTX 4050 (6GB)(DO NOT EXCEED 45%):
 Seq Len | Batch |   Tokens |  Peak MB | % of 6GB
--------------------------------------------------
     512 |     2 |     1024 |    513.3 |     8.4%
     512 |     4 |     2048 |    800.5 |    13.0%
     512 |     8 |     4096 |   1509.8 |    24.6%
    1024 |     2 |     2048 |    800.2 |    13.0%
    1024 |     4 |     4096 |   1509.2 |    24.6%
    1024 |     8 |     8192 |   2912.6 |    47.4%
    2048 |     2 |     4096 |   1507.6 |    24.5%
    2048 |     4 |     8192 |   2912.1 |    47.4%
    2048 |     8 |    16384 |   5723.5 |    93.2%
    4096 |     2 |     8192 |   2910.3 |    47.4%
    4096 |     4 |    16384 |   5723.0 |    93.1%
    4096 |     8 |    32768 |  11352.4 |   184.8%


In [9]:
# =============================================================================
# PROGRESSIVE CURRICULUM: Retrieval → Gym → LM
# =============================================================================
# Order matters! NO language modeling until multi-needle works.
#
#   Phase 1 (Retrieval): Single-needle until 95%+ (quick, ~100 steps)
#   Phase 2 (Gym):       Progressive multi-needle (2→5 needles at NEW_T)
#                        Trains β gating to be selective
#   Phase 3 (LM):        Language modeling ONLY after gym succeeds
#
# Key insight: Train at target seq_len from the start - no mismatch!

import importlib
analysis = importlib.reload(analysis)
train_progressive_curriculum = analysis.train_progressive_curriculum
proper_niah_test = analysis.proper_niah_test
multi_needle_test = analysis.multi_needle_test

# Fresh model
model = TransparentHybrid(cfg).to(DEVICE).bfloat16()
print(f"Fresh model: {model.count_params():,} params")
print(f"Training at seq_len={NEW_T} throughout (no mismatch)")

# Run progressive curriculum
history = train_progressive_curriculum(
    model,
    wikitext_loader=data_loader,
    max_steps=3000,
    lr=3e-4,
    log_interval=50,
    # Phase transitions
    retrieval_threshold=0.95,    # Move to gym when single-needle > 95%
    gym_threshold=0.70,          # Move to LM when multi-needle > 70%
    # Gym progression - USE NEW_T THROUGHOUT
    gym_start_needles=2,
    gym_max_needles=5,
    gym_start_seq=NEW_T,         # Start at target length!
    gym_max_seq=NEW_T,           # No seq progression
    # Safety limits
    max_retrieval_steps=200,
    max_gym_steps=2000,
)

# Final evaluation
print("\n" + "="*60)
print("FINAL EVALUATION")
print("="*60)

print("\n1. Single-needle (should be 100%):")
proper_niah_test(model, seq_len=NEW_T, n_trials=20)

print("\n2. Multi-needle (should be >70%):")
multi_needle_test(model, seq_len=NEW_T, n_needles=3, n_trials=30)

print("\n3. Multi-needle stress test (5 needles):")
multi_needle_test(model, seq_len=NEW_T, n_needles=5, n_trials=30)

Fresh model: 33,217,560 params
Training at seq_len=2048 throughout (no mismatch)
PROGRESSIVE CURRICULUM
Phase 1: Retrieval until 95% (max 200 steps)
Phase 2: Gym 2→5 needles until 70%
Phase 3: LM (only after gym succeeds)
[RETRIEVAL] Step    0: loss=11.375, β=0.131 (4.8 s/s)

[TRANSITION] Retrieval → Gym at step 22 (acc=95.7%)
Memory Gym: 2000 samples, T=2048, 2 needles, batch=4
[GYM ] Step   50: loss=5.451, acc=43.0%, needles=2, seq=2048, β=0.131, surv=[0.00,-0.00] (8.3 s/s)
[GYM ] Step  100: loss=4.982, acc=3.5%, needles=2, seq=2048, β=0.134, surv=[-0.00,0.00] (6.6 s/s)
[GYM ] Step  150: loss=3.729, acc=5.5%, needles=2, seq=2048, β=0.130, surv=[0.00,0.00] (6.5 s/s)
[GYM ] Step  200: loss=3.658, acc=3.0%, needles=2, seq=2048, β=0.129, surv=[0.00,0.00] (6.3 s/s)
[GYM ] Step  250: loss=3.630, acc=3.0%, needles=2, seq=2048, β=0.126, surv=[-0.00,-0.00] (6.1 s/s)
[GYM ] Step  300: loss=3.583, acc=2.5%, needles=2, seq=2048, β=0.136, surv=[0.00,0.00] (6.2 s/s)
[GYM ] Step  350: loss=3.577, a

{'per_needle': [{'needle_idx': 0, 'accuracy': 0, 'correct': 0, 'total': 0},
  {'needle_idx': 1, 'accuracy': 0, 'correct': 0, 'total': 0},
  {'needle_idx': 2, 'accuracy': 0, 'correct': 0, 'total': 0},
  {'needle_idx': 3, 'accuracy': 0, 'correct': 0, 'total': 0},
  {'needle_idx': 4, 'accuracy': 0, 'correct': 0, 'total': 0}],
 'total_correct': 0,
 'total_trials': 30,
 'confusion_matrix': tensor([[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]]),
 'overall_accuracy': 0.0}

In [12]:
# =============================================================================
# DIAGNOSTIC: Is the model using STATE or just SWA?
# =============================================================================
# This test places needles OUTSIDE the SWA window to verify state usage.
# If outside_window accuracy is ~0%, model bypasses GDN state entirely.

import importlib
analysis = importlib.reload(analysis)
test_state_vs_swa = analysis.test_state_vs_swa

# First: Test UNTRAINED model to check if architecture works
print("="*60)
print("ARCHITECTURE CHECK: Testing UNTRAINED model")
print("="*60)
untrained_model = TransparentHybrid(cfg).to(DEVICE).bfloat16()
print(f"Fresh untrained model: {untrained_model.count_params():,} params")
untrained_result = test_state_vs_swa(untrained_model, seq_len=NEW_T, n_trials=50)

# Then: Test trained model
print("\n" + "="*60)
print("TRAINED MODEL CHECK")
print("="*60)
print(f"Testing trained model with window_size={cfg.window_size}")
trained_result = test_state_vs_swa(model, seq_len=NEW_T, n_trials=50)

# Comparison
print("\n" + "="*60)
print("COMPARISON")
print("="*60)
print(f"Untrained outside-window: {untrained_result['outside_accuracy']*100:.1f}%")
print(f"Trained outside-window:   {trained_result['outside_accuracy']*100:.1f}%")
if trained_result['outside_accuracy'] < untrained_result['outside_accuracy']:
    print("\n⚠ TRAINING BROKE STATE RETRIEVAL")
elif trained_result['outside_accuracy'] > 0.5:
    print("\n✓ STATE IS WORKING")
else:
    print("\n✗ STATE NEVER WORKED (architecture issue?)")

ARCHITECTURE CHECK: Testing UNTRAINED model
Fresh untrained model: 33,217,560 params

STATE vs SWA VERIFICATION (window_size=64)

Inside SWA window (≤64 tokens):
  Accuracy: 0.0% (0/50)
  → Could be SWA OR state

Outside SWA window (>64 tokens):
  Accuracy: 0.0% (0/50)
  → MUST be state (SWA can't see)

✗ NO STATE USAGE: Model relies only on SWA (0.0%)

TRAINED MODEL CHECK
Testing trained model with window_size=64

STATE vs SWA VERIFICATION (window_size=64)

Inside SWA window (≤64 tokens):
  Accuracy: 2.0% (1/50)
  → Could be SWA OR state

Outside SWA window (>64 tokens):
  Accuracy: 2.0% (1/50)
  → MUST be state (SWA can't see)

✗ NO STATE USAGE: Model relies only on SWA (2.0%)

COMPARISON
Untrained outside-window: 0.0%
Trained outside-window:   2.0%

✗ STATE NEVER WORKED (architecture issue?)


In [16]:
# =============================================================================
# DEEP DIAGNOSTIC: Test the state mechanism at component level
# =============================================================================
# This is a unit-test level diagnostic that checks:
# 1. Is state being written (non-zero)?
# 2. Can we retrieve with the same key that wrote?
# 3. Does the SWA retrieval path work?

import importlib
analysis = importlib.reload(analysis)
diagnose_state_mechanism = analysis.diagnose_state_mechanism

# Run on trained model
print("Testing TRAINED model:")
trained_diag = diagnose_state_mechanism(model, verbose=True)

print("\n\nTesting UNTRAINED model:")
untrained_diag = diagnose_state_mechanism(untrained_model, verbose=True)

Testing TRAINED model:

STATE MECHANISM DIAGNOSTIC

1. STATE WRITING
   State norm: 49.7500
   State max:  1.4297
   ✓ State is being written

2. SINGLE-TOKEN SELF-RETRIEVAL (isolated GDN layer)
   Avg cosine similarity: 1.0000
   Per-head sims: ['1.00', '1.00', '1.00', '1.00', '1.00', '1.00', '1.00', '1.00']
   Magnitude ratios (ret/exp): ['0.99', '1.00', '0.02', '0.75', '0.07', '1.00', '0.96', '0.94']
   ✓ Self-retrieval works!

3. SWA RETRIEVAL PATH
   Query norm (after ReLU): 9.1875
   Query sparsity: 57.8% zeros
   Retrieved norm: 14.5801
   ✓ SWA retrieval produces non-zero output

4. MULTI-TOKEN INTERFERENCE (10 tokens)
   Avg similarity across all tokens: 0.5330
   First token (most overwritten): 0.4084
   Last token (most recent): 0.9766
   Per-token: ['0.41', '0.64', '0.35', '0.44', '0.43', '0.43', '0.64', '0.55', '0.47', '0.98']
   ✗ Multi-token retrieval fails

SUMMARY
✓ All mechanisms working - architecture CAN retrieve from state


Testing UNTRAINED model:

STATE MECHANIS

In [17]:
# =============================================================================
# COMPARE GATE VALUES: trained vs untrained
# =============================================================================
# Check if training changed β and g in ways that hurt memory

import torch

def analyze_gates(model, name="Model"):
    print(f"\n{name} gate analysis:")
    for i, layer in enumerate(model.layers):
        if hasattr(layer, 'beta_proj') and hasattr(layer, 'g_proj'):
            # Create random input to measure gate activation
            with torch.no_grad():
                x = torch.randn(1, 64, model.cfg.d_model, device=DEVICE, dtype=torch.bfloat16)
                x_norm = layer.norm(x)
                beta = torch.sigmoid(layer.beta_proj(x_norm))
                g = torch.sigmoid(layer.g_proj(x_norm))
                
                print(f"  Layer {i} (GDN):")
                print(f"    β: mean={beta.mean():.4f}, std={beta.std():.4f}, range=[{beta.min():.3f}, {beta.max():.3f}]")
                print(f"    g: mean={g.mean():.4f}, std={g.std():.4f}, range=[{g.min():.3f}, {g.max():.3f}]")
                
                # Check bias values
                beta_bias = layer.beta_proj.bias.data
                g_bias = layer.g_proj.bias.data
                print(f"    β_bias: {beta_bias.mean():.4f} (init was -2.0)")
                print(f"    g_bias: {g_bias.mean():.4f} (init was 3.0)")

analyze_gates(untrained_model, "UNTRAINED")
analyze_gates(model, "TRAINED")

# Key insight: 
# - β controls how much NEW info is written (high β = more writing)
# - g controls how much OLD info is retained (high g = more retention)
# For good memory: need low β (selective writing) + high g (retention)
# If training increased β or decreased g, memory is hurt


UNTRAINED gate analysis:
  Layer 0 (GDN):
    β: mean=0.1299, std=0.0645, range=[0.022, 0.447]
    g: mean=0.8672, std=0.0669, range=[0.559, 0.977]
    β_bias: -2.0000 (init was -2.0)
    g_bias: 2.0000 (init was 3.0)

TRAINED gate analysis:
  Layer 0 (GDN):
    β: mean=0.1504, std=0.1025, range=[0.015, 0.816]
    g: mean=0.8633, std=0.0742, range=[0.590, 0.977]
    β_bias: -2.0000 (init was -2.0)
    g_bias: 2.0000 (init was 3.0)


In [18]:
# =============================================================================
# WHAT β DO MARKER TOKENS PRODUCE?
# =============================================================================
# Check if marker/needle tokens get special treatment

def check_marker_gates(model, name="Model"):
    gdn_layer = None
    for layer in model.layers:
        if hasattr(layer, 'beta_proj'):
            gdn_layer = layer
            break
    
    if gdn_layer is None:
        print(f"{name}: No GDN layer found")
        return
    
    print(f"\n{name} - Gate values for special tokens:")
    
    tokens_to_check = {
        'marker_token': cfg.marker_token,
        'cue_token': cfg.cue_token,
        'random_token': 500,
        'random_token2': 1000,
    }
    
    with torch.no_grad():
        for token_name, token_id in tokens_to_check.items():
            tok = torch.tensor([[token_id]], device=DEVICE)
            emb = model.embed(tok)
            x_norm = gdn_layer.norm(emb)
            
            beta = torch.sigmoid(gdn_layer.beta_proj(x_norm))
            g = torch.sigmoid(gdn_layer.g_proj(x_norm))
            
            print(f"  {token_name} ({token_id}): β={beta.mean():.4f}, g={g.mean():.4f}")

check_marker_gates(untrained_model, "UNTRAINED")
check_marker_gates(model, "TRAINED")

print(f"\nMarker token: {cfg.marker_token}, Cue token: {cfg.cue_token}")


UNTRAINED - Gate values for special tokens:
  marker_token (50251): β=0.1309, g=0.8242
  cue_token (50250): β=0.1826, g=0.8516
  random_token (500): β=0.1260, g=0.9062
  random_token2 (1000): β=0.1221, g=0.9102

TRAINED - Gate values for special tokens:
  marker_token (50251): β=0.5547, g=0.9375
  cue_token (50250): β=0.4023, g=0.8086
  random_token (500): β=0.7148, g=0.9336
  random_token2 (1000): β=0.7500, g=0.9727

Marker token: 50251, Cue token: 50250


In [20]:
# =============================================================================
# KEY ORTHOGONALITY CHECK
# =============================================================================
# If keys for different tokens are orthogonal, writes don't interfere.
# If keys are similar (high cosine), writes overwrite each other.

import torch.nn.functional as F

def check_key_similarity(model, n_tokens=50, name="Model"):
    gdn_layer = None
    for layer in model.layers:
        if hasattr(layer, 'k_proj'):
            gdn_layer = layer
            break
    
    if gdn_layer is None:
        print(f"{name}: No GDN layer found")
        return
    
    print(f"\n{name} - Key similarity analysis:")
    
    # Sample random tokens and compute their keys
    tokens = torch.randint(100, 10000, (1, n_tokens), device=DEVICE)
    
    with torch.no_grad():
        emb = model.embed(tokens)
        x_norm = gdn_layer.norm(emb)
        keys = gdn_layer.k_proj(x_norm).view(1, n_tokens, cfg.n_heads, cfg.head_dim)
        keys = F.normalize(keys.float(), p=2, dim=-1)  # [1, T, H, K]
        
        # Compute pairwise cosine similarity (keys are already normalized)
        # For each head, compute T×T similarity matrix
        avg_sims = []
        for h in range(cfg.n_heads):
            k_h = keys[0, :, h, :]  # [T, K]
            sim_matrix = k_h @ k_h.T  # [T, T] - all pairwise similarities
            
            # Get off-diagonal elements (exclude self-similarity)
            mask = ~torch.eye(n_tokens, dtype=torch.bool, device=DEVICE)
            off_diag_sims = sim_matrix[mask]
            
            avg_sim = off_diag_sims.abs().mean().item()
            avg_sims.append(avg_sim)
        
        overall_avg = sum(avg_sims) / len(avg_sims)
        max_avg = max(avg_sims)
        
        print(f"  Avg |cosine| between different token keys:")
        print(f"    Per-head: {[f'{s:.3f}' for s in avg_sims]}")
        print(f"    Overall: {overall_avg:.3f}")
        print(f"    Max head: {max_avg:.3f}")
        
        # For perfect orthogonality, avg should be close to 0
        # For random vectors in K dimensions, expected |cosine| ≈ sqrt(2/π)/sqrt(K)
        expected_random = (2/3.14159)**0.5 / (cfg.head_dim**0.5)
        print(f"    Expected for random {cfg.head_dim}D vectors: {expected_random:.3f}")
        
        if overall_avg < expected_random:
            print(f"    ✓ Keys are MORE orthogonal than random")
        else:
            print(f"    ⚠ Keys are LESS orthogonal than random (more interference)")

check_key_similarity(untrained_model, name="UNTRAINED")
check_key_similarity(model, name="TRAINED")


UNTRAINED - Key similarity analysis:
  Avg |cosine| between different token keys:
    Per-head: ['0.105', '0.107', '0.102', '0.105', '0.107', '0.108', '0.106', '0.104']
    Overall: 0.105
    Max head: 0.108
    Expected for random 64D vectors: 0.100
    ⚠ Keys are LESS orthogonal than random (more interference)

TRAINED - Key similarity analysis:
  Avg |cosine| between different token keys:
    Per-head: ['0.690', '0.593', '0.561', '0.398', '0.596', '0.384', '0.652', '0.786']
    Overall: 0.582
    Max head: 0.786
    Expected for random 64D vectors: 0.100
    ⚠ Keys are LESS orthogonal than random (more interference)


## The Key Insight: Training Collapsed the Key Space!

**Problem discovered:**
- Untrained model: Key similarity = 0.105 (near-random, orthogonal)
- Trained model: Key similarity = 0.582 (highly correlated, interfering)

**What happened:**
- Standard LM training optimizes for prediction, not memory
- The model learned to project similar inputs to similar keys
- This means EVERY token write interferes with previous writes
- Early needles get overwritten by subsequent haystack tokens

**Solution: Add regularization losses**
1. **Key orthogonality loss**: Penalize high cosine similarity between keys
2. **Beta sparsity loss**: Encourage low β on average (selective writing)

Let's try training with these regularizations to preserve memory capability:

In [23]:
# =============================================================================
# FRESH MODEL WITH REGULARIZED TRAINING
# =============================================================================
# Create a new untrained model and train with key orthogonality + beta sparsity regularization

import importlib
import model as model_module
analysis = importlib.reload(analysis)
model_module = importlib.reload(model_module)

from model import TransparentHybrid
from analysis import train_with_key_reg, diagnose_state_mechanism

# Create fresh model with bfloat16
fresh_model = TransparentHybrid(cfg).to(DEVICE).to(torch.bfloat16)
print(f"Fresh model: {fresh_model.count_params()/1e6:.1f}M parameters")

# Check initial state
print("\n--- INITIAL DIAGNOSTICS ---")
check_key_similarity(fresh_model, name="FRESH (before training)")

# Train with regularization
print("\n--- TRAINING WITH REGULARIZATION ---")
history_reg = train_with_key_reg(
    fresh_model, 
    data_loader, 
    steps=500,
    lr=1e-4,  # Lower LR for stability
    key_orth_weight=0.5,     # Strong orthogonality pressure
    beta_sparsity_weight=0.3, # Moderate sparsity pressure
    retrieval_weight=1.0,
    log_interval=50
)

# Check after training
print("\n--- POST-TRAINING DIAGNOSTICS ---")
check_key_similarity(fresh_model, name="FRESH (after regularized training)")

Fresh model: 33.2M parameters

--- INITIAL DIAGNOSTICS ---

FRESH (before training) - Key similarity analysis:
  Avg |cosine| between different token keys:
    Per-head: ['0.107', '0.103', '0.104', '0.103', '0.106', '0.101', '0.109', '0.102']
    Overall: 0.104
    Max head: 0.109
    Expected for random 64D vectors: 0.100
    ⚠ Keys are LESS orthogonal than random (more interference)

--- TRAINING WITH REGULARIZATION ---
Training with regularization (500 steps)
  LR: 0.0001
  Weights: retrieval=1.0, key_orth=0.5, beta_sparse=0.3
Step    0: LM=10.938, KO=0.027, BS=0.131, RET=11.438 (1.4 s/s)
Step   50: LM=8.125, KO=0.031, BS=0.129, RET=0.961 (3.9 s/s)
Step  100: LM=7.031, KO=0.026, BS=0.118, RET=0.590 (4.3 s/s)
Step  150: LM=6.844, KO=0.022, BS=0.110, RET=0.400 (4.3 s/s)
Step  200: LM=6.625, KO=0.026, BS=0.107, RET=0.312 (4.4 s/s)
Step  250: LM=6.625, KO=0.026, BS=0.104, RET=0.238 (4.4 s/s)
Step  300: LM=6.344, KO=0.027, BS=0.100, RET=0.176 (4.4 s/s)
Step  350: LM=6.406, KO=0.029, BS=0

In [24]:
# =============================================================================
# TEST REGULARIZED MODEL
# =============================================================================
# Check if the regularized model preserved memory capability

# Deep diagnostic
print("--- MECHANISM DIAGNOSTIC ---")
fresh_diag = diagnose_state_mechanism(fresh_model, verbose=True)

# Multi-needle test  
print("\n--- MULTI-NEEDLE TEST ---")
from analysis import multi_needle_test
multi_result = multi_needle_test(fresh_model, seq_len=512, n_needles=3, n_trials=20)

# Compare to the collapsed (no-reg) model
print("\n--- COMPARE TO COLLAPSED (NO-REG) MODEL ---")
multi_result_old = multi_needle_test(model, seq_len=512, n_needles=3, n_trials=20)

--- MECHANISM DIAGNOSTIC ---

STATE MECHANISM DIAGNOSTIC

1. STATE WRITING
   State norm: 5.3750
   State max:  0.1826
   ✓ State is being written

2. SINGLE-TOKEN SELF-RETRIEVAL (isolated GDN layer)
   Avg cosine similarity: 1.0000
   Per-head sims: ['1.00', '1.00', '1.00', '1.00', '1.00', '1.00', '1.00', '1.00']
   Magnitude ratios (ret/exp): ['0.07', '0.10', '0.03', '0.12', '0.04', '0.11', '0.16', '0.10']
   ✓ Self-retrieval works!

3. SWA RETRIEVAL PATH
   Query norm (after ReLU): 7.0625
   Query sparsity: 49.6% zeros
   Retrieved norm: 1.9044
   ✓ SWA retrieval produces non-zero output

4. MULTI-TOKEN INTERFERENCE (10 tokens)
   Avg similarity across all tokens: 0.7913
   First token (most overwritten): 0.4656
   Last token (most recent): 0.9852
   Per-token: ['0.47', '0.71', '0.62', '0.69', '0.81', '0.86', '0.95', '0.90', '0.93', '0.99']
   ✓ Multi-token retrieval works!

SUMMARY
✓ All mechanisms working - architecture CAN retrieve from state

--- MULTI-NEEDLE TEST ---

MULTI-NEE

In [25]:
# =============================================================================
# SINGLE NEEDLE TEST
# =============================================================================
# Does the basic NIAH work at all?

from analysis import proper_niah_test, test_state_vs_swa

print("--- REGULARIZED MODEL (fresh_model) ---")
print("\nSingle needle test:")
niah_reg = proper_niah_test(fresh_model, seq_len=256, needle_pos=32, n_trials=30)

print("\nState vs SWA verification:")
swa_reg = test_state_vs_swa(fresh_model, seq_len=512, n_trials=20)

print("\n--- ORIGINAL TRAINED MODEL (model) ---")
print("\nSingle needle test:")
niah_old = proper_niah_test(model, seq_len=256, needle_pos=32, n_trials=30)

print("\nState vs SWA verification:")
swa_old = test_state_vs_swa(model, seq_len=512, n_trials=20)

--- REGULARIZED MODEL (fresh_model) ---

Single needle test:
  Accuracy: 100.0% (30/30)

State vs SWA verification:

STATE vs SWA VERIFICATION (window_size=64)

Inside SWA window (≤64 tokens):
  Accuracy: 0.0% (0/20)
  → Could be SWA OR state

Outside SWA window (>64 tokens):
  Accuracy: 0.0% (0/20)
  → MUST be state (SWA can't see)

✗ NO STATE USAGE: Model relies only on SWA (0.0%)

--- ORIGINAL TRAINED MODEL (model) ---

Single needle test:
  Accuracy: 0.0% (0/30)

State vs SWA verification:

STATE vs SWA VERIFICATION (window_size=64)

Inside SWA window (≤64 tokens):
  Accuracy: 0.0% (0/20)
  → Could be SWA OR state

Outside SWA window (>64 tokens):
  Accuracy: 0.0% (0/20)
  → MUST be state (SWA can't see)

✗ NO STATE USAGE: Model relies only on SWA (0.0%)


In [26]:
# =============================================================================
# VERIFY RETRIEVAL VS MEMORIZATION
# =============================================================================
# Test if model actually retrieves or just memorized "CUE → specific token"

def test_varied_needles(model, seq_len=256, needle_pos=32, n_trials=30):
    """Test with DIFFERENT needle IDs each trial to verify retrieval."""
    model.eval()
    device = next(model.parameters()).device
    cfg = model.cfg
    
    correct = 0
    for trial in range(n_trials):
        # Use different needle ID each trial
        needle_id = cfg.vocab_size - 50 + (trial % 50)
        
        seq = torch.randint(0, cfg.vocab_size - 100, (1, seq_len), device=device)
        seq[0, needle_pos] = cfg.marker_token
        seq[0, needle_pos + 1] = needle_id
        seq[0, -1] = cfg.cue_token
        
        with torch.no_grad():
            logits, _, _, _ = model(seq)
        
        pred = logits[0, -1].argmax().item()
        if pred == needle_id:
            correct += 1
    
    acc = correct / n_trials
    print(f"  Accuracy with varied needles: {acc*100:.1f}% ({correct}/{n_trials})")
    return acc

print("--- REGULARIZED MODEL (varied needle IDs) ---")
varied_acc_reg = test_varied_needles(fresh_model, seq_len=256, needle_pos=32, n_trials=30)

print("\n--- ORIGINAL TRAINED MODEL (varied needle IDs) ---")
varied_acc_old = test_varied_needles(model, seq_len=256, needle_pos=32, n_trials=30)

print("\n--- UNTRAINED MODEL (varied needle IDs) ---")
varied_acc_untrained = test_varied_needles(untrained_model, seq_len=256, needle_pos=32, n_trials=30)

--- REGULARIZED MODEL (varied needle IDs) ---
  Accuracy with varied needles: 0.0% (0/30)

--- ORIGINAL TRAINED MODEL (varied needle IDs) ---
  Accuracy with varied needles: 0.0% (0/30)

--- UNTRAINED MODEL (varied needle IDs) ---
  Accuracy with varied needles: 0.0% (0/30)


In [27]:
# =============================================================================
# RETRAIN WITH FIXED RETRIEVAL LOSS (varied needle IDs)
# =============================================================================
# The previous training used fixed needle ID - model just memorized "CUE → specific token"
# Now use varied needle IDs to force actual retrieval learning

import importlib
analysis = importlib.reload(analysis)
from analysis import train_with_key_reg

# Create a new fresh model
fresh_model_v2 = TransparentHybrid(cfg).to(DEVICE).to(torch.bfloat16)
print(f"Fresh model v2: {fresh_model_v2.count_params()/1e6:.1f}M parameters")

# Train with fixed retrieval loss (varied needles) + regularization
print("\n--- TRAINING WITH VARIED NEEDLES + REGULARIZATION ---")
history_v2 = train_with_key_reg(
    fresh_model_v2, 
    data_loader, 
    steps=1000,  # More steps
    lr=1e-4,
    key_orth_weight=0.5,
    beta_sparsity_weight=0.2,
    retrieval_weight=2.0,  # Higher retrieval weight
    log_interval=100
)

# Test with varied needle IDs
print("\n--- TEST WITH VARIED NEEDLES ---")
varied_acc_v2 = test_varied_needles(fresh_model_v2, seq_len=256, needle_pos=32, n_trials=30)

Fresh model v2: 33.2M parameters

--- TRAINING WITH VARIED NEEDLES + REGULARIZATION ---
Training with regularization (1000 steps)
  LR: 0.0001
  Weights: retrieval=2.0, key_orth=0.5, beta_sparse=0.2
Step    0: LM=10.938, KO=0.033, BS=0.131, RET=11.000 (1.8 s/s)
Step  100: LM=7.156, KO=0.024, BS=0.124, RET=10.750 (4.5 s/s)
Step  200: LM=6.844, KO=0.027, BS=0.114, RET=9.438 (4.6 s/s)
Step  300: LM=6.594, KO=0.021, BS=0.107, RET=8.250 (4.5 s/s)
Step  400: LM=6.406, KO=0.030, BS=0.104, RET=8.250 (4.5 s/s)
Step  500: LM=6.219, KO=0.027, BS=0.102, RET=8.875 (4.5 s/s)
Step  600: LM=6.250, KO=0.022, BS=0.102, RET=7.281 (4.4 s/s)
Step  700: LM=6.188, KO=0.025, BS=0.100, RET=7.250 (4.4 s/s)
Step  800: LM=6.094, KO=0.046, BS=0.099, RET=6.250 (4.4 s/s)
Step  900: LM=5.844, KO=0.028, BS=0.096, RET=5.750 (4.3 s/s)
Training complete: 1000 steps in 229.8s

Final metrics:
  Key orthogonality: 0.0251 (lower = more orthogonal)
  Beta sparsity: 0.0991 (lower = more selective)

--- TEST WITH VARIED NEEDLES

## THE FIX: Shifted Value Mode

**The core problem was architectural:**
- Original: Store `(k_t, v_t)` - key and value from SAME token
- This can't learn MARKER → VALUE because at MARKER position, we haven't seen VALUE yet!

**The fix:**
- Shifted: Store `(k_t, v_{t+1})` - key from token t, value from token t+1
- Now when we see MARKER, we store the VALUE that comes after
- Query with key(MARKER) retrieves value(VALUE)

This is the correct way to build associative memory in an autoregressive model.

In [28]:
# =============================================================================
# TEST SHIFTED VALUE MODE
# =============================================================================
# Create fresh model with shifted_value=True (now the default)

import importlib
import config as config_module
import model as model_module
import analysis as analysis_module

config_module = importlib.reload(config_module)
model_module = importlib.reload(model_module)
analysis_module = importlib.reload(analysis_module)

from config import HybridConfig
from model import TransparentHybrid
from analysis import diagnose_state_mechanism, train_with_key_reg

# Create config with shifted_value=True (default)
cfg_shifted = HybridConfig(
    d_model=512,
    n_heads=8,
    head_dim=64,
    value_dim=128,
    layer_pattern="GS",
    window_size=64,
    chunk_size=64,
    shifted_value=True,  # THE FIX
)

print(f"Shifted value mode: {cfg_shifted.shifted_value}")
print(f"Config: {cfg_shifted}")

# Create model
shifted_model = TransparentHybrid(cfg_shifted).to(DEVICE).to(torch.bfloat16)
print(f"\nShifted model: {shifted_model.count_params()/1e6:.1f}M parameters")

# Quick forward test
x_test = torch.randint(0, 1000, (2, 64), device=DEVICE)
with torch.no_grad():
    logits, _, diags, state = shifted_model(x_test)
print(f"Forward pass OK: output={logits.shape}")

# Check the mechanism
print("\n--- MECHANISM DIAGNOSTIC (SHIFTED VALUE) ---")
shifted_diag = diagnose_state_mechanism(shifted_model, verbose=True)

Shifted value mode: True
Config: HybridConfig(GS, d=512, h=8, K=64, V=128)

Shifted model: 33.2M parameters
Forward pass OK: output=torch.Size([2, 64, 50257])

--- MECHANISM DIAGNOSTIC (SHIFTED VALUE) ---

STATE MECHANISM DIAGNOSTIC

1. STATE WRITING
   State norm: 5.4688
   State max:  0.1504
   ✓ State is being written

2. SHIFTED-VALUE RETRIEVAL (key=token0, expect=value_of_token1)
   Shifted value mode: True
   Avg cosine similarity: 1.0000
   Per-head sims: ['1.00', '1.00', '1.00', '1.00', '1.00', '1.00', '1.00', '1.00']
   Magnitude ratios (ret/exp): ['0.14', '0.10', '0.10', '0.15', '0.05', '0.13', '0.11', '0.15']
   ✓ Retrieval works!

3. SWA RETRIEVAL PATH
   Query norm (after ReLU): 6.9688
   Query sparsity: 49.4% zeros
   Retrieved norm: 1.5725
   ✓ SWA retrieval produces non-zero output

4. MULTI-TOKEN INTERFERENCE (10 tokens)
   Avg similarity across all tokens: 0.0501
   First token (most overwritten): 0.0160
   Last token (most recent): 0.2831
   Per-token: ['0.02', '0.05

In [29]:
# =============================================================================
# TEST SHIFTED MODEL ON NIAH (UNTRAINED)
# =============================================================================
# Even untrained, the architecture should show SOME retrieval capability
# because the shifted value mode aligns key(MARKER) with value(NEEDLE)

print("--- UNTRAINED SHIFTED MODEL ---")
print("\n1. Single needle with varied IDs:")
varied_acc_shifted = test_varied_needles(shifted_model, seq_len=256, needle_pos=32, n_trials=30)

print("\n2. State vs SWA verification:")
from analysis import test_state_vs_swa
swa_shifted = test_state_vs_swa(shifted_model, seq_len=512, n_trials=20)

# The key question: Does outside-window work better now?
print("\n" + "="*60)
print("COMPARISON: Shifted vs Original (both untrained)")
print("="*60)
print(f"Shifted model outside-window: {swa_shifted['outside_accuracy']*100:.1f}%")
print(f"Original model outside-window: {untrained_result['outside_accuracy']*100:.1f}%")

--- UNTRAINED SHIFTED MODEL ---

1. Single needle with varied IDs:
  Accuracy with varied needles: 0.0% (0/30)

2. State vs SWA verification:

STATE vs SWA VERIFICATION (window_size=64)

Inside SWA window (≤64 tokens):
  Accuracy: 0.0% (0/20)
  → Could be SWA OR state

Outside SWA window (>64 tokens):
  Accuracy: 0.0% (0/20)
  → MUST be state (SWA can't see)

✗ NO STATE USAGE: Model relies only on SWA (0.0%)

COMPARISON: Shifted vs Original (both untrained)
Shifted model outside-window: 0.0%
Original model outside-window: 0.0%


In [30]:
# =============================================================================
# TRAIN SHIFTED MODEL
# =============================================================================
# Now train with regularization + varied needle IDs
# The shifted value architecture should now be CAPABLE of learning retrieval

print("--- TRAINING SHIFTED MODEL ---")
history_shifted = train_with_key_reg(
    shifted_model, 
    data_loader, 
    steps=1000,
    lr=1e-4,
    key_orth_weight=0.3,      # Moderate orthogonality
    beta_sparsity_weight=0.1, # Light sparsity 
    retrieval_weight=3.0,     # Strong retrieval signal
    log_interval=100
)

# Test after training
print("\n--- POST-TRAINING TEST ---")
print("\n1. Varied needle test:")
varied_acc_shifted_trained = test_varied_needles(shifted_model, seq_len=256, needle_pos=32, n_trials=30)

print("\n2. State vs SWA:")
swa_shifted_trained = test_state_vs_swa(shifted_model, seq_len=512, n_trials=20)

print("\n3. Key similarity (should stay low):")
check_key_similarity(shifted_model, name="SHIFTED (after training)")

--- TRAINING SHIFTED MODEL ---
Training with regularization (1000 steps)
  LR: 0.0001
  Weights: retrieval=3.0, key_orth=0.3, beta_sparse=0.1
Step    0: LM=10.938, KO=0.031, BS=0.127, RET=10.938 (0.4 s/s)
Step  100: LM=7.125, KO=0.020, BS=0.116, RET=9.375 (4.2 s/s)
Step  200: LM=6.875, KO=0.025, BS=0.114, RET=8.875 (4.3 s/s)
Step  300: LM=6.438, KO=0.024, BS=0.111, RET=9.125 (4.3 s/s)
Step  400: LM=6.438, KO=0.025, BS=0.110, RET=7.938 (4.3 s/s)
Step  500: LM=6.500, KO=0.019, BS=0.111, RET=7.531 (4.4 s/s)
Step  600: LM=6.375, KO=0.023, BS=0.110, RET=7.219 (4.4 s/s)
Step  700: LM=6.281, KO=0.026, BS=0.108, RET=6.688 (4.4 s/s)
Step  800: LM=5.906, KO=0.027, BS=0.105, RET=6.500 (4.4 s/s)
Step  900: LM=5.938, KO=0.028, BS=0.105, RET=6.250 (4.4 s/s)
Training complete: 1000 steps in 228.3s

Final metrics:
  Key orthogonality: 0.0237 (lower = more orthogonal)
  Beta sparsity: 0.1050 (lower = more selective)

--- POST-TRAINING TEST ---

1. Varied needle test:
  Accuracy with varied needles: 0.0

In [31]:
# =============================================================================
# DEBUG: What is the model actually predicting?
# =============================================================================

def debug_niah_prediction(model, seq_len=256, needle_pos=32):
    """Debug a single NIAH prediction to see what's happening."""
    model.eval()
    device = next(model.parameters()).device
    cfg = model.cfg
    
    # Create test sequence
    needle_id = cfg.vocab_size - 50
    seq = torch.randint(0, cfg.vocab_size - 100, (1, seq_len), device=device)
    seq[0, needle_pos] = cfg.marker_token
    seq[0, needle_pos + 1] = needle_id
    seq[0, -1] = cfg.cue_token
    
    print(f"Test setup:")
    print(f"  Needle ID: {needle_id}")
    print(f"  Marker token: {cfg.marker_token} at pos {needle_pos}")
    print(f"  Cue token: {cfg.cue_token} at pos {seq_len-1}")
    print(f"  Distance: {seq_len - 1 - needle_pos} tokens")
    
    with torch.no_grad():
        logits, _, diags, final_state = model(seq)
    
    # Get prediction at CUE position
    pred_logits = logits[0, -1]  # [vocab_size]
    pred_token = pred_logits.argmax().item()
    pred_probs = torch.softmax(pred_logits.float(), dim=-1)
    
    # Top 5 predictions
    top5_probs, top5_tokens = pred_probs.topk(5)
    
    print(f"\nPrediction at CUE position:")
    print(f"  Predicted: {pred_token}")
    print(f"  Expected:  {needle_id}")
    print(f"  Correct:   {pred_token == needle_id}")
    
    print(f"\nTop 5 predictions:")
    for i, (prob, tok) in enumerate(zip(top5_probs, top5_tokens)):
        marker = " ← NEEDLE" if tok.item() == needle_id else ""
        marker = " ← MARKER" if tok.item() == cfg.marker_token else marker
        marker = " ← CUE" if tok.item() == cfg.cue_token else marker
        print(f"  {i+1}. Token {tok.item():5d}: {prob.item()*100:5.2f}%{marker}")
    
    # Check probability of correct answer
    needle_prob = pred_probs[needle_id].item()
    print(f"\nNeedle ID ({needle_id}) probability: {needle_prob*100:.4f}%")
    
    # Check state
    print(f"\nState diagnostics:")
    print(f"  State norm: {final_state.norm().item():.4f}")
    print(f"  State max:  {final_state.abs().max().item():.4f}")
    
    # Check layer diagnostics
    for d in diags:
        if 'beta_mean' in d:
            print(f"  Layer {d['layer_idx']} ({d['layer']}): β={d['beta_mean']:.4f}")

print("--- DEBUG: Shifted trained model ---")
debug_niah_prediction(shifted_model)

print("\n\n--- DEBUG: Original untrained model ---")
debug_niah_prediction(untrained_model)

--- DEBUG: Shifted trained model ---
Test setup:
  Needle ID: 50207
  Marker token: 50251 at pos 32
  Cue token: 50250 at pos 255
  Distance: 223 tokens

Prediction at CUE position:
  Predicted: 50164
  Expected:  50207
  Correct:   False

Top 5 predictions:
  1. Token 50164:  0.90%
  2. Token 50186:  0.79%
  3. Token 50183:  0.70%
  4. Token 50178:  0.60%
  5. Token 50169:  0.53%

Needle ID (50207) probability: 0.0006%

State diagnostics:
  State norm: 10.5625
  State max:  0.2695
  Layer 0 (G): β=0.1562


--- DEBUG: Original untrained model ---
Test setup:
  Needle ID: 50207
  Marker token: 50251 at pos 32
  Cue token: 50250 at pos 255
  Distance: 223 tokens

Prediction at CUE position:
  Predicted: 9810
  Expected:  50207
  Correct:   False

Top 5 predictions:
  1. Token  9810:  0.01%
  2. Token 33057:  0.01%
  3. Token 49519:  0.01%
  4. Token 18198:  0.01%
  5. Token 44512:  0.01%

Needle ID (50207) probability: 0.0008%

State diagnostics:
  State norm: 5.3750
  State max:  0.2432

In [32]:
# =============================================================================
# PURE RETRIEVAL TRAINING (No LM distraction)
# =============================================================================
# The issue: LM loss dominates, retrieval doesn't learn
# Solution: Train ONLY on retrieval first, then add LM

import importlib
analysis = importlib.reload(analysis)
from analysis import compute_retrieval_loss

# Fresh shifted model
cfg_shifted2 = HybridConfig(
    d_model=512, n_heads=8, head_dim=64, value_dim=128,
    layer_pattern="GS", window_size=64, chunk_size=64,
    shifted_value=True,
)
shifted_model2 = TransparentHybrid(cfg_shifted2).to(DEVICE).to(torch.bfloat16)

# PURE RETRIEVAL TRAINING
print("--- PURE RETRIEVAL TRAINING ---")
optimizer = torch.optim.AdamW(shifted_model2.parameters(), lr=3e-4, weight_decay=0.01)

shifted_model2.train()
for step in range(500):
    optimizer.zero_grad()
    
    # Only retrieval loss - no LM, no regularization
    ret_loss = compute_retrieval_loss(shifted_model2, seq_len=256, batch_size=8)
    ret_loss.backward()
    
    torch.nn.utils.clip_grad_norm_(shifted_model2.parameters(), 1.0)
    optimizer.step()
    
    if step % 50 == 0:
        print(f"Step {step:4d}: RET={ret_loss.item():.4f}")

# Test
print("\n--- TEST AFTER PURE RETRIEVAL TRAINING ---")
shifted_model2.eval()
varied_acc_pure = test_varied_needles(shifted_model2, seq_len=256, needle_pos=32, n_trials=30)

print("\nDebug prediction:")
debug_niah_prediction(shifted_model2)

--- PURE RETRIEVAL TRAINING ---
Step    0: RET=10.6250
Step   50: RET=5.8125
Step  100: RET=4.1562
Step  150: RET=4.4688
Step  200: RET=3.9844
Step  250: RET=3.6875
Step  300: RET=3.7188
Step  350: RET=2.2344
Step  400: RET=3.6719
Step  450: RET=1.8828

--- TEST AFTER PURE RETRIEVAL TRAINING ---
  Accuracy with varied needles: 0.0% (0/30)

Debug prediction:
Test setup:
  Needle ID: 50207
  Marker token: 50251 at pos 32
  Cue token: 50250 at pos 255
  Distance: 223 tokens

Prediction at CUE position:
  Predicted: 50163
  Expected:  50207
  Correct:   False

Top 5 predictions:
  1. Token 50163:  4.27%
  2. Token 50204:  4.27%
  3. Token 50183:  3.89%
  4. Token 50158:  3.89%
  5. Token 50178:  3.54%

Needle ID (50207) probability: 0.0000%

State diagnostics:
  State norm: 2.9531
  State max:  0.2256
  Layer 0 (G): β=0.0272


In [33]:
# =============================================================================
# PURE RETRIEVAL TRAINING v2 (Wide vocab range)
# =============================================================================
# Previous issue: Model learned "predict high vocab after CUE" shortcut
# Fix: Use needle IDs from FULL vocab range

import importlib
analysis = importlib.reload(analysis)
from analysis import compute_retrieval_loss

# Fresh shifted model
cfg_shifted3 = HybridConfig(
    d_model=512, n_heads=8, head_dim=64, value_dim=128,
    layer_pattern="GS", window_size=64, chunk_size=64,
    shifted_value=True,
)
shifted_model3 = TransparentHybrid(cfg_shifted3).to(DEVICE).to(torch.bfloat16)

def test_varied_needles_wide(model, seq_len=256, needle_pos=32, n_trials=30):
    """Test with needle IDs from FULL vocab range."""
    model.eval()
    device = next(model.parameters()).device
    cfg = model.cfg
    
    correct = 0
    for trial in range(n_trials):
        # Use needle from FULL range (100 to vocab_size-100)
        needle_id = torch.randint(100, cfg.vocab_size - 100, (1,)).item()
        
        seq = torch.randint(0, cfg.vocab_size - 100, (1, seq_len), device=device)
        seq[0, needle_pos] = cfg.marker_token
        seq[0, needle_pos + 1] = needle_id
        seq[0, -1] = cfg.cue_token
        
        with torch.no_grad():
            logits, _, _, _ = model(seq)
        
        pred = logits[0, -1].argmax().item()
        if pred == needle_id:
            correct += 1
    
    acc = correct / n_trials
    print(f"  Accuracy (full vocab range): {acc*100:.1f}% ({correct}/{n_trials})")
    return acc

# PURE RETRIEVAL TRAINING
print("--- PURE RETRIEVAL TRAINING (wide vocab) ---")
optimizer = torch.optim.AdamW(shifted_model3.parameters(), lr=3e-4, weight_decay=0.01)

shifted_model3.train()
for step in range(1000):  # More steps
    optimizer.zero_grad()
    
    ret_loss = compute_retrieval_loss(shifted_model3, seq_len=256, batch_size=8)
    ret_loss.backward()
    
    torch.nn.utils.clip_grad_norm_(shifted_model3.parameters(), 1.0)
    optimizer.step()
    
    if step % 100 == 0:
        print(f"Step {step:4d}: RET={ret_loss.item():.4f}")

# Test
print("\n--- TEST ---")
shifted_model3.eval()
acc_wide = test_varied_needles_wide(shifted_model3, seq_len=256, needle_pos=32, n_trials=50)

--- PURE RETRIEVAL TRAINING (wide vocab) ---
Step    0: RET=10.8750
Step  100: RET=10.8125
Step  200: RET=11.0000
Step  300: RET=11.0000
Step  400: RET=11.1250
Step  500: RET=11.3125
Step  600: RET=11.0000
Step  700: RET=10.6875
Step  800: RET=9.1250
Step  900: RET=6.7500

--- TEST ---
  Accuracy (full vocab range): 98.0% (49/50)


In [36]:
# =============================================================================
# VERIFY RETRIEVAL MECHANISM
# =============================================================================
# Confirm this is REAL state-based retrieval, not another shortcut

print("=== VERIFYING RETRIEVAL MECHANISM ===\n")

# Test 1: Different needle positions
print("1. DIFFERENT NEEDLE POSITIONS:")
for needle_pos in [8, 32, 64, 128]:
    acc = test_varied_needles_wide(shifted_model3, seq_len=256, needle_pos=needle_pos, n_trials=30)
    print(f"     Needle at position {needle_pos}: {acc*100:.1f}%")

# Test 2: Different sequence lengths  
print("\n2. DIFFERENT SEQUENCE LENGTHS:")
for seq_len in [128, 256, 384, 512]:
    acc = test_varied_needles_wide(shifted_model3, seq_len=seq_len, needle_pos=32, n_trials=30)
    print(f"     Seq len {seq_len}: {acc*100:.1f}%")

# Test 3: Ablation - what happens without the MARKER?
print("\n3. ABLATION - NO MARKER (should fail):")
def test_no_marker(model, seq_len=256, needle_pos=32, n_trials=30):
    model.eval()
    device = next(model.parameters()).device
    cfg = model.cfg
    
    correct = 0
    for trial in range(n_trials):
        needle_id = torch.randint(100, cfg.vocab_size - 100, (1,)).item()
        
        seq = torch.randint(0, cfg.vocab_size - 100, (1, seq_len), device=device)
        # NO MARKER - just place needle at position
        seq[0, needle_pos + 1] = needle_id
        seq[0, -1] = cfg.cue_token
        
        with torch.no_grad():
            logits, _, _, _ = model(seq)
        
        pred = logits[0, -1].argmax().item()
        if pred == needle_id:
            correct += 1
    
    return correct / n_trials

acc_no_marker = test_no_marker(shifted_model3, n_trials=50)
print(f"   Accuracy WITHOUT marker: {acc_no_marker*100:.1f}%")

# Test 4: Check β (gate) values during forward
print("\n4. GATE (β) VALUES:")
cfg = shifted_model3.cfg
seq = torch.randint(0, cfg.vocab_size - 100, (1, 128), device=DEVICE)
seq[0, 16] = cfg.marker_token
seq[0, 32] = cfg.cue_token

with torch.no_grad():
    _, _, layer_states, _ = shifted_model3(seq)

# layer_states is a list of dicts
gdn_state = layer_states[0]  # First layer (GDN)
print(f"   Available keys: {gdn_state.keys()}")
# Use 'beta' instead of 'β'
if 'beta' in gdn_state:
    beta = gdn_state['beta'][0].float()  # [T, H]
    print(f"   β at marker position:  {beta[16].mean().item():.4f}")
    print(f"   β at regular position: {beta[8].mean().item():.4f}")
    print(f"   β at cue position:     {beta[32].mean().item():.4f}")
    print(f"   Mean β overall:        {beta.mean().item():.4f}")

print("\n" + "="*50)
print("CONCLUSION: Real retrieval mechanism is working!")
print("="*50)

=== VERIFYING RETRIEVAL MECHANISM ===

1. DIFFERENT NEEDLE POSITIONS:
  Accuracy (full vocab range): 96.7% (29/30)
     Needle at position 8: 96.7%
  Accuracy (full vocab range): 100.0% (30/30)
     Needle at position 32: 100.0%
  Accuracy (full vocab range): 96.7% (29/30)
     Needle at position 64: 96.7%
  Accuracy (full vocab range): 93.3% (28/30)
     Needle at position 128: 93.3%

2. DIFFERENT SEQUENCE LENGTHS:
  Accuracy (full vocab range): 96.7% (29/30)
     Seq len 128: 96.7%
  Accuracy (full vocab range): 100.0% (30/30)
     Seq len 256: 100.0%
  Accuracy (full vocab range): 93.3% (28/30)
     Seq len 384: 93.3%
  Accuracy (full vocab range): 90.0% (27/30)
     Seq len 512: 90.0%

3. ABLATION - NO MARKER (should fail):
   Accuracy WITHOUT marker: 0.0%

4. GATE (β) VALUES:
   Available keys: dict_keys(['beta_mean', 'beta_max', 'g_mean', 'state_norm', 'state_max', 'layer', 'layer_idx'])

CONCLUSION: Real retrieval mechanism is working!


In [None]:
# =============================================================================
# PHASE 2: ADD LM TO RETRIEVAL-TRAINED MODEL
# =============================================================================
# shifted_model3 already has 98% retrieval. Now add LM capability.

print("Starting with retrieval-trained model (98% NIAH)")
print("Adding structured LM training...\n")

# Use shifted_model3 (already trained on retrieval)
# Lower LR to not destroy retrieval
optimizer = torch.optim.AdamW(shifted_model3.parameters(), lr=1e-4, weight_decay=0.01)

# Structured patterns (learnable, unlike random)
def make_structured_batch(batch_size=16, seq_len=128):
    """Patterns the model can actually learn."""
    x = torch.zeros(batch_size, seq_len, dtype=torch.long, device=DEVICE)
    for b in range(batch_size):
        # Repeating pattern: token followed by token+1
        base = torch.randint(100, 1000, (seq_len // 2,), device=DEVICE)
        for i, t in enumerate(base):
            if 2*i+1 < seq_len:
                x[b, 2*i] = t
                x[b, 2*i+1] = (t + 1) % 1000 + 100
    return x

shifted_model3.train()
for step in range(200):
    optimizer.zero_grad()
    
    x = make_structured_batch(batch_size=16, seq_len=128)
    targets = x[:, 1:].contiguous()
    
    logits, loss, _, _ = shifted_model3(x[:, :-1], targets=targets)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(shifted_model3.parameters(), 1.0)
    optimizer.step()
    
    if step % 50 == 0:
        print(f"Step {step}: LM loss = {loss.item():.4f}")

# Check retrieval still works
print("\n--- RETRIEVAL CHECK ---")
acc = test_varied_needles_wide(shifted_model3, seq_len=256, needle_pos=32, n_trials=30)
print(f"Retrieval after LM phase: {acc*100:.1f}%")

PART 1: STATE CAPACITY ANALYSIS

Theoretical state capacity: 65,536 floats/layer
  = 8 heads × 64 keys × 128 values

State matrix shape: torch.Size([8, 64, 128])
State matrix stats:
  Mean: -0.000254
  Std:  0.075182
  Max:  1.351562

Effective rank per head (via SVD):
  Head 0: rank=6, top singular=6.3107
  Head 1: rank=14, top singular=7.9908
  Head 2: rank=3, top singular=2.8393
  Head 3: rank=6, top singular=14.8318
  Head 4: rank=5, top singular=1.0326
  Head 5: rank=6, top singular=0.6762
  Head 6: rank=1, top singular=2.1289
  Head 7: rank=6, top singular=0.9176

UNTRAINED state stats:
  Std: 0.023769
  Max: 0.179688

TRAINED state stats:
  Std: 0.075182
  Max: 1.351562

State activity ratio (trained/untrained): 3.16x


In [39]:
# =============================================================================
# PART 2: JOINT LM + RETRIEVAL TRAINING
# =============================================================================
# The key question: Can the model do BOTH tasks simultaneously?
# - Language modeling (next token prediction on regular text)
# - Needle retrieval (MARKER → value storage → CUE → retrieval)

print("=" * 60)
print("PART 2: JOINT LM + RETRIEVAL TRAINING")
print("=" * 60)

from analysis import compute_retrieval_loss

# Start fresh with a new model
cfg_joint = HybridConfig(
    d_model=512, n_heads=8, head_dim=64, value_dim=128,
    layer_pattern="GS", window_size=64, chunk_size=64,
    shifted_value=True,
)
joint_model = TransparentHybrid(cfg_joint).to(DEVICE).to(torch.bfloat16)

print(f"\nModel params: {joint_model.count_params():,}")

# Mixed training: LM + Retrieval
optimizer = torch.optim.AdamW(joint_model.parameters(), lr=3e-4, weight_decay=0.01)

# LM data (simple random for now - replace with real data for better results)
def get_lm_batch(batch_size=8, seq_len=256):
    """Generate random sequences for LM training."""
    x = torch.randint(0, cfg_joint.vocab_size - 100, (batch_size, seq_len), device=DEVICE)
    targets = x.clone()
    targets[:, :-1] = x[:, 1:]  # Shift for next-token prediction
    targets[:, -1] = -100  # Ignore last position
    return x, targets

print("\n--- Training Loop (LM + Retrieval) ---")
lm_losses = []
ret_losses = []
ret_weight = 1.0  # Balance retrieval vs LM

for step in range(500):
    optimizer.zero_grad()
    joint_model.train()
    
    # LM loss
    x, targets = get_lm_batch()
    logits, lm_loss, _, _ = joint_model(x, targets=targets)
    
    # Retrieval loss
    ret_loss = compute_retrieval_loss(joint_model, seq_len=256, batch_size=4)
    
    # Combined loss
    total_loss = lm_loss + ret_weight * ret_loss
    total_loss.backward()
    
    torch.nn.utils.clip_grad_norm_(joint_model.parameters(), 1.0)
    optimizer.step()
    
    lm_losses.append(lm_loss.item())
    ret_losses.append(ret_loss.item())
    
    if step % 100 == 0:
        print(f"Step {step:4d}: LM={lm_loss.item():.4f}, RET={ret_loss.item():.4f}")

# Final evaluation
print("\n--- EVALUATION ---")
joint_model.eval()

# Test LM perplexity
with torch.no_grad():
    x, targets = get_lm_batch(batch_size=32)
    _, lm_loss_eval, _, _ = joint_model(x, targets=targets)
    ppl = torch.exp(lm_loss_eval).item()
    print(f"LM Perplexity: {ppl:.2f}")

# Test retrieval
acc_joint = test_varied_needles_wide(joint_model, seq_len=256, needle_pos=32, n_trials=50)
print(f"Retrieval Accuracy: {acc_joint*100:.1f}%")

# Compare with pure retrieval model
print(f"\n[Comparison] Pure retrieval model: 98%")
print(f"[Comparison] Joint model: {acc_joint*100:.1f}%")

PART 2: JOINT LM + RETRIEVAL TRAINING

Model params: 33,217,560

--- Training Loop (LM + Retrieval) ---
Step    0: LM=10.9375, RET=10.8750
Step  100: LM=10.9375, RET=10.9375
Step  200: LM=10.9375, RET=11.2500
Step  300: LM=10.9375, RET=10.8750
Step  400: LM=10.9375, RET=10.6250

--- EVALUATION ---
LM Perplexity: 56320.00
  Accuracy (full vocab range): 0.0% (0/50)
Retrieval Accuracy: 0.0%

[Comparison] Pure retrieval model: 98%
[Comparison] Joint model: 0.0%


In [40]:
# =============================================================================
# PART 2b: CURRICULUM - LM FIRST, THEN RETRIEVAL
# =============================================================================
# Maybe joint training from scratch is too hard.
# Try: First train LM, then add retrieval gradually

print("=" * 60)
print("PART 2b: CURRICULUM TRAINING")
print("=" * 60)

# Fresh model
cfg_curr = HybridConfig(
    d_model=512, n_heads=8, head_dim=64, value_dim=128,
    layer_pattern="GS", window_size=64, chunk_size=64,
    shifted_value=True,
)
curriculum_model = TransparentHybrid(cfg_curr).to(DEVICE).to(torch.bfloat16)

optimizer = torch.optim.AdamW(curriculum_model.parameters(), lr=3e-4, weight_decay=0.01)

# Phase 1: Pure LM (500 steps)
print("\n--- PHASE 1: LM only (500 steps) ---")
for step in range(500):
    optimizer.zero_grad()
    curriculum_model.train()
    
    x, targets = get_lm_batch()
    _, lm_loss, _, _ = curriculum_model(x, targets=targets)
    lm_loss.backward()
    
    torch.nn.utils.clip_grad_norm_(curriculum_model.parameters(), 1.0)
    optimizer.step()
    
    if step % 100 == 0:
        print(f"Step {step:4d}: LM={lm_loss.item():.4f}")

# Check LM
with torch.no_grad():
    curriculum_model.eval()
    x, targets = get_lm_batch(batch_size=32)
    _, lm_loss_eval, _, _ = curriculum_model(x, targets=targets)
    print(f"After Phase 1 LM loss: {lm_loss_eval.item():.4f}")

# Check retrieval (should be near 0 since not trained)
acc_phase1 = test_varied_needles_wide(curriculum_model, n_trials=30)
print(f"After Phase 1 Retrieval: {acc_phase1*100:.1f}%")

# Phase 2: Add retrieval
print("\n--- PHASE 2: LM + Retrieval (500 steps) ---")
for step in range(500):
    optimizer.zero_grad()
    curriculum_model.train()
    
    # LM loss
    x, targets = get_lm_batch()
    _, lm_loss, _, _ = curriculum_model(x, targets=targets)
    
    # Retrieval loss
    ret_loss = compute_retrieval_loss(curriculum_model, seq_len=256, batch_size=4)
    
    # Combined - higher retrieval weight since LM already learned basics
    total_loss = lm_loss + 2.0 * ret_loss
    total_loss.backward()
    
    torch.nn.utils.clip_grad_norm_(curriculum_model.parameters(), 1.0)
    optimizer.step()
    
    if step % 100 == 0:
        print(f"Step {step:4d}: LM={lm_loss.item():.4f}, RET={ret_loss.item():.4f}")

# Final evaluation
print("\n--- FINAL EVALUATION ---")
curriculum_model.eval()
with torch.no_grad():
    x, targets = get_lm_batch(batch_size=32)
    _, lm_loss_eval, _, _ = curriculum_model(x, targets=targets)
    print(f"Final LM loss: {lm_loss_eval.item():.4f}")

acc_final = test_varied_needles_wide(curriculum_model, n_trials=50)
print(f"Final Retrieval: {acc_final*100:.1f}%")

PART 2b: CURRICULUM TRAINING

--- PHASE 1: LM only (500 steps) ---
Step    0: LM=10.9375
Step  100: LM=10.9375
Step  200: LM=10.9375
Step  300: LM=10.9375
Step  400: LM=10.8750
After Phase 1 LM loss: 10.8750
  Accuracy (full vocab range): 0.0% (0/30)
After Phase 1 Retrieval: 0.0%

--- PHASE 2: LM + Retrieval (500 steps) ---
Step    0: LM=10.8750, RET=10.5000
Step  100: LM=10.8750, RET=11.2500
Step  200: LM=10.8750, RET=10.9375
Step  300: LM=10.9375, RET=10.7500
Step  400: LM=10.9375, RET=10.8750

--- FINAL EVALUATION ---
Final LM loss: 10.9375
  Accuracy (full vocab range): 0.0% (0/50)
Final Retrieval: 0.0%


In [41]:
# =============================================================================
# PART 2c: PATTERN-BASED LM + RETRIEVAL
# =============================================================================
# Random sequences have no structure - can't learn LM from them!
# Let's use LEARNABLE patterns: simple repetition/markov structure

print("=" * 60)
print("PART 2c: STRUCTURED LM + RETRIEVAL")
print("=" * 60)

def get_structured_lm_batch(batch_size=8, seq_len=256):
    """
    Generate sequences with learnable patterns:
    - Token A often followed by A+1
    - Simple bigram patterns
    """
    x = torch.zeros(batch_size, seq_len, dtype=torch.long, device=DEVICE)
    
    for b in range(batch_size):
        # Start with random token
        x[b, 0] = torch.randint(100, 1000, (1,))
        
        for t in range(1, seq_len):
            prev = x[b, t-1].item()
            # 70% chance: next token = prev + 1 (mod range)
            # 30% chance: random jump
            if torch.rand(1) < 0.7:
                x[b, t] = (prev + 1) % 1000 + 100
            else:
                x[b, t] = torch.randint(100, 1000, (1,))
    
    targets = x.clone()
    targets[:, :-1] = x[:, 1:]
    targets[:, -1] = -100
    return x, targets

# Test the pattern
x_test, y_test = get_structured_lm_batch(batch_size=1)
print(f"Sample structured sequence: {x_test[0, :10].tolist()}")

# Fresh model
cfg_struct = HybridConfig(
    d_model=512, n_heads=8, head_dim=64, value_dim=128,
    layer_pattern="GS", window_size=64, chunk_size=64,
    shifted_value=True,
)
struct_model = TransparentHybrid(cfg_struct).to(DEVICE).to(torch.bfloat16)
optimizer = torch.optim.AdamW(struct_model.parameters(), lr=3e-4, weight_decay=0.01)

# Phase 1: Structured LM
print("\n--- PHASE 1: Structured LM (1000 steps) ---")
for step in range(1000):
    optimizer.zero_grad()
    struct_model.train()
    
    x, targets = get_structured_lm_batch()
    _, lm_loss, _, _ = struct_model(x, targets=targets)
    lm_loss.backward()
    
    torch.nn.utils.clip_grad_norm_(struct_model.parameters(), 1.0)
    optimizer.step()
    
    if step % 200 == 0:
        print(f"Step {step:4d}: LM={lm_loss.item():.4f}")

# Evaluate LM
with torch.no_grad():
    struct_model.eval()
    x, targets = get_structured_lm_batch(batch_size=32)
    _, lm_loss_eval, _, _ = struct_model(x, targets=targets)
    ppl = torch.exp(lm_loss_eval).item()
    print(f"Structured LM Perplexity: {ppl:.2f} (optimal ~4 for 70/30 pattern)")

# Check if retrieval still works after LM training
acc_lm = test_varied_needles_wide(struct_model, n_trials=30)
print(f"Retrieval after LM only: {acc_lm*100:.1f}%")

# Phase 2: Add retrieval
print("\n--- PHASE 2: Structured LM + Retrieval (1000 steps) ---")
for step in range(1000):
    optimizer.zero_grad()
    struct_model.train()
    
    x, targets = get_structured_lm_batch()
    _, lm_loss, _, _ = struct_model(x, targets=targets)
    
    ret_loss = compute_retrieval_loss(struct_model, seq_len=256, batch_size=4)
    
    total_loss = lm_loss + 1.0 * ret_loss
    total_loss.backward()
    
    torch.nn.utils.clip_grad_norm_(struct_model.parameters(), 1.0)
    optimizer.step()
    
    if step % 200 == 0:
        print(f"Step {step:4d}: LM={lm_loss.item():.4f}, RET={ret_loss.item():.4f}")

# Final evaluation
print("\n--- FINAL EVALUATION ---")
struct_model.eval()
with torch.no_grad():
    x, targets = get_structured_lm_batch(batch_size=32)
    _, lm_loss_eval, _, _ = struct_model(x, targets=targets)
    ppl = torch.exp(lm_loss_eval).item()
    print(f"Final LM Perplexity: {ppl:.2f}")

acc_final = test_varied_needles_wide(struct_model, n_trials=50)
print(f"Final Retrieval: {acc_final*100:.1f}%")

PART 2c: STRUCTURED LM + RETRIEVAL
Sample structured sequence: [669, 770, 871, 972, 317, 418, 803, 405, 506, 607]

--- PHASE 1: Structured LM (1000 steps) ---
Step    0: LM=10.8750
Step  200: LM=2.8438
Step  400: LM=2.7812
Step  600: LM=2.6719
Step  800: LM=2.8125
Structured LM Perplexity: 14.44 (optimal ~4 for 70/30 pattern)
  Accuracy (full vocab range): 0.0% (0/30)
Retrieval after LM only: 0.0%

--- PHASE 2: Structured LM + Retrieval (1000 steps) ---
Step    0: LM=2.6562, RET=13.4375
Step  200: LM=2.8906, RET=11.0000


KeyboardInterrupt: 

In [42]:
# =============================================================================
# PROPER CURRICULUM: RETRIEVAL → LM (ONE CLEAN RUN)
# =============================================================================

import importlib
analysis = importlib.reload(analysis)
from analysis import compute_retrieval_loss

# FRESH MODEL
cfg_final = HybridConfig(
    d_model=512, n_heads=8, head_dim=64, value_dim=128,
    layer_pattern="GS", window_size=64, chunk_size=64,
    shifted_value=True,
)
final_model = TransparentHybrid(cfg_final).to(DEVICE).to(torch.bfloat16)
print(f"Fresh model: {final_model.count_params():,} params")

optimizer = torch.optim.AdamW(final_model.parameters(), lr=3e-4, weight_decay=0.01)

# =============================================================================
# PHASE 1: PURE RETRIEVAL (until 95%+)
# =============================================================================
print("\n" + "="*60)
print("PHASE 1: PURE RETRIEVAL")
print("="*60)

final_model.train()
for step in range(2000):
    optimizer.zero_grad()
    ret_loss = compute_retrieval_loss(final_model, seq_len=256, batch_size=8)
    ret_loss.backward()
    torch.nn.utils.clip_grad_norm_(final_model.parameters(), 1.0)
    optimizer.step()
    
    if step % 200 == 0:
        # Check accuracy
        final_model.eval()
        acc = test_varied_needles_wide(final_model, seq_len=256, needle_pos=32, n_trials=20)
        final_model.train()
        print(f"Step {step:4d}: loss={ret_loss.item():.4f}, acc={acc*100:.1f}%")
        
        if acc >= 0.95:
            print(f"  → Reached 95%+ at step {step}, moving to Phase 2")
            break

# Final Phase 1 check
final_model.eval()
phase1_acc = test_varied_needles_wide(final_model, seq_len=256, needle_pos=32, n_trials=50)
print(f"\nPHASE 1 COMPLETE: {phase1_acc*100:.1f}% retrieval accuracy")

Fresh model: 33,217,560 params

PHASE 1: PURE RETRIEVAL
  Accuracy (full vocab range): 0.0% (0/20)
Step    0: loss=11.1250, acc=0.0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  200: loss=10.9375, acc=0.0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  400: loss=11.0000, acc=0.0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  600: loss=11.0000, acc=0.0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  800: loss=9.5000, acc=0.0%
  Accuracy (full vocab range): 100.0% (20/20)
Step 1000: loss=3.8906, acc=100.0%
  → Reached 95%+ at step 1000, moving to Phase 2
  Accuracy (full vocab range): 100.0% (50/50)

PHASE 1 COMPLETE: 100.0% retrieval accuracy


In [43]:
# =============================================================================
# PHASE 1b: TEST RETRIEVAL THOROUGHLY
# =============================================================================

print("="*60)
print("RETRIEVAL STRESS TESTS")
print("="*60)

# Test 1: Different distances
print("\n1. DISTANCE TEST (needle at various positions):")
for needle_pos in [8, 32, 64, 128, 200]:
    acc = test_varied_needles_wide(final_model, seq_len=256, needle_pos=needle_pos, n_trials=30)

# Test 2: Longer sequences
print("\n2. SEQUENCE LENGTH TEST:")
for seq_len in [256, 512, 1024]:
    acc = test_varied_needles_wide(final_model, seq_len=seq_len, needle_pos=32, n_trials=30)

# Test 3: MULTI-NEEDLE
print("\n3. MULTI-NEEDLE RETRIEVAL:")

def test_multi_needle(model, seq_len=512, n_needles=3, n_trials=30):
    """Test retrieval of multiple needles in same sequence."""
    model.eval()
    device = next(model.parameters()).device
    cfg = model.cfg
    
    correct_all = 0
    correct_any = 0
    per_needle_correct = [0] * n_needles
    
    for trial in range(n_trials):
        # Generate sequence
        seq = torch.randint(0, cfg.vocab_size - 100, (1, seq_len), device=device)
        
        # Place n_needles at evenly spaced positions
        needle_ids = []
        positions = []
        spacing = (seq_len - 50) // (n_needles + 1)
        
        for i in range(n_needles):
            pos = spacing * (i + 1)
            needle_id = torch.randint(100, cfg.vocab_size - 100, (1,)).item()
            
            seq[0, pos] = cfg.marker_token
            seq[0, pos + 1] = needle_id
            
            positions.append(pos)
            needle_ids.append(needle_id)
        
        # Query each needle
        trial_correct = 0
        for i, (pos, needle_id) in enumerate(zip(positions, needle_ids)):
            # Create query sequence: original + CUE at end
            query_seq = seq.clone()
            query_seq[0, -1] = cfg.cue_token
            
            # Also need to place the MARKER at end-2 to trigger retrieval
            # Wait - the test should query with the original marker's position info
            # Actually: CUE should retrieve what was stored at MARKER
            # But we have multiple markers... need to think about this
            
            # For now: test if model can retrieve the LAST needle
            pass
        
        # Simpler test: just check if LAST needle is retrieved
        query_seq = seq.clone()
        query_seq[0, -1] = cfg.cue_token
        
        with torch.no_grad():
            logits, _, _, _ = model(query_seq)
        
        pred = logits[0, -1].argmax().item()
        
        # Check each needle
        for i, nid in enumerate(needle_ids):
            if pred == nid:
                per_needle_correct[i] += 1
                trial_correct += 1
        
        if trial_correct == n_needles:
            correct_all += 1
        if trial_correct > 0:
            correct_any += 1
    
    print(f"  {n_needles} needles in seq_len={seq_len}:")
    print(f"    Per-needle retrieval: {[f'{c/n_trials*100:.0f}%' for c in per_needle_correct]}")
    print(f"    Any needle correct: {correct_any/n_trials*100:.1f}%")
    
    # Which needle gets retrieved most?
    most_retrieved = per_needle_correct.index(max(per_needle_correct))
    print(f"    Most retrieved: needle {most_retrieved+1} (position {most_retrieved+1}/{n_needles})")
    
    return per_needle_correct, correct_any / n_trials

# Test with increasing needle counts
for n in [2, 3, 5]:
    test_multi_needle(final_model, seq_len=512, n_needles=n, n_trials=30)

RETRIEVAL STRESS TESTS

1. DISTANCE TEST (needle at various positions):
  Accuracy (full vocab range): 100.0% (30/30)
  Accuracy (full vocab range): 100.0% (30/30)
  Accuracy (full vocab range): 100.0% (30/30)
  Accuracy (full vocab range): 96.7% (29/30)
  Accuracy (full vocab range): 100.0% (30/30)

2. SEQUENCE LENGTH TEST:
  Accuracy (full vocab range): 100.0% (30/30)
  Accuracy (full vocab range): 96.7% (29/30)
  Accuracy (full vocab range): 13.3% (4/30)

3. MULTI-NEEDLE RETRIEVAL:
  2 needles in seq_len=512:
    Per-needle retrieval: ['0%', '100%']
    Any needle correct: 100.0%
    Most retrieved: needle 2 (position 2/2)
  3 needles in seq_len=512:
    Per-needle retrieval: ['0%', '0%', '100%']
    Any needle correct: 100.0%
    Most retrieved: needle 3 (position 3/3)
  5 needles in seq_len=512:
    Per-needle retrieval: ['0%', '0%', '0%', '0%', '100%']
    Any needle correct: 100.0%
    Most retrieved: needle 5 (position 5/5)


In [44]:
# =============================================================================
# DIAGNOSIS: WHY ONLY LAST NEEDLE?
# =============================================================================
# The state is being overwritten. Check:
# 1. Are keys collapsing (all similar)?
# 2. Is β too high (overwriting everything)?
# 3. Is g too low (not retaining)?

print("="*60)
print("DIAGNOSING RECENCY BIAS")
print("="*60)

# Check key similarity after training
print("\n1. KEY SIMILARITY:")
check_key_similarity(final_model, n_tokens=50, name="TRAINED (final_model)")

# Check gate values
print("\n2. GATE VALUES:")
analyze_gates(final_model, "TRAINED (final_model)")

# Check β specifically on marker tokens
print("\n3. β ON MARKER vs REGULAR TOKENS:")
gdn_layer = None
for layer in final_model.layers:
    if hasattr(layer, 'beta_proj'):
        gdn_layer = layer
        break

with torch.no_grad():
    # Create sequence with markers
    seq = torch.randint(100, 1000, (1, 128), device=DEVICE)
    seq[0, 20] = cfg_final.marker_token
    seq[0, 60] = cfg_final.marker_token
    seq[0, 100] = cfg_final.marker_token
    
    emb = final_model.embed(seq)
    x_norm = gdn_layer.norm(emb)
    beta = torch.sigmoid(gdn_layer.beta_proj(x_norm))  # [1, T, H]
    g = torch.sigmoid(gdn_layer.g_proj(x_norm))
    
    print(f"  β at marker positions: {beta[0, [20, 60, 100]].mean(dim=1).tolist()}")
    print(f"  β at regular positions: {beta[0, [10, 50, 90]].mean(dim=1).tolist()}")
    print(f"  Mean β overall: {beta.mean().item():.4f}")
    print(f"  Mean g overall: {g.mean().item():.4f}")

# The problem: if β is high everywhere, every token overwrites state
# Solution: β should be HIGH only at markers, LOW elsewhere

DIAGNOSING RECENCY BIAS

1. KEY SIMILARITY:

TRAINED (final_model) - Key similarity analysis:
  Avg |cosine| between different token keys:
    Per-head: ['0.599', '0.752', '0.475', '0.529', '0.701', '0.673', '0.798', '0.284']
    Overall: 0.601
    Max head: 0.798
    Expected for random 64D vectors: 0.100
    ⚠ Keys are LESS orthogonal than random (more interference)

2. GATE VALUES:

TRAINED (final_model) gate analysis:
  Layer 0 (GDN):
    β: mean=0.1309, std=0.0654, range=[0.023, 0.439]
    g: mean=0.8633, std=0.0762, range=[0.488, 0.992]
    β_bias: -2.0000 (init was -2.0)
    g_bias: 2.0000 (init was 3.0)

3. β ON MARKER vs REGULAR TOKENS:
  β at marker positions: [0.353515625, 0.353515625, 0.353515625]
  β at regular positions: [0.04931640625, 0.05615234375, 0.045166015625]
  Mean β overall: 0.0752
  Mean g overall: 0.8633


In [45]:
# =============================================================================
# PHASE 2: RETRAIN WITH KEY ORTHOGONALITY
# =============================================================================
# Pure retrieval training collapsed keys. Add orthogonality regularization.

import torch.nn.functional as F

def compute_key_orth_loss(model, batch_size=8, seq_len=128):
    """Penalize high cosine similarity between keys."""
    gdn_layer = None
    for layer in model.layers:
        if hasattr(layer, 'k_proj'):
            gdn_layer = layer
            break
    
    if gdn_layer is None:
        return torch.tensor(0.0, device=DEVICE)
    
    # Random tokens
    tokens = torch.randint(100, 10000, (batch_size, seq_len), device=DEVICE)
    emb = model.embed(tokens)
    x_norm = gdn_layer.norm(emb)
    keys = gdn_layer.k_proj(x_norm)  # [B, T, H*K]
    keys = keys.view(batch_size, seq_len, model.cfg.n_heads, model.cfg.head_dim)
    keys = F.normalize(keys.float(), p=2, dim=-1)  # [B, T, H, K]
    
    # Compute pairwise similarity per head, average across batch
    loss = 0.0
    for h in range(model.cfg.n_heads):
        k_h = keys[:, :, h, :]  # [B, T, K]
        # [B, T, T] similarity matrices
        sim = torch.bmm(k_h, k_h.transpose(1, 2))
        # Penalize off-diagonal (exclude self-similarity)
        mask = ~torch.eye(seq_len, dtype=torch.bool, device=DEVICE)
        off_diag = sim[:, mask].abs()
        loss = loss + off_diag.mean()
    
    return loss / model.cfg.n_heads

# FRESH MODEL
print("="*60)
print("RETRAINING WITH KEY ORTHOGONALITY")
print("="*60)

cfg_v2 = HybridConfig(
    d_model=512, n_heads=8, head_dim=64, value_dim=128,
    layer_pattern="GS", window_size=64, chunk_size=64,
    shifted_value=True,
)
model_v2 = TransparentHybrid(cfg_v2).to(DEVICE).to(torch.bfloat16)
print(f"Fresh model: {model_v2.count_params():,} params")

optimizer = torch.optim.AdamW(model_v2.parameters(), lr=3e-4, weight_decay=0.01)

model_v2.train()
for step in range(2000):
    optimizer.zero_grad()
    
    # Retrieval loss
    ret_loss = compute_retrieval_loss(model_v2, seq_len=256, batch_size=8)
    
    # Key orthogonality loss
    orth_loss = compute_key_orth_loss(model_v2, batch_size=4, seq_len=64)
    
    total_loss = ret_loss + 0.5 * orth_loss
    total_loss.backward()
    
    torch.nn.utils.clip_grad_norm_(model_v2.parameters(), 1.0)
    optimizer.step()
    
    if step % 200 == 0:
        model_v2.eval()
        acc = test_varied_needles_wide(model_v2, seq_len=256, needle_pos=32, n_trials=20)
        model_v2.train()
        print(f"Step {step:4d}: ret={ret_loss.item():.3f}, orth={orth_loss.item():.3f}, acc={acc*100:.0f}%")
        
        if acc >= 0.95:
            print(f"  → 95%+ reached, checking key similarity...")
            check_key_similarity(model_v2, n_tokens=30, name="model_v2")
            break

# Final check
print("\n--- FINAL CHECKS ---")
model_v2.eval()
print("\n1. Key similarity:")
check_key_similarity(model_v2, n_tokens=50, name="model_v2 (after orth reg)")

print("\n2. Multi-needle test:")
for n in [2, 3, 5]:
    test_multi_needle(model_v2, seq_len=512, n_needles=n, n_trials=30)

RETRAINING WITH KEY ORTHOGONALITY
Fresh model: 33,217,560 params
  Accuracy (full vocab range): 0.0% (0/20)
Step    0: ret=10.875, orth=0.106, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  200: ret=10.875, orth=0.105, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  400: ret=10.875, orth=0.133, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  600: ret=11.188, orth=0.140, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  800: ret=10.938, orth=0.140, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step 1000: ret=11.000, orth=0.125, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step 1200: ret=11.125, orth=0.117, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step 1400: ret=11.000, orth=0.122, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step 1600: ret=11.000, orth=0.120, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step 1800: ret=11.125, orth=0.120, acc=0%

--- FINAL CHECKS ---

1. Key similarity:

model_v2 (after orth reg) - Key similarity

In [46]:
# =============================================================================
# PHASE 2b: CURRICULUM - RETRIEVAL FIRST, THEN ORTH REG
# =============================================================================
# Pure retrieval works but collapses keys
# Pure orth reg prevents learning
# Solution: Train retrieval FIRST, then add orth reg to fix keys

print("="*60)
print("CURRICULUM: RETRIEVAL → ORTH REG")
print("="*60)

cfg_v3 = HybridConfig(
    d_model=512, n_heads=8, head_dim=64, value_dim=128,
    layer_pattern="GS", window_size=64, chunk_size=64,
    shifted_value=True,
)
model_v3 = TransparentHybrid(cfg_v3).to(DEVICE).to(torch.bfloat16)
optimizer = torch.optim.AdamW(model_v3.parameters(), lr=3e-4, weight_decay=0.01)

# STAGE 1: Pure retrieval until 95%+
print("\n--- STAGE 1: Pure Retrieval ---")
model_v3.train()
for step in range(2000):
    optimizer.zero_grad()
    ret_loss = compute_retrieval_loss(model_v3, seq_len=256, batch_size=8)
    ret_loss.backward()
    torch.nn.utils.clip_grad_norm_(model_v3.parameters(), 1.0)
    optimizer.step()
    
    if step % 200 == 0:
        model_v3.eval()
        acc = test_varied_needles_wide(model_v3, seq_len=256, needle_pos=32, n_trials=20)
        model_v3.train()
        print(f"Step {step:4d}: loss={ret_loss.item():.3f}, acc={acc*100:.0f}%")
        if acc >= 0.95:
            print(f"  → Stage 1 complete")
            break

# Check keys before stage 2
print("\nKey similarity BEFORE orth reg:")
check_key_similarity(model_v3, n_tokens=30, name="model_v3")

# STAGE 2: Add orth reg while maintaining retrieval
print("\n--- STAGE 2: Retrieval + Orth Reg (lower LR) ---")
optimizer = torch.optim.AdamW(model_v3.parameters(), lr=1e-4, weight_decay=0.01)

model_v3.train()
for step in range(1000):
    optimizer.zero_grad()
    
    ret_loss = compute_retrieval_loss(model_v3, seq_len=256, batch_size=8)
    orth_loss = compute_key_orth_loss(model_v3, batch_size=4, seq_len=64)
    
    # Light orth pressure
    total_loss = ret_loss + 0.1 * orth_loss
    total_loss.backward()
    
    torch.nn.utils.clip_grad_norm_(model_v3.parameters(), 1.0)
    optimizer.step()
    
    if step % 200 == 0:
        model_v3.eval()
        acc = test_varied_needles_wide(model_v3, seq_len=256, needle_pos=32, n_trials=20)
        model_v3.train()
        print(f"Step {step:4d}: ret={ret_loss.item():.3f}, orth={orth_loss.item():.3f}, acc={acc*100:.0f}%")

# Final checks
print("\n--- FINAL EVALUATION ---")
model_v3.eval()

print("\n1. Key similarity AFTER orth reg:")
check_key_similarity(model_v3, n_tokens=50, name="model_v3 (after curriculum)")

print("\n2. Single needle at various distances:")
for pos in [32, 64, 128]:
    test_varied_needles_wide(model_v3, seq_len=256, needle_pos=pos, n_trials=20)

print("\n3. Multi-needle:")
for n in [2, 3, 5]:
    test_multi_needle(model_v3, seq_len=512, n_needles=n, n_trials=30)

CURRICULUM: RETRIEVAL → ORTH REG

--- STAGE 1: Pure Retrieval ---
  Accuracy (full vocab range): 0.0% (0/20)
Step    0: loss=10.812, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  200: loss=10.875, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  400: loss=10.938, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  600: loss=11.188, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  800: loss=10.000, acc=0%
  Accuracy (full vocab range): 100.0% (20/20)
Step 1000: loss=4.312, acc=100%
  → Stage 1 complete

Key similarity BEFORE orth reg:

model_v3 - Key similarity analysis:
  Avg |cosine| between different token keys:
    Per-head: ['0.791', '0.294', '0.822', '0.732', '0.243', '0.817', '0.850', '0.661']
    Overall: 0.652
    Max head: 0.850
    Expected for random 64D vectors: 0.100
    ⚠ Keys are LESS orthogonal than random (more interference)

--- STAGE 2: Retrieval + Orth Reg (lower LR) ---
  Accuracy (full vocab range): 100.0% (20/20)
Step    0: ret=5.250, 

In [47]:
# =============================================================================
# ANALYSIS: WHY ONLY LAST NEEDLE?
# =============================================================================
# Even with improved key orthogonality, only last needle retrieved.
# Check if the problem is that ALL markers produce the SAME key.

print("="*60)
print("MARKER KEY ANALYSIS")
print("="*60)

gdn_layer = None
for layer in model_v3.layers:
    if hasattr(layer, 'k_proj'):
        gdn_layer = layer
        break

# Compute keys for the MARKER token at different positions
marker_token = cfg_v3.marker_token
print(f"Marker token ID: {marker_token}")

with torch.no_grad():
    # Same marker token, but different context
    seq1 = torch.randint(100, 1000, (1, 64), device=DEVICE)
    seq2 = torch.randint(100, 1000, (1, 64), device=DEVICE)
    seq3 = torch.randint(100, 1000, (1, 64), device=DEVICE)
    
    # Place marker at same position in each
    pos = 32
    seq1[0, pos] = marker_token
    seq2[0, pos] = marker_token
    seq3[0, pos] = marker_token
    
    # Get embeddings and keys
    emb1 = model_v3.embed(seq1)
    emb2 = model_v3.embed(seq2)
    emb3 = model_v3.embed(seq3)
    
    x1 = gdn_layer.norm(emb1)
    x2 = gdn_layer.norm(emb2)
    x3 = gdn_layer.norm(emb3)
    
    k1 = gdn_layer.k_proj(x1)[0, pos]  # [H*K]
    k2 = gdn_layer.k_proj(x2)[0, pos]
    k3 = gdn_layer.k_proj(x3)[0, pos]
    
    # Reshape to [H, K] and normalize
    k1 = F.normalize(k1.view(cfg_v3.n_heads, cfg_v3.head_dim).float(), dim=-1)
    k2 = F.normalize(k2.view(cfg_v3.n_heads, cfg_v3.head_dim).float(), dim=-1)
    k3 = F.normalize(k3.view(cfg_v3.n_heads, cfg_v3.head_dim).float(), dim=-1)
    
    # Similarity between marker keys in different contexts
    sim_12 = (k1 * k2).sum(dim=-1)  # [H]
    sim_13 = (k1 * k3).sum(dim=-1)
    sim_23 = (k2 * k3).sum(dim=-1)
    
    print(f"\nMarker key similarity ACROSS DIFFERENT CONTEXTS:")
    print(f"  Seq1 vs Seq2: {sim_12.mean().item():.4f} (per-head: {sim_12.tolist()})")
    print(f"  Seq1 vs Seq3: {sim_13.mean().item():.4f}")
    print(f"  Seq2 vs Seq3: {sim_23.mean().item():.4f}")
    
    if sim_12.mean() > 0.9:
        print("\n⚠ PROBLEM: Marker produces nearly IDENTICAL keys regardless of context!")
        print("   This means all markers write to the SAME slot, overwriting each other.")
    else:
        print("\n✓ Marker keys vary with context")

# The fix might need to be: use POSITION or CONTEXT to differentiate marker keys
# Or: use the VALUE (needle) to create a unique key for each marker+needle pair

MARKER KEY ANALYSIS
Marker token ID: 50251

Marker key similarity ACROSS DIFFERENT CONTEXTS:
  Seq1 vs Seq2: 1.0000 (per-head: [1.0, 0.9999998211860657, 1.0, 0.9999999403953552, 1.0, 1.0000001192092896, 1.0, 1.0])
  Seq1 vs Seq3: 1.0000
  Seq2 vs Seq3: 1.0000

⚠ PROBLEM: Marker produces nearly IDENTICAL keys regardless of context!
   This means all markers write to the SAME slot, overwriting each other.


In [48]:
# =============================================================================
# TEST RoPE FOR POSITION-DEPENDENT KEYS
# =============================================================================

import importlib
import config as config_module
import model as model_module

config_module = importlib.reload(config_module)
model_module = importlib.reload(model_module)

from config import HybridConfig
from model import TransparentHybrid

print("="*60)
print("TESTING RoPE (Rotary Position Embeddings)")
print("="*60)

# Create model WITH RoPE
cfg_rope = HybridConfig(
    d_model=512, n_heads=8, head_dim=64, value_dim=128,
    layer_pattern="GS", window_size=64, chunk_size=64,
    shifted_value=True,
    use_rope=True,  # ENABLE RoPE
)
model_rope = TransparentHybrid(cfg_rope).to(DEVICE).to(torch.bfloat16)
print(f"Model with RoPE: {model_rope.count_params():,} params")

# Quick forward test
x_test = torch.randint(0, 1000, (2, 64), device=DEVICE)
with torch.no_grad():
    logits, _, diags, state = model_rope(x_test)
print(f"Forward pass OK: output={logits.shape}")

# Check if marker keys are NOW position-dependent
print("\n--- MARKER KEY ANALYSIS (with RoPE) ---")
gdn_layer = model_rope.layers[0]

marker_token = cfg_rope.marker_token
with torch.no_grad():
    # Same marker token, different contexts
    seq1 = torch.randint(100, 1000, (1, 64), device=DEVICE)
    seq2 = torch.randint(100, 1000, (1, 64), device=DEVICE)
    
    # Place marker at SAME position
    seq1[0, 32] = marker_token
    seq2[0, 32] = marker_token
    
    # Get keys
    emb1 = model_rope.embed(seq1)
    emb2 = model_rope.embed(seq2)
    x1 = gdn_layer.norm(emb1)
    x2 = gdn_layer.norm(emb2)
    k1_raw = gdn_layer.k_proj(x1).view(1, 64, 8, 64)
    k2_raw = gdn_layer.k_proj(x2).view(1, 64, 8, 64)
    
    # Apply RoPE
    k1 = gdn_layer.rotary(k1_raw)
    k2 = gdn_layer.rotary(k2_raw)
    
    # Compare keys at position 32
    k1_32 = F.normalize(k1[0, 32].float(), dim=-1)
    k2_32 = F.normalize(k2[0, 32].float(), dim=-1)
    sim_same_pos = (k1_32 * k2_32).sum(dim=-1)
    print(f"Marker at SAME position (32): similarity={sim_same_pos.mean().item():.4f}")
    
    # Now place markers at DIFFERENT positions
    seq3 = torch.randint(100, 1000, (1, 64), device=DEVICE)
    seq3[0, 16] = marker_token  # Position 16
    
    emb3 = model_rope.embed(seq3)
    x3 = gdn_layer.norm(emb3)
    k3_raw = gdn_layer.k_proj(x3).view(1, 64, 8, 64)
    k3 = gdn_layer.rotary(k3_raw)
    
    k3_16 = F.normalize(k3[0, 16].float(), dim=-1)
    sim_diff_pos = (k1_32 * k3_16).sum(dim=-1)
    print(f"Marker at DIFFERENT positions (32 vs 16): similarity={sim_diff_pos.mean().item():.4f}")
    
    if sim_diff_pos.mean() < 0.5 and sim_same_pos.mean() > 0.9:
        print("\n✓ RoPE working: Same position → similar keys, different positions → different keys!")
    else:
        print("\n⚠ RoPE may not be working as expected")

TESTING RoPE (Rotary Position Embeddings)
Model with RoPE: 33,217,560 params
Forward pass OK: output=torch.Size([2, 64, 50257])

--- MARKER KEY ANALYSIS (with RoPE) ---
Marker at SAME position (32): similarity=1.0000
Marker at DIFFERENT positions (32 vs 16): similarity=0.6113

⚠ RoPE may not be working as expected


In [49]:
# =============================================================================
# TRAIN WITH RoPE
# =============================================================================

print("="*60)
print("TRAINING WITH RoPE")
print("="*60)

# Fresh model with RoPE
cfg_rope = HybridConfig(
    d_model=512, n_heads=8, head_dim=64, value_dim=128,
    layer_pattern="GS", window_size=64, chunk_size=64,
    shifted_value=True,
    use_rope=True,
)
model_rope = TransparentHybrid(cfg_rope).to(DEVICE).to(torch.bfloat16)

optimizer = torch.optim.AdamW(model_rope.parameters(), lr=3e-4, weight_decay=0.01)

# PHASE 1: Pure retrieval
print("\n--- PHASE 1: Pure Retrieval ---")
model_rope.train()
for step in range(2000):
    optimizer.zero_grad()
    ret_loss = compute_retrieval_loss(model_rope, seq_len=256, batch_size=8)
    ret_loss.backward()
    torch.nn.utils.clip_grad_norm_(model_rope.parameters(), 1.0)
    optimizer.step()
    
    if step % 200 == 0:
        model_rope.eval()
        acc = test_varied_needles_wide(model_rope, seq_len=256, needle_pos=32, n_trials=20)
        model_rope.train()
        print(f"Step {step:4d}: loss={ret_loss.item():.3f}, acc={acc*100:.0f}%")
        if acc >= 0.95:
            print(f"  → Phase 1 complete at step {step}")
            break

# Test single needle
print("\n--- SINGLE NEEDLE TESTS ---")
model_rope.eval()
for pos in [32, 64, 128]:
    test_varied_needles_wide(model_rope, seq_len=256, needle_pos=pos, n_trials=30)

# Test multi-needle - THE KEY TEST
print("\n--- MULTI-NEEDLE TESTS (with RoPE) ---")
for n in [2, 3, 5]:
    test_multi_needle(model_rope, seq_len=512, n_needles=n, n_trials=30)

# Check key similarity
print("\n--- KEY SIMILARITY (with RoPE) ---")
check_key_similarity(model_rope, n_tokens=50, name="model_rope (trained)")

TRAINING WITH RoPE

--- PHASE 1: Pure Retrieval ---
  Accuracy (full vocab range): 0.0% (0/20)
Step    0: loss=10.938, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  200: loss=10.812, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  400: loss=11.000, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  600: loss=10.875, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  800: loss=10.375, acc=0%
  Accuracy (full vocab range): 80.0% (16/20)
Step 1000: loss=6.156, acc=80%
  Accuracy (full vocab range): 100.0% (20/20)
Step 1200: loss=5.031, acc=100%
  → Phase 1 complete at step 1200

--- SINGLE NEEDLE TESTS ---
  Accuracy (full vocab range): 96.7% (29/30)
  Accuracy (full vocab range): 100.0% (30/30)
  Accuracy (full vocab range): 96.7% (29/30)

--- MULTI-NEEDLE TESTS (with RoPE) ---
  2 needles in seq_len=512:
    Per-needle retrieval: ['77%', '10%']
    Any needle correct: 86.7%
    Most retrieved: needle 1 (position 1/2)
  3 needles in seq_len=512:
    Per-needle

In [50]:
# =============================================================================
# IMPROVED MULTI-NEEDLE TEST: Query with MARKER token
# =============================================================================
# Current test uses CUE for all queries. With RoPE, we should query
# with the MARKER token at a position that matches where it was stored.
# But that's not how inference works - we need to query with just the marker token.

# Actually, the issue is: how does the model know WHICH needle to retrieve?
# The CUE token is the same for all needles.
# 
# In NIAH literature, each needle has a UNIQUE marker (key).
# Let's implement that: each marker-needle pair uses a different marker token.

print("="*60)
print("IMPROVED MULTI-NEEDLE: UNIQUE MARKERS")
print("="*60)

def test_multi_needle_unique_markers(model, seq_len=512, n_needles=3, n_trials=30):
    """
    Test with UNIQUE marker tokens for each needle.
    Format:
    - MARKER_1, NEEDLE_1 at pos P1
    - MARKER_2, NEEDLE_2 at pos P2
    - ...
    - Query: MARKER_1 → should retrieve NEEDLE_1
    """
    model.eval()
    device = next(model.parameters()).device
    cfg = model.cfg
    
    # Use different marker tokens for each needle
    base_marker = cfg.marker_token - 100  # Start below marker_token
    
    per_needle_correct = [0] * n_needles
    
    for trial in range(n_trials):
        seq = torch.randint(0, cfg.vocab_size - 200, (1, seq_len), device=device)
        
        # Place needles with unique markers
        spacing = (seq_len - 50) // (n_needles + 1)
        markers = []
        needles = []
        
        for i in range(n_needles):
            pos = spacing * (i + 1)
            marker_id = base_marker + i  # Unique marker per needle
            needle_id = torch.randint(100, cfg.vocab_size - 200, (1,)).item()
            
            seq[0, pos] = marker_id
            seq[0, pos + 1] = needle_id
            
            markers.append(marker_id)
            needles.append(needle_id)
        
        # Query each needle using its unique marker
        for i, (marker_id, needle_id) in enumerate(zip(markers, needles)):
            query_seq = seq.clone()
            query_seq[0, -1] = marker_id  # Query with the specific marker
            
            with torch.no_grad():
                logits, _, _, _ = model(query_seq)
            
            pred = logits[0, -1].argmax().item()
            if pred == needle_id:
                per_needle_correct[i] += 1
    
    print(f"  {n_needles} needles with UNIQUE markers:")
    print(f"    Per-needle: {[f'{c/n_trials*100:.0f}%' for c in per_needle_correct]}")
    print(f"    Average: {sum(per_needle_correct)/(n_needles*n_trials)*100:.1f}%")
    
    return per_needle_correct

# Test with unique markers
print("\n--- Testing with UNIQUE markers (untrained for this) ---")
for n in [2, 3, 5]:
    test_multi_needle_unique_markers(model_rope, seq_len=512, n_needles=n, n_trials=20)

# The model wasn't trained for unique markers, so it probably won't work well yet.
# But this shows the right direction: each needle needs a unique key to query by.

IMPROVED MULTI-NEEDLE: UNIQUE MARKERS

--- Testing with UNIQUE markers (untrained for this) ---
  2 needles with UNIQUE markers:
    Per-needle: ['0%', '0%']
    Average: 0.0%
  3 needles with UNIQUE markers:
    Per-needle: ['0%', '0%', '0%']
    Average: 0.0%
  5 needles with UNIQUE markers:
    Per-needle: ['0%', '0%', '0%', '0%', '0%']
    Average: 0.0%


In [51]:
# =============================================================================
# PROPER TEST: NO RoPE IN GDN, RoPE IN SWA
# =============================================================================
# Architecture understanding:
# - GDN: Content-addressable storage (key = content, not position)
# - SWA: Position-aware attention (RoPE on q/k for local window)
#
# The problem with multi-needle is that all MARKERs have the SAME key.
# Solution: Use UNIQUE marker tokens for each slot, OR use context-dependent keys.

import importlib
import config as config_module
import model as model_module
config_module = importlib.reload(config_module)
model_module = importlib.reload(model_module)
from config import HybridConfig
from model import TransparentHybrid

print("="*60)
print("TESTING: NO RoPE in GDN, RoPE in SWA")
print("="*60)

# Model WITHOUT RoPE (GDN pure content, SWA local attention)
cfg_clean = HybridConfig(
    d_model=512, n_heads=8, head_dim=64, value_dim=128,
    layer_pattern="GS", window_size=64, chunk_size=64,
    shifted_value=True,
    use_rope=True,  # RoPE now only applies to SWA
)
model_clean = TransparentHybrid(cfg_clean).to(DEVICE).to(torch.bfloat16)
print(f"Model: {model_clean.count_params():,} params")

# Forward test
x_test = torch.randint(0, 1000, (2, 64), device=DEVICE)
with torch.no_grad():
    logits, _, diags, state = model_clean(x_test)
print(f"Forward pass OK: {logits.shape}")

# Train on retrieval
optimizer = torch.optim.AdamW(model_clean.parameters(), lr=3e-4, weight_decay=0.01)

print("\n--- Training Pure Retrieval ---")
model_clean.train()
for step in range(2000):
    optimizer.zero_grad()
    ret_loss = compute_retrieval_loss(model_clean, seq_len=256, batch_size=8)
    ret_loss.backward()
    torch.nn.utils.clip_grad_norm_(model_clean.parameters(), 1.0)
    optimizer.step()
    
    if step % 200 == 0:
        model_clean.eval()
        acc = test_varied_needles_wide(model_clean, seq_len=256, needle_pos=32, n_trials=20)
        model_clean.train()
        print(f"Step {step:4d}: loss={ret_loss.item():.3f}, acc={acc*100:.0f}%")
        if acc >= 0.95:
            print(f"  → Reached 95%+")
            break

# Test
print("\n--- EVALUATION ---")
model_clean.eval()
print("\n1. Single needle:")
for pos in [32, 64, 128]:
    test_varied_needles_wide(model_clean, seq_len=256, needle_pos=pos, n_trials=30)

print("\n2. Multi-needle:")
for n in [2, 3, 5]:
    test_multi_needle(model_clean, seq_len=512, n_needles=n, n_trials=30)

TESTING: NO RoPE in GDN, RoPE in SWA
Model: 33,217,560 params
Forward pass OK: torch.Size([2, 64, 50257])

--- Training Pure Retrieval ---
  Accuracy (full vocab range): 0.0% (0/20)
Step    0: loss=11.000, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  200: loss=10.938, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  400: loss=11.000, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  600: loss=10.812, acc=0%
  Accuracy (full vocab range): 0.0% (0/20)
Step  800: loss=10.625, acc=0%
  Accuracy (full vocab range): 65.0% (13/20)
Step 1000: loss=6.438, acc=65%
  Accuracy (full vocab range): 100.0% (20/20)
Step 1200: loss=4.625, acc=100%
  → Reached 95%+

--- EVALUATION ---

1. Single needle:
  Accuracy (full vocab range): 100.0% (30/30)
  Accuracy (full vocab range): 100.0% (30/30)
  Accuracy (full vocab range): 93.3% (28/30)

2. Multi-needle:
  2 needles in seq_len=512:
    Per-needle retrieval: ['0%', '100%']
    Any needle correct: 100.0%
    Most retrieved: needle

In [56]:
# =============================================================================
# EXPERIMENT: Context-Dependent Keys for Multi-Needle Retrieval
# =============================================================================
# PROBLEM: All MARKER tokens produce identical keys because keys are computed 
# from token embeddings only. When multiple MARKERs appear, the delta rule 
# overwrites previous associations (last needle wins).
#
# SOLUTION: Make keys depend on LOCAL CONTEXT, not just the token itself.
# - Use causal convolution over the last few tokens
# - Each MARKER now has a unique key based on what came before it

import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from typing import Optional, Tuple, Dict, List
from config import HybridConfig
from core import chunk_delta_rule
from model import RMSNorm, SwiGLUFFN, SlidingWindowAttention

class GDNWithContextKeys(nn.Module):
    """GDN with context-dependent keys via causal convolution."""
    def __init__(self, cfg: HybridConfig, layer_idx: int = 0, context_size: int = 4):
        super().__init__()
        self.cfg = cfg
        self.layer_idx = layer_idx
        self.context_size = context_size
        H, K, V = cfg.n_heads, cfg.head_dim, cfg.value_dim
        
        # Context convolution for keys (causal)
        self.key_conv = nn.Conv1d(
            in_channels=cfg.d_model,
            out_channels=cfg.d_model, 
            kernel_size=context_size,
            padding=context_size - 1,
            groups=1
        )
        
        self.k_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.v_proj = nn.Linear(cfg.d_model, H * V, bias=False)
        self.o_proj = nn.Linear(H * V, cfg.d_model, bias=False)
        
        self.beta_proj = nn.Linear(cfg.d_model, H, bias=True)
        nn.init.constant_(self.beta_proj.bias, cfg.beta_bias)
        self.g_proj = nn.Linear(cfg.d_model, H, bias=True)
        nn.init.constant_(self.g_proj.bias, cfg.g_bias)
        
        self.norm = RMSNorm(cfg.d_model)
        self.use_shifted_value = getattr(cfg, 'shifted_value', True)
        
    def forward(self, x: torch.Tensor, initial_state: Optional[torch.Tensor] = None
               ) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
        B, T, D = x.shape
        H, K, V = self.cfg.n_heads, self.cfg.head_dim, self.cfg.value_dim
        
        x_norm = self.norm(x)
        
        # CONTEXT-DEPENDENT KEYS
        x_conv = x_norm.transpose(1, 2)  # [B, D, T]
        x_conv = self.key_conv(x_conv)[:, :, :T]  # causal
        x_conv = x_conv.transpose(1, 2)  # [B, T, D]
        
        k_full = self.k_proj(x_conv).view(B, T, H, K)  # Context-dependent!
        v_full = self.v_proj(x_norm).view(B, T, H, V)
        
        if self.use_shifted_value and T > 1:
            k = k_full[:, :-1]
            v = v_full[:, 1:]
            beta = torch.sigmoid(self.beta_proj(x_norm[:, :-1]))
            g = torch.sigmoid(self.g_proj(x_norm[:, :-1]))
            T_eff = T - 1
        else:
            k, v = k_full, v_full
            beta = torch.sigmoid(self.beta_proj(x_norm))
            g = torch.sigmoid(self.g_proj(x_norm))
            T_eff = T
        
        k = F.normalize(k.float(), p=2, dim=-1).to(x.dtype)
        
        if initial_state is None:
            initial_state = torch.zeros(B, H, K, V, device=x.device, dtype=x.dtype)
        
        o, final_state = chunk_delta_rule(k, v, beta, g, initial_state, chunk_size=32)
        
        o_flat = o.reshape(B, T_eff, H * V)
        out = self.o_proj(o_flat)
        
        if self.use_shifted_value and T > 1:
            out_full = torch.zeros(B, T, D, device=x.device, dtype=x.dtype)
            out_full[:, 1:] = out
        else:
            out_full = out
        
        return x + out_full, final_state, {'gate_g': g.mean(), 'gate_beta': beta.mean()}


class TransparentHybridContextKeys(nn.Module):
    """Hybrid model with context-dependent GDN keys."""
    def __init__(self, cfg: HybridConfig, context_size: int = 4):
        super().__init__()
        self.cfg = cfg
        self.context_size = context_size
        
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.layers = nn.ModuleList()
        self.ffns = nn.ModuleList()
        
        for i, layer_type in enumerate(cfg.layer_pattern):
            if layer_type == 'G':
                self.layers.append(GDNWithContextKeys(cfg, i, context_size))
            else:
                self.layers.append(SlidingWindowAttention(cfg, i))
            self.ffns.append(SwiGLUFFN(cfg.d_model))
        
        self.norm_f = RMSNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        self.embed.weight = self.lm_head.weight
    
    def forward(self, input_ids, return_diagnostics=False):
        x = self.embed(input_ids)
        state = None
        all_diag = []
        
        for i, (layer, ffn) in enumerate(zip(self.layers, self.ffns)):
            lt = self.cfg.layer_pattern[i]
            if lt == 'G':
                x, state, diag = layer(x, initial_state=state)
            else:
                x, diag = layer(x, gdn_state=state)
            x = ffn(x)
            all_diag.append(diag)
        
        logits = self.lm_head(self.norm_f(x))
        return (logits, state, all_diag) if return_diagnostics else (logits, state)


# Test
print("=" * 60)
print("TESTING: Context-Dependent Keys for Multi-Needle")
print("=" * 60)

cfg_ctx = HybridConfig(
    d_model=512,
    n_heads=8,
    head_dim=64,
    value_dim=128,
    layer_pattern="GS",
    vocab_size=50257,
    shifted_value=True,
    use_rope=True,
)

model_ctx = TransparentHybridContextKeys(cfg_ctx, context_size=8).to(DEVICE).to(torch.bfloat16)
print(f"Model params: {sum(p.numel() for p in model_ctx.parameters()):,}")

# Forward test
x_test = torch.randint(0, 50257, (2, 64), device=DEVICE)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    logits, _ = model_ctx(x_test)
print(f"Forward OK: {logits.shape}")

# Train (no GradScaler needed for bf16)
optimizer = torch.optim.AdamW(model_ctx.parameters(), lr=1e-3)
marker_token = 50256

for step in range(2001):
    seq_len = 256
    batch_size = 8
    
    n_needles = random.randint(2, 4)
    positions = sorted(random.sample(range(20, seq_len - 20), n_needles))
    values = [random.randint(1000, 5000) for _ in range(n_needles)]
    
    seq = torch.randint(100, 1000, (batch_size, seq_len), device=DEVICE)
    cue_idx = random.randint(0, n_needles - 1)
    
    for b in range(batch_size):
        for i, (pos, val) in enumerate(zip(positions, values)):
            seq[b, pos] = marker_token
            seq[b, pos + 1] = val
        seq[b, -2] = marker_token
        seq[b, -1] = values[cue_idx]
    
    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        logits, _ = model_ctx(seq)
        loss = F.cross_entropy(logits[:, -2].float(), seq[:, -1])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if step % 200 == 0:
        with torch.no_grad():
            pred = logits[:, -2].argmax(dim=-1)
            acc = (pred == seq[:, -1]).float().mean().item() * 100
        print(f"Step {step:4d}: loss={loss.item():.3f}, acc={acc:.0f}%")
        if acc >= 90:
            print("  → Good accuracy!")
            break

# Evaluate
print("\n--- MULTI-NEEDLE EVALUATION ---")

for n_needles in [2, 3, 5]:
    per_needle = [0] * n_needles
    total = 30
    
    for _ in range(total):
        positions = sorted(random.sample(range(20, 400), n_needles))
        values = [random.randint(1000, 5000) for _ in range(n_needles)]
        
        seq = torch.randint(100, 1000, (1, 512), device=DEVICE)
        for pos, val in zip(positions, values):
            seq[0, pos] = marker_token
            seq[0, pos + 1] = val
        
        for needle_idx in range(n_needles):
            test_seq = seq.clone()
            test_seq[0, -2] = marker_token  
            test_seq[0, -1] = 0
            
            with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                logits, _ = model_ctx(test_seq)
                pred = logits[0, -2].argmax().item()
            
            if pred == values[needle_idx]:
                per_needle[needle_idx] += 1
    
    accs = [f"{100*c/total:.0f}%" for c in per_needle]
    print(f"  {n_needles} needles: {accs}")

TESTING: Context-Dependent Keys for Multi-Needle
Model params: 35,315,224
Forward OK: torch.Size([2, 64, 50257])
Step    0: loss=10.949, acc=0%
Step  200: loss=11.042, acc=0%
Step  400: loss=2.267, acc=0%
Step  600: loss=0.020, acc=100%
  → Good accuracy!

--- MULTI-NEEDLE EVALUATION ---
  2 needles: ['0%', '0%']
  3 needles: ['0%', '0%', '0%']
  5 needles: ['0%', '0%', '0%', '0%', '0%']


In [57]:
# =============================================================================
# EXPERIMENT 2: Multi-Slot Markers (Different Markers for Different Slots)
# =============================================================================
# INSIGHT: We can't use the SAME marker for storage and retrieval if we want
# multi-needle. The cue needs to uniquely identify WHICH slot to retrieve.
#
# SOLUTION: Use slot-specific markers.
# - MARKER_0, MARKER_1, MARKER_2, ... for different slots
# - Store: MARKER_i + value_i  
# - Retrieve: MARKER_i + ? → should output value_i

import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from config import HybridConfig
from model import TransparentHybrid

print("=" * 60)
print("TESTING: Slot-Specific Markers (MARKER_0, MARKER_1, ...)")
print("=" * 60)

# Use vocabulary slots for slot markers
# Let's use tokens 50000-50010 as slot markers
MARKER_OFFSET = 50000
N_SLOTS = 10

cfg_slots = HybridConfig(
    d_model=512,
    n_heads=8,
    head_dim=64,
    value_dim=128,
    layer_pattern="GS",
    vocab_size=50257,
    shifted_value=True,
    use_rope=True,
)

model_slots = TransparentHybrid(cfg_slots).to(DEVICE).to(torch.bfloat16)
print(f"Model params: {sum(p.numel() for p in model_slots.parameters()):,}")

# Train with slot-specific markers
optimizer = torch.optim.AdamW(model_slots.parameters(), lr=1e-3)

for step in range(2001):
    seq_len = 256
    batch_size = 8
    
    # Random number of needles with UNIQUE slot markers
    n_needles = random.randint(2, min(5, N_SLOTS))
    slots = random.sample(range(N_SLOTS), n_needles)  # Which slot markers to use
    positions = sorted(random.sample(range(20, seq_len - 20), n_needles))
    values = [random.randint(1000, 5000) for _ in range(n_needles)]
    
    seq = torch.randint(100, 1000, (batch_size, seq_len), device=DEVICE)
    
    # Pick which slot/needle to retrieve
    cue_idx = random.randint(0, n_needles - 1)
    cue_slot = slots[cue_idx]
    
    for b in range(batch_size):
        for i, (pos, slot, val) in enumerate(zip(positions, slots, values)):
            seq[b, pos] = MARKER_OFFSET + slot  # Slot-specific marker
            seq[b, pos + 1] = val
        seq[b, -2] = MARKER_OFFSET + cue_slot  # CUE uses the SAME slot marker
        seq[b, -1] = values[cue_idx]  # Should retrieve the value for this slot
    
    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_slots(seq)
        loss = F.cross_entropy(logits[:, -2].float(), seq[:, -1])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if step % 200 == 0:
        with torch.no_grad():
            pred = logits[:, -2].argmax(dim=-1)
            acc = (pred == seq[:, -1]).float().mean().item() * 100
        print(f"Step {step:4d}: loss={loss.item():.3f}, acc={acc:.0f}%")
        if acc >= 95:
            print("  → Reached 95%!")

# Evaluate: For each slot, can we retrieve the correct value?
print("\n--- SLOT-SPECIFIC MULTI-NEEDLE EVALUATION ---")

for n_needles in [2, 3, 5]:
    correct = 0
    total = 0
    
    for _ in range(30):
        slots = random.sample(range(N_SLOTS), n_needles)
        positions = sorted(random.sample(range(20, 400), n_needles))
        values = [random.randint(1000, 5000) for _ in range(n_needles)]
        
        seq = torch.randint(100, 1000, (1, 512), device=DEVICE)
        for pos, slot, val in zip(positions, slots, values):
            seq[0, pos] = MARKER_OFFSET + slot
            seq[0, pos + 1] = val
        
        # Test each slot
        for needle_idx in range(n_needles):
            slot = slots[needle_idx]
            test_seq = seq.clone()
            test_seq[0, -2] = MARKER_OFFSET + slot  # Cue with the correct slot marker
            test_seq[0, -1] = 0
            
            with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                logits, _, _, _ = model_slots(test_seq)
                pred = logits[0, -2].argmax().item()
            
            if pred == values[needle_idx]:
                correct += 1
            total += 1
    
    print(f"  {n_needles} needles: {100*correct/total:.0f}% ({correct}/{total})")

TESTING: Slot-Specific Markers (MARKER_0, MARKER_1, ...)
Model params: 33,217,560
Step    0: loss=10.760, acc=0%
Step  200: loss=10.909, acc=0%
Step  400: loss=11.216, acc=0%
Step  600: loss=10.700, acc=0%
Step  800: loss=9.360, acc=0%
Step 1000: loss=5.165, acc=0%
Step 1200: loss=10.666, acc=0%
Step 1400: loss=4.709, acc=0%
Step 1600: loss=10.890, acc=0%
Step 1800: loss=4.807, acc=0%
Step 2000: loss=10.718, acc=0%

--- SLOT-SPECIFIC MULTI-NEEDLE EVALUATION ---
  2 needles: 0% (0/60)
  3 needles: 0% (0/90)
  5 needles: 0% (0/150)


In [58]:
# =============================================================================
# EXPERIMENT 3: Curriculum - Single Slot First, Then More
# =============================================================================
# Issue: Direct multi-slot training doesn't converge
# Solution: Curriculum - start with 1 slot, then add more progressively

import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from config import HybridConfig
from model import TransparentHybrid

print("=" * 60)
print("CURRICULUM: Single Slot → Multi-Slot")
print("=" * 60)

MARKER_OFFSET = 50000
N_SLOTS = 10

cfg_curr = HybridConfig(
    d_model=512,
    n_heads=8,
    head_dim=64,
    value_dim=128,
    layer_pattern="GS",
    vocab_size=50257,
    shifted_value=True,
    use_rope=True,
)

model_curr = TransparentHybrid(cfg_curr).to(DEVICE).to(torch.bfloat16)
optimizer = torch.optim.AdamW(model_curr.parameters(), lr=1e-3)

def train_n_slots(model, opt, n_slots_train, steps, print_every=100):
    """Train with exactly n_slots_train needles/slots."""
    for step in range(steps):
        seq_len = 256
        batch_size = 8
        
        # Use specific slots
        slots = list(range(n_slots_train))  # Use slots 0, 1, 2, ...
        positions = sorted(random.sample(range(20, seq_len - 20), n_slots_train))
        values = [random.randint(1000, 5000) for _ in range(n_slots_train)]
        
        seq = torch.randint(100, 1000, (batch_size, seq_len), device=DEVICE)
        
        cue_idx = random.randint(0, n_slots_train - 1)
        cue_slot = slots[cue_idx]
        
        for b in range(batch_size):
            for i, (pos, slot, val) in enumerate(zip(positions, slots, values)):
                seq[b, pos] = MARKER_OFFSET + slot
                seq[b, pos + 1] = val
            seq[b, -2] = MARKER_OFFSET + cue_slot
            seq[b, -1] = values[cue_idx]
        
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            logits, _, _, _ = model(seq)
            loss = F.cross_entropy(logits[:, -2].float(), seq[:, -1])
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        if step % print_every == 0:
            with torch.no_grad():
                pred = logits[:, -2].argmax(dim=-1)
                acc = (pred == seq[:, -1]).float().mean().item() * 100
            print(f"  [{n_slots_train} slots] Step {step:4d}: loss={loss.item():.3f}, acc={acc:.0f}%")
            if acc >= 95:
                return True, step
    return False, steps


def evaluate_slots(model, n_slots_eval, n_trials=30):
    """Evaluate retrieval accuracy for each slot."""
    correct_per_slot = [0] * n_slots_eval
    
    for _ in range(n_trials):
        slots = list(range(n_slots_eval))
        positions = sorted(random.sample(range(20, 400), n_slots_eval))
        values = [random.randint(1000, 5000) for _ in range(n_slots_eval)]
        
        seq = torch.randint(100, 1000, (1, 512), device=DEVICE)
        for pos, slot, val in zip(positions, slots, values):
            seq[0, pos] = MARKER_OFFSET + slot
            seq[0, pos + 1] = val
        
        for slot_idx in range(n_slots_eval):
            test_seq = seq.clone()
            test_seq[0, -2] = MARKER_OFFSET + slot_idx
            test_seq[0, -1] = 0
            
            with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                logits, _, _, _ = model(test_seq)
                pred = logits[0, -2].argmax().item()
            
            if pred == values[slot_idx]:
                correct_per_slot[slot_idx] += 1
    
    accs = [f"{100*c/n_trials:.0f}%" for c in correct_per_slot]
    return accs, sum(correct_per_slot) / (n_trials * n_slots_eval) * 100


# Phase 1: Single slot
print("\n--- Phase 1: Single Slot ---")
success, steps = train_n_slots(model_curr, optimizer, 1, steps=1000, print_every=200)
accs, overall = evaluate_slots(model_curr, 1)
print(f"  Eval: {accs}, Overall: {overall:.0f}%")

if overall >= 80:
    # Phase 2: Two slots
    print("\n--- Phase 2: Two Slots ---")
    success, steps = train_n_slots(model_curr, optimizer, 2, steps=1000, print_every=200)
    accs, overall = evaluate_slots(model_curr, 2)
    print(f"  Eval: {accs}, Overall: {overall:.0f}%")
    
    if overall >= 80:
        # Phase 3: Three slots
        print("\n--- Phase 3: Three Slots ---")
        success, steps = train_n_slots(model_curr, optimizer, 3, steps=1000, print_every=200)
        accs, overall = evaluate_slots(model_curr, 3)
        print(f"  Eval: {accs}, Overall: {overall:.0f}%")

        if overall >= 80:
            # Phase 4: Five slots
            print("\n--- Phase 4: Five Slots ---")
            success, steps = train_n_slots(model_curr, optimizer, 5, steps=1500, print_every=200)
            accs, overall = evaluate_slots(model_curr, 5)
            print(f"  Eval: {accs}, Overall: {overall:.0f}%")

CURRICULUM: Single Slot → Multi-Slot

--- Phase 1: Single Slot ---
  [1 slots] Step    0: loss=10.819, acc=0%
  [1 slots] Step  200: loss=10.651, acc=0%
  [1 slots] Step  400: loss=10.784, acc=0%


KeyboardInterrupt: 

In [59]:
# =============================================================================
# UNDERSTANDING THE ARCHITECTURE LIMITATION
# =============================================================================
# The Delta Rule stores (k_t, v_t) where BOTH come from token t.
# So if we see token X, we store X's embedding. 
# We can query for X and get X back, but NOT what came AFTER X.

# NIAH requires: see MARKER, store VALUE that follows
# But architecture stores: see MARKER, store MARKER's embedding

# Let's verify with a simpler test:
# "Repeat back the token you saw at position P when you see it as a cue"

def test_echo_retrieval(model, seq_len=128, n_trials=30):
    """
    Test: Can the model echo back a specific token when cued with itself?
    
    Format:
    - Place unique token X at position P
    - At end, place X again as cue
    - Should predict X (the token itself, not what followed)
    
    This matches what the architecture can actually store.
    """
    model.eval()
    device = next(model.parameters()).device
    cfg = model.cfg
    
    correct = 0
    for trial in range(n_trials):
        # Use a unique token for this trial
        unique_token = cfg.vocab_size - 100 + (trial % 50)
        
        seq = torch.randint(0, cfg.vocab_size - 200, (1, seq_len), device=device)
        
        # Place unique token early in sequence
        pos = torch.randint(10, seq_len // 2, (1,)).item()
        seq[0, pos] = unique_token
        
        # Place same token at end as "cue" for retrieval
        seq[0, -1] = unique_token
        
        with torch.no_grad():
            logits, _, _, _ = model(seq)
        
        # After seeing unique_token again, what does the model predict?
        pred = logits[0, -1].argmax().item()
        
        # For "echo" we just check if model recognizes this token
        # If state stores token info, the prediction should be influenced by having seen it before
        # But this is more of a "does the model remember seeing this" test
        
        # Actually, let's check if the logit for unique_token is in top-5
        top5 = logits[0, -1].topk(5).indices.tolist()
        if unique_token in top5:
            correct += 1
    
    acc = correct / n_trials
    print(f"  Echo retrieval (unique token in top-5): {acc*100:.1f}% ({correct}/{n_trials})")
    return acc

print("Testing echo retrieval (does model remember seeing a token):")
print("\n--- FRESH MODEL V2 (regularized) ---")
echo_fresh = test_echo_retrieval(fresh_model_v2, seq_len=256, n_trials=30)

print("\n--- UNTRAINED MODEL ---") 
echo_untrained = test_echo_retrieval(untrained_model, seq_len=256, n_trials=30)

# Expected: if state helps, trained model should score higher than untrained

Testing echo retrieval (does model remember seeing a token):

--- FRESH MODEL V2 (regularized) ---
  Echo retrieval (unique token in top-5): 0.0% (0/30)

--- UNTRAINED MODEL ---
  Echo retrieval (unique token in top-5): 13.3% (4/30)


In [10]:
# =============================================================================
# SAVE MODEL
# =============================================================================

# torch.save({
#     'model_state_dict': model.state_dict(),
#     'config': cfg,
#     'history': history,
# }, 'groundthink_v7_checkpoint.pt')

print("\n✓ Training complete!")


✓ Training complete!


### [SYNC LOG] analyze_gradients marker/cue logic updated
- Marker and cue positions in analyze_gradients are now set relative to seq_len (marker at seq_len//4, cue at seq_len-1).
- This ensures gradient analysis tests are meaningful for any sequence length.


### [SYNC LOG] All test/reporting functions now use NEW_T
- Calls to proper_niah_test, test_niah_by_distance, run_full_diagnostic, and analyze_gradients now explicitly use seq_len=NEW_T.
- This ensures all evaluation and reporting is consistent with the configured sequence length.


In [62]:
# =============================================================================
# DIAGNOSTIC: Interference Scaling & Key Similarity
# =============================================================================
# Before trying fixes, measure the ACTUAL problem:
# 1. How does retrieval error scale with number of writes?
# 2. How similar are our keys? (Are collisions even the issue?)

import torch
import torch.nn.functional as F
import importlib

import config as config_module
import model as model_module
config_module = importlib.reload(config_module)
model_module = importlib.reload(model_module)

from config import HybridConfig
from model import TransparentHybrid

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("=" * 60)
print("INTERFERENCE & KEY SIMILARITY DIAGNOSTIC")
print("=" * 60)

cfg = HybridConfig(
    d_model=512, n_heads=8, head_dim=64, value_dim=128,
    layer_pattern="GS", shifted_value=True,
)
H, K, V = cfg.n_heads, cfg.head_dim, cfg.value_dim
print(f"Config: H={H}, K={K}, V={V}")

# =============================================================================
# TEST 1: Key Similarity Distribution
# =============================================================================
print("\n" + "-" * 40)
print("TEST 1: Key Similarity Distribution")
print("-" * 40)

# Generate random normalized keys (what our model produces)
n_keys = 1000
test_keys = F.normalize(torch.randn(n_keys, K, device=DEVICE), dim=-1)

# Compute all pairwise dot products
dots = test_keys @ test_keys.T  # [n_keys, n_keys]

# Get upper triangle (exclude diagonal = self-similarity)
triu_mask = torch.triu(torch.ones(n_keys, n_keys, device=DEVICE), diagonal=1).bool()
triu_dots = dots[triu_mask]

print(f"Random {K}D unit vectors (n={n_keys}):")
print(f"  Mean |dot|:       {triu_dots.abs().mean().item():.4f}")
print(f"  Std |dot|:        {triu_dots.abs().std().item():.4f}")
print(f"  Max |dot|:        {triu_dots.abs().max().item():.4f}")
print(f"  % with |dot|>0.3: {(triu_dots.abs() > 0.3).float().mean().item()*100:.2f}%")
print(f"  % with |dot|>0.5: {(triu_dots.abs() > 0.5).float().mean().item()*100:.2f}%")

# Theoretical expectation for random unit vectors in K dimensions
# E[|dot|] ≈ sqrt(2/pi) / sqrt(K)
expected_mean = (2/3.14159)**0.5 / (K**0.5)
print(f"\n  Theoretical E[|dot|] for K={K}: {expected_mean:.4f}")

# =============================================================================
# TEST 2: Interference Scaling with Delta Rule
# =============================================================================
print("\n" + "-" * 40)
print("TEST 2: Interference Scaling (Delta Rule)")
print("-" * 40)

def test_interference(n_writes, n_trials=10):
    """Test retrieval error after n_writes using exact delta rule."""
    errors = []
    
    for _ in range(n_trials):
        state = torch.zeros(1, H, K, V, device=DEVICE)
        keys, values = [], []
        
        for i in range(n_writes):
            k = F.normalize(torch.randn(1, H, K, device=DEVICE), dim=-1)
            v = torch.randn(1, H, V, device=DEVICE)
            keys.append(k)
            values.append(v)
            
            # Delta rule: S += β * (v - S·k) ⊗ k
            # With β=1 (full write)
            pred = torch.einsum('bhkv,bhk->bhv', state, k)
            error = v - pred
            outer = torch.einsum('bhv,bhk->bhkv', error, k)
            state = state + outer
        
        # Retrieve FIRST item written
        retrieved = torch.einsum('bhkv,bhk->bhv', state, keys[0])
        rel_error = (retrieved - values[0]).norm() / values[0].norm()
        errors.append(rel_error.item())
    
    return sum(errors) / len(errors)

print(f"{'n_writes':>10} | {'Mean Rel. Error':>15} | {'Interpretation':>20}")
print("-" * 50)

for n in [1, 5, 10, 25, 50, 100, 200, 500]:
    err = test_interference(n, n_trials=20)
    if err < 0.01:
        interp = "Perfect"
    elif err < 0.1:
        interp = "Good"
    elif err < 0.5:
        interp = "Degraded"
    else:
        interp = "FAILED"
    print(f"{n:>10} | {err:>15.4f} | {interp:>20}")

# =============================================================================
# TEST 3: What about LEARNED keys from actual model?
# =============================================================================
print("\n" + "-" * 40)
print("TEST 3: Learned Key Similarity (Real Model)")
print("-" * 40)

model = TransparentHybrid(cfg).to(DEVICE).to(torch.bfloat16)

# Get GDN layer
gdn_layer = [l for l in model.layers if hasattr(l, 'k_proj')][0]

# Generate keys for random tokens
n_tokens = 500
tokens = torch.randint(100, cfg.vocab_size - 100, (1, n_tokens), device=DEVICE)

with torch.no_grad():
    emb = model.embed(tokens)
    x_norm = gdn_layer.norm(emb)
    keys = gdn_layer.k_proj(x_norm).view(1, n_tokens, H, K)
    keys = F.normalize(keys.float(), dim=-1)  # [1, n_tokens, H, K]

# Check similarity per head
print("Per-head key similarity (untrained model):")
for h in range(H):
    k_h = keys[0, :, h, :]  # [n_tokens, K]
    dots_h = k_h @ k_h.T
    triu_h = dots_h[triu_mask[:n_tokens, :n_tokens]]
    print(f"  Head {h}: mean|dot|={triu_h.abs().mean().item():.4f}, max={triu_h.abs().max().item():.4f}")

# Check MARKER token keys specifically
print("\n" + "-" * 40)
print("TEST 4: MARKER Token Key Similarity")
print("-" * 40)

# Create sequences with MARKER at different positions
seq = torch.randint(100, 1000, (1, 256), device=DEVICE)
marker_positions = [20, 50, 100, 150, 200]
for pos in marker_positions:
    seq[0, pos] = cfg.marker_token

with torch.no_grad():
    emb = model.embed(seq)
    x_norm = gdn_layer.norm(emb)
    all_keys = gdn_layer.k_proj(x_norm).view(1, 256, H, K)
    all_keys = F.normalize(all_keys.float(), dim=-1)
    
    # Extract marker keys
    marker_keys = all_keys[0, marker_positions, :, :]  # [5, H, K]

# Pairwise similarity between MARKER keys at different positions
print("Similarity between MARKER keys at different positions:")
for i in range(len(marker_positions)):
    for j in range(i+1, len(marker_positions)):
        k_i = marker_keys[i]  # [H, K]
        k_j = marker_keys[j]  # [H, K]
        sim = (k_i * k_j).sum(dim=-1)  # [H] - per head similarity
        print(f"  pos {marker_positions[i]:3d} vs {marker_positions[j]:3d}: mean_sim={sim.mean().item():.4f}, per_head=[{', '.join(f'{s:.2f}' for s in sim.tolist())}]")

# =============================================================================
# CONCLUSION
# =============================================================================
print("\n" + "=" * 60)
print("DIAGNOSIS SUMMARY")
print("=" * 60)
print("""
Key findings:
1. Random keys in K=64 are reasonably orthogonal (mean |dot| ≈ 0.1)
2. Delta rule retrieval degrades with more writes (check numbers above)
3. MARKER tokens at different positions have IDENTICAL keys (similarity ≈ 1.0)
   → This is WHY multi-needle fails: same key = overwrite at same address

The fix needed: Make MARKER keys POSITION-DEPENDENT, not token-dependent.
Options:
  A) Add positional encoding to keys (RoPE, learned PE)
  B) Context-dependent keys (conv, local attention on key projection)
  C) Per-head positional specialization
""")

INTERFERENCE & KEY SIMILARITY DIAGNOSTIC
Config: H=8, K=64, V=128

----------------------------------------
TEST 1: Key Similarity Distribution
----------------------------------------
Random 64D unit vectors (n=1000):
  Mean |dot|:       0.1002
  Std |dot|:        0.0748
  Max |dot|:        0.5486
  % with |dot|>0.3: 1.51%
  % with |dot|>0.5: 0.00%

  Theoretical E[|dot|] for K=64: 0.0997

----------------------------------------
TEST 2: Interference Scaling (Delta Rule)
----------------------------------------
  n_writes | Mean Rel. Error |       Interpretation
--------------------------------------------------
         1 |          0.0000 |              Perfect
         5 |          0.2621 |             Degraded
        10 |          0.3899 |             Degraded
        25 |          0.6457 |               FAILED
        50 |          0.9160 |               FAILED
       100 |          1.1921 |               FAILED
       200 |          1.3709 |               FAILED
       500 |   

In [65]:
"""
============================================================
PURE MATH TEST: 3 SOLUTIONS FOR MULTI-NEEDLE RETRIEVAL
============================================================
The delta rule: S_t = S_{t-1} + β * (v - S_{t-1}·k) ⊗ k
Retrieval:      r = S · q

Problem: With identical keys (k_1 = k_2 = ... = k_n), only last write survives.

Testing 3 solutions at the MATH level (no NN training):
A) Positional encoding on keys (make keys position-dependent)
B) Orthogonal key projection (project to orthogonal subspace per write)
C) Error-correcting writes (modify value to account for existing state)
"""

import torch
import torch.nn.functional as F
import math

torch.manual_seed(42)

K = 64   # Key dimension
V = 128  # Value dimension  
H = 8    # Heads (for testing per-head variations)

def simulate_delta_rule(keys, values, queries, beta=1.0):
    """
    Pure delta rule simulation.
    keys: [n_writes, K]
    values: [n_writes, V]
    queries: [n_queries, K]
    Returns: [n_queries, V] retrieved values, [n_queries] retrieval errors
    
    State S is [K, V], retrieval is q @ S = [V]
    """
    n_writes = keys.shape[0]
    S = torch.zeros(K, V)  # State matrix [K, V]
    
    # Sequential writes
    for i in range(n_writes):
        k = keys[i]  # [K]
        v = values[i]  # [V]
        # Delta rule: S += β * (v - k·S) ⊗ k  
        # k·S = [K] @ [K,V] but we need to do k^T @ S for retrieval
        # Actually: S is [K,V], k is [K], so k @ S doesn't work directly
        # Standard formulation: S[K,V], retrieve with q gives q @ S = [V]
        # For delta rule write: S += k ⊗ (v - S^T @ k)^T = k ⊗ (v - (k @ S))
        # Wait, let's be careful: if S is [K,V] and we retrieve with q @ S
        # then k @ S gives [V], which is current value at key k
        current = k @ S  # [V] - what's currently stored at this key
        delta = v - current  # [V]
        S = S + beta * torch.outer(k, delta)  # [K,V] + outer([K], [V]) = [K,V]
    
    # Retrieval
    retrieved = queries @ S  # [n_queries, V]
    
    # Compute errors (assuming queries[i] should retrieve values[i])
    errors = []
    for i in range(min(len(queries), len(values))):
        err = torch.norm(retrieved[i] - values[i]) / torch.norm(values[i])
        errors.append(err.item())
    
    return retrieved, errors

print("=" * 70)
print("SOLUTION A: Positional Encoding on Keys")
print("=" * 70)
print("\nIdea: k_i = f(token) + g(position)")
print("If g(pos) varies enough, keys become distinguishable.\n")

# Simulate: base key (same for all) + positional component
n_writes = 5
base_key = F.normalize(torch.randn(K), dim=0)
values = F.normalize(torch.randn(n_writes, V), dim=1)

# Different positional encoding strengths
for pos_scale in [0.0, 0.1, 0.3, 0.5, 1.0]:
    # Create position-dependent keys
    keys = []
    for i in range(n_writes):
        # Sinusoidal positional encoding
        pos_enc = torch.zeros(K)
        for j in range(K // 2):
            freq = 1.0 / (10000 ** (2 * j / K))
            pos_enc[2*j] = math.sin(i * freq)
            pos_enc[2*j + 1] = math.cos(i * freq)
        
        key_i = base_key + pos_scale * pos_enc
        key_i = F.normalize(key_i, dim=0)  # Re-normalize
        keys.append(key_i)
    
    keys = torch.stack(keys)
    
    # Use same keys for queries (exact match retrieval)
    _, errors = simulate_delta_rule(keys, values, keys)
    mean_err = sum(errors) / len(errors)
    
    # Check key similarity
    key_sims = []
    for i in range(n_writes):
        for j in range(i+1, n_writes):
            key_sims.append(torch.dot(keys[i], keys[j]).item())
    mean_sim = sum(key_sims) / len(key_sims) if key_sims else 1.0
    
    status = "✓" if mean_err < 0.1 else "✗"
    print(f"  pos_scale={pos_scale:.1f}: mean_err={mean_err:.4f}, key_sim={mean_sim:.4f} {status}")

print("\n" + "=" * 70)
print("SOLUTION B: Orthogonal Key Projection")
print("=" * 70)
print("\nIdea: Force keys to be orthogonal to each other.")
print("Use Gram-Schmidt or random orthogonal basis.\n")

# Use truly orthogonal keys (via QR decomposition)
for n_writes in [2, 5, 10, 25, 50]:
    # Generate random orthogonal keys
    random_matrix = torch.randn(n_writes, K)
    if n_writes <= K:
        Q, R = torch.linalg.qr(random_matrix.T)
        keys = Q[:, :n_writes].T  # [n_writes, K], orthonormal
    else:
        # More writes than dimensions - can't be orthogonal
        keys = F.normalize(random_matrix, dim=1)
    
    values = F.normalize(torch.randn(n_writes, V), dim=1)
    
    _, errors = simulate_delta_rule(keys, values, keys)
    mean_err = sum(errors) / len(errors)
    
    # Verify orthogonality
    if n_writes <= K:
        ortho_err = torch.norm(keys @ keys.T - torch.eye(n_writes)).item()
    else:
        ortho_err = float('inf')
    
    status = "✓" if mean_err < 0.1 else "✗"
    print(f"  n_writes={n_writes:2d}: mean_err={mean_err:.6f}, ortho_err={ortho_err:.6f} {status}")

print("\n" + "=" * 70)
print("SOLUTION C: Error-Correcting Writes")  
print("=" * 70)
print("\nIdea: Before writing v, compute v' = v + correction")
print("where correction accounts for interference from existing state.\n")

def simulate_delta_with_correction(keys, values, queries, beta=1.0, n_iters=5):
    """
    Iterative error correction: adjust values to compensate for interference.
    """
    n_writes = keys.shape[0]
    
    # Start with original values, iteratively correct
    corrected_values = values.clone()
    
    for iteration in range(n_iters):
        S = torch.zeros(K, V)
        
        # Write with current corrected values
        for i in range(n_writes):
            k = keys[i]
            v = corrected_values[i]
            current = k @ S  # Fixed: k @ S not S @ k
            delta = v - current
            S = S + beta * torch.outer(k, delta)
        
        # Measure retrieval errors and compute corrections
        for i in range(n_writes):
            retrieved = keys[i] @ S
            error = values[i] - retrieved  # What we wanted - what we got
            corrected_values[i] = corrected_values[i] + 0.5 * error  # Partial correction
    
    # Final retrieval
    S = torch.zeros(K, V)
    for i in range(n_writes):
        k = keys[i]
        v = corrected_values[i]
        current = k @ S  # Fixed
        delta = v - current
        S = S + beta * torch.outer(k, delta)
    
    retrieved = queries @ S
    errors = []
    for i in range(min(len(queries), len(values))):
        err = torch.norm(retrieved[i] - values[i]) / torch.norm(values[i])
        errors.append(err.item())
    
    return retrieved, errors

# Test with identical keys (worst case)
n_writes = 5
base_key = F.normalize(torch.randn(K), dim=0)
keys = base_key.unsqueeze(0).repeat(n_writes, 1)  # All identical
values = F.normalize(torch.randn(n_writes, V), dim=1)

print("With IDENTICAL keys:")
for n_iters in [0, 1, 5, 10, 20]:
    if n_iters == 0:
        _, errors = simulate_delta_rule(keys, values, keys)
    else:
        _, errors = simulate_delta_with_correction(keys, values, keys, n_iters=n_iters)
    mean_err = sum(errors) / len(errors)
    status = "✓" if mean_err < 0.1 else "✗"
    print(f"  n_iters={n_iters:2d}: mean_err={mean_err:.4f} {status}")

# Test with slightly different keys
print("\nWith SIMILAR keys (sim ≈ 0.9):")
keys = []
for i in range(n_writes):
    noise = 0.3 * F.normalize(torch.randn(K), dim=0)
    keys.append(F.normalize(base_key + noise, dim=0))
keys = torch.stack(keys)

for n_iters in [0, 1, 5, 10, 20]:
    if n_iters == 0:
        _, errors = simulate_delta_rule(keys, values, keys)
    else:
        _, errors = simulate_delta_with_correction(keys, values, keys, n_iters=n_iters)
    mean_err = sum(errors) / len(errors)
    status = "✓" if mean_err < 0.1 else "✗"
    print(f"  n_iters={n_iters:2d}: mean_err={mean_err:.4f} {status}")

print("\n" + "=" * 70)
print("SUMMARY: What Works?")
print("=" * 70)
print("""
A) Positional Encoding:
   - Works IF pos_scale is high enough to make keys distinguishable
   - Problem: In NN, model produces same key for same token regardless of position
   - Fix: Add positional embedding BEFORE key projection, or position-dependent key proj
   
B) Orthogonal Keys:
   - PERFECT retrieval for n_writes ≤ K (64 writes with K=64)
   - Fundamental capacity limit: can't store more than K orthogonal vectors
   - For multi-needle (2-5 needles), this is MORE than enough!
   
C) Error-Correcting Writes:
   - Does NOT help with identical keys (convergence impossible)
   - Helps slightly with similar keys
   - Not practical: requires multiple passes, not causal

CONCLUSION: 
→ Solution B (orthogonal keys) proves the math CAN work for multi-needle
→ The issue is making the NN learn to produce distinct keys for MARKER tokens
→ Since MARKER tokens are identical, we need position-aware key generation
""")


SOLUTION A: Positional Encoding on Keys

Idea: k_i = f(token) + g(position)
If g(pos) varies enough, keys become distinguishable.

  pos_scale=0.0: mean_err=1.1148, key_sim=1.0000 ✗
  pos_scale=0.1: mean_err=1.0875, key_sim=0.9773 ✗
  pos_scale=0.3: mean_err=1.0339, key_sim=0.9307 ✗
  pos_scale=0.5: mean_err=1.0142, key_sim=0.9126 ✗
  pos_scale=1.0: mean_err=0.9992, key_sim=0.8985 ✗

SOLUTION B: Orthogonal Key Projection

Idea: Force keys to be orthogonal to each other.
Use Gram-Schmidt or random orthogonal basis.

  n_writes= 2: mean_err=0.000000, ortho_err=0.000000 ✓
  n_writes= 5: mean_err=0.000000, ortho_err=0.000000 ✓
  n_writes=10: mean_err=0.000000, ortho_err=0.000000 ✓
  n_writes=25: mean_err=0.000000, ortho_err=0.000001 ✓
  n_writes=50: mean_err=0.000000, ortho_err=0.000002 ✓

SOLUTION C: Error-Correcting Writes

Idea: Before writing v, compute v' = v + correction
where correction accounts for interference from existing state.

With IDENTICAL keys:
  n_iters= 0: mean_err=1.136

In [66]:
"""
============================================================
CRITICAL QUESTION: What key similarity is tolerable?
============================================================
"""

print("=" * 70)
print("TEST: Retrieval Error vs Key Similarity")
print("=" * 70)
print("\nWe'll create keys with controlled similarity and measure error.\n")

n_writes = 5
values = F.normalize(torch.randn(n_writes, V), dim=1)

print(f"{'Avg Similarity':>15} | {'Mean Error':>12} | {'Max Error':>12} | Status")
print("-" * 60)

for target_sim in [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0]:
    # Create keys with controlled similarity
    # Start from orthogonal, interpolate toward identical
    random_matrix = torch.randn(n_writes, K)
    Q, R = torch.linalg.qr(random_matrix.T)
    ortho_keys = Q[:, :n_writes].T  # Orthonormal keys
    
    base_key = F.normalize(torch.randn(K), dim=0)
    identical_keys = base_key.unsqueeze(0).repeat(n_writes, 1)
    
    # Interpolate: sim=0 → orthogonal, sim=1 → identical
    keys = (1 - target_sim) * ortho_keys + target_sim * identical_keys
    keys = F.normalize(keys, dim=1)
    
    # Measure actual similarity
    actual_sims = []
    for i in range(n_writes):
        for j in range(i+1, n_writes):
            actual_sims.append(torch.dot(keys[i], keys[j]).item())
    avg_sim = sum(actual_sims) / len(actual_sims)
    
    # Test retrieval
    _, errors = simulate_delta_rule(keys, values, keys)
    mean_err = sum(errors) / len(errors)
    max_err = max(errors)
    
    status = "✓ OK" if mean_err < 0.1 else ("⚠ WARN" if mean_err < 0.3 else "✗ FAIL")
    print(f"{avg_sim:>15.3f} | {mean_err:>12.4f} | {max_err:>12.4f} | {status}")

print("\n" + "=" * 70)
print("TEST: How many random keys can we store?")
print("=" * 70)
print("\nWith random (non-orthogonal) keys, capacity is limited.\n")

for n_writes in [2, 5, 10, 20, 30, 40, 50, 60, 64, 70, 80, 100]:
    # Random normalized keys (NOT orthogonalized)
    keys = F.normalize(torch.randn(n_writes, K), dim=1)
    values = F.normalize(torch.randn(n_writes, V), dim=1)
    
    _, errors = simulate_delta_rule(keys, values, keys)
    mean_err = sum(errors) / len(errors)
    
    # Measure key similarity
    sims = []
    for i in range(min(20, n_writes)):  # Sample
        for j in range(i+1, min(20, n_writes)):
            sims.append(abs(torch.dot(keys[i], keys[j]).item()))
    avg_sim = sum(sims) / len(sims) if sims else 0
    
    status = "✓" if mean_err < 0.1 else ("⚠" if mean_err < 0.3 else "✗")
    print(f"  n={n_writes:3d}: mean_err={mean_err:.4f}, avg|sim|={avg_sim:.4f} {status}")

print("\n" + "=" * 70)
print("CONCLUSION")  
print("=" * 70)
print("""
1. Key similarity must be < 0.3 for reasonable accuracy
2. Random K=64 keys have avg|sim| ≈ 0.1, which is borderline OK
3. With random keys, ~20-30 writes is the practical limit for K=64
4. For multi-needle (2-5 needles), random keys SHOULD work
   → The problem is that MARKER tokens produce IDENTICAL keys (sim=1.0)
   
WHAT WE NEED:
- NOT just "different" keys, but near-ORTHOGONAL keys for each MARKER
- Options:
  a) Random key per position (hash position → key)
  b) Learned orthogonal key bank
  c) Force key orthogonality via loss
""")

TEST: Retrieval Error vs Key Similarity

We'll create keys with controlled similarity and measure error.

 Avg Similarity |   Mean Error |    Max Error | Status
------------------------------------------------------------
          1.000 |       1.1556 |       1.5268 | ✗ FAIL
          0.988 |       1.1403 |       1.5093 | ✗ FAIL
          0.941 |       1.0838 |       1.4435 | ✗ FAIL
          0.856 |       0.9758 |       1.2836 | ✗ FAIL
          0.687 |       0.7786 |       1.0430 | ✗ FAIL
          0.492 |       0.5626 |       0.8038 | ✗ FAIL
          0.326 |       0.3664 |       0.5359 | ✗ FAIL
          0.150 |       0.1790 |       0.2960 | ⚠ WARN
          0.024 |       0.0449 |       0.1058 | ✓ OK
          0.010 |       0.0171 |       0.0334 | ✓ OK
         -0.000 |       0.0000 |       0.0000 | ✓ OK

TEST: How many random keys can we store?

With random (non-orthogonal) keys, capacity is limited.

  n=  2: mean_err=0.0513, avg|sim|=0.1030 ✓
  n=  5: mean_err=0.1655, avg|sim|=

In [67]:
"""
============================================================
SOLUTION: Position-Hashed Keys
============================================================
Instead of learning keys from token embeddings,
use deterministic hash: position → random key

This guarantees MARKER at different positions get different keys.
"""

def position_hash_key(position, seed=42, K=64):
    """Generate a deterministic random key from position."""
    gen = torch.Generator().manual_seed(seed + position)
    key = torch.randn(K, generator=gen)
    return F.normalize(key, dim=0)

print("=" * 70)
print("TEST: Position-Hashed Keys for Multi-Needle")
print("=" * 70)

# Simulate multi-needle scenario with position-hashed keys
for n_needles in [2, 3, 5, 10, 20]:
    # Random positions for needles
    positions = sorted(torch.randperm(500)[:n_needles].tolist())
    
    # Position-hashed keys
    keys = torch.stack([position_hash_key(pos) for pos in positions])
    values = F.normalize(torch.randn(n_needles, V), dim=1)
    
    # Check key similarity
    sims = []
    for i in range(n_needles):
        for j in range(i+1, n_needles):
            sims.append(abs(torch.dot(keys[i], keys[j]).item()))
    avg_sim = sum(sims) / len(sims) if sims else 0
    max_sim = max(sims) if sims else 0
    
    # Test retrieval
    _, errors = simulate_delta_rule(keys, values, keys)
    mean_err = sum(errors) / len(errors)
    
    status = "✓" if mean_err < 0.1 else "✗"
    print(f"  n_needles={n_needles:2d}: err={mean_err:.4f}, avg|sim|={avg_sim:.3f}, max|sim|={max_sim:.3f} {status}")

print("\n" + "=" * 70)
print("THE CATCH: How to retrieve with position-hashed keys?")
print("=" * 70)
print("""
Problem: At write time we use hash(position) as key.
         At retrieval, we query with hash(CUE_position).
         But CUE is at a DIFFERENT position than MARKER!

Options:
1. Learn a mapping: CUE token → MARKER key
   - Train model to output hash(marker_pos) when it sees CUE
   - This is what single-needle already does implicitly!
   
2. Content-addressable with position hash:
   - Key = hash(token, position) for MARKER
   - Query = learned from (CUE, context)
   - Model learns to produce the right query for each CUE
   
3. Hybrid: token-based key + position offset
   - Key = embed(MARKER) + position_offset
   - Query = embed(CUE) + learned_offset
   
Let's verify the math: if model learns to output correct query,
retrieval should work.
""")

# Verify: if queries exactly match keys, retrieval is perfect
print("Verification: Query = Key (perfect match)")
for n_needles in [2, 5, 10, 20]:
    positions = sorted(torch.randperm(500)[:n_needles].tolist())
    keys = torch.stack([position_hash_key(pos) for pos in positions])
    values = F.normalize(torch.randn(n_needles, V), dim=1)
    
    # Perfect queries = keys
    _, errors = simulate_delta_rule(keys, values, keys)
    mean_err = sum(errors) / len(errors)
    print(f"  n={n_needles:2d}: mean_err={mean_err:.6f}")

print("\n" + "-" * 70)
print("KEY INSIGHT:")
print("-" * 70)
print("""
The ONLY issue is that our model produces IDENTICAL keys for MARKER tokens.
Single-needle works because the model learns:
   CUE token → query that matches THE marker key
   
Multi-needle fails because:
   ALL markers → SAME key → overwrite
   CUE → query that matches that one key → retrieves last write only

Fix: Make markers produce DIFFERENT keys, then train model to:
   CUE_1 → query for MARKER_1's key
   CUE_2 → query for MARKER_2's key

How to make markers produce different keys?
   a) Add position encoding to embedding BEFORE key projection
   b) Learn position-dependent key projection
   c) Use random key bank indexed by something (head? layer?)
""")

TEST: Position-Hashed Keys for Multi-Needle
  n_needles= 2: err=0.0530, avg|sim|=0.107, max|sim|=0.107 ✓
  n_needles= 3: err=0.1027, avg|sim|=0.118, max|sim|=0.189 ✗
  n_needles= 5: err=0.1122, avg|sim|=0.085, max|sim|=0.169 ✗
  n_needles=10: err=0.2251, avg|sim|=0.092, max|sim|=0.397 ✗
  n_needles=20: err=0.4195, avg|sim|=0.106, max|sim|=0.367 ✗

THE CATCH: How to retrieve with position-hashed keys?

Problem: At write time we use hash(position) as key.
         At retrieval, we query with hash(CUE_position).
         But CUE is at a DIFFERENT position than MARKER!

Options:
1. Learn a mapping: CUE token → MARKER key
   - Train model to output hash(marker_pos) when it sees CUE
   - This is what single-needle already does implicitly!

2. Content-addressable with position hash:
   - Key = hash(token, position) for MARKER
   - Query = learned from (CUE, context)
   - Model learns to produce the right query for each CUE

3. Hybrid: token-based key + position offset
   - Key = embed(MARKER)

In [68]:
"""
============================================================
PRACTICAL SOLUTION: Learned Orthogonal Key Bank
============================================================
Pre-allocate K orthogonal keys. At each MARKER position, 
use the next unused key from the bank.

This GUARANTEES zero interference up to K markers.
"""

print("=" * 70)
print("TEST: Orthogonal Key Bank")
print("=" * 70)

# Create orthogonal key bank
KEY_BANK_SIZE = K  # Can store up to K=64 distinct keys
random_matrix = torch.randn(KEY_BANK_SIZE, K)
Q, R = torch.linalg.qr(random_matrix.T)
KEY_BANK = Q.T  # [64, 64] - orthonormal keys

# Verify orthonormality
ortho_check = KEY_BANK @ KEY_BANK.T
print(f"Key bank orthonormality check: off-diag max = {(ortho_check - torch.eye(K)).abs().max():.6f}")

for n_needles in [2, 3, 5, 10, 20, 30, 50, 64]:
    if n_needles > KEY_BANK_SIZE:
        print(f"  n_needles={n_needles:2d}: EXCEEDS BANK SIZE")
        continue
        
    # Use first n_needles keys from bank
    keys = KEY_BANK[:n_needles]
    values = F.normalize(torch.randn(n_needles, V), dim=1)
    
    _, errors = simulate_delta_rule(keys, values, keys)
    mean_err = sum(errors) / len(errors)
    max_err = max(errors)
    
    status = "✓" if mean_err < 0.01 else "✗"
    print(f"  n_needles={n_needles:2d}: mean_err={mean_err:.6f}, max_err={max_err:.6f} {status}")

print("\n" + "=" * 70)
print("IMPLEMENTATION PLAN")
print("=" * 70)
print("""
To implement orthogonal key bank in the model:

1. In GDN layer, create:
   self.key_bank = nn.Parameter(torch.zeros(n_heads, K, K))
   # Initialize with orthogonal matrices per head
   
2. Track write counter per head:
   self.register_buffer('write_counter', torch.zeros(n_heads, dtype=torch.long))
   
3. At MARKER position:
   key = self.key_bank[head, write_counter[head] % K]
   write_counter[head] += 1
   
4. At CUE position:
   # Model must learn to output query matching the stored key
   # Need to pass slot index somehow (CUE_1, CUE_2, etc.)
   
Problem: How does CUE know WHICH slot to query?
   - Single-needle: trivial (only one slot used)
   - Multi-needle: CUE must know its "slot index"
   
Options:
   a) CUE tokens are position-aware: CUE_1, CUE_2, ... (we tried, failed)
   b) CUE learns from context which MARKER it matches
   c) Query all slots, blend results (attention-like)
""")

# Test option (c): Query multiple slots
print("\n" + "=" * 70)
print("TEST: Query All Slots with Softmax Blending")
print("=" * 70)

def retrieve_multi_slot(S, query, key_bank, n_used, temp=1.0):
    """
    Soft attention over stored slots.
    S: [K, V] state matrix
    query: [K] query vector
    key_bank: [K, K] orthogonal keys
    n_used: number of slots actually used
    """
    # Compute attention over used keys
    keys = key_bank[:n_used]  # [n_used, K]
    
    # Attention scores
    scores = keys @ query  # [n_used]
    weights = F.softmax(scores / temp, dim=0)  # [n_used]
    
    # Retrieve from each slot
    retrieved_per_slot = keys @ S  # [n_used, V]
    
    # Weighted blend
    result = (weights.unsqueeze(1) * retrieved_per_slot).sum(dim=0)  # [V]
    
    return result, weights

# Test: Can softmax attention select the right slot?
n_needles = 5
keys = KEY_BANK[:n_needles]
values = F.normalize(torch.randn(n_needles, V), dim=1)

# Build state
S = torch.zeros(K, V)
for i in range(n_needles):
    k, v = keys[i], values[i]
    current = k @ S
    S = S + torch.outer(k, v - current)

print(f"\nWith {n_needles} needles stored:")
for target_idx in range(n_needles):
    # Use the exact key as query (perfect match)
    query = keys[target_idx]
    result, weights = retrieve_multi_slot(S, query, KEY_BANK, n_needles)
    
    # Check if correct value retrieved
    err = torch.norm(result - values[target_idx]) / torch.norm(values[target_idx])
    
    # Check attention weight distribution
    max_weight_idx = weights.argmax().item()
    correct = max_weight_idx == target_idx
    
    print(f"  Target slot {target_idx}: err={err:.4f}, "
          f"weights=[{', '.join([f'{w:.2f}' for w in weights])}], "
          f"correct={'✓' if correct else '✗'}")

print("\n" + "-" * 70)
print("ANALYSIS:")
print("-" * 70)
print("""
With orthogonal keys:
- Direct query with exact key → perfect retrieval
- Softmax blending → correct slot has highest weight

This proves the MATH works. Now we need the model to learn:
1. MARKER → use next key from orthogonal bank
2. CUE → output query vector that matches target key

For multi-needle, the challenge is teaching CUE to distinguish targets.
This requires CUE to be slot-aware (CUE_0, CUE_1, etc.) OR 
context-aware (learn from preceding tokens which MARKER to retrieve).
""")

TEST: Orthogonal Key Bank
Key bank orthonormality check: off-diag max = 0.000001
  n_needles= 2: mean_err=0.000000, max_err=0.000000 ✓
  n_needles= 3: mean_err=0.000000, max_err=0.000000 ✓
  n_needles= 5: mean_err=0.000000, max_err=0.000000 ✓
  n_needles=10: mean_err=0.000000, max_err=0.000000 ✓
  n_needles=20: mean_err=0.000000, max_err=0.000000 ✓
  n_needles=30: mean_err=0.000000, max_err=0.000000 ✓
  n_needles=50: mean_err=0.000000, max_err=0.000001 ✓
  n_needles=64: mean_err=0.000000, max_err=0.000001 ✓

IMPLEMENTATION PLAN

To implement orthogonal key bank in the model:

1. In GDN layer, create:
   self.key_bank = nn.Parameter(torch.zeros(n_heads, K, K))
   # Initialize with orthogonal matrices per head

2. Track write counter per head:
   self.register_buffer('write_counter', torch.zeros(n_heads, dtype=torch.long))

3. At MARKER position:
   key = self.key_bank[head, write_counter[head] % K]
   write_counter[head] += 1

4. At CUE position:
   # Model must learn to output query ma

# Summary: Pure Math Analysis

## What We Learned

| Test | Result |
|------|--------|
| Positional encoding (sin/cos) | FAILS - sim=0.9 still too high |
| Random keys | MARGINAL - only 2-5 writes before degradation |
| Error-correcting writes | FAILS - impossible with identical keys |
| **Orthogonal key bank** | **PERFECT** - 0% error up to K=64 needles |

## Key Threshold
- Need key similarity **< 0.02** for reliable 5-needle retrieval
- Random keys at K=64 give sim ≈ 0.1 (marginal)
- MARKERs currently give sim = 1.0 (IDENTICAL - fatal)

## The Solution
**Orthogonal Key Bank**:
1. Pre-allocate K orthogonal keys (via QR decomposition)
2. At each MARKER, use next key from bank (round-robin or counter)
3. At CUE, model outputs query that matches target key

## The Remaining Challenge
How does CUE know **which** key to query?
- Option A: Slot-indexed CUEs (CUE_0, CUE_1...) - tried, failed
- Option B: Context-dependent query (learn from preceding tokens)
- Option C: Attention over stored slots (soft) - FAILS (blending dilutes)

## Next Step
Implement orthogonal key bank in GDN and test with slot-indexed CUEs.

In [69]:
"""
============================================================
HETEROASSOCIATIVE CONTENT-ADDRESSED MEMORY
============================================================
Key insight: Context before each MARKER is different!
  MARKER_1 appears after [haystack tokens A, B, C]
  MARKER_2 appears after [haystack tokens X, Y, Z]
  
If key = f(context_window), then MARKER_1 and MARKER_2 get different keys.
At CUE time, model reconstructs the context → same key → retrieval works.

Implementation: short_conv(context_window) → key
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

print("=" * 70)
print("STEP 1: Test Context-Based Key Generation (Pure Simulation)")
print("=" * 70)

# Simulate: different random context windows → different keys via conv
d_model = 512
K = 64
conv_width = 4

# Short convolution for key generation
key_conv = nn.Conv1d(d_model, K, kernel_size=conv_width, padding=0, bias=False)
nn.init.xavier_normal_(key_conv.weight)

def context_to_key(context_window):
    """
    context_window: [conv_width, d_model]
    Returns: [K] normalized key
    """
    # Reshape for conv1d: [1, d_model, conv_width]
    x = context_window.T.unsqueeze(0)  
    # Conv output: [1, K, 1]
    out = key_conv(x)
    key = out.squeeze()  # [K]
    return F.normalize(key, dim=0)

# Generate random context windows (simulating haystack before each MARKER)
n_markers = 5
context_windows = []
for i in range(n_markers):
    # Random context window (different tokens before each MARKER)
    ctx = torch.randn(conv_width, d_model)
    context_windows.append(ctx)

# Generate keys from context
keys = torch.stack([context_to_key(ctx) for ctx in context_windows])

# Check key similarity
print(f"\nContext-based keys for {n_markers} markers:")
sims = []
for i in range(n_markers):
    for j in range(i+1, n_markers):
        sim = torch.dot(keys[i], keys[j]).item()
        sims.append(abs(sim))
        print(f"  key[{i}] · key[{j}] = {sim:.4f}")

avg_sim = sum(sims) / len(sims)
max_sim = max(sims)
print(f"\nAvg |similarity|: {avg_sim:.4f}, Max |similarity|: {max_sim:.4f}")
print(f"Required for multi-needle: avg_sim < 0.1, max_sim < 0.3")

status = "✓ GOOD" if avg_sim < 0.15 and max_sim < 0.4 else "✗ TOO SIMILAR"
print(f"Status: {status}")

print("\n" + "=" * 70)
print("STEP 2: Simulate Write/Read with Context Keys")
print("=" * 70)

# Generate values (what we want to store)
values = F.normalize(torch.randn(n_markers, V), dim=1)

# Write to state matrix
S = torch.zeros(K, V)
for i in range(n_markers):
    k = keys[i]
    v = values[i]
    current = k @ S
    delta = v - current
    S = S + torch.outer(k, delta)

# Read with same keys (assuming CUE reconstructs correct context)
print("\nRetrieval with EXACT same context (ideal case):")
errors = []
for i in range(n_markers):
    query = keys[i]  # Same key as write
    retrieved = query @ S
    err = torch.norm(retrieved - values[i]) / torch.norm(values[i])
    errors.append(err.item())
    status = "✓" if err < 0.1 else "✗"
    print(f"  Marker {i}: error = {err:.4f} {status}")

mean_err = sum(errors) / len(errors)
print(f"\nMean retrieval error: {mean_err:.4f}")

print("\n" + "=" * 70)
print("STEP 3: What if CUE context is SIMILAR but not IDENTICAL?")
print("=" * 70)

# In practice, CUE context won't be identical to MARKER context
# Test robustness to context noise

for noise_level in [0.0, 0.1, 0.2, 0.3, 0.5, 1.0]:
    errors = []
    for i in range(n_markers):
        # Add noise to context
        noisy_ctx = context_windows[i] + noise_level * torch.randn_like(context_windows[i])
        query = context_to_key(noisy_ctx)
        
        retrieved = query @ S
        err = torch.norm(retrieved - values[i]) / torch.norm(values[i])
        errors.append(err.item())
    
    mean_err = sum(errors) / len(errors)
    max_err = max(errors)
    status = "✓" if mean_err < 0.2 else "⚠" if mean_err < 0.5 else "✗"
    print(f"  noise={noise_level:.1f}: mean_err={mean_err:.4f}, max_err={max_err:.4f} {status}")

print("\n" + "-" * 70)
print("INSIGHT:")
print("-" * 70)
print("""
Context-based keys work IF:
1. Different MARKERs have different context → different keys ✓
2. CUE can reconstruct the context that preceded target MARKER

The challenge: How does CUE reconstruct MARKER's context?
  - In NIAH: haystack before MARKER = haystack before CUE (same sequence!)
  - CUE can use SWA to look back at local context
  - If same tokens appear before CUE as before MARKER → same key → retrieval!

This is EXACTLY how Based/CAT works!
""")

STEP 1: Test Context-Based Key Generation (Pure Simulation)

Context-based keys for 5 markers:
  key[0] · key[1] = -0.1192
  key[0] · key[2] = 0.0640
  key[0] · key[3] = -0.1956
  key[0] · key[4] = -0.1688
  key[1] · key[2] = -0.0338
  key[1] · key[3] = 0.0528
  key[1] · key[4] = 0.0700
  key[2] · key[3] = 0.1914
  key[2] · key[4] = 0.0086
  key[3] · key[4] = -0.1070

Avg |similarity|: 0.1011, Max |similarity|: 0.1956
Required for multi-needle: avg_sim < 0.1, max_sim < 0.3
Status: ✓ GOOD

STEP 2: Simulate Write/Read with Context Keys

Retrieval with EXACT same context (ideal case):
  Marker 0: error = 0.3419 ✗
  Marker 1: error = 0.1143 ✗
  Marker 2: error = 0.2008 ✗
  Marker 3: error = 0.1108 ✗
  Marker 4: error = 0.0000 ✓

Mean retrieval error: 0.1536

STEP 3: What if CUE context is SIMILAR but not IDENTICAL?
  noise=0.0: mean_err=0.1536, max_err=0.3419 ✓
  noise=0.1: mean_err=0.1450, max_err=0.3098 ✓
  noise=0.2: mean_err=0.1753, max_err=0.3663 ✓
  noise=0.3: mean_err=0.1786, max_er

In [70]:
"""
============================================================
GDN WITH CONTEXT-BASED KEYS (Based/CAT Style)
============================================================
"""

class GDNContextKeys(nn.Module):
    """
    Gated Delta Net with context-based key generation.
    
    Key difference from standard GDN:
    - Keys are generated from short_conv(context_window), not just current token
    - This makes keys position/context-dependent, enabling multi-needle
    """
    
    def __init__(self, d_model, n_heads, head_dim, value_dim, 
                 conv_width=4, beta_bias=-4.0, g_bias=4.0):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = head_dim
        self.value_dim = value_dim
        self.conv_width = conv_width
        
        # Key projection with causal convolution
        # Input: [B, T, d_model], Output: [B, T, n_heads * head_dim]
        self.key_conv = nn.Conv1d(
            d_model, n_heads * head_dim, 
            kernel_size=conv_width, 
            padding=conv_width - 1,  # Causal padding
            groups=1
        )
        
        # Query uses same convolution (for matching at retrieval)
        self.query_conv = nn.Conv1d(
            d_model, n_heads * head_dim,
            kernel_size=conv_width,
            padding=conv_width - 1,
            groups=1
        )
        
        # Value projection (standard)
        self.W_v = nn.Linear(d_model, n_heads * value_dim)
        
        # Gates
        self.W_beta = nn.Linear(d_model, n_heads)
        self.W_g = nn.Linear(d_model, n_heads)
        
        # Output projection
        self.W_o = nn.Linear(n_heads * value_dim, d_model)
        
        # Bias initialization for sparse gating
        self.W_beta.bias.data.fill_(beta_bias)
        self.W_g.bias.data.fill_(g_bias)
        
    def forward(self, x, state=None):
        """
        x: [B, T, d_model]
        state: [B, n_heads, head_dim, value_dim] or None
        Returns: output [B, T, d_model], new_state, diagnostics
        """
        B, T, _ = x.shape
        H, K, V = self.n_heads, self.head_dim, self.value_dim
        
        if state is None:
            state = torch.zeros(B, H, K, V, device=x.device, dtype=x.dtype)
        
        # Context-based keys via causal convolution
        # [B, T, d] -> [B, d, T] -> conv -> [B, H*K, T] -> [B, T, H*K]
        x_t = x.transpose(1, 2)
        keys_raw = self.key_conv(x_t)[:, :, :T].transpose(1, 2)  # Causal: trim to T
        queries_raw = self.query_conv(x_t)[:, :, :T].transpose(1, 2)
        
        # Reshape to heads
        keys = keys_raw.view(B, T, H, K)
        queries = queries_raw.view(B, T, H, K)
        
        # Normalize keys and queries
        keys = F.normalize(keys, dim=-1)
        queries = F.normalize(queries, dim=-1)
        
        # Values and gates
        values = self.W_v(x).view(B, T, H, V)
        beta = torch.sigmoid(self.W_beta(x)).unsqueeze(-1)  # [B, T, H, 1]
        g = torch.sigmoid(self.W_g(x)).unsqueeze(-1)  # [B, T, H, 1]
        
        # Sequential delta rule update
        outputs = []
        S = state.clone()
        
        for t in range(T):
            k_t = keys[:, t]  # [B, H, K]
            q_t = queries[:, t]  # [B, H, K]
            v_t = values[:, t]  # [B, H, V]
            beta_t = beta[:, t]  # [B, H, 1]
            g_t = g[:, t]  # [B, H, 1]
            
            # Retrieval: q @ S
            retrieved = torch.einsum('bhk,bhkv->bhv', q_t, S)  # [B, H, V]
            
            # Current at key position: k @ S
            current = torch.einsum('bhk,bhkv->bhv', k_t, S)  # [B, H, V]
            
            # Delta rule update
            delta = v_t - current  # [B, H, V]
            update = torch.einsum('bhk,bhv->bhkv', k_t, delta)  # [B, H, K, V]
            S = g_t.unsqueeze(-1) * S + beta_t.unsqueeze(-1) * update
            
            outputs.append(retrieved)
        
        # Stack outputs
        output = torch.stack(outputs, dim=1)  # [B, T, H, V]
        output = output.view(B, T, H * V)
        output = self.W_o(output)
        
        diagnostics = {
            'beta_mean': beta.mean().item(),
            'g_mean': g.mean().item(),
        }
        
        return output, S, diagnostics


# Test the model
print("=" * 70)
print("Testing GDNContextKeys on Multi-Needle Sequence")
print("=" * 70)

# Create model
gdn_ctx = GDNContextKeys(
    d_model=512, n_heads=8, head_dim=64, value_dim=128,
    conv_width=4, beta_bias=-4.0, g_bias=4.0
).to(DEVICE)

# Create test sequence with 2 needles
# Format: [haystack...] MARKER needle1 [haystack...] MARKER needle2 [haystack...] CUE1 CUE2
seq_len = 200
vocab_size = 256
MARKER = 254
CUE = 255

# Build sequence
tokens = torch.randint(0, 250, (1, seq_len), device=DEVICE)

# Insert needles at different positions with different contexts
needle_pos_1 = 30
needle_pos_2 = 100
cue_pos_1 = 150
cue_pos_2 = 160

# Make context before each MARKER different
tokens[0, needle_pos_1] = MARKER
tokens[0, needle_pos_1 + 1] = 42  # needle value 1

tokens[0, needle_pos_2] = MARKER  
tokens[0, needle_pos_2 + 1] = 77  # needle value 2

tokens[0, cue_pos_1] = CUE
tokens[0, cue_pos_2] = CUE

# Create embeddings
embed = nn.Embedding(vocab_size, 512).to(DEVICE)
x = embed(tokens)

# Forward pass
with torch.no_grad():
    output, state, diag = gdn_ctx(x)

print(f"\nSequence structure:")
print(f"  MARKER_1 at pos {needle_pos_1}, context: {tokens[0, needle_pos_1-4:needle_pos_1].tolist()}")
print(f"  MARKER_2 at pos {needle_pos_2}, context: {tokens[0, needle_pos_2-4:needle_pos_2].tolist()}")
print(f"  CUE_1 at pos {cue_pos_1}, context: {tokens[0, cue_pos_1-4:cue_pos_1].tolist()}")
print(f"  CUE_2 at pos {cue_pos_2}, context: {tokens[0, cue_pos_2-4:cue_pos_2].tolist()}")

print(f"\nDiagnostics: beta_mean={diag['beta_mean']:.4f}, g_mean={diag['g_mean']:.4f}")

# Check key similarity at MARKER positions
x_t = x.transpose(1, 2)
keys_raw = gdn_ctx.key_conv(x_t)[:, :, :seq_len].transpose(1, 2)
keys = keys_raw.view(1, seq_len, 8, 64)
keys = F.normalize(keys, dim=-1)

k1 = keys[0, needle_pos_1]  # [H, K]
k2 = keys[0, needle_pos_2]  # [H, K]

print(f"\nKey similarity between MARKER_1 and MARKER_2:")
for h in range(8):
    sim = torch.dot(k1[h], k2[h]).item()
    print(f"  Head {h}: sim = {sim:.4f}")

avg_sim = torch.einsum('hk,hk->h', k1, k2).mean().item()
print(f"\nAvg similarity across heads: {avg_sim:.4f}")
print(f"Target: < 0.3 for multi-needle to work")

Testing GDNContextKeys on Multi-Needle Sequence


: 

In [1]:
"""
============================================================
TEST: Context-Based Keys in Actual GDN+SWA Architecture
============================================================
Using the modified GatedDeltaNetLayer with key_conv_width=4
"""

# Reload modules to pick up changes
import importlib
import sys
for mod_name in list(sys.modules.keys()):
    if 'config' in mod_name or 'model' in mod_name or 'core' in mod_name:
        del sys.modules[mod_name]

import torch
import torch.nn.functional as F

# Fresh imports
sys.path.insert(0, '/home/m_tes/groundthink/gt-v6/v7-design/groundthink_v7')
from config import HybridConfig
from model import TransparentHybrid, GatedDeltaNetLayer

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {DEVICE}")

# Create config with context-based keys
cfg_conv = HybridConfig(
    d_model=512,
    n_heads=8,
    head_dim=64,
    value_dim=128,
    vocab_size=256,
    layer_pattern="GS",
    window_size=64,
    beta_bias=-4.0,
    g_bias=4.0,
    shifted_value=True,
    key_conv_width=4,  # <-- THE FIX: 4-token context window
)

print(f"\nConfig: key_conv_width={cfg_conv.key_conv_width}")
print(f"Architecture: {cfg_conv.layer_pattern} (GDN + SWA)")

# Create model
model_conv = TransparentHybrid(cfg_conv).to(DEVICE)
print(f"Model params: {model_conv.count_params():,}")

# Get the GDN layer
gdn_layer = model_conv.layers[0]
print(f"GDN has k_conv: {hasattr(gdn_layer, 'k_conv')}")

# Test: Create sequence with 2 MARKER tokens at different positions
# Each MARKER has DIFFERENT context before it
seq_len = 256
MARKER = 254

tokens = torch.randint(0, 250, (1, seq_len), device=DEVICE)

# Insert MARKERs with different preceding context
marker_pos_1 = 50
marker_pos_2 = 150

# Ensure different context by using different random tokens
tokens[0, marker_pos_1] = MARKER
tokens[0, marker_pos_2] = MARKER

print(f"\n{'='*60}")
print("MARKER Key Similarity Analysis")
print('='*60)
print(f"MARKER_1 at pos {marker_pos_1}, context: {tokens[0, marker_pos_1-4:marker_pos_1].tolist()}")
print(f"MARKER_2 at pos {marker_pos_2}, context: {tokens[0, marker_pos_2-4:marker_pos_2].tolist()}")

# Get embeddings
x = model_conv.embed(tokens)
x_norm = gdn_layer.norm(x)

# Get keys using context conv
x_t = x_norm.transpose(1, 2)
k_raw = gdn_layer.k_conv(x_t)[:, :, :seq_len]
k_full = k_raw.transpose(1, 2).view(1, seq_len, 8, 64)
k_full = F.normalize(k_full, dim=-1)

# Compare keys at MARKER positions
k1 = k_full[0, marker_pos_1]  # [H, K]
k2 = k_full[0, marker_pos_2]  # [H, K]

print(f"\nPer-head key similarity (MARKER_1 vs MARKER_2):")
for h in range(8):
    sim = torch.dot(k1[h], k2[h]).item()
    status = "✓" if abs(sim) < 0.3 else "✗"
    print(f"  Head {h}: sim = {sim:+.4f} {status}")

avg_sim = (k1 * k2).sum(dim=-1).mean().item()
print(f"\nAvg similarity: {avg_sim:+.4f}")
print(f"Target for multi-needle: |sim| < 0.3")
print(f"Result: {'✓ GOOD - keys are different!' if abs(avg_sim) < 0.3 else '✗ TOO SIMILAR'}")

# Compare with standard (no conv) model
print(f"\n{'='*60}")
print("Comparison: With vs Without Context Conv")
print('='*60)

cfg_std = HybridConfig(
    d_model=512,
    n_heads=8,
    head_dim=64,
    value_dim=128,
    vocab_size=256,
    layer_pattern="GS",
    key_conv_width=1,  # Standard linear projection
)

model_std = TransparentHybrid(cfg_std).to(DEVICE)
gdn_std = model_std.layers[0]

x_std = model_std.embed(tokens)
x_std_norm = gdn_std.norm(x_std)
k_std = gdn_std.k_proj(x_std_norm).view(1, seq_len, 8, 64)
k_std = F.normalize(k_std, dim=-1)

k1_std = k_std[0, marker_pos_1]
k2_std = k_std[0, marker_pos_2]

avg_sim_std = (k1_std * k2_std).sum(dim=-1).mean().item()
print(f"Standard (no conv):   avg_sim = {avg_sim_std:+.4f}")
print(f"Context conv (w=4):   avg_sim = {avg_sim:+.4f}")
print(f"\nImprovement: {abs(avg_sim_std) - abs(avg_sim):.4f} reduction in similarity")

Device: cuda

Config: key_conv_width=4
Architecture: GS (GDN + SWA)
Model params: 8,403,480
GDN has k_conv: True

MARKER Key Similarity Analysis
MARKER_1 at pos 50, context: [199, 20, 185, 15]
MARKER_2 at pos 150, context: [92, 189, 60, 67]


: 