# Wave Field LLM V4.3.3 â€” Training & Diagnostics

**Physics-based O(n log n) attention vs Standard O(nÂ²) Transformer**

This notebook trains both architectures on WikiText-103 with full monitoring:
- Live training curves with `tqdm` progress bars
- Physics diagnostics at every checkpoint (kernel health, gate activity, rank)
- Side-by-side comparison tables via `pandas`
- Generation samples at each checkpoint
- Speed & memory scaling benchmarks

**GPU:** T4 (15.6 GB) â€” supports up to ~120M params with gradient checkpointing

**Architecture:** Wave Field V4.3.3 (SPECTRE-Wave) with:
- Learned feature maps (Hedgehog, ICLR 2024)
- HiPPO kernel init (S4D)
- Content-adaptive spectral gate (SPECTRE)
- Per-group LR (kernel params at 50x)
- `torch.compile` on non-FFT submodules
- cuFFT-optimal padding

In [None]:
!nvidia-smi

In [None]:
# ============================================================
# CELL 1: SETUP â€” Clone repo + install deps
# ============================================================
import os

REPO = '/content/wave-field-llm'
if not os.path.exists(REPO):
    !git clone https://github.com/Pankh-AI/wave-field-llm.git {REPO}
else:
    !cd {REPO} && git pull --ff-only

os.chdir(REPO)
!pip install -q datasets tokenizers matplotlib tqdm pandas ipywidgets
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('results', exist_ok=True)
print(f"Working dir: {os.getcwd()}")
print("Setup complete")

In [None]:
# ============================================================
# CELL 2: IMPORTS + CONFIG
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import time, math, gc, sys, json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from tqdm.notebook import tqdm, trange
from IPython.display import display, HTML, clear_output
from collections import defaultdict

sys.path.insert(0, '.')
from src.wave_field_transformer import WaveFieldTransformer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0

display(HTML(f"""
<div style='background:#1a1a2e;color:#e0e0e0;padding:15px;border-radius:8px;font-family:monospace'>
  <h3 style='color:#00d4ff;margin:0'>Wave Field LLM V4.3.3</h3>
  <p>GPU: <b>{gpu_name}</b> | VRAM: <b>{vram_gb:.1f} GB</b> | PyTorch: <b>{torch.__version__}</b></p>
</div>
"""))

# ============== MODEL CONFIG ==============
# Two scale options â€” pick one:

SCALE = 'S1'  # 'S1' = ~22M (fast, proof-of-concept) | 'S3' = ~120M (full scale)

CONFIGS = {
    'S1': dict(embed_dim=384, num_layers=8, num_heads=8, ffn_dim=1536,
              field_size=1536, seq_len=512, batch_size=8, grad_accum=4,
              peak_lr=6e-4, total_steps=2000, eval_every=200, ckpt_every=500),
    'S3': dict(embed_dim=768, num_layers=12, num_heads=12, ffn_dim=3072,
              field_size=2048, seq_len=512, batch_size=4, grad_accum=8,
              peak_lr=3e-4, total_steps=6000, eval_every=500, ckpt_every=1000),
}

CFG = CONFIGS[SCALE]
BPE_VOCAB = 8192
USE_FP16 = True   # T4: fp16 + GradScaler (NOT bf16 â€” T4 bf16 is emulated)
CKPT_DIR = Path('checkpoints')

print(f"Scale: {SCALE} | embed={CFG['embed_dim']} layers={CFG['num_layers']} "
      f"heads={CFG['num_heads']} ffn={CFG['ffn_dim']} field={CFG['field_size']}")
print(f"Seq: {CFG['seq_len']} | Batch: {CFG['batch_size']} x {CFG['grad_accum']} accum "
      f"= {CFG['batch_size'] * CFG['grad_accum']} effective")
print(f"Steps: {CFG['total_steps']} | Eval every: {CFG['eval_every']} | Ckpt every: {CFG['ckpt_every']}")

In [None]:
# ============================================================
# CELL 3: STANDARD TRANSFORMER BASELINE
# ============================================================

class StandardTransformer(nn.Module):
    """Standard O(n^2) Transformer for fair comparison.
    Same embedding dim, layers, heads, FFN. Uses PyTorch's built-in
    TransformerEncoder with pre-norm and GELU."""

    def __init__(self, vocab_size, embedding_dim=256, num_layers=6,
                 num_heads=8, ffn_dim=1024, max_seq_len=514, dropout=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.max_seq_len = max_seq_len
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.positional_embedding = nn.Embedding(max_seq_len, embedding_dim)
        self.dropout = nn.Dropout(dropout)
        layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim, nhead=num_heads, dim_feedforward=ffn_dim,
            dropout=dropout, activation='gelu', batch_first=True, norm_first=True)
        self.transformer = nn.TransformerEncoder(layer, num_layers=num_layers)
        self.norm = nn.LayerNorm(embedding_dim)
        self.output_projection = nn.Linear(embedding_dim, vocab_size, bias=False)
        self.output_projection.weight = self.token_embedding.weight
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, 0, 0.02)

    def forward(self, input_ids, labels=None, mask=None):
        if input_ids.dim() == 1:
            input_ids = input_ids.unsqueeze(0)
        B, N = input_ids.shape
        pos = torch.arange(N, device=input_ids.device).unsqueeze(0).expand(B, -1)
        x = self.token_embedding(input_ids) + self.positional_embedding(pos)
        x = self.dropout(x)
        causal = torch.triu(torch.full((N, N), float('-inf'),
                            device=input_ids.device), diagonal=1)
        x = self.transformer(x, mask=causal)
        x = self.norm(x)
        logits = self.output_projection(x)
        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits.view(-1, self.vocab_size),
                                  labels.view(-1), ignore_index=-100)
        return logits, loss

print("StandardTransformer defined")

In [None]:
# ============================================================
# CELL 4: DATA â€” WikiText-103 (103M tokens, proper scale)
# ============================================================
from datasets import load_dataset
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders

print("Loading WikiText-103...")
ds = load_dataset('wikitext', 'wikitext-103-raw-v1')

# Filter empty lines and headers
train_texts = [t for t in ds['train']['text'] if len(t.strip()) > 100]
val_texts = [t for t in ds['validation']['text'] if len(t.strip()) > 100]
test_texts = [t for t in ds['test']['text'] if len(t.strip()) > 100]

print(f"  Train docs: {len(train_texts):,} | Val docs: {len(val_texts):,} | Test: {len(test_texts):,}")
print(f"  Est. train chars: {sum(len(t) for t in train_texts):,}")

