In [None]:
# CELL: Transparent Hybrid Architecture with Diagnostics
# Uses raw fla.ops, exposes all state, configurable layers

import torch
import torch.nn as nn
import torch.nn.functional as F
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
from dataclasses import dataclass
from typing import Optional, List, Dict, Literal
import math

@dataclass
class HybridConfig:
    """Fully configurable - start small, scale up"""
    d_model: int = 256
    n_heads: int = 8
    head_dim: int = 32        # d_model // n_heads
    expand_v: float = 2.0     # value expansion (GDN default)
    vocab_size: int = 50257
    
    # Layer pattern: 'G' = GDN, 'S' = SWA
    # Examples: "GS", "GGS", "GGSG", "GGGSGGGS"
    layer_pattern: str = "GS"
    
    # SWA config
    window_size: int = 1024
    
    # Init
    init_std: float = 0.02
    
    def __post_init__(self):
        self.head_dim = self.d_model // self.n_heads
        self.value_dim = int(self.head_dim * self.expand_v)
        
    @property
    def n_layers(self):
        return len(self.layer_pattern)
    
    def layer_type(self, idx: int) -> str:
        return self.layer_pattern[idx]


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
        return (x.float() * norm).type_as(x) * self.weight


class GatedDeltaNetLayer(nn.Module):
    """
    Transparent GDN using raw op.
    
    Delta Rule: Sₜ = αₜ * Sₜ₋₁ + βₜ * (vₜ ⊗ kₜ)
    - αₜ (gate g): controls forgetting (in log space)
    - βₜ (beta): controls write strength
    - Sₜ: state matrix [B, H, K, V] - the memory
    
    Output: Sₜ @ qₜ (query the memory)
    """
    def __init__(self, cfg: HybridConfig, layer_idx: int):
        super().__init__()
        self.cfg = cfg
        self.layer_idx = layer_idx
        
        # Projections
        self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * cfg.head_dim, bias=False)
        self.k_proj = nn.Linear(cfg.d_model, cfg.n_heads * cfg.head_dim, bias=False)
        self.v_proj = nn.Linear(cfg.d_model, cfg.n_heads * cfg.value_dim, bias=False)
        self.o_proj = nn.Linear(cfg.n_heads * cfg.value_dim, cfg.d_model, bias=False)
        
        # Gate projections (per-head scalars)
        self.beta_proj = nn.Linear(cfg.d_model, cfg.n_heads, bias=False)  # write strength
        self.g_proj = nn.Linear(cfg.d_model, cfg.n_heads, bias=False)     # forget gate (log space)
        
        self.norm = RMSNorm(cfg.d_model)
        
    def forward(self, x: torch.Tensor, 
                initial_state: Optional[torch.Tensor] = None,
                output_state: bool = True) -> tuple:
        """
        Args:
            x: [B, T, D]
            initial_state: [B, H, K, V] or None
            output_state: whether to return final state
            
        Returns:
            output: [B, T, D]
            state: [B, H, K, V] if output_state else None
            diagnostics: dict with internal values for inspection
        """
        B, T, D = x.shape
        H, K, V = self.cfg.n_heads, self.cfg.head_dim, self.cfg.value_dim
        
        # Pre-norm
        x_norm = self.norm(x)
        
        # Project to q, k, v
        q = self.q_proj(x_norm).view(B, T, H, K)
        k = self.k_proj(x_norm).view(B, T, H, K)
        v = self.v_proj(x_norm).view(B, T, H, V)
        
        # Normalize k for stability (as per FLA convention)
        k = F.normalize(k.float(), p=2, dim=-1).to(x.dtype)
        
        # Gates
        beta = self.beta_proj(x_norm).sigmoid()  # [B, T, H] write strength ∈ (0,1)
        g = F.logsigmoid(self.g_proj(x_norm))    # [B, T, H] forget gate in log space
        
        # Core delta rule op
        output, state = chunk_gated_delta_rule(
            q, k, v, g, beta,
            initial_state=initial_state,
            output_final_state=output_state
        )
        
        # Project back
        output = output.reshape(B, T, H * V)
        output = self.o_proj(output)
        
        # Residual
        output = x + output
        
        # Diagnostics for understanding
        diagnostics = {
            'beta_mean': beta.mean().item(),      # avg write strength
            'beta_std': beta.std().item(),
            'g_mean': g.exp().mean().item(),      # avg forget rate (converted from log)
            'g_std': g.exp().std().item(),
            'state_norm': state.norm().item() if state is not None else 0,
            'state_shape': tuple(state.shape) if state is not None else None,
        }
        
        return output, state, diagnostics


