# V4.3.4 Component Ablation Benchmark

**SPECTRE-Wave Field LLM** — Isolating the impact of each V4.3.4 fix:
- A) NormalizedExp activation (was ELU+1)
- B) SpectralGate init 10x stronger (0.1 vs 0.01)
- C) Kernel damping range expanded (-3.0 vs -1.4)

**Crash-safe**: All results saved to Google Drive. If Colab disconnects, just re-run all cells — completed variants are skipped automatically.

**Parallel mode**: 3 processes share the T4 GPU → **~1.5hrs** instead of ~5hrs.

**Estimated time**: ~1.5hrs on T4 (parallel) / ~5hrs (sequential)

In [None]:
# Cell 1: Mount Google Drive (persistence layer)
# All results, checkpoints, cache, and monitor data persist here
from google.colab import drive
drive.mount('/content/drive')

import os
DRIVE_DIR = '/content/drive/MyDrive/wavellm_results'
os.makedirs(DRIVE_DIR, exist_ok=True)
os.makedirs(f'{DRIVE_DIR}/cache', exist_ok=True)
os.makedirs(f'{DRIVE_DIR}/monitor', exist_ok=True)
print(f'Drive dir: {DRIVE_DIR}')
print(f'Contents: {os.listdir(DRIVE_DIR)}')

In [None]:
# Cell 2: Clone repo + install deps (idempotent — safe to re-run)
import os
REPO_DIR = '/content/wave-field-llm'
DRIVE_DIR = '/content/drive/MyDrive/wavellm_results'

if not os.path.exists(REPO_DIR):
    !git clone https://github.com/Pankh-AI/wave-field-llm.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull --ff-only

%cd {REPO_DIR}
!pip install -q torch datasets tokenizers tqdm

# CRITICAL: Symlink results/ -> Google Drive
# This makes ALL checkpoints, cache, and results auto-persist to Drive
results_link = os.path.join(REPO_DIR, 'results')
if os.path.islink(results_link):
    os.unlink(results_link)
elif os.path.isdir(results_link):
    import shutil
    shutil.rmtree(results_link)
os.symlink(DRIVE_DIR, results_link)

# Verify symlink
assert os.path.islink(results_link), 'Symlink failed!'
assert os.path.realpath(results_link) == DRIVE_DIR, 'Symlink points to wrong dir!'
print(f'results/ -> {DRIVE_DIR} (symlinked to Drive)')
print(f'Cache contents: {os.listdir(os.path.join(DRIVE_DIR, "cache"))}')

In [None]:
# Cell 3: Verify GPU + show resume state
import torch
import json
import os
import glob

DRIVE_DIR = '/content/drive/MyDrive/wavellm_results'

# GPU check
assert torch.cuda.is_available(), 'No GPU! Go to Runtime > Change runtime type > T4 GPU'
gpu_name = torch.cuda.get_device_name(0)
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f'GPU: {gpu_name} ({vram_gb:.1f} GB VRAM)')

# AMP dtype selection
if 'T4' in gpu_name or 'V100' in gpu_name:
    print(f'AMP: fp16 + GradScaler (narrow exponent range)')
elif 'A100' in gpu_name or 'H100' in gpu_name:
    print(f'AMP: bf16, no GradScaler needed')

# Resume state check — per-variant files (parallel) + partial JSON (sequential)
all_variants = ['A_v433_baseline', 'B_normalized_exp_only', 'C_gate_10x_only',
                'D_kernel_reach_only', 'E_v434_full', 'F_standard']
done = set()

# Check per-variant result files
for key in all_variants:
    vpath = os.path.join(DRIVE_DIR, f'v434_ablation_{key}.json')
    if os.path.exists(vpath):
        done.add(key)

# Check partial JSON (from sequential runs)
partial = os.path.join(DRIVE_DIR, 'v434_ablation_partial.json')
if os.path.exists(partial):
    with open(partial) as f:
        data = json.load(f)
    for r in data.get('results', []):
        abl = r.get('ablation')
        if abl:
            done.add(abl)

# Check final results
final = os.path.join(DRIVE_DIR, 'v434_ablation.json')
if os.path.exists(final):
    with open(final) as f:
        data = json.load(f)
    for r in data.get('results', []):
        abl = r.get('ablation')
        if abl:
            done.add(abl)

remaining = [v for v in all_variants if v not in done]

if len(done) == len(all_variants):
    print(f'\nALL {len(done)} VARIANTS DONE! Run Cell 7 for results.')