# Train BPE tokenizer
print(f"\nTraining BPE ({BPE_VOCAB} vocab)...")
raw_tok = Tokenizer(models.BPE())
raw_tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
raw_tok.decoder = decoders.ByteLevel()
tok_trainer = trainers.BpeTrainer(
    vocab_size=BPE_VOCAB,
    special_tokens=["<pad>", "<unk>", "<bos>", "<eos>"],
    min_frequency=2)
raw_tok.train_from_iterator(train_texts[:20000], tok_trainer)
VOCAB_SIZE = raw_tok.get_vocab_size()
print(f"  Vocab: {VOCAB_SIZE}")

# Tokenize into flat streams (batch encode â€” 10-20x faster than one-by-one)
print("Tokenizing (batch mode)...")
BATCH = 4096
train_ids = []
for i in tqdm(range(0, len(train_texts), BATCH), desc='Train', leave=False):
    batch = train_texts[i:i+BATCH]
    for enc in raw_tok.encode_batch(batch):
        if enc.ids:
            train_ids.extend(enc.ids)

val_ids = []
for i in range(0, len(val_texts), BATCH):
    for enc in raw_tok.encode_batch(val_texts[i:i+BATCH]):
        if enc.ids:
            val_ids.extend(enc.ids)

print(f"  Train tokens: {len(train_ids):,} ({len(train_ids)/1e6:.1f}M)")
print(f"  Val tokens:   {len(val_ids):,} ({len(val_ids)/1e6:.1f}M)")

# ---- Data utilities ----
def make_chunks_gpu(token_ids, seq_len):
    """Create (input, target) chunk tensors on GPU. Zero CPU->GPU overhead."""
    all_ids = torch.tensor(token_ids, dtype=torch.long, device=device)
    n_chunks = (len(token_ids) - 1) // seq_len
    if n_chunks == 0:
        return None, None
    usable = n_chunks * seq_len
    x = all_ids[:usable].reshape(n_chunks, seq_len)
    y = all_ids[1:usable + 1].reshape(n_chunks, seq_len)
    return x, y

def make_batches_gpu(x, y, batch_size, shuffle=True):
    """Yield batches from GPU-resident tensors."""
    n = x.shape[0]
    if n < batch_size:
        return
    idx = torch.randperm(n, device=x.device) if shuffle else torch.arange(n, device=x.device)
    for s in range(0, n - batch_size + 1, batch_size):
        bi = idx[s:s + batch_size]
        yield x[bi], y[bi]

# Pre-chunk for training
seq_len = CFG['seq_len']
train_x, train_y = make_chunks_gpu(train_ids, seq_len)
val_x, val_y = make_chunks_gpu(val_ids, seq_len)
print(f"\nGPU-resident chunks (seq={seq_len}):")
print(f"  Train: {train_x.shape[0]:,} | Val: {val_x.shape[0]:,}")
print("Data ready")

In [None]:
# ============================================================
# CELL 5: TRAINING ENGINE + PHYSICS MONITOR
# ============================================================

class WarmupCosineScheduler:
    """Linear warmup then cosine decay."""
    def __init__(self, optimizer, warmup, total, min_lr=1e-5):
        self.optimizer, self.warmup, self.total = optimizer, warmup, total
        self.min_lr = min_lr
        self.base_lrs = [pg['lr'] for pg in optimizer.param_groups]
        self.step_count = 0

    def step(self):
        self.step_count += 1
        for pg, blr in zip(self.optimizer.param_groups, self.base_lrs):
            if self.step_count <= self.warmup:
                pg['lr'] = blr * (self.step_count / self.warmup)
            else:
                p = (self.step_count - self.warmup) / max(1, self.total - self.warmup)
                pg['lr'] = self.min_lr + 0.5 * (blr - self.min_lr) * (1 + math.cos(math.pi * p))

    def get_lr(self):
        return [pg['lr'] for pg in self.optimizer.param_groups]


class PhysicsMonitor:
    """Lightweight diagnostics for Wave Field models.

    Tracks per-checkpoint:
    - Kernel params (frequency, damping, phase) per layer/head
    - Gate activation statistics
    - Output rank (SVD) per layer
    - Gradient norms per parameter group
    - Feature map activation stats
    """

    def __init__(self):
        self.snapshots = []

    @torch.no_grad()
    def snapshot(self, model, step, sample_x=None):
        """Capture physics state. Call during eval (model.eval())."""
        snap = {'step': step, 'layers': []}

        for li, layer in enumerate(model.layers):
            attn = layer.attention
            H = attn.num_heads
            layer_data = {'layer': li}

            # 1. Kernel parameters
            freq = attn.wave_frequency.detach().cpu().numpy()
            damp = F.softplus(attn.wave_damping).detach().cpu().numpy()
            phase = attn.wave_phase.detach().cpu().numpy()
            layer_data['freq'] = freq.tolist()
            layer_data['damp'] = damp.tolist()
            layer_data['phase'] = phase.tolist()

            # 2. Gate bias (should stay ~2.0 for healthy training)
            D = attn.embedding_dim
            gate_bias = attn.qkvg_proj.bias[3*D:].detach().cpu()
            layer_data['gate_bias_mean'] = gate_bias.mean().item()

            # 3. Field coupling strength
            coupling = F.softmax(attn.field_coupling.detach().cpu(), dim=-1)
            off_diag = coupling.clone()
            off_diag.fill_diagonal_(0)
            layer_data['cross_coupling'] = off_diag.sum().item()

            # 4. Gradient norms (if available)
            kernel_grad_norm = 0.0
            for name in ['wave_frequency', 'wave_damping', 'wave_phase']:
                p = getattr(attn, name)
                if p.grad is not None:
                    kernel_grad_norm += p.grad.norm().item()
            layer_data['kernel_grad_norm'] = kernel_grad_norm

            snap['layers'].append(layer_data)

        # 5. Output rank from a sample forward pass
        if sample_x is not None and hasattr(model, 'layers'):
            x = model.token_embedding(sample_x[:2])
            pos = model.positional_encoding(sample_x.shape[1], sample_x.device)
            x = x + pos.unsqueeze(0)
            x = model.dropout(x)
            for li, layer_mod in enumerate(model.layers):
                x = layer_mod(x)
                # Rank: number of singular values > 1% of max
                s = torch.linalg.svdvals(x[0].float())
                rank = (s > s[0] * 0.01).sum().item()
                snap['layers'][li]['output_rank'] = rank
                snap['layers'][li]['norm_mean'] = x.norm(dim=-1).mean().item()

        self.snapshots.append(snap)
        return snap

    def summary_df(self):
        """Return a pandas DataFrame of kernel params across snapshots."""
        rows = []
        for snap in self.snapshots:
            for ld in snap['layers']:
                for h in range(len(ld['freq'])):
                    rows.append({
                        'step': snap['step'],
                        'layer': ld['layer'],
                        'head': h,
                        'freq': ld['freq'][h],
                        'damp': ld['damp'][h],
                        'phase': ld['phase'][h],
                        'gate_bias': ld.get('gate_bias_mean', 0),
                        'cross_coupling': ld.get('cross_coupling', 0),
                        'kernel_grad': ld.get('kernel_grad_norm', 0),
                        'output_rank': ld.get('output_rank', -1),
                        'norm_mean': ld.get('norm_mean', -1),
                    })
        return pd.DataFrame(rows)


