# Force Memory Usage

Testing fixes from compass doc:
1. **Distractor > window_size**: 80 tokens distractor, 64 window → fact outside local attention
2. **local_dropout=0.3**: Bottleneck the SWA shortcut, force state retrieval

**Changes made:**
- `config.py`: Added `local_dropout: float = 0.3`
- `model.py`: `local_out = self.local_dropout(local_out)` after o_proj
- `make_example()`: distractor_tokens=80 (was 400, but 400//5=80 anyway, now explicit)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer
import random

DEVICE = 'cuda'
torch.manual_seed(42)

<torch._C.Generator at 0x77439c594190>

In [18]:
# Import model - reload to pick up changes
import importlib
import config
import model as model_module
importlib.reload(config)
importlib.reload(model_module)

from config import HybridConfig
from model import TransparentHybrid

from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

## Test Data Generator

In [3]:
NAMES = ['Alice', 'Bob', 'Carol', 'Dave', 'Eve', 'Frank', 'Grace', 'Henry']
OBJECTS = ['ball', 'car', 'hat', 'book', 'pen', 'cup', 'ring', 'lamp']
COLORS = ['red', 'blue', 'green', 'yellow', 'purple', 'orange', 'pink', 'white']

def make_example(distractor_tokens=80):
    """Fact at start, distractor BEYOND window_size (64), query at end."""
    name = random.choice(NAMES)
    obj = random.choice(OBJECTS)
    color = random.choice(COLORS)
    
    fact = f'{name} has a {color} {obj}.'
    distractor = ' The sky is blue.' * (distractor_tokens // 5)
    query = f' What does {name} have?'
    answer = f' {color}'
    
    return fact + distractor + query, answer

## State Ablation Test

In [4]:
def test_state_ablation(model, cfg, n_samples=100):
    """Test if zeroing state hurts accuracy."""
    model.eval()
    correct_normal = 0
    correct_zeroed = 0
    
    for _ in range(n_samples):
        text, answer = make_example()
        tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=600)['input_ids'].to(DEVICE)
        answer_id = tokenizer.encode(answer)[0]
        
        # Normal
        with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
            logits, _, _, _ = model(tokens)
        if answer_id in logits[0, -1].topk(5).indices.tolist():
            correct_normal += 1
        
        # Zeroed state
        with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
            x = model.embed(tokens)
            x = model.embed_norm(x)
            state = None
            for i, (layer, ffn) in enumerate(zip(model.layers, model.ffns)):
                if cfg.layer_pattern[i] == 'G':
                    x, state, _ = layer(x, initial_state=state, input_ids=tokens)
                    key_bank = layer.key_bank
                else:
                    x, _ = layer(x, gdn_state=torch.zeros_like(state), input_ids=tokens, key_bank=key_bank)
                x = ffn(x)
            logits_z = model.lm_head(model.norm_f(x))
        if answer_id in logits_z[0, -1].topk(5).indices.tolist():
            correct_zeroed += 1
    
    print(f'Normal: {correct_normal}/{n_samples} ({correct_normal/n_samples:.0%})')
    print(f'Zeroed: {correct_zeroed}/{n_samples} ({correct_zeroed/n_samples:.0%})')
    print(f'Delta: {correct_normal - correct_zeroed}')
    return correct_normal, correct_zeroed

## Baseline (current broken model)

In [7]:
cfg = HybridConfig(
    d_model=256, n_heads=4, head_dim=64, value_dim=64,
    layer_pattern='GS', vocab_size=50257,
    window_size=64, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=0.1
)
model = TransparentHybrid(cfg).to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=1e-3)

print('Training baseline...')
for step in range(3000):
    text, answer = make_example()
    tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=600)['input_ids'].to(DEVICE)
    targets = tokens[:, 1:].contiguous()
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model(tokens)
        loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
    opt.zero_grad()
    loss.backward()
    opt.step()
    if step % 500 == 0:
        print(f'Step {step}: loss={loss.item():.3f}')

print('\nBaseline ablation:')
test_state_ablation(model, cfg)

Training baseline...
Step 0: loss=10.855
Step 500: loss=0.039
Step 1000: loss=0.072
Step 1500: loss=0.031
Step 2000: loss=0.017
Step 2500: loss=0.019

Baseline ablation:
Normal: 63/100 (63%)
Zeroed: 65/100 (65%)
Delta: -2


(63, 65)

In [8]:
# DEBUG: Where is information flowing?
# Test: What if we remove the state ENTIRELY and use local_scale=0?
# If model still works, info flows through residual x

# Let's trace a single forward pass
cfg_debug = HybridConfig(
    d_model=256, n_heads=4, head_dim=64, value_dim=64,
    layer_pattern='GS', vocab_size=50257,
    window_size=64, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=0.1,
    local_scale=0.0  # ZERO out local attention entirely
)
model_debug = TransparentHybrid(cfg_debug).to(DEVICE)
opt_debug = torch.optim.AdamW(model_debug.parameters(), lr=1e-3)

print('Training with local_scale=0.0 (no local attention at all)...')
for step in range(3000):
    text, answer = make_example()
    tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=600)['input_ids'].to(DEVICE)
    targets = tokens[:, 1:].contiguous()
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_debug(tokens)
        loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
    opt_debug.zero_grad()
    loss.backward()
    opt_debug.step()
    if step % 500 == 0:
        print(f'Step {step}: loss={loss.item():.3f}')

print('\nWith local_scale=0.0:')
test_state_ablation(model_debug, cfg_debug)

Training with local_scale=0.0 (no local attention at all)...
Step 0: loss=10.908
Step 500: loss=0.064
Step 1000: loss=0.028
Step 1500: loss=0.028
Step 2000: loss=0.032
Step 2500: loss=0.116

With local_scale=0.0:
Normal: 59/100 (59%)
Zeroed: 65/100 (65%)
Delta: -6


(59, 65)

In [9]:
# REAL TEST: distractor = 400 tokens (way beyond window=64)
# Different query phrasing to prevent pattern matching

def make_hard_example():
    """Fact at start, LONG distractor, paraphrased query at end."""
    name = random.choice(NAMES)
    obj = random.choice(OBJECTS)
    color = random.choice(COLORS)
    
    fact = f'{name} has a {color} {obj}.'
    # 400 tokens of distractor (well beyond 64 window)
    distractor = ' '.join(['The weather is nice today.' for _ in range(60)])
    # Different query phrasing
    query = f" Tell me the color of {name}'s item."
    answer = f' {color}'
    
    return fact + distractor + query, answer

# Test with BRUTAL settings
cfg_hard = HybridConfig(
    d_model=256, n_heads=4, head_dim=64, value_dim=64,
    layer_pattern='GS', vocab_size=50257,
    window_size=64, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=0.1,
    local_scale=0.3
)
model_hard = TransparentHybrid(cfg_hard).to(DEVICE)
opt_hard = torch.optim.AdamW(model_hard.parameters(), lr=1e-3)

print('Training on HARD task (400+ token distractor, paraphrased query)...')
for step in range(3000):
    text, answer = make_hard_example()
    tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=700)['input_ids'].to(DEVICE)
    targets = tokens[:, 1:].contiguous()
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_hard(tokens)
        loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
    opt_hard.zero_grad()
    loss.backward()
    opt_hard.step()
    if step % 500 == 0:
        print(f'Step {step}: loss={loss.item():.3f}')

# Test with make_hard_example
def test_state_ablation_hard(model, cfg, n_samples=100):
    """Test if zeroing state hurts accuracy on hard examples."""
    model.eval()
    correct_normal = 0
    correct_zeroed = 0
    
    for _ in range(n_samples):
        text, answer = make_hard_example()
        tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=700)['input_ids'].to(DEVICE)
        answer_id = tokenizer.encode(answer)[0]
        
        # Normal
        with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
            logits, _, _, _ = model(tokens)
        if answer_id in logits[0, -1].topk(5).indices.tolist():
            correct_normal += 1
        
        # Zeroed state
        with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
            x = model.embed(tokens)
            x = model.embed_norm(x)
            state = None
            for i, (layer, ffn) in enumerate(zip(model.layers, model.ffns)):
                if cfg.layer_pattern[i] == 'G':
                    x, state, _ = layer(x, initial_state=state, input_ids=tokens)
                    key_bank = layer.key_bank
                else:
                    x, _ = layer(x, gdn_state=torch.zeros_like(state), input_ids=tokens, key_bank=key_bank)
                x = ffn(x)
            logits_z = model.lm_head(model.norm_f(x))
        if answer_id in logits_z[0, -1].topk(5).indices.tolist():
            correct_zeroed += 1
    
    print(f'Normal: {correct_normal}/{n_samples} ({correct_normal/n_samples:.0%})')
    print(f'Zeroed: {correct_zeroed}/{n_samples} ({correct_zeroed/n_samples:.0%})')
    print(f'Delta: {correct_normal - correct_zeroed}')
    return correct_normal, correct_zeroed

print('\nHard task ablation:')
test_state_ablation_hard(model_hard, cfg_hard)

Training on HARD task (400+ token distractor, paraphrased query)...
Step 0: loss=10.909
Step 500: loss=0.028
Step 1000: loss=0.043
Step 1500: loss=0.017
Step 2000: loss=0.018
Step 2500: loss=0.018

Hard task ablation:
Normal: 64/100 (64%)
Zeroed: 66/100 (66%)
Delta: -2


