# GroundThink v6 Hybrid Architecture
## GDN + SWA with Retrieval Training

**Key Fixes Applied:**
1. **Gatekeeper GDN**: β bias=-2.0, no floor → sparse selective writes
2. **Sparse SWA Retrieval**: Dedicated query projection with ReLU sparsity
3. **Proper NIAH Test**: Uses MARKER + CUE tokens for retrieval signal
4. **Mixed Training**: LM + synthetic retrieval tasks
5. **Auxiliary Retrieval Loss**: Direct gradient for state→retrieval pathway

---

In [12]:
# =============================================================================
# CELL 0: Configuration & Imports
# =============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Tuple, Any
import math
import time
from collections import defaultdict


@dataclass
class HybridConfig:
    """
    Hybrid architecture configuration.
    
    Layer patterns: 'GS', 'GSG', 'GSGS', etc.
    """
    d_model: int = 256
    n_heads: int = 8
    head_dim: int = 32
    expand_v: float = 2.0
    vocab_size: int = 50257
    
    layer_pattern: str = "GS"
    window_size: int = 64
    
    init_std: float = 0.02
    
    state_accumulation: str = 'weighted'
    state_weight_new: float = 0.5
    
    # Special tokens for retrieval
    marker_token: int = 50251
    cue_token: int = 50250
    
    def __post_init__(self):
        self.head_dim = self.d_model // self.n_heads
        self.value_dim = int(self.head_dim * self.expand_v)
        
    @property
    def n_layers(self) -> int:
        return len(self.layer_pattern)
    
    @property
    def gdn_indices(self) -> List[int]:
        return [i for i, t in enumerate(self.layer_pattern) if t == 'G']
    
    @property
    def swa_indices(self) -> List[int]:
        return [i for i, t in enumerate(self.layer_pattern) if t == 'S']


def count_params(model):
    return sum(p.numel() for p in model.parameters())


print("Configuration loaded.")

Configuration loaded.


In [13]:
# =============================================================================
# CELL 1: Basic Components (RMSNorm, FFN)
# =============================================================================

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x / rms * self.weight


class FFN(nn.Module):
    def __init__(self, cfg: HybridConfig):
        super().__init__()
        hidden = int(cfg.d_model * 4)
        self.norm = RMSNorm(cfg.d_model)
        self.w1 = nn.Linear(cfg.d_model, hidden, bias=False)
        self.w2 = nn.Linear(hidden, cfg.d_model, bias=False)
        self.w3 = nn.Linear(cfg.d_model, hidden, bias=False)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.norm(x)
        return x + self.w2(F.silu(self.w1(h)) * self.w3(h))


print("Basic components loaded.")

Basic components loaded.


In [14]:
# =============================================================================
# CELL 2: GatedDeltaNetLayer - GATEKEEPER INITIALIZATION
# =============================================================================
#
# KEY FIX: β is LOW by default (bias=-2.0 → sigmoid ≈ 0.12)
#          No floor on β → can reach near 0
#          Model must LEARN to spike β for important tokens
#
# This prevents the state from being flooded with noise.
# =============================================================================

class GatedDeltaNetLayer(nn.Module):
    """
    GDN with Gatekeeper initialization.
    
    Delta Rule: Sₜ = gₜ * Sₜ₋₁ + βₜ * (vₜ ⊗ kₜ)
    
    With β biased low, the model preserves state by default
    and only writes when it learns something is important.
    """
    
    def __init__(self, cfg: HybridConfig, layer_idx: int):
        super().__init__()
        self.cfg = cfg
        self.layer_idx = layer_idx
        
        H, K, V = cfg.n_heads, cfg.head_dim, cfg.value_dim
        
        # QKV projections
        self.q_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.k_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.v_proj = nn.Linear(cfg.d_model, H * V, bias=False)
        self.o_proj = nn.Linear(H * V, cfg.d_model, bias=False)
        
        # === GATEKEEPER BETA ===
        self.beta_proj = nn.Linear(cfg.d_model, H, bias=True)
        # CRITICAL: Negative bias = closed by default
        # sigmoid(-2.0) ≈ 0.12, preserves 88% of old memory
        nn.init.constant_(self.beta_proj.bias, -2.0)
        
        # Forget gate
        self.g_proj = nn.Linear(cfg.d_model, H, bias=False)
        
        self.norm = RMSNorm(cfg.d_model)
        
    def forward(
        self, 
        x: torch.Tensor, 
        initial_state: Optional[torch.Tensor] = None,
        output_state: bool = True
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Dict[str, Any]]:
        B, T, D = x.shape
        H, K, V = self.cfg.n_heads, self.cfg.head_dim, self.cfg.value_dim
        
        x_norm = self.norm(x)
        
        # Project to Q, K, V
        q = self.q_proj(x_norm).view(B, T, H, K)
        k = self.k_proj(x_norm).view(B, T, H, K)
        v = self.v_proj(x_norm).view(B, T, H, V)
        
        # Normalize K for stability
        k = F.normalize(k.float(), p=2, dim=-1).to(x.dtype)
        
        # === GATEKEEPER BETA ===
        # NO floor! β can reach near 0.
        # With bias=-2.0, default β ≈ 0.12
        beta = torch.sigmoid(self.beta_proj(x_norm))  # [B, T, H] in [0, 1]
        
        # Forget gate (log space for numerical stability)
        g = F.logsigmoid(self.g_proj(x_norm))
        
        # Core delta rule
        output, state = chunk_gated_delta_rule(
            q, k, v, g, beta,
            initial_state=initial_state,
            output_final_state=output_state
        )
        
        # Project and residual
        output = output.reshape(B, T, H * V)
        output = self.o_proj(output)
        output = x + output
        
        # Diagnostics
        diagnostics = {
            'beta_mean': beta.mean().item(),
            'beta_std': beta.std().item(),
            'beta_min': beta.min().item(),
            'beta_max': beta.max().item(),
            'g_mean': g.exp().mean().item(),
            'g_std': g.exp().std().item(),
            'state_norm': state.norm().item() if state is not None else 0,
            'state_shape': tuple(state.shape) if state is not None else None,
        }
        
        return output, state, diagnostics