def evaluate(model, val_x, val_y, batch_size, vocab_size):
    """Evaluate model on val data. Returns (loss, ppl, accuracy)."""
    model.eval()
    total_loss, total_correct, total_tokens, n = 0, 0, 0, 0
    n_val = min(val_x.shape[0], 500)
    with torch.no_grad():
        for x, y in make_batches_gpu(val_x[:n_val], val_y[:n_val],
                                     batch_size, shuffle=False):
            with torch.amp.autocast('cuda', enabled=USE_FP16):
                logits, _ = model(x)
                loss = F.cross_entropy(logits.reshape(-1, vocab_size), y.reshape(-1))
            total_loss += loss.item()
            preds = logits.argmax(dim=-1)
            total_correct += (preds == y).sum().item()
            total_tokens += y.numel()
            n += 1
    avg_loss = total_loss / max(n, 1)
    ppl = math.exp(min(avg_loss, 20))
    acc = total_correct / max(total_tokens, 1) * 100
    return avg_loss, ppl, acc


@torch.no_grad()
def generate_text(model, tokenizer, prompt="The", max_tokens=80,
                  temperature=0.8, top_k=40):
    """Autoregressive text generation."""
    model.eval()
    ids = tokenizer.encode(prompt).ids
    if not ids:
        return "(empty prompt)"
    input_ids = torch.tensor([ids], device=device)
    max_ctx = getattr(model, 'max_seq_len', 2048)
    if hasattr(model, 'positional_embedding'):
        max_ctx = model.positional_embedding.weight.shape[0]

    for _ in range(max_tokens):
        ctx = input_ids[:, -max_ctx:] if input_ids.shape[1] > max_ctx else input_ids
        with torch.amp.autocast('cuda', enabled=USE_FP16):
            logits, _ = model(ctx)
        logits = logits[0, -1] / temperature
        if top_k > 0:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[-1]] = float('-inf')
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, 1)
        input_ids = torch.cat([input_ids, next_id.unsqueeze(0)], dim=1)
    return tokenizer.decode(input_ids[0].tolist())


print("Training engine ready")

In [None]:
# ============================================================
# CELL 6: VISUALIZATION HELPERS
# ============================================================

def plot_training_curves(wave_history, std_history, title_suffix=''):
    """Live training curve comparison."""
    fig, axes = plt.subplots(1, 3, figsize=(16, 4))

    # PPL
    ax = axes[0]
    if wave_history:
        ax.plot([h['step'] for h in wave_history],
                [h['ppl'] for h in wave_history],
                'r-o', label='Wave V4.3.3', linewidth=2, markersize=4)
    if std_history:
        ax.plot([h['step'] for h in std_history],
                [h['ppl'] for h in std_history],
                'b-s', label='Standard', linewidth=2, markersize=4)
    ax.set_xlabel('Step')
    ax.set_ylabel('Val PPL')
    ax.set_title(f'Perplexity {title_suffix}')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_yscale('log')

    # Accuracy
    ax = axes[1]
    if wave_history:
        ax.plot([h['step'] for h in wave_history],
                [h['acc'] for h in wave_history],
                'r-o', label='Wave', linewidth=2, markersize=4)
    if std_history:
        ax.plot([h['step'] for h in std_history],
                [h['acc'] for h in std_history],
                'b-s', label='Standard', linewidth=2, markersize=4)
    ax.set_xlabel('Step')
    ax.set_ylabel('Accuracy (%)')
    ax.set_title('Next-token Accuracy')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Tokens/s
    ax = axes[2]
    if wave_history:
        wh = [h for h in wave_history if h.get('tps', 0) > 0]
        if wh:
            ax.plot([h['step'] for h in wh],
                    [h['tps']/1e3 for h in wh],
                    'r-o', label='Wave', linewidth=2, markersize=4)
    if std_history:
        sh = [h for h in std_history if h.get('tps', 0) > 0]
        if sh:
            ax.plot([h['step'] for h in sh],
                    [h['tps']/1e3 for h in sh],
                    'b-s', label='Standard', linewidth=2, markersize=4)
    ax.set_xlabel('Step')
    ax.set_ylabel('K tok/s')
    ax.set_title('Throughput')
    ax.legend()
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('results/training_curves.png', dpi=150, bbox_inches='tight')
    plt.show()