(64, 66)

In [10]:
# SANITY CHECK: Is the model learning ANYTHING meaningful?
# Test: shuffle answers and see if accuracy drops to ~12.5% (1/8 colors)

print("=== SANITY CHECKS ===")
model_hard.eval()

# 1. Normal accuracy with verbose output
correct = 0
for i in range(20):
    text, answer = make_hard_example()
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=700)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_hard(tokens)
    
    pred_token = logits[0, -1].argmax().item()
    pred_text = tokenizer.decode([pred_token])
    top5 = [tokenizer.decode([t]) for t in logits[0, -1].topk(5).indices.tolist()]
    
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct += 1
        status = "✓"
    else:
        status = "✗"
    
    if i < 5:  # Show first 5
        print(f"{status} Answer: {answer.strip():8s} | Pred: {pred_text:8s} | Top5: {top5}")

print(f"\nAccuracy (top-5): {correct}/20 = {correct/20:.0%}")

# 2. Random baseline: what's the chance of guessing?
print(f"\nRandom baseline: 1/{len(COLORS)} = {1/len(COLORS):.1%}")

# 3. Most common prediction
from collections import Counter
preds = []
for _ in range(100):
    text, answer = make_hard_example()
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=700)['input_ids'].to(DEVICE)
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_hard(tokens)
    preds.append(tokenizer.decode([logits[0, -1].argmax().item()]).strip())

print(f"\nMost common predictions: {Counter(preds).most_common(5)}")

=== SANITY CHECKS ===
✗ Answer: pink     | Pred:  green   | Top5: [' green', ' white', ' red', ' orange', ' purple']
✓ Answer: white    | Pred:  green   | Top5: [' green', ' white', ' red', ' orange', ' purple']
✗ Answer: blue     | Pred:  pink    | Top5: [' pink', ' green', ' yellow', ' purple', ' white']
✗ Answer: orange   | Pred:  green   | Top5: [' green', ' white', 'The', ' red', ' purple']
✓ Answer: red      | Pred:  white   | Top5: [' white', ' red', ' orange', ' green', ' yellow']

Accuracy (top-5): 12/20 = 60%

Random baseline: 1/8 = 12.5%

Most common predictions: [('green', 57), ('white', 30), ('pink', 13)]


In [11]:
# FIX: Train with ANSWER-ONLY loss
# Only compute loss on the answer token, ignore distractors

cfg_answer_only = HybridConfig(
    d_model=256, n_heads=4, head_dim=64, value_dim=64,
    layer_pattern='GS', vocab_size=50257,
    window_size=64, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=0.1,
    local_scale=0.3
)
model_answer_only = TransparentHybrid(cfg_answer_only).to(DEVICE)
opt_answer_only = torch.optim.AdamW(model_answer_only.parameters(), lr=1e-3)

print('Training with ANSWER-ONLY loss (ignoring distractors)...')
for step in range(3000):
    text, answer = make_hard_example()
    
    # Tokenize separately to know where answer starts
    text_tokens = tokenizer(text, return_tensors='pt')['input_ids']
    answer_tokens = tokenizer(answer, return_tensors='pt')['input_ids']
    full_tokens = torch.cat([text_tokens, answer_tokens], dim=1)[:, :700].to(DEVICE)
    
    answer_start = text_tokens.size(1)
    
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_answer_only(full_tokens)
        
        # Only compute loss on answer position
        # logits[0, answer_start-1] predicts answer_tokens[0, 0]
        loss = F.cross_entropy(
            logits[0, answer_start-1:answer_start].reshape(-1, 50257),
            answer_tokens[0, :1].reshape(-1).to(DEVICE)
        )
    
    opt_answer_only.zero_grad()
    loss.backward()
    opt_answer_only.step()
    
    if step % 500 == 0:
        print(f'Step {step}: loss={loss.item():.3f}')

print('\nAnswer-only trained model:')
test_state_ablation_hard(model_answer_only, cfg_answer_only)

# Check predictions
model_answer_only.eval()
correct = 0
preds = []
for _ in range(100):
    text, answer = make_hard_example()
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=700)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_answer_only(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct += 1
    preds.append(tokenizer.decode([logits[0, -1].argmax().item()]).strip())

print(f"\nTop-5 accuracy: {correct}%")
print(f"Prediction distribution: {Counter(preds).most_common(8)}")

Training with ANSWER-ONLY loss (ignoring distractors)...
Step 0: loss=11.250
Step 500: loss=0.004
Step 1000: loss=0.002
Step 1500: loss=0.001
Step 2000: loss=0.000
Step 2500: loss=0.000

Answer-only trained model:
Normal: 59/100 (59%)
Zeroed: 60/100 (60%)
Delta: -1

Top-5 accuracy: 68%
Prediction distribution: [('white', 100)]


In [12]:
# PROPER FIX: Combined LM loss + Retrieval loss + State reconstruction loss
# The model must learn language modeling AND retrieval AND prove state stores info

cfg_combined = HybridConfig(
    d_model=256, n_heads=4, head_dim=64, value_dim=64,
    layer_pattern='GS', vocab_size=50257,
    window_size=64, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=0.1,
    local_scale=0.3
)
model_combined = TransparentHybrid(cfg_combined).to(DEVICE)

# State reconstruction head: given final state, predict the color token
# This forces the state to actually encode the fact
state_dim = cfg_combined.n_heads * cfg_combined.head_dim * cfg_combined.value_dim
recon_head = nn.Sequential(
    nn.Linear(state_dim, 256),
    nn.ReLU(),
    nn.Linear(256, len(COLORS))  # 8 classes
).to(DEVICE)

# Color to index mapping
color_to_idx = {c: i for i, c in enumerate(COLORS)}

opt_combined = torch.optim.AdamW(
    list(model_combined.parameters()) + list(recon_head.parameters()), 
    lr=1e-3
)

print('Training with COMBINED loss (LM + answer + state recon)...')
for step in range(5000):
    text, answer = make_hard_example()
    color = answer.strip()
    color_idx = color_to_idx[color]
    
    text_tokens = tokenizer(text, return_tensors='pt')['input_ids']
    answer_tokens = tokenizer(answer, return_tensors='pt')['input_ids']
    full_tokens = torch.cat([text_tokens, answer_tokens], dim=1)[:, :700].to(DEVICE)
    answer_start = text_tokens.size(1)
    
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, state = model_combined(full_tokens)
        
        # 1. LM loss on all tokens (keeps model grounded)
        targets = full_tokens[:, 1:].contiguous()
        lm_loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
        
        # 2. Answer loss (weighted higher)
        answer_loss = F.cross_entropy(
            logits[0, answer_start-1:answer_start].reshape(-1, 50257),
            answer_tokens[0, :1].reshape(-1).to(DEVICE)
        )
        
        # 3. State reconstruction loss: state must encode the color
        state_flat = state.reshape(1, -1).float()
        recon_logits = recon_head(state_flat)
        recon_loss = F.cross_entropy(
            recon_logits, 
            torch.tensor([color_idx], device=DEVICE)
        )
        
        # Combined loss
        loss = lm_loss + 2.0 * answer_loss + 1.0 * recon_loss
    
    opt_combined.zero_grad()
    loss.backward()
    opt_combined.step()
    
    if step % 500 == 0:
        print(f'Step {step}: lm={lm_loss.item():.3f}, ans={answer_loss.item():.3f}, recon={recon_loss.item():.3f}')

print('\nCombined training model:')
test_state_ablation_hard(model_combined, cfg_combined)

# Check if state actually encodes color
model_combined.eval()
state_correct = 0
for _ in range(100):
    text, answer = make_hard_example()
    color = answer.strip()
    color_idx = color_to_idx[color]
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=700)['input_ids'].to(DEVICE)
    
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        _, _, _, state = model_combined(tokens)
    state_flat = state.reshape(1, -1).float()
    recon_pred = recon_head(state_flat).argmax().item()
    if recon_pred == color_idx:
        state_correct += 1

print(f"\nState reconstruction accuracy: {state_correct}%")
print("(This shows if state actually encodes the fact)")

Training with COMBINED loss (LM + answer + state recon)...
Step 0: lm=10.958, ans=11.062, recon=2.094
Step 500: lm=0.028, ans=0.007, recon=0.000
Step 1000: lm=0.028, ans=0.002, recon=0.000
Step 1500: lm=0.024, ans=0.001, recon=0.000
Step 2000: lm=0.023, ans=0.001, recon=0.000
Step 2500: lm=0.023, ans=0.000, recon=0.000
Step 3000: lm=0.024, ans=0.000, recon=0.000
Step 3500: lm=0.021, ans=0.000, recon=0.000
Step 4000: lm=0.023, ans=0.000, recon=0.000
Step 4500: lm=0.026, ans=0.000, recon=0.000

Combined training model:
Normal: 65/100 (65%)
Zeroed: 68/100 (68%)
Delta: -3

State reconstruction accuracy: 10%
(This shows if state actually encodes the fact)


In [13]:
# ACTUALLY CORRECT: Reconstruct fact from state AT QUERY TIME
# The reconstruction loss should use state before seeing the answer

cfg_v2 = HybridConfig(
    d_model=256, n_heads=4, head_dim=64, value_dim=64,
    layer_pattern='GS', vocab_size=50257,
    window_size=64, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=0.1,
    local_scale=0.3
)
model_v2 = TransparentHybrid(cfg_v2).to(DEVICE)

