# GroundThink V6 - Hybrid GatedDeltaNet + SWA (WSL Local)

**Gated Delta Rule:** `Sₜ = αₜ Sₜ₋₁ + βₜ Δₜ`
- `αₜ` (gate): rapid forgetting from Mamba2
- `βₜΔₜ` (delta): targeted updates from DeltaNet

**Architecture:** GatedDeltaNet (FLA) + SlidingWindowAttention (flash_attn)

**Required Environment:**
- PyTorch nightly (cu126)
- flash-attn (prebuilt wheel)
- flash-linear-attention 0.4.2+

In [None]:
# CELL 0: VERIFY ENVIRONMENT
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import time
import math
import numpy as np

print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Compute Capability: {torch.cuda.get_device_capability(0)}")

# Verify flash_attn
try:
    from flash_attn import flash_attn_func
    import flash_attn
    print(f"flash_attn: {flash_attn.__version__}")
    FLASH_ATTN_AVAILABLE = True
except ImportError as e:
    print(f"flash_attn: NOT AVAILABLE - {e}")
    FLASH_ATTN_AVAILABLE = False

# Verify FLA
try:
    from fla.layers import GatedDeltaNet
    print("FLA GatedDeltaNet: OK")
except ImportError as e:
    raise ImportError(f"FLA not available: {e}")

# Enable TF32 for Ampere+
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

print("\n✓ Environment ready")

In [None]:
# CELL 1: CONFIG
from dataclasses import dataclass, field
from typing import List

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Hardware detection
USE_FLASH = False
DTYPE = torch.float32

if torch.cuda.is_available():
    props = torch.cuda.get_device_properties(0)
    major, minor = torch.cuda.get_device_capability(0)
    print(f"GPU: {props.name} (Compute {major}.{minor}, {props.total_memory/1e9:.1f}GB)")
    
    # FlashAttention requires Ampere+ (sm_80+)
    if major >= 8 and FLASH_ATTN_AVAILABLE:
        USE_FLASH = True
        print("FlashAttention: ENABLED")
    else:
        print(f"FlashAttention: DISABLED (need Ampere+ and flash_attn installed)")
    
    # bfloat16 for Ampere+, float16 for older
    DTYPE = torch.bfloat16 if major >= 8 else torch.float16
    print(f"Training dtype: {DTYPE}")

@dataclass
class ModelConfig:
    vocab_size: int = 50257
    d_model: int = 256        # Small for RTX 4050
    n_layers: int = 12
    n_heads: int = 8
    head_dim: int = 32
    attn_interval: int = 4    # SWA every 4th layer (3:1 ratio)
    window_size: int = 512
    expand_k: float = 1.0
    expand_v: float = 2.0
    use_gradient_checkpointing: bool = True
    tie_weights: bool = True
    
    def __post_init__(self):
        self.head_dim = self.d_model // self.n_heads
    
    def get_swa_layer_indices(self):
        return [i for i in range(self.n_layers) if i % self.attn_interval == (self.attn_interval - 1)]

@dataclass
class TrainConfig:
    dataset_name: str = "HuggingFaceFW/fineweb-edu"
    dataset_subset: str = "sample-10BT"
    target_tokens: int = 10_000_000  # Smaller for local
    batch_size: int = 2
    seq_len: int = 512
    accum_steps: int = 2
    steps: int = 5000
    warmup_ratio: float = 0.1
    lr: float = 3e-4
    weight_decay: float = 0.1
    grad_clip: float = 1.0
    betas: tuple = (0.9, 0.95)
    log_interval: int = 50
    grad_log_interval: int = 500
    niah_checkpoints: List[int] = field(default_factory=lambda: [500, 1000, 2000, 3000, 5000])
    
    @property
    def warmup_steps(self): return int(self.steps * self.warmup_ratio)
    @property
    def effective_batch_size(self): return self.batch_size * self.accum_steps

MODEL_CFG = ModelConfig()
TRAIN_CFG = TrainConfig()
print(f"\nConfig: d={MODEL_CFG.d_model}, layers={MODEL_CFG.n_layers}, SWA@{MODEL_CFG.get_swa_layer_indices()}")

In [None]:
# CELL 2: MODEL COMPONENTS
from transformers import AutoTokenizer
from fla.layers import GatedDeltaNet

if USE_FLASH:
    from flash_attn import flash_attn_func

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
        return (x.float() * norm).type_as(x) * self.weight


