# GroundThink V6 - Hybrid GatedDeltaNet + SWA (WSL Local)

**Gated Delta Rule:** `Sₜ = αₜ Sₜ₋₁ + βₜ Δₜ`
- `αₜ` (gate): rapid forgetting from Mamba2
- `βₜΔₜ` (delta): targeted updates from DeltaNet

**Architecture:** GatedDeltaNet (FLA) + SlidingWindowAttention (flash_attn)

**Required Environment:**
- PyTorch nightly (cu126)
- flash-attn (prebuilt wheel)
- flash-linear-attention 0.4.2+

In [1]:
# CELL 0: VERIFY ENVIRONMENT
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import time
import math
import numpy as np

print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Compute Capability: {torch.cuda.get_device_capability(0)}")

# Verify flash_attn
try:
    from flash_attn import flash_attn_func
    import flash_attn
    print(f"flash_attn: {flash_attn.__version__}")
    FLASH_ATTN_AVAILABLE = True
except ImportError as e:
    print(f"flash_attn: NOT AVAILABLE - {e}")
    FLASH_ATTN_AVAILABLE = False

# Verify FLA
try:
    from fla.layers import GatedDeltaNet
    print("FLA GatedDeltaNet: OK")
except ImportError as e:
    raise ImportError(f"FLA not available: {e}")

# Enable TF32 for Ampere+
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

print("\n✓ Environment ready")

Python: 3.12.3 (main, Jan  8 2026, 11:30:50) [GCC 13.3.0]
PyTorch: 2.11.0.dev20260128+cu128
CUDA available: True
GPU: NVIDIA GeForce RTX 4050 Laptop GPU
Compute Capability: (8, 9)
flash_attn: 2.8.3


  from .autonotebook import tqdm as notebook_tqdm


FLA GatedDeltaNet: OK

✓ Environment ready


In [2]:
# CELL 1: CONFIG
from dataclasses import dataclass, field
from typing import List

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Hardware detection
USE_FLASH = False
DTYPE = torch.float32

if torch.cuda.is_available():
    props = torch.cuda.get_device_properties(0)
    major, minor = torch.cuda.get_device_capability(0)
    print(f"GPU: {props.name} (Compute {major}.{minor}, {props.total_memory/1e9:.1f}GB)")
    
    # FlashAttention requires Ampere+ (sm_80+)
    if major >= 8 and FLASH_ATTN_AVAILABLE:
        USE_FLASH = True
        print("FlashAttention: ENABLED")
    else:
        print(f"FlashAttention: DISABLED (need Ampere+ and flash_attn installed)")
    
    # bfloat16 for Ampere+, float16 for older
    DTYPE = torch.bfloat16 if major >= 8 else torch.float16
    print(f"Training dtype: {DTYPE}")

@dataclass
class ModelConfig:
    vocab_size: int = 50257
    d_model: int = 256        # Small for RTX 4050
    n_layers: int = 12
    n_heads: int = 8
    head_dim: int = 32
    attn_interval: int = 4    # SWA every 4th layer (3:1 ratio)
    window_size: int = 512
    expand_k: float = 1.0
    expand_v: float = 2.0
    use_gradient_checkpointing: bool = True
    tie_weights: bool = True
    
    def __post_init__(self):
        self.head_dim = self.d_model // self.n_heads
    
    def get_swa_layer_indices(self):
        return [i for i in range(self.n_layers) if i % self.attn_interval == (self.attn_interval - 1)]

@dataclass
class TrainConfig:
    dataset_name: str = "HuggingFaceFW/fineweb-edu"
    dataset_subset: str = "sample-10BT"
    target_tokens: int = 10_000_000  # Smaller for local
    batch_size: int = 2
    seq_len: int = 512
    accum_steps: int = 2
    steps: int = 5000
    warmup_ratio: float = 0.1
    lr: float = 3e-4
    weight_decay: float = 0.1
    grad_clip: float = 1.0
    betas: tuple = (0.9, 0.95)
    log_interval: int = 50
    grad_log_interval: int = 500
    niah_checkpoints: List[int] = field(default_factory=lambda: [500, 1000, 2000, 3000, 5000])
    
    @property
    def warmup_steps(self): return int(self.steps * self.warmup_ratio)
    @property
    def effective_batch_size(self): return self.batch_size * self.accum_steps

