# GroundThink V6 - Hybrid GatedDeltaNet + SWAttention

**Architecture:** Uses FLA library's `GatedDeltaNet` and `SlidingWindowAttention` directly.

**Key features:**
- FLA kernels (NVIDIA recommended)
- vLLM IsHybrid protocol ready (state management, introspection methods)
- Configurable layer pattern via `attn_interval`
- Gradient checkpointing on SWA layers
- NIAH testing + gradient monitoring from V5

In [1]:
# CELL 0: ENVIRONMENT SETUP
from google.colab import drive
drive.mount('/content/drive')

!pip install -q flash-linear-attention
!pip install -q datasets transformers

import sys
print(f"Python: {sys.version}")

ValueError: mount failed

In [3]:
# CELL 1: CONFIGURATION
from dataclasses import dataclass, field
from typing import List, Optional
import torch

@dataclass
class ModelConfig:
    vocab_size: int = 50257
    d_model: int = 256
    n_layers: int = 12
    n_heads: int = 8
    head_dim: int = 64
    attn_interval: int = 4  # SWA every N layers (4 = 3:1 ratio)
    window_size: int = 2048
    expand_k: float = 1.0
    expand_v: float = 2.0
    max_seq_len: int = 2048
    use_gradient_checkpointing: bool = True
    tie_weights: bool = True
    
    def __post_init__(self):
        assert self.d_model % self.n_heads == 0
        self.head_dim = self.d_model // self.n_heads
        
    def get_swa_layer_indices(self) -> List[int]:
        return [i for i in range(self.n_layers) 
                if i % self.attn_interval == (self.attn_interval - 1)]

@dataclass 
class TrainConfig:
    dataset_name: str = "HuggingFaceFW/fineweb-edu"
    dataset_subset: str = "sample-10BT"
    target_tokens: int = 20_000_000
    batch_size: int = 2
    seq_len: int = 512
    accum_steps: int = 2
    steps: int = 10000
    warmup_ratio: float = 0.1
    lr: float = 3e-4
    weight_decay: float = 0.1
    grad_clip: float = 1.0
    betas: tuple = (0.9, 0.95)
    log_interval: int = 50
    grad_log_interval: int = 500
    niah_checkpoints: List[int] = field(default_factory=lambda: [
        500, 1000, 2000, 3000, 5000, 7500, 10000
    ])
    
    @property
    def warmup_steps(self) -> int:
        return int(self.steps * self.warmup_ratio)
    
    @property
    def effective_batch_size(self) -> int:
        return self.batch_size * self.accum_steps

MODEL_CFG = ModelConfig()
TRAIN_CFG = TrainConfig()

print("=" * 60)
print("GROUNDTHINK V6 CONFIG")
print("=" * 60)
print(f"d_model:       {MODEL_CFG.d_model}")
print(f"n_layers:      {MODEL_CFG.n_layers}")
print(f"attn_interval: {MODEL_CFG.attn_interval} -> SWA at {MODEL_CFG.get_swa_layer_indices()}")
print(f"window_size:   {MODEL_CFG.window_size}")
print(f"batch:         {TRAIN_CFG.effective_batch_size}")
print(f"steps:         {TRAIN_CFG.steps}")

GROUNDTHINK V6 CONFIG
d_model:       256
n_layers:      12
attn_interval: 4 -> SWA at [3, 7, 11]
window_size:   2048
batch:         4
steps:         10000


In [4]:
# CELL 2: IMPORTS & DEVICE
import math
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from transformers import AutoTokenizer

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
    props = torch.cuda.get_device_properties(0)
    print(f"GPU: {props.name} ({props.total_memory/1e9:.1f}GB)")
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

from fla.layers import GatedDeltaNet
from fla.layers import SlidingWindowAttention as FLA_SWA
print("FLA loaded")

ModuleNotFoundError: No module named 'fla'

In [None]:
# CELL 3: MODEL COMPONENTS

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
        return (x.float() * norm).type_as(x) * self.weight