class SwiGLUFFN(nn.Module):
    def __init__(self, d_model, expansion=8/3):
        super().__init__()
        hidden = ((int(d_model * expansion) + 63) // 64) * 64
        self.w1 = nn.Linear(d_model, hidden, bias=False)
        self.w3 = nn.Linear(d_model, hidden, bias=False)
        self.w2 = nn.Linear(hidden, d_model, bias=False)
        self.norm = RMSNorm(d_model)
    
    def forward(self, x):
        h = self.norm(x)
        return x + self.w2(F.silu(self.w1(h)) * self.w3(h))


class SlidingWindowAttention(nn.Module):
    """SWA with KV-Cache for inference."""
    def __init__(self, d_model, n_heads, window_size, layer_idx=0):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.window_size = window_size
        
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
    
    def forward(self, x, past_key_values=None, use_cache=False):
        B, T, D = x.shape
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim)
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim)
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim)
        
        current_cache = None
        if use_cache:
            if past_key_values is not None:
                pk, pv = past_key_values
                k = torch.cat([pk, k], dim=1)
                v = torch.cat([pv, v], dim=1)
            current_cache = (k[:, -self.window_size:].detach(), v[:, -self.window_size:].detach())
        
        # Inference mode with cache
        if use_cache and past_key_values is not None:
            q_t, k_t, v_t = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
            out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=False)
            out = out.transpose(1, 2)
        elif USE_FLASH:
            # Training with FlashAttention
            out = flash_attn_func(q, k, v, causal=True, window_size=(self.window_size, 0))
        else:
            # Manual sliding window fallback
            q_t, k_t, v_t = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
            mask = torch.ones(T, T, device=x.device, dtype=torch.bool).triu(1)
            mask |= torch.ones(T, T, device=x.device, dtype=torch.bool).tril(-self.window_size - 1)
            attn = (q_t @ k_t.transpose(-2, -1)) / (self.head_dim ** 0.5)
            attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
            out = (F.softmax(attn, dim=-1) @ v_t).transpose(1, 2)
        
        return self.out_proj(out.reshape(B, T, D)), current_cache


class HybridBlock(nn.Module):
    """GatedDeltaNet or SlidingWindowAttention block."""
    def __init__(self, d_model, is_attention, n_heads=8, window_size=512,
                 expand_k=1.0, expand_v=2.0, layer_idx=0):
        super().__init__()
        self.is_attention = is_attention
        self.layer_idx = layer_idx
        self.norm = RMSNorm(d_model)
        
        if is_attention:
            self.layer = SlidingWindowAttention(d_model, n_heads, window_size, layer_idx)
        else:
            # GatedDeltaNet: Sₜ = αₜ Sₜ₋₁ + βₜ Δₜ
            self.layer = GatedDeltaNet(
                hidden_size=d_model,
                expand_k=expand_k,
                expand_v=expand_v,
                layer_idx=layer_idx
            )
    
    def forward(self, x, past_state=None, use_cache=False):
        residual = x
        x = self.norm(x)
        new_state = None
        
        if self.is_attention:
            x, new_state = self.layer(x, past_key_values=past_state, use_cache=use_cache)
        else:
            # GatedDeltaNet always returns (output, state) tuple
            if use_cache:
                x, new_state = self.layer(x, initial_state=past_state, use_cache=True, output_final_state=True)
            else:
                out = self.layer(x)
                if isinstance(out, tuple):
                    x, new_state = out
                else:
                    x = out
        
        return residual + x, new_state


class GroundThinkLM(nn.Module):
    """Hybrid LM: GatedDeltaNet + SlidingWindowAttention"""
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.cfg = cfg
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
        
        swa_indices = set(cfg.get_swa_layer_indices())
        self._swa_indices = swa_indices
        
        self.blocks = nn.ModuleList()
        self.ffns = nn.ModuleList()
        for i in range(cfg.n_layers):
            self.blocks.append(HybridBlock(
                cfg.d_model, is_attention=(i in swa_indices),
                n_heads=cfg.n_heads, window_size=cfg.window_size,
                expand_k=cfg.expand_k, expand_v=cfg.expand_v, layer_idx=i
            ))
            self.ffns.append(SwiGLUFFN(cfg.d_model))
        
        self.norm_f = RMSNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        if cfg.tie_weights:
            self.lm_head.weight = self.embed.weight
    
    def forward(self, input_ids, targets=None, past_states=None, use_cache=False):
        x = self.embed(input_ids)
        new_states = [] if use_cache else None
        
        for i, (block, ffn) in enumerate(zip(self.blocks, self.ffns)):
            layer_past = past_states[i] if (past_states is not None and len(past_states) > i) else None
            
            if self.cfg.use_gradient_checkpointing and self.training and not use_cache and i in self._swa_indices:
                x = checkpoint(self._fwd_block, block, ffn, x, use_reentrant=False)
            else:
                x, layer_new_state = block(x, layer_past, use_cache)
                x = ffn(x)
                if use_cache:
                    new_states.append(layer_new_state)
        
        logits = self.lm_head(self.norm_f(x))
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) if targets is not None else None
        return logits, loss, new_states
    
    @staticmethod
    def _fwd_block(block, ffn, x):
        x, _ = block(x, None, False)
        return ffn(x)
    
    def get_layer_types(self):
        return ['SWA' if i in self._swa_indices else 'GDN' for i in range(self.cfg.n_layers)]
    
    def count_parameters(self):
        c = {'embed': sum(p.numel() for p in self.embed.parameters()), 'gdn': 0, 'swa': 0, 'ffn': 0}
        for i, (b, f) in enumerate(zip(self.blocks, self.ffns)):
            bp = sum(p.numel() for p in b.parameters())
            fp = sum(p.numel() for p in f.parameters())
            c['swa' if i in self._swa_indices else 'gdn'] += bp
            c['ffn'] += fp
        c['total'] = sum(c.values())
        return c

