# GroundThink v6 Hybrid Architecture
## GDN (Triton) + SWA (flash_attn) with Complete Analysis Suite

---

### Architecture Overview

**GatedDeltaNet (GDN):**
- TRUE Delta Rule: `S_t = g_t * S_{t-1} + β_t * (v_t - S_{t-1}·k_t) ⊗ k_t`
- Error correction prevents redundant writes and state explosion
- Triton kernels for forward + backward (validated)

**SlidingWindowAttention (SWA):**
- Local attention via flash_attn with window_size parameter
- **State retrieval**: Queries GDN state for global context (enables NIAH)
- PyTorch fallback if flash_attn unavailable

**Hybrid Information Flow:**
```
GDN layers: Compress sequence into state S_t [H, K, V]
     ↓ (state flows)
SWA layers: Query state for retrieval + local attention
```

### Notebook Structure

| Cell | Content |
|------|---------|
| 0 | Environment & Imports |
| 1 | Configuration |
| 2 | Triton Kernels (Forward + Backward) |
| 3 | Model Components |
| 4 | Model Assembly |
| 5 | Data Loading |
| 6 | NIAH Testing Suite |
| 7 | Delta Rule Validation Suite |
| 8 | Training Infrastructure |
| 9 | Gradient Analysis |
| 10 | Performance Profiling |
| 11 | Triton Cache Management |
| 12 | Quick Start / Validation |
| 13 | Training Execution |
| 14 | Post-Training Evaluation |

---

*** Recent edits: ***
### Cell 13: Training Execution

```pythonpython
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)  # adjusted from 1.0 to 2.0
```  
### Cell 1: Configuration

```pythonpython
    g_bias: float = 4.0       # high retention  # adjusted from +2.0 to +4.0
```

### Cell 13: Training Execution

```pythonpython
    lr=2e-4, # adjusted from 3e-4 to 2e-4 to prevent collapse in mixed phase, study curriculum learning literature for details
```
### Both changes below made at the same time:
### Cell 2: Triton Kernels (Backward Pass)

```pythonpython
         # Reconstruct state_prev via division (numerical stability floor)
         safe_denom changed from 1e-6 to 1e-3
         safe_g changed from 1e-6 to 1e-3
```
### Cell 3: Model Components (RMSNorm)

```pythonpython
    def __init__(self, dim: int, eps: float = 1e-3):  # adjusted from 1e-6 to 1e-3
```


In [299]:
# =============================================================================
# CELL 0: Environment & Imports
# =============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
import triton
import triton.language as tl
from dataclasses import dataclass
from typing import Optional, List, Dict, Tuple, Any
import math
import time
import json
from pathlib import Path