class SwiGLUFFN(nn.Module):
    def __init__(self, d_model: int, expansion: float = 8/3):
        super().__init__()
        hidden = int(d_model * expansion)
        hidden = ((hidden + 63) // 64) * 64
        self.w1 = nn.Linear(d_model, hidden, bias=False)
        self.w3 = nn.Linear(d_model, hidden, bias=False)
        self.w2 = nn.Linear(hidden, d_model, bias=False)
        self.norm = RMSNorm(d_model)
        nn.init.normal_(self.w1.weight, std=0.02)
        nn.init.normal_(self.w3.weight, std=0.02)
        nn.init.normal_(self.w2.weight, std=0.02 / math.sqrt(2))
    
    def forward(self, x):
        residual = x
        x = self.norm(x)
        return residual + self.w2(F.silu(self.w1(x)) * self.w3(x))


class HybridBlock(nn.Module):
    """GatedDeltaNet or SWA with vLLM IsHybrid state support."""
    def __init__(self, d_model: int, is_attention: bool, n_heads: int = 8,
                 window_size: int = 2048, expand_k: float = 1.0, 
                 expand_v: float = 2.0, layer_idx: int = 0):
        super().__init__()
        self.is_attention = is_attention
        self.layer_idx = layer_idx
        self.norm = RMSNorm(d_model)
        
        if is_attention:
            self.layer = FLA_SWA(
                hidden_size=d_model, num_heads=n_heads,
                window_size=window_size, layer_idx=layer_idx)
        else:
            self.layer = GatedDeltaNet(
                hidden_size=d_model, expand_k=expand_k,
                expand_v=expand_v, layer_idx=layer_idx)
    
    def forward(self, x, attention_mask=None, past_state=None, use_cache=False):
        residual = x
        x = self.norm(x)
        new_state = None
        
        if use_cache:
            if self.is_attention:
                x, new_state = self.layer(x, attention_mask=attention_mask,
                                          past_key_values=past_state, use_cache=True)
            else:
                x, new_state = self.layer(x, past_state=past_state, use_cache=True)
        else:
            if self.is_attention and attention_mask is not None:
                x = self.layer(x, attention_mask=attention_mask)
            else:
                x = self.layer(x)
        
        return residual + x, new_state


class GroundThinkLM(nn.Module):
    """Hybrid LM with vLLM IsHybrid protocol support."""
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
        nn.init.normal_(self.embed.weight, std=0.02)
        
        swa_indices = set(cfg.get_swa_layer_indices())
        self._swa_indices = swa_indices
        self._gdn_indices = set(range(cfg.n_layers)) - swa_indices
        
        self.blocks = nn.ModuleList()
        self.ffns = nn.ModuleList()
        
        for i in range(cfg.n_layers):
            is_attn = i in swa_indices
            self.blocks.append(HybridBlock(
                d_model=cfg.d_model, is_attention=is_attn,
                n_heads=cfg.n_heads, window_size=cfg.window_size,
                expand_k=cfg.expand_k, expand_v=cfg.expand_v, layer_idx=i))
            self.ffns.append(SwiGLUFFN(cfg.d_model))
        
        self.norm_f = RMSNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        if cfg.tie_weights:
            self.lm_head.weight = self.embed.weight
        
    def forward(self, input_ids, targets=None, attention_mask=None,
                past_states=None, use_cache=False):
        x = self.embed(input_ids)
        new_states = [] if use_cache else None
        
        for i, (block, ffn) in enumerate(zip(self.blocks, self.ffns)):
            past_state = past_states[i] if past_states else None
            
            if self.cfg.use_gradient_checkpointing and self.training and i in self._swa_indices:
                x = checkpoint(self._forward_block_train, block, ffn, x, attention_mask,
                               use_reentrant=False)
            else:
                x, state = block(x, attention_mask, past_state, use_cache)
                x = ffn(x)
                if use_cache:
                    new_states.append(state)
        
        x = self.norm_f(x)
        logits = self.lm_head(x)
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        
        return logits, loss, new_states
    
    @staticmethod
    def _forward_block_train(block, ffn, x, attention_mask):
        x, _ = block(x, attention_mask, None, False)
        x = ffn(x)
        return x
    
    # vLLM IsHybrid protocol
    @classmethod
    def get_state_dtype_from_config(cls, config):
        return {'gdn': torch.bfloat16, 'swa': torch.bfloat16}
    
    @classmethod
    def get_state_shape_from_config(cls, config, batch_size=1):
        head_dim = config.d_model // config.n_heads
        gdn_state_dim = int(config.d_model * config.expand_v)
        return {
            'gdn': (batch_size, gdn_state_dim),
            'swa_k': (batch_size, config.n_heads, config.window_size, head_dim),
            'swa_v': (batch_size, config.n_heads, config.window_size, head_dim),
        }
    
    def get_layer_types(self):
        return ['swa' if i in self._swa_indices else 'gdn' for i in range(self.cfg.n_layers)]
    
    def count_parameters(self):
        counts = {'embed': sum(p.numel() for p in self.embed.parameters()),
                  'gdn': 0, 'swa': 0, 'ffn': 0}
        for i, (block, ffn) in enumerate(zip(self.blocks, self.ffns)):
            bp = sum(p.numel() for p in block.parameters())
            fp = sum(p.numel() for p in ffn.parameters())
            if i in self._swa_indices:
                counts['swa'] += bp
            else:
                counts['gdn'] += bp
            counts['ffn'] += fp
        counts['total'] = sum(counts.values())
        return counts

print("Model components defined")

In [None]:
# CELL 4: MONITORING

def print_gradient_summary(model):
    agg = {'embed': [], 'gdn': [], 'swa': [], 'ffn': [], 'norm': []}
    swa_idx = model._swa_indices
    
    for name, param in model.named_parameters():
        if param.grad is None:
            continue
        norm = param.grad.norm().item()
        if 'embed' in name:
            agg['embed'].append(norm)
        elif 'norm' in name:
            agg['norm'].append(norm)
        elif 'ffn' in name:
            agg['ffn'].append(norm)
        elif 'blocks' in name:
            try:
                idx = int(name.split('.')[1])
                if idx in swa_idx:
                    agg['swa'].append(norm)
                else:
                    agg['gdn'].append(norm)
            except:
                pass
    
    print("\nGradient Norms:")
    for comp, vals in agg.items():
        if vals:
            m, mx = np.mean(vals), np.max(vals)
            flag = " WARNING" if mx > 5.0 else ""
            print(f"  {comp:<6} mean={m:6.3f} max={mx:6.2f}{flag}")


def needle_test(model, tokenizer, seq_len=512, n_trials=50, needle_token=50250, device="cuda"):
    model.eval()
    random_chance = 1.0 / tokenizer.vocab_size
    probs = []
    
    with torch.no_grad():
        for _ in range(n_trials):
            tokens = torch.randint(1000, 10000, (1, seq_len), device=device)
            pos = torch.randint(64, seq_len - 64, (1,)).item()
            tokens[0, pos] = needle_token
            
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                logits, _, _ = model(tokens)
            
            p = F.softmax(logits[0, -1].float(), dim=-1)[needle_token].item()
            probs.append(p)
    
    return {
        'mean': np.mean(probs), 'std': np.std(probs),
        'random_chance': random_chance, 'ratio': np.mean(probs) / random_chance
    }


def probe_layer_representations(model, tokenizer, needle_id=50250, seq_len=512, 
                                 needle_pos=256, device="cuda"):
    model.eval()
    tokens = torch.randint(1000, 10000, (1, seq_len), device=device)
    tokens[0, needle_pos] = needle_id
    
    with torch.no_grad():
        x = model.embed(tokens)
        needle_embed = model.embed.weight[needle_id].float()
        
        print(f"\nNeedle ({needle_id}) representation through layers:")
        swa_idx = model._swa_indices
        
        for i, (block, ffn) in enumerate(zip(model.blocks, model.ffns)):
            x, _ = block(x, None, None, False)
            x = ffn(x)
            
            sim = F.cosine_similarity(x[0, needle_pos].float(), needle_embed, dim=0).item()
            layer_type = "SWA" if i in swa_idx else "GDN"
            bar = "#" * int(max(0, (sim + 1) * 10))
            print(f"  L{i:2d} [{layer_type}]: {sim:+.3f} {bar}")

print("Monitoring functions defined")

In [None]:
# CELL 5: DATA LOADING
from datasets import load_dataset
from tqdm import tqdm

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
MODEL_CFG.vocab_size = tokenizer.vocab_size

print(f"Streaming: {TRAIN_CFG.dataset_name}/{TRAIN_CFG.dataset_subset}")
dataset = load_dataset(TRAIN_CFG.dataset_name, name=TRAIN_CFG.dataset_subset, 
                       split="train", streaming=True)

token_buffer = []
pbar = tqdm(total=TRAIN_CFG.target_tokens, desc="Tokenizing", unit="tok")

for example in dataset:
    tokens = tokenizer.encode(example['text']) + [tokenizer.eos_token_id]
    token_buffer.extend(tokens)
    pbar.update(len(tokens))
    if len(token_buffer) >= TRAIN_CFG.target_tokens:
        break

pbar.close()

all_tokens = torch.tensor(token_buffer[:TRAIN_CFG.target_tokens], dtype=torch.long)
del token_buffer, dataset
import gc; gc.collect()

print(f"Loaded {len(all_tokens):,} tokens")

def get_batch(batch_size=TRAIN_CFG.batch_size, seq_len=TRAIN_CFG.seq_len):
    ix = torch.randint(len(all_tokens) - seq_len - 1, (batch_size,))
    x = torch.stack([all_tokens[i:i+seq_len] for i in ix])
    y = torch.stack([all_tokens[i+1:i+seq_len+1] for i in ix])
    return x.to(DEVICE), y.to(DEVICE)

In [None]:
# CELL 6: BUILD MODEL
print("Building model...")
model = GroundThinkLM(MODEL_CFG).to(DEVICE).to(torch.bfloat16)

params = model.count_parameters()
print(f"Parameters: {params['total']/1e6:.2f}M")
print(f"  Embed: {params['embed']/1e6:.2f}M")
print(f"  GDN:   {params['gdn']/1e6:.2f}M")
print(f"  SWA:   {params['swa']/1e6:.2f}M")
print(f"  FFN:   {params['ffn']/1e6:.2f}M")
print(f"SWA at layers: {sorted(model._swa_indices)}")
print(f"Layer types: {model.get_layer_types()}")

# Test
print("\nTesting forward/backward...")
x, y = get_batch()
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
    logits, loss, _ = model(x, y)
loss.backward()
print(f"Forward: loss={loss.item():.4f}")
print(f"Backward: OK")
print(f"Peak memory: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
model.zero_grad()

In [None]:
# CELL 7: TRAINING
optimizer = torch.optim.AdamW(model.parameters(), lr=TRAIN_CFG.lr,
                               betas=TRAIN_CFG.betas, weight_decay=TRAIN_CFG.weight_decay)

losses, grad_norms, niah_trajectory = [], [], []
random_chance = 1.0 / tokenizer.vocab_size
start = time.time()

print("\n" + "=" * 60)
print(f"TRAINING ({TRAIN_CFG.steps} steps)")
print(f"  Effective batch: {TRAIN_CFG.effective_batch_size}")
print(f"  Tokens/step: {TRAIN_CFG.effective_batch_size * TRAIN_CFG.seq_len:,}")
print("=" * 60 + "\n")

model.train()
optimizer.zero_grad()

for step in range(TRAIN_CFG.steps):
    # LR schedule
    if step < TRAIN_CFG.warmup_steps:
        lr = TRAIN_CFG.lr * (step + 1) / TRAIN_CFG.warmup_steps
    else:
        progress = (step - TRAIN_CFG.warmup_steps) / (TRAIN_CFG.steps - TRAIN_CFG.warmup_steps)
        lr = TRAIN_CFG.lr * 0.5 * (1 + math.cos(math.pi * progress))
    
    for pg in optimizer.param_groups:
        pg['lr'] = lr
    
    # Accumulation
    accum_loss = 0.0
    for _ in range(TRAIN_CFG.accum_steps):
        x, y = get_batch()
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            _, loss, _ = model(x, y)
            loss_scaled = loss / TRAIN_CFG.accum_steps
        loss_scaled.backward()
        accum_loss += loss.item()
    
    total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), TRAIN_CFG.grad_clip)
    optimizer.step()
    optimizer.zero_grad()
    
    avg_loss = accum_loss / TRAIN_CFG.accum_steps
    losses.append(avg_loss)
    grad_norms.append(total_norm.item())
    
    if step % TRAIN_CFG.log_interval == 0:
        avg = np.mean(losses[-50:]) if len(losses) >= 50 else np.mean(losses)
        elapsed = time.time() - start
        tps = (step + 1) * TRAIN_CFG.effective_batch_size * TRAIN_CFG.seq_len / elapsed
        print(f"[{step:5d}/{TRAIN_CFG.steps}] loss={avg:.4f} | grad={total_norm.item():.3f} | lr={lr:.2e} | {tps:,.0f} tok/s")
    
    if (step + 1) % TRAIN_CFG.grad_log_interval == 0:
        print_gradient_summary(model)
    
    if (step + 1) in TRAIN_CFG.niah_checkpoints:
        niah = needle_test(model, tokenizer, TRAIN_CFG.seq_len, n_trials=30, device=DEVICE)
        ratio = niah['ratio']
        niah_trajectory.append((step + 1, ratio))
        status = "PASS" if ratio > 1.0 else "MARGINAL" if ratio > 0.5 else "FAIL"
        print(f"  >>> NIAH@{step+1}: {ratio:.2f}x random [{status}]")
        model.train()