print("GatedDeltaNetLayer loaded (GATEKEEPER).")
print("  - β bias = -2.0 → default β ≈ 0.12")
print("  - No β floor → sparse selective writes")

GatedDeltaNetLayer loaded (GATEKEEPER).
  - β bias = -2.0 → default β ≈ 0.12
  - No β floor → sparse selective writes


In [15]:
# =============================================================================
# CELL 3: SlidingWindowAttention - DEDICATED SPARSE RETRIEVAL
# =============================================================================
#
# KEY FIX: SWA has its OWN query projection for state retrieval.
#          ReLU on retrieval query forces sparsity (pointy attention).
#
# This decouples GDN's write optimization from SWA's read optimization.
# =============================================================================

class SlidingWindowAttention(nn.Module):
    """
    SWA with dedicated sparse retrieval from GDN state.
    """
    
    def __init__(self, cfg: HybridConfig, layer_idx: int):
        super().__init__()
        self.cfg = cfg
        self.layer_idx = layer_idx
        H, K, V = cfg.n_heads, cfg.head_dim, cfg.value_dim
        
        # === LOCAL ATTENTION ===
        self.q_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.k_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.v_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.o_proj = nn.Linear(H * K, cfg.d_model, bias=False)
        
        # === DEDICATED GLOBAL RETRIEVAL ===
        self.global_q_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        nn.init.normal_(self.global_q_proj.weight, std=0.01)
        
        self.retrieval_o_proj = nn.Linear(H * V, cfg.d_model, bias=False)
        nn.init.xavier_uniform_(self.retrieval_o_proj.weight, gain=0.5)
        
        # Gate for retrieval contribution
        self.gate_proj = nn.Linear(cfg.d_model, H, bias=True)
        nn.init.zeros_(self.gate_proj.weight)
        nn.init.constant_(self.gate_proj.bias, 0.0)
        
        self.norm = RMSNorm(cfg.d_model)
        self.scale = K ** -0.5
        
    def forward(
        self, 
        x: torch.Tensor,
        gdn_state: Optional[torch.Tensor] = None,
        return_attn: bool = False
    ) -> Tuple[torch.Tensor, Dict[str, Any], ...]:
        B, T, D = x.shape
        H = self.cfg.n_heads
        K = self.cfg.head_dim
        V = self.cfg.value_dim
        W = self.cfg.window_size
        
        x_norm = self.norm(x)
        
        # === LOCAL ATTENTION ===
        q = self.q_proj(x_norm).view(B, T, H, K).transpose(1, 2)
        k_local = self.k_proj(x_norm).view(B, T, H, K).transpose(1, 2)
        v_local = self.v_proj(x_norm).view(B, T, H, K).transpose(1, 2)
        
        mask = torch.ones(T, T, device=x.device, dtype=torch.bool)
        mask = mask.triu(1) | mask.tril(-W - 1)
        
        attn_local = (q @ k_local.transpose(-2, -1)) * self.scale
        attn_local = attn_local.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        attn_weights_local = F.softmax(attn_local, dim=-1)
        local_out = attn_weights_local @ v_local
        
        local_out = local_out.transpose(1, 2).reshape(B, T, H * K)
        local_out = self.o_proj(local_out)
        
        # === DEDICATED SPARSE RETRIEVAL ===
        retrieval_out = torch.zeros(B, T, D, device=x.device, dtype=x.dtype)
        attn_weights_global = None
        gate_values = torch.full((B, H, T, 1), 0.5, device=x.device, dtype=x.dtype)
        
        if gdn_state is not None:
            state = gdn_state.to(x.dtype)  # [B, H, K, V]
            
            # === OWN QUERY with ReLU SPARSITY ===
            q_global = self.global_q_proj(x_norm).view(B, T, H, K).transpose(1, 2)
            q_global = F.relu(q_global)  # Sparsity: zero out negative values
            q_global = F.normalize(q_global.float(), p=2, dim=-1).to(x.dtype)
            
            # === LINEAR RETRIEVAL: State @ q ===
            retrieved = torch.einsum('bhkv,bhtk->bhtv', state, q_global)
            
            # For diagnostics
            state_k_norms = state.norm(dim=-1)
            attn_weights_global = torch.einsum('bhtk,bhk->bhtk', q_global.abs(), state_k_norms)
            attn_weights_global = F.softmax(attn_weights_global / 0.1, dim=-1)
            
            # Project retrieved values
            retrieved = retrieved.transpose(1, 2).reshape(B, T, H * V)
            retrieval_out = self.retrieval_o_proj(retrieved)
            
            # Learned gate
            gate_logits = self.gate_proj(x_norm)
            gate = torch.sigmoid(gate_logits)
            gate_values = gate.transpose(1, 2).unsqueeze(-1)
            
            gate_scale = gate.mean(dim=-1, keepdim=True)
            retrieval_out = gate_scale * retrieval_out
        
        # === COMBINE ===
        out = x + local_out + retrieval_out
        
        # Diagnostics
        diagnostics = {
            'local_attn_entropy': -(attn_weights_local * attn_weights_local.clamp(min=1e-8).log()).sum(-1).mean().item(),
            'gate_mean': gate_values.mean().item(),
            'gate_std': gate_values.std().item(),
            'retrieval_norm': retrieval_out.norm().item(),
            'local_norm': local_out.norm().item(),
            'global_attn_entropy': (
                -(attn_weights_global * attn_weights_global.clamp(min=1e-8).log()).sum(-1).mean().item() 
                if attn_weights_global is not None else 0
            ),
        }
        
        if return_attn:
            return out, diagnostics, attn_weights_local, attn_weights_global, gate_values
        return out, diagnostics


