# V4.3.4 Component Ablation Benchmark

**SPECTRE-Wave Field LLM** — Isolating the impact of each V4.3.4 fix:
- A) NormalizedExp activation (was ELU+1)
- B) SpectralGate init 10x stronger (0.1 vs 0.01)
- C) Kernel damping range expanded (-3.0 vs -1.4)

**Crash-safe**: All results saved to Google Drive. If Colab disconnects, just re-run all cells — completed variants are skipped automatically.

**Estimated time**: ~5hrs on T4 (6 variants x ~50min each)

In [None]:
# Cell 1: Mount Google Drive (persistence layer)
# All results, checkpoints, cache, and monitor data persist here
from google.colab import drive
drive.mount('/content/drive')

import os
DRIVE_DIR = '/content/drive/MyDrive/wavellm_results'
os.makedirs(DRIVE_DIR, exist_ok=True)
os.makedirs(f'{DRIVE_DIR}/cache', exist_ok=True)
os.makedirs(f'{DRIVE_DIR}/monitor', exist_ok=True)
print(f'Drive dir: {DRIVE_DIR}')
print(f'Contents: {os.listdir(DRIVE_DIR)}')

In [None]:
# Cell 2: Clone repo + install deps (idempotent — safe to re-run)
import os
REPO_DIR = '/content/wave-field-llm'
DRIVE_DIR = '/content/drive/MyDrive/wavellm_results'

if not os.path.exists(REPO_DIR):
    !git clone https://github.com/Pankh-AI/wave-field-llm.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull --ff-only

%cd {REPO_DIR}
!pip install -q torch datasets tokenizers tqdm

# CRITICAL: Symlink results/ -> Google Drive
# This makes ALL checkpoints, cache, and results auto-persist to Drive
results_link = os.path.join(REPO_DIR, 'results')
if os.path.islink(results_link):
    os.unlink(results_link)
elif os.path.isdir(results_link):
    import shutil
    shutil.rmtree(results_link)
os.symlink(DRIVE_DIR, results_link)

# Verify symlink
assert os.path.islink(results_link), 'Symlink failed!'
assert os.path.realpath(results_link) == DRIVE_DIR, 'Symlink points to wrong dir!'
print(f'results/ -> {DRIVE_DIR} (symlinked to Drive)')
print(f'Cache contents: {os.listdir(os.path.join(DRIVE_DIR, "cache"))}')

In [None]:
# Cell 3: Verify GPU + show resume state
import torch
import json
import os

DRIVE_DIR = '/content/drive/MyDrive/wavellm_results'

# GPU check
assert torch.cuda.is_available(), 'No GPU! Go to Runtime > Change runtime type > T4 GPU'
gpu_name = torch.cuda.get_device_name(0)
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f'GPU: {gpu_name} ({vram_gb:.1f} GB VRAM)')

# AMP dtype selection
if 'T4' in gpu_name or 'V100' in gpu_name:
    print(f'AMP: fp16 + GradScaler (narrow exponent range)')
elif 'A100' in gpu_name or 'H100' in gpu_name:
    print(f'AMP: bf16, no GradScaler needed')

# Resume state check
partial = os.path.join(DRIVE_DIR, 'v434_ablation_partial.json')
final = os.path.join(DRIVE_DIR, 'v434_ablation.json')

if os.path.exists(final):
    with open(final) as f:
        data = json.load(f)
    print(f'\nFINAL RESULTS ALREADY EXIST ({len(data["results"])} variants):')
    for r in data['results']:
        name = r.get('ablation_name', '?')
        ppl = r.get('best_ppl', 'N/A')
        print(f'  {name}: PPL={ppl}')
    print('\nRe-running Cell 4 will skip all variants (already done).')
elif os.path.exists(partial):
    with open(partial) as f:
        data = json.load(f)
    done = data.get('completed_variants', [])
    print(f'\nResuming! {len(done)}/7 variants already done: {done}')
    for r in data.get('results', []):
        name = r.get('ablation_name', '?')
        ppl = r.get('best_ppl', 'N/A')
        print(f'  {name}: PPL={ppl}')
    print(f'\nCell 4 will resume from variant {len(done)+1}/7.')
else:
    print('\nFresh run - no previous results found on Drive.')
    print('Cell 4 will train all 7 variants from scratch (~5hrs on T4).')

In [None]:
# Cell 4: Run ablation (~5hrs on T4, crash-safe)
# - MONITOR=1 (default): WaveFieldMonitor captures physics diagnostics
# - Results saved to Drive after EACH variant (crash-safe)
# - On reconnect: re-run Cells 1-4, completed variants are skipped
!cd /content/wave-field-llm && SEED=42 WANDB=0 python benchmarks/benchmark_v434_ablation.py

In [None]:
# Cell 5: Check progress (run anytime — even while Cell 4 is running in another tab)
import json, os, glob

DRIVE_DIR = '/content/drive/MyDrive/wavellm_results'

# 1. Check which variants are done
partial = os.path.join(DRIVE_DIR, 'v434_ablation_partial.json')
final = os.path.join(DRIVE_DIR, 'v434_ablation.json')
path = final if os.path.exists(final) else partial