MODEL_CFG = ModelConfig()
TRAIN_CFG = TrainConfig()
print(f"\nConfig: d={MODEL_CFG.d_model}, layers={MODEL_CFG.n_layers}, SWA@{MODEL_CFG.get_swa_layer_indices()}")

GPU: NVIDIA GeForce RTX 4050 Laptop GPU (Compute 8.9, 6.4GB)
FlashAttention: ENABLED
Training dtype: torch.bfloat16

Config: d=256, layers=12, SWA@[3, 7, 11]


In [3]:
# CELL 2: MODEL COMPONENTS
from xml.parsers.expat import model
from transformers import AutoTokenizer
from fla.layers import GatedDeltaNet

if USE_FLASH:
    from flash_attn import flash_attn_func

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
        return (x.float() * norm).type_as(x) * self.weight


class SwiGLUFFN(nn.Module):
    def __init__(self, d_model, expansion=8/3):
        super().__init__()
        hidden = ((int(d_model * expansion) + 63) // 64) * 64
        self.w1 = nn.Linear(d_model, hidden, bias=False)
        self.w3 = nn.Linear(d_model, hidden, bias=False)
        self.w2 = nn.Linear(hidden, d_model, bias=False)
        self.norm = RMSNorm(d_model)
    
    def forward(self, x):
        h = self.norm(x)
        return x + self.w2(F.silu(self.w1(h)) * self.w3(h))


class SlidingWindowAttention(nn.Module):
    """SWA with KV-Cache for inference."""
    def __init__(self, d_model, n_heads, window_size, layer_idx=0):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.window_size = window_size
        
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
    
    def forward(self, x, past_key_values=None, use_cache=False):
        B, T, D = x.shape
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim)
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim)
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim)
        
        current_cache = None
        if use_cache:
            if past_key_values is not None:
                pk, pv = past_key_values
                k = torch.cat([pk, k], dim=1)
                v = torch.cat([pv, v], dim=1)
            current_cache = (k[:, -self.window_size:].detach(), v[:, -self.window_size:].detach())
        
        # Inference mode with cache
        if use_cache and past_key_values is not None:
            q_t, k_t, v_t = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
            out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=False)
            out = out.transpose(1, 2)
        elif USE_FLASH:
            # Training with FlashAttention
            out = flash_attn_func(q, k, v, causal=True, window_size=(self.window_size, 0))
        else:
            # Manual sliding window fallback
            q_t, k_t, v_t = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
            mask = torch.ones(T, T, device=x.device, dtype=torch.bool).triu(1)
            mask |= torch.ones(T, T, device=x.device, dtype=torch.bool).tril(-self.window_size - 1)
            attn = (q_t @ k_t.transpose(-2, -1)) / (self.head_dim ** 0.5)
            attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
            out = (F.softmax(attn, dim=-1) @ v_t).transpose(1, 2)
        
        return self.out_proj(out.reshape(B, T, D)), current_cache


class HybridBlock(nn.Module):
    """GatedDeltaNet or SlidingWindowAttention block."""
    def __init__(self, d_model, is_attention, n_heads=8, window_size=512,
                 expand_k=1.0, expand_v=2.0, layer_idx=0):
        super().__init__()
        self.is_attention = is_attention
        self.layer_idx = layer_idx
        self.norm = RMSNorm(d_model)
        
        if is_attention:
            self.layer = SlidingWindowAttention(d_model, n_heads, window_size, layer_idx)
        else:
            # GatedDeltaNet: Sₜ = αₜ Sₜ₋₁ + βₜ Δₜ
            self.layer = GatedDeltaNet(
                hidden_size=d_model,
                expand_k=expand_k,
                expand_v=expand_v,
                layer_idx=layer_idx
            )
    
    def forward(self, x, past_state=None, use_cache=False):
        residual = x
        x = self.norm(x)
        new_state = None
        
        if self.is_attention:
            x, new_state = self.layer(x, past_key_values=past_state, use_cache=use_cache)
        else:
            # GatedDeltaNet always returns (output, state) tuple
            if use_cache:
                x, new_state = self.layer(x, initial_state=past_state, use_cache=True, output_final_state=True)
            else:
                out = self.layer(x)
                if isinstance(out, tuple):
                    x = out[0]
                    new_state = out[-1] if len(out) > 1 else None
                else:
                    x = out
        
        return residual + x, new_state


