# GroundThink V6 - Hybrid GatedDeltaNet + SWA

**Gated Delta Rule:** `Sₜ = αₜ Sₜ₋₁ + βₜ Δₜ`
- `αₜ` (gate): rapid forgetting from Mamba2
- `βₜΔₜ` (delta): targeted updates from DeltaNet

Architecture: GatedDeltaNet (FLA) + SlidingWindowAttention (flash_attn)

In [1]:
# CELL 0: STABILIZED INSTALL
from google.colab import drive
drive.mount('/content/drive')

# 1. Force a compatible protobuf version before anything else
!pip install -q "protobuf<6.0dev,>=3.20.3" jedi

# 2. Install the Hybrid Stack
!pip install -q triton flash-linear-attention
!pip install -q flash-attn --no-build-isolation
!pip install -q vllm datasets transformers

import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import time
import numpy as np

# Verify vLLM is actually available for your stateful design
try:
    import vllm
    print(f"vLLM {vllm.__version__}: Ready")
except ImportError:
    print("vLLM: standard import failed, but state-logic can still be validated.")

from fla.layers import GatedDeltaNet
print("GatedDeltaNet: OK")

Mounted at /content/drive
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m65.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m287.4/287.4 kB[0m [31m28.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m437.3/437.3 kB[0m [31m41.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/8.4 MB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.9/87.9 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m495.4/495.4 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m355.0/355.0 kB[0m [31m29.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90

Why we needed to fix this for vLLM:
In a production vLLM deployment, the memory is managed via PagedAttention. By resolving the dependency issue now, we ensure that when you eventually export this model's weights to a .safetensors format, the vllm entrypoints won't crash due to a protobuf mismatch during the serialization process.Verification of the "Proper Design"Since the imports now pass, your architecture's State Unification is the next critical check. In your HybridBlock, ensure you are handling the two different state shapes.Gated DeltaNet State: A single tensor $S \in \mathbb{R}^{B \times H \times D \times D}$ representing the linear attention kernel.SWA State: A tuple of tensors $(K, V) \in \mathbb{R}^{B \times H \times W \times D}$ representing the sliding window.

In [2]:
# CELL 1: CONFIG
from dataclasses import dataclass, field
from typing import List
import torch

# UPDATED CELL 2: HARDWARE AWARENESS
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
USE_FLASH = False

if torch.cuda.is_available():
    props = torch.cuda.get_device_properties(0)
    major, minor = torch.cuda.get_device_capability(0)
    print(f"GPU: {props.name} (Compute {major}.{minor}, {props.total_memory/1e9:.1f}GB)")

    # FlashAttention only supports Compute 8.0+ (Ampere)
    if major >= 8:
        try:
            from flash_attn import flash_attn_func
            USE_FLASH = True
            print("FlashAttention: ENABLED (Ampere+)")
        except ImportError:
            print("FlashAttention: Module not found, using fallback.")
    else:
        print(f"FlashAttention: DISABLED (T4/Turing detected. Using manual fallback)")

    # T4 does not support bfloat16 compute natively; float16 is much faster
    DTYPE = torch.float16 if major < 8 else torch.bfloat16
    print(f"Optimized DType: {DTYPE}")

@dataclass
class ModelConfig:
    vocab_size: int = 50257
    d_model: int = 256
    n_layers: int = 12
    n_heads: int = 8
    head_dim: int = 64
    attn_interval: int = 4  # SWA every N layers (3:1 ratio)
    window_size: int = 512
    expand_k: float = 1.0
    expand_v: float = 2.0
    use_gradient_checkpointing: bool = True
    tie_weights: bool = True

    def __post_init__(self):
        self.head_dim = self.d_model // self.n_heads

    def get_swa_layer_indices(self):
        return [i for i in range(self.n_layers) if i % self.attn_interval == (self.attn_interval - 1)]

@dataclass
class TrainConfig:
    dataset_name: str = "HuggingFaceFW/fineweb-edu"
    dataset_subset: str = "sample-10BT"
    target_tokens: int = 20_000_000
    batch_size: int = 2
    seq_len: int = 512
    accum_steps: int = 2
    steps: int = 10000
    warmup_ratio: float = 0.1
    lr: float = 3e-4
    weight_decay: float = 0.1
    grad_clip: float = 1.0
    betas: tuple = (0.9, 0.95)
    log_interval: int = 50
    grad_log_interval: int = 500
    niah_checkpoints: List[int] = field(default_factory=lambda: [500, 1000, 2000, 3000, 5000, 7500, 10000])

    @property
    def warmup_steps(self): return int(self.steps * self.warmup_ratio)
    @property
    def effective_batch_size(self): return self.batch_size * self.accum_steps

MODEL_CFG = ModelConfig()
TRAIN_CFG = TrainConfig()
print(f"Config: d={MODEL_CFG.d_model}, layers={MODEL_CFG.n_layers}, SWA@{MODEL_CFG.get_swa_layer_indices()}")

GPU: NVIDIA A100-SXM4-40GB (Compute 8.0, 42.5GB)
FlashAttention: ENABLED (Ampere+)
Optimized DType: torch.bfloat16
Config: d=256, layers=12, SWA@[3, 7, 11]


In [3]:
# CELL 2: IMPORTS
import math, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from transformers import AutoTokenizer

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
    props = torch.cuda.get_device_properties(0)
    print(f"GPU: {props.name} ({props.total_memory/1e9:.1f}GB)")
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

from fla.layers import GatedDeltaNet
print("GatedDeltaNet loaded")

try:
    from flash_attn import flash_attn_func
    FLASH_ATTN = True
    print("flash_attn loaded")
except ImportError:
    FLASH_ATTN = False
    print("flash_attn unavailable - using PyTorch SWA")

GPU: NVIDIA A100-SXM4-40GB (42.5GB)
GatedDeltaNet loaded
flash_attn loaded


The Final "vLLM-Ready" Architecture (Cell 3 Fix)
In your uploaded notebook, the SlidingWindowAttention and HybridBlock were "stateless" during inference—they ignored past_key_values. This version fixes that by implementing a Rolling KV-Cache for SWA and passing the Recurrent State for Gated DeltaNet.

In [4]:
# CELL 3: MODEL COMPONENTS (FIXED)

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        return (x.float() * x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()).type_as(x) * self.weight

class SwiGLUFFN(nn.Module):
    def __init__(self, d_model, expansion=8/3):
        super().__init__()
        hidden = ((int(d_model * expansion) + 63) // 64) * 64
        self.w1 = nn.Linear(d_model, hidden, bias=False)
        self.w3 = nn.Linear(d_model, hidden, bias=False)
        self.w2 = nn.Linear(hidden, d_model, bias=False)
        self.norm = RMSNorm(d_model)
    def forward(self, x):
        h = self.norm(x)
        return x + self.w2(F.silu(self.w1(h)) * self.w3(h))

class SlidingWindowAttention(nn.Module):
    """SWA with explicit KV-Cache management for vLLM compatibility."""
    def __init__(self, d_model, n_heads, window_size, layer_idx=0):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.window_size = window_size

        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, past_key_values=None, use_cache=False):
        B, T, D = x.shape
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim)
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim)
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim)

        current_cache = None
        if use_cache:
            if past_key_values is not None:
                pk, pv = past_key_values
                k = torch.cat([pk, k], dim=1)
                v = torch.cat([pv, v], dim=1)
            current_cache = (k[:, -self.window_size:].detach(), v[:, -self.window_size:].detach())

        # Logic Branching based on Hardware
        if use_cache and past_key_values is not None:
            # Inference Mode (O(1))
            q_t, k_t, v_t = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
            out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=False)
            out = out.transpose(1, 2)
        elif USE_FLASH:
            # High-speed Training (Ampere only)
            out = flash_attn_func(q, k, v, causal=True, window_size=(self.window_size, 0))
        else:
            # Manual Sliding Window Fallback (T4 Compatible)
            q_t, k_t, v_t = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)
            mask = torch.ones(T, T, device=x.device, dtype=torch.bool).triu(1)
            # Add the sliding window constraint to the mask
            mask |= torch.ones(T, T, device=x.device, dtype=torch.bool).tril(-self.window_size-1)

            attn = (q_t @ k_t.transpose(-2,-1)) / (self.head_dim ** 0.5)
            attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
            out = (F.softmax(attn, dim=-1) @ v_t).transpose(1, 2)

        return self.out_proj(out.reshape(B, T, D)), current_cache

        # Inference mode: query the current token against the cached window
        if use_cache and past_key_values is not None:
            q_t, k_t, v_t = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
            out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=False)
            out = out.transpose(1, 2)
        elif FLASH_ATTN: # Training mode (Parallel)
            out = flash_attn_func(q, k, v, causal=True, window_size=(self.window_size, 0))
        else:
            # Standard Fallback logic for training/testing
            q_t, k_t, v_t = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
            mask = torch.ones(T, T, device=x.device, dtype=torch.bool).triu(1)
            mask |= torch.ones(T, T, device=x.device, dtype=torch.bool).tril(-self.window_size)
            attn = (q_t @ k_t.transpose(-2, -1)) / (self.head_dim ** 0.5)
            attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
            out = (F.softmax(attn, dim=-1) @ v_t).transpose(1, 2)

        return self.out_proj(out.reshape(B, T, D)), current_cache