print("Model components defined")

In [None]:
# CELL 3: MONITORING

def print_gradient_summary(model):
    agg = {'embed': [], 'gdn': [], 'swa': [], 'ffn': []}
    for name, p in model.named_parameters():
        if p.grad is None:
            continue
        n = p.grad.norm().item()
        if 'embed' in name:
            agg['embed'].append(n)
        elif 'ffn' in name:
            agg['ffn'].append(n)
        elif 'blocks' in name:
            idx = int(name.split('.')[1])
            agg['swa' if idx in model._swa_indices else 'gdn'].append(n)
    print("Gradients:")
    for k, v in agg.items():
        if v:
            print(f"  {k}: mean={np.mean(v):.3f} max={np.max(v):.2f}")


def needle_test(model, tokenizer, seq_len=512, n_trials=50, needle_token=50250, device="cuda"):
    model.eval()
    probs = []
    with torch.no_grad():
        for _ in range(n_trials):
            tokens = torch.randint(1000, 10000, (1, seq_len), device=device)
            pos = torch.randint(64, seq_len - 64, (1,)).item()
            tokens[0, pos] = needle_token
            with torch.amp.autocast('cuda', dtype=DTYPE):
                logits, _, _ = model(tokens)
            probs.append(F.softmax(logits[0, -1].float(), dim=-1)[needle_token].item())
    rc = 1.0 / tokenizer.vocab_size
    return {'mean': np.mean(probs), 'ratio': np.mean(probs) / rc}


def probe_layers(model, needle_id=50250, seq_len=512, pos=256, device="cuda"):
    model.eval()
    tokens = torch.randint(1000, 10000, (1, seq_len), device=device)
    tokens[0, pos] = needle_id
    with torch.no_grad():
        x = model.embed(tokens)
        emb = model.embed.weight[needle_id].float()
        print("Needle representation through layers:")
        for i, (b, f) in enumerate(zip(model.blocks, model.ffns)):
            x, _ = b(x, None, False)
            x = f(x)
            sim = F.cosine_similarity(x[0, pos].float(), emb, dim=0).item()
            ltype = 'SWA' if i in model._swa_indices else 'GDN'
            print(f"  L{i:2d}[{ltype}]: {sim:+.3f}")

print("Monitoring functions ready")

In [None]:
# CELL 4: DATA LOADING
from datasets import load_dataset
from tqdm import tqdm

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
MODEL_CFG.vocab_size = tokenizer.vocab_size

print(f"Streaming {TRAIN_CFG.dataset_name}...")
ds = load_dataset(TRAIN_CFG.dataset_name, name=TRAIN_CFG.dataset_subset, split="train", streaming=True)

buf = []
pbar = tqdm(total=TRAIN_CFG.target_tokens, unit="tok", desc="Tokenizing")
for ex in ds:
    toks = tokenizer.encode(ex['text']) + [tokenizer.eos_token_id]
    buf.extend(toks)
    pbar.update(len(toks))
    if len(buf) >= TRAIN_CFG.target_tokens:
        break
pbar.close()

all_tokens = torch.tensor(buf[:TRAIN_CFG.target_tokens], dtype=torch.long)
del buf, ds
print(f"Loaded {len(all_tokens):,} tokens")

def get_batch():
    ix = torch.randint(len(all_tokens) - TRAIN_CFG.seq_len - 1, (TRAIN_CFG.batch_size,))
    x = torch.stack([all_tokens[i:i+TRAIN_CFG.seq_len] for i in ix])
    y = torch.stack([all_tokens[i+1:i+TRAIN_CFG.seq_len+1] for i in ix])
    return x.to(DEVICE), y.to(DEVICE)

In [None]:
# CELL 5: BUILD MODEL
print("Building model...")
model = GroundThinkLM(MODEL_CFG).to(DEVICE).to(DTYPE)

p = model.count_parameters()
print(f"Parameters: {p['total']/1e6:.2f}M")
print(f"  GDN: {p['gdn']/1e6:.2f}M, SWA: {p['swa']/1e6:.2f}M, FFN: {p['ffn']/1e6:.2f}M")
print(f"Layers: {model.get_layer_types()}")