class GroundThinkLM(nn.Module):
    """Hybrid LM: GatedDeltaNet + SlidingWindowAttention"""
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.cfg = cfg
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
        nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)
        print(f"New embed std: {self.embed.weight.std().item():.4f}")

        swa_indices = set(cfg.get_swa_layer_indices())
        self._swa_indices = swa_indices
        
        self.blocks = nn.ModuleList()
        self.ffns = nn.ModuleList()
        for i in range(cfg.n_layers):
            self.blocks.append(HybridBlock(
                cfg.d_model, is_attention=(i in swa_indices),
                n_heads=cfg.n_heads, window_size=cfg.window_size,
                expand_k=cfg.expand_k, expand_v=cfg.expand_v, layer_idx=i
            ))
            self.ffns.append(SwiGLUFFN(cfg.d_model))
        
        self.norm_f = RMSNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        if cfg.tie_weights:
            self.lm_head.weight = self.embed.weight
        else:
            nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.02)
            
    def forward(self, input_ids, targets=None, past_states=None, use_cache=False):
        x = self.embed(input_ids)
        new_states = [] if use_cache else None
        
        for i, (block, ffn) in enumerate(zip(self.blocks, self.ffns)):
            layer_past = past_states[i] if (past_states is not None and len(past_states) > i) else None
            
            if self.cfg.use_gradient_checkpointing and self.training and not use_cache and i in self._swa_indices:
                x = checkpoint(self._fwd_block, block, ffn, x, use_reentrant=False)
            else:
                x, layer_new_state = block(x, layer_past, use_cache)
                x = ffn(x)
                if use_cache:
                    new_states.append(layer_new_state)
        
        logits = self.lm_head(self.norm_f(x))
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) if targets is not None else None
        return logits, loss, new_states
    
    @staticmethod
    def _fwd_block(block, ffn, x):
        x, _ = block(x, None, False)
        return ffn(x)
    
    def get_layer_types(self):
        return ['SWA' if i in self._swa_indices else 'GDN' for i in range(self.cfg.n_layers)]
    
    def count_parameters(self):
        c = {'embed': sum(p.numel() for p in self.embed.parameters()), 'gdn': 0, 'swa': 0, 'ffn': 0}
        for i, (b, f) in enumerate(zip(self.blocks, self.ffns)):
            bp = sum(p.numel() for p in b.parameters())
            fp = sum(p.numel() for p in f.parameters())
            c['swa' if i in self._swa_indices else 'gdn'] += bp
            c['ffn'] += fp
        c['total'] = sum(c.values())
        return c

print("Model components defined")

Model components defined


In [6]:
# Rebuild model reference
model = GroundThinkLM(MODEL_CFG).to(DEVICE).to(DTYPE)

# Now probe GatedDeltaNet state in eval mode (where it returns state)
from fla.layers import GatedDeltaNet

gdn = GatedDeltaNet(
    hidden_size=256,
    num_heads=8,
    head_dim=32,
    mode='chunk',
    layer_idx=0
).cuda().bfloat16()

gdn.eval()
x = torch.randn(1, 128, 256, device='cuda', dtype=torch.bfloat16)

with torch.no_grad():
    # Try to get state
    out = gdn(x, use_cache=True, output_attentions=True)
    