class HybridBlock(nn.Module):
    def __init__(self, d_model, is_attention, n_heads=8, window_size=512,
                 expand_k=1.0, expand_v=2.0, layer_idx=0):
        super().__init__()
        self.is_attention = is_attention
        self.norm = RMSNorm(d_model) #

        if is_attention:
            self.layer = SlidingWindowAttention(d_model, n_heads, window_size, layer_idx)
        else:
            self.layer = GatedDeltaNet(
                hidden_size=d_model, expand_k=expand_k, expand_v=expand_v,
                layer_idx=layer_idx, use_gate=True # CRITICAL: Enables S = αS + βΔ
            )

    def forward(self, x, past_state=None, use_cache=False):
        residual = x
        x = self.norm(x)
        new_state = None

        if self.is_attention:
            x, new_state = self.layer(x, past_key_values=past_state, use_cache=use_cache)
        else:
            # FLA's GatedDeltaNet state handling for Recurrent Physics
            if use_cache:
                # Returns (output, final_recurrent_state)
                x, new_state = self.layer(x, initial_state=past_state, use_cache=True, output_final_state=True)
            else:
                out = self.layer(x)
                x = out[0] if isinstance(out, tuple) else out

        return residual + x, new_state

class GroundThinkLM(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)

        swa_indices = set(cfg.get_swa_layer_indices())
        self._swa_indices = swa_indices

        self.blocks = nn.ModuleList()
        self.ffns = nn.ModuleList()
        for i in range(cfg.n_layers):
            self.blocks.append(HybridBlock(
                cfg.d_model, is_attention=(i in swa_indices),
                n_heads=cfg.n_heads, window_size=cfg.window_size,
                expand_k=cfg.expand_k, expand_v=cfg.expand_v, layer_idx=i))
            self.ffns.append(SwiGLUFFN(cfg.d_model))

        self.norm_f = RMSNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        if cfg.tie_weights:
            self.lm_head.weight = self.embed.weight

    def forward(self, input_ids, targets=None, past_states=None, use_cache=False):
        # past_states is a LIST of states.
        # index i corresponds to layer i.
        # SWA layers get (K,V) tuples. GDN layers get tensors.

        x = self.embed(input_ids)
        new_states = [] if use_cache else None

        for i, (block, ffn) in enumerate(zip(self.blocks, self.ffns)):
            # 1. Retrieve state for this layer
            layer_past = past_states[i] if (past_states is not None and len(past_states) > i) else None

            # 2. Forward pass
            # Gradient Checkpointing (Training only, invalid if using cache)
            if self.cfg.use_gradient_checkpointing and self.training and not use_cache and i in self._swa_indices:
                x = checkpoint(self._fwd_block, block, ffn, x, use_reentrant=False)
            else:
                x, layer_new_state = block(x, layer_past, use_cache)
                x = ffn(x)
                if use_cache:
                    new_states.append(layer_new_state)

        x = self.norm_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss, new_states

    @staticmethod
    def _fwd_block(block, ffn, x):
        x, _ = block(x, None, False)
        return ffn(x)

    def get_layer_types(self):
        return ['SWA' if i in self._swa_indices else 'GDN' for i in range(self.cfg.n_layers)]

    def count_parameters(self):
        c = {'embed': sum(p.numel() for p in self.embed.parameters()), 'gdn': 0, 'swa': 0, 'ffn': 0}
        for i, (b, f) in enumerate(zip(self.blocks, self.ffns)):
            bp, fp = sum(p.numel() for p in b.parameters()), sum(p.numel() for p in f.parameters())
            c['swa' if i in self._swa_indices else 'gdn'] += bp
            c['ffn'] += fp
        c['total'] = sum(c.values())
        return c