# Summary
elapsed = time.time() - start
print("\n" + "=" * 60)
print("TRAINING COMPLETE")
print("=" * 60)
print(f"Time: {elapsed/60:.1f} min")
print(f"Speed: {TRAIN_CFG.steps * TRAIN_CFG.effective_batch_size * TRAIN_CFG.seq_len / elapsed:,.0f} tok/s")
print(f"Initial loss: {np.mean(losses[:50]):.4f}")
print(f"Final loss: {np.mean(losses[-50:]):.4f}")
print(f"\nNIAH Trajectory:")
for step, ratio in niah_trajectory:
    bar = "#" * int(min(ratio * 5, 20))
    print(f"  {step:>6}: {ratio:.2f}x {bar}")

In [None]:
# CELL 8: FINAL EVALUATION
print("\n" + "=" * 60)
print("FINAL EVALUATION")
print("=" * 60)

print("\nNeedle-in-a-Haystack:")
for length in [128, 256, 512, 1024]:
    niah = needle_test(model, tokenizer, length, n_trials=50, device=DEVICE)
    status = "PASS" if niah['ratio'] > 1.0 else "MARGINAL" if niah['ratio'] > 0.5 else "FAIL"
    print(f"  NIAH@{length}: {niah['ratio']:.2f}x random [{status}] (P={niah['mean']:.2e})")