print("SlidingWindowAttention loaded (SPARSE RETRIEVAL).")
print("  - Dedicated global_q_proj for state retrieval")
print("  - ReLU sparsity on retrieval query")

SlidingWindowAttention loaded (SPARSE RETRIEVAL).
  - Dedicated global_q_proj for state retrieval
  - ReLU sparsity on retrieval query


In [16]:
# =============================================================================
# CELL 4: TransparentHybrid Model (FIXED ACCUMULATION)
# =============================================================================

class TransparentHybrid(nn.Module):
    """
    Hybrid GDN + SWA model.
    """
    
    def __init__(self, cfg: HybridConfig):
        super().__init__()
        self.cfg = cfg
        
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
        nn.init.normal_(self.embed.weight, std=cfg.init_std)
        
        self.layers = nn.ModuleList()
        self.ffns = nn.ModuleList()
        
        for i, layer_type in enumerate(cfg.layer_pattern):
            if layer_type == 'G':
                self.layers.append(GatedDeltaNetLayer(cfg, i))
            elif layer_type == 'S':
                self.layers.append(SlidingWindowAttention(cfg, i))
            else:
                raise ValueError(f"Unknown layer type: {layer_type}")
            self.ffns.append(FFN(cfg))
        
        self.norm_f = RMSNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        self.lm_head.weight = self.embed.weight
        
    def _accumulate_state(
        self, 
        accumulated: Optional[torch.Tensor], 
        new_state: torch.Tensor
    ) -> torch.Tensor:
        if accumulated is None:
            return new_state
        
        strategy = self.cfg.state_accumulation
        
        # --- FIX: Residual Sum is safer for depth ---
        if strategy == 'weighted':
             # Treat 'weighted' as residual sum with a scale factor if desired,
             # but simple addition is robust.
             return accumulated + (self.cfg.state_weight_new * new_state)
             
        elif strategy == 'replace':
            return new_state
        elif strategy == 'sum': # Added explicit sum
             return accumulated + new_state
        else:
            # Fallback to sum if unsure, averaging destroys info in deep nets
            return accumulated + new_state
        
    def forward(
        self, 
        input_ids: torch.Tensor, 
        targets: Optional[torch.Tensor] = None,
        return_diagnostics: bool = False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[Dict]], Optional[torch.Tensor]]:
        x = self.embed(input_ids)
        accumulated_state = None
        all_diagnostics = []
        
        for i, (layer, ffn) in enumerate(zip(self.layers, self.ffns)):
            layer_type = self.cfg.layer_pattern[i]
            
            if layer_type == 'G':
                x, state, diag = layer(x, initial_state=accumulated_state, output_state=True)
                accumulated_state = self._accumulate_state(accumulated_state, state)
                diag['layer_type'] = 'GDN'
            elif layer_type == 'S':
                x, diag = layer(x, gdn_state=accumulated_state)
                diag['layer_type'] = 'SWA'
            
            x = ffn(x)
            diag['layer_idx'] = i
            all_diagnostics.append(diag)
        
        x = self.norm_f(x)
        logits = self.lm_head(x)
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        
        if return_diagnostics:
            return logits, loss, all_diagnostics, accumulated_state
        return logits, loss, None, accumulated_state


print("TransparentHybrid loaded (FIXED).")

TransparentHybrid loaded (FIXED).


In [17]:
# =============================================================================
# CELL 5: Diagnostic Toolkit (FIXED)
# =============================================================================

