# GroundThink V6 - Hybrid GatedDeltaNet + SWAttention

In [None]:
# CELL 0: INSTALL & EXPLORE FLA
from google.colab import drive
drive.mount('/content/drive')

!pip install -q triton
!pip install -q flash-linear-attention
!pip install -q datasets transformers

import sys
print(f"Python: {sys.version}")

# Explore FLA structure
import fla
print(f"\nFLA top-level: {[x for x in dir(fla) if not x.startswith('_')]}")

# Find all submodules
import pkgutil
print("\nFLA submodules:")
for importer, modname, ispkg in pkgutil.walk_packages(fla.__path__, fla.__name__ + '.'):
    print(f"  {modname}")

# Try common import paths
print("\n--- Trying imports ---")
try:
    from fla.layers import GatedDeltaNet
    print("from fla.layers import GatedDeltaNet: OK")
except ImportError as e:
    print(f"from fla.layers: {e}")

try:
    from fla.models import GatedDeltaNetModel
    print("from fla.models import GatedDeltaNetModel: OK")
except ImportError as e:
    print(f"from fla.models: {e}")

try:
    from fla.ops.gated_delta_rule import GatedDeltaNet
    print("from fla.ops.gated_delta_rule: OK")
except ImportError as e:
    print(f"from fla.ops.gated_delta_rule: {e}")

# Check what's in fla.layers if it exists
try:
    import fla.layers as fla_layers
    print(f"\nfla.layers contents: {[x for x in dir(fla_layers) if not x.startswith('_')]}")
except:
    pass

# Check what's in fla.models if it exists  
try:
    import fla.models as fla_models
    print(f"\nfla.models contents: {[x for x in dir(fla_models) if not x.startswith('_')]}")
except:
    pass

In [None]:
# CELL 1: CONFIGURATION (run after Cell 0 confirms imports)
from dataclasses import dataclass, field
from typing import List, Optional
import torch

@dataclass
class ModelConfig:
    vocab_size: int = 50257
    d_model: int = 512
    n_layers: int = 12
    n_heads: int = 8
    head_dim: int = 64
    attn_interval: int = 4
    window_size: int = 2048
    expand_k: float = 1.0
    expand_v: float = 2.0
    max_seq_len: int = 2048
    use_gradient_checkpointing: bool = True
    tie_weights: bool = True
    
    def __post_init__(self):
        self.head_dim = self.d_model // self.n_heads
        
    def get_swa_layer_indices(self):
        return [i for i in range(self.n_layers) if i % self.attn_interval == (self.attn_interval - 1)]

@dataclass 
class TrainConfig:
    dataset_name: str = "HuggingFaceFW/fineweb-edu"
    dataset_subset: str = "sample-10BT"
    target_tokens: int = 20_000_000
    batch_size: int = 2
    seq_len: int = 512
    accum_steps: int = 2
    steps: int = 10000
    warmup_ratio: float = 0.1
    lr: float = 3e-4
    weight_decay: float = 0.1
    grad_clip: float = 1.0
    betas: tuple = (0.9, 0.95)
    log_interval: int = 50
    grad_log_interval: int = 500
    niah_checkpoints: List[int] = field(default_factory=lambda: [500, 1000, 2000, 3000, 5000, 7500, 10000])
    
    @property
    def warmup_steps(self): return int(self.steps * self.warmup_ratio)
    @property
    def effective_batch_size(self): return self.batch_size * self.accum_steps

MODEL_CFG = ModelConfig()
TRAIN_CFG = TrainConfig()
print(f"Config: d={MODEL_CFG.d_model}, layers={MODEL_CFG.n_layers}, SWA@{MODEL_CFG.get_swa_layer_indices()}")

In [None]:
# CELL 2: IMPORTS - UPDATE THESE BASED ON CELL 0 OUTPUT
import math, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from transformers import AutoTokenizer

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
    props = torch.cuda.get_device_properties(0)
    print(f"GPU: {props.name} ({props.total_memory/1e9:.1f}GB)")
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

# ============================================================
# FLA IMPORTS - UPDATE BASED ON CELL 0 EXPLORATION
# ============================================================
# Option 1: If fla.layers works
# from fla.layers import GatedDeltaNet
# from fla.layers import SlidingWindowAttention as FLA_SWA

# Option 2: If using models
# from fla.models.gated_delta_net import GatedDeltaNetConfig, GatedDeltaNetModel

# Option 3: Build from ops
# from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule

# For now, let's see what we found:
import fla
print("Checking fla structure...")

# Try the most likely paths
GatedDeltaNet = None
FLA_SWA = None

# Try fla.layers
try:
    from fla.layers import GatedDeltaNet as GDN
    GatedDeltaNet = GDN
    print("GatedDeltaNet from fla.layers: OK")
