# GroundThink v6 Hybrid Architecture
## GDN + SWA with Diagnostic Instrumentation

**Structure:**
1. Configuration & Imports
2. Core Components (RMSNorm, FFN)
3. GatedDeltaNetLayer (GDN) - Recurrent memory
4. SlidingWindowAttention (SWA) - Local + Global attention with state cross-attention
5. TransparentHybrid - Main model
6. **Diagnostic Toolkit** - Probe functions for debugging information flow
7. Training Infrastructure
8. Analysis & Execution

---

In [None]:
# =============================================================================
# CELL 0: Configuration & Imports
# =============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Tuple, Literal, Any
import math
import time
from collections import defaultdict


@dataclass
class HybridConfig:
    """
    Fully configurable hybrid architecture.
    
    Layer pattern examples:
        'GS'      - 1 GDN, 1 SWA (minimal)
        'GGSS'    - 2 GDN, 2 SWA
        'GSGSG'   - Interleaved
        'GGGSGGGS' - DeepSeek-style ratio
    """
    # Model dimensions
    d_model: int = 256
    n_heads: int = 8
    head_dim: int = 64          # Computed in __post_init__
    expand_v: float = 2.0       # Value expansion for GDN state
    vocab_size: int = 50257
    
    # Architecture
    layer_pattern: str = "GS"
    
    # SWA config
    window_size: int = 512
    
    # Initialization
    init_std: float = 0.02
    
    # State accumulation strategy: 'replace', 'avg', 'weighted'
    state_accumulation: str = 'weighted'
    state_weight_new: float = 0.5  # Weight for newer state in weighted mode
    
    def __post_init__(self):
        self.head_dim = self.d_model // self.n_heads
        self.value_dim = int(self.head_dim * self.expand_v)
        
    @property
    def n_layers(self) -> int:
        return len(self.layer_pattern)
    
    @property
    def gdn_indices(self) -> List[int]:
        return [i for i, t in enumerate(self.layer_pattern) if t == 'G']
    
    @property
    def swa_indices(self) -> List[int]:
        return [i for i, t in enumerate(self.layer_pattern) if t == 'S']
    
    def layer_type(self, idx: int) -> str:
        return self.layer_pattern[idx]


print("Configuration loaded.")