if os.path.exists(path):
    with open(path) as f:
        data = json.load(f)
    results = data.get('results', [])
    status = 'FINAL' if 'v434_ablation.json' in path else 'IN PROGRESS'
    print(f'=== {status}: {len(results)} variants done ===')
    print(f'{"Variant":<30} {"PPL":>8} {"Acc":>7} {"tok/s":>10}')
    print(f'{"-"*30} {"-"*8} {"-"*7} {"-"*10}')
    for r in results:
        name = r.get('ablation_name', r.get('run_name', '?'))
        ppl = r.get('best_ppl', 'N/A')
        acc = r.get('best_acc', 'N/A')
        tps = r.get('tokens_per_sec', 'N/A')
        ppl_s = f'{ppl:>8.1f}' if isinstance(ppl, (int,float)) else f'{ppl:>8}'
        acc_s = f'{acc:>6.1f}%' if isinstance(acc, (int,float)) else f'{acc:>7}'
        tps_s = f'{tps:>10,}' if isinstance(tps, (int,float)) else f'{tps:>10}'
        print(f'{name:<30} {ppl_s} {acc_s} {tps_s}')
else:
    print('No results yet - first variant still training.')

# 2. Check monitor data for latest/current variant
monitor_dirs = sorted(glob.glob(f'{DRIVE_DIR}/monitor/*/'))
if monitor_dirs:
    latest = monitor_dirs[-1]
    name = os.path.basename(latest.rstrip('/'))
    steps_file = os.path.join(latest, 'monitor_steps.json')
    if os.path.exists(steps_file):
        with open(steps_file) as f:
            steps = json.load(f)
        if steps:
            last = steps[-1]
            print(f'\n=== LIVE MONITOR: {name} (step {last.get("step", "?")}) ===')
            print(f'  Loss: {last.get("loss", "?"):.4f}')
            print(f'  LR:   {last.get("lr", "?"):.6f}')
            print(f'  Total steps recorded: {len(steps)}')

In [None]:
# Cell 6: Visualize physics diagnostics (12-panel dashboard per variant)
# Run after each variant completes, or at the end for all variants
import glob, os

DRIVE_DIR = '/content/drive/MyDrive/wavellm_results'
monitor_dirs = sorted(glob.glob(f'{DRIVE_DIR}/monitor/*/'))

if not monitor_dirs:
    print('No monitor data yet. Run Cell 4 first.')
else:
    print(f'Found {len(monitor_dirs)} monitored variants:')
    for mdir in monitor_dirs:
        snap_file = os.path.join(mdir, 'monitor_snapshots.json')
        step_file = os.path.join(mdir, 'monitor_steps.json')
        if os.path.exists(snap_file):
            name = os.path.basename(mdir.rstrip('/'))
            print(f'\n=== Generating dashboard: {name} ===')
            cmd = f'python diagnostics/visualize_monitor.py {snap_file}'
            if os.path.exists(step_file):
                cmd += f' {step_file}'
            !{cmd}
            # Display dashboard inline
            png = snap_file.replace('.json', '_dashboard.png')
            if os.path.exists(png):
                from IPython.display import Image, display
                display(Image(filename=png, width=1200))
            else:
                print(f'  Dashboard PNG not found at {png}')

In [None]:
# Cell 7: Final results table + gap analysis
import json, os

DRIVE_DIR = '/content/drive/MyDrive/wavellm_results'
results_path = os.path.join(DRIVE_DIR, 'v434_ablation.json')
partial_path = os.path.join(DRIVE_DIR, 'v434_ablation_partial.json')

path = results_path if os.path.exists(results_path) else partial_path
if not os.path.exists(path):
    print('No results found. Run Cell 4 first.')
else:
    with open(path) as f:
        data = json.load(f)
    results = data['results']
    status = 'FINAL' if 'v434_ablation.json' in path else 'PARTIAL'

    print(f'{"="*60}')
    print(f'  V4.3.4 ABLATION RESULTS ({status})')
    print(f'{"="*60}')
    print(f'\n  {"Variant":<30} {"PPL":>8} {"Acc":>7} {"Params":>12} {"tok/s":>10}')
    print(f'  {"-"*30} {"-"*8} {"-"*7} {"-"*12} {"-"*10}')

    std_ppl = None
    for r in results:
        name = r.get('ablation_name', r.get('run_name', '?'))
        ppl = r.get('best_ppl', 'N/A')
        acc = r.get('best_acc', 'N/A')
        params = r.get('params', 'N/A')
        tps = r.get('tokens_per_sec', 'N/A')
        ppl_s = f'{ppl:>8.1f}' if isinstance(ppl, (int,float)) else f'{ppl:>8}'
        acc_s = f'{acc:>6.1f}%' if isinstance(acc, (int,float)) else f'{acc:>7}'
        params_s = f'{params:>12,}' if isinstance(params, (int,float)) else f'{params:>12}'
        tps_s = f'{tps:>10,}' if isinstance(tps, (int,float)) else f'{tps:>10}'
        print(f'  {name:<30} {ppl_s} {acc_s} {params_s} {tps_s}')
        if r.get('ablation') == 'F_standard' and isinstance(ppl, (int,float)):
            std_ppl = ppl

    # Gap analysis
    if std_ppl:
        print(f'\n  --- GAP ANALYSIS (vs Standard PPL {std_ppl:.1f}) ---')
        print(f'  {"Variant":<30} {"PPL":>8} {"Gap":>8} {"Delta":>8}')
        print(f'  {"-"*30} {"-"*8} {"-"*8} {"-"*8}')
        baseline_ppl = None
        for r in results:
            ppl = r.get('best_ppl')
            if not isinstance(ppl, (int, float)):
                continue
            name = r.get('ablation_name', r.get('run_name', '?'))
            gap = ppl / std_ppl
            if r.get('ablation') == 'A_v433_baseline':
                baseline_ppl = ppl
            delta = f'{((ppl - baseline_ppl) / baseline_ppl * 100):>+7.1f}%' if baseline_ppl else '    base'
            print(f'  {name:<30} {ppl:>8.1f} {gap:>7.2f}x {delta}')