# State reconstruction head
state_dim = cfg_v2.n_heads * cfg_v2.head_dim * cfg_v2.value_dim
recon_head_v2 = nn.Sequential(
    nn.Linear(state_dim, 256),
    nn.ReLU(),
    nn.Linear(256, len(COLORS))
).to(DEVICE)

opt_v2 = torch.optim.AdamW(
    list(model_v2.parameters()) + list(recon_head_v2.parameters()), 
    lr=1e-3
)

print('Training with reconstruction loss on PROMPT state (not prompt+answer)...')
for step in range(5000):
    text, answer = make_hard_example()
    color = answer.strip()
    color_idx = color_to_idx[color]
    
    # Process PROMPT ONLY for reconstruction loss
    prompt_tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=700)['input_ids'].to(DEVICE)
    
    # Process FULL sequence for LM loss
    full_tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=700)['input_ids'].to(DEVICE)
    answer_start = prompt_tokens.size(1)
    
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        # Get state from prompt only
        _, _, _, prompt_state = model_v2(prompt_tokens)
        
        # Get logits from full sequence for LM loss
        logits, _, _, _ = model_v2(full_tokens)
        
        # 1. LM loss
        targets = full_tokens[:, 1:].contiguous()
        lm_loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
        
        # 2. Answer loss
        answer_loss = F.cross_entropy(
            logits[0, answer_start-1:answer_start].reshape(-1, 50257),
            tokenizer(answer, return_tensors='pt')['input_ids'][0, :1].to(DEVICE)
        )
        
        # 3. State reconstruction from PROMPT state
        state_flat = prompt_state.reshape(1, -1).float()
        recon_logits = recon_head_v2(state_flat)
        recon_loss = F.cross_entropy(
            recon_logits, 
            torch.tensor([color_idx], device=DEVICE)
        )
        
        loss = lm_loss + 2.0 * answer_loss + 2.0 * recon_loss
    
    opt_v2.zero_grad()
    loss.backward()
    opt_v2.step()
    
    if step % 500 == 0:
        print(f'Step {step}: lm={lm_loss.item():.3f}, ans={answer_loss.item():.3f}, recon={recon_loss.item():.3f}')

print('\nV2 model (prompt-state reconstruction):')
test_state_ablation_hard(model_v2, cfg_v2)

# Check state reconstruction
model_v2.eval()
state_correct = 0
for _ in range(100):
    text, answer = make_hard_example()
    color = answer.strip()
    color_idx = color_to_idx[color]
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=700)['input_ids'].to(DEVICE)
    
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        _, _, _, state = model_v2(tokens)
    state_flat = state.reshape(1, -1).float()
    recon_pred = recon_head_v2(state_flat).argmax().item()
    if recon_pred == color_idx:
        state_correct += 1

print(f"\nState reconstruction accuracy: {state_correct}%")

Training with reconstruction loss on PROMPT state (not prompt+answer)...
Step 0: lm=10.923, ans=10.312, recon=2.047
Step 500: lm=0.030, ans=0.006, recon=2.062
Step 1000: lm=0.028, ans=0.002, recon=2.156
Step 1500: lm=0.028, ans=0.001, recon=2.094
Step 2000: lm=0.342, ans=0.001, recon=2.094
Step 2500: lm=0.026, ans=0.000, recon=2.078
Step 3000: lm=0.026, ans=0.000, recon=2.125
Step 3500: lm=0.024, ans=0.000, recon=2.062
Step 4000: lm=0.026, ans=0.000, recon=2.094
Step 4500: lm=0.025, ans=0.000, recon=2.078

V2 model (prompt-state reconstruction):
Normal: 60/100 (60%)
Zeroed: 60/100 (60%)
Delta: 0

State reconstruction accuracy: 9%


In [14]:
# SIMPLEST TEST: Can we even train a tiny model to store and retrieve 1 fact?
# Forget distractors - just: store fact, query, retrieve

# Minimal test: "red" -> state -> query -> "red"
cfg_tiny = HybridConfig(
    d_model=128, n_heads=2, head_dim=64, value_dim=64,
    layer_pattern='G',  # Just GDN, no SWA
    vocab_size=50257,
    window_size=32, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=1.0  # Always write
)
model_tiny = TransparentHybrid(cfg_tiny).to(DEVICE)

# Simple task: "Alice has red. What color?" -> "red"
def make_minimal_example():
    color = random.choice(COLORS)
    text = f"{color}. What color?"
    answer = f" {color}"
    return text, answer

opt_tiny = torch.optim.AdamW(model_tiny.parameters(), lr=1e-3)

print('Training TINY model on minimal retrieval task...')
for step in range(2000):
    text, answer = make_minimal_example()
    tokens = tokenizer(text + answer, return_tensors='pt')['input_ids'].to(DEVICE)
    targets = tokens[:, 1:].contiguous()
    
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_tiny(tokens)
        loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
    
    opt_tiny.zero_grad()
    loss.backward()
    opt_tiny.step()
    
    if step % 400 == 0:
        print(f'Step {step}: loss={loss.item():.3f}')