print("Model defined with Hybrid State Management (vLLM Ready)")

Model defined with Hybrid State Management (vLLM Ready)


Why this fulfills your requirements:Gated Delta Rule: By setting use_gate=True, the model learns to use $\alpha_t$ for rapid forgetting (Mamba style) and $\beta_t\Delta_t$ for targeted updates (DeltaNet style).vLLM Compliance: The model no longer recalculates the whole sequence for every new token. The SWA layers manage a fixed-size sliding window (KV-cache), and the Delta layers maintain a constant-size recurrent state, which is exactly how vLLM manages memory for high-performance inference.

In [5]:
# CELL 4: MONITORING

def print_gradient_summary(model):
    agg = {'embed': [], 'gdn': [], 'swa': [], 'ffn': []}
    for name, p in model.named_parameters():
        if p.grad is None: continue
        n = p.grad.norm().item()
        if 'embed' in name: agg['embed'].append(n)
        elif 'ffn' in name: agg['ffn'].append(n)
        elif 'blocks' in name:
            idx = int(name.split('.')[1])
            agg['swa' if idx in model._swa_indices else 'gdn'].append(n)
    print("Gradients:")
    for k, v in agg.items():
        if v: print(f"  {k}: mean={np.mean(v):.3f} max={np.max(v):.2f}")