elif done:
    print(f'\nResuming: {len(done)}/{len(all_variants)} done, {len(remaining)} remaining')
    print(f'  Done: {sorted(done)}')
    print(f'  Remaining: {remaining}')
    # Show completed results
    for key in sorted(done):
        vpath = os.path.join(DRIVE_DIR, f'v434_ablation_{key}.json')
        if os.path.exists(vpath):
            with open(vpath) as f:
                r = json.load(f)
            ppl = r.get('best_ppl', '?')
            print(f'    {r.get("ablation_name", key)}: PPL={ppl}')
else:
    print(f'\nFresh run — no previous results found on Drive.')
    print(f'Cell 4 will train all {len(all_variants)} variants.')

In [None]:
# Cell 4: Run ablation — PARALLEL MODE (~1.5hrs on T4)
# 3 processes share the GPU, each training 2 variants simultaneously.
# S1 model (22M params) uses ~2GB VRAM, so 3 processes fit easily in 16GB.
# Each process saves per-variant result files — crash-safe per-variant.
#
# To run SEQUENTIAL instead (slower but simpler):
#   !cd /content/wave-field-llm && SEED=42 WANDB=0 python benchmarks/benchmark_v434_ablation.py

import subprocess, os, json, glob, time

REPO = '/content/wave-field-llm'
DRIVE_DIR = '/content/drive/MyDrive/wavellm_results'

# Check what's already done
all_variants = ['A_v433_baseline', 'B_normalized_exp_only', 'C_gate_10x_only',
                'D_kernel_reach_only', 'E_v434_full', 'F_standard']
done = set()
for key in all_variants:
    vpath = os.path.join(DRIVE_DIR, f'v434_ablation_{key}.json')
    if os.path.exists(vpath):
        done.add(key)
# Also check partial JSON from previous sequential run
partial = os.path.join(DRIVE_DIR, 'v434_ablation_partial.json')
if os.path.exists(partial):
    with open(partial) as f:
        for r in json.load(f).get('results', []):
            abl = r.get('ablation')
            if abl:
                done.add(abl)

remaining = [v for v in all_variants if v not in done]

if not remaining:
    print('All 6 variants already done! Run Cell 7 for results.')
else:
    print(f'{len(remaining)} variants remaining: {remaining}')
    if done:
        print(f'Skipping (already done): {sorted(done)}')

    # Split remaining into 3 groups for parallel execution
    n_procs = min(3, len(remaining))
    groups = [[] for _ in range(n_procs)]
    for i, v in enumerate(remaining):
        groups[i % n_procs].append(v)
    groups = [g for g in groups if g]  # remove empty

    # Launch parallel processes
    procs = []
    env = {**os.environ, 'SEED': '42', 'WANDB': '0', 'MONITOR': '1'}
    for i, group in enumerate(groups):
        variants_str = ','.join(group)
        proc_env = {**env, 'VARIANTS': variants_str}
        log_path = os.path.join(DRIVE_DIR, f'log_proc{i}.txt')
        print(f'  Process {i}: {variants_str} -> {log_path}')
        p = subprocess.Popen(
            ['python', 'benchmarks/benchmark_v434_ablation.py'],
            cwd=REPO, env=proc_env,
            stdout=open(log_path, 'w'),
            stderr=subprocess.STDOUT,
        )
        procs.append((i, p, group))

    print(f'\n{len(procs)} processes launched! Polling every 60s...\n')

    # Poll for completion with progress updates
    while any(p.poll() is None for _, p, _ in procs):
        time.sleep(60)
        # Count newly done variants
        now_done = set()
        for key in all_variants:
            vpath = os.path.join(DRIVE_DIR, f'v434_ablation_{key}.json')
            if os.path.exists(vpath):
                now_done.add(key)
        new = now_done - done
        if new:
            for k in sorted(new):
                vpath = os.path.join(DRIVE_DIR, f'v434_ablation_{k}.json')
                try:
                    with open(vpath) as f:
                        r = json.load(f)
                    ppl = r.get('best_ppl', '?')
                    print(f'  DONE: {r.get("ablation_name", k)} — PPL={ppl}')
                except:
                    print(f'  DONE: {k}')
            done.update(new)
        running = sum(1 for _, p, _ in procs if p.poll() is None)
        print(f'  [{time.strftime("%H:%M")}] {len(now_done)}/{len(all_variants)} done, {running} procs running')

    # Final status
    print(f'\n{"="*50}')
    for i, p, group in procs:
        status = 'OK' if p.returncode == 0 else f'FAIL (code {p.returncode})'
        print(f'  Process {i} ({",".join(group)}): {status}')
        if p.returncode != 0:
            log_path = os.path.join(DRIVE_DIR, f'log_proc{i}.txt')
            print(f'    Check log: {log_path}')
    print(f'\nDone! Run Cell 5 to check results, Cell 7 for final table.')