print(f"Output tuple length: {len(out)}")
for i, item in enumerate(out):
    if item is None:
        print(f"  [{i}]: None")
    elif hasattr(item, 'shape'):
        print(f"  [{i}]: shape={item.shape}, dtype={item.dtype}")
    else:
        print(f"  [{i}]: type={type(item)}")
        # Dig deeper
        if hasattr(item, '__dict__'):
            for k, v in vars(item).items():
                if hasattr(v, 'shape'):
                    print(f"       .{k}: shape={v.shape}")
                elif v is not None:
                    print(f"       .{k}: {type(v)}")

New embed std: 0.0200
Output tuple length: 3
  [0]: shape=torch.Size([1, 128, 256]), dtype=torch.bfloat16
  [1]: None
  [2]: None


In [8]:
import fla.ops as fla_ops
print(dir(fla_ops))

# Look for gated_delta or delta_net ops
import fla.ops.gated_delta_rule as gdr
print(dir(gdr))

# Check the actual chunk function signature
import inspect
if hasattr(gdr, 'chunk_gated_delta_rule'):
    sig = inspect.signature(gdr.chunk_gated_delta_rule)
    print("\nchunk_gated_delta_rule parameters:")
    for name, param in sig.parameters.items():
        print(f"  {name}: {param.default if param.default is not inspect.Parameter.empty else 'REQUIRED'}")