# Environment check
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
print(f"Triton: {triton.__version__}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Check flash_attn
try:
    from flash_attn import flash_attn_func
    FLASH_ATTN_AVAILABLE = True
    print("flash_attn: ✓ Available")
except ImportError:
    FLASH_ATTN_AVAILABLE = False
    print("flash_attn: ✗ Using PyTorch fallback")

# Check FLA (for comparison)
try:
    from fla.ops.gated_delta_rule import chunk_gated_delta_rule
    HAS_FLA = True
    print("FLA: ✓ Available (for profiling comparison)")
except ImportError:
    HAS_FLA = False
    print("FLA: Not available")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\n✓ Environment ready. Device: {DEVICE}")

PyTorch: 2.11.0.dev20260128+cu128
CUDA: 12.8
Triton: 3.6.0
GPU: NVIDIA GeForce RTX 4050 Laptop GPU
GPU Memory: 6.4 GB
flash_attn: ✓ Available
FLA: ✓ Available (for profiling comparison)

✓ Environment ready. Device: cuda


In [300]:
# =============================================================================
# CELL 1: Configuration
# =============================================================================

@dataclass
class HybridConfig:
    """
    Configuration for TransparentHybrid model.
    
    Key Parameters:
        layer_pattern: String of 'G' (GDN) and 'S' (SWA)
            Examples: "GS", "GGS", "GGSG", "GGGSGGGS"
        
        head_dim (K): Key dimension for GDN state matrix
        value_dim (V): Value dimension for GDN state matrix
        
        State matrix shape: [B, H, K, V]
        Theoretical capacity: K * V floats per head
        Effective capacity: Much less with random keys (interference)
    
    Gate Initialization (CRITICAL):
        beta_bias=-2.0: sigmoid(-2) ≈ 0.12 → sparse writes (gatekeeper)
        g_bias=+2.0:    sigmoid(+2) ≈ 0.88 → high retention
    """
    # Model dimensions
    d_model: int = 256
    n_heads: int = 8
    head_dim: int = 32        # K dimension for GDN
    value_dim: int = 64       # V dimension for GDN (typically 2x head_dim)
    vocab_size: int = 50257
    
    # Layer pattern
    layer_pattern: str = "GS"
    
    # SWA config
    window_size: int = 64
    
    # Initialization
    init_std: float = 0.02
    
    # Gate biases (GDN)
    beta_bias: float = -2.0   # sparse writes
    g_bias: float = 4.0       # high retention ### adjusted from +2.0 to +4.0
    
    # Special tokens for NIAH testing
    marker_token: int = 50251
    cue_token: int = 50250
    
    def __post_init__(self):
        if self.head_dim is None:
            self.head_dim = self.d_model // self.n_heads
        if self.value_dim is None:
            self.value_dim = self.head_dim * 2
    
    @property
    def n_layers(self) -> int:
        return len(self.layer_pattern)
    
    @property
    def state_capacity(self) -> int:
        """Theoretical state capacity per head (K * V floats)."""
        return self.head_dim * self.value_dim
    
    def describe(self) -> str:
        gdn = sum(1 for c in self.layer_pattern if c == 'G')
        swa = sum(1 for c in self.layer_pattern if c == 'S')
        return (f"HybridConfig: {self.layer_pattern} ({gdn} GDN + {swa} SWA)\n"
                f"  d_model={self.d_model}, n_heads={self.n_heads}\n"
                f"  GDN state: [{self.n_heads}, {self.head_dim}, {self.value_dim}]\n"
                f"  SWA window: {self.window_size}\n"
                f"  State capacity/head: {self.state_capacity} floats")


def count_params(model):
    return sum(p.numel() for p in model.parameters())


# Default config for quick testing
DEFAULT_CFG = HybridConfig(d_model=256, n_heads=8, layer_pattern="GS")
print(DEFAULT_CFG.describe())

HybridConfig: GS (1 GDN + 1 SWA)
  d_model=256, n_heads=8
  GDN state: [8, 32, 64]
  SWA window: 64
  State capacity/head: 2048 floats


In [301]:
# =============================================================================
# CELL 2: Triton Kernels (Forward + Backward)
# =============================================================================
#
# DESIGN NOTES:
#   1. Forward and backward are SEPARATE kernels (not fused)
#      - Explicit state mutation semantics
#      - Cleaner autograd integration
#      - Triton compiles/caches separately anyway
#
#   2. State reconstruction in backward via division:
#      state_prev = (state - β * outer) / g
#      - Unstable when g ≈ 0, using safe_denom floor
#      - Better approach: checkpoint forward states (future work)
#
#   3. Memory layout: Always .contiguous().float() at entry
#      - Non-contiguous tensors cause silent errors
#
#   4. Grid size: (B, H) - one program per batch×head
#      - May hit limits if B*H > 1024 (use tiling for larger)
#
# =============================================================================

@triton.jit
def delta_rule_fwd_kernel(
    K_ptr, V_ptr, Beta_ptr, G_ptr, State_ptr, Out_ptr,
    stride_k_b, stride_k_t, stride_k_h, stride_k_d,
    stride_v_b, stride_v_t, stride_v_h, stride_v_d,
    stride_bg_b, stride_bg_t, stride_bg_h,
    stride_s_b, stride_s_h, stride_s_k, stride_s_v,
    stride_o_b, stride_o_t, stride_o_h, stride_o_v,
    T,
    K_DIM: tl.constexpr,
    V_DIM: tl.constexpr,
):
    """
    TRUE Delta Rule forward kernel.
    
    For each timestep t:
        1. pred = S_{t-1} · k_t           (what we'd retrieve)
        2. error = v_t - pred              (correction needed)
        3. outer = k_t ⊗ error            (rank-1 update)
        4. S_t = g_t * S_{t-1} + β_t * outer
        5. out_t = S_t · k_t              (output)
    """
    pid_b = tl.program_id(0)
    pid_h = tl.program_id(1)
    
    k_offs = tl.arange(0, K_DIM)
    v_offs = tl.arange(0, V_DIM)
    
    # Load initial state
    state_base = State_ptr + pid_b * stride_s_b + pid_h * stride_s_h
    state_ptrs = state_base + k_offs[:, None] * stride_s_k + v_offs[None, :] * stride_s_v
    state = tl.load(state_ptrs).to(tl.float32)
    
    for t in range(T):
        # Load k_t, v_t, beta_t, g_t
        k_base = K_ptr + pid_b * stride_k_b + t * stride_k_t + pid_h * stride_k_h
        k_t = tl.load(k_base + k_offs * stride_k_d).to(tl.float32)
        
        v_base = V_ptr + pid_b * stride_v_b + t * stride_v_t + pid_h * stride_v_h
        v_t = tl.load(v_base + v_offs * stride_v_d).to(tl.float32)
        
        bg_offset = pid_b * stride_bg_b + t * stride_bg_t + pid_h * stride_bg_h
        beta_t = tl.load(Beta_ptr + bg_offset).to(tl.float32)
        g_t = tl.load(G_ptr + bg_offset).to(tl.float32)
        
        # TRUE Delta Rule
        pred = tl.sum(state * k_t[:, None], axis=0)  # S·k
        error = v_t - pred                            # error correction
        outer = k_t[:, None] * error[None, :]        # k ⊗ error
        state = g_t * state + beta_t * outer         # update
        
        # Output
        out_t = tl.sum(state * k_t[:, None], axis=0)
        out_base = Out_ptr + pid_b * stride_o_b + t * stride_o_t + pid_h * stride_o_h
        tl.store(out_base + v_offs * stride_o_v, out_t)
    
    # Store final state
    tl.store(state_ptrs, state)


@triton.jit
def delta_rule_bwd_kernel(
    K_ptr, V_ptr, Beta_ptr, G_ptr, State_in_ptr,
    dOut_ptr, dState_out_ptr,
    dK_ptr, dV_ptr, dBeta_ptr, dG_ptr, dState_in_ptr,
    stride_k_b, stride_k_t, stride_k_h, stride_k_d,
    stride_v_b, stride_v_t, stride_v_h, stride_v_d,
    stride_bg_b, stride_bg_t, stride_bg_h,
    stride_s_b, stride_s_h, stride_s_k, stride_s_v,
    stride_o_b, stride_o_t, stride_o_h, stride_o_v,
    stride_dk_b, stride_dk_t, stride_dk_h, stride_dk_d,
    stride_dv_b, stride_dv_t, stride_dv_h, stride_dv_d,
    stride_dbg_b, stride_dbg_t, stride_dbg_h,
    stride_ds_b, stride_ds_h, stride_ds_k, stride_ds_v,
    T,
    K_DIM: tl.constexpr,
    V_DIM: tl.constexpr,
):
    """
    Backward kernel using adjoint method with state reconstruction.
    
    Gradient terms for dk (3 terms - CRITICAL):
        1. dk_from_output = S · dout        (from out = S·k)
        2. dk_from_outer  = β * dS · error  (from outer = k ⊗ error)
        3. dk_from_pred   = -β * S_prev · (dS·k)  (from error = v - S·k)
    
    The third term was missing in initial version (caused 4.7% gradient error).
    """
    pid_b = tl.program_id(0)
    pid_h = tl.program_id(1)
    
    k_offs = tl.arange(0, K_DIM)
    v_offs = tl.arange(0, V_DIM)
    
    # Phase 1: Forward pass to get final state
    state_base = State_in_ptr + pid_b * stride_s_b + pid_h * stride_s_h
    state_ptrs = state_base + k_offs[:, None] * stride_s_k + v_offs[None, :] * stride_s_v
    state = tl.load(state_ptrs).to(tl.float32)
    
    for t in range(T):
        k_base = K_ptr + pid_b * stride_k_b + t * stride_k_t + pid_h * stride_k_h
        k_t = tl.load(k_base + k_offs * stride_k_d).to(tl.float32)
        v_base = V_ptr + pid_b * stride_v_b + t * stride_v_t + pid_h * stride_v_h
        v_t = tl.load(v_base + v_offs * stride_v_d).to(tl.float32)
        bg_offset = pid_b * stride_bg_b + t * stride_bg_t + pid_h * stride_bg_h
        beta_t = tl.load(Beta_ptr + bg_offset).to(tl.float32)
        g_t = tl.load(G_ptr + bg_offset).to(tl.float32)
        
        pred = tl.sum(state * k_t[:, None], axis=0)
        error = v_t - pred
        outer = k_t[:, None] * error[None, :]
        state = g_t * state + beta_t * outer
    
    # Phase 2: Backward pass (reverse time)
    dstate_out_base = dState_out_ptr + pid_b * stride_ds_b + pid_h * stride_ds_h
    dstate_out_ptrs = dstate_out_base + k_offs[:, None] * stride_ds_k + v_offs[None, :] * stride_ds_v
    dstate = tl.load(dstate_out_ptrs).to(tl.float32)
    
    for t_rev in range(T):
        t = T - 1 - t_rev
        
        # Reload inputs for timestep t
        k_base = K_ptr + pid_b * stride_k_b + t * stride_k_t + pid_h * stride_k_h
        k_t = tl.load(k_base + k_offs * stride_k_d).to(tl.float32)
        v_base = V_ptr + pid_b * stride_v_b + t * stride_v_t + pid_h * stride_v_h
        v_t = tl.load(v_base + v_offs * stride_v_d).to(tl.float32)
        bg_offset = pid_b * stride_bg_b + t * stride_bg_t + pid_h * stride_bg_h
        beta_t = tl.load(Beta_ptr + bg_offset).to(tl.float32)
        g_t = tl.load(G_ptr + bg_offset).to(tl.float32)
        
        dout_base = dOut_ptr + pid_b * stride_o_b + t * stride_o_t + pid_h * stride_o_h
        dout_t = tl.load(dout_base + v_offs * stride_o_v).to(tl.float32)
        
        # Reconstruct state_prev via division (numerical stability floor)
        out_t = tl.sum(state * k_t[:, None], axis=0)
        denom = g_t - beta_t
        safe_denom = tl.where(tl.abs(denom) > 1e-3, denom, 1e-3) ### Reduced from 1e-6 to 1e-3
        pred_before = (out_t - beta_t * v_t) / safe_denom        ### Also see RMSNorm eps change
        error_t = v_t - pred_before                              ### in Cell 2. 
        outer_t = k_t[:, None] * error_t[None, :]
        safe_g = tl.where(g_t > 1e-3, g_t, 1e-3) ### Reduced from 1e-6 to 1e-3
        state_prev = (state - beta_t * outer_t) / safe_g
        
        # Accumulate dstate from output gradient
        dstate = dstate + k_t[:, None] * dout_t[None, :]
        
        # Gradient for v: dv = β * (dS · k)
        dv_t = beta_t * tl.sum(dstate * k_t[:, None], axis=0)
        dv_base = dV_ptr + pid_b * stride_dv_b + t * stride_dv_t + pid_h * stride_dv_h
        tl.store(dv_base + v_offs * stride_dv_d, dv_t)
        
        # Gradient for beta: dβ = sum(dS * outer)
        dbeta_t = tl.sum(dstate * outer_t)
        dbeta_base = dBeta_ptr + pid_b * stride_dbg_b + t * stride_dbg_t + pid_h * stride_dbg_h
        tl.store(dbeta_base, dbeta_t)
        
        # Gradient for g: dg = sum(dS * S_prev)
        dg_t = tl.sum(dstate * state_prev)
        dg_base = dG_ptr + pid_b * stride_dbg_b + t * stride_dbg_t + pid_h * stride_dbg_h
        tl.store(dg_base, dg_t)
        
        # Gradient for k: ALL 3 TERMS
        dk_from_output = tl.sum(state * dout_t[None, :], axis=1)
        dk_from_outer = beta_t * tl.sum(dstate * error_t[None, :], axis=1)
        dstate_dot_k = tl.sum(dstate * k_t[:, None], axis=0)
        dk_from_pred = -beta_t * tl.sum(state_prev * dstate_dot_k[None, :], axis=1)
        dk_t = dk_from_output + dk_from_outer + dk_from_pred
        dk_base = dK_ptr + pid_b * stride_dk_b + t * stride_dk_t + pid_h * stride_dk_h
        tl.store(dk_base + k_offs * stride_dk_d, dk_t)
        
        # Propagate dstate backward
        dstate_k = tl.sum(dstate * k_t[:, None], axis=0)
        dstate = g_t * dstate - beta_t * k_t[:, None] * dstate_k[None, :]
        state = state_prev
    
    # Store gradient w.r.t initial state
    dstate_in_base = dState_in_ptr + pid_b * stride_ds_b + pid_h * stride_ds_h
    dstate_in_ptrs = dstate_in_base + k_offs[:, None] * stride_ds_k + v_offs[None, :] * stride_ds_v
    tl.store(dstate_in_ptrs, dstate)


print("Triton kernels defined (forward + backward).")

Triton kernels defined (forward + backward).


In [302]:
# =============================================================================
# CELL 2b: Triton Wrappers & Autograd
# =============================================================================

def triton_delta_rule(k, v, beta, g, initial_state=None):
    """Forward pass wrapper. Always works in float32 internally."""
    B, T, H, K_DIM = k.shape
    V_DIM = v.shape[-1]
    device = k.device
    orig_dtype = k.dtype
    
    # CRITICAL: contiguous + float32
    k = k.contiguous().float()
    v = v.contiguous().float()
    beta = beta.contiguous().float()
    g = g.contiguous().float()
    
    if initial_state is None:
        state = torch.zeros(B, H, K_DIM, V_DIM, device=device, dtype=torch.float32)
    else:
        state = initial_state.contiguous().float().clone()
    
    out = torch.empty(B, T, H, V_DIM, device=device, dtype=torch.float32)
    
    # Launch kernel
    delta_rule_fwd_kernel[(B, H)](
        k, v, beta, g, state, out,
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        beta.stride(0), beta.stride(1), beta.stride(2),
        state.stride(0), state.stride(1), state.stride(2), state.stride(3),
        out.stride(0), out.stride(1), out.stride(2), out.stride(3),
        T, K_DIM, V_DIM,
    )
    
    return out.to(orig_dtype), state.to(orig_dtype)


def triton_delta_rule_backward(k, v, beta, g, initial_state, d_out, d_state_out):
    """Backward pass wrapper."""
    B, T, H, K_DIM = k.shape
    V_DIM = v.shape[-1]
    
    k = k.contiguous().float()
    v = v.contiguous().float()
    beta = beta.contiguous().float()
    g = g.contiguous().float()
    initial_state = initial_state.contiguous().float()
    d_out = d_out.contiguous().float()
    d_state_out = d_state_out.contiguous().float()
    
    d_k = torch.empty_like(k)
    d_v = torch.empty_like(v)
    d_beta = torch.empty_like(beta)
    d_g = torch.empty_like(g)
    d_state_in = torch.empty_like(initial_state)
    
    delta_rule_bwd_kernel[(B, H)](
        k, v, beta, g, initial_state,
        d_out, d_state_out,
        d_k, d_v, d_beta, d_g, d_state_in,
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        beta.stride(0), beta.stride(1), beta.stride(2),
        initial_state.stride(0), initial_state.stride(1), initial_state.stride(2), initial_state.stride(3),
        d_out.stride(0), d_out.stride(1), d_out.stride(2), d_out.stride(3),
        d_k.stride(0), d_k.stride(1), d_k.stride(2), d_k.stride(3),
        d_v.stride(0), d_v.stride(1), d_v.stride(2), d_v.stride(3),
        d_beta.stride(0), d_beta.stride(1), d_beta.stride(2),
        d_state_in.stride(0), d_state_in.stride(1), d_state_in.stride(2), d_state_in.stride(3),
        T, K_DIM, V_DIM,
    )
    
    return d_k, d_v, d_beta, d_g, d_state_in


class DeltaRuleFunction(torch.autograd.Function):
    """Autograd wrapper connecting Triton forward + backward."""
    @staticmethod
    def forward(ctx, k, v, beta, g, initial_state):
        ctx.save_for_backward(k, v, beta, g, initial_state)
        output, final_state = triton_delta_rule(k, v, beta, g, initial_state)
        return output, final_state
    
    @staticmethod
    def backward(ctx, d_output, d_final_state):
        k, v, beta, g, initial_state = ctx.saved_tensors
        d_k, d_v, d_beta, d_g, d_initial_state = triton_delta_rule_backward(
            k, v, beta, g, initial_state, d_output, d_final_state
        )
        return d_k, d_v, d_beta, d_g, d_initial_state


print("Triton wrappers + autograd ready.")

Triton wrappers + autograd ready.


In [303]:
# =============================================================================
# CELL 3: Model Components
# =============================================================================

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization."""
    def __init__(self, dim: int, eps: float = 1e-3):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x / rms * self.weight


class SwiGLUFFN(nn.Module):
    """SwiGLU Feed-Forward Network."""
    def __init__(self, d_model: int, expansion: float = 8/3):
        super().__init__()
        hidden = ((int(d_model * expansion) + 63) // 64) * 64  # Round to 64
        self.w1 = nn.Linear(d_model, hidden, bias=False)
        self.w3 = nn.Linear(d_model, hidden, bias=False)
        self.w2 = nn.Linear(hidden, d_model, bias=False)
        self.norm = RMSNorm(d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.norm(x)
        return x + self.w2(F.silu(self.w1(h)) * self.w3(h))


class GatedDeltaNetLayer(nn.Module):
    """
    GDN layer with Triton kernels.
    
    TRUE Delta Rule: S_t = g_t * S_{t-1} + β_t * (v_t - S_{t-1}·k_t) ⊗ k_t
    
    Returns:
        output: [B, T, D] with residual connection
        state: [B, H, K, V] final state
        diag: dict with diagnostic info
    """
    def __init__(self, cfg: HybridConfig, layer_idx: int):
        super().__init__()
        self.cfg = cfg
        self.layer_idx = layer_idx
        H, K, V = cfg.n_heads, cfg.head_dim, cfg.value_dim
        
        # Projections
        self.k_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        self.v_proj = nn.Linear(cfg.d_model, H * V, bias=False)
        self.o_proj = nn.Linear(H * V, cfg.d_model, bias=False)
        
        # Gates with biased initialization
        self.beta_proj = nn.Linear(cfg.d_model, H, bias=True)
        nn.init.constant_(self.beta_proj.bias, cfg.beta_bias)  # sparse writes
        
        self.g_proj = nn.Linear(cfg.d_model, H, bias=True)
        nn.init.constant_(self.g_proj.bias, cfg.g_bias)  # high retention
        
        self.norm = RMSNorm(cfg.d_model)
        
    def forward(self, x, initial_state=None):
        B, T, D = x.shape
        H, K, V = self.cfg.n_heads, self.cfg.head_dim, self.cfg.value_dim
        
        x_norm = self.norm(x)
        
        # Projections
        k = self.k_proj(x_norm).view(B, T, H, K)
        v = self.v_proj(x_norm).view(B, T, H, V)
        
        # CRITICAL: L2 normalize keys
        k = F.normalize(k.float(), p=2, dim=-1).to(x.dtype)
        
        # Gates
        beta = torch.sigmoid(self.beta_proj(x_norm))  # [B, T, H]
        g = torch.sigmoid(self.g_proj(x_norm))        # [B, T, H]
        
        # Initialize state
        if initial_state is None:
            state = torch.zeros(B, H, K, V, device=x.device, dtype=x.dtype)
        else:
            state = initial_state.to(x.dtype)
        
        # Triton kernel
        out, new_state = DeltaRuleFunction.apply(k, v, beta, g, state)
        
        # Output projection + residual
        output = out.to(x.dtype).reshape(B, T, H * V)
        output = x + self.o_proj(output)
        
        # Diagnostics
        diag = {
            'beta_mean': beta.mean().item(),
            'beta_max': beta.max().item(),
            'g_mean': g.mean().item(),
            'state_norm': new_state.norm().item(),
            'state_max': new_state.abs().max().item(),
        }
        
        return output, new_state, diag


class SlidingWindowAttention(nn.Module):
    """
    SWA with state retrieval from GDN.
    
    Two pathways:
        1. Local attention (flash_attn or PyTorch fallback)
        2. State retrieval: queries GDN state for global context
    
    This is the key feature enabling NIAH - the SWA can "see" information
    stored in GDN state from earlier in the sequence.
    """
    def __init__(self, cfg: HybridConfig, layer_idx: int):
        super().__init__()
        self.cfg = cfg
        H, K, V = cfg.n_heads, cfg.head_dim, cfg.value_dim
        self.head_dim = cfg.d_model // cfg.n_heads
        self.scale = K ** -0.5
        
        # Local attention projections
        self.q_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.k_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.v_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.o_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        
        # State retrieval (dedicated projection)
        self.global_q_proj = nn.Linear(cfg.d_model, H * K, bias=False)
        nn.init.normal_(self.global_q_proj.weight, std=cfg.init_std)
        self.retrieval_o_proj = nn.Linear(H * V, cfg.d_model, bias=False)
        
        # Retrieval gate (starts open)
        self.gate_proj = nn.Linear(cfg.d_model, H, bias=True)
        nn.init.constant_(self.gate_proj.bias, 1.0)
        
        self.norm = RMSNorm(cfg.d_model)
        
    def forward(self, x, gdn_state=None):
        B, T, D = x.shape
        H = self.cfg.n_heads
        K, V, W = self.cfg.head_dim, self.cfg.value_dim, self.cfg.window_size
        
        x_norm = self.norm(x)
        
        # === Local Attention ===
        q = self.q_proj(x_norm).view(B, T, H, self.head_dim)
        k = self.k_proj(x_norm).view(B, T, H, self.head_dim)
        v = self.v_proj(x_norm).view(B, T, H, self.head_dim)
        
        if FLASH_ATTN_AVAILABLE:
            local_out = flash_attn_func(q, k, v, causal=True, window_size=(W, 0))
            local_out = local_out.reshape(B, T, D)
        else:
            # PyTorch fallback
            q = q.transpose(1, 2)  # [B, H, T, D]
            k = k.transpose(1, 2)
            v = v.transpose(1, 2)
            
            mask = torch.ones(T, T, device=x.device, dtype=torch.bool).triu(1)
            mask |= torch.ones(T, T, device=x.device, dtype=torch.bool).tril(-W - 1)
            
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
            local_out = (F.softmax(attn, dim=-1) @ v).transpose(1, 2).reshape(B, T, D)
        
        local_out = self.o_proj(local_out)
        
        # === State Retrieval ===
        retrieval_out = torch.zeros_like(x)
        gate_mean = 0.0
        
        if gdn_state is not None:
            # Query GDN state for global retrieval
            q_g = self.global_q_proj(x_norm).view(B, T, H, K).transpose(1, 2)  # [B, H, T, K]
            q_g = F.relu(q_g)  # Sparse queries (only positive activations retrieve)
            
            # Retrieve from state: [B, H, K, V] @ [B, H, T, K]^T -> [B, H, T, V]
            retrieved = torch.einsum('bhkv,bhtk->bhtv', gdn_state.to(x.dtype), q_g)
            retrieved = retrieved.transpose(1, 2).reshape(B, T, H * V)
            retrieval_out = self.retrieval_o_proj(retrieved)
            
            # Gate modulates retrieval
            gate = torch.sigmoid(self.gate_proj(x_norm))  # [B, T, H]
            gate_mean = gate.mean().item()
            retrieval_out = gate.mean(dim=-1, keepdim=True) * retrieval_out
        
        out = x + local_out + retrieval_out
        
        diag = {
            'gate_mean': gate_mean,
            'local_norm': local_out.norm().item(),
            'retrieval_norm': retrieval_out.norm().item() if gdn_state is not None else 0.0,
        }
        
        return out, diag


print("Model components loaded (GDN, SWA, FFN, RMSNorm).")

Model components loaded (GDN, SWA, FFN, RMSNorm).


In [304]:
# =============================================================================
# CELL 4: Model Assembly - TransparentHybrid
# =============================================================================

class TransparentHybrid(nn.Module):
    """
    GDN + SWA hybrid model with full visibility into state flow.
    
    Information Flow:
        - GDN layers compress sequence into state S_t [H, K, V]
        - State flows to subsequent SWA layers for retrieval
        - SWA provides precision retrieval (window + global via state)
    
    Layer pattern examples:
        "GS"       - 2 layers: GDN, SWA (minimal)
        "GGS"      - 3 layers: 2 GDN, 1 SWA
        "GGSG"     - 4 layers: GDN, GDN, SWA, GDN
        "GGGSGGGS" - 8 layers: sparse SWA placement
    """
    def __init__(self, cfg: HybridConfig):
        super().__init__()
        self.cfg = cfg
        
        # Token embedding
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
        nn.init.normal_(self.embed.weight, std=cfg.init_std)
        
        # Build layers
        self.layers = nn.ModuleList()
        self.ffns = nn.ModuleList()
        
        for i, lt in enumerate(cfg.layer_pattern):
            if lt == 'G':
                self.layers.append(GatedDeltaNetLayer(cfg, i))
            elif lt == 'S':
                self.layers.append(SlidingWindowAttention(cfg, i))
            else:
                raise ValueError(f"Unknown layer type: {lt}")
            self.ffns.append(SwiGLUFFN(cfg.d_model))
        
        # Output
        self.norm_f = RMSNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        
        # Weight tying
        self.lm_head.weight = self.embed.weight
        
    class TransparentHybrid(nn.Module):
        """
        GDN + SWA hybrid model with full visibility into state flow.
        Includes memory-safe diagnostic collection.
        """
    def __init__(self, cfg: HybridConfig):
        super().__init__()
        self.cfg = cfg
        
        # Token embedding
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
        nn.init.normal_(self.embed.weight, std=cfg.init_std)
        
        # Build layers
        self.layers = nn.ModuleList()
        self.ffns = nn.ModuleList()
        
        for i, lt in enumerate(cfg.layer_pattern):
            if lt == 'G':
                self.layers.append(GatedDeltaNetLayer(cfg, i))
            elif lt == 'S':
                self.layers.append(SlidingWindowAttention(cfg, i))
            else:
                raise ValueError(f"Unknown layer type: {lt}")
            self.ffns.append(SwiGLUFFN(cfg.d_model))
        
        # Output
        self.norm_f = RMSNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        
        # Weight tying
        self.lm_head.weight = self.embed.weight
        
    def forward(self, input_ids, targets=None, return_diagnostics=False):
        """
        Args:
            input_ids: [B, T] token indices
            targets: [B, T] optional target indices for loss
            return_diagnostics: whether to return layer diagnostics
            
        Returns:
            logits: [B, T, vocab_size]
            loss: scalar if targets provided
            all_diag: list of layer diagnostics (DETACHED from graph)
            state: final GDN state (DETACHED from graph)
        """
        x = self.embed(input_ids)
        state = None
        all_diag = []
        
        # Track the "current" GDN state for SWA layers to use
        current_gdn_state = None
        
        for i, (layer, ffn) in enumerate(zip(self.layers, self.ffns)):
            lt = self.cfg.layer_pattern[i]
            if lt == 'G':
                x, state, diag = layer(x, initial_state=state)
                current_gdn_state = state # Keep attached for SWA next step
                
                # FIXED: Copy ALL keys required by run_full_diagnostic
                if return_diagnostics:
                    safe_diag = {
                        'layer': lt,
                        'layer_idx': i,
                        'beta_mean': diag['beta_mean'],
                        'beta_max': diag.get('beta_max', 0.0),   # ADDED
                        'g_mean': diag['g_mean'],
                        'state_norm': diag['state_norm'],
                        'state_max': diag.get('state_max', 0.0)  # ADDED
                    }
                    all_diag.append(safe_diag)
                
            else: # SWA
                x, diag = layer(x, gdn_state=current_gdn_state)
                
                if return_diagnostics:
                    safe_diag = {
                        'layer': lt,
                        'layer_idx': i,
                        'gate_mean': diag['gate_mean'],
                        'local_norm': diag['local_norm'],
                        'retrieval_norm': diag['retrieval_norm']
                    }
                    all_diag.append(safe_diag)
            
            x = ffn(x)
        
        x = self.norm_f(x)
        logits = self.lm_head(x)
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), 
                targets.view(-1), 
                ignore_index=-100
            )
        
        # Return detached state to stop memory leaks
        return logits, loss, all_diag, state.detach() if state is not None else None

print("TransparentHybrid model ready.")

TransparentHybrid model ready.


In [305]:
# =============================================================================
# CELL 5: Data Loading
# =============================================================================

from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Dataset

class TextDataset(Dataset):
    """Simple dataset for language modeling."""
    def __init__(self, tokens, seq_len=128):
        self.tokens = tokens
        self.seq_len = seq_len
        
    def __len__(self):
        return (len(self.tokens) - 1) // self.seq_len
    
    def __getitem__(self, idx):
        start = idx * self.seq_len
        return torch.tensor(self.tokens[start:start + self.seq_len + 1], dtype=torch.long)


def load_data(n_tokens=500_000, seq_len=128, batch_size=16):
    """Load wikitext data for training."""
    print(f"Loading {n_tokens:,} tokens from wikitext-103...")
    
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    dataset = load_dataset('wikitext', 'wikitext-103-raw-v1', split='train')
    
    all_tokens = []
    for item in dataset:
        if item['text'].strip():
            all_tokens.extend(tokenizer.encode(item['text']))
            if len(all_tokens) >= n_tokens:
                break
    
    all_tokens = all_tokens[:n_tokens]
    print(f"Loaded {len(all_tokens):,} tokens")
    
    ds = TextDataset(all_tokens, seq_len)
    return DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)


# Uncomment to load data:
# data_loader = load_data(n_tokens=500_000, seq_len=128, batch_size=16)
print("Data loading utilities ready.")

Data loading utilities ready.


In [306]:
# =============================================================================
# CELL 6: NIAH Testing Suite
# =============================================================================
#
# Needle-In-A-Haystack tests measure the model's ability to:
#   1. Store information at an early position (MARKER + NEEDLE)
#   2. Retrieve it when cued at a later position (CUE → predict NEEDLE)
#
# This tests the GDN → SWA retrieval pathway.
#
# =============================================================================

def proper_niah_test(model, seq_len=128, needle_pos=32, n_trials=30):
    """NIAH test with MARKER + CUE tokens."""
    model.eval()
    device = next(model.parameters()).device
    cfg = model.cfg
    
    correct = 0
    for _ in range(n_trials):
        needle_id = cfg.vocab_size - 3  # Use a specific token as needle
        seq = torch.randint(0, cfg.vocab_size - 100, (1, seq_len), device=device)
        
        # Place MARKER + NEEDLE at needle_pos
        seq[0, needle_pos] = cfg.marker_token
        seq[0, needle_pos + 1] = needle_id
        
        # Place CUE at end
        seq[0, -1] = cfg.cue_token
        
        with torch.no_grad():
            logits, _, _, _ = model(seq)
        
        # Check if model predicts needle_id after CUE
        pred = logits[0, -1].argmax().item()
        if pred == needle_id:
            correct += 1
    
    acc = correct / n_trials
    print(f"  Accuracy: {acc*100:.1f}% ({correct}/{n_trials})")
    return {'accuracy': acc, 'correct': correct, 'total': n_trials}


def test_niah_by_distance(model, distances=[5, 10, 20, 40, 60, 95], n_trials=20, seq_len=128):
    """
    Test retrieval across varying distances.
    
    Distance = (seq_len - 1) - needle_pos = how far back the needle is from CUE
    """
    print(f"\nNIAH by Distance (seq_len={seq_len}):")
    results = {}
    
    for dist in distances:
        needle_pos = max(2, seq_len - dist - 2)
        print(f"  Distance {dist:3d} (pos={needle_pos:3d}): ", end="")
        result = proper_niah_test(model, seq_len=seq_len, needle_pos=needle_pos, n_trials=n_trials)
        results[dist] = result
    
    return results


def run_full_diagnostic(model, seq_len=128, needle_pos=32):
    """
    Comprehensive diagnostic with state health check.
    
    Checks:
        - State norm (should be bounded, not exploding)
        - Gate values (β and g activation patterns)
        - Retrieval pathway activation
    """
    model.eval()
    device = next(model.parameters()).device
    cfg = model.cfg
    
    # Create test sequence
    needle_id = cfg.vocab_size - 3
    seq = torch.randint(0, cfg.vocab_size - 100, (1, seq_len), device=device)
    seq[0, needle_pos] = cfg.marker_token
    seq[0, needle_pos + 1] = needle_id
    seq[0, -1] = cfg.cue_token
    
    with torch.no_grad():
        logits, _, diags, state = model(seq, return_diagnostics=True)
    
    print(f"\n{'='*60}")
    print("DIAGNOSTIC REPORT")
    print(f"{'='*60}")
    
    # State health
    print(f"\nState Health:")
    print(f"  State norm: {state.norm().item():.2f}")
    print(f"  State max:  {state.abs().max().item():.2f}")
    
    if state.abs().max().item() < 10:
        print(f"  ✓ State bounded - Delta Rule working!")
    elif state.abs().max().item() < 100:
        print(f"  ⚠ State moderately large - monitor during training")
    else:
        print(f"  ✗ State explosion detected - check Delta Rule!")
    
    # Layer diagnostics
    print(f"\nLayer Diagnostics:")
    for i, d in enumerate(diags):
        if d['layer'] == 'G':
            print(f"  Layer {i} (GDN): β={d['beta_mean']:.3f} (max={d['beta_max']:.3f}), "
                  f"g={d['g_mean']:.3f}, state_norm={d['state_norm']:.2f}")
        else:
            print(f"  Layer {i} (SWA): gate={d['gate_mean']:.3f}, "
                  f"local={d['local_norm']:.2f}, retrieval={d['retrieval_norm']:.2f}")
    
    # Prediction check
    pred = logits[0, -1].argmax().item()
    print(f"\nPrediction: {pred} (target: {needle_id}) - {'✓' if pred == needle_id else '✗'}")
    
    return state, diags


print("NIAH testing suite loaded.")

NIAH testing suite loaded.


In [307]:
# =============================================================================
# CELL 7: Delta Rule Validation Suite
# =============================================================================
#
# These tests validate the correctness of the TRUE Delta Rule implementation.
# Understanding these behaviors is CRITICAL for debugging and tuning.
#
# =============================================================================

def test_identical_tokens():
    """
    TEST: Identical tokens should produce zero error on second write.
    
    This is the KEY property of Delta Rule vs naive outer product.
    If this fails, you don't have TRUE Delta Rule.
    """
    print("\n" + "=" * 60)
    print("TEST: Identical Tokens (Redundancy Suppression)")
    print("=" * 60)
    
    B, H, K, V = 1, 4, 32, 64
    device = DEVICE
    
    state = torch.zeros(B, H, K, V, device=device)
    k = F.normalize(torch.randn(B, H, K, device=device), dim=-1)
    v = torch.randn(B, H, V, device=device)
    
    # First write
    pred1 = torch.einsum('bhkv,bhk->bhv', state, k)
    error1 = v - pred1
    update1 = torch.einsum('bhv,bhk->bhkv', error1, k)
    state = state + update1
    norm1 = state.norm().item()
    
    # Second write (SAME k, v)
    pred2 = torch.einsum('bhkv,bhk->bhv', state, k)
    error2 = v - pred2
    update2 = torch.einsum('bhv,bhk->bhkv', error2, k)
    state = state + update2
    norm2 = state.norm().item()
    
    error_ratio = error2.norm().item() / (error1.norm().item() + 1e-8)
    growth = norm2 / norm1
    
    print(f"  Error1 norm: {error1.norm().item():.4f}")
    print(f"  Error2 norm: {error2.norm().item():.6f} (should be ~0)")
    print(f"  Error ratio: {error_ratio:.6f}")
    print(f"  State growth: {growth:.4f}x (should be ~1.0)")
    
    passed = error_ratio < 0.001 and growth < 1.01
    print(f"  → {'✓ PASS' if passed else '✗ FAIL'}")
    return passed


def test_orthogonal_keys():
    """
    TEST: Orthogonal keys should store independently without interference.
    
    This tests the theoretical best case for associative memory.
    """
    print("\n" + "=" * 60)
    print("TEST: Orthogonal Keys (Independent Storage)")
    print("=" * 60)
    
    B, H, K, V = 1, 1, 32, 64
    device = DEVICE
    
    state = torch.zeros(B, H, K, V, device=device)
    
    # Two orthogonal keys (unit vectors)
    k1 = torch.zeros(B, H, K, device=device)
    k1[0, 0, 0] = 1.0
    k2 = torch.zeros(B, H, K, device=device)
    k2[0, 0, 1] = 1.0
    
    v1 = torch.randn(B, H, V, device=device)
    v2 = torch.randn(B, H, V, device=device)
    
    # Write v1 at k1
    pred = torch.einsum('bhkv,bhk->bhv', state, k1)
    error = v1 - pred
    update = torch.einsum('bhv,bhk->bhkv', error, k1)
    state = state + update
    
    # Write v2 at k2
    pred = torch.einsum('bhkv,bhk->bhv', state, k2)
    error = v2 - pred
    update = torch.einsum('bhv,bhk->bhkv', error, k2)
    state = state + update
    
    # Retrieve
    retrieved_v1 = torch.einsum('bhkv,bhk->bhv', state, k1)
    retrieved_v2 = torch.einsum('bhkv,bhk->bhv', state, k2)
    
    error_v1 = (retrieved_v1 - v1).norm().item() / v1.norm().item()
    error_v2 = (retrieved_v2 - v2).norm().item() / v2.norm().item()
    
    print(f"  v1 retrieval error: {error_v1:.6f} (should be ~0)")
    print(f"  v2 retrieval error: {error_v2:.6f} (should be ~0)")
    
    passed = error_v1 < 0.001 and error_v2 < 0.001
    print(f"  → {'✓ PASS' if passed else '✗ FAIL'}")
    return passed


def test_interference():
    """
    TEST: Similar (non-orthogonal) keys cause interference.
    
    This documents EXPECTED behavior - not a failure!
    Understanding interference is key to capacity planning.
    """
    print("\n" + "=" * 60)
    print("TEST: Key Interference (Expected Behavior)")
    print("=" * 60)
    
    B, H, K, V = 1, 1, 32, 64
    device = DEVICE
    
    state = torch.zeros(B, H, K, V, device=device)
    
    # Two similar keys (high dot product)
    k1 = F.normalize(torch.randn(B, H, K, device=device), dim=-1)
    k2 = k1 + 0.1 * F.normalize(torch.randn(B, H, K, device=device), dim=-1)
    k2 = F.normalize(k2, dim=-1)
    
    dot_product = (k1 * k2).sum().item()
    
    v1 = torch.randn(B, H, V, device=device)
    v2 = torch.randn(B, H, V, device=device)
    
    # Write v1 at k1
    pred = torch.einsum('bhkv,bhk->bhv', state, k1)
    error = v1 - pred
    update = torch.einsum('bhv,bhk->bhkv', error, k1)
    state = state + update
    
    # Retrieve v1 BEFORE writing v2
    retrieved_v1_before = torch.einsum('bhkv,bhk->bhv', state, k1)
    error_before = (retrieved_v1_before - v1).norm().item() / v1.norm().item()
    
    # Write v2 at k2 (similar key)
    pred = torch.einsum('bhkv,bhk->bhv', state, k2)
    error = v2 - pred
    update = torch.einsum('bhv,bhk->bhkv', error, k2)
    state = state + update
    
    # Retrieve v1 AFTER writing v2
    retrieved_v1_after = torch.einsum('bhkv,bhk->bhv', state, k1)
    error_after = (retrieved_v1_after - v1).norm().item() / v1.norm().item()
    
    print(f"  Key similarity (dot product): {dot_product:.4f}")
    print(f"  v1 error BEFORE v2 write: {error_before:.6f}")
    print(f"  v1 error AFTER v2 write:  {error_after:.4f}")
    print(f"  Interference occurred: {error_after > error_before}")
    print(f"  → This is EXPECTED: Similar keys interfere")
    return True


def test_capacity_limit():
    """
    TEST: State behavior under many writes.
    
    Shows how state norm grows and early items degrade.
    """
    print("\n" + "=" * 60)
    print("TEST: Capacity Limit (State Saturation)")
    print("=" * 60)
    
    B, H, K, V = 1, 1, 32, 64
    device = DEVICE
    
    n_writes = [10, 50, 100, 200]
    results = []
    
    for n in n_writes:
        state = torch.zeros(B, H, K, V, device=device)
        keys, values = [], []
        
        for i in range(n):
            k = F.normalize(torch.randn(B, H, K, device=device), dim=-1)
            v = torch.randn(B, H, V, device=device)
            keys.append(k)
            values.append(v)
            
            pred = torch.einsum('bhkv,bhk->bhv', state, k)
            error = v - pred
            update = torch.einsum('bhv,bhk->bhkv', error, k)
            state = state + update
        
        # Test retrieval of first and last items
        retrieved_first = torch.einsum('bhkv,bhk->bhv', state, keys[0])
        error_first = (retrieved_first - values[0]).norm().item() / values[0].norm().item()
        
        retrieved_last = torch.einsum('bhkv,bhk->bhv', state, keys[-1])
        error_last = (retrieved_last - values[-1]).norm().item() / values[-1].norm().item()
        
        print(f"  n={n:3d}: state_norm={state.norm().item():.2f}, "
              f"first_err={error_first:.4f}, last_err={error_last:.4f}")
        
        results.append({'n': n, 'state_norm': state.norm().item(),
                       'first_error': error_first, 'last_error': error_last})
    
    print(f"\n  → State grows with writes, early items degrade (expected)")
    return results


def test_forget_gate():
    """
    TEST: Forget gate controls decay.
    
    Shows how different g values affect information retention.
    """
    print("\n" + "=" * 60)
    print("TEST: Forget Gate Effect")
    print("=" * 60)
    
    B, H, K, V = 1, 1, 32, 64
    device = DEVICE
    
    k = F.normalize(torch.randn(B, H, K, device=device), dim=-1)
    v = torch.randn(B, H, V, device=device)
    
    for g_val in [1.0, 0.9, 0.5, 0.1]:
        state = torch.zeros(B, H, K, V, device=device)
        g = torch.full((B, H), g_val, device=device)
        
        # Write once
        pred = torch.einsum('bhkv,bhk->bhv', state, k)
        error = v - pred
        update = torch.einsum('bhv,bhk->bhkv', error, k)
        state = g.unsqueeze(-1).unsqueeze(-1) * state + update
        
        # Apply 10 more "noise" steps
        k_noise = F.normalize(torch.randn(B, H, K, device=device), dim=-1)
        v_zero = torch.zeros(B, H, V, device=device)
        
        for _ in range(10):
            pred = torch.einsum('bhkv,bhk->bhv', state, k_noise)
            error = v_zero - pred
            update = torch.einsum('bhv,bhk->bhkv', error, k_noise)
            state = g.unsqueeze(-1).unsqueeze(-1) * state + update
        
        # Try to retrieve original
        retrieved = torch.einsum('bhkv,bhk->bhv', state, k)
        retention = (retrieved * v).sum().item() / (v.norm().item() ** 2)
        
        print(f"  g={g_val:.1f}: retention after 10 steps = {retention:.4f}")
    
    print(f"\n  → Lower g = faster decay")
    return True


def run_all_validations():
    """Run complete Delta Rule validation suite."""
    print("\n" + "#" * 70)
    print("# DELTA RULE VALIDATION SUITE")
    print("#" * 70)
    
    results = {
        'identical_tokens': test_identical_tokens(),
        'orthogonal_keys': test_orthogonal_keys(),
        'interference': test_interference(),
        'capacity': test_capacity_limit(),
        'forget_gate': test_forget_gate(),
    }
    
    print("\n" + "=" * 60)
    print("VALIDATION SUMMARY")
    print("=" * 60)
    for name, passed in results.items():
        if isinstance(passed, bool):
            status = "✓ PASS" if passed else "✗ FAIL"
        else:
            status = "INFO"
        print(f"  {name}: {status}")
    
    return results


print("Delta Rule validation suite loaded.")

Delta Rule validation suite loaded.


In [308]:
# =============================================================================
# CELL 8: Training Infrastructure
# =============================================================================
#
# CURRICULUM LEARNING:
#   Phase 1 (warmup): Pure retrieval loss
#       - Teaches the model to use MARKER/CUE pattern
#       - Establishes GDN → SWA retrieval pathway
#   
#   Phase 2 (mixed): LM loss + weighted retrieval loss
#       - Maintains retrieval capability
#       - Adds language modeling objective
#
# This two-phase approach prevents the model from "forgetting" how to 
# retrieve once LM training begins.
#
# =============================================================================

def compute_retrieval_loss(model, seq_len=128, batch_size=4):
    """
    Synthetic retrieval task for gradient signal.
    
    Creates sequences with MARKER + NEEDLE → CUE pattern.
    Loss is only computed at the CUE position.
    """
    device = next(model.parameters()).device
    cfg = model.cfg
    
    needle_id = cfg.vocab_size - 3
    tokens = torch.randint(0, cfg.vocab_size - 100, (batch_size, seq_len), device=device)
    
    # Place MARKER + NEEDLE at random positions
    for i in range(batch_size):
        pos = torch.randint(5, seq_len - 10, (1,)).item()
        tokens[i, pos] = cfg.marker_token
        tokens[i, pos + 1] = needle_id
    
    # Place CUE at end
    tokens[:, -1] = cfg.cue_token
    
    # Target: predict needle_id after CUE
    targets = torch.full((batch_size, seq_len), -100, device=device)  # -100 = ignore
    targets[:, -1] = needle_id
    
    _, loss, _, _ = model(tokens, targets=targets)
    return loss


def train_curriculum(model, data_loader, steps=1000, warmup_steps=200,
                     lr=3e-4, retrieval_weight=2.0, log_interval=100):
    """
    Curriculum training: retrieval warmup → mixed LM/retrieval.
    
    Args:
        model: TransparentHybrid
        data_loader: DataLoader for LM data
        steps: total training steps
        warmup_steps: steps for pure retrieval training
        lr: learning rate
        retrieval_weight: weight for retrieval loss in mixed phase
        log_interval: steps between log messages
    """
    device = next(model.parameters()).device
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    
    lm_iter = iter(data_loader)
    history = {'step': [], 'lm': [], 'ret': [], 'phase': []}
    
    print(f"Training {steps} steps ({warmup_steps} warmup)")
    print(f"  LR: {lr}, Retrieval weight: {retrieval_weight}")
    print("="*60)
    
    model.train()
    start_time = time.time()
    
    for step in range(steps):
        optimizer.zero_grad()
        
        # Phase 1: Pure retrieval (warmup)
        if step < warmup_steps:
            ret_loss = compute_retrieval_loss(model)
            ret_loss.backward()
            history['ret'].append(ret_loss.item())
            history['lm'].append(0)
            history['phase'].append('warmup')
        
        # Phase 2: Mixed LM + retrieval
        else:
            try:
                batch = next(lm_iter)
            except StopIteration:
                lm_iter = iter(data_loader)
                batch = next(lm_iter)
            
            input_ids = batch[:, :-1].to(device)
            targets = batch[:, 1:].to(device)
            _, lm_loss, _, _ = model(input_ids, targets)
            
            ret_loss = compute_retrieval_loss(model)
            
            total = lm_loss + retrieval_weight * ret_loss
            total.backward()
            
            history['lm'].append(lm_loss.item())
            history['ret'].append(ret_loss.item())
            history['phase'].append('mixed')
        
        history['step'].append(step)
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0) ### adjusted from 1.0 to 2.0
        optimizer.step()
        
        # Logging
        if step % log_interval == 0:
            phase = "WARMUP" if step < warmup_steps else "MIXED"
            lm = history['lm'][-1]
            ret = history['ret'][-1]
            elapsed = time.time() - start_time
            steps_per_sec = (step + 1) / elapsed
            print(f"[{phase:6s}] Step {step:5d}: LM={lm:6.3f} RET={ret:6.3f} ({steps_per_sec:.1f} steps/s)")
    
    total_time = time.time() - start_time
    print("="*60)
    print(f"Training complete: {steps} steps in {total_time:.1f}s ({steps/total_time:.1f} steps/s)")
    
    return history


print("Training infrastructure loaded.")

Training infrastructure loaded.


In [309]:
# =============================================================================
# CELL 9: Gradient Analysis
# =============================================================================
#
# Understanding gradient flow is CRITICAL for debugging training issues.
#
# Watch for:
#   - NaN/Inf in gradients (numerical instability)
#   - Very small gradients in early layers (vanishing)
#   - Very large gradients in specific components (exploding)
#   - Zero gradients in expected-trainable parameters
#
# =============================================================================

def analyze_gradients(model, seq_len=64, verbose=True):
    """
    Analyze gradient flow through the model.
    
    Returns dict of gradient norms by parameter name.
    """
    device = next(model.parameters()).device
    model.train()
    
    # Create retrieval task input
    x = torch.randint(0, model.cfg.vocab_size - 100, (2, seq_len), device=device)
    x[:, 10] = model.cfg.marker_token
    x[:, 11] = model.cfg.vocab_size - 3  # needle
    x[:, -1] = model.cfg.cue_token
    
    targets = torch.full((2, seq_len), -100, device=device)
    targets[:, -1] = model.cfg.vocab_size - 3
    
    # Forward + backward
    model.zero_grad()
    _, loss, _, _ = model(x, targets)
    loss.backward()
    
    # Collect gradient info
    grad_info = {}
    has_nan = False
    has_inf = False
    has_zero = []
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norm = param.grad.norm().item()
            grad_info[name] = grad_norm
            
            if torch.isnan(param.grad).any():
                has_nan = True
            if torch.isinf(param.grad).any():
                has_inf = True
            if grad_norm < 1e-10:
                has_zero.append(name)
    
    if verbose:
        print(f"\n{'='*60}")
        print("GRADIENT ANALYSIS")
        print(f"{'='*60}")
        print(f"\nLoss: {loss.item():.4f}")
        print(f"\nGradient norms by component:")
        
        # Group by layer
        for name, norm in sorted(grad_info.items()):
            if 'proj' in name or 'embed' in name or 'lm_head' in name:
                print(f"  {name:40s}: {norm:.6f}")
        
        print(f"\n  NaN in gradients: {'✗ YES' if has_nan else '✓ NO'}")
        print(f"  Inf in gradients: {'✗ YES' if has_inf else '✓ NO'}")
        
        if has_zero:
            print(f"  Zero gradients: {len(has_zero)} parameters")
            for name in has_zero[:5]:
                print(f"    - {name}")
        else:
            print(f"  Zero gradients: ✓ NONE")
    
    return {
        'grad_norms': grad_info,
        'has_nan': has_nan,
        'has_inf': has_inf,
        'has_zero': has_zero,
        'loss': loss.item(),
    }


print("Gradient analysis loaded.")

Gradient analysis loaded.


In [310]:
# =============================================================================
# CELL 10: Performance Profiling
# =============================================================================
#
# Profile both:
#   1. Triton kernel alone
#   2. Full model forward pass
#   3. Comparison with FLA (if available)
#
# Key metrics:
#   - ms per forward pass
#   - tokens per second
#   - speedup vs PyTorch/FLA
#
# =============================================================================

def profile_triton_kernel(batch_sizes=[1, 4, 8], seq_lens=[64, 128, 256, 512],
                          n_heads=8, head_dim=32, value_dim=64,
                          n_warmup=3, n_runs=10):
    """Profile Triton Delta Rule kernel."""
    print("="*70)
    print("TRITON KERNEL PROFILING")
    print(f"Config: H={n_heads}, K={head_dim}, V={value_dim}")
    print("="*70)
    
    device = DEVICE
    results = {}
    
    for B in batch_sizes:
        for T in seq_lens:
            k = F.normalize(torch.randn(B, T, n_heads, head_dim, device=device), dim=-1)
            v = torch.randn(B, T, n_heads, value_dim, device=device)
            beta = torch.sigmoid(torch.randn(B, T, n_heads, device=device))
            g = torch.sigmoid(torch.randn(B, T, n_heads, device=device))
            
            # Warmup
            for _ in range(n_warmup):
                triton_delta_rule(k, v, beta, g)
                torch.cuda.synchronize()
            
            # Timed runs
            times = []
            for _ in range(n_runs):
                torch.cuda.synchronize()
                start = time.perf_counter()
                triton_delta_rule(k, v, beta, g)
                torch.cuda.synchronize()
                times.append(time.perf_counter() - start)
            
            avg_ms = sum(times) / len(times) * 1000
            tokens_per_sec = (B * T) / (sum(times) / len(times))
            
            results[(B, T)] = {'avg_ms': avg_ms, 'tokens_per_sec': tokens_per_sec}
            print(f"B={B:2d}, T={T:4d}: {avg_ms:8.2f} ms | {tokens_per_sec:>12,.0f} tok/s")
    
    return results


def profile_fla_kernel(batch_sizes=[1, 4, 8], seq_lens=[64, 128, 256, 512],
                       n_heads=8, head_dim=32, value_dim=64,
                       n_warmup=3, n_runs=10):
    """Profile FLA chunked kernel for comparison."""
    if not HAS_FLA:
        print("FLA not available for comparison")
        return {}
    
    print("\n" + "="*70)
    print("FLA CHUNKED KERNEL (for comparison)")
    print("="*70)
    
    device = DEVICE
    results = {}
    
    for B in batch_sizes:
        for T in seq_lens:
            q = torch.randn(B, T, n_heads, head_dim, device=device)
            k = F.normalize(torch.randn(B, T, n_heads, head_dim, device=device), dim=-1)
            v = torch.randn(B, T, n_heads, value_dim, device=device)
            beta = torch.sigmoid(torch.randn(B, T, n_heads, device=device))
            g = torch.sigmoid(torch.randn(B, T, n_heads, device=device))
            
            # Warmup
            for _ in range(n_warmup):
                _, _ = chunk_gated_delta_rule(q, k, v, g, beta, output_final_state=True)
                torch.cuda.synchronize()
            
            # Timed
            times = []
            for _ in range(n_runs):
                torch.cuda.synchronize()
                start = time.perf_counter()
                _, _ = chunk_gated_delta_rule(q, k, v, g, beta, output_final_state=True)
                torch.cuda.synchronize()
                times.append(time.perf_counter() - start)
            
            avg_ms = sum(times) / len(times) * 1000
            tokens_per_sec = (B * T) / (sum(times) / len(times))
            
            results[(B, T)] = {'avg_ms': avg_ms, 'tokens_per_sec': tokens_per_sec}
            print(f"B={B:2d}, T={T:4d}: {avg_ms:8.2f} ms | {tokens_per_sec:>12,.0f} tok/s")
    
    return results


def profile_full_model(model, batch_sizes=[1, 4], seq_lens=[64, 128, 256],
                       n_warmup=3, n_runs=10):
    """Profile full model forward pass."""
    print("\n" + "="*70)
    print("FULL MODEL PROFILING")
    print("="*70)
    
    device = next(model.parameters()).device
    model.eval()
    results = {}
    
    for B in batch_sizes:
        for T in seq_lens:
            x = torch.randint(0, model.cfg.vocab_size, (B, T), device=device)
            
            # Warmup
            for _ in range(n_warmup):
                with torch.no_grad():
                    model(x)
                torch.cuda.synchronize()
            
            # Timed
            times = []
            for _ in range(n_runs):
                torch.cuda.synchronize()
                start = time.perf_counter()
                with torch.no_grad():
                    model(x)
                torch.cuda.synchronize()
                times.append(time.perf_counter() - start)
            
            avg_ms = sum(times) / len(times) * 1000
            tokens_per_sec = (B * T) / (sum(times) / len(times))
            
            results[(B, T)] = {'avg_ms': avg_ms, 'tokens_per_sec': tokens_per_sec}
            print(f"B={B:2d}, T={T:4d}: {avg_ms:8.2f} ms | {tokens_per_sec:>12,.0f} tok/s")
    
    return results


print("Performance profiling loaded.")

Performance profiling loaded.


In [311]:
# =============================================================================
# CELL 11: Triton Cache Management
# =============================================================================
#
# TRITON JIT COMPILATION:
#   - Kernels are compiled at runtime targeting active GPU architecture
#   - First launch incurs ~100-500ms compile time
#   - Compiled kernels are cached in ~/.triton/cache/
#
# CACHE INVALIDATION TRIGGERS:
#   - Triton version change
#   - CUDA version change
#   - GPU architecture change
#   - Kernel source change (even whitespace!)
#
# DEPLOYMENT STRATEGY:
#   1. Run warmup script to compile all needed configs
#   2. Ship cache with project: export TRITON_CACHE_DIR=/path/to/.triton_cache
#
# =============================================================================

def warmup_triton_cache(configs=None):
    """
    Pre-compile Triton kernels for common configurations.
    
    Run this before training to avoid JIT compilation delays.
    """
    print("="*60)
    print("TRITON CACHE WARMUP")
    print("="*60)
    
    if configs is None:
        configs = [
            (1, 8, 8, 32, 64),    # Single sample, short
            (1, 64, 8, 32, 64),   # Single sample, short
            (1, 128, 8, 32, 64),  # Single sample, typical
            (4, 128, 8, 32, 64),  # Small batch
            (8, 128, 8, 32, 64),  # Medium batch
            (16, 128, 8, 32, 64), # Large batch
            (4, 256, 8, 32, 64),  # Longer sequence
            (4, 512, 8, 32, 64),  # Long sequence
        ]
    
    device = DEVICE
    
    for B, T, H, K, V in configs:
        print(f"  Compiling B={B}, T={T}, H={H}, K={K}, V={V}...", end=" ")
        
        k = F.normalize(torch.randn(B, T, H, K, device=device), dim=-1)
        v = torch.randn(B, T, H, V, device=device)
        beta = torch.sigmoid(torch.randn(B, T, H, device=device))
        g = torch.sigmoid(torch.randn(B, T, H, device=device))
        state = torch.zeros(B, H, K, V, device=device)
        
        # Forward
        out, final_state = triton_delta_rule(k, v, beta, g, state)
        
        # Backward
        d_out = torch.randn_like(out)
        d_state = torch.randn_like(final_state)
        triton_delta_rule_backward(k, v, beta, g, state, d_out, d_state)
        
        torch.cuda.synchronize()
        print("✓")
    
    print("\nCache warmed.")
    print("\nTriton cache location: ~/.triton/cache/")
    print("\nTo ship cache with project:")
    print("  export TRITON_CACHE_DIR=/path/to/project/.triton_cache")


def check_triton_cache():
    """Check Triton cache status."""
    import os
    cache_dir = os.path.expanduser("~/.triton/cache")
    
    if os.path.exists(cache_dir):
        files = list(Path(cache_dir).rglob("*"))
        total_size = sum(f.stat().st_size for f in files if f.is_file())
        print(f"Triton cache: {len(files)} files, {total_size/1e6:.1f} MB")
        print(f"Location: {cache_dir}")
    else:
        print(f"Triton cache not found at {cache_dir}")


print("Triton cache management loaded.")

Triton cache management loaded.


In [312]:
# =============================================================================
# CELL 12: Quick Start / Validation
# =============================================================================

def quick_start():
    """Quick validation that everything works."""
    print("\n" + "="*70)
    print("GROUNDTHINK V6 - QUICK START VALIDATION")
    print("="*70)
    
    # 1. Create model
    cfg = HybridConfig(d_model=256, n_heads=8, layer_pattern="GS")
    model = TransparentHybrid(cfg).to(DEVICE).bfloat16()
    
    print(f"\n1. Model Created")
    print(f"   Pattern: {cfg.layer_pattern}")
    print(f"   Parameters: {count_params(model):,}")
    
    # 2. Test forward pass
    print(f"\n2. Forward Pass Test")
    x = torch.randint(0, 1000, (2, 64), device=DEVICE)
    with torch.no_grad():
        logits, _, diags, state = model(x, return_diagnostics=True)
    print(f"   Output shape: {logits.shape}")
    print(f"   State norm: {state.norm().item():.4f}")
    print(f"   GDN: β={diags[0]['beta_mean']:.3f}, g={diags[0]['g_mean']:.3f}")
    print(f"   ✓ Forward pass OK")
    
    # 3. Test backward pass
    print(f"\n3. Backward Pass Test")
    model.train()
    x = torch.randint(0, 1000, (2, 64), device=DEVICE)
    y = torch.randint(0, 1000, (2, 64), device=DEVICE)
    _, loss, _, _ = model(x, y)
    loss.backward()
    print(f"   Loss: {loss.item():.4f}")
    print(f"   ✓ Backward pass OK")
    
    # 4. Delta Rule validation
    print(f"\n4. Delta Rule Test")
    passed = test_identical_tokens()
    
    # 5. NIAH test (untrained)
    print(f"\n5. NIAH Test (untrained model)")
    model.eval()
    proper_niah_test(model, seq_len=64, needle_pos=20, n_trials=10)
    
    print("\n" + "="*70)
    print("✓ QUICK START COMPLETE - Model ready for training")
    print("="*70)
    
    return model, cfg


# Run quick start
model, cfg = quick_start()


GROUNDTHINK V6 - QUICK START VALIDATION

1. Model Created
   Pattern: GS
   Parameters: 14,741,016

2. Forward Pass Test
   Output shape: torch.Size([2, 64, 50257])
   State norm: 6.4062
   GDN: β=0.124, g=0.980
   ✓ Forward pass OK

3. Backward Pass Test
   Loss: 10.8125
   ✓ Backward pass OK

4. Delta Rule Test

TEST: Identical Tokens (Redundancy Suppression)
  Error1 norm: 15.9750
  Error2 norm: 0.000001 (should be ~0)
  Error ratio: 0.000000
  State growth: 1.0000x (should be ~1.0)
  → ✓ PASS

5. NIAH Test (untrained model)
  Accuracy: 0.0% (0/10)

✓ QUICK START COMPLETE - Model ready for training


In [313]:
# =============================================================================
# CELL 13: Training Execution
# =============================================================================

# Load data
data_loader = load_data(n_tokens=500_000, seq_len=128, batch_size=16)

# Create fresh model for training
cfg = HybridConfig(d_model=256, n_heads=8, layer_pattern="GS")
model = TransparentHybrid(cfg).to(DEVICE).bfloat16()

print(f"\nModel: {cfg.layer_pattern}")
print(f"Parameters: {count_params(model):,}")

# Pre-training check
print("\n--- Pre-training State ---")
x = torch.randint(0, 1000, (1, 128), device=DEVICE)
with torch.no_grad():
    _, _, diags, state = model(x, return_diagnostics=True)
print(f"Initial state norm: {state.norm().item():.4f}")
print(f"GDN: β={diags[0]['beta_mean']:.3f}, g={diags[0]['g_mean']:.3f}")

Loading 500,000 tokens from wikitext-103...
Loaded 500,000 tokens

Model: GS
Parameters: 14,741,016

--- Pre-training State ---
Initial state norm: 4.7188
GDN: β=0.123, g=0.980


In [314]:
def warmup_kernels(model, cfg):
    print("Status: Priming Triton Kernels (Burn-in Phase)...")
    model.train() 
    # Ensure DEVICE is set to "cuda"
    dummy_x = torch.randint(0, cfg.vocab_size, (1, 64)).cuda()
    
    for i in range(5):
        # Flexible Unpacking: Take the first element (logits), ignore the rest
        outputs = model(dummy_x)
        logits = outputs[0] if isinstance(outputs, (tuple, list)) else outputs
        
        loss = logits.sum()
        loss.backward()
        model.zero_grad()
        torch.cuda.synchronize()
        print(f"  Warmup Iteration {i+1}/5 Complete")
    
    print("Status: Kernels hot. Ready for training.")

# Execute
warmup_kernels(model, cfg)

Status: Priming Triton Kernels (Burn-in Phase)...
  Warmup Iteration 1/5 Complete
  Warmup Iteration 2/5 Complete
  Warmup Iteration 3/5 Complete
  Warmup Iteration 4/5 Complete
  Warmup Iteration 5/5 Complete
Status: Kernels hot. Ready for training.


In [315]:
# Train!
history = train_curriculum(
    model, 
    data_loader, 
    steps=2000, 
    warmup_steps=200,
    lr=2e-4, # adjusted from 3e-4 to 2e-4 to prevent collapse in mixed phase, study curriculum learning literature for details
    retrieval_weight=2.0,
    log_interval=100
)

Training 2000 steps (200 warmup)
  LR: 0.0002, Retrieval weight: 2.0
[WARMUP] Step     0: LM= 0.000 RET=10.750 (20.4 steps/s)
[WARMUP] Step   100: LM= 0.000 RET= 0.178 (49.8 steps/s)
[MIXED ] Step   200: LM=13.500 RET= 0.029 (50.4 steps/s)
[MIXED ] Step   300: LM= 7.531 RET= 0.014 (33.1 steps/s)
[MIXED ] Step   400: LM= 7.438 RET= 0.014 (28.4 steps/s)
[MIXED ] Step   500: LM= 7.219 RET= 0.013 (26.3 steps/s)
[MIXED ] Step   600: LM= 7.188 RET= 0.013 (26.2 steps/s)
[MIXED ] Step   700: LM= 7.344 RET= 0.014 (25.1 steps/s)
[MIXED ] Step   800: LM= 7.281 RET= 0.014 (24.3 steps/s)
[MIXED ] Step   900: LM= 7.281 RET= 0.014 (23.7 steps/s)
[MIXED ] Step  1000: LM= 7.281 RET= 0.013 (23.2 steps/s)
[MIXED ] Step  1100: LM= 7.312 RET= 0.014 (22.8 steps/s)
[MIXED ] Step  1200: LM= 7.312 RET= 0.013 (22.9 steps/s)
[MIXED ] Step  1300: LM= 7.281 RET= 0.013 (22.5 steps/s)
[MIXED ] Step  1400: LM= 7.438 RET= 0.014 (22.3 steps/s)
[MIXED ] Step  1500: LM= 7.438 RET= 0.014 (22.1 steps/s)
[MIXED ] Step  1600

In [316]:
# =============================================================================
# CELL 14: Post-Training Evaluation
# =============================================================================

print("="*60)
print("POST-TRAINING EVALUATION")
print("="*60)

# 1. NIAH Accuracy
print("\n1. NIAH Accuracy:")
proper_niah_test(model, seq_len=128, n_trials=30)

# 2. NIAH by Distance
print("\n2. NIAH by Distance:")
test_niah_by_distance(model, seq_len=128)

# 3. Full Diagnostic
print("\n3. State Health:")
run_full_diagnostic(model, seq_len=128)

# 4. Delta Rule Validation (post-training)
print("\n4. Delta Rule Validation:")
test_identical_tokens()

# 5. Gradient Analysis
print("\n5. Gradient Analysis:")
analyze_gradients(model)

print("\n" + "="*60)
print("EVALUATION COMPLETE")
print("="*60)

POST-TRAINING EVALUATION

1. NIAH Accuracy:
  Accuracy: 100.0% (30/30)

2. NIAH by Distance:

NIAH by Distance (seq_len=128):
  Distance   5 (pos=121):   Accuracy: 100.0% (20/20)
  Distance  10 (pos=116):   Accuracy: 100.0% (20/20)
  Distance  20 (pos=106):   Accuracy: 100.0% (20/20)
  Distance  40 (pos= 86):   Accuracy: 100.0% (20/20)
  Distance  60 (pos= 66):   Accuracy: 100.0% (20/20)
  Distance  95 (pos= 31):   Accuracy: 100.0% (20/20)

3. State Health:

DIAGNOSTIC REPORT

State Health:
  State norm: 7.75
  State max:  0.50
  ✓ State bounded - Delta Rule working!

Layer Diagnostics:
  Layer 0 (GDN): β=0.112 (max=0.365), g=0.980, state_norm=7.75
  Layer 1 (SWA): gate=0.922, local=81.50, retrieval=48.50

Prediction: 50254 (target: 50254) - ✓

4. Delta Rule Validation:

TEST: Identical Tokens (Redundancy Suppression)
  Error1 norm: 15.1082
  Error2 norm: 0.000001 (should be ~0)
  Error ratio: 0.000000
  State growth: 1.0000x (should be ~1.0)
  → ✓ PASS

5. Gradient Analysis:

GRADIENT

In [317]:
# =============================================================================
# CELL 15: Performance Profiling (Optional)
# =============================================================================

# Uncomment to run profiling:

# # 1. Triton kernel profiling
triton_results = profile_triton_kernel()

# # 2. FLA comparison (if available)
fla_results = profile_fla_kernel()

# # 3. Full model profiling
model_results = profile_full_model(model)

# # 4. Speedup analysis
if fla_results:
     print("\n" + "-"*60)
     print("SPEEDUP: Triton vs FLA")
     print("-"*60)
     for key in triton_results:
         if key in fla_results:
             speedup = fla_results[key]['avg_ms'] / triton_results[key]['avg_ms']
             print(f"B={key[0]:2d}, T={key[1]:4d}: {speedup:.2f}x")

print("Profiling cell ready. Uncomment to run.")

TRITON KERNEL PROFILING
Config: H=8, K=32, V=64
B= 1, T=  64:     0.13 ms |      493,280 tok/s
B= 1, T= 128:     0.26 ms |      496,505 tok/s
B= 1, T= 256:     0.64 ms |      397,564 tok/s
B= 1, T= 512:     0.53 ms |      970,825 tok/s
B= 4, T=  64:     0.24 ms |    1,055,400 tok/s
B= 4, T= 128:     0.35 ms |    1,482,763 tok/s
B= 4, T= 256:     0.44 ms |    2,305,276 tok/s
B= 4, T= 512:     0.56 ms |    3,628,184 tok/s
B= 8, T=  64:     0.16 ms |    3,117,177 tok/s
B= 8, T= 128:     0.26 ms |    3,960,540 tok/s
B= 8, T= 256:     0.52 ms |    3,946,534 tok/s
B= 8, T= 512:     0.84 ms |    4,871,025 tok/s

FLA CHUNKED KERNEL (for comparison)
B= 1, T=  64:     0.40 ms |      158,239 tok/s
B= 1, T= 128:     0.37 ms |      349,867 tok/s
B= 1, T= 256:     0.35 ms |      731,251 tok/s
B= 1, T= 512:     0.40 ms |    1,272,767 tok/s
B= 4, T=  64:     0.36 ms |      709,553 tok/s
B= 4, T= 128:     0.37 ms |    1,396,387 tok/s
B= 4, T= 256:     0.38 ms |    2,687,615 tok/s
B= 4, T= 512:     0.71