In [None]:
# Cell 5: Check progress (run anytime — even while Cell 4 is running)
import json, os, glob

DRIVE_DIR = '/content/drive/MyDrive/wavellm_results'

# 1. Check per-variant result files (parallel mode)
all_variants = ['A_v433_baseline', 'B_normalized_exp_only', 'C_gate_10x_only',
                'D_kernel_reach_only', 'E_v434_full', 'F_standard']
results = []
done = set()
for key in all_variants:
    vpath = os.path.join(DRIVE_DIR, f'v434_ablation_{key}.json')
    if os.path.exists(vpath):
        with open(vpath) as f:
            r = json.load(f)
        results.append(r)
        done.add(key)

# Also check partial JSON (sequential mode)
partial = os.path.join(DRIVE_DIR, 'v434_ablation_partial.json')
if os.path.exists(partial):
    with open(partial) as f:
        data = json.load(f)
    for r in data.get('results', []):
        abl = r.get('ablation')
        if abl and abl not in done:
            results.append(r)
            done.add(abl)

remaining = [v for v in all_variants if v not in done]
print(f'=== PROGRESS: {len(done)}/{len(all_variants)} variants done ===')
if remaining:
    print(f'Remaining: {remaining}')

if results:
    print(f'\n{"Variant":<30} {"PPL":>8} {"Acc":>7} {"tok/s":>10}')
    print(f'{"-"*30} {"-"*8} {"-"*7} {"-"*10}')
    for r in results:
        name = r.get('ablation_name', r.get('run_name', '?'))
        ppl = r.get('best_ppl', 'N/A')
        acc = r.get('best_acc', 'N/A')
        tps = r.get('tokens_per_sec', 'N/A')
        ppl_s = f'{ppl:>8.1f}' if isinstance(ppl, (int,float)) else f'{ppl:>8}'
        acc_s = f'{acc:>6.1f}%' if isinstance(acc, (int,float)) else f'{acc:>7}'
        tps_s = f'{tps:>10,}' if isinstance(tps, (int,float)) else f'{tps:>10}'
        print(f'{name:<30} {ppl_s} {acc_s} {tps_s}')

# 2. Check process logs (parallel mode)
for i in range(3):
    log_path = os.path.join(DRIVE_DIR, f'log_proc{i}.txt')
    if os.path.exists(log_path):
        with open(log_path) as f:
            lines = f.readlines()
        if lines:
            # Show last 5 lines of each log
            print(f'\n--- Process {i} log (last 5 lines) ---')
            for line in lines[-5:]:
                print(f'  {line.rstrip()}')

# 3. Check monitor data for latest/current variant
monitor_dirs = sorted(glob.glob(f'{DRIVE_DIR}/monitor/*/'))
if monitor_dirs:
    latest = monitor_dirs[-1]
    name = os.path.basename(latest.rstrip('/'))
    steps_file = os.path.join(latest, 'monitor_steps.json')
    if os.path.exists(steps_file):
        with open(steps_file) as f:
            steps = json.load(f)
        if steps:
            last = steps[-1]
            print(f'\n=== LIVE MONITOR: {name} (step {last.get("step", "?")}) ===')
            print(f'  Loss: {last.get("loss", "?"):.4f}')
            print(f'  LR:   {last.get("lr", "?"):.6f}')
            print(f'  Total steps recorded: {len(steps)}')

In [None]:
# Cell 6: Visualize physics diagnostics (12-panel dashboard per variant)
# Run after each variant completes, or at the end for all variants
import glob, os

DRIVE_DIR = '/content/drive/MyDrive/wavellm_results'
monitor_dirs = sorted(glob.glob(f'{DRIVE_DIR}/monitor/*/'))

if not monitor_dirs:
    print('No monitor data yet. Run Cell 4 first.')
else:
    print(f'Found {len(monitor_dirs)} monitored variants:')
    for mdir in monitor_dirs:
        snap_file = os.path.join(mdir, 'monitor_snapshots.json')
        step_file = os.path.join(mdir, 'monitor_steps.json')
        if os.path.exists(snap_file):
            name = os.path.basename(mdir.rstrip('/'))
            print(f'\n=== Generating dashboard: {name} ===')
            cmd = f'python diagnostics/visualize_monitor.py {snap_file}'
            if os.path.exists(step_file):
                cmd += f' {step_file}'
            !{cmd}
            # Display dashboard inline
            png = snap_file.replace('.json', '_dashboard.png')
            if os.path.exists(png):
                from IPython.display import Image, display
                display(Image(filename=png, width=1200))
            else:
                print(f'  Dashboard PNG not found at {png}')