['__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'abc', 'attn', 'based', 'chunk_abc', 'chunk_comba', 'chunk_delta_rule', 'chunk_dplr_delta_rule', 'chunk_gated_delta_rule', 'chunk_gla', 'chunk_gsa', 'chunk_iplr_delta_rule', 'chunk_kda', 'chunk_lightning_attn', 'chunk_linear_attn', 'chunk_log_linear_attn', 'chunk_mesa_net', 'chunk_retention', 'chunk_rwkv6', 'chunk_rwkv7', 'chunk_simple_gla', 'comba', 'common', 'delta_rule', 'deltaformer', 'forgetting_attn', 'fused_chunk_based', 'fused_chunk_delta_rule', 'fused_chunk_gla', 'fused_chunk_linear_attn', 'fused_chunk_retention', 'fused_chunk_simple_gla', 'fused_recurrent_comba', 'fused_recurrent_delta_rule', 'fused_recurrent_dplr_delta_rule', 'fused_recurrent_gated_delta_rule', 'fused_recurrent_gla', 'fused_recurrent_gsa', 'fused_recurrent_hgrn', 'fused_recurrent_iplr_delta_rule', 'fused_recurrent_kda', 'fused_recurrent_lightning_attn', 'fused_recurrent_linear_att

In [9]:
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
import torch
import torch.nn.functional as F

B, T, H, D = 1, 128, 8, 32  # batch, seq, heads, head_dim
V = D * 2  # value dim (expand_v=2)

q = torch.randn(B, T, H, D, device='cuda', dtype=torch.bfloat16)
k = F.normalize(torch.randn(B, T, H, D, device='cuda', dtype=torch.float32), p=2, dim=-1).to(torch.bfloat16)
v = torch.randn(B, T, H, V, device='cuda', dtype=torch.bfloat16)
beta = torch.rand(B, T, H, device='cuda', dtype=torch.bfloat16).sigmoid()
g = F.logsigmoid(torch.rand(B, T, H, device='cuda', dtype=torch.bfloat16))  # gate in log space

# WITH output_final_state=True
output, final_state = chunk_gated_delta_rule(q, k, v, g, beta, output_final_state=True)

print(f"Output: {output.shape}")
print(f"Final state: {final_state.shape if final_state is not None else None}")
print(f"State dtype: {final_state.dtype if final_state is not None else None}")

Output: torch.Size([1, 128, 8, 64])
Final state: torch.Size([1, 8, 32, 64])
State dtype: torch.float32


In [7]:
# Sanity check the model
import torch
import torch.nn.functional as F

# 1. Check embedding -> logit path
model.eval()
x = torch.randint(0, 1000, (1, 16), device='cuda')
with torch.no_grad():
    logits, _, _ = model(x)

print(f"Logits shape: {logits.shape}")
print(f"Logits mean: {logits.mean().item():.4f}, std: {logits.std().item():.4f}")
print(f"Expected init loss: {np.log(MODEL_CFG.vocab_size):.4f}")

# 2. Check if GatedDeltaNet is actually processing (not just passing through)
gdn_block = model.blocks[0]  # First GDN block
x_in = model.embed(x)
x_out, _ = gdn_block(x_in, None, False)
diff = (x_out - x_in).abs().mean().item()
print(f"GDN block change magnitude: {diff:.6f}")

# 3. Check SWA block
swa_idx = list(model._swa_indices)[0]
swa_block = model.blocks[swa_idx]
x_out_swa, _ = swa_block(x_in, None, False)
diff_swa = (x_out_swa - x_in).abs().mean().item()
print(f"SWA block change magnitude: {diff_swa:.6f}")

Logits shape: torch.Size([1, 16, 50257])
Logits mean: 0.0005, std: 0.3203
Expected init loss: 10.8249
GDN block change magnitude: 0.028198
SWA block change magnitude: 0.106445


In [32]:
# 1. Check if Triton kernels are being called
import os
os.environ['TRITON_PRINT_AUTOTUNING'] = '1'  # Will print when triton tunes

# Run a forward pass and watch for triton output
x = torch.randint(0, 1000, (1, 64), device='cuda')
with torch.no_grad():
    out = model(x)

# 2. Check triton is importable and what FLA is using
import triton
print(f"Triton version: {triton.__version__}")

# 3. Check what mode GatedDeltaNet is actually using
gdn_layer = model.blocks[0].layer  # The actual GatedDeltaNet
print(f"GDN mode: {gdn_layer.mode}")

# 4. Check the weight scales - this is likely the real problem
print("\nWeight statistics:")
print(f"Embedding std: {model.embed.weight.std().item():.4f}")
print(f"LM head std: {model.lm_head.weight.std().item():.4f}")

for i, block in enumerate(model.blocks[:3]):
    if hasattr(block.layer, 'q_proj'):
        print(f"Block {i} q_proj std: {block.layer.q_proj.weight.std().item():.4f}")
    if hasattr(block.layer, 'o_proj'):
        print(f"Block {i} o_proj std: {block.layer.o_proj.weight.std().item():.4f}")

Triton version: 3.6.0
GDN mode: chunk

Weight statistics:
Embedding std: 1.0000
LM head std: 1.0000
Block 0 q_proj std: 0.0361
Block 0 o_proj std: 0.0104
Block 1 q_proj std: 0.0361
Block 1 o_proj std: 0.0104
Block 2 q_proj std: 0.0361
Block 2 o_proj std: 0.0104


In [24]:
# Check loss calculation
x, y = get_batch()
with torch.amp.autocast('cuda', dtype=DTYPE):
    logits, loss, _ = model(x, y)

print(f"Logits shape: {logits.shape}")
print(f"Logits range: [{logits.min().item():.2f}, {logits.max().item():.2f}]")
print(f"Loss: {loss.item():.4f}")
print(f"Expected random loss: {np.log(MODEL_CFG.vocab_size):.4f}")

# Check if logits are sane
probs = F.softmax(logits[0, 0].float(), dim=-1)
print(f"Prob sum: {probs.sum().item():.4f}")
print(f"Max prob: {probs.max().item():.6f}")

Logits shape: torch.Size([2, 512, 50257])
Logits range: [-82.50, 274.00]
Loss: 235.8252
Expected random loss: 10.8249
Prob sum: 1.0000
Max prob: 1.000000


In [11]:
# CELL 3: MONITORING

def print_gradient_summary(model):
    agg = {'embed': [], 'gdn': [], 'swa': [], 'ffn': []}
    for name, p in model.named_parameters():
        if p.grad is None:
            continue
        n = p.grad.norm().item()
        if 'embed' in name:
            agg['embed'].append(n)
        elif 'ffn' in name:
            agg['ffn'].append(n)
        elif 'blocks' in name:
            idx = int(name.split('.')[1])
            agg['swa' if idx in model._swa_indices else 'gdn'].append(n)
    print("Gradients:")
    for k, v in agg.items():
        if v:
            print(f"  {k}: mean={np.mean(v):.3f} max={np.max(v):.2f}")


def needle_test(model, tokenizer, seq_len=512, n_trials=50, needle_token=50250, device="cuda"):
    model.eval()
    probs = []
    with torch.no_grad():
        for _ in range(n_trials):
            tokens = torch.randint(1000, 10000, (1, seq_len), device=device)
            pos = torch.randint(64, seq_len - 64, (1,)).item()
            tokens[0, pos] = needle_token
            with torch.amp.autocast('cuda', dtype=DTYPE):
                logits, _, _ = model(tokens)
            probs.append(F.softmax(logits[0, -1].float(), dim=-1)[needle_token].item())
    rc = 1.0 / tokenizer.vocab_size
    return {'mean': np.mean(probs), 'ratio': np.mean(probs) / rc}


def probe_layers(model, needle_id=50250, seq_len=512, pos=256, device="cuda"):
    model.eval()
    tokens = torch.randint(1000, 10000, (1, seq_len), device=device)
    tokens[0, pos] = needle_id
    with torch.no_grad():
        x = model.embed(tokens)
        emb = model.embed.weight[needle_id].float()
        print("Needle representation through layers:")
        for i, (b, f) in enumerate(zip(model.blocks, model.ffns)):
            x, _ = b(x, None, False)
            x = f(x)
            sim = F.cosine_similarity(x[0, pos].float(), emb, dim=0).item()
            ltype = 'SWA' if i in model._swa_indices else 'GDN'
            print(f"  L{i:2d}[{ltype}]: {sim:+.3f}")

print("Monitoring functions ready")

Monitoring functions ready


In [7]:
# CELL 4: DATA LOADING
from datasets import load_dataset
from tqdm import tqdm

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
MODEL_CFG.vocab_size = tokenizer.vocab_size

print(f"Streaming {TRAIN_CFG.dataset_name}...")
ds = load_dataset(TRAIN_CFG.dataset_name, name=TRAIN_CFG.dataset_subset, split="train", streaming=True)

buf = []
pbar = tqdm(total=TRAIN_CFG.target_tokens, unit="tok", desc="Tokenizing")
for ex in ds:
    toks = tokenizer.encode(ex['text']) + [tokenizer.eos_token_id]
    buf.extend(toks)
    pbar.update(len(toks))
    if len(buf) >= TRAIN_CFG.target_tokens:
        break
pbar.close()

all_tokens = torch.tensor(buf[:TRAIN_CFG.target_tokens], dtype=torch.long)
del buf, ds
print(f"Loaded {len(all_tokens):,} tokens")

def get_batch():
    ix = torch.randint(len(all_tokens) - TRAIN_CFG.seq_len - 1, (TRAIN_CFG.batch_size,))
    x = torch.stack([all_tokens[i:i+TRAIN_CFG.seq_len] for i in ix])
    y = torch.stack([all_tokens[i+1:i+TRAIN_CFG.seq_len+1] for i in ix])
    return x.to(DEVICE), y.to(DEVICE)

Streaming HuggingFaceFW/fineweb-edu...


Tokenizing:   0%|          | 846/10000000 [00:04<14:49:03, 187.45tok/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1055 > 1024). Running this sequence through the model will result in indexing errors
Tokenizing: 10000241tok [00:24, 400932.79tok/s]                            


Loaded 10,000,000 tokens


In [16]:
# CELL 5: BUILD MODEL
print("Building model...")
model = GroundThinkLM(MODEL_CFG).to(DEVICE).to(DTYPE)

p = model.count_parameters()
print(f"Parameters: {p['total']/1e6:.2f}M")
print(f"  GDN: {p['gdn']/1e6:.2f}M, SWA: {p['swa']/1e6:.2f}M, FFN: {p['ffn']/1e6:.2f}M")
print(f"Layers: {model.get_layer_types()}")

# Test forward/backward
print("\nTesting forward/backward...")
x, y = get_batch()
with torch.amp.autocast('cuda', dtype=DTYPE):
    _, loss, _ = model(x, y)
loss.backward()
print(f"Forward OK: loss={loss.item():.4f}")
print(f"Peak memory: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
model.zero_grad()
torch.cuda.reset_peak_memory_stats()

Building model...
Parameters: 48.71M
  GDN: 28.57M, SWA: 0.79M, FFN: 6.49M
Layers: ['GDN', 'GDN', 'GDN', 'SWA', 'GDN', 'GDN', 'GDN', 'SWA', 'GDN', 'GDN', 'GDN', 'SWA']

Testing forward/backward...
Forward OK: loss=235.1660
Peak memory: 2.45GB


In [18]:
# Check loss calculation
x, y = get_batch()
with torch.amp.autocast('cuda', dtype=DTYPE):
    logits, loss, _ = model(x, y)

print(f"Logits shape: {logits.shape}")
print(f"Logits range: [{logits.min().item():.2f}, {logits.max().item():.2f}]")
print(f"Loss: {loss.item():.4f}")
print(f"Expected random loss: {np.log(MODEL_CFG.vocab_size):.4f}")

# Check if logits are sane
probs = F.softmax(logits[0, 0].float(), dim=-1)
print(f"Prob sum: {probs.sum().item():.4f}")
print(f"Max prob: {probs.max().item():.6f}")

Logits shape: torch.Size([2, 512, 50257])
Logits range: [-86.50, 272.00]
Loss: 235.0752
Expected random loss: 10.8249
Prob sum: 1.0000
Max prob: 1.000000


In [14]:
# CELL 6: TRAINING LOOP
opt = torch.optim.AdamW(model.parameters(), lr=TRAIN_CFG.lr, betas=TRAIN_CFG.betas, weight_decay=TRAIN_CFG.weight_decay)
losses, niah_traj = [], []
start = time.time()

print(f"\nTRAINING {TRAIN_CFG.steps} steps (effective batch={TRAIN_CFG.effective_batch_size})\n")
model.train()

for step in range(TRAIN_CFG.steps):
    # LR schedule: warmup then cosine decay
    if step < TRAIN_CFG.warmup_steps:
        lr = TRAIN_CFG.lr * (step + 1) / TRAIN_CFG.warmup_steps
    else:
        progress = (step - TRAIN_CFG.warmup_steps) / (TRAIN_CFG.steps - TRAIN_CFG.warmup_steps)
        lr = TRAIN_CFG.lr * 0.5 * (1 + math.cos(math.pi * progress))
    for pg in opt.param_groups:
        pg['lr'] = lr
    
    # Gradient accumulation
    acc_loss = 0
    for _ in range(TRAIN_CFG.accum_steps):
        x, y = get_batch()
        with torch.amp.autocast('cuda', dtype=DTYPE):
            _, loss, _ = model(x, y)
        (loss / TRAIN_CFG.accum_steps).backward()
        acc_loss += loss.item()
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), TRAIN_CFG.grad_clip)
    opt.step()
    opt.zero_grad()
    losses.append(acc_loss / TRAIN_CFG.accum_steps)
    
    # Logging
    if step % TRAIN_CFG.log_interval == 0:
        avg = np.mean(losses[-50:]) if len(losses) >= 50 else np.mean(losses)
        elapsed = time.time() - start
        tps = (step + 1) * TRAIN_CFG.effective_batch_size * TRAIN_CFG.seq_len / elapsed
        print(f"[{step:5d}/{TRAIN_CFG.steps}] loss={avg:.4f} lr={lr:.2e} {tps:,.0f} tok/s")
    
    if (step + 1) % TRAIN_CFG.grad_log_interval == 0:
        print_gradient_summary(model)
    
    # NIAH checkpoint
    if (step + 1) in TRAIN_CFG.niah_checkpoints:
        n = needle_test(model, tokenizer, TRAIN_CFG.seq_len, 30, device=DEVICE)
        niah_traj.append((step + 1, n['ratio']))
        status = "PASS" if n['ratio'] > 1.0 else "FAIL"
        print(f"  >>> NIAH@{step+1}: {n['ratio']:.2f}x random [{status}]")
        model.train()