def plot_physics_diagnostics(monitor, title='Wave Field Physics'):
    """Plot kernel evolution, rank, and gradient health."""
    if not monitor.snapshots:
        print("No snapshots yet.")
        return

    df = monitor.summary_df()
    steps = sorted(df['step'].unique())
    n_layers = df['layer'].nunique()

    fig, axes = plt.subplots(2, 3, figsize=(18, 8))
    fig.suptitle(title, fontsize=14, fontweight='bold')

    # 1. Frequency evolution
    ax = axes[0, 0]
    for layer_idx in [0, n_layers - 1]:
        sub = df[df['layer'] == layer_idx]
        for h in sub['head'].unique():
            hd = sub[sub['head'] == h]
            ax.plot(hd['step'], hd['freq'], '-', alpha=0.6,
                    label=f'L{layer_idx}H{h}' if h < 3 else None)
    ax.set_xlabel('Step'); ax.set_ylabel('Frequency')
    ax.set_title('Kernel Frequencies'); ax.legend(fontsize=7, ncol=2)
    ax.grid(True, alpha=0.3)

    # 2. Damping evolution
    ax = axes[0, 1]
    for layer_idx in [0, n_layers - 1]:
        sub = df[df['layer'] == layer_idx]
        for h in sub['head'].unique():
            hd = sub[sub['head'] == h]
            ax.plot(hd['step'], hd['damp'], '-', alpha=0.6,
                    label=f'L{layer_idx}H{h}' if h < 3 else None)
    ax.set_xlabel('Step'); ax.set_ylabel('Damping (softplus)')
    ax.set_title('Kernel Damping'); ax.legend(fontsize=7, ncol=2)
    ax.grid(True, alpha=0.3)

    # 3. Output rank per layer
    ax = axes[0, 2]
    rank_df = df[df['output_rank'] >= 0].groupby(['step', 'layer'])['output_rank'].first().reset_index()
    if not rank_df.empty:
        for li in rank_df['layer'].unique():
            sub = rank_df[rank_df['layer'] == li]
            ax.plot(sub['step'], sub['output_rank'], '-o', markersize=3, label=f'Layer {li}')
    ax.set_xlabel('Step'); ax.set_ylabel('Effective Rank')
    ax.set_title('Output Rank (collapse = 1)'); ax.legend(fontsize=7)
    ax.grid(True, alpha=0.3)

    # 4. Gate bias
    ax = axes[1, 0]
    gb_df = df.groupby(['step', 'layer'])['gate_bias'].first().reset_index()
    for li in gb_df['layer'].unique():
        sub = gb_df[gb_df['layer'] == li]
        ax.plot(sub['step'], sub['gate_bias'], '-', label=f'L{li}')
    ax.axhline(y=2.0, color='gray', linestyle='--', alpha=0.5, label='init=2.0')
    ax.set_xlabel('Step'); ax.set_ylabel('Mean Gate Bias')
    ax.set_title('Gate Bias Health'); ax.legend(fontsize=7)
    ax.grid(True, alpha=0.3)

    # 5. Cross-coupling
    ax = axes[1, 1]
    cc_df = df.groupby(['step', 'layer'])['cross_coupling'].first().reset_index()
    for li in cc_df['layer'].unique():
        sub = cc_df[cc_df['layer'] == li]
        ax.plot(sub['step'], sub['cross_coupling'], '-', label=f'L{li}')
    ax.set_xlabel('Step'); ax.set_ylabel('Total Cross-Coupling')
    ax.set_title('Field Coupling Activity'); ax.legend(fontsize=7)
    ax.grid(True, alpha=0.3)

    # 6. Norm growth
    ax = axes[1, 2]
    nm_df = df[df['norm_mean'] > 0].groupby(['step', 'layer'])['norm_mean'].first().reset_index()
    if not nm_df.empty:
        for li in nm_df['layer'].unique():
            sub = nm_df[nm_df['layer'] == li]
            ax.plot(sub['step'], sub['norm_mean'], '-', label=f'L{li}')
    ax.set_xlabel('Step'); ax.set_ylabel('Mean Norm')
    ax.set_title('Activation Norms'); ax.legend(fontsize=7)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('results/physics_diagnostics.png', dpi=150, bbox_inches='tight')
    plt.show()


def display_comparison_table(wave_history, std_history):
    """Show a side-by-side pandas table of metrics."""
    rows = []
    all_steps = sorted(set(
        [h['step'] for h in wave_history] +
        [h['step'] for h in std_history]
    ))
    wave_map = {h['step']: h for h in wave_history}
    std_map = {h['step']: h for h in std_history}
    for step in all_steps:
        w = wave_map.get(step, {})
        s = std_map.get(step, {})
        rows.append({
            'Step': step,
            'Wave PPL': f"{w['ppl']:.1f}" if 'ppl' in w else '-',
            'Std PPL': f"{s['ppl']:.1f}" if 'ppl' in s else '-',
            'Wave Acc%': f"{w['acc']:.1f}" if 'acc' in w else '-',
            'Std Acc%': f"{s['acc']:.1f}" if 'acc' in s else '-',
            'Wave Loss': f"{w['val_loss']:.3f}" if 'val_loss' in w else '-',
            'Std Loss': f"{s['val_loss']:.3f}" if 'val_loss' in s else '-',
        })
    display(pd.DataFrame(rows).style.set_caption(
        'Wave Field vs Standard â€” Training Progress'))


print("Visualization helpers ready")

In [None]:
# ============================================================
# CELL 7: CREATE MODELS + PARAM COUNT
# ============================================================
gc.collect()
torch.cuda.empty_cache()

# Wave Field V4.3.3 (SPECTRE)
wave_model = WaveFieldTransformer(
    vocab_size=VOCAB_SIZE,
    embedding_dim=CFG['embed_dim'],
    num_layers=CFG['num_layers'],
    num_heads=CFG['num_heads'],
    ffn_dim=CFG['ffn_dim'],
    field_size=CFG['field_size'],
    max_seq_len=CFG['seq_len'] + 2,
    dropout=0.1,
    use_checkpoint=True,
    interference_interval=3,
    n_components=1,
    local_window=0,
    device=device
).to(device)

# Standard Transformer baseline
std_model = StandardTransformer(
    vocab_size=VOCAB_SIZE,
    embedding_dim=CFG['embed_dim'],
    num_layers=CFG['num_layers'],
    num_heads=CFG['num_heads'],
    ffn_dim=CFG['ffn_dim'],
    max_seq_len=CFG['seq_len'] + 2,
    dropout=0.1
).to(device)

# torch.compile for Wave (non-FFT submodules only)
try:
    wave_model.compile_model(mode='reduce-overhead')
    compile_status = 'enabled'
except Exception as e:
    compile_status = f'skipped ({e})'

wave_params = sum(p.numel() for p in wave_model.parameters())
std_params = sum(p.numel() for p in std_model.parameters())

display(HTML(f"""
<div style='background:#1a1a2e;color:#e0e0e0;padding:15px;border-radius:8px;font-family:monospace'>
  <h3 style='color:#00d4ff;margin:0 0 10px 0'>Model Summary â€” Scale {SCALE}</h3>
  <table style='color:#e0e0e0;border-collapse:collapse;width:100%'>
    <tr style='border-bottom:1px solid #444'>
      <th style='text-align:left;padding:5px'>Model</th>
      <th style='text-align:right;padding:5px'>Params</th>
      <th style='text-align:right;padding:5px'>Complexity</th>
      <th style='text-align:right;padding:5px'>torch.compile</th>
    </tr>
    <tr>
      <td style='padding:5px;color:#ff6b6b'>Wave Field V4.3.3</td>
      <td style='padding:5px;text-align:right'><b>{wave_params:,}</b> ({wave_params/1e6:.1f}M)</td>
      <td style='padding:5px;text-align:right'>O(n log n)</td>
      <td style='padding:5px;text-align:right'>{compile_status}</td>
    </tr>
    <tr>
      <td style='padding:5px;color:#4ecdc4'>Standard Transformer</td>
      <td style='padding:5px;text-align:right'><b>{std_params:,}</b> ({std_params/1e6:.1f}M)</td>
      <td style='padding:5px;text-align:right'>O(n^2)</td>
      <td style='padding:5px;text-align:right'>N/A</td>
    </tr>
  </table>
  <p style='margin:10px 0 0 0;color:#888'>Overhead: Wave has +{(wave_params-std_params)/1e6:.1f}M from feature maps, spectral gate, interference</p>
</div>
"""))