@torch.no_grad()
def probe_gdn_state_content(
    model: TransparentHybrid, 
    input_ids: torch.Tensor, 
    target_token_pos: int = 32,
    verbose: bool = True
) -> Dict[str, Any]:
    """Check if needle is stored in GDN state."""
    model.eval()
    results = {'layers': [], 'summary': {}}
    
    x = model.embed(input_ids)
    accumulated_state = None
    needle_id = input_ids[0, target_token_pos].item()
    
    if verbose:
        print(f"Needle token ID: {needle_id} at position {target_token_pos}")
    
    for i, (layer, ffn) in enumerate(zip(model.layers, model.ffns)):
        layer_type = model.cfg.layer_pattern[i]
        
        if layer_type == 'G':
            x, layer_state, diag = layer(x, initial_state=accumulated_state, output_state=True)
            accumulated_state = model._accumulate_state(accumulated_state, layer_state)
            
            # Probe: Can we retrieve the needle?
            needle_embed = model.embed.weight[needle_id]
            q_needle = layer.q_proj(needle_embed.unsqueeze(0))
            q_needle = q_needle.view(1, model.cfg.n_heads, model.cfg.head_dim)
            
            # --- FIX: Cast query to match state dtype (likely float32) ---
            q_needle = q_needle.to(layer_state.dtype) 
            
            retrieved = torch.einsum('bhkv,bhk->bhv', layer_state, q_needle)
            needle_retrieval = retrieved.norm().item()
            
            # Random baseline
            random_id = torch.randint(0, model.cfg.vocab_size, (1,)).item()
            random_embed = model.embed.weight[random_id]
            q_random = layer.q_proj(random_embed.unsqueeze(0))
            q_random = q_random.view(1, model.cfg.n_heads, model.cfg.head_dim)
            
            # --- FIX: Cast random query to match state dtype ---
            q_random = q_random.to(layer_state.dtype)
            
            retrieved_random = torch.einsum('bhkv,bhk->bhv', layer_state, q_random)
            random_retrieval = retrieved_random.norm().item()
            
            snr = needle_retrieval / (random_retrieval + 1e-8)
            
            layer_result = {
                'layer_idx': i,
                'layer_type': 'GDN',
                'state_norm': layer_state.norm().item(),
                'needle_retrieval': needle_retrieval,
                'random_retrieval': random_retrieval,
                'signal_to_noise': snr,
                'beta_mean': diag['beta_mean'],
                'beta_std': diag['beta_std'],
                'g_mean': diag['g_mean'],
            }
            results['layers'].append(layer_result)
            
            if verbose:
                status = "✓" if snr > 1.0 else "✗"
                snr_quality = "GOOD" if snr > 1.0 else "WEAK"
                print(f"\n  [GDN Layer {i}] {status}")
                print(f"      State norm: {layer_state.norm().item():.4f}")
                print(f"      Needle Retrieval: {needle_retrieval:.6f}")
                print(f"      Random Retrieval: {random_retrieval:.6f}")
                print(f"      Signal-to-Noise:  {snr:.4f} ({snr_quality})")
                print(f"      β={diag['beta_mean']:.3f}±{diag['beta_std']:.3f}, g={diag['g_mean']:.3f}")
        else:
            x, _ = layer(x, gdn_state=accumulated_state)
        
        x = ffn(x)
    
    gdn_results = [r for r in results['layers'] if r['layer_type'] == 'GDN']
    if gdn_results:
        avg_snr = sum(r['signal_to_noise'] for r in gdn_results) / len(gdn_results)
        max_snr = max(r['signal_to_noise'] for r in gdn_results)
        results['summary'] = {'avg_snr': avg_snr, 'max_snr': max_snr, 'needle_stored': max_snr > 1.0}
        
        if verbose:
            print(f"\n  Summary: avg_SNR={avg_snr:.4f}, max_SNR={max_snr:.4f}")
            print(f"  → Needle {'IS' if max_snr > 1.0 else 'NOT'} stored in GDN state")
    
    return results


@torch.no_grad()
def run_full_diagnostic(model, seq_len=128, needle_pos=32):
    """Run complete diagnostic suite."""
    model.eval()
    device = next(model.parameters()).device
    
    print(f"\n{'='*60}")
    print(f"FULL DIAGNOSTIC SUITE")
    print(f"{'='*60}")
    print(f"Sequence length: {seq_len}")
    print(f"Needle position: {needle_pos}")
    
    # Create test sequence
    seq = torch.randint(0, model.cfg.vocab_size - 100, (1, seq_len), device=device)
    needle_id = 50000
    seq[0, needle_pos] = needle_id
    
    print(f"\n--- GDN State Analysis ---")
    gdn_result = probe_gdn_state_content(model, seq, needle_pos)
    
    return gdn_result


print("Diagnostic toolkit loaded (FIXED).")

Diagnostic toolkit loaded (FIXED).


In [18]:
# =============================================================================
# CELL 6: Data Loading
# =============================================================================

from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Dataset

class TextDataset(Dataset):
    def __init__(self, tokens, seq_len=128):
        self.tokens = tokens
        self.seq_len = seq_len
        
    def __len__(self):
        return (len(self.tokens) - 1) // self.seq_len
    
    def __getitem__(self, idx):
        start = idx * self.seq_len
        chunk = self.tokens[start:start + self.seq_len + 1]
        return torch.tensor(chunk, dtype=torch.long)


def load_data(n_tokens=2_000_000, seq_len=128, batch_size=16):
    """Load and tokenize training data."""
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    
    print("Loading dataset...")
    dataset = load_dataset('wikitext', 'wikitext-103-raw-v1', split='train')
    
    print("Tokenizing...")
    all_tokens = []
    for item in dataset:
        if item['text'].strip():
            tokens = tokenizer.encode(item['text'])
            all_tokens.extend(tokens)
            if len(all_tokens) >= n_tokens:
                break
    
    all_tokens = all_tokens[:n_tokens]
    print(f"Total tokens: {len(all_tokens):,}")
    
    ds = TextDataset(all_tokens, seq_len=seq_len)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)
    
    return loader, tokenizer


# Load data
data_loader, tokenizer = load_data(n_tokens=2_000_000, seq_len=128, batch_size=16)
print(f"Batches per epoch: {len(data_loader)}")

Loading tokenizer...
Loading dataset...
Tokenizing...
Total tokens: 2,000,000
Batches per epoch: 976


In [19]:
# =============================================================================
# CELL 7: Retrieval Training & Testing Infrastructure
# =============================================================================
#
# This cell provides:
# 1. Proper NIAH test with MARKER + CUE tokens
# 2. Synthetic retrieval data generator
# 3. Auxiliary retrieval loss computation
#
# =============================================================================