# Summary
elapsed = time.time() - start
print(f"\n{'='*60}")
print(f"Training complete in {elapsed/60:.1f} minutes")
print(f"Loss: {np.mean(losses[:50]):.4f} -> {np.mean(losses[-50:]):.4f}")
print(f"NIAH trajectory: {niah_traj}")


TRAINING 5000 steps (effective batch=4)

[    0/5000] loss=233.8555 lr=6.00e-07 2,756 tok/s
[   50/5000] loss=232.5334 lr=3.06e-05 8,188 tok/s
[  100/5000] loss=210.8668 lr=6.06e-05 8,748 tok/s
[  150/5000] loss=64.4376 lr=9.06e-05 8,521 tok/s
[  200/5000] loss=35.3392 lr=1.21e-04 8,606 tok/s
[  250/5000] loss=32.2790 lr=1.51e-04 8,813 tok/s
[  300/5000] loss=31.0059 lr=1.81e-04 8,979 tok/s
[  350/5000] loss=30.5083 lr=2.11e-04 8,990 tok/s
[  400/5000] loss=30.2468 lr=2.41e-04 9,079 tok/s
[  450/5000] loss=29.6533 lr=2.71e-04 9,111 tok/s
Gradients:
  >>> NIAH@500: 0.00x random [FAIL]