# Test
model_tiny.eval()
correct = 0
preds_dict = {}
for _ in range(100):
    text, answer = make_minimal_example()
    tokens = tokenizer(text, return_tensors='pt')['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_tiny(tokens)
    
    pred = tokenizer.decode([logits[0, -1].argmax().item()]).strip()
    expected = answer.strip()
    
    if expected not in preds_dict:
        preds_dict[expected] = []
    preds_dict[expected].append(pred)
    
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct += 1

print(f'\nAccuracy: {correct}%')
print('Sample predictions by expected:')
for exp, preds in list(preds_dict.items())[:4]:
    print(f'  Expected "{exp}": got {Counter(preds).most_common(3)}')

Training TINY model on minimal retrieval task...
Step 0: loss=10.887
Step 400: loss=0.040
Step 800: loss=0.007
Step 1200: loss=0.002
Step 1600: loss=0.001

Accuracy: 100%
Sample predictions by expected:
  Expected "red": got [('red', 15)]
  Expected "blue": got [('blue', 17)]
  Expected "yellow": got [('yellow', 9)]
  Expected "white": got [('white', 15)]


In [15]:
# CURRICULUM: Gradually increase distractor length
# Start with 0 distractors, increase to 50, 100, 200

def make_curriculum_example(n_distractor_tokens):
    """Fact, distractor, query."""
    color = random.choice(COLORS)
    fact = f"{color}."
    distractor = " word" * n_distractor_tokens
    query = " What color?"
    answer = f" {color}"
    return fact + distractor + query, answer

# Fresh model
cfg_curr = HybridConfig(
    d_model=128, n_heads=2, head_dim=64, value_dim=64,
    layer_pattern='G',
    vocab_size=50257,
    window_size=32, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=1.0
)
model_curr = TransparentHybrid(cfg_curr).to(DEVICE)
opt_curr = torch.optim.AdamW(model_curr.parameters(), lr=1e-3)

curriculum = [
    (0, 500),    # 0 distractors for 500 steps
    (10, 500),   # 10 distractors
    (30, 500),   # 30 distractors
    (50, 500),   # 50 distractors
    (100, 500),  # 100 distractors
]

print('Curriculum training...')
step = 0
for n_dist, n_steps in curriculum:
    for _ in range(n_steps):
        text, answer = make_curriculum_example(n_dist)
        tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
        targets = tokens[:, 1:].contiguous()
        
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            logits, _, _, _ = model_curr(tokens)
            loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
        
        opt_curr.zero_grad()
        loss.backward()
        opt_curr.step()
        step += 1
        
        if step % 500 == 0:
            print(f'Step {step}: dist={n_dist}, loss={loss.item():.3f}')

# Test at each distractor level
print('\nTesting at each distractor level:')
model_curr.eval()
for test_dist in [0, 10, 30, 50, 100]:
    correct = 0
    for _ in range(50):
        text, answer = make_curriculum_example(test_dist)
        tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
        answer_id = tokenizer.encode(answer)[0]
        
        with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
            logits, _, _, _ = model_curr(tokens)
        
        if answer_id in logits[0, -1].topk(5).indices.tolist():
            correct += 1
    
    print(f'  {test_dist:3d} distractors: {correct}/50 = {correct*2}%')

Curriculum training...
Step 500: dist=0, loss=0.015
Step 1000: dist=10, loss=0.003
Step 1500: dist=30, loss=0.041
Step 2000: dist=50, loss=0.083
Step 2500: dist=100, loss=0.085

Testing at each distractor level:
    0 distractors: 50/50 = 100%
   10 distractors: 50/50 = 100%
   30 distractors: 50/50 = 100%
   50 distractors: 50/50 = 100%
  100 distractors: 50/50 = 100%


In [16]:
# STATE ABLATION on curriculum-trained GDN model
print('State ablation on curriculum-trained model:')
print('Testing at 100 distractors (well beyond any local window):\n')

correct_normal = 0
correct_zeroed = 0

for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    
    # Normal
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_curr(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct_normal += 1
    
    # Zeroed state - manually run forward with state zeroed
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        x = model_curr.embed(tokens)
        x = model_curr.embed_norm(x)
        state = torch.zeros(1, cfg_curr.n_heads, cfg_curr.head_dim, cfg_curr.value_dim, 
                           device=DEVICE, dtype=x.dtype)
        
        for i, (layer, ffn) in enumerate(zip(model_curr.layers, model_curr.ffns)):
            lt = cfg_curr.layer_pattern[i]
            if lt == 'G':
                # Run GDN but KEEP STATE ZEROED (don't update)
                x_norm = layer.norm(x)
                k = layer.k_proj(x_norm).view(1, -1, cfg_curr.n_heads, cfg_curr.head_dim)
                v = layer.v_proj(x_norm).view(1, -1, cfg_curr.n_heads, cfg_curr.value_dim)
                
                # Output as if state is zero: S @ k = 0
                out = torch.zeros_like(v)
                
                output = out.reshape(1, -1, cfg_curr.n_heads * cfg_curr.value_dim)
                x = x + layer.o_proj(output)
            x = ffn(x)
        
        logits_z = model_curr.lm_head(model_curr.norm_f(x))
    
    if answer_id in logits_z[0, -1].topk(5).indices.tolist():
        correct_zeroed += 1

print(f'Normal:  {correct_normal}/100 = {correct_normal}%')
print(f'Zeroed:  {correct_zeroed}/100 = {correct_zeroed}%')
print(f'Delta:   {correct_normal - correct_zeroed}')
print(f'\n{"STATE MATTERS!" if correct_normal - correct_zeroed > 20 else "STATE IS STILL DECORATIVE"}')

State ablation on curriculum-trained model:
Testing at 100 distractors (well beyond any local window):

Normal:  100/100 = 100%
Zeroed:  57/100 = 57%
Delta:   43

STATE MATTERS!


In [20]:
# HYBRID with curriculum + aggressive local_scale (bottleneck)
cfg_hybrid = HybridConfig(
    d_model=128, n_heads=2, head_dim=64, value_dim=64,
    layer_pattern='GS',  # GDN + SWA
    vocab_size=50257,
    window_size=32, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=1.0,
    local_scale=0.1  # Aggressive bottleneck on SWA local path
)
model_hybrid = TransparentHybrid(cfg_hybrid).to(DEVICE)
opt_hybrid = torch.optim.AdamW(model_hybrid.parameters(), lr=1e-3)

curriculum = [
    (100, 1000),  # Start with 100 distractors (forces memory use)
    (50, 500),
    (30, 500),
    (10, 500),
]

print('Hybrid curriculum training (starting HARD)...')
step = 0
for n_dist, n_steps in curriculum:
    for _ in range(n_steps):
        text, answer = make_curriculum_example(n_dist)
        tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
        targets = tokens[:, 1:].contiguous()
        
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            logits, _, _, _ = model_hybrid(tokens)
            loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
        
        opt_hybrid.zero_grad()
        loss.backward()
        opt_hybrid.step()
        step += 1
        
        if step % 500 == 0:
            print(f'Step {step}: dist={n_dist}, loss={loss.item():.3f}')

# Test hybrid
print('\nHybrid model accuracy at 100 distractors:')
model_hybrid.eval()
correct = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_hybrid(tokens)
    
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct += 1

print(f'Accuracy: {correct}%')

# State ablation
print('\nState ablation:')
correct_normal = 0
correct_zeroed = 0

for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    
    # Normal
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_hybrid(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct_normal += 1
    
    # Zeroed - run manually
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        x = model_hybrid.embed(tokens)
        x = model_hybrid.embed_norm(x)
        state = None
        for i, (layer, ffn) in enumerate(zip(model_hybrid.layers, model_hybrid.ffns)):
            if cfg_hybrid.layer_pattern[i] == 'G':
                x, state, _ = layer(x, initial_state=state, input_ids=tokens)
                key_bank = layer.key_bank
            else:
                # Zero the state for SWA
                x, _ = layer(x, gdn_state=torch.zeros_like(state), input_ids=tokens, key_bank=key_bank)
            x = ffn(x)
        logits_z = model_hybrid.lm_head(model_hybrid.norm_f(x))
    
    if answer_id in logits_z[0, -1].topk(5).indices.tolist():
        correct_zeroed += 1

print(f'Normal: {correct_normal}%')
print(f'Zeroed: {correct_zeroed}%')
print(f'Delta: {correct_normal - correct_zeroed}')

Hybrid curriculum training (starting HARD)...
Step 500: dist=100, loss=0.056
Step 1000: dist=100, loss=0.046
Step 1500: dist=50, loss=0.027
Step 2000: dist=30, loss=0.014
Step 2500: dist=10, loss=0.001

Hybrid model accuracy at 100 distractors:
Accuracy: 29%

State ablation:
Normal: 35%
Zeroed: 65%
Delta: -30


In [22]:
# PROPER TRAINING per ssm_training_text.md
# Fixes:
# 1. LR 3e-4 (not 1e-3)
# 2. 500 step warmup
# 3. Layer ratio: GGGGGS (5:1 GDN:SWA per doc's 1:7 attention:SSM)
# 4. Gradient clipping
# 5. Use BF16 but keep state operations careful

from torch.optim.lr_scheduler import LambdaLR
import math

def get_warmup_scheduler(optimizer, warmup_steps, total_steps):
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return 0.5 * (1 + math.cos(math.pi * progress))
    return LambdaLR(optimizer, lr_lambda)

# Proper hybrid config with 5:1 GDN:SWA ratio
cfg_proper = HybridConfig(
    d_model=128, n_heads=2, head_dim=64, value_dim=64,
    layer_pattern='GGGGGS',  # 5:1 ratio
    vocab_size=50257,
    window_size=32, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=1.0,
    local_scale=1.0
)

model_proper = TransparentHybrid(cfg_proper).to(DEVICE)

# Proper optimizer settings per doc
opt_proper = torch.optim.AdamW(
    model_proper.parameters(), 
    lr=3e-4,
    betas=(0.9, 0.95),
    weight_decay=0.1
)

total_steps = 5000
warmup_steps = 500
scheduler = get_warmup_scheduler(opt_proper, warmup_steps, total_steps)

print('Training with PROPER settings (LR 3e-4, warmup, 5:1 ratio, grad clip)...')
for step in range(total_steps):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    targets = tokens[:, 1:].contiguous()
    
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_proper(tokens)
        loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
    
    opt_proper.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model_proper.parameters(), 1.0)
    opt_proper.step()
    scheduler.step()
    
    if step % 500 == 0:
        print(f'Step {step}: loss={loss.item():.3f}, lr={scheduler.get_last_lr()[0]:.2e}')

# Test
print('\nProperly trained model (5:1 GDN:SWA ratio):')
model_proper.eval()
correct = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_proper(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct += 1
print(f'Accuracy: {correct}%')

# State ablation
correct_normal = 0
correct_zeroed = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_proper(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct_normal += 1
    
    # Zeroed
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        x = model_proper.embed(tokens)
        x = model_proper.embed_norm(x)
        state = None
        for i, (layer, ffn) in enumerate(zip(model_proper.layers, model_proper.ffns)):
            if cfg_proper.layer_pattern[i] == 'G':
                x, state, _ = layer(x, initial_state=state, input_ids=tokens)
                key_bank = layer.key_bank
            else:
                x, _ = layer(x, gdn_state=torch.zeros_like(state), input_ids=tokens, key_bank=key_bank)
            x = ffn(x)
        logits_z = model_proper.lm_head(model_proper.norm_f(x))
    if answer_id in logits_z[0, -1].topk(5).indices.tolist():
        correct_zeroed += 1

print(f'\nState Ablation:')
print(f'Normal: {correct_normal}%')
print(f'Zeroed: {correct_zeroed}%')
print(f'Delta: {correct_normal - correct_zeroed}')

Training with PROPER settings (LR 3e-4, warmup, 5:1 ratio, grad clip)...
Step 0: loss=8.866, lr=6.00e-07
Step 500: loss=0.080, lr=3.00e-04
Step 1000: loss=0.018, lr=2.91e-04
Step 1500: loss=0.015, lr=2.65e-04
Step 2000: loss=0.027, lr=2.25e-04
Step 2500: loss=0.014, lr=1.76e-04
Step 3000: loss=0.041, lr=1.24e-04
Step 3500: loss=0.055, lr=7.49e-05
Step 4000: loss=0.001, lr=3.50e-05
Step 4500: loss=0.000, lr=9.01e-06

Properly trained model (5:1 GDN:SWA ratio):
Accuracy: 44%

State Ablation:
Normal: 50%
Zeroed: 52%
Delta: -2


In [23]:
# Compare: GDN-ONLY with same proper training settings
# This tests if the issue is with GDN or with the hybrid

cfg_gdn_only = HybridConfig(
    d_model=128, n_heads=2, head_dim=64, value_dim=64,
    layer_pattern='GGGGGG',  # 6 GDN layers, NO SWA
    vocab_size=50257,
    window_size=32, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=1.0
)

model_gdn_only = TransparentHybrid(cfg_gdn_only).to(DEVICE)
opt_gdn_only = torch.optim.AdamW(
    model_gdn_only.parameters(), 
    lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1
)
scheduler_gdn = get_warmup_scheduler(opt_gdn_only, 500, 5000)

print('Training GDN-ONLY with proper settings...')
for step in range(5000):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    targets = tokens[:, 1:].contiguous()
    
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_gdn_only(tokens)
        loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
    
    opt_gdn_only.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model_gdn_only.parameters(), 1.0)
    opt_gdn_only.step()
    scheduler_gdn.step()
    
    if step % 1000 == 0:
        print(f'Step {step}: loss={loss.item():.3f}')

# Test
print('\nGDN-ONLY model:')
model_gdn_only.eval()
correct = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_gdn_only(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct += 1
print(f'Accuracy: {correct}%')

# State ablation for GDN-only
correct_normal = 0
correct_zeroed = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_gdn_only(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct_normal += 1
    
    # Zeroed - for GDN-only, we zero all states
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        x = model_gdn_only.embed(tokens)
        x = model_gdn_only.embed_norm(x)
        for i, (layer, ffn) in enumerate(zip(model_gdn_only.layers, model_gdn_only.ffns)):
            # Run with zero initial state and don't accumulate
            x_norm = layer.norm(x)
            # Just output zeros from delta rule (state is zeroed)
            out = torch.zeros(1, tokens.size(1), cfg_gdn_only.n_heads * cfg_gdn_only.value_dim, 
                            device=DEVICE, dtype=x.dtype)
            x = x + layer.o_proj(out)
            x = ffn(x)
        logits_z = model_gdn_only.lm_head(model_gdn_only.norm_f(x))
    if answer_id in logits_z[0, -1].topk(5).indices.tolist():
        correct_zeroed += 1

print(f'\nGDN-ONLY State Ablation:')
print(f'Normal: {correct_normal}%')
print(f'Zeroed: {correct_zeroed}%')
print(f'Delta: {correct_normal - correct_zeroed}')

Training GDN-ONLY with proper settings...
Step 0: loss=8.986
Step 1000: loss=0.080
Step 2000: loss=0.061
Step 3000: loss=0.043
Step 4000: loss=0.085

GDN-ONLY model:
Accuracy: 68%

GDN-ONLY State Ablation:
Normal: 71%
Zeroed: 71%
Delta: 0


In [24]:
# REPLICATE SUCCESSFUL CONFIG: single GDN layer + curriculum + high LR
# This got Delta=43 earlier

cfg_success = HybridConfig(
    d_model=128, n_heads=2, head_dim=64, value_dim=64,
    layer_pattern='G',  # SINGLE GDN layer
    vocab_size=50257,
    window_size=32, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=1.0  # Always write
)
model_success = TransparentHybrid(cfg_success).to(DEVICE)
opt_success = torch.optim.AdamW(model_success.parameters(), lr=1e-3)  # High LR

# Curriculum: 0 -> 100 distractors
curriculum_success = [
    (0, 500),
    (10, 500),
    (30, 500),
    (50, 500),
    (100, 500),
]

print('Training single GDN + curriculum (the successful recipe)...')
step = 0
for n_dist, n_steps in curriculum_success:
    for _ in range(n_steps):
        text, answer = make_curriculum_example(n_dist)
        tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
        targets = tokens[:, 1:].contiguous()
        
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            logits, _, _, _ = model_success(tokens)
            loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
        
        opt_success.zero_grad()
        loss.backward()
        opt_success.step()
        step += 1
        
        if step % 500 == 0:
            print(f'Step {step}: dist={n_dist}, loss={loss.item():.3f}')

# Verify it works
print('\nSingle GDN + curriculum accuracy at 100 distractors:')
model_success.eval()
correct = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_success(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct += 1
print(f'Accuracy: {correct}%')

# State ablation - proper method for single GDN
correct_normal = 0
correct_zeroed = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_success(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct_normal += 1
    
    # Zeroed
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        x = model_success.embed(tokens)
        x = model_success.embed_norm(x)
        state = torch.zeros(1, cfg_success.n_heads, cfg_success.head_dim, cfg_success.value_dim, 
                           device=DEVICE, dtype=x.dtype)
        for i, (layer, ffn) in enumerate(zip(model_success.layers, model_success.ffns)):
            # Run GDN but with state contribution zeroed in output
            x_norm = layer.norm(x)
            out = torch.zeros(1, tokens.size(1), cfg_success.n_heads * cfg_success.value_dim, 
                            device=DEVICE, dtype=x.dtype)
            x = x + layer.o_proj(out)
            x = ffn(x)
        logits_z = model_success.lm_head(model_success.norm_f(x))
    if answer_id in logits_z[0, -1].topk(5).indices.tolist():
        correct_zeroed += 1

print(f'\nState Ablation (single GDN + curriculum):')
print(f'Normal: {correct_normal}%')
print(f'Zeroed: {correct_zeroed}%')
print(f'Delta: {correct_normal - correct_zeroed}')
print(f'\n{"✅ STATE MATTERS!" if correct_normal - correct_zeroed > 20 else "❌ STATE IS DECORATIVE"}')

Training single GDN + curriculum (the successful recipe)...
Step 500: dist=0, loss=0.012
Step 1000: dist=10, loss=0.172
Step 1500: dist=30, loss=0.089
Step 2000: dist=50, loss=0.140
Step 2500: dist=100, loss=0.174

Single GDN + curriculum accuracy at 100 distractors:
Accuracy: 100%

State Ablation (single GDN + curriculum):
Normal: 100%
Zeroed: 64%
Delta: 36

✅ STATE MATTERS!


In [25]:
# NOW: Add SWA to successful recipe
# Key insight from docs: SWA should NOT be able to solve the task alone

# Test 1: GS with curriculum
cfg_gs_curr = HybridConfig(
    d_model=128, n_heads=2, head_dim=64, value_dim=64,
    layer_pattern='GS',
    vocab_size=50257,
    window_size=32, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=1.0,
    local_scale=1.0  # Full local attention
)
model_gs = TransparentHybrid(cfg_gs_curr).to(DEVICE)
opt_gs = torch.optim.AdamW(model_gs.parameters(), lr=1e-3)

curriculum_gs = [
    (0, 500),
    (10, 500),
    (30, 500),
    (50, 500),
    (100, 500),
]

print('Training GS (hybrid) with curriculum...')
step = 0
for n_dist, n_steps in curriculum_gs:
    for _ in range(n_steps):
        text, answer = make_curriculum_example(n_dist)
        tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
        targets = tokens[:, 1:].contiguous()
        
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            logits, _, _, _ = model_gs(tokens)
            loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
        
        opt_gs.zero_grad()
        loss.backward()
        opt_gs.step()
        step += 1
        
        if step % 500 == 0:
            print(f'Step {step}: dist={n_dist}, loss={loss.item():.3f}')

# Test
print('\nGS hybrid accuracy:')
model_gs.eval()
correct = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_gs(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct += 1
print(f'Accuracy: {correct}%')

# State ablation
correct_normal = 0
correct_zeroed = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_gs(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct_normal += 1
    
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        x = model_gs.embed(tokens)
        x = model_gs.embed_norm(x)
        state = None
        for i, (layer, ffn) in enumerate(zip(model_gs.layers, model_gs.ffns)):
            if cfg_gs_curr.layer_pattern[i] == 'G':
                x, state, _ = layer(x, initial_state=state, input_ids=tokens)
                key_bank = layer.key_bank
            else:
                x, _ = layer(x, gdn_state=torch.zeros_like(state), input_ids=tokens, key_bank=key_bank)
            x = ffn(x)
        logits_z = model_gs.lm_head(model_gs.norm_f(x))
    if answer_id in logits_z[0, -1].topk(5).indices.tolist():
        correct_zeroed += 1

print(f'\nGS Hybrid State Ablation:')
print(f'Normal: {correct_normal}%')
print(f'Zeroed: {correct_zeroed}%')
print(f'Delta: {correct_normal - correct_zeroed}')
print(f'\n{"✅ STATE MATTERS!" if correct_normal - correct_zeroed > 10 else "❌ SWA bypasses state"}')

Training GS (hybrid) with curriculum...
Step 500: dist=0, loss=0.008
Step 1000: dist=10, loss=0.002
Step 1500: dist=30, loss=0.019
Step 2000: dist=50, loss=0.071
Step 2500: dist=100, loss=0.049

GS hybrid accuracy:
Accuracy: 18%

GS Hybrid State Ablation:
Normal: 14%
Zeroed: 14%
Delta: 0

❌ SWA bypasses state


In [26]:
# ============================================================
# FIX: SHARED KEY/QUERY SPACE (per practical_hybrid_solutions.md)
# SWA now uses GDN's k_proj for state queries - aligned retrieval
# ============================================================

# Reload to pick up model changes
import importlib
import model as model_module
importlib.reload(model_module)
from model import TransparentHybrid

# Same config as before, but now with shared key projection
cfg_shared = HybridConfig(
    d_model=128, n_heads=2, head_dim=64, value_dim=64,
    layer_pattern='GS',  # Hybrid
    vocab_size=50257,
    window_size=32, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=1.0,
    local_scale=1.0  # Full local attention
)
model_shared = TransparentHybrid(cfg_shared).to(DEVICE)
opt_shared = torch.optim.AdamW(model_shared.parameters(), lr=1e-3)

# Same curriculum that worked for GDN-only
curriculum_shared = [
    (0, 500),
    (10, 500),
    (30, 500),
    (50, 500),
    (100, 500),
]

print('Training HYBRID with SHARED key projection (the fix)...')
step = 0
for n_dist, n_steps in curriculum_shared:
    for _ in range(n_steps):
        text, answer = make_curriculum_example(n_dist)
        tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
        targets = tokens[:, 1:].contiguous()
        
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            logits, _, _, _ = model_shared(tokens)
            loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
        
        opt_shared.zero_grad()
        loss.backward()
        opt_shared.step()
        step += 1
        
        if step % 500 == 0:
            print(f'Step {step}: dist={n_dist}, loss={loss.item():.3f}')

# Test accuracy
print('\nHybrid with SHARED key projection:')
model_shared.eval()
correct = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_shared(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct += 1
print(f'Accuracy: {correct}%')

# State ablation - the real test
correct_normal = 0
correct_zeroed = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_shared(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct_normal += 1
    
    # Zeroed state
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        x = model_shared.embed(tokens)
        x = model_shared.embed_norm(x)
        state = None
        gdn_k_proj = None
        for i, (layer, ffn) in enumerate(zip(model_shared.layers, model_shared.ffns)):
            if cfg_shared.layer_pattern[i] == 'G':
                x, state, _ = layer(x, initial_state=state, input_ids=tokens)
                key_bank = layer.key_bank
                gdn_k_proj = layer.k_proj
            else:
                # Pass ZEROED state but still use shared k_proj
                x, _ = layer(x, gdn_state=torch.zeros_like(state), input_ids=tokens, 
                            key_bank=key_bank, gdn_k_proj=gdn_k_proj)
            x = ffn(x)
        logits_z = model_shared.lm_head(model_shared.norm_f(x))
    if answer_id in logits_z[0, -1].topk(5).indices.tolist():
        correct_zeroed += 1

print(f'\nState Ablation (SHARED key projection):')
print(f'Normal: {correct_normal}%')
print(f'Zeroed: {correct_zeroed}%')
print(f'Delta: {correct_normal - correct_zeroed}')
print(f'\n{"✅ STATE MATTERS - FIX WORKS!" if correct_normal - correct_zeroed > 10 else "❌ Still bypassing state"}')

Training HYBRID with SHARED key projection (the fix)...
Step 500: dist=0, loss=0.008
Step 1000: dist=10, loss=0.002
Step 1500: dist=30, loss=0.028
Step 2000: dist=50, loss=0.039
Step 2500: dist=100, loss=0.048

Hybrid with SHARED key projection:
Accuracy: 37%

State Ablation (SHARED key projection):
Normal: 39%
Zeroed: 54%
Delta: -15

❌ Still bypassing state


In [27]:
# ============================================================
# CONCLUSION: Per practical_hybrid_solutions.md
# "For retrieval tasks, GDN-only is sufficient. Add SWA only if LM quality suffers."
# 
# Let's test both on actual language modeling (perplexity on real text)
# ============================================================

# Use the curriculum-trained GDN-only model (model_success) and hybrid (model_shared)
# Test perplexity on real text

test_texts = [
    "The quick brown fox jumps over the lazy dog. This is a test of language modeling capabilities.",
    "In machine learning, neural networks are computational systems inspired by biological neural networks.",
    "The weather today is sunny with a chance of rain in the afternoon. Pack an umbrella just in case.",
    "Python is a high-level programming language known for its readability and versatility.",
    "Coffee is one of the most popular beverages in the world, consumed by millions daily.",
]

print('Testing language modeling (perplexity):')
print('=' * 50)

# Test GDN-only (model_success - the one that works for retrieval)
model_success.eval()
gdn_losses = []
for text in test_texts:
    tokens = tokenizer(text, return_tensors='pt')['input_ids'].to(DEVICE)
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_success(tokens)
    targets = tokens[:, 1:].contiguous()
    loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
    gdn_losses.append(loss.item())

# Test Hybrid (model_shared - broken for retrieval but maybe better for LM)
model_shared.eval()
hybrid_losses = []
for text in test_texts:
    tokens = tokenizer(text, return_tensors='pt')['input_ids'].to(DEVICE)
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_shared(tokens)
    targets = tokens[:, 1:].contiguous()
    loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
    hybrid_losses.append(loss.item())

import math
gdn_ppl = math.exp(sum(gdn_losses) / len(gdn_losses))
hybrid_ppl = math.exp(sum(hybrid_losses) / len(hybrid_losses))

print(f'GDN-only perplexity: {gdn_ppl:.1f}')
print(f'Hybrid perplexity: {hybrid_ppl:.1f}')
print()

# Summary table
print('=' * 60)
print('SUMMARY: GDN-only vs Hybrid')
print('=' * 60)
print(f'{"Metric":<25} {"GDN-only":<15} {"Hybrid":<15}')
print('-' * 60)
print(f'{"Retrieval accuracy":<25} {"100%":<15} {"37%":<15}')
print(f'{"State ablation delta":<25} {"+36":<15} {"-15":<15}')
print(f'{"Perplexity":<25} {f"{gdn_ppl:.1f}":<15} {f"{hybrid_ppl:.1f}":<15}')
print('-' * 60)
print()
print('CONCLUSION:')
print('  GDN-only: STATE WORKS, retrieval works, may need SWA for fluency')
print('  Hybrid: STATE BYPASSED, retrieval broken, SWA dominates')
print()
print('Per practical_hybrid_solutions.md: Use GDN-only for retrieval.')
print('Add SWA only if perplexity is unacceptably high on real LM tasks.')

Testing language modeling (perplexity):
GDN-only perplexity: 607583395.3
Hybrid perplexity: 153621140.1

SUMMARY: GDN-only vs Hybrid
Metric                    GDN-only        Hybrid         
------------------------------------------------------------
Retrieval accuracy        100%            37%            
State ablation delta      +36             -15            
Perplexity                607583395.3     153621140.1    
------------------------------------------------------------

CONCLUSION:
  GDN-only: STATE WORKS, retrieval works, may need SWA for fluency
  Hybrid: STATE BYPASSED, retrieval broken, SWA dominates

Per practical_hybrid_solutions.md: Use GDN-only for retrieval.
Add SWA only if perplexity is unacceptably high on real LM tasks.


# Final Results Summary

## Key Finding: The Hybrid Architecture Problem

| Config | Retrieval Accuracy | State Delta | State Works? |
|--------|-------------------|-------------|--------------|
| GDN-only + curriculum | **100%** | **+36** | ✅ YES |
| GDN+SWA (separate projections) | 18% | 0 | ❌ NO |
| GDN+SWA (shared key projection) | 37% | -15 | ❌ NO |

## Root Cause (per practical_hybrid_solutions.md)

> "The GDN-only model works because it's autoassociative—the same key used to write is used to read."

When SWA is added:
1. SWA's local attention has **stronger gradient signal** than state retrieval
2. Even with shared key projection, SWA's residual path carries information
3. Model learns to ignore state because **"when attention is available, attention wins"**

## What Works

1. **GDN-only with curriculum**: 100% accuracy, Delta=+36
2. **beta_floor=1.0**: Always write (not gated)
3. **Curriculum learning**: 0→10→30→50→100 distractors
4. **High LR (1e-3)**: Simple model, simple task

## Next Steps (per the docs)

For a **long-context conversational agent** that maintains character/topic consistency:

1. **For retrieval**: Use GDN-only layers
2. **For language fluency**: Add minimal SWA (1:7 ratio per Jamba)
3. **Critical**: Place attention layers LAST, not interspersed
4. **Train on real LM objective** with retrieval auxiliary task

In [28]:
# ============================================================
# ARCHITECTURE FIX: GDN layers FIRST, SWA at END ONLY
# Per ssm_training_text.md: "never place Transformer blocks at the front"
# Per practical_hybrid_solutions.md: Use GDN for state, SWA only for fluency
# ============================================================

# Reload model
import importlib
import model as model_module
importlib.reload(model_module)
from model import TransparentHybrid

# Architecture: GGGGGS (5 GDN layers, then 1 SWA at the end)
# This ensures GDN does all the heavy lifting, SWA is just for output fluency
cfg_final = HybridConfig(
    d_model=128, n_heads=2, head_dim=64, value_dim=64,
    layer_pattern='GGGGGS',  # 5:1 ratio, SWA at END
    vocab_size=50257,
    window_size=32, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=1.0,
    local_scale=0.5  # Moderate bottleneck on local path
)
model_final = TransparentHybrid(cfg_final).to(DEVICE)
opt_final = torch.optim.AdamW(model_final.parameters(), lr=1e-3)

# Curriculum
curriculum_final = [
    (0, 500),
    (10, 500),
    (30, 500),
    (50, 500),
    (100, 1000),  # More time on hard task
]

print('Training GGGGGS (5:1 GDN:SWA, SWA at end only)...')
step = 0
for n_dist, n_steps in curriculum_final:
    for _ in range(n_steps):
        text, answer = make_curriculum_example(n_dist)
        tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
        targets = tokens[:, 1:].contiguous()
        
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            logits, _, _, _ = model_final(tokens)
            loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
        
        opt_final.zero_grad()
        loss.backward()
        opt_final.step()
        step += 1
        
        if step % 500 == 0:
            print(f'Step {step}: dist={n_dist}, loss={loss.item():.3f}')

# Test
print('\nGGGGGS (5:1, SWA at end):')
model_final.eval()
correct = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_final(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct += 1
print(f'Accuracy: {correct}%')

# State ablation
correct_normal = 0
correct_zeroed = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_final(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct_normal += 1
    
    # Zeroed - manually run
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        x = model_final.embed(tokens)
        x = model_final.embed_norm(x)
        state = None
        gdn_k_proj = None
        for i, (layer, ffn) in enumerate(zip(model_final.layers, model_final.ffns)):
            if cfg_final.layer_pattern[i] == 'G':
                x, state, _ = layer(x, initial_state=state, input_ids=tokens)
                key_bank = layer.key_bank
                gdn_k_proj = layer.k_proj
            else:
                x, _ = layer(x, gdn_state=torch.zeros_like(state), input_ids=tokens,
                            key_bank=key_bank, gdn_k_proj=gdn_k_proj)
            x = ffn(x)
        logits_z = model_final.lm_head(model_final.norm_f(x))
    if answer_id in logits_z[0, -1].topk(5).indices.tolist():
        correct_zeroed += 1

print(f'\nState Ablation (GGGGGS):')
print(f'Normal: {correct_normal}%')
print(f'Zeroed: {correct_zeroed}%')
print(f'Delta: {correct_normal - correct_zeroed}')

if correct_normal - correct_zeroed > 10:
    print('\n✅ STATE MATTERS with proper layer ordering!')
else:
    print('\n❌ Still bypassing - SWA residual too strong')

Training GGGGGS (5:1 GDN:SWA, SWA at end only)...
Step 500: dist=0, loss=0.008
Step 1000: dist=10, loss=0.001
Step 1500: dist=30, loss=0.001
Step 2000: dist=50, loss=0.001
Step 2500: dist=100, loss=0.023
Step 3000: dist=100, loss=0.002

GGGGGS (5:1, SWA at end):
Accuracy: 69%

State Ablation (GGGGGS):
Normal: 61%
Zeroed: 61%
Delta: 0

❌ Still bypassing - SWA residual too strong


# ✅ FINAL CONCLUSIONS

## What We Tested

| Architecture | Accuracy | State Delta | Works? |
|--------------|----------|-------------|--------|
| GDN-only (1 layer) + curriculum | **100%** | **+36** | ✅ |
| GS (1:1 hybrid) + curriculum | 18% | 0 | ❌ |
| GS + shared key projection | 37% | -15 | ❌ |
| GGGGGS (5:1, SWA at end) | 69% | 0 | ❌ |

## Root Cause

Per **compass_artifact_text_markdown.md**:
> "When attention is available, attention wins"

Per **practical_hybrid_solutions.md**:
> "The hybrid breaks this because you've introduced a second projection that never learns to query the state."

Even with:
- Shared key projections
- SWA at the end only
- Local path bottleneck (0.5 scale)

...the **residual connection** `out = x + local_out + retrieval_out` carries information from GDN output through SWA without needing state retrieval.

## The Solution

For **long-context retrieval tasks**:
1. Use **GDN-only** architecture
2. Apply **curriculum learning** (easy → hard)
3. Use **beta_floor=1.0** (always write)
4. Test with **state ablation** to verify state is used

For **language modeling + retrieval** (the real goal):
1. Train GDN on LM objective first
2. Add retrieval as auxiliary task
3. Only add SWA if perplexity is unacceptable
4. If adding SWA: **remove residual around state retrieval** (break the shortcut)

In [29]:
# ============================================================
# FINAL FIX: Remove residual around state retrieval
# Change: out = local_out + retrieval_out (NO x residual in SWA)
# This FORCES state usage by breaking the shortcut
# ============================================================

# Modify SWA forward to not use x residual
# We'll do this by subclassing

class SWA_NoResidual(model_module.SlidingWindowAttention):
    """SWA without residual - forces state retrieval usage."""
    
    def forward(self, x, gdn_state=None, input_ids=None, key_bank=None, gdn_k_proj=None):
        # Call parent forward
        out, diag = super().forward(x, gdn_state, input_ids, key_bank, gdn_k_proj)
        # Parent does: out = x + local_out + retrieval_out
        # We want: out = local_out + retrieval_out (no x)
        # So subtract x
        return out - x, diag

# Build model with modified SWA
cfg_norez = HybridConfig(
    d_model=128, n_heads=2, head_dim=64, value_dim=64,
    layer_pattern='GS',
    vocab_size=50257,
    window_size=32, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=1.0,
    local_scale=1.0
)
model_norez = TransparentHybrid(cfg_norez).to(DEVICE)

# Replace SWA layer with no-residual version
model_norez.layers[1] = SWA_NoResidual(cfg_norez, 1).to(DEVICE)

opt_norez = torch.optim.AdamW(model_norez.parameters(), lr=1e-3)

curriculum_norez = [
    (0, 500),
    (10, 500),
    (30, 500),
    (50, 500),
    (100, 500),
]

print('Training GS with NO RESIDUAL in SWA...')
step = 0
for n_dist, n_steps in curriculum_norez:
    for _ in range(n_steps):
        text, answer = make_curriculum_example(n_dist)
        tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
        targets = tokens[:, 1:].contiguous()
        
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            logits, _, _, _ = model_norez(tokens)
            loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
        
        opt_norez.zero_grad()
        loss.backward()
        opt_norez.step()
        step += 1
        
        if step % 500 == 0:
            print(f'Step {step}: dist={n_dist}, loss={loss.item():.3f}')

# Test
print('\nGS with NO RESIDUAL:')
model_norez.eval()
correct = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_norez(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct += 1
print(f'Accuracy: {correct}%')

# State ablation
correct_normal = 0
correct_zeroed = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_norez(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct_normal += 1
    
    # Zeroed
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        x = model_norez.embed(tokens)
        x = model_norez.embed_norm(x)
        state = None
        gdn_k_proj = None
        for i, (layer, ffn) in enumerate(zip(model_norez.layers, model_norez.ffns)):
            if cfg_norez.layer_pattern[i] == 'G':
                x, state, _ = layer(x, initial_state=state, input_ids=tokens)
                key_bank = layer.key_bank
                gdn_k_proj = layer.k_proj
            else:
                x, _ = layer(x, gdn_state=torch.zeros_like(state), input_ids=tokens,
                            key_bank=key_bank, gdn_k_proj=gdn_k_proj)
            x = ffn(x)
        logits_z = model_norez.lm_head(model_norez.norm_f(x))
    if answer_id in logits_z[0, -1].topk(5).indices.tolist():
        correct_zeroed += 1

print(f'\nState Ablation (NO RESIDUAL):')
print(f'Normal: {correct_normal}%')
print(f'Zeroed: {correct_zeroed}%')
print(f'Delta: {correct_normal - correct_zeroed}')

if correct_normal - correct_zeroed > 10:
    print('\n✅ BREAKING RESIDUAL WORKS! State matters in hybrid!')
else:
    print('\n❌ Still not working')

Training GS with NO RESIDUAL in SWA...
Step 500: dist=0, loss=0.007
Step 1000: dist=10, loss=0.609
Step 1500: dist=30, loss=0.133
Step 2000: dist=50, loss=0.024
Step 2500: dist=100, loss=0.051

GS with NO RESIDUAL:
Accuracy: 0%

State Ablation (NO RESIDUAL):
Normal: 0%
Zeroed: 0%
Delta: 0

❌ Still not working


In [30]:
# ============================================================
# ALTERNATIVE: Stochastic depth on local path only
# Per compass doc: "applying higher dropout to non-memory pathways forces memory utilization"
# ============================================================

class SWA_StochasticLocal(model_module.SlidingWindowAttention):
    """SWA with stochastic dropout on local path only."""
    
    def __init__(self, cfg, layer_idx=0, local_drop_prob=0.7):
        super().__init__(cfg, layer_idx)
        self.local_drop_prob = local_drop_prob
    
    def forward(self, x, gdn_state=None, input_ids=None, key_bank=None, gdn_k_proj=None):
        B, T, D = x.shape
        H = self.cfg.n_heads
        K, V, W = self.cfg.head_dim, self.cfg.value_dim, self.cfg.window_size
        
        x_norm = self.norm(x)
        
        # Local attention
        q = self.q_proj(x_norm).view(B, T, H, D // H)
        k = self.k_proj(x_norm).view(B, T, H, D // H)
        v = self.v_proj(x_norm).view(B, T, H, D // H)
        
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        mask = torch.ones(T, T, device=x.device, dtype=torch.bool).triu(1)
        mask |= torch.ones(T, T, device=x.device, dtype=torch.bool).tril(-W - 1)
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        local_out = (F.softmax(attn, dim=-1) @ v).transpose(1, 2).reshape(B, T, D)
        local_out = self.o_proj(local_out)
        
        # STOCHASTIC DEPTH: drop local path during training
        if self.training and torch.rand(1).item() < self.local_drop_prob:
            local_out = torch.zeros_like(local_out)
        
        # State retrieval
        retrieval_out = torch.zeros_like(x)
        if gdn_state is not None and gdn_k_proj is not None:
            q_g = gdn_k_proj(x_norm).view(B, T, H, K)
            q_g = q_g.transpose(1, 2)
            retrieved = torch.einsum('bhkv,bhtk->bhtv', gdn_state.to(x.dtype), q_g)
            retrieved = retrieved.transpose(1, 2).reshape(B, T, H * V)
            retrieval_out = self.retrieval_o_proj(retrieved)
            
            gate = torch.sigmoid(self.gate_proj(x_norm))
            retrieval_out = gate.mean(dim=-1, keepdim=True) * retrieval_out
        
        out = x + local_out + retrieval_out
        return out, {'gate_mean': 0, 'local_norm': local_out.norm().item(), 'retrieval_norm': retrieval_out.norm().item()}

# Build model
cfg_stoch = HybridConfig(
    d_model=128, n_heads=2, head_dim=64, value_dim=64,
    layer_pattern='GS',
    vocab_size=50257,
    window_size=32, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=1.0
)
model_stoch = TransparentHybrid(cfg_stoch).to(DEVICE)
model_stoch.layers[1] = SWA_StochasticLocal(cfg_stoch, 1, local_drop_prob=0.7).to(DEVICE)

opt_stoch = torch.optim.AdamW(model_stoch.parameters(), lr=1e-3)

curriculum_stoch = [
    (0, 500),
    (10, 500),
    (30, 500),
    (50, 500),
    (100, 1000),
]

print('Training with STOCHASTIC LOCAL DROP (70% drop)...')
step = 0
for n_dist, n_steps in curriculum_stoch:
    for _ in range(n_steps):
        text, answer = make_curriculum_example(n_dist)
        tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
        targets = tokens[:, 1:].contiguous()
        
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            logits, _, _, _ = model_stoch(tokens)
            loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
        
        opt_stoch.zero_grad()
        loss.backward()
        opt_stoch.step()
        step += 1
        
        if step % 500 == 0:
            print(f'Step {step}: dist={n_dist}, loss={loss.item():.3f}')

# Test
print('\nStochastic local drop:')
model_stoch.eval()  # Dropout off at eval
correct = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_stoch(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct += 1
print(f'Accuracy: {correct}%')

# State ablation
correct_normal = 0
correct_zeroed = 0
for _ in range(100):
    text, answer = make_curriculum_example(100)
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    answer_id = tokenizer.encode(answer)[0]
    
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_stoch(tokens)
    if answer_id in logits[0, -1].topk(5).indices.tolist():
        correct_normal += 1
    
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        x = model_stoch.embed(tokens)
        x = model_stoch.embed_norm(x)
        state = None
        gdn_k_proj = None
        for i, (layer, ffn) in enumerate(zip(model_stoch.layers, model_stoch.ffns)):
            if cfg_stoch.layer_pattern[i] == 'G':
                x, state, _ = layer(x, initial_state=state, input_ids=tokens)
                key_bank = layer.key_bank
                gdn_k_proj = layer.k_proj
            else:
                x, _ = layer(x, gdn_state=torch.zeros_like(state), input_ids=tokens,
                            key_bank=key_bank, gdn_k_proj=gdn_k_proj)
            x = ffn(x)
        logits_z = model_stoch.lm_head(model_stoch.norm_f(x))
    if answer_id in logits_z[0, -1].topk(5).indices.tolist():
        correct_zeroed += 1

print(f'\nState Ablation (Stochastic local):')
print(f'Normal: {correct_normal}%')
print(f'Zeroed: {correct_zeroed}%')
print(f'Delta: {correct_normal - correct_zeroed}')

if correct_normal - correct_zeroed > 10:
    print('\n✅ STOCHASTIC DEPTH WORKS!')
else:
    print('\n❌ Still bypassing')

Training with STOCHASTIC LOCAL DROP (70% drop)...
Step 500: dist=0, loss=0.009
Step 1000: dist=10, loss=0.002
Step 1500: dist=30, loss=0.021
Step 2000: dist=50, loss=0.045
Step 2500: dist=100, loss=0.025
Step 3000: dist=100, loss=0.028

Stochastic local drop:
Accuracy: 73%

State Ablation (Stochastic local):
Normal: 64%
Zeroed: 49%
Delta: 15

✅ STOCHASTIC DEPTH WORKS!


# 🎉 WORKING HYBRID CONFIGURATION

## Final Results

| Architecture | Accuracy | State Delta | Works? |
|--------------|----------|-------------|--------|
| GDN-only + curriculum | 100% | +36 | ✅ |
| GS (vanilla) | 18% | 0 | ❌ |
| GS + shared keys | 37% | -15 | ❌ |
| GGGGGS (5:1) | 69% | 0 | ❌ |
| **GS + stochastic local drop (70%)** | **73%** | **+15** | ✅ |

## The Working Fix

Per **compass_artifact_text_markdown.md**:
> "Stochastic depth randomly drops entire residual blocks during training; applied inversely to non-memory pathways, this forces information through memory."

**Implementation:**
```python
if self.training and torch.rand(1).item() < 0.7:
    local_out = torch.zeros_like(local_out)  # Drop local 70% of time
```

This forces the model to learn state retrieval because local attention is unreliable during training.

## Summary of What Works

1. **GDN-only**: Best for pure retrieval (Delta=+36)
2. **GS + stochastic local**: Hybrid that uses state (Delta=+15)
3. **Curriculum learning**: Essential for both
4. **Shared key projection**: Helps alignment but not sufficient alone
5. **beta_floor=1.0**: Always write, don't gate

## Next Steps

1. Increase training steps for stochastic hybrid
2. Try 50% local drop instead of 70%
3. Test on language modeling perplexity
4. Scale to larger model/longer sequences

In [None]:
# DIAGNOSIS: Is SWA's state retrieval pathway learning at all?
# Check gradient norms during training

cfg_diag = HybridConfig(
    d_model=128, n_heads=2, head_dim=64, value_dim=64,
    layer_pattern='GS',
    vocab_size=50257,
    window_size=32, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=1.0,
    local_scale=0.3  # Bottleneck local path
)
model_diag = TransparentHybrid(cfg_diag).to(DEVICE)
opt_diag = torch.optim.AdamW(model_diag.parameters(), lr=1e-3)

print('Training hybrid and monitoring gradient norms...')
gdn_grads = []
swa_local_grads = []
swa_retrieval_grads = []

for step in range(1000):
    text, answer = make_curriculum_example(100)  # Hard task
    tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=300)['input_ids'].to(DEVICE)
    targets = tokens[:, 1:].contiguous()
    
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, _ = model_diag(tokens)
        loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
    
    opt_diag.zero_grad()
    loss.backward()
    
    # Collect gradient norms
    if step % 100 == 0:
        gdn_layer = model_diag.layers[0]  # GDN
        swa_layer = model_diag.layers[1]  # SWA
        
        gdn_grad = gdn_layer.k_proj.weight.grad.norm().item() if gdn_layer.k_proj.weight.grad is not None else 0
        swa_local_grad = swa_layer.q_proj.weight.grad.norm().item() if swa_layer.q_proj.weight.grad is not None else 0
        swa_retrieval_grad = swa_layer.global_q_proj.weight.grad.norm().item() if swa_layer.global_q_proj.weight.grad is not None else 0
        
        gdn_grads.append(gdn_grad)
        swa_local_grads.append(swa_local_grad)
        swa_retrieval_grads.append(swa_retrieval_grad)
        
        print(f'Step {step}: GDN k_proj={gdn_grad:.4f}, SWA local={swa_local_grad:.4f}, SWA retrieval={swa_retrieval_grad:.4f}')
    
    opt_diag.step()

print(f'\nAverage gradient norms:')
print(f'  GDN k_proj: {sum(gdn_grads)/len(gdn_grads):.4f}')
print(f'  SWA local: {sum(swa_local_grads)/len(swa_local_grads):.4f}')
print(f'  SWA retrieval: {sum(swa_retrieval_grads)/len(swa_retrieval_grads):.4f}')

ratio = sum(swa_local_grads)/max(sum(swa_retrieval_grads), 1e-8)
print(f'\nLocal/Retrieval gradient ratio: {ratio:.1f}x')
print('(If ratio >> 1, local path dominates learning)')

## Fix 1: Auxiliary Reconstruction Loss

Force state to encode retrievable info by adding loss that reconstructs input from state.

In [None]:
# Train with reconstruction loss: state must be able to reconstruct early tokens
cfg2 = HybridConfig(
    d_model=256, n_heads=4, head_dim=64, value_dim=64,
    layer_pattern='GS', vocab_size=50257,
    window_size=64, beta_bias=0.0, g_bias=2.0,
    shifted_value=True, beta_floor=0.1
)
model2 = TransparentHybrid(cfg2).to(DEVICE)

# Add reconstruction head: state -> predict early tokens
recon_head = nn.Linear(cfg2.n_heads * cfg2.head_dim * cfg2.value_dim, 50257).to(DEVICE)

opt2 = torch.optim.AdamW(list(model2.parameters()) + list(recon_head.parameters()), lr=1e-3)

print('Training with reconstruction loss...')
for step in range(3000):
    text, answer = make_example()
    tokens = tokenizer(text + answer, return_tensors='pt', truncation=True, max_length=600)['input_ids'].to(DEVICE)
    targets = tokens[:, 1:].contiguous()
    
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        logits, _, _, state = model2(tokens)
        lm_loss = F.cross_entropy(logits[:, :-1].reshape(-1, 50257), targets.reshape(-1))
        
        # Reconstruction loss: predict first 10 tokens from final state
        state_flat = state.reshape(1, -1)  # [1, H*K*V]
        recon_logits = recon_head(state_flat.float())  # [1, vocab]
        # Target: average of first 10 token embeddings -> predict first token
        first_token = tokens[0, 0]
        recon_loss = F.cross_entropy(recon_logits, first_token.unsqueeze(0))
        
        loss = lm_loss + 0.5 * recon_loss
    
    opt2.zero_grad()
    loss.backward()
    opt2.step()
    if step % 500 == 0:
        print(f'Step {step}: lm={lm_loss.item():.3f}, recon={recon_loss.item():.3f}')

print('\nWith reconstruction loss:')
test_state_ablation(model2, cfg2)