class SlidingWindowAttention(nn.Module):
    """
    SWA that can also attend to GDN state.
    
    Two attention sources:
    1. Local window (standard SWA)
    2. Global state from GDN (optional cross-attention)
    """
    def __init__(self, cfg: HybridConfig, layer_idx: int):
        super().__init__()
        self.cfg = cfg
        self.layer_idx = layer_idx
        
        # Standard attention projections
        self.q_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.k_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.v_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.o_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        
        # For querying GDN state: project state [H, K, V] to something usable
        # State is [B, H, K, V] - we can treat it as H key-value pairs
        self.state_k_proj = nn.Linear(cfg.head_dim, cfg.head_dim, bias=False)
        self.state_v_proj = nn.Linear(cfg.value_dim, cfg.head_dim, bias=False)  # compress V->K
        
        self.norm = RMSNorm(cfg.d_model)
        self.scale = cfg.head_dim ** -0.5
        
    def forward(self, x: torch.Tensor,
                gdn_state: Optional[torch.Tensor] = None) -> tuple:
        """
        Args:
            x: [B, T, D]
            gdn_state: [B, H, K, V] accumulated state from GDN layers
            
        Returns:
            output: [B, T, D]
            diagnostics: dict
        """
        B, T, D = x.shape
        H = self.cfg.n_heads
        K = self.cfg.head_dim
        W = self.cfg.window_size
        
        x_norm = self.norm(x)
        
        q = self.q_proj(x_norm).view(B, T, H, K)
        k = self.k_proj(x_norm).view(B, T, H, K)
        v = self.v_proj(x_norm).view(B, T, H, K)
        
        # Transpose for attention: [B, H, T, K]
        q, k, v = [t.transpose(1, 2) for t in (q, k, v)]
        
        # Sliding window mask
        mask = torch.ones(T, T, device=x.device, dtype=torch.bool)
        mask = mask.triu(1) | mask.tril(-W - 1)  # causal + window
        
        # Local attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        attn_weights_local = F.softmax(attn, dim=-1)
        local_out = attn_weights_local @ v  # [B, H, T, K]
        
        # State attention (if GDN state provided)
        state_out = None
        state_attn_weight = 0.0     # Is this influence amount from GDN to SWA?
        if gdn_state is not None:
            # gdn_state: [B, H, K, V]
            # Treat each row of state as a "memory slot"
            # Query: q [B, H, T, K]
            # State keys: state [B, H, K, V] -> we use the K dimension as "memory slots"
            
            # Simple approach: q @ state -> [B, H, T, V], then project back
            state_retrieved = q @ gdn_state.to(q.dtype)  # [B, H, T, V]
            state_out = self.state_v_proj(state_retrieved)  # [B, H, T, K]
            
            # Combine local and state (learnable mixing would be better, but start simple)
            # For now: add them
            state_attn_weight = (state_out.norm() / (local_out.norm() + 1e-8)).item()
        
        # Combine
        if state_out is not None:
            out = local_out + state_out
        else:
            out = local_out
        
        # Reshape and project
        out = out.transpose(1, 2).reshape(B, T, D)
        out = self.o_proj(out)
        out = x + out
        
        diagnostics = {
            'local_attn_entropy': -(attn_weights_local * attn_weights_local.clamp(min=1e-8).log()).sum(-1).mean().item(),
            'state_contribution': state_attn_weight,
        }
        
        return out, diagnostics