[  500/5000] loss=29.0275 lr=3.00e-04 7,580 tok/s
[  550/5000] loss=28.3613 lr=3.00e-04 7,618 tok/s
[  600/5000] loss=27.9643 lr=3.00e-04 7,711 tok/s
[  650/5000] loss=27.6132 lr=2.99e-04 7,801 tok/s
[  700/5000] loss=27.3086 lr=2.99e-04 7,847 tok/s
[  750/5000] loss=26.9528 lr=2.98e-04 7,927 tok/s
[  800/5000] loss=26.8805 lr=2.97e-04 7,958 tok/s
[  850/5000] loss=26.6274 lr=2.96e-04 8,007 tok/s
[  900/5

In [None]:
# CELL 7: FINAL EVALUATION
print("\n" + "="*60)
print("FINAL EVALUATION")
print("="*60)

# NIAH at multiple lengths
for L in [128, 256, 512]:
    n = needle_test(model, tokenizer, L, 50, device=DEVICE)
    status = "PASS" if n['ratio'] > 1.0 else "FAIL"
    print(f"NIAH@{L}: {n['ratio']:.2f}x random [{status}]")

# Layer probing
print()
probe_layers(model, device=DEVICE)

# Verdict
lm_pass = np.mean(losses[:50]) - np.mean(losses[-50:]) > 2.0
niah_pass = any(r > 1.0 for _, r in niah_traj)

print(f"\nVerdict:")
print(f"  LM Training: {'PASS' if lm_pass else 'MARGINAL'}")
print(f"  NIAH: {'PASS' if niah_pass else 'FAIL'}")

In [None]:
# CELL 8: SAVE CHECKPOINT
import os
from pathlib import Path

save_dir = Path("./checkpoints")
save_dir.mkdir(exist_ok=True)

ckpt_path = save_dir / "groundthink_v6_final.pt"
torch.save({
    'model_state_dict': model.state_dict(),
    'config': MODEL_CFG,
    'losses': losses,
    'niah_trajectory': niah_traj,
}, ckpt_path)

print(f"Saved checkpoint: {ckpt_path}")
print(f"Size: {ckpt_path.stat().st_size / 1e6:.1f} MB")