# Test forward/backward
print("\nTesting forward/backward...")
x, y = get_batch()
with torch.amp.autocast('cuda', dtype=DTYPE):
    _, loss, _ = model(x, y)
loss.backward()
print(f"Forward OK: loss={loss.item():.4f}")
print(f"Peak memory: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
model.zero_grad()
torch.cuda.reset_peak_memory_stats()

In [None]:
# CELL 6: TRAINING LOOP
opt = torch.optim.AdamW(model.parameters(), lr=TRAIN_CFG.lr, betas=TRAIN_CFG.betas, weight_decay=TRAIN_CFG.weight_decay)
losses, niah_traj = [], []
start = time.time()

print(f"\nTRAINING {TRAIN_CFG.steps} steps (effective batch={TRAIN_CFG.effective_batch_size})\n")
model.train()

for step in range(TRAIN_CFG.steps):
    # LR schedule: warmup then cosine decay
    if step < TRAIN_CFG.warmup_steps:
        lr = TRAIN_CFG.lr * (step + 1) / TRAIN_CFG.warmup_steps
    else:
        progress = (step - TRAIN_CFG.warmup_steps) / (TRAIN_CFG.steps - TRAIN_CFG.warmup_steps)
        lr = TRAIN_CFG.lr * 0.5 * (1 + math.cos(math.pi * progress))
    for pg in opt.param_groups:
        pg['lr'] = lr
    
    # Gradient accumulation
    acc_loss = 0
    for _ in range(TRAIN_CFG.accum_steps):
        x, y = get_batch()
        with torch.amp.autocast('cuda', dtype=DTYPE):
            _, loss, _ = model(x, y)
        (loss / TRAIN_CFG.accum_steps).backward()
        acc_loss += loss.item()
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), TRAIN_CFG.grad_clip)
    opt.step()
    opt.zero_grad()
    losses.append(acc_loss / TRAIN_CFG.accum_steps)
    
    # Logging
    if step % TRAIN_CFG.log_interval == 0:
        avg = np.mean(losses[-50:]) if len(losses) >= 50 else np.mean(losses)
        elapsed = time.time() - start
        tps = (step + 1) * TRAIN_CFG.effective_batch_size * TRAIN_CFG.seq_len / elapsed
        print(f"[{step:5d}/{TRAIN_CFG.steps}] loss={avg:.4f} lr={lr:.2e} {tps:,.0f} tok/s")
    
    if (step + 1) % TRAIN_CFG.grad_log_interval == 0:
        print_gradient_summary(model)
    
    # NIAH checkpoint
    if (step + 1) in TRAIN_CFG.niah_checkpoints:
        n = needle_test(model, tokenizer, TRAIN_CFG.seq_len, 30, device=DEVICE)
        niah_traj.append((step + 1, n['ratio']))
        status = "PASS" if n['ratio'] > 1.0 else "FAIL"
        print(f"  >>> NIAH@{step+1}: {n['ratio']:.2f}x random [{status}]")
        model.train()

# Summary
elapsed = time.time() - start
print(f"\n{'='*60}")
print(f"Training complete in {elapsed/60:.1f} minutes")
print(f"Loss: {np.mean(losses[:50]):.4f} -> {np.mean(losses[-50:]):.4f}")
print(f"NIAH trajectory: {niah_traj}")

In [None]:
# CELL 7: FINAL EVALUATION
print("\n" + "="*60)
print("FINAL EVALUATION")
print("="*60)

# NIAH at multiple lengths
for L in [128, 256, 512]:
    n = needle_test(model, tokenizer, L, 50, device=DEVICE)
    status = "PASS" if n['ratio'] > 1.0 else "FAIL"
    print(f"NIAH@{L}: {n['ratio']:.2f}x random [{status}]")

# Layer probing
print()
probe_layers(model, device=DEVICE)

# Verdict
lm_pass = np.mean(losses[:50]) - np.mean(losses[-50:]) > 2.0
niah_pass = any(r > 1.0 for _, r in niah_traj)

print(f"\nVerdict:")
print(f"  LM Training: {'PASS' if lm_pass else 'MARGINAL'}")
print(f"  NIAH: {'PASS' if niah_pass else 'FAIL'}")

In [None]:
# CELL 8: SAVE CHECKPOINT
import os
from pathlib import Path

save_dir = Path("./checkpoints")
save_dir.mkdir(exist_ok=True)

ckpt_path = save_dir / "groundthink_v6_final.pt"
torch.save({
    'model_state_dict': model.state_dict(),
    'config': MODEL_CFG,
    'losses': losses,
    'niah_trajectory': niah_traj,
}, ckpt_path)

print(f"Saved checkpoint: {ckpt_path}")
print(f"Size: {ckpt_path.stat().st_size / 1e6:.1f} MB")