In [None]:
# Cell 7: Final results table + gap analysis (merges per-variant files)
import json, os

DRIVE_DIR = '/content/drive/MyDrive/wavellm_results'

# Collect results from all sources: per-variant files, partial JSON, final JSON
all_variants = ['A_v433_baseline', 'B_normalized_exp_only', 'C_gate_10x_only',
                'D_kernel_reach_only', 'E_v434_full', 'F_standard']
results = []
done = set()

# Per-variant result files (parallel mode)
for key in all_variants:
    vpath = os.path.join(DRIVE_DIR, f'v434_ablation_{key}.json')
    if os.path.exists(vpath):
        with open(vpath) as f:
            r = json.load(f)
        results.append(r)
        done.add(key)

# Partial JSON fallback (sequential mode)
partial_path = os.path.join(DRIVE_DIR, 'v434_ablation_partial.json')
if os.path.exists(partial_path):
    with open(partial_path) as f:
        data = json.load(f)
    for r in data.get('results', []):
        abl = r.get('ablation')
        if abl and abl not in done:
            results.append(r)
            done.add(abl)

# Final JSON fallback
final_path = os.path.join(DRIVE_DIR, 'v434_ablation.json')
if os.path.exists(final_path):
    with open(final_path) as f:
        data = json.load(f)
    for r in data.get('results', []):
        abl = r.get('ablation')
        if abl and abl not in done:
            results.append(r)
            done.add(abl)

if not results:
    print('No results found. Run Cell 4 first.')
else:
    remaining = [v for v in all_variants if v not in done]
    status = 'FINAL' if not remaining else f'PARTIAL ({len(done)}/{len(all_variants)})'

    print(f'{"="*60}')
    print(f'  V4.3.4 ABLATION RESULTS ({status})')
    print(f'{"="*60}')
    print(f'\n  {"Variant":<30} {"PPL":>8} {"Acc":>7} {"Params":>12} {"tok/s":>10}')
    print(f'  {"-"*30} {"-"*8} {"-"*7} {"-"*12} {"-"*10}')

    std_ppl = None
    for r in results:
        name = r.get('ablation_name', r.get('run_name', '?'))
        ppl = r.get('best_ppl', 'N/A')
        acc = r.get('best_acc', 'N/A')
        params = r.get('params', 'N/A')
        tps = r.get('tokens_per_sec', 'N/A')
        ppl_s = f'{ppl:>8.1f}' if isinstance(ppl, (int,float)) else f'{ppl:>8}'
        acc_s = f'{acc:>6.1f}%' if isinstance(acc, (int,float)) else f'{acc:>7}'
        params_s = f'{params:>12,}' if isinstance(params, (int,float)) else f'{params:>12}'
        tps_s = f'{tps:>10,}' if isinstance(tps, (int,float)) else f'{tps:>10}'
        print(f'  {name:<30} {ppl_s} {acc_s} {params_s} {tps_s}')
        if r.get('ablation') == 'F_standard' and isinstance(ppl, (int,float)):
            std_ppl = ppl

    # Gap analysis
    if std_ppl:
        print(f'\n  --- GAP ANALYSIS (vs Standard PPL {std_ppl:.1f}) ---')
        print(f'  {"Variant":<30} {"PPL":>8} {"Gap":>8} {"Delta":>8}')
        print(f'  {"-"*30} {"-"*8} {"-"*8} {"-"*8}')
        baseline_ppl = None
        for r in results:
            ppl = r.get('best_ppl')
            if not isinstance(ppl, (int, float)):
                continue
            name = r.get('ablation_name', r.get('run_name', '?'))
            gap = ppl / std_ppl
            if r.get('ablation') == 'A_v433_baseline':
                baseline_ppl = ppl
            delta = f'{((ppl - baseline_ppl) / baseline_ppl * 100):>+7.1f}%' if baseline_ppl else '    base'
            print(f'  {name:<30} {ppl:>8.1f} {gap:>7.2f}x {delta}')

    # Save merged final results if all done
    if not remaining:
        merged = {
            'metadata': {
                'benchmark': 'v434_ablation',
                'scale': 'S1',
                'timestamp': __import__('time').strftime('%Y-%m-%d %H:%M:%S'),
            },
            'results': results,
        }
        merged_path = os.path.join(DRIVE_DIR, 'v434_ablation.json')
        with open(merged_path, 'w') as f:
            json.dump(merged, f, indent=2)
        print(f'\n  Merged results saved: {merged_path}')