probe_layer_representations(model, tokenizer, device=DEVICE)

print("\n" + "=" * 60)
print("VERDICT")
print("=" * 60)

loss_drop = np.mean(losses[:50]) - np.mean(losses[-50:])
niah_pass = any(r > 1.0 for _, r in niah_trajectory)

print(f"Language Modeling: {'PASS' if loss_drop > 2.0 else 'MARGINAL'} (drop={loss_drop:.2f})")
print(f"NIAH Retrieval: {'PASS' if niah_pass else 'FAIL'}")

In [None]:
# CELL 9: EXPORT
import csv
from datetime import datetime

run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
export_dir = "/content/drive/MyDrive/groundthink/colab-exports"

# Checkpoint
ckpt_path = f"{export_dir}/v6_{run_id}.pt"
torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': MODEL_CFG,
    'train_config': TRAIN_CFG,
    'losses': losses,
    'niah_trajectory': niah_trajectory,
}, ckpt_path)
print(f"Checkpoint: {ckpt_path}")

# CSV
csv_path = f"{export_dir}/v6_{run_id}.csv"
final_niah = needle_test(model, tokenizer, 512, n_trials=50, device=DEVICE)

with open(csv_path, 'w', newline='') as f:
    w = csv.writer(f)
    w.writerow(['run_id', 'd_model', 'n_layers', 'attn_interval', 'window_size',
                'seq_len', 'steps', 'initial_loss', 'final_loss', 'niah_ratio'])
    w.writerow([run_id, MODEL_CFG.d_model, MODEL_CFG.n_layers, MODEL_CFG.attn_interval,
                MODEL_CFG.window_size, TRAIN_CFG.seq_len, TRAIN_CFG.steps,
                f"{np.mean(losses[:50]):.4f}", f"{np.mean(losses[-50:]):.4f}",
                f"{final_niah['ratio']:.2f}x"])
print(f"Results: {csv_path}")