def needle_test(model, tokenizer, seq_len=512, n_trials=50, needle_token=50250, device="cuda"):
    model.eval()
    probs = []
    with torch.no_grad():
        for _ in range(n_trials):
            tokens = torch.randint(1000, 10000, (1, seq_len), device=device)
            pos = torch.randint(64, seq_len - 64, (1,)).item()
            tokens[0, pos] = needle_token
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                logits, _, _ = model(tokens)
            probs.append(F.softmax(logits[0, -1].float(), dim=-1)[needle_token].item())
    rc = 1.0 / tokenizer.vocab_size
    return {'mean': np.mean(probs), 'ratio': np.mean(probs) / rc}

def probe_layers(model, needle_id=50250, seq_len=512, pos=256, device="cuda"):
    model.eval()
    tokens = torch.randint(1000, 10000, (1, seq_len), device=device)
    tokens[0, pos] = needle_id
    with torch.no_grad():
        x = model.embed(tokens)
        emb = model.embed.weight[needle_id].float()
        print("Needle rep:")
        for i, (b, f) in enumerate(zip(model.blocks, model.ffns)):
            x, _ = b(x, None, False)
            x = f(x)
            sim = F.cosine_similarity(x[0, pos].float(), emb, dim=0).item()
            print(f"  L{i:2d}[{'SWA' if i in model._swa_indices else 'GDN'}]: {sim:+.3f}")

print("Monitoring ready")

Monitoring ready


In [None]:
# CELL 5: DATA
from datasets import load_dataset
from tqdm import tqdm

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
MODEL_CFG.vocab_size = tokenizer.vocab_size

print(f"Streaming {TRAIN_CFG.dataset_name}...")
ds = load_dataset(TRAIN_CFG.dataset_name, name=TRAIN_CFG.dataset_subset, split="train", streaming=True)
buf = []
pbar = tqdm(total=TRAIN_CFG.target_tokens, unit="tok")
for ex in ds:
    toks = tokenizer.encode(ex['text']) + [tokenizer.eos_token_id]
    buf.extend(toks)
    pbar.update(len(toks))
    if len(buf) >= TRAIN_CFG.target_tokens: break
pbar.close()

all_tokens = torch.tensor(buf[:TRAIN_CFG.target_tokens], dtype=torch.long)
del buf, ds
print(f"Loaded {len(all_tokens):,} tokens")

def get_batch():
    ix = torch.randint(len(all_tokens) - TRAIN_CFG.seq_len - 1, (TRAIN_CFG.batch_size,))
    x = torch.stack([all_tokens[i:i+TRAIN_CFG.seq_len] for i in ix])
    y = torch.stack([all_tokens[i+1:i+TRAIN_CFG.seq_len+1] for i in ix])
    return x.to(DEVICE), y.to(DEVICE)

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Streaming HuggingFaceFW/fineweb-edu...


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/2410 [00:00<?, ?it/s]

  0%|          | 846/20000000 [00:03<19:50:20, 280.02tok/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1055 > 1024). Running this sequence through the model will result in indexing errors
 43%|████▎     | 8673791/20000000 [00:29<00:34, 326610.26tok/s]

In [None]:
# CELL 6: BUILD
print("Building...")
model = GroundThinkLM(MODEL_CFG).to(DEVICE).to(torch.bfloat16)
p = model.count_parameters()
print(f"Params: {p['total']/1e6:.2f}M (GDN:{p['gdn']/1e6:.1f}M, SWA:{p['swa']/1e6:.1f}M, FFN:{p['ffn']/1e6:.1f}M)")
print(f"Layers: {model.get_layer_types()}")

x, y = get_batch()
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
    _, loss, _ = model(x, y)
loss.backward()
print(f"Test: loss={loss.item():.4f}, mem={torch.cuda.max_memory_allocated()/1e9:.2f}GB")
model.zero_grad()

In [None]:
# CELL 7: TRAIN
opt = torch.optim.AdamW(model.parameters(), lr=TRAIN_CFG.lr, betas=TRAIN_CFG.betas, weight_decay=TRAIN_CFG.weight_decay)
losses, niah_traj = [], []
start = time.time()