class RetrievalDataGenerator:
    """
    Generates synthetic retrieval training examples.
    
    Format: [context] MARKER value [distractor] CUE -> value
    """
    
    def __init__(self, cfg: HybridConfig):
        self.vocab_size = cfg.vocab_size
        self.marker_token = cfg.marker_token
        self.cue_token = cfg.cue_token
        
    def generate_batch(
        self,
        batch_size: int,
        seq_len: int,
        device: str = 'cuda',
        min_distance: int = 10,
        max_distance: int = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Generate retrieval training batch.
        
        Returns:
            input_ids: [B, seq_len]
            targets: [B, seq_len] with -100 for non-retrieval positions
            needle_ids: [B] the tokens to retrieve
        """
        if max_distance is None:
            max_distance = seq_len - 10
        
        # Random base sequence
        input_ids = torch.randint(
            0, self.vocab_size - 100,
            (batch_size, seq_len),
            device=device
        )
        
        # Random needles
        needle_ids = torch.randint(
            1000, self.vocab_size - 100,
            (batch_size,),
            device=device
        )
        
        # Targets: only care about retrieval position
        targets = torch.full(
            (batch_size, seq_len),
            -100,
            device=device
        )
        
        for b in range(batch_size):
            # Varied distance for curriculum
            distance = torch.randint(min_distance, max_distance, (1,)).item()
            marker_pos = max(2, seq_len - distance - 3)
            
            # Place MARKER, needle, and CUE
            input_ids[b, marker_pos] = self.marker_token
            input_ids[b, marker_pos + 1] = needle_ids[b]
            input_ids[b, -2] = self.cue_token
            
            # Target: position after CUE should predict needle
            targets[b, -1] = needle_ids[b]
        
        return input_ids, targets, needle_ids


def compute_auxiliary_retrieval_loss(
    model,
    batch_size: int = 8,
    seq_len: int = 128,
    device: str = 'cuda',
) -> torch.Tensor:
    """
    Auxiliary loss: Direct gradient for state → retrieval pathway.
    
    This explicitly trains:
    1. GDN to store needle when MARKER is seen
    2. State to retain needle across distractors
    3. SWA to retrieve when CUE is seen
    """
    retrieval_gen = RetrievalDataGenerator(model.cfg)
    input_ids, targets, needles = retrieval_gen.generate_batch(
        batch_size, seq_len, device
    )
    
    logits, _, _, _ = model(input_ids, return_diagnostics=False)
    
    # Loss only on retrieval positions (where targets != -100)
    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        targets.view(-1),
        ignore_index=-100
    )
    
    return loss


def proper_niah_test(
    model,
    seq_len: int = 128,
    needle_pos: int = 32,
    n_trials: int = 30,
) -> Dict:
    """
    NIAH test with proper retrieval cue.
    
    Format: [hay] MARKER needle [hay] CUE -> should predict needle
    """
    model.eval()
    device = next(model.parameters()).device
    cfg = model.cfg
    
    results = []
    
    for trial in range(n_trials):
        needle_id = torch.randint(1000, cfg.vocab_size - 100, (1,)).item()
        
        # Build sequence
        seq = torch.randint(0, cfg.vocab_size - 100, (1, seq_len), device=device)
        seq[0, needle_pos - 1] = cfg.marker_token
        seq[0, needle_pos] = needle_id
        seq[0, -2] = cfg.cue_token
        
        with torch.no_grad():
            logits, _, _, _ = model(seq)
        
        pred_probs = F.softmax(logits[0, -1, :], dim=-1)
        needle_prob = pred_probs[needle_id].item()
        needle_rank = (pred_probs > needle_prob).sum().item() + 1
        
        random_id = torch.randint(1000, cfg.vocab_size - 100, (1,)).item()
        random_prob = pred_probs[random_id].item()
        
        ratio = needle_prob / (random_prob + 1e-10)
        
        results.append({
            'needle_id': needle_id,
            'needle_prob': needle_prob,
            'needle_rank': needle_rank,
            'ratio': ratio,
            'success': ratio > 1.0,
        })
    
    avg_ratio = sum(r['ratio'] for r in results) / len(results)
    avg_rank = sum(r['needle_rank'] for r in results) / len(results)
    success_rate = sum(r['success'] for r in results) / len(results)
    
    status = "PASS" if avg_ratio > 1.0 else "FAIL"
    print(f"NIAH: {avg_ratio:.4f}x random ({status})")
    print(f"  Avg rank: {avg_rank:.0f}/{cfg.vocab_size}, Success: {success_rate*100:.1f}%")
    
    return {'avg_ratio': avg_ratio, 'avg_rank': avg_rank, 'success_rate': success_rate}


def test_niah_by_distance(
    model,
    distances: List[int] = [5, 10, 20, 40, 60, 95],
    n_trials: int = 20,
):
    """Test NIAH at increasing distances to find capacity limit."""
    model.eval()
    print("\nNIAH by distance:")
    print("-" * 40)
    
    results = []
    seq_len = 128
    
    for dist in distances:
        needle_pos = max(2, seq_len - dist - 2)
        result = proper_niah_test(model, seq_len=seq_len, needle_pos=needle_pos, n_trials=n_trials)
        results.append({'distance': dist, **result})
    
    return results


print("Retrieval infrastructure loaded.")
print("  - RetrievalDataGenerator: Creates synthetic retrieval tasks")
print("  - compute_auxiliary_retrieval_loss(): Direct retrieval gradient")
print("  - proper_niah_test(): NIAH with MARKER + CUE")
print("  - test_niah_by_distance(): Find capacity limits")

Retrieval infrastructure loaded.
  - RetrievalDataGenerator: Creates synthetic retrieval tasks
  - compute_auxiliary_retrieval_loss(): Direct retrieval gradient
  - proper_niah_test(): NIAH with MARKER + CUE
  - test_niah_by_distance(): Find capacity limits


In [20]:
# =============================================================================
# CELL 8: Mixed Training Loop (LM + Retrieval + Auxiliary)
# =============================================================================

def train_mixed(
    model,
    lm_data_loader,
    steps: int = 20000,
    lr: float = 3e-4,
    warmup_steps: int = 2000,
    retrieval_ratio: float = 0.1,
    auxiliary_weight: float = 0.1,
    log_every: int = 100,
    niah_every: int = 1000,
    device: str = 'cuda',
):
    """
    Training with mixed objectives:
    1. LM loss (standard next-token prediction)
    2. Retrieval loss (synthetic MARKER/CUE tasks)
    3. Auxiliary retrieval loss (direct state→retrieval gradient)
    
    Args:
        retrieval_ratio: Fraction of batches that are pure retrieval
        auxiliary_weight: Weight for auxiliary loss added to LM batches
    """
    model.train()
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = CosineAnnealingLR(optimizer, T_max=steps, eta_min=lr * 0.01)
    
    retrieval_gen = RetrievalDataGenerator(model.cfg)
    lm_iter = iter(lm_data_loader)
    
    history = {'lm_loss': [], 'ret_loss': [], 'aux_loss': [], 'niah': []}
    start_time = time.time()
    
    for step in range(steps):
        # Warmup
        if step < warmup_steps:
            lr_scale = (step + 1) / warmup_steps
            for pg in optimizer.param_groups:
                pg['lr'] = lr * lr_scale
        
        optimizer.zero_grad()
        
        # Decide batch type
        if torch.rand(1).item() < retrieval_ratio:
            # Pure retrieval batch
            input_ids, targets, _ = retrieval_gen.generate_batch(16, 128, device)
            logits, _, diags, _ = model(input_ids, return_diagnostics=True)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-100
            )
            history['ret_loss'].append(loss.item())
            batch_type = 'RET'
        else:
            # LM batch + auxiliary loss
            try:
                batch = next(lm_iter)
            except StopIteration:
                lm_iter = iter(lm_data_loader)
                batch = next(lm_iter)
            
            input_ids = batch[:, :-1].to(device)
            targets = batch[:, 1:].to(device)
            
            logits, lm_loss, diags, _ = model(input_ids, targets, return_diagnostics=True)
            
            # Add auxiliary retrieval loss
            aux_loss = compute_auxiliary_retrieval_loss(model, batch_size=8, seq_len=128, device=device)
            
            loss = lm_loss + auxiliary_weight * aux_loss
            
            history['lm_loss'].append(lm_loss.item())
            history['aux_loss'].append(aux_loss.item())
            batch_type = 'LM+AUX'
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        if step >= warmup_steps:
            scheduler.step()
        
        # Logging
        if step % log_every == 0:
            lm_avg = sum(history['lm_loss'][-100:]) / max(1, len(history['lm_loss'][-100:]))
            ret_avg = sum(history['ret_loss'][-100:]) / max(1, len(history['ret_loss'][-100:]))
            aux_avg = sum(history['aux_loss'][-100:]) / max(1, len(history['aux_loss'][-100:]))
            
            beta = diags[0].get('beta_mean', 0) if diags else 0
            g = diags[0].get('g_mean', 0) if diags else 0
            gate = diags[1].get('gate_mean', 0) if diags and len(diags) > 1 else 0
            
            current_lr = optimizer.param_groups[0]['lr']
            
            print(f"[{step:5d}] LM={lm_avg:.3f} RET={ret_avg:.3f} AUX={aux_avg:.3f} | "
                  f"β={beta:.3f} g={g:.3f} gate={gate:.2f} | lr={current_lr:.2e}")
        
        # NIAH check
        if step % niah_every == 0 and step > 0:
            model.eval()
            niah_result = proper_niah_test(model, seq_len=128, needle_pos=32, n_trials=20)
            history['niah'].append(niah_result['avg_ratio'])
            model.train()
    
    print(f"\nTraining complete in {(time.time() - start_time)/60:.1f} min")
    print(f"Final LM loss: {sum(history['lm_loss'][-100:])/100:.3f}")
    
    return history


print("Mixed training loop loaded.")
print("  - LM loss + Auxiliary retrieval loss on LM batches")
print("  - Pure retrieval batches mixed in")

Mixed training loop loaded.
  - LM loss + Auxiliary retrieval loss on LM batches
  - Pure retrieval batches mixed in


In [21]:
# =============================================================================
# CELL 9: Model Instantiation
# =============================================================================

cfg = HybridConfig(
    d_model=256,
    n_heads=8,
    layer_pattern="GS",
    window_size=64,
    state_accumulation='weighted',
    state_weight_new=0.5,
)

model = TransparentHybrid(cfg).cuda().bfloat16()
params = count_params(model)

print(f"\nArchitecture: {cfg.layer_pattern}")
print(f"  GDN layers: {cfg.gdn_indices}")
print(f"  SWA layers: {cfg.swa_indices}")
print(f"  Parameters: {params:,}")
print(f"  State shape: [B, {cfg.n_heads}, {cfg.head_dim}, {cfg.value_dim}]")

# Quick forward test
x = torch.randint(0, 1000, (1, 128), device='cuda')
with torch.no_grad():
    logits, _, diags, state = model(x, return_diagnostics=True)

print(f"\nInitial diagnostics:")
print(f"  GDN β={diags[0]['beta_mean']:.3f} (should be ~0.12)")
print(f"  GDN g={diags[0]['g_mean']:.3f}")
print(f"  SWA gate={diags[1]['gate_mean']:.3f}")


Architecture: GS
  GDN layers: [0]
  SWA layers: [1]
  Parameters: 15,298,064
  State shape: [B, 8, 32, 64]

Initial diagnostics:
  GDN β=0.128 (should be ~0.12)
  GDN g=0.500
  SWA gate=0.500


In [22]:
# =============================================================================
# CELL 10: Pre-Training Baseline
# =============================================================================

print("#" * 60)
print("# PRE-TRAINING BASELINE")
print("#" * 60)

print("\n1. NIAH with proper retrieval cue:")
baseline_niah = proper_niah_test(model, seq_len=128, needle_pos=32, n_trials=30)

print("\n2. NIAH by distance:")
baseline_distances = test_niah_by_distance(model, distances=[5, 10, 20, 40, 60, 95], n_trials=20)

print("\n3. GDN state analysis:")
run_full_diagnostic(model, seq_len=128, needle_pos=32)

############################################################
# PRE-TRAINING BASELINE
############################################################

1. NIAH with proper retrieval cue:
NIAH: 1.1766x random (PASS)
  Avg rank: 28167/50257, Success: 56.7%

2. NIAH by distance:

NIAH by distance:
----------------------------------------
NIAH: 1.0494x random (PASS)
  Avg rank: 29253/50257, Success: 40.0%
NIAH: 1.1925x random (PASS)
  Avg rank: 25554/50257, Success: 45.0%
NIAH: 0.9634x random (FAIL)
  Avg rank: 28194/50257, Success: 35.0%
NIAH: 1.1655x random (PASS)
  Avg rank: 23925/50257, Success: 50.0%
NIAH: 1.1483x random (PASS)
  Avg rank: 24883/50257, Success: 50.0%
NIAH: 0.9975x random (FAIL)
  Avg rank: 22382/50257, Success: 40.0%

3. GDN state analysis:

FULL DIAGNOSTIC SUITE
Sequence length: 128
Needle position: 32

--- GDN State Analysis ---
Needle token ID: 50000 at position 32

  [GDN Layer 0] ✗
      State norm: 1.6587
      Needle Retrieval: 0.012147
      Random Retrieval: 0.028

{'layers': [{'layer_idx': 0,
   'layer_type': 'GDN',
   'state_norm': 1.658738613128662,
   'needle_retrieval': 0.012146753259003162,
   'random_retrieval': 0.028279220685362816,
   'signal_to_noise': 0.429529126663628,
   'beta_mean': 0.1337890625,
   'beta_std': 0.06640625,
   'g_mean': 0.50390625}],
 'summary': {'avg_snr': 0.429529126663628,
  'max_snr': 0.429529126663628,
  'needle_stored': False}}

In [23]:
# =============================================================================
# CELL 11: Training
# =============================================================================

print("#" * 60)
print("# MIXED TRAINING: LM + Retrieval + Auxiliary")
print("#" * 60)

history = train_mixed(
    model,
    data_loader,
    steps=20000,
    lr=3e-4,
    warmup_steps=2000,
    retrieval_ratio=0.1,      # 10% pure retrieval batches
    auxiliary_weight=0.1,     # Auxiliary loss weight on LM batches
    log_every=100,
    niah_every=1000,
    device='cuda',
)

############################################################
# MIXED TRAINING: LM + Retrieval + Auxiliary
############################################################
[    0] LM=10.875 RET=0.000 AUX=10.875 | β=0.134 g=0.500 gate=0.50 | lr=1.50e-07
[  100] LM=10.871 RET=10.906 AUX=10.894 | β=0.133 g=0.498 gate=0.50 | lr=1.51e-05
[  200] LM=10.784 RET=10.892 AUX=10.873 | β=0.133 g=0.500 gate=0.50 | lr=3.01e-05
[  300] LM=10.471 RET=10.895 AUX=10.890 | β=0.136 g=0.500 gate=0.50 | lr=4.51e-05
[  400] LM=9.954 RET=10.889 AUX=10.884 | β=0.139 g=0.498 gate=0.49 | lr=6.01e-05
[  500] LM=9.436 RET=10.891 AUX=10.885 | β=0.143 g=0.496 gate=0.49 | lr=7.51e-05
[  600] LM=8.836 RET=10.902 AUX=10.891 | β=0.130 g=0.504 gate=0.49 | lr=9.01e-05
[  700] LM=8.229 RET=10.906 AUX=10.920 | β=0.155 g=0.496 gate=0.47 | lr=1.05e-04
[  800] LM=7.858 RET=10.908 AUX=10.889 | β=0.166 g=0.494 gate=0.48 | lr=1.20e-04
[  900] LM=7.626 RET=10.916 AUX=10.880 | β=0.177 g=0.494 gate=0.52 | lr=1.35e-04
[ 1000] LM=7.453 RET

In [24]:
# =============================================================================
# CELL 12: Post-Training Evaluation
# =============================================================================

print("#" * 60)
print("# POST-TRAINING EVALUATION")
print("#" * 60)

model.eval()

print("\n1. NIAH with proper cue:")
final_niah = proper_niah_test(model, seq_len=128, needle_pos=32, n_trials=30)

print("\n2. NIAH by distance:")
final_distances = test_niah_by_distance(model, distances=[5, 10, 20, 40, 60, 95], n_trials=20)

print("\n3. Comparison:")
print(f"  Baseline NIAH: {baseline_niah['avg_ratio']:.4f}x")
print(f"  Final NIAH:    {final_niah['avg_ratio']:.4f}x")
if baseline_niah['avg_ratio'] > 0:
    print(f"  Improvement:   {final_niah['avg_ratio'] / baseline_niah['avg_ratio']:.2f}x")

print("\n4. GDN state analysis:")
run_full_diagnostic(model, seq_len=128, needle_pos=32)

print("\n5. Training history:")
print(f"  Final LM loss: {sum(history['lm_loss'][-100:])/100:.3f}")
print(f"  Final RET loss: {sum(history['ret_loss'][-100:])/max(1,len(history['ret_loss'][-100:])):.3f}")
print(f"  Final AUX loss: {sum(history['aux_loss'][-100:])/100:.3f}")

############################################################
# POST-TRAINING EVALUATION
############################################################

1. NIAH with proper cue:
NIAH: 1.0847x random (PASS)
  Avg rank: 27809/50257, Success: 40.0%

2. NIAH by distance:

NIAH by distance:
----------------------------------------
NIAH: 1.2785x random (PASS)
  Avg rank: 23196/50257, Success: 40.0%
NIAH: 1.3466x random (PASS)
  Avg rank: 21868/50257, Success: 60.0%
NIAH: 1.1938x random (PASS)
  Avg rank: 21348/50257, Success: 50.0%
NIAH: 1.1246x random (PASS)
  Avg rank: 25958/50257, Success: 35.0%
NIAH: 1.3827x random (PASS)
  Avg rank: 23393/50257, Success: 65.0%
NIAH: 1.0713x random (PASS)
  Avg rank: 25967/50257, Success: 50.0%

3. Comparison:
  Baseline NIAH: 1.1766x
  Final NIAH:    1.0847x
  Improvement:   0.92x

4. GDN state analysis:

FULL DIAGNOSTIC SUITE
Sequence length: 128
Needle position: 32

--- GDN State Analysis ---
Needle token ID: 50000 at position 32

  [GDN Layer 0] ✓
     

In [25]:
# =============================================================================
# CELL 13: Curriculum Training (The Fix)
# =============================================================================

def train_curriculum(model, lm_loader, device='cuda'):
    print(f"\n{'='*60}")
    print("PHASE 1: RETRIEVAL BOOTCAMP (Force the Circuit)")
    print(f"{'='*60}")
    
    # Setup for pure retrieval
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    retrieval_gen = RetrievalDataGenerator(model.cfg)
    
    # Trackers
    losses = []
    
    # --- PHASE 1: PURE RETRIEVAL (1000 steps) ---
    model.train()
    for step in range(1001):
        optimizer.zero_grad()
        
        # 100% Synthetic Data
        input_ids, targets, _ = retrieval_gen.generate_batch(
            batch_size=32,       # Larger batch for stable gradients
            seq_len=128, 
            device=device,
            min_distance=5,      # Start easy
            max_distance=90      # Go hard
        )
        
        logits, _, diags, _ = model(input_ids, return_diagnostics=True)
        
        # Calculate Loss
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)), 
            targets.view(-1), 
            ignore_index=-100
        )
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        losses.append(loss.item())
        
        if step % 100 == 0:
            avg_loss = sum(losses[-10:]) / 10
            beta = diags[0]['beta_mean']
            print(f"[Bootcamp {step}] Loss: {avg_loss:.4f} | β: {beta:.3f}")
            
            # Early exit if solved
            if avg_loss < 0.5:
                print(">> Circuit formed! Moving to Phase 2.")
                break

    print(f"\n{'='*60}")
    print("PHASE 2: LANGUAGE INTEGRATION (Preserve the Circuit)")
    print(f"{'='*60}")
    
    # Re-initialize optimizer for mixed phase
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) # Lower LR to protect circuit
    lm_iter = iter(lm_loader)
    
    for step in range(2000): # Short fine-tuning
        optimizer.zero_grad()
        
        # 50/50 Mix
        if step % 2 == 0:
            # Retrieval Task (Maintain Memory)
            input_ids, targets, _ = retrieval_gen.generate_batch(16, 128, device)
            logits, _, _, _ = model(input_ids, return_diagnostics=True)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
            task = "RET"
        else:
            # Language Task (Learn Syntax)
            try:
                batch = next(lm_iter)
            except StopIteration:
                lm_iter = iter(lm_loader)
                batch = next(lm_iter)
            
            input_ids = batch[:, :-1].to(device)
            targets = batch[:, 1:].to(device)
            logits, loss, _, _ = model(input_ids, targets, return_diagnostics=True)
            task = "LM "
            
        loss.backward()
        optimizer.step()
        
        if step % 200 == 0:
            print(f"[Integration {step}] Task: {task} | Loss: {loss.item():.4f}")

    return model

# --- EXECUTE ---
print("Starting Curriculum Training...")
model = TransparentHybrid(cfg).cuda().bfloat16() # Reset model
train_curriculum(model, data_loader)

print("\nRunning Final Diagnostics...")
run_full_diagnostic(model, seq_len=128, needle_pos=32)

Starting Curriculum Training...

PHASE 1: RETRIEVAL BOOTCAMP (Force the Circuit)
[Bootcamp 0] Loss: 1.0875 | β: 0.133
[Bootcamp 100] Loss: 10.8812 | β: 0.132
[Bootcamp 200] Loss: 10.8625 | β: 0.135
[Bootcamp 300] Loss: 10.8562 | β: 0.137
[Bootcamp 400] Loss: 10.8875 | β: 0.140
[Bootcamp 500] Loss: 10.8812 | β: 0.142
[Bootcamp 600] Loss: 10.8562 | β: 0.141
[Bootcamp 700] Loss: 10.8438 | β: 0.141
[Bootcamp 800] Loss: 10.6375 | β: 0.138
[Bootcamp 900] Loss: 10.3688 | β: 0.133
[Bootcamp 1000] Loss: 9.9688 | β: 0.122

PHASE 2: LANGUAGE INTEGRATION (Preserve the Circuit)
[Integration 0] Task: RET | Loss: 10.1250
[Integration 200] Task: RET | Loss: 10.3125
[Integration 400] Task: RET | Loss: 10.3750
[Integration 600] Task: RET | Loss: 9.7500
[Integration 800] Task: RET | Loss: 9.7500
[Integration 1000] Task: RET | Loss: 9.8750
[Integration 1200] Task: RET | Loss: 10.4375
[Integration 1400] Task: RET | Loss: 9.5625
[Integration 1600] Task: RET | Loss: 10.1875
[Integration 1800] Task: RET | Los

{'layers': [{'layer_idx': 0,
   'layer_type': 'GDN',
   'state_norm': 1.5464245080947876,
   'needle_retrieval': 0.03408018499612808,
   'random_retrieval': 0.02842489629983902,
   'signal_to_noise': 1.1989550514831808,
   'beta_mean': 0.1005859375,
   'beta_std': 0.053955078125,
   'g_mean': 0.447265625}],
 'summary': {'avg_snr': 1.1989550514831808,
  'max_snr': 1.1989550514831808,
  'needle_stored': True}}