In [None]:
# =============================================================================
# CELL 1: Core Components
# =============================================================================

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization."""
    
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
        return (x.float() * norm).type_as(x) * self.weight


class FFN(nn.Module):
    """SwiGLU Feed-Forward Network."""
    
    def __init__(self, cfg: HybridConfig):
        super().__init__()
        hidden = int(cfg.d_model * 8 / 3)
        hidden = ((hidden + 63) // 64) * 64  # Round to 64 for efficiency
        
        self.w1 = nn.Linear(cfg.d_model, hidden, bias=False)
        self.w3 = nn.Linear(cfg.d_model, hidden, bias=False)
        self.w2 = nn.Linear(hidden, cfg.d_model, bias=False)
        self.norm = RMSNorm(cfg.d_model)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.norm(x)
        return x + self.w2(F.silu(self.w1(h)) * self.w3(h))


print("Core components loaded.")

In [None]:
# =============================================================================
# CELL 2: Gated Delta Network Layer
# =============================================================================

class GatedDeltaNetLayer(nn.Module):
    """
    Transparent GDN using raw FLA op.
    
    Delta Rule: Sₜ = αₜ * Sₜ₋₁ + βₜ * (vₜ ⊗ kₜ)
    
    Parameters:
        αₜ (gate g): Controls forgetting (in log space for numerical stability)
        βₜ (beta):   Controls write strength
        Sₜ:          State matrix [B, H, K, V] - the associative memory
    
    Output: Sₜ @ qₜ (query the memory)
    """
    
    def __init__(self, cfg: HybridConfig, layer_idx: int):
        super().__init__()
        self.cfg = cfg
        self.layer_idx = layer_idx
        
        H, K, V = cfg.n_heads, cfg.head_dim, cfg.value_dim
        
        # QKV projections
        self.q_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.k_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.v_proj = nn.Linear(cfg.d_model, H * V, bias=False)
        self.o_proj = nn.Linear(H * V, cfg.d_model, bias=False)
        
        # Gate projections (per-head scalars)
        self.beta_proj = nn.Linear(cfg.d_model, H, bias=False)  # Write strength
        self.g_proj = nn.Linear(cfg.d_model, H, bias=False)     # Forget gate (log space)
        
        self.norm = RMSNorm(cfg.d_model)
        
    def forward(
        self, 
        x: torch.Tensor, 
        initial_state: Optional[torch.Tensor] = None,
        output_state: bool = True
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Dict[str, Any]]:
        """
        Args:
            x: Input tensor [B, T, D]
            initial_state: Previous state [B, H, K, V] or None
            output_state: Whether to return final state
            
        Returns:
            output: [B, T, D]
            state: [B, H, K, V] if output_state else None
            diagnostics: Dict with internal values for inspection
        """
        B, T, D = x.shape
        H, K, V = self.cfg.n_heads, self.cfg.head_dim, self.cfg.value_dim
        
        # Pre-norm
        x_norm = self.norm(x)
        
        # Project to Q, K, V
        q = self.q_proj(x_norm).view(B, T, H, K)
        k = self.k_proj(x_norm).view(B, T, H, K)
        v = self.v_proj(x_norm).view(B, T, H, V)
        
        # Normalize K for stability (FLA convention)
        k = F.normalize(k.float(), p=2, dim=-1).to(x.dtype)
        
        # Gates
        beta = 0.5 + 0.5 *self.beta_proj(x_norm).sigmoid()   # [B, T, H] write strength ∈ (0,1)
        g = F.logsigmoid(self.g_proj(x_norm))     # [B, T, H] forget gate in log space
        
        # Core delta rule operation
        output, state = chunk_gated_delta_rule(
            q, k, v, g, beta,
            initial_state=initial_state,
            output_final_state=output_state
        )
        
        # Project back to model dimension
        output = output.reshape(B, T, H * V)
        output = self.o_proj(output)
        
        # Residual connection
        output = x + output
        
        # Diagnostics
        diagnostics = {
            'beta_mean': beta.mean().item(),
            'beta_std': beta.std().item(),
            'beta_min': beta.min().item(),
            'beta_max': beta.max().item(),
            'g_mean': g.exp().mean().item(),      # Convert from log space
            'g_std': g.exp().std().item(),
            'state_norm': state.norm().item() if state is not None else 0,
            'state_shape': tuple(state.shape) if state is not None else None,
        }
        
        return output, state, diagnostics


print("GatedDeltaNetLayer loaded.")

In [None]:
# =============================================================================
# CELL 3: SlidingWindowAttention with ALIGNED Retrieval
# =============================================================================
#
# KEY INSIGHT: SWA must query in the same "language" GDN stores in.
#
# GDN stores with: State += β * (v ⊗ k) where k = k_proj(x)
# GDN reads with:  output = State @ q
#
# For retrieval to work: query must align with stored keys.
# 
# FIX: SWA uses GDN's k_proj for retrieval queries, and the same
#      State @ q operation GDN uses internally.
#
# =============================================================================

class SlidingWindowAttention(nn.Module):
    """
    SWA with ALIGNED retrieval from GDN state.
    
    The retrieval pathway uses GDN's k_proj (passed during forward) to ensure
    queries are in the same space as stored keys.
    
    Retrieval operation: State @ q (same as GDN's internal read)
    """
    
    def __init__(self, cfg: HybridConfig, layer_idx: int):
        super().__init__()
        self.cfg = cfg
        self.layer_idx = layer_idx
        H, K, V = cfg.n_heads, cfg.head_dim, cfg.value_dim
        
        # === LOCAL ATTENTION (unchanged) ===
        self.q_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.k_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.v_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.o_proj = nn.Linear(H * K, cfg.d_model, bias=False)
        
        # === RETRIEVAL OUTPUT PROJECTION ===
        # State @ q gives [B, H, T, V], need to project V -> d_model
        self.retrieval_o_proj = nn.Linear(H * V, cfg.d_model, bias=False)
        nn.init.xavier_uniform_(self.retrieval_o_proj.weight, gain=0.5)
        
        # === GATE: scales retrieval contribution ===
        self.gate_proj = nn.Linear(cfg.d_model, H, bias=True)
        nn.init.zeros_(self.gate_proj.weight)
        nn.init.constant_(self.gate_proj.bias, 1.0)
        
        self.norm = RMSNorm(cfg.d_model)
        self.scale = K ** -0.5
        
    def forward(
        self, 
        x: torch.Tensor,
        gdn_state: Optional[torch.Tensor] = None,
        gdn_q_proj: Optional[nn.Module] = None,  # GDN's q_proj for aligned retrieval
        return_attn: bool = False
    ) -> Tuple[torch.Tensor, Dict[str, Any], ...]:
        B, T, D = x.shape
        H = self.cfg.n_heads
        K = self.cfg.head_dim
        V = self.cfg.value_dim
        W = self.cfg.window_size
        
        x_norm = self.norm(x)
        
        # === LOCAL ATTENTION ===
        q = self.q_proj(x_norm).view(B, T, H, K).transpose(1, 2)
        k_local = self.k_proj(x_norm).view(B, T, H, K).transpose(1, 2)
        v_local = self.v_proj(x_norm).view(B, T, H, K).transpose(1, 2)
        
        mask = torch.ones(T, T, device=x.device, dtype=torch.bool)
        mask = mask.triu(1) | mask.tril(-W - 1)
        
        attn_local = (q @ k_local.transpose(-2, -1)) * self.scale
        attn_local = attn_local.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        attn_weights_local = F.softmax(attn_local, dim=-1)
        local_out = attn_weights_local @ v_local
        
        local_out = local_out.transpose(1, 2).reshape(B, T, H * K)
        local_out = self.o_proj(local_out)
        
        # === ALIGNED RETRIEVAL FROM GDN STATE ===
        retrieval_out = torch.zeros(B, T, D, device=x.device, dtype=x.dtype)
        attn_weights_global = None
        gate_values = torch.full((B, H, T, 1), 0.5, device=x.device, dtype=x.dtype)
        
        if gdn_state is not None:
            state = gdn_state.to(x.dtype)  # [B, H, K, V]
            
            # === QUERY PROJECTION ===
            if gdn_q_proj is not None:
                # ALIGNED: Use GDN's q_proj for retrieval queries
                # This matches how GDN internally queries its state
                # k_proj defines storage address, q_proj defines retrieval query
                q_ret = gdn_q_proj(x_norm)  # [B, T, H*K]
            else:
                # FALLBACK: Use local q_proj (for backwards compatibility)
                q_ret = self.q_proj(x_norm)  # [B, T, H*K]
            
            q_ret = q_ret.view(B, T, H, K).transpose(1, 2)  # [B, H, T, K]
            
            # NOTE: GDN normalizes KEYS but not QUERIES
            # For consistent behavior, we don't normalize q_ret here
            # The dot product (q · k) works with normalized k and unnormalized q
            
            # === RETRIEVAL: Same operation GDN uses internally ===
            # State @ q = Σᵢ vᵢ * (kᵢ · q)
            # State: [B, H, K, V], q_ret: [B, H, T, K]
            # Result: [B, H, T, V]
            retrieved = torch.einsum('bhkv,bhtk->bhtv', state, q_ret)
            
            # Compute "attention weights" for diagnostics
            # Approximated by q's alignment with state structure
            state_k_norms = state.norm(dim=-1)  # [B, H, K]
            attn_weights_global = torch.einsum('bhtk,bhk->bhtk', q_ret.abs(), state_k_norms)
            attn_weights_global = F.softmax(attn_weights_global, dim=-1)  # [B, H, T, K]
            
            # Project retrieved values to output dimension
            retrieved = retrieved.transpose(1, 2).reshape(B, T, H * V)
            retrieval_out = self.retrieval_o_proj(retrieved)
            
            # Learned gate for retrieval scaling
            gate_logits = self.gate_proj(x_norm)
            gate = torch.sigmoid(gate_logits)
            gate_values = gate.transpose(1, 2).unsqueeze(-1)
            
            gate_scale = gate.mean(dim=-1, keepdim=True)
            retrieval_out = gate_scale * retrieval_out
        
        # === COMBINE: Local + Gated Retrieval ===
        out = x + local_out + retrieval_out
        
        # Diagnostics
        diagnostics = {
            'local_attn_entropy': -(attn_weights_local * attn_weights_local.clamp(min=1e-8).log()).sum(-1).mean().item(),
            'gate_mean': gate_values.mean().item(),
            'gate_std': gate_values.std().item(),
            'gate_min': gate_values.min().item(),
            'gate_max': gate_values.max().item(),
            'retrieval_norm': retrieval_out.norm().item(),
            'local_norm': local_out.norm().item(),
            'global_attn_entropy': (
                -(attn_weights_global * attn_weights_global.clamp(min=1e-8).log()).sum(-1).mean().item() 
                if attn_weights_global is not None else 0
            ),
        }
        
        if return_attn:
            return out, diagnostics, attn_weights_local, attn_weights_global, gate_values
        return out, diagnostics


print("SlidingWindowAttention loaded (ALIGNED RETRIEVAL via q_proj).")
print("  - Uses GDN's q_proj for retrieval queries (not k_proj)")
print("  - This matches how GDN internally retrieves from state")
print("  - Retrieval: q_proj(query) · k_proj(stored) -> correct interaction")

In [None]:
# =============================================================================
# CELL 4: TransparentHybrid Model (with aligned retrieval support)
# =============================================================================
#
# CHANGE: SWA layers now receive gdn_k_proj for aligned retrieval
#
# =============================================================================

class TransparentHybrid(nn.Module):
    """
    Configurable GDN + SWA hybrid with full visibility.
    
    Information flow:
        - GDN layers compress sequence into state matrix Sₜ [H, K, V]
        - State accumulates across GDN layers
        - SWA layers query accumulated state for global context
        - SWA uses GDN's k_proj for aligned retrieval (NEW)
    """
    
    def __init__(self, cfg: HybridConfig):
        super().__init__()
        self.cfg = cfg
        
        # Embedding
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
        nn.init.normal_(self.embed.weight, std=cfg.init_std)
        
        # Build layers
        self.layers = nn.ModuleList()
        self.ffns = nn.ModuleList()
        
        for i, layer_type in enumerate(cfg.layer_pattern):
            if layer_type == 'G':
                self.layers.append(GatedDeltaNetLayer(cfg, i))
            elif layer_type == 'S':
                self.layers.append(SlidingWindowAttention(cfg, i))
            else:
                raise ValueError(f"Unknown layer type: {layer_type}")
            self.ffns.append(FFN(cfg))
        
        self.norm_f = RMSNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        self.lm_head.weight = self.embed.weight  # Tie weights
        
    def _accumulate_state(
        self, 
        accumulated: Optional[torch.Tensor], 
        new_state: torch.Tensor
    ) -> torch.Tensor:
        """Accumulate GDN states based on configured strategy."""
        if accumulated is None:
            return new_state
        
        strategy = self.cfg.state_accumulation
        if strategy == 'replace':
            return new_state
        elif strategy == 'avg':
            return 0.5 * accumulated + 0.5 * new_state
        elif strategy == 'weighted':
            w = self.cfg.state_weight_new
            return (1 - w) * accumulated + w * new_state
        else:
            raise ValueError(f"Unknown state accumulation strategy: {strategy}")
    
    def _get_gdn_q_proj(self) -> Optional[nn.Module]:
        """Get the first GDN layer's q_proj for aligned retrieval.
        
        NOTE: We use q_proj (not k_proj) because:
        - k_proj defines WHERE to store (storage address)
        - q_proj defines WHAT to look for (retrieval query)
        - GDN learns these to interact: retrieval = (q · k) * v
        - SWA should query with q_proj to find what k_proj stored
        """
        for layer in self.layers:
            if isinstance(layer, GatedDeltaNetLayer):
                return layer.q_proj
        return None
        
    def forward(
        self, 
        input_ids: torch.Tensor, 
        targets: Optional[torch.Tensor] = None,
        return_diagnostics: bool = False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[Dict]], Optional[torch.Tensor]]:
        """
        Standard forward pass.
        
        Returns:
            logits, loss, diagnostics (if requested), final_state
        """
        x = self.embed(input_ids)
        accumulated_state = None
        all_diagnostics = []
        
        # Get GDN's q_proj for aligned SWA retrieval
        gdn_q_proj = self._get_gdn_q_proj()
        
        for i, (layer, ffn) in enumerate(zip(self.layers, self.ffns)):
            layer_type = self.cfg.layer_pattern[i]
            
            if layer_type == 'G':
                x, state, diag = layer(x, initial_state=accumulated_state, output_state=True)
                accumulated_state = self._accumulate_state(accumulated_state, state)
                diag['layer_type'] = 'GDN'
            elif layer_type == 'S':
                # Pass gdn_q_proj for aligned retrieval
                x, diag = layer(x, gdn_state=accumulated_state, gdn_q_proj=gdn_q_proj)
                diag['layer_type'] = 'SWA'
            
            x = ffn(x)
            diag['layer_idx'] = i
            all_diagnostics.append(diag)
        
        x = self.norm_f(x)
        logits = self.lm_head(x)
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        
        if return_diagnostics:
            return logits, loss, all_diagnostics, accumulated_state
        return logits, loss, None, accumulated_state


print("TransparentHybrid loaded (with aligned retrieval support).")

In [None]:
# =============================================================================
# CELL 5: DIAGNOSTIC TOOLKIT
# =============================================================================
# These functions probe the information flow to identify bottlenecks.
# Run these AFTER training to understand why NIAH fails.
#
# Decision Tree:
#   1. probe_gdn_state_content() -> Is the needle IN the GDN state?
#      - Yes -> Architecture working, tune gate/fusion
#      - No  -> Is it being written? Check beta values
#               - No write -> Fix GDN write mechanism
#               - Written but lost -> Fix SWA query mechanism
#
#   2. visualize_swa_state_attention() -> Is SWA attending to state?
#      - Check attention entropy and head specialization
#
#   3. trace_needle_pipeline() -> End-to-end similarity tracking
# =============================================================================

@torch.no_grad()
def probe_gdn_state_content(
    model: TransparentHybrid, 
    input_ids: torch.Tensor, 
    target_token_pos: int = 32,
    verbose: bool = True
) -> Dict[str, Any]:
    """
    Probe 1: Check if specific token information is encoded in GDN state.
    
    This directly tests the WRITE path: is the GDN storing token-specific
    information that can be retrieved later?
    
    Args:
        model: The hybrid model
        input_ids: Input sequence [1, T] with a "needle" token at target_token_pos
        target_token_pos: Position of the needle token to probe for
        verbose: Print results
        
    Returns:
        Dict with per-layer retrieval analysis and signal-to-noise ratios
    """
    model.eval()
    results = {'layers': [], 'summary': {}}
    
    x = model.embed(input_ids)
    accumulated_state = None
    needle_id = input_ids[0, target_token_pos].item()
    
    if verbose:
        print(f"\n{'='*60}")
        print(f"PROBE 1: GDN State Content Analysis")
        print(f"{'='*60}")
        print(f"Needle token ID: {needle_id} at position {target_token_pos}")
        print()
    
    for i, (layer, ffn) in enumerate(zip(model.layers, model.ffns)):
        layer_type = model.cfg.layer_pattern[i]
        
        if layer_type == 'G':
            x, layer_state, diag = layer(x, initial_state=accumulated_state, output_state=True)
            accumulated_state = model._accumulate_state(accumulated_state, layer_state)
            
            # Get needle embedding and project through this layer's K projection
            needle_embed = model.embed.weight[needle_id]
            needle_key = layer.k_proj(needle_embed).view(model.cfg.n_heads, model.cfg.head_dim)
            needle_key = F.normalize(needle_key.float(), p=2, dim=-1)
            
            # Query the state: how much of the needle is retrievable?
            # needle_key @ state -> retrieved value
            retrieved = torch.einsum('hk,bhkv->bhv', needle_key, layer_state.float())
            needle_retrieval_norm = retrieved.norm().item()
            
            # Compare to random token baseline
            rand_token = torch.randint(0, model.cfg.vocab_size, (1,), device=input_ids.device).item()
            rand_embed = model.embed.weight[rand_token]
            rand_key = layer.k_proj(rand_embed).view(model.cfg.n_heads, model.cfg.head_dim)
            rand_key = F.normalize(rand_key.float(), p=2, dim=-1)
            rand_retrieved = torch.einsum('hk,bhkv->bhv', rand_key, layer_state.float())
            rand_retrieval_norm = rand_retrieved.norm().item()
            
            snr = needle_retrieval_norm / (rand_retrieval_norm + 1e-8)
            
            layer_result = {
                'layer_idx': i,
                'layer_type': 'GDN',
                'state_norm': layer_state.norm().item(),
                'state_shape': tuple(layer_state.shape),
                'needle_retrieval_norm': needle_retrieval_norm,
                'random_retrieval_norm': rand_retrieval_norm,
                'signal_to_noise': snr,
                'beta_mean': diag['beta_mean'],
                'beta_std': diag['beta_std'],
                'g_mean': diag['g_mean'],
            }
            results['layers'].append(layer_result)
            
            if verbose:
                status = "✓" if snr > 1.0 else "✗"
                print(f"  [GDN Layer {i}] {status}")
                print(f"      State: norm={layer_state.norm().item():.4f}, shape={tuple(layer_state.shape)}")
                print(f"      Needle Retrieval: {needle_retrieval_norm:.6f}")
                print(f"      Random Retrieval: {rand_retrieval_norm:.6f}")
                print(f"      Signal-to-Noise:  {snr:.4f} {'(GOOD)' if snr > 1.0 else '(WEAK)'}")
                print(f"      β={diag['beta_mean']:.3f}±{diag['beta_std']:.3f}, g={diag['g_mean']:.3f}")
                print()
                
        elif layer_type == 'S':
            x, _ = layer(x, gdn_state=accumulated_state)
            results['layers'].append({'layer_idx': i, 'layer_type': 'SWA'})
        
        x = ffn(x)
    
    # Summary
    gdn_results = [r for r in results['layers'] if r['layer_type'] == 'GDN']
    if gdn_results:
        avg_snr = sum(r['signal_to_noise'] for r in gdn_results) / len(gdn_results)
        max_snr = max(r['signal_to_noise'] for r in gdn_results)
        results['summary'] = {
            'avg_snr': avg_snr,
            'max_snr': max_snr,
            'needle_stored': max_snr > 1.0,
        }
        
        if verbose:
            print(f"  Summary: avg_SNR={avg_snr:.4f}, max_SNR={max_snr:.4f}")
            if max_snr > 1.0:
                print(f"  → Needle IS stored in GDN state")
            else:
                print(f"  → Needle NOT effectively stored (check beta/write mechanism)")
    
    return results


@torch.no_grad()
def visualize_swa_state_attention(
    model: nn.Module,
    input_ids: torch.Tensor,
    target_swa_layer: Optional[int] = None,
    verbose: bool = True
) -> Dict:
    """
    Probe 2: Analyze SWA's attention to GDN state.
    
    UPDATED: Now passes gdn_k_proj for aligned retrieval.
    """
    model.eval()
    
    # First pass to get GDN state
    x = model.embed(input_ids)
    accumulated_state = None
    
    # Get gdn_q_proj for aligned retrieval
    gdn_q_proj = model._get_gdn_q_proj() if hasattr(model, '_get_gdn_q_proj') else None
    
    # Forward through layers to accumulate state
    for i, (layer, ffn) in enumerate(zip(model.layers, model.ffns)):
        layer_type = model.cfg.layer_pattern[i]
        
        if layer_type == 'G':
            x, state, _ = layer(x, initial_state=accumulated_state, output_state=True)
            accumulated_state = model._accumulate_state(accumulated_state, state)
        elif layer_type == 'S':
            x, _ = layer(x, gdn_state=accumulated_state, gdn_q_proj=gdn_q_proj)
        
        x = ffn(x)
    
    # Now analyze each SWA layer's attention
    results = {'layers': [], 'summary': {}}
    
    x = model.embed(input_ids)
    accumulated_state = None
    
    for i, (layer, ffn) in enumerate(zip(model.layers, model.ffns)):
        layer_type = model.cfg.layer_pattern[i]
        
        if layer_type == 'G':
            x, state, _ = layer(x, initial_state=accumulated_state, output_state=True)
            accumulated_state = model._accumulate_state(accumulated_state, state)
            x = ffn(x)
            continue
        
        # SWA layer - get attention weights WITH aligned retrieval
        if target_swa_layer is not None and i != target_swa_layer:
            x, _ = layer(x, gdn_state=accumulated_state, gdn_q_proj=gdn_q_proj)
            x = ffn(x)
            continue
        
        # Get detailed attention for this SWA layer
        out, diag, attn_local, attn_global, gate = layer(
            x, 
            gdn_state=accumulated_state,
            gdn_q_proj=gdn_q_proj,  # Pass for aligned retrieval
            return_attn=True
        )
        
        if attn_global is None:
            x = ffn(out)
            continue
        
        # Analyze global attention
        # attn_global: [B, H, T, K] - attention over K state dimensions
        H = attn_global.shape[1]
        T = attn_global.shape[2]
        K = attn_global.shape[3]
        
        # Focus on final token's attention to state
        final_attn = attn_global[0, :, -1, :]  # [H, K]
        
        # Per-head analysis
        head_analysis = []
        for h in range(H):
            attn_h = final_attn[h]  # [K]
            entropy = -(attn_h * attn_h.clamp(min=1e-8).log()).sum().item()
            max_attn, max_slot = attn_h.max(dim=0)
            head_analysis.append({
                'head': h,
                'entropy': entropy,
                'max_attn': max_attn.item(),
                'max_slot': max_slot.item(),
            })
        
        # Average entropy across heads
        avg_entropy = sum(h['entropy'] for h in head_analysis) / H
        max_entropy = math.log(K)  # Uniform distribution
        focus_ratio = 1 - (avg_entropy / max_entropy)
        
        layer_result = {
            'layer_idx': i,
            'layer_type': 'SWA',
            'local_attn_shape': tuple(attn_local.shape),
            'global_attn_shape': tuple(attn_global.shape),
            'gate_mean': gate.mean().item(),
            'gate_std': gate.std().item(),
            'avg_global_entropy': avg_entropy,
            'max_possible_entropy': max_entropy,
            'focus_ratio': focus_ratio,
            'per_head': head_analysis,
        }
        results['layers'].append(layer_result)
        
        if verbose:
            status = "✓" if focus_ratio > 0.2 else "✗"
            print(f"\n  [SWA Layer {i}] {status}")
            print(f"      Global Attn Shape: {tuple(attn_global.shape)}")
            print(f"      Gate: mean={gate.mean().item():.3f}, std={gate.std().item():.3f}")
            print(f"      Avg Entropy: {avg_entropy:.4f} / {max_entropy:.4f} (max)")
            focus_status = "FOCUSED" if focus_ratio > 0.2 else "DIFFUSE"
            print(f"      Focus Ratio: {focus_ratio:.4f} ({focus_status})")
            print(f"      Per-head (final token -> state):")
            for h in head_analysis:
                print(f"        H{h['head']}: entropy={h['entropy']:.3f}, max={h['max_attn']:.3f}@slot{h['max_slot']}")
        
        x = ffn(out)
    
    # Summary
    if results['layers']:
        avg_focus = sum(l['focus_ratio'] for l in results['layers']) / len(results['layers'])
        avg_gate = sum(l['gate_mean'] for l in results['layers']) / len(results['layers'])
        results['summary'] = {
            'avg_focus_ratio': avg_focus,
            'avg_gate': avg_gate,
            'attention_focused': avg_focus > 0.2,
            'using_global': avg_gate > 0.3,
        }
        
        if verbose:
            print(f"\n  Summary: avg_focus={avg_focus:.4f}, avg_gate={avg_gate:.4f}")
            if avg_focus < 0.2:
                print("  → SWA attention is DIFFUSE (not learning to query)")
            else:
                print("  → SWA attention is FOCUSED (learning to query specific slots)")
            if avg_gate > 0.3:
                print("  → SWA is USING global state (gate > 0.3)")
            else:
                print("  → SWA is IGNORING global state (gate < 0.3)")
    
    return results

@torch.no_grad()
def trace_needle_pipeline(
    model: TransparentHybrid, 
    input_ids: torch.Tensor, 
    needle_pos: int = 32,
    query_pos: int = -1,
    verbose: bool = True
) -> Dict[str, Any]:
    """
    Probe 3: Trace needle information through the entire forward pass.
    
    Tracks cosine similarity between the query position's representation
    and the needle token's embedding as it evolves through layers.
    
    Look for:
        - Similarity INCREASE at SWA layers (retrieval working)
        - State containing needle at GDN layers
    
    Args:
        model: The hybrid model
        input_ids: Input sequence [1, T] with needle at needle_pos
        needle_pos: Position of needle token
        query_pos: Position to track (default: -1 = final position)
        verbose: Print results
        
    Returns:
        Dict with per-layer trajectory
    """
    model.eval()
    
    T = input_ids.shape[1]
    if query_pos < 0:
        query_pos = T + query_pos  # Convert negative index
    
    needle_id = input_ids[0, needle_pos].item()
    needle_embed = model.embed.weight[needle_id].float()
    
    if verbose:
        print(f"\n{'='*60}")
        print(f"PROBE 3: Needle Pipeline Trace")
        print(f"{'='*60}")
        print(f"Needle: token {needle_id} @ pos {needle_pos}")
        print(f"Query:  pos {query_pos}")
        print(f"Distance: {query_pos - needle_pos} tokens\n")
    
    x = model.embed(input_ids)
    accumulated_state = None
    trajectory = []
    
    # Initial similarity
    init_rep = x[0, query_pos].float()
    init_sim = F.cosine_similarity(init_rep, needle_embed, dim=0).item()
    trajectory.append({
        'stage': 'embed',
        'layer_idx': -1,
        'layer_type': 'EMBED',
        'similarity': init_sim,
        'delta': 0.0,
    })
    
    for i, (layer, ffn) in enumerate(zip(model.layers, model.ffns)):
        layer_type = model.cfg.layer_pattern[i]
        
        pre_rep = x[0, query_pos].float()
        pre_sim = F.cosine_similarity(pre_rep, needle_embed, dim=0).item()
        
        state_needle_norm = None
        
        if layer_type == 'G':
            x, layer_state, diag = layer(x, initial_state=accumulated_state, output_state=True)
            accumulated_state = model._accumulate_state(accumulated_state, layer_state)
            
            # Check if needle is in state
            needle_key = layer.k_proj(needle_embed.to(x.dtype)).view(model.cfg.n_heads, model.cfg.head_dim)
            needle_key = F.normalize(needle_key.float(), p=2, dim=-1)
            retrieved = torch.einsum('hk,bhkv->bhv', needle_key, layer_state.float())
            state_needle_norm = retrieved.norm().item()
            
        elif layer_type == 'S':
            x, diag = layer(x, gdn_state=accumulated_state)
        
        x = ffn(x)
        
        post_rep = x[0, query_pos].float()
        post_sim = F.cosine_similarity(post_rep, needle_embed, dim=0).item()
        delta = post_sim - pre_sim
        
        layer_result = {
            'stage': f'layer_{i}',
            'layer_idx': i,
            'layer_type': layer_type,
            'pre_sim': pre_sim,
            'post_sim': post_sim,
            'delta': delta,
            'state_needle_norm': state_needle_norm,
        }
        trajectory.append(layer_result)
        
        if verbose:
            arrow = "↑" if delta > 0.01 else ("↓" if delta < -0.01 else "→")
            type_str = 'GDN' if layer_type == 'G' else 'SWA'
            state_str = f", state_needle={state_needle_norm:.4f}" if state_needle_norm else ""
            print(f"  L{i:2d} [{type_str}]: {pre_sim:+.4f} {arrow} {post_sim:+.4f} (Δ{delta:+.4f}){state_str}")
    
    # Summary
    final_sim = trajectory[-1]['post_sim']
    total_delta = final_sim - init_sim
    
    # Find where biggest gains/losses happen
    max_gain = max(t['delta'] for t in trajectory[1:])
    max_loss = min(t['delta'] for t in trajectory[1:])
    max_gain_layer = [t for t in trajectory[1:] if t['delta'] == max_gain][0]
    max_loss_layer = [t for t in trajectory[1:] if t['delta'] == max_loss][0]
    
    results = {
        'trajectory': trajectory,
        'summary': {
            'initial_sim': init_sim,
            'final_sim': final_sim,
            'total_delta': total_delta,
            'max_gain': max_gain,
            'max_gain_layer': max_gain_layer['layer_idx'],
            'max_gain_type': max_gain_layer['layer_type'],
            'max_loss': max_loss,
            'max_loss_layer': max_loss_layer['layer_idx'],
            'retrieval_working': total_delta > 0,
        }
    }
    
    if verbose:
        print(f"\n  Summary:")
        print(f"    Initial → Final: {init_sim:+.4f} → {final_sim:+.4f} (Δ{total_delta:+.4f})")
        print(f"    Max gain: {max_gain:+.4f} at L{max_gain_layer['layer_idx']} [{max_gain_layer['layer_type']}]")
        print(f"    Max loss: {max_loss:+.4f} at L{max_loss_layer['layer_idx']} [{max_loss_layer['layer_type']}]")
        if total_delta > 0:
            print(f"    → Needle info IS reaching query position")
        else:
            print(f"    → Needle info NOT reaching query position")
    
    return results


def run_full_diagnostic(
    model: TransparentHybrid,
    seq_len: int = 128,
    needle_pos: int = 32,
    needle_token: int = 50000,
    device: str = 'cuda'
) -> Dict[str, Any]:
    """
    Run all three diagnostic probes and return combined results.
    
    This is your one-stop diagnostic function. Run this after training
    to understand where the information pipeline is breaking.
    """
    # Create test sequence with needle
    tokens = torch.randint(1000, 10000, (1, seq_len), device=device)
    tokens[0, needle_pos] = needle_token
    
    print("\n" + "#"*60)
    print("# FULL DIAGNOSTIC SUITE")
    print("#"*60)
    print(f"Sequence length: {seq_len}")
    print(f"Needle token: {needle_token} @ position {needle_pos}")
    print(f"Query position: {seq_len - 1} (final)")
    
    # Run probes
    probe1 = probe_gdn_state_content(model, tokens, target_token_pos=needle_pos)
    probe2 = visualize_swa_state_attention(model, tokens)
    probe3 = trace_needle_pipeline(model, tokens, needle_pos=needle_pos)
    
    # Diagnosis
    print("\n" + "="*60)
    print("DIAGNOSIS")
    print("="*60)
    
    needle_stored = probe1['summary'].get('needle_stored', False)
    attention_focused = probe2['summary'].get('attention_focused', False)
    using_global = probe2['summary'].get('using_global', False)
    retrieval_working = probe3['summary'].get('retrieval_working', False)
    
    if needle_stored and attention_focused and retrieval_working:
        print("✓ Architecture appears to be WORKING")
        print("  Focus on: Gate tuning, fusion weights, training dynamics")
    elif not needle_stored:
        print("✗ Problem: GDN NOT storing needle")
        print("  → Check beta (write strength) values")
        print("  → Check beta_proj initialization")
        print("  → May need stronger write gate bias")
    elif needle_stored and not attention_focused:
        print("✗ Problem: SWA attention is DIFFUSE")
        print("  → Needle is stored but SWA can't find it")
        print("  → Check state_k_proj / state_v_proj initialization")
        print("  → Consider attentional reading instead of additive fusion")
    elif needle_stored and attention_focused and not using_global:
        print("✗ Problem: SWA IGNORING global state")
        print("  → Gate is too low")
        print("  → Consider initializing gate bias positive")
        print("  → Or switch to learned gating mechanism")
    else:
        print("? Unclear failure mode - review individual probe results")
    
    return {
        'probe1_gdn_state': probe1,
        'probe2_swa_attention': probe2,
        'probe3_pipeline': probe3,
    }


print("Diagnostic Toolkit loaded.")
print("  - probe_gdn_state_content()")
print("  - visualize_swa_state_attention()")
print("  - trace_needle_pipeline()")
print("  - run_full_diagnostic()")

In [None]:
# =============================================================================
# CELL 6: Training Infrastructure
# =============================================================================

def simple_niah(
    model: TransparentHybrid, 
    seq_len: int = 128, 
    needle_pos: int = 32, 
    needle_token: int = 50000,
    n_trials: int = 20
) -> List[Dict[str, float]]:
    """
    Needle-In-A-Haystack test.
    
    Insert a rare token early in sequence, check if model assigns
    higher probability to it at the final position.
    
    Returns:
        List of trial results with needle probability and ratio vs random
    """
    model.eval()
    results = []
    
    with torch.no_grad():
        for _ in range(n_trials):
            tokens = torch.randint(1000, 10000, (1, seq_len), device=next(model.parameters()).device)
            tokens[0, needle_pos] = needle_token
            
            logits, _, _, state = model(tokens, return_diagnostics=True)
            
            final_probs = F.softmax(logits[0, -1].float(), dim=-1)
            needle_prob = final_probs[needle_token].item()
            random_baseline = 1.0 / model.cfg.vocab_size
            
            results.append({
                'needle_prob': needle_prob,
                'ratio': needle_prob / random_baseline,
                'state_norm': state.norm().item() if state is not None else 0,
            })
    
    avg_ratio = sum(r['ratio'] for r in results) / len(results)
    print(f"NIAH: {avg_ratio:.4f}x random ({'PASS' if avg_ratio > 1.0 else 'FAIL'})")
    return results


class DataLoader:
    """Simple streaming data loader for training."""
    
    def __init__(
        self, 
        token_tensor: torch.Tensor, 
        batch_size: int = 4, 
        seq_len: int = 128
    ):
        self.tokens = token_tensor
        self.batch_size = batch_size
        self.seq_len = seq_len
        
    def get_batch(self) -> Tuple[torch.Tensor, torch.Tensor]:
        ix = torch.randint(0, len(self.tokens) - self.seq_len - 1, (self.batch_size,))
        x = torch.stack([self.tokens[i:i+self.seq_len] for i in ix])
        y = torch.stack([self.tokens[i+1:i+self.seq_len+1] for i in ix])
        return x, y


def train(
    model: TransparentHybrid,
    data_loader: DataLoader,
    steps: int = 10000,
    lr: float = 3e-4,
    warmup_steps: int = 200,
    log_every: int = 100,
    niah_every: int = 500,
    niah_seq_len: int = 128,
    niah_needle_pos: int = 32,
) -> Dict[str, List]:
    """
    Training loop with monitoring.
    
    Returns:
        History dict with loss, NIAH ratios, and diagnostic values
    """
    from torch.optim import AdamW
    
    opt = AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1)
    
    history = defaultdict(list)
    
    print(f"\nTraining {steps} steps")
    print("="*60)
    
    model.train()
    start = time.time()
    
    for step in range(steps):
        # LR schedule: linear warmup then cosine decay
        if step < warmup_steps:
            current_lr = lr * (step + 1) / warmup_steps
        else:
            progress = (step - warmup_steps) / (steps - warmup_steps)
            current_lr = lr * 0.5 * (1 + math.cos(math.pi * progress))
        for pg in opt.param_groups:
            pg['lr'] = current_lr
        
        # Forward
        x, y = data_loader.get_batch()
        logits, loss, diags, state = model(x, y, return_diagnostics=True)
        
        # Backward
        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        
        # Track
        history['loss'].append(loss.item())
        history['lr'].append(current_lr)
        
        # Extract diagnostics
        gdn_diags = [d for d in diags if d['layer_type'] == 'GDN']
        swa_diags = [d for d in diags if d['layer_type'] == 'SWA']
        
        if gdn_diags:
            history['gdn_beta'].append(gdn_diags[0]['beta_mean'])
            history['gdn_g'].append(gdn_diags[0]['g_mean'])
            history['state_norm'].append(gdn_diags[0]['state_norm'])
        if swa_diags:
            history['swa_gate'].append(swa_diags[0]['gate_mean'])
        
        # Log
        if step % log_every == 0:
            elapsed = time.time() - start
            tps = (step + 1) * data_loader.batch_size * data_loader.seq_len / elapsed
            avg_loss = sum(history['loss'][-50:]) / min(50, len(history['loss']))
            
            gdn_str = f"β={gdn_diags[0]['beta_mean']:.3f} g={gdn_diags[0]['g_mean']:.3f}" if gdn_diags else ""
            swa_str = f"gate={swa_diags[0]['gate_mean']:.2f}" if swa_diags else ""
            
            print(f"[{step:5d}] loss={avg_loss:.3f} lr={current_lr:.2e} | {gdn_str} {swa_str} | {tps:,.0f} tok/s")
        
        # NIAH check
        if (step + 1) % niah_every == 0:
            model.eval()
            niah = simple_niah(model, seq_len=niah_seq_len, needle_pos=niah_needle_pos, n_trials=30)
            avg_ratio = sum(r['ratio'] for r in niah) / len(niah)
            history['niah_ratio'].append((step + 1, avg_ratio))
            model.train()
    
    elapsed = time.time() - start
    print(f"\nTraining complete in {elapsed/60:.1f} min")
    print(f"Final loss: {sum(history['loss'][-50:])/50:.3f}")
    
    return dict(history)


print("Training infrastructure loaded.")

In [None]:
# =============================================================================
# CELL 7: Build & Initialize Model
# =============================================================================

print("="*60)
print("Building TransparentHybrid")
print("="*60)

# Configuration
cfg = HybridConfig(
    d_model=256,
    n_heads=8,
    layer_pattern="GS",
    window_size=128,
    state_accumulation='weighted',
    state_weight_new=0.5,
)

model = TransparentHybrid(cfg).cuda().bfloat16()
params = count_params(model)

print(f"\nArchitecture: {cfg.layer_pattern}")
print(f"  GDN layers: {cfg.gdn_indices}")
print(f"  SWA layers: {cfg.swa_indices}")
print(f"\nParameters:")
print(f"  Total: {params['total']/1e6:.2f}M")
print(f"  GDN:   {params['gdn']/1e6:.2f}M")
print(f"  SWA:   {params['swa']/1e6:.2f}M")
print(f"  FFN:   {params['ffn']/1e6:.2f}M")
print(f"  Embed: {params['embed']/1e6:.2f}M")

# Quick forward pass test
x = torch.randint(0, 1000, (1, 128), device='cuda')
y = torch.randint(0, 1000, (1, 128), device='cuda')
logits, loss, diags, state = model(x, y, return_diagnostics=True)

print(f"\nForward pass OK:")
print(f"  Logits: {logits.shape}")
print(f"  Loss: {loss.item():.4f}")
print(f"  State: {state.shape if state is not None else None}")

In [None]:
# =============================================================================
# CELL 8: Pre-Training Diagnostics (Optional)
# =============================================================================
# Run this BEFORE training to establish baseline behavior.
# Compare with post-training diagnostics to see what changed.

print("Pre-training diagnostic baseline:")
print()

# Quick NIAH
niah_pre = simple_niah(model, seq_len=128, needle_pos=32, n_trials=20)

# Full diagnostic
pre_diag = run_full_diagnostic(model, seq_len=128, needle_pos=32)

In [None]:
# =============================================================================
# CELL 9: Load Data
# =============================================================================

from datasets import load_dataset
from transformers import AutoTokenizer

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

print("Loading data...")
ds = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT", split="train", streaming=True)

# Buffer tokens
token_buffer = []
target_tokens = 2_000_000  # 2M tokens

for doc in ds:
    toks = tokenizer.encode(doc['text'])
    token_buffer.extend(toks)
    if len(token_buffer) >= target_tokens:
        break

token_tensor = torch.tensor(token_buffer[:target_tokens], device='cuda')
print(f"Loaded {len(token_tensor):,} tokens")

# Create data loader
data_loader = DataLoader(token_tensor, batch_size=4, seq_len=128)

In [None]:
# =============================================================================
# CELL 10: Training
# =============================================================================

history = train(
    model,
    data_loader,
    steps=20000,
    lr=3e-4,
    warmup_steps=2000,
    log_every=100,
    niah_every=500,
    niah_seq_len=128,
    niah_needle_pos=32,
)

In [None]:
# =============================================================================
# CELL 11: POST-TRAINING DIAGNOSTICS
# =============================================================================
# This is the critical cell. Run after training to identify the failure mode.

print("\n" + "#"*60)
print("# POST-TRAINING DIAGNOSTIC ANALYSIS")
print("#"*60)

# Final NIAH at multiple positions
print("\nNIAH at different needle positions:")
model.eval()
for needle_pos in [16, 32, 64, 96]:
    niah = simple_niah(model, seq_len=128, needle_pos=needle_pos, n_trials=30)
    avg = sum(r['ratio'] for r in niah) / len(niah)
    print(f"  needle@{needle_pos}: {avg:.2f}x random")

# Full diagnostic suite
post_diag = run_full_diagnostic(model, seq_len=128, needle_pos=32)

In [None]:
# =============================================================================
# CELL 12: Compare Pre vs Post Training
# =============================================================================

print("\n" + "="*60)
print("PRE vs POST TRAINING COMPARISON")
print("="*60)

# GDN State Storage
pre_snr = pre_diag['probe1_gdn_state']['summary'].get('max_snr', 0)
post_snr = post_diag['probe1_gdn_state']['summary'].get('max_snr', 0)
print(f"\nGDN State SNR:")
print(f"  Pre:  {pre_snr:.4f}")
print(f"  Post: {post_snr:.4f}")
print(f"  Change: {post_snr - pre_snr:+.4f}")

# SWA Attention Focus
pre_focus = pre_diag['probe2_swa_attention']['summary'].get('avg_focus_ratio', 0)
post_focus = post_diag['probe2_swa_attention']['summary'].get('avg_focus_ratio', 0)
print(f"\nSWA Focus Ratio:")
print(f"  Pre:  {pre_focus:.4f}")
print(f"  Post: {post_focus:.4f}")
print(f"  Change: {post_focus - pre_focus:+.4f}")

# Gate Usage
pre_gate = pre_diag['probe2_swa_attention']['summary'].get('avg_gate', 0.5)
post_gate = post_diag['probe2_swa_attention']['summary'].get('avg_gate', 0.5)
print(f"\nSWA Gate (global usage):")
print(f"  Pre:  {pre_gate:.4f}")
print(f"  Post: {post_gate:.4f}")
print(f"  Change: {post_gate - pre_gate:+.4f}")

# Pipeline
pre_delta = pre_diag['probe3_pipeline']['summary'].get('total_delta', 0)
post_delta = post_diag['probe3_pipeline']['summary'].get('total_delta', 0)
print(f"\nPipeline Similarity Delta:")
print(f"  Pre:  {pre_delta:+.4f}")
print(f"  Post: {post_delta:+.4f}")
print(f"  Change: {post_delta - pre_delta:+.4f}")

In [None]:
# =============================================================================
# CELL 13: Next Steps (Based on Diagnosis)
# =============================================================================
# 
# Based on the diagnostic results, here are the recommended fixes:
#
# IF GDN NOT STORING (SNR < 1.0):
#   - Increase beta initialization: nn.init.constant_(beta_proj.bias, 1.0)
#   - Or scale beta output: beta = 0.5 + 0.5 * sigmoid(beta_proj(x))
#
# IF SWA ATTENTION DIFFUSE (focus_ratio < 0.2):
#   - Reinitialize state_k_proj with larger gain:
#     nn.init.xavier_uniform_(state_k_proj.weight, gain=2.0)
#   - Or add temperature scaling to global attention
#
# IF SWA IGNORING GLOBAL (gate < 0.3):
#   - Initialize gate bias positive: nn.init.constant_(gate_proj.bias, 1.0)
#   - Or switch from interpolation to concatenation + learned projection
#
# IF ALL METRICS OK BUT NIAH STILL FAILS:
#   - The fusion mechanism itself may be the bottleneck
#   - Consider replacing additive fusion with cross-attention reading
#   - Or add a dedicated "retrieval head" that directly queries state
#
# =============================================================================

print("Diagnostic complete. Review results above to determine next steps.")