except ImportError:
    pass

# Try fla.layers.gated_delta_net
if GatedDeltaNet is None:
    try:
        from fla.layers.gated_delta_net import GatedDeltaNet as GDN
        GatedDeltaNet = GDN
        print("GatedDeltaNet from fla.layers.gated_delta_net: OK")
    except ImportError:
        pass

# Try sliding window
try:
    from fla.layers import SlidingWindowAttention as SWA
    FLA_SWA = SWA
    print("SlidingWindowAttention from fla.layers: OK")
except ImportError:
    pass

# Fallback: use standard MultiheadAttention with window masking
if FLA_SWA is None:
    print("SlidingWindowAttention not found - will use custom implementation")
    FLA_SWA = None

if GatedDeltaNet is None:
    raise ImportError("Could not find GatedDeltaNet in FLA. Check Cell 0 output for available modules.")

print(f"\nGatedDeltaNet: {GatedDeltaNet}")
print(f"SWA: {FLA_SWA}")

In [None]:
# CELL 3: MODEL COMPONENTS

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        return (x.float() * x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()).type_as(x) * self.weight

class SwiGLUFFN(nn.Module):
    def __init__(self, d_model, expansion=8/3):
        super().__init__()
        hidden = ((int(d_model * expansion) + 63) // 64) * 64
        self.w1 = nn.Linear(d_model, hidden, bias=False)
        self.w3 = nn.Linear(d_model, hidden, bias=False)
        self.w2 = nn.Linear(hidden, d_model, bias=False)
        self.norm = RMSNorm(d_model)
    def forward(self, x):
        h = self.norm(x)
        return x + self.w2(F.silu(self.w1(h)) * self.w3(h))

# Custom SWA if FLA doesn't have one
class CustomSlidingWindowAttention(nn.Module):
    """Multi-head attention with sliding window mask."""
    def __init__(self, hidden_size, num_heads, window_size, layer_idx=0):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.window_size = window_size
        self.layer_idx = layer_idx
        
        self.qkv = nn.Linear(hidden_size, 3 * hidden_size, bias=False)
        self.out = nn.Linear(hidden_size, hidden_size, bias=False)
        
    def forward(self, x, attention_mask=None):
        B, T, D = x.shape
        qkv = self.qkv(x).view(B, T, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Sliding window + causal mask
        mask = torch.ones(T, T, device=x.device, dtype=torch.bool).triu(1)  # causal
        mask |= torch.ones(T, T, device=x.device, dtype=torch.bool).tril(-self.window_size)  # window
        
        attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        attn = F.softmax(attn, dim=-1)
        
        out = (attn @ v).transpose(1, 2).reshape(B, T, D)
        return self.out(out)

# Use FLA's SWA if available, otherwise custom
SlidingWindowAttention = FLA_SWA if FLA_SWA is not None else CustomSlidingWindowAttention

class HybridBlock(nn.Module):
    def __init__(self, d_model, is_attention, n_heads=8, window_size=2048,
                 expand_k=1.0, expand_v=2.0, layer_idx=0):
        super().__init__()
        self.is_attention = is_attention
        self.layer_idx = layer_idx
        self.norm = RMSNorm(d_model)
        
        if is_attention:
            self.layer = SlidingWindowAttention(
                hidden_size=d_model, num_heads=n_heads,
                window_size=window_size, layer_idx=layer_idx)
        else:
            # GatedDeltaNet from FLA
            self.layer = GatedDeltaNet(
                hidden_size=d_model, 
                expand_k=expand_k,
                expand_v=expand_v, 
                layer_idx=layer_idx)
    
    def forward(self, x, attention_mask=None, past_state=None, use_cache=False):
        residual = x
        x = self.norm(x)
        new_state = None
        
        if use_cache:
            if self.is_attention:
                x, new_state = self.layer(x, attention_mask=attention_mask,
                                          past_key_values=past_state, use_cache=True)
            else:
                x, new_state = self.layer(x, past_state=past_state, use_cache=True)
        else:
            if self.is_attention:
                x = self.layer(x, attention_mask=attention_mask)
            else:
                x = self.layer(x)
        
        return residual + x, new_state

class GroundThinkLM(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
        
        swa_indices = set(cfg.get_swa_layer_indices())
        self._swa_indices = swa_indices
        self._gdn_indices = set(range(cfg.n_layers)) - swa_indices
        
        self.blocks = nn.ModuleList()
        self.ffns = nn.ModuleList()
        for i in range(cfg.n_layers):
            self.blocks.append(HybridBlock(
                cfg.d_model, is_attention=(i in swa_indices),
                n_heads=cfg.n_heads, window_size=cfg.window_size,
                expand_k=cfg.expand_k, expand_v=cfg.expand_v, layer_idx=i))
            self.ffns.append(SwiGLUFFN(cfg.d_model))
        
        self.norm_f = RMSNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        if cfg.tie_weights:
            self.lm_head.weight = self.embed.weight
        
    def forward(self, input_ids, targets=None, attention_mask=None,
                past_states=None, use_cache=False):
        x = self.embed(input_ids)
        new_states = [] if use_cache else None
        
        for i, (block, ffn) in enumerate(zip(self.blocks, self.ffns)):
            past_state = past_states[i] if past_states else None
            if self.cfg.use_gradient_checkpointing and self.training and i in self._swa_indices:
                x = checkpoint(self._fwd_block, block, ffn, x, attention_mask, use_reentrant=False)
            else:
                x, state = block(x, attention_mask, past_state, use_cache)
                x = ffn(x)
                if use_cache: new_states.append(state)
        
        logits = self.lm_head(self.norm_f(x))
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) if targets is not None else None
        return logits, loss, new_states
    
    @staticmethod
    def _fwd_block(block, ffn, x, mask):
        x, _ = block(x, mask, None, False)
        return ffn(x)
    
    def get_layer_types(self):
        return ['swa' if i in self._swa_indices else 'gdn' for i in range(self.cfg.n_layers)]
    
    def count_parameters(self):
        c = {'embed': sum(p.numel() for p in self.embed.parameters()), 'gdn': 0, 'swa': 0, 'ffn': 0}
        for i, (b, f) in enumerate(zip(self.blocks, self.ffns)):
            bp, fp = sum(p.numel() for p in b.parameters()), sum(p.numel() for p in f.parameters())
            c['swa' if i in self._swa_indices else 'gdn'] += bp
            c['ffn'] += fp
        c['total'] = sum(c.values())
        return c

print("Model components defined")
print(f"Using SWA: {SlidingWindowAttention}")

In [None]:
# CELL 4: MONITORING

def print_gradient_summary(model):
    agg = {'embed': [], 'gdn': [], 'swa': [], 'ffn': []}
    for name, p in model.named_parameters():
        if p.grad is None: continue
        n = p.grad.norm().item()
        if 'embed' in name: agg['embed'].append(n)
        elif 'ffn' in name: agg['ffn'].append(n)
        elif 'blocks' in name:
            idx = int(name.split('.')[1])
            agg['swa' if idx in model._swa_indices else 'gdn'].append(n)
    print("Gradient Norms:")
    for k, v in agg.items():
        if v: print(f"  {k}: mean={np.mean(v):.3f} max={np.max(v):.2f}")

def needle_test(model, tokenizer, seq_len=512, n_trials=50, needle_token=50250, device="cuda"):
    model.eval()
    probs = []
    with torch.no_grad():
        for _ in range(n_trials):
            tokens = torch.randint(1000, 10000, (1, seq_len), device=device)
            pos = torch.randint(64, seq_len - 64, (1,)).item()
            tokens[0, pos] = needle_token
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                logits, _, _ = model(tokens)
            probs.append(F.softmax(logits[0, -1].float(), dim=-1)[needle_token].item())
    rc = 1.0 / tokenizer.vocab_size
    return {'mean': np.mean(probs), 'std': np.std(probs), 'ratio': np.mean(probs) / rc}

def probe_layers(model, tokenizer, needle_id=50250, seq_len=512, pos=256, device="cuda"):
    model.eval()
    tokens = torch.randint(1000, 10000, (1, seq_len), device=device)
    tokens[0, pos] = needle_id
    with torch.no_grad():
        x = model.embed(tokens)
        emb = model.embed.weight[needle_id].float()
        print(f"Needle representation:")
        for i, (b, f) in enumerate(zip(model.blocks, model.ffns)):
            x, _ = b(x, None, None, False)
            x = f(x)
            sim = F.cosine_similarity(x[0, pos].float(), emb, dim=0).item()
            t = "SWA" if i in model._swa_indices else "GDN"
            print(f"  L{i:2d}[{t}]: {sim:+.3f}")

print("Monitoring ready")

In [None]:
# CELL 5: DATA
from datasets import load_dataset
from tqdm import tqdm

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
MODEL_CFG.vocab_size = tokenizer.vocab_size

print(f"Streaming {TRAIN_CFG.dataset_name}...")
ds = load_dataset(TRAIN_CFG.dataset_name, name=TRAIN_CFG.dataset_subset, split="train", streaming=True)
buf = []
pbar = tqdm(total=TRAIN_CFG.target_tokens, unit="tok")
for ex in ds:
    toks = tokenizer.encode(ex['text']) + [tokenizer.eos_token_id]
    buf.extend(toks)
    pbar.update(len(toks))
    if len(buf) >= TRAIN_CFG.target_tokens: break
pbar.close()

all_tokens = torch.tensor(buf[:TRAIN_CFG.target_tokens], dtype=torch.long)
del buf, ds
print(f"Loaded {len(all_tokens):,} tokens")

def get_batch():
    ix = torch.randint(len(all_tokens) - TRAIN_CFG.seq_len - 1, (TRAIN_CFG.batch_size,))
    x = torch.stack([all_tokens[i:i+TRAIN_CFG.seq_len] for i in ix])
    y = torch.stack([all_tokens[i+1:i+TRAIN_CFG.seq_len+1] for i in ix])
    return x.to(DEVICE), y.to(DEVICE)

In [None]:
# CELL 6: BUILD MODEL
print("Building...")
model = GroundThinkLM(MODEL_CFG).to(DEVICE).to(torch.bfloat16)
p = model.count_parameters()
print(f"Params: {p['total']/1e6:.2f}M (GDN:{p['gdn']/1e6:.1f}M, SWA:{p['swa']/1e6:.1f}M, FFN:{p['ffn']/1e6:.1f}M)")
print(f"Layers: {model.get_layer_types()}")

x, y = get_batch()
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
    _, loss, _ = model(x, y)
loss.backward()
print(f"Test: loss={loss.item():.4f}, mem={torch.cuda.max_memory_allocated()/1e9:.2f}GB")
model.zero_grad()

In [None]:
# CELL 7: TRAIN
opt = torch.optim.AdamW(model.parameters(), lr=TRAIN_CFG.lr, betas=TRAIN_CFG.betas, weight_decay=TRAIN_CFG.weight_decay)
losses, niah_traj = [], []
start = time.time()

print(f"\nTRAINING {TRAIN_CFG.steps} steps, batch={TRAIN_CFG.effective_batch_size}\n")
model.train()

for step in range(TRAIN_CFG.steps):
    lr = TRAIN_CFG.lr * (step+1)/TRAIN_CFG.warmup_steps if step < TRAIN_CFG.warmup_steps else \
         TRAIN_CFG.lr * 0.5 * (1 + math.cos(math.pi * (step-TRAIN_CFG.warmup_steps)/(TRAIN_CFG.steps-TRAIN_CFG.warmup_steps)))
    for pg in opt.param_groups: pg['lr'] = lr
    
    acc_loss = 0
    for _ in range(TRAIN_CFG.accum_steps):
        x, y = get_batch()
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            _, loss, _ = model(x, y)
        (loss / TRAIN_CFG.accum_steps).backward()
        acc_loss += loss.item()
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), TRAIN_CFG.grad_clip)
    opt.step()
    opt.zero_grad()
    losses.append(acc_loss / TRAIN_CFG.accum_steps)
    
    if step % TRAIN_CFG.log_interval == 0:
        avg = np.mean(losses[-50:]) if len(losses) >= 50 else np.mean(losses)
        tps = (step+1) * TRAIN_CFG.effective_batch_size * TRAIN_CFG.seq_len / (time.time()-start)
        print(f"[{step:5d}] loss={avg:.4f} lr={lr:.2e} {tps:,.0f}tok/s")
    
    if (step+1) % TRAIN_CFG.grad_log_interval == 0:
        print_gradient_summary(model)
    
    if (step+1) in TRAIN_CFG.niah_checkpoints:
        n = needle_test(model, tokenizer, TRAIN_CFG.seq_len, 30, device=DEVICE)
        niah_traj.append((step+1, n['ratio']))
        print(f"  NIAH@{step+1}: {n['ratio']:.2f}x")
        model.train()

print(f"\nDone in {(time.time()-start)/60:.1f}min")
print(f"Loss: {np.mean(losses[:50]):.4f} -> {np.mean(losses[-50:]):.4f}")

In [None]:
# CELL 8: EVAL
print("\nFINAL EVAL")
for L in [128, 256, 512, 1024]:
    n = needle_test(model, tokenizer, L, 50, device=DEVICE)
    print(f"  NIAH@{L}: {n['ratio']:.2f}x")
probe_layers(model, tokenizer, device=DEVICE)
print(f"\nLM: {'PASS' if np.mean(losses[:50])-np.mean(losses[-50:])>2 else 'MARGINAL'}")
print(f"NIAH: {'PASS' if any(r>1 for _,r in niah_traj) else 'FAIL'}")

In [None]:
# CELL 9: SAVE
from datetime import datetime
rid = datetime.now().strftime("%Y%m%d_%H%M%S")
path = f"/content/drive/MyDrive/groundthink/colab-exports/v6_{rid}.pt"
torch.save({'state': model.state_dict(), 'cfg': MODEL_CFG, 'losses': losses, 'niah': niah_traj}, path)
print(f"Saved: {path}")