class FFN(nn.Module):
    """Simple SwiGLU FFN"""
    def __init__(self, cfg: HybridConfig):
        super().__init__()
        hidden = int(cfg.d_model * 8/3)
        hidden = ((hidden + 63) // 64) * 64  # round to 64
        
        self.w1 = nn.Linear(cfg.d_model, hidden, bias=False)
        self.w3 = nn.Linear(cfg.d_model, hidden, bias=False)
        self.w2 = nn.Linear(hidden, cfg.d_model, bias=False)
        self.norm = RMSNorm(cfg.d_model)
        
    def forward(self, x):
        h = self.norm(x)
        return x + self.w2(F.silu(self.w1(h)) * self.w3(h))


class TransparentHybrid(nn.Module):
    """
    Configurable GDN + SWA hybrid with full visibility.
    
    Key insight: GDN state Sₜ is a [H, K, V] matrix per batch.
    - It accumulates information across the sequence
    - SWA can query it to retrieve global context
    - This is how a needle at pos 100 reaches output at pos 500
    """
    def __init__(self, cfg: HybridConfig):
        super().__init__()
        self.cfg = cfg
        
        # Embedding
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
        nn.init.normal_(self.embed.weight, std=cfg.init_std)
        
        # Build layers according to pattern
        self.layers = nn.ModuleList()
        self.ffns = nn.ModuleList()
        
        for i, layer_type in enumerate(cfg.layer_pattern):
            if layer_type == 'G':
                self.layers.append(GatedDeltaNetLayer(cfg, i))
            elif layer_type == 'S':
                self.layers.append(SlidingWindowAttention(cfg, i))
            else:
                raise ValueError(f"Unknown layer type: {layer_type}")
            self.ffns.append(FFN(cfg))
        
        self.norm_f = RMSNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        self.lm_head.weight = self.embed.weight  # tie weights
        
        # Track which layers are GDN for state accumulation
        self.gdn_indices = [i for i, t in enumerate(cfg.layer_pattern) if t == 'G']
        self.swa_indices = [i for i, t in enumerate(cfg.layer_pattern) if t == 'S']
        
        print(f"Architecture: {cfg.layer_pattern}")
        print(f"  GDN layers: {self.gdn_indices}")
        print(f"  SWA layers: {self.swa_indices}")
        
    def forward(self, input_ids: torch.Tensor, 
                targets: Optional[torch.Tensor] = None,
                return_diagnostics: bool = False) -> tuple:
        """
        Args:
            input_ids: [B, T]
            targets: [B, T] for loss computation
            return_diagnostics: whether to return per-layer diagnostics
            
        Returns:
            logits: [B, T, V]
            loss: scalar if targets provided
            diagnostics: dict if return_diagnostics
        """
        x = self.embed(input_ids)
        
        # Track GDN state - accumulate across GDN layers
        accumulated_state = None
        all_diagnostics = []
        
        for i, (layer, ffn) in enumerate(zip(self.layers, self.ffns)):
            layer_type = self.cfg.layer_pattern[i]
            
            if layer_type == 'G':
                x, state, diag = layer(x, initial_state=accumulated_state, output_state=True)
                # Accumulate state (could also replace - design choice)
                if accumulated_state is None:
                    accumulated_state = state
                else:
                    # Weighted combination - newer state more important
                    accumulated_state = 0.5 * accumulated_state + 0.5 * state
                diag['layer_type'] = 'GDN'
                
            elif layer_type == 'S':
                x, diag = layer(x, gdn_state=accumulated_state)
                diag['layer_type'] = 'SWA'
            
            x = ffn(x)
            diag['layer_idx'] = i
            all_diagnostics.append(diag)
        
        x = self.norm_f(x)
        logits = self.lm_head(x)
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        
        if return_diagnostics:
            return logits, loss, all_diagnostics, accumulated_state
        return logits, loss, None, None
    
    def probe_state(self, input_ids: torch.Tensor, 
                    needle_pos: int,
                    query_pos: int) -> dict:
        """
        Diagnostic: Check if needle information is in the state.
        
        Returns analysis of whether the needle token at needle_pos
        can be retrieved from state at query_pos.
        """
        self.eval()
        with torch.no_grad():
            x = self.embed(input_ids)
            accumulated_state = None
            
            results = {'per_layer': []}
            
            for i, (layer, ffn) in enumerate(zip(self.layers, self.ffns)):
                layer_type = self.cfg.layer_pattern[i]
                
                if layer_type == 'G':
                    x, state, _ = layer(x, initial_state=accumulated_state, output_state=True)
                    
                    if accumulated_state is None:
                        accumulated_state = state
                    else:
                        accumulated_state = 0.5 * accumulated_state + 0.5 * state
                    
                    # Check: can we find needle in state?
                    # Get needle's key representation
                    needle_embed = self.embed.weight[input_ids[0, needle_pos]]
                    
                    # State is [B, H, K, V] - query it
                    # A simple probe: project needle through k_proj, query state
                    k_proj = layer.k_proj
                    needle_key = k_proj(needle_embed).view(self.cfg.n_heads, self.cfg.head_dim)
                    needle_key = F.normalize(needle_key.float(), p=2, dim=-1)
                    
                    # Query state: needle_key @ state -> retrieval
                    # state: [1, H, K, V]
                    retrieved = torch.einsum('hk,bhkv->bhv', needle_key, accumulated_state.float())
                    
                    results['per_layer'].append({
                        'layer': i,
                        'type': 'GDN',
                        'state_norm': state.norm().item(),
                        'retrieved_norm': retrieved.norm().item(),
                    })
                    
                elif layer_type == 'S':
                    x, _ = layer(x, gdn_state=accumulated_state)
                    results['per_layer'].append({
                        'layer': i,
                        'type': 'SWA',
                    })
                
                x = ffn(x)
            
            results['final_state_norm'] = accumulated_state.norm().item() if accumulated_state is not None else 0
            
        return results


def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    gdn = sum(p.numel() for i, l in enumerate(model.layers) if model.cfg.layer_pattern[i] == 'G' for p in l.parameters())
    swa = sum(p.numel() for i, l in enumerate(model.layers) if model.cfg.layer_pattern[i] == 'S' for p in l.parameters())
    ffn = sum(p.numel() for f in model.ffns for p in f.parameters())
    return {'total': total, 'gdn': gdn, 'swa': swa, 'ffn': ffn}


# ============ TEST IT ============
print("="*60)
print("Building TransparentHybrid")
print("="*60)

# Start minimal
cfg = HybridConfig(
    d_model=512,
    n_heads=8,
    layer_pattern="GS",  # Just 1 GDN + 1 SWA to understand
    window_size=128,
)

model = TransparentHybrid(cfg).cuda().bfloat16()
params = count_params(model)
print(f"Parameters: {params['total']/1e6:.2f}M (GDN: {params['gdn']/1e6:.2f}M, SWA: {params['swa']/1e6:.2f}M)")

# Test forward
x = torch.randint(0, 1000, (1, 128), device='cuda')
y = torch.randint(0, 1000, (1, 128), device='cuda')

logits, loss, diagnostics, state = model(x, y, return_diagnostics=True)
print(f"\nForward pass:")
print(f"  Logits: {logits.shape}")
print(f"  Loss: {loss.item():.4f} (expected ~{math.log(cfg.vocab_size):.2f})")
print(f"  Final state: {state.shape if state is not None else None}")

print(f"\nPer-layer diagnostics:")
for d in diagnostics:
    print(f"  Layer {d['layer_idx']} [{d['layer_type']}]: {d}")

In [None]:
# 1. NIAH test - does the needle actually get stored and retrieved?
def simple_niah(model, seq_len=128, needle_pos=32, n_trials=20):
    """Put a rare token early, see if model predicts it at the end"""
    model.eval()
    needle_token = 50000  # rare token
    
    results = []
    with torch.no_grad():
        for _ in range(n_trials):
            # Random tokens with needle inserted
            tokens = torch.randint(1000, 10000, (1, seq_len), device='cuda')
            tokens[0, needle_pos] = needle_token
            
            logits, _, diags, state = model(tokens, return_diagnostics=True)
            
            # Check: does the final position predict the needle?
            final_probs = F.softmax(logits[0, -1].float(), dim=-1)
            needle_prob = final_probs[needle_token].item()
            random_baseline = 1.0 / model.cfg.vocab_size
            
            results.append({
                'needle_prob': needle_prob,
                'ratio': needle_prob / random_baseline,
                'state_norm': state.norm().item(),
            })
    
    avg_ratio = sum(r['ratio'] for r in results) / len(results)
    print(f"NIAH (untrained): {avg_ratio:.4f}x random")
    print(f"  (>1.0 means model finds needle, <1.0 means no retrieval)")
    return results

# Test before training
niah_results = simple_niah(model)

# 2. Probe: where is needle info stored?
print("\n" + "="*60)
print("Probing state for needle information")
print("="*60)
tokens = torch.randint(1000, 10000, (1, 128), device='cuda')
tokens[0, 32] = 50000  # needle at pos 32

probe = model.probe_state(tokens, needle_pos=32, query_pos=127)
print(f"Final state norm: {probe['final_state_norm']:.4f}")
for layer in probe['per_layer']:
    print(f"  {layer}")

In [None]:
# TRAINING WITH MONITORING
from torch.optim import AdamW
from datasets import load_dataset
from transformers import AutoTokenizer
import time

# Data setup
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

print("Loading data...")
ds = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT", split="train", streaming=True)

# Buffer tokens
token_buffer = []
target_tokens = 2_000_000  # 2M for quick test

for doc in ds:
    toks = tokenizer.encode(doc['text'])
    token_buffer.extend(toks)
    if len(token_buffer) >= target_tokens:
        break

token_tensor = torch.tensor(token_buffer[:target_tokens], device='cuda')
print(f"Loaded {len(token_tensor):,} tokens")

def get_batch(batch_size=4, seq_len=128):
    ix = torch.randint(0, len(token_tensor) - seq_len - 1, (batch_size,))
    x = torch.stack([token_tensor[i:i+seq_len] for i in ix])
    y = torch.stack([token_tensor[i+1:i+seq_len+1] for i in ix])
    return x, y

# Training config
STEPS = 20000
LR = 3e-4
BATCH = 4
SEQ_LEN = 128
LOG_EVERY = 100
NIAH_EVERY = 500

# Optimizer
opt = AdamW(model.parameters(), lr=LR, betas=(0.9, 0.95), weight_decay=0.1)

# Tracking
history = {
    'loss': [],
    'niah_ratio': [],
    'gdn_beta': [],
    'gdn_g': [],
    'state_norm': [],
    'swa_state_contrib': [],
}

print(f"\nTraining {STEPS} steps, batch={BATCH}, seq_len={SEQ_LEN}")
print("="*60)

model.train()
start = time.time()

for step in range(STEPS):
    # LR schedule: linear warmup then cosine
    if step < 200:
        lr = LR * (step + 1) / 200
    else:
        progress = (step - 200) / (STEPS - 200)
        lr = LR * 0.5 * (1 + math.cos(math.pi * progress))
    for pg in opt.param_groups:
        pg['lr'] = lr
    
    # Forward
    x, y = get_batch(BATCH, SEQ_LEN)
    logits, loss, diags, state = model(x, y, return_diagnostics=True)
    
    # Backward
    opt.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step()
    
    # Track
    history['loss'].append(loss.item())
    
    # Extract diagnostics
    gdn_diag = [d for d in diags if d['layer_type'] == 'GDN'][0]
    swa_diag = [d for d in diags if d['layer_type'] == 'SWA'][0]
    
    history['gdn_beta'].append(gdn_diag['beta_mean'])
    history['gdn_g'].append(gdn_diag['g_mean'])
    history['state_norm'].append(gdn_diag['state_norm'])
    history['swa_state_contrib'].append(swa_diag['state_contribution'])
    
    # Log
    if step % LOG_EVERY == 0:
        elapsed = time.time() - start
        tps = (step + 1) * BATCH * SEQ_LEN / elapsed
        avg_loss = sum(history['loss'][-50:]) / min(50, len(history['loss']))
        
        print(f"[{step:4d}] loss={avg_loss:.3f} lr={lr:.2e} | "
              f"β={gdn_diag['beta_mean']:.3f} g={gdn_diag['g_mean']:.3f} "
              f"state={gdn_diag['state_norm']:.1f} swa_state={swa_diag['state_contribution']:.2f} | "
              f"{tps:,.0f} tok/s")
    
    # NIAH check
    if (step + 1) % NIAH_EVERY == 0:
        model.eval()
        niah = simple_niah(model, seq_len=SEQ_LEN, needle_pos=32, n_trials=30)
        avg_ratio = sum(r['ratio'] for r in niah) / len(niah)
        history['niah_ratio'].append((step + 1, avg_ratio))
        status = "PASS" if avg_ratio > 1.0 else "FAIL"
        print(f"  >>> NIAH@{step+1}: {avg_ratio:.2f}x random [{status}]")
        model.train()

# Final summary
print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)
elapsed = time.time() - start
print(f"Time: {elapsed/60:.1f} min")
print(f"Loss: {history['loss'][0]:.2f} -> {sum(history['loss'][-50:])/50:.2f}")
print(f"NIAH trajectory: {history['niah_ratio']}")

# Final NIAH at multiple positions
print("\nFinal NIAH at different needle positions:")
model.eval()
for needle_pos in [16, 32, 64, 96]:
    niah = simple_niah(model, seq_len=128, needle_pos=needle_pos, n_trials=30)
    avg = sum(r['ratio'] for r in niah) / len(niah)
    print(f"  needle@{needle_pos}: {avg:.2f}x random")