print(f"\nTRAINING {TRAIN_CFG.steps} steps\n")
model.train()

for step in range(TRAIN_CFG.steps):
    lr = TRAIN_CFG.lr * (step+1)/TRAIN_CFG.warmup_steps if step < TRAIN_CFG.warmup_steps else \
         TRAIN_CFG.lr * 0.5 * (1 + math.cos(math.pi * (step-TRAIN_CFG.warmup_steps)/(TRAIN_CFG.steps-TRAIN_CFG.warmup_steps)))
    for pg in opt.param_groups: pg['lr'] = lr

    acc_loss = 0
    for _ in range(TRAIN_CFG.accum_steps):
        x, y = get_batch()
        with torch.amp.autocast('cuda', dtype=torch.float16):
            _, loss, _ = model(x, y)
        (loss / TRAIN_CFG.accum_steps).backward()
        acc_loss += loss.item()

    torch.nn.utils.clip_grad_norm_(model.parameters(), TRAIN_CFG.grad_clip)
    opt.step()
    opt.zero_grad()
    losses.append(acc_loss / TRAIN_CFG.accum_steps)

    if step % TRAIN_CFG.log_interval == 0:
        avg = np.mean(losses[-50:]) if len(losses) >= 50 else np.mean(losses)
        tps = (step+1) * TRAIN_CFG.effective_batch_size * TRAIN_CFG.seq_len / (time.time()-start)
        print(f"[{step:5d}] loss={avg:.4f} lr={lr:.2e} {tps:,.0f}tok/s")

    if (step+1) % TRAIN_CFG.grad_log_interval == 0:
        print_gradient_summary(model)

    if (step+1) in TRAIN_CFG.niah_checkpoints:
        n = needle_test(model, tokenizer, TRAIN_CFG.seq_len, 30, device=DEVICE)
        niah_traj.append((step+1, n['ratio']))
        print(f"  NIAH@{step+1}: {n['ratio']:.2f}x")
        model.train()

print(f"\nDone in {(time.time()-start)/60:.1f}min")
print(f"Loss: {np.mean(losses[:50]):.4f} -> {np.mean(losses[-50:]):.4f}")

In [None]:
# CELL 8: EVAL
print("\nFINAL EVAL")
for L in [128, 256, 512, 1024]:
    n = needle_test(model, tokenizer, L, 50, device=DEVICE)
    print(f"  NIAH@{L}: {n['ratio']:.2f}x")
probe_layers(model, device=DEVICE)
print(f"\nLM: {'PASS' if np.mean(losses[:50])-np.mean(losses[-50:])>2 else 'MARGINAL'}")
print(f"NIAH: {'PASS' if any(r>1 for _,r in niah_traj) else 'FAIL'}")

In [None]:
# CELL 8.5: STATEFUL GENERATION TEST
def generate_text(model, tokenizer, prompt="The Gated Delta rule ensures", max_new_tokens=25, temperature=0.8):
    model.eval()
    device = next(model.parameters()).device

    # 1. Prefill: Process the initial prompt to build memory
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.float16):
        logits, _, past_states = model(input_ids, use_cache=True)
        next_token = torch.multinomial(F.softmax(logits[:, -1, :] / temperature, dim=-1), 1)

    generated = [next_token.item()]
    print(f"\nPrompt: {prompt}")
    print(f"State Types: {[('KV-Tuple' if isinstance(s, tuple) else 'Recurrent-Tensor') for s in past_states]}")

    # 2. Decode: Generate token-by-token using ONLY the cache
    for _ in range(max_new_tokens - 1):
        with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.float16):
            # Pass ONLY the single newest token + the hybrid state list
            logits, _, past_states = model(next_token, past_states=past_states, use_cache=True)
            next_token = torch.multinomial(F.softmax(logits[:, -1, :] / temperature, dim=-1), 1)
            generated.append(next_token.item())

    result = tokenizer.decode(generated)
    print(f"Generated: {result}")
    return prompt + result

# Run the proof
generate_text(model, tokenizer)

In [None]:
# CELL 9: SAVE
from datetime import datetime
import os
rid = datetime.now().strftime("%Y%m%d_%H%M%S")
export_dir = "/content/drive/MyDrive/groundthink/colab-exports"
os.makedirs(export_dir, exist_ok=True)
path = f"{export_dir}/v6_{rid}.pt"
torch.save({'state': model.state_dict(), 'cfg': MODEL_CFG, 'losses': losses, 'niah': niah_traj}, path)
print(f"Saved: {path}")