# VRAM sanity check
torch.cuda.reset_peak_memory_stats()
dummy = torch.randint(0, VOCAB_SIZE, (CFG['batch_size'], CFG['seq_len']), device=device)
with torch.amp.autocast('cuda', enabled=USE_FP16):
    _ = wave_model(dummy)
peak = torch.cuda.max_memory_allocated() / 1e9
print(f"Wave forward pass peak VRAM: {peak:.2f} GB / {vram_gb:.1f} GB")
del dummy, _
gc.collect()
torch.cuda.empty_cache()

In [None]:
# ============================================================
# CELL 8: TRAIN WAVE FIELD V4.3.3
# ============================================================
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# V4.3.2 per-group LR: kernel params at 50x, QKV at 3x
wave_optimizer = wave_model.configure_optimizer(
    base_lr=CFG['peak_lr'], kernel_lr_mult=50.0, qk_lr_mult=3.0)
wave_scheduler = WarmupCosineScheduler(
    wave_optimizer, warmup=max(CFG['total_steps'] // 10, 100),
    total=CFG['total_steps'])
wave_scaler = torch.amp.GradScaler('cuda', enabled=USE_FP16)
wave_monitor = PhysicsMonitor()

wave_history = []
batch_size = CFG['batch_size']
grad_accum = CFG['grad_accum']
total_steps = CFG['total_steps']
eval_every = CFG['eval_every']
ckpt_every = CFG['ckpt_every']

# Initial eval
vl, vp, va = evaluate(wave_model, val_x, val_y, batch_size, VOCAB_SIZE)
wave_history.append({'step': 0, 'val_loss': vl, 'ppl': vp, 'acc': va, 'tps': 0})
wave_monitor.snapshot(wave_model, 0, sample_x=train_x[:2])
print(f"Wave init: PPL {vp:.1f} | Acc {va:.1f}%")

# Training loop
step = 0
accum_count = 0
best_wave_ppl = float('inf')
t0 = time.time()
wave_model.train()
wave_optimizer.zero_grad(set_to_none=True)

pbar = tqdm(total=total_steps, desc='Wave Training', unit='step')

while step < total_steps:
    for x, y in make_batches_gpu(train_x, train_y, batch_size):
        if step >= total_steps:
            break

        with torch.amp.autocast('cuda', enabled=USE_FP16):
            logits, _ = wave_model(x)
            loss = F.cross_entropy(logits.reshape(-1, VOCAB_SIZE), y.reshape(-1))
            loss_scaled = loss / grad_accum

        if torch.isnan(loss):
            continue

        wave_scaler.scale(loss_scaled).backward()
        accum_count += 1

        if accum_count % grad_accum == 0:
            wave_scaler.unscale_(wave_optimizer)
            torch.nn.utils.clip_grad_norm_(wave_model.parameters(), 1.0)
            wave_scaler.step(wave_optimizer)
            wave_scaler.update()
            wave_scheduler.step()
            wave_optimizer.zero_grad(set_to_none=True)
            step += 1

            # Update progress bar
            elapsed = time.time() - t0
            tok_seen = step * batch_size * grad_accum * CFG['seq_len']
            tps = tok_seen / elapsed if elapsed > 0 else 0
            lr_now = wave_scheduler.get_lr()[0]
            pbar.set_postfix({
                'loss': f'{loss.item():.3f}',
                'lr': f'{lr_now:.1e}',
                'tok/s': f'{tps/1e3:.0f}K',
                'VRAM': f'{torch.cuda.memory_allocated()/1e9:.1f}G'
            })
            pbar.update(1)

            # Eval + diagnostics checkpoint
            if step % eval_every == 0 or step == total_steps:
                wave_model.eval()
                vl, vp, va = evaluate(wave_model, val_x, val_y, batch_size, VOCAB_SIZE)
                wave_monitor.snapshot(wave_model, step, sample_x=train_x[:2])
                wave_history.append({
                    'step': step, 'val_loss': vl, 'ppl': vp, 'acc': va,
                    'tps': tps, 'tok_M': tok_seen / 1e6, 'time_s': elapsed
                })
                mark = ''
                if vp < best_wave_ppl:
                    best_wave_ppl = vp
                    torch.save(wave_model.state_dict(),
                               CKPT_DIR / f'wave_{SCALE}_best.pt')
                    mark = ' *BEST'
                tqdm.write(f"  [Wave] Step {step}/{total_steps} | "
                           f"PPL {vp:.1f} | Acc {va:.1f}% | "
                           f"{tok_seen/1e6:.1f}M tok | "
                           f"{tps/1e3:.0f}K tok/s{mark}")

                # Generation sample
                sample = generate_text(wave_model, raw_tok,
                                       prompt='The world', max_tokens=50)
                tqdm.write(f"  [Gen] {sample[:150]}")
                wave_model.train()

            # Save checkpoint
            if step % ckpt_every == 0:
                torch.save({
                    'model': wave_model.state_dict(),
                    'optimizer': wave_optimizer.state_dict(),
                    'scaler': wave_scaler.state_dict(),
                    'step': step, 'history': wave_history
                }, CKPT_DIR / f'wave_{SCALE}_step{step}.pt')

pbar.close()
wave_peak_mem = torch.cuda.max_memory_allocated() / 1e9
wave_total_time = time.time() - t0
print(f"\nWave training complete: {wave_total_time:.0f}s | Peak VRAM: {wave_peak_mem:.2f} GB")
print(f"Best PPL: {best_wave_ppl:.1f}")

In [None]:
# ============================================================
# CELL 9: WAVE PHYSICS DIAGNOSTICS
# ============================================================
plot_physics_diagnostics(wave_monitor, title=f'Wave Field V4.3.3 Physics â€” {SCALE}')

# Kernel parameter table at final step
df = wave_monitor.summary_df()
final = df[df['step'] == df['step'].max()]
display(final[['layer', 'head', 'freq', 'damp', 'phase',
               'output_rank', 'norm_mean']].style.set_caption(
    'Final Kernel Parameters').format({
        'freq': '{:.3f}', 'damp': '{:.3f}', 'phase': '{:.3f}',
        'output_rank': '{:.0f}', 'norm_mean': '{:.1f}'
    }).background_gradient(subset=['damp'], cmap='RdYlGn_r')
     .background_gradient(subset=['output_rank'], cmap='RdYlGn'))

In [None]:
# ============================================================
# CELL 10: TRAIN STANDARD TRANSFORMER
# ============================================================
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

std_optimizer = torch.optim.AdamW(std_model.parameters(),
                                   lr=CFG['peak_lr'], weight_decay=0.01)
std_scheduler = WarmupCosineScheduler(
    std_optimizer, warmup=max(CFG['total_steps'] // 10, 100),
    total=CFG['total_steps'])
std_scaler = torch.amp.GradScaler('cuda', enabled=USE_FP16)

std_history = []

# Initial eval
vl, vp, va = evaluate(std_model, val_x, val_y, batch_size, VOCAB_SIZE)
std_history.append({'step': 0, 'val_loss': vl, 'ppl': vp, 'acc': va, 'tps': 0})
print(f"Std init: PPL {vp:.1f} | Acc {va:.1f}%")

# Training loop
step = 0
accum_count = 0
best_std_ppl = float('inf')
t0 = time.time()
std_model.train()
std_optimizer.zero_grad(set_to_none=True)

pbar = tqdm(total=total_steps, desc='Standard Training', unit='step')

while step < total_steps:
    for x, y in make_batches_gpu(train_x, train_y, batch_size):
        if step >= total_steps:
            break

        with torch.amp.autocast('cuda', enabled=USE_FP16):
            logits, _ = std_model(x)
            loss = F.cross_entropy(logits.reshape(-1, VOCAB_SIZE), y.reshape(-1))
            loss_scaled = loss / grad_accum

        if torch.isnan(loss):
            continue

        std_scaler.scale(loss_scaled).backward()
        accum_count += 1

        if accum_count % grad_accum == 0:
            std_scaler.unscale_(std_optimizer)
            torch.nn.utils.clip_grad_norm_(std_model.parameters(), 1.0)
            std_scaler.step(std_optimizer)
            std_scaler.update()
            std_scheduler.step()
            std_optimizer.zero_grad(set_to_none=True)
            step += 1

            elapsed = time.time() - t0
            tok_seen = step * batch_size * grad_accum * CFG['seq_len']
            tps = tok_seen / elapsed if elapsed > 0 else 0
            lr_now = std_scheduler.get_lr()[0]
            pbar.set_postfix({
                'loss': f'{loss.item():.3f}',
                'lr': f'{lr_now:.1e}',
                'tok/s': f'{tps/1e3:.0f}K',
                'VRAM': f'{torch.cuda.memory_allocated()/1e9:.1f}G'
            })
            pbar.update(1)

            if step % eval_every == 0 or step == total_steps:
                std_model.eval()
                vl, vp, va = evaluate(std_model, val_x, val_y, batch_size, VOCAB_SIZE)
                std_history.append({
                    'step': step, 'val_loss': vl, 'ppl': vp, 'acc': va,
                    'tps': tps, 'tok_M': tok_seen / 1e6, 'time_s': elapsed
                })
                mark = ''
                if vp < best_std_ppl:
                    best_std_ppl = vp
                    torch.save(std_model.state_dict(),
                               CKPT_DIR / f'std_{SCALE}_best.pt')
                    mark = ' *BEST'
                tqdm.write(f"  [Std] Step {step}/{total_steps} | "
                           f"PPL {vp:.1f} | Acc {va:.1f}% | "
                           f"{tok_seen/1e6:.1f}M tok | "
                           f"{tps/1e3:.0f}K tok/s{mark}")

                sample = generate_text(std_model, raw_tok,
                                       prompt='The world', max_tokens=50)
                tqdm.write(f"  [Gen] {sample[:150]}")
                std_model.train()

            if step % ckpt_every == 0:
                torch.save({
                    'model': std_model.state_dict(),
                    'optimizer': std_optimizer.state_dict(),
                    'scaler': std_scaler.state_dict(),
                    'step': step, 'history': std_history
                }, CKPT_DIR / f'std_{SCALE}_step{step}.pt')

pbar.close()
std_peak_mem = torch.cuda.max_memory_allocated() / 1e9
std_total_time = time.time() - t0
print(f"\nStandard training complete: {std_total_time:.0f}s | Peak VRAM: {std_peak_mem:.2f} GB")
print(f"Best PPL: {best_std_ppl:.1f}")

In [None]:
# ============================================================
# CELL 11: TRAINING COMPARISON â€” CURVES + TABLE
# ============================================================
plot_training_curves(wave_history, std_history, title_suffix=f'({SCALE})')
display_comparison_table(wave_history, std_history)

# Summary card
display(HTML(f"""
<div style='background:#1a1a2e;color:#e0e0e0;padding:15px;border-radius:8px;font-family:monospace;margin-top:10px'>
  <h3 style='color:#00d4ff;margin:0 0 10px 0'>Training Summary â€” {SCALE}</h3>
  <table style='color:#e0e0e0;border-collapse:collapse;width:100%'>
    <tr style='border-bottom:1px solid #444'>
      <th style='text-align:left;padding:5px'>Metric</th>
      <th style='text-align:right;padding:5px;color:#ff6b6b'>Wave V4.3.3</th>
      <th style='text-align:right;padding:5px;color:#4ecdc4'>Standard</th>
    </tr>
    <tr><td style='padding:5px'>Best PPL</td>
        <td style='padding:5px;text-align:right'><b>{best_wave_ppl:.1f}</b></td>
        <td style='padding:5px;text-align:right'><b>{best_std_ppl:.1f}</b></td></tr>
    <tr><td style='padding:5px'>Parameters</td>
        <td style='padding:5px;text-align:right'>{wave_params/1e6:.1f}M</td>
        <td style='padding:5px;text-align:right'>{std_params/1e6:.1f}M</td></tr>
    <tr><td style='padding:5px'>Peak VRAM</td>
        <td style='padding:5px;text-align:right'>{wave_peak_mem:.2f} GB</td>
        <td style='padding:5px;text-align:right'>{std_peak_mem:.2f} GB</td></tr>
    <tr><td style='padding:5px'>Training Time</td>
        <td style='padding:5px;text-align:right'>{wave_total_time:.0f}s</td>
        <td style='padding:5px;text-align:right'>{std_total_time:.0f}s</td></tr>
  </table>
</div>
"""))

In [None]:
# ============================================================
# CELL 12: GENERATION COMPARISON
# ============================================================
prompts = [
    "The president of the United States",
    "In the beginning, there was",
    "Scientists discovered that",
    "The most important thing about",
]

gen_rows = []
for prompt in prompts:
    wave_text = generate_text(wave_model, raw_tok, prompt=prompt,
                              max_tokens=60, temperature=0.8)
    std_text = generate_text(std_model, raw_tok, prompt=prompt,
                             max_tokens=60, temperature=0.8)
    gen_rows.append({
        'Prompt': prompt,
        'Wave V4.3.3': wave_text[:200],
        'Standard': std_text[:200]
    })

gen_df = pd.DataFrame(gen_rows)
display(gen_df.style.set_caption('Generation Comparison (temp=0.8, top_k=40)')
        .set_properties(**{'white-space': 'pre-wrap', 'max-width': '400px'}))

In [None]:
# ============================================================
# CELL 13: SPEED & MEMORY SCALING BENCHMARK
# ============================================================
# Clean up training models to free VRAM for benchmarking
del wave_optimizer, wave_scaler, std_optimizer, std_scaler
gc.collect()
torch.cuda.empty_cache()

SPEED_SEQ_LENS = [256, 512, 1024, 2048, 4096, 8192]
SPEED_BATCH = 2

speed_results = []

for sl in tqdm(SPEED_SEQ_LENS, desc='Speed benchmark'):
    row = {'seq_len': sl}

    # Standard
    gc.collect(); torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    try:
        m = StandardTransformer(
            vocab_size=VOCAB_SIZE, embedding_dim=CFG['embed_dim'],
            num_layers=CFG['num_layers'], num_heads=CFG['num_heads'],
            ffn_dim=CFG['ffn_dim'], max_seq_len=sl+2, dropout=0.0
        ).to(device).eval()
        x = torch.randint(0, VOCAB_SIZE, (SPEED_BATCH, sl), device=device)
        with torch.no_grad(), torch.amp.autocast('cuda', enabled=USE_FP16):
            for _ in range(3): m(x)
        torch.cuda.synchronize()
        t0 = time.time()
        with torch.no_grad(), torch.amp.autocast('cuda', enabled=USE_FP16):
            for _ in range(10): m(x)
        torch.cuda.synchronize()
        row['std_ms'] = (time.time() - t0) / 10 * 1000
        row['std_mem'] = torch.cuda.max_memory_allocated() / 1e9
        del m, x
    except RuntimeError:
        row['std_ms'] = float('inf')
        row['std_mem'] = float('inf')

    # Wave
    gc.collect(); torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    try:
        fs = max(sl * 2, CFG['field_size'])
        m = WaveFieldTransformer(
            vocab_size=VOCAB_SIZE, embedding_dim=CFG['embed_dim'],
            num_layers=CFG['num_layers'], num_heads=CFG['num_heads'],
            ffn_dim=CFG['ffn_dim'], field_size=fs,
            max_seq_len=sl+2, dropout=0.0, use_checkpoint=False,
            interference_interval=3, n_components=1, local_window=0,
            device=device
        ).to(device).eval()
        x = torch.randint(0, VOCAB_SIZE, (SPEED_BATCH, sl), device=device)
        with torch.no_grad(), torch.amp.autocast('cuda', enabled=USE_FP16):
            for _ in range(3): m(x)
        torch.cuda.synchronize()
        t0 = time.time()
        with torch.no_grad(), torch.amp.autocast('cuda', enabled=USE_FP16):
            for _ in range(10): m(x)
        torch.cuda.synchronize()
        row['wave_ms'] = (time.time() - t0) / 10 * 1000
        row['wave_mem'] = torch.cuda.max_memory_allocated() / 1e9
        del m, x
    except RuntimeError:
        row['wave_ms'] = float('inf')
        row['wave_mem'] = float('inf')

    gc.collect(); torch.cuda.empty_cache()

    if row['std_ms'] < float('inf') and row['wave_ms'] < float('inf'):
        row['speedup'] = row['std_ms'] / row['wave_ms']
    else:
        row['speedup'] = None
    speed_results.append(row)

speed_df = pd.DataFrame(speed_results)

# Display table
def fmt_ms(v): return 'OOM' if v == float('inf') else f'{v:.1f}'
def fmt_gb(v): return 'OOM' if v == float('inf') else f'{v:.2f}'

disp = speed_df.copy()
disp['std_ms'] = disp['std_ms'].apply(fmt_ms)
disp['wave_ms'] = disp['wave_ms'].apply(fmt_ms)
disp['std_mem'] = disp['std_mem'].apply(fmt_gb)
disp['wave_mem'] = disp['wave_mem'].apply(fmt_gb)
disp['speedup'] = disp['speedup'].apply(lambda v: f'{v:.2f}x' if v else '-')
disp.columns = ['Seq Len', 'Std (ms)', 'Std VRAM', 'Wave (ms)', 'Wave VRAM', 'Speedup']
display(disp.style.set_caption(f'Speed & Memory Benchmark (B={SPEED_BATCH}, fp16)'))

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

valid = speed_df[(speed_df['std_ms'] < float('inf')) & (speed_df['wave_ms'] < float('inf'))]
if not valid.empty:
    ax1.plot(valid['seq_len'], valid['std_ms'], 'b-o', label='Standard O(n^2)', linewidth=2)
    ax1.plot(valid['seq_len'], valid['wave_ms'], 'r-o', label='Wave O(n log n)', linewidth=2)
ax1.set_xlabel('Sequence Length'); ax1.set_ylabel('Forward Pass (ms)')
ax1.set_title('Speed Scaling'); ax1.legend()
ax1.grid(True, alpha=0.3); ax1.set_yscale('log'); ax1.set_xscale('log', base=2)

valid_mem = speed_df[(speed_df['std_mem'] < float('inf')) & (speed_df['wave_mem'] < float('inf'))]
if not valid_mem.empty:
    ax2.plot(valid_mem['seq_len'], valid_mem['std_mem'], 'b-o', label='Standard O(n^2)', linewidth=2)
    ax2.plot(valid_mem['seq_len'], valid_mem['wave_mem'], 'r-o', label='Wave O(n log n)', linewidth=2)
ax2.axhline(y=vram_gb, color='gray', linestyle='--', alpha=0.7, label=f'{gpu_name} ({vram_gb:.0f} GB)')
ax2.set_xlabel('Sequence Length'); ax2.set_ylabel('Peak VRAM (GB)')
ax2.set_title('Memory Scaling'); ax2.legend()
ax2.grid(True, alpha=0.3); ax2.set_xscale('log', base=2)

plt.tight_layout()
plt.savefig('results/speed_memory_scaling.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================================
# CELL 14: KERNEL HEALTH â€” FROZEN OR EVOLVING?
# ============================================================
df = wave_monitor.summary_df()
if len(df['step'].unique()) >= 2:
    init_df = df[df['step'] == df['step'].min()]
    final_df = df[df['step'] == df['step'].max()]

    merged = init_df.merge(final_df, on=['layer', 'head'],
                           suffixes=('_init', '_final'))

    merged['freq_change'] = (merged['freq_final'] - merged['freq_init']).abs()
    merged['damp_change'] = (merged['damp_final'] - merged['damp_init']).abs()

    avg_freq_change = merged['freq_change'].mean()
    avg_damp_change = merged['damp_change'].mean()

    frozen_threshold = 0.01
    n_frozen_freq = (merged['freq_change'] < frozen_threshold).sum()
    total_heads = len(merged)

    health = 'HEALTHY' if n_frozen_freq < total_heads * 0.3 else 'FROZEN'
    color = '#00ff00' if health == 'HEALTHY' else '#ff4444'

    display(HTML(f"""
    <div style='background:#1a1a2e;color:#e0e0e0;padding:15px;border-radius:8px;font-family:monospace'>
      <h3 style='color:{color};margin:0'>Kernel Health: {health}</h3>
      <p>Avg frequency change: <b>{avg_freq_change:.4f}</b> | 
         Avg damping change: <b>{avg_damp_change:.4f}</b></p>
      <p>Frozen heads: {n_frozen_freq}/{total_heads} (threshold: change &lt; {frozen_threshold})</p>
      <p style='color:#888;margin-top:5px'>V4.3.2 fix: kernel_lr = 50x base_lr prevents freezing.</p>
    </div>
    """))

    # Heatmaps
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))

    freq_pivot = merged.pivot(index='layer', columns='head', values='freq_change')
    damp_pivot = merged.pivot(index='layer', columns='head', values='damp_change')

    im1 = ax1.imshow(freq_pivot.values, cmap='YlOrRd', aspect='auto')
    ax1.set_xlabel('Head'); ax1.set_ylabel('Layer')
    ax1.set_title('Frequency Change (init -> final)')
    plt.colorbar(im1, ax=ax1)

    im2 = ax2.imshow(damp_pivot.values, cmap='YlOrRd', aspect='auto')
    ax2.set_xlabel('Head'); ax2.set_ylabel('Layer')
    ax2.set_title('Damping Change (init -> final)')
    plt.colorbar(im2, ax=ax2)

    plt.tight_layout()
    plt.savefig('results/kernel_health.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("Need at least 2 snapshots for kernel health analysis.")

In [None]:
# ============================================================
# CELL 15: SAVE ALL RESULTS + FINAL VERDICT
# ============================================================

all_results = {
    'config': {
        'scale': SCALE,
        'embed_dim': CFG['embed_dim'],
        'num_layers': CFG['num_layers'],
        'num_heads': CFG['num_heads'],
        'ffn_dim': CFG['ffn_dim'],
        'field_size': CFG['field_size'],
        'seq_len': CFG['seq_len'],
        'vocab_size': VOCAB_SIZE,
        'total_steps': CFG['total_steps'],
        'wave_params': wave_params,
        'std_params': std_params,
        'gpu': gpu_name,
        'vram_gb': vram_gb,
    },
    'wave': {
        'best_ppl': best_wave_ppl,
        'peak_mem_gb': wave_peak_mem,
        'time_s': wave_total_time,
        'history': wave_history,
    },
    'standard': {
        'best_ppl': best_std_ppl,
        'peak_mem_gb': std_peak_mem,
        'time_s': std_total_time,
        'history': std_history,
    },
    'speed_benchmark': speed_results,
    'physics_snapshots': wave_monitor.snapshots,
}

with open(f'results/colab_{SCALE}_results.json', 'w') as f:
    json.dump(all_results, f, indent=2, default=str)

print(f"Results saved: results/colab_{SCALE}_results.json")
print(f"Checkpoints in: {CKPT_DIR}/")
print(f"Plots in: results/")

# Final verdict
if best_wave_ppl < best_std_ppl:
    verdict = 'Wave WINS'
    color = '#00ff00'
elif best_wave_ppl < best_std_ppl * 1.5:
    verdict = 'COMPETITIVE'
    color = '#ffaa00'
else:
    verdict = 'Standard leads (need more data/scale)'
    color = '#ff6666'

display(HTML(f"""
<div style='background:#0a0a1a;color:#e0e0e0;padding:20px;border-radius:12px;
            font-family:monospace;border:2px solid {color};margin-top:10px'>
  <h2 style='color:{color};margin:0 0 10px 0;text-align:center'>{verdict}</h2>
  <table style='color:#e0e0e0;width:100%;border-collapse:collapse'>
    <tr><td style='padding:5px'>Wave V4.3.3 Best PPL</td>
        <td style='padding:5px;text-align:right;font-size:1.2em'><b>{best_wave_ppl:.1f}</b></td></tr>
    <tr><td style='padding:5px'>Standard Best PPL</td>
        <td style='padding:5px;text-align:right;font-size:1.2em'><b>{best_std_ppl:.1f}</b></td></tr>
    <tr style='border-top:1px solid #444'>
        <td style='padding:5px'>Wave Params</td>
        <td style='padding:5px;text-align:right'>{wave_params/1e6:.1f}M</td></tr>
    <tr><td style='padding:5px'>Standard Params</td>
        <td style='padding:5px;text-align:right'>{std_params/1e6:.1f}M</td></tr>
    <tr><td style='padding:5px'>Wave Complexity</td>
        <td style='padding:5px;text-align:right'>O(n log n)</td></tr>
    <tr><td style='padding:5px'>Standard Complexity</td>
        <td style='padding:5px;text-align:right'>O(n^2)</td></tr>
  </table>
  <p style='color:#888;text-align:center;margin:10px 0 0 0'>
    Scale {SCALE} | {CFG['embed_dim']}d {CFG['num_layers']}L {CFG['num_heads']}H | 
    {gpu_name} | WikiText-103
  </p>
</div>
"""))