# VAE Hyperparameter Sweep (MLP vs Conv)

This notebook runs a large hyperparameter sweep for both VAE architectures using the existing scripts:
- `scripts/train_vae.py`
- `scripts/eval_vae.py`

It keeps the same dataset, chronological split logic, and evaluation process used in the main pipeline.

## Goals
1. Test many combinations (latent size, architecture widths, optimization knobs).
2. Check if there is a clear best configuration.
3. Quantify whether architecture choice (MLP vs Conv) materially changes outcomes after tuning.


In [None]:
import os
import sys
import json
import time
import hashlib
import random
import shutil
import subprocess
from pathlib import Path
from itertools import product

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

plt.style.use('seaborn-v0_8-whitegrid')

print('Python:', sys.version)
print('Executable:', sys.executable)
print('Notebook CWD:', Path.cwd())

In [None]:
# =========================
# Sweep Configuration
# =========================

def find_project_root(start: Path) -> Path:
    start = start.resolve()
    candidates = [start, *start.parents]
    for p in candidates:
        if (p / 'pyproject.toml').exists() and (p / 'scripts').exists():
            return p
    raise RuntimeError(f'Could not find project root from: {start}')

PROJECT_ROOT = find_project_root(Path.cwd())
PARQUET_PATH = (PROJECT_ROOT / 'data/processed/vae/parquet/AAPL_vsurf_processed.parquet').resolve()
OUTPUT_ROOT = (PROJECT_ROOT / 'artifacts/tuning_sweep').resolve()
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)

# Trial budget controls
SEED = 42
MAX_TRIALS_MLP = 60
MAX_TRIALS_CONV = 60

# Keep training protocol consistent with existing pipeline
EPOCHS = 100
PATIENCE = 20
DEVICE = 'auto'

# Optional cleanup after sweep
KEEP_FULL_ARTIFACTS_FOR_TOP_K = 10

assert PARQUET_PATH.exists(), f'Missing parquet: {PARQUET_PATH}'

total_trials = MAX_TRIALS_MLP + MAX_TRIALS_CONV
print('Project root:', PROJECT_ROOT)
print('Parquet:', PARQUET_PATH)
print('Output root:', OUTPUT_ROOT)
print('Budget:', total_trials, 'trials total')
print('Estimated runtime (~2 min/trial):', round(total_trials * 2 / 60, 2), 'hours')

In [None]:
# =========================
# Search Spaces
# =========================

# Shared knobs
LR_SPACE = [3e-4, 7e-4, 1e-3, 2e-3]
BETA_SPACE = [0.25, 0.5, 1.0, 2.0]
BATCH_SPACE = [32, 64]
WEIGHT_DECAY_SPACE = [0.0, 1e-6, 1e-5]

# MLP-specific
MLP_LATENT_SPACE = [4, 8, 12, 16, 24]
MLP_HIDDEN_SPACE = [
    [128, 64],
    [256, 128],
    [384, 192],
    [512, 256],
]

# Conv-specific
CONV_LATENT_SPACE = [4, 8, 12, 16, 24]
CONV_CHANNEL_SPACE = [
    [16, 32, 64],
    [24, 48, 96],
    [32, 64, 128],
]
CONV_FC_SPACE = [128, 256, 512]
CONV_BATCHNORM_SPACE = [True, False]

def build_mlp_grid():
    grid = []
    for latent_dim, hidden_dims, lr, beta, batch_size, wd in product(
        MLP_LATENT_SPACE,
        MLP_HIDDEN_SPACE,
        LR_SPACE,
        BETA_SPACE,
        BATCH_SPACE,
        WEIGHT_DECAY_SPACE,
    ):
        grid.append({
            'model_type': 'mlp',
            'latent_dim': latent_dim,
            'hidden_dims': hidden_dims,
            'lr': lr,
            'beta': beta,
            'batch_size': batch_size,
            'weight_decay': wd,
        })
    return grid

def build_conv_grid():
    grid = []
    for latent_dim, channels, fc_dim, batchnorm, lr, beta, batch_size, wd in product(
        CONV_LATENT_SPACE,
        CONV_CHANNEL_SPACE,
        CONV_FC_SPACE,
        CONV_BATCHNORM_SPACE,
        LR_SPACE,
        BETA_SPACE,
        BATCH_SPACE,
        WEIGHT_DECAY_SPACE,
    ):
        grid.append({
            'model_type': 'conv',
            'latent_dim': latent_dim,
            'channels': channels,
            'fc_dim': fc_dim,
            'batchnorm': batchnorm,
            'lr': lr,
            'beta': beta,
            'batch_size': batch_size,
            'weight_decay': wd,
        })
    return grid

mlp_grid = build_mlp_grid()
conv_grid = build_conv_grid()

print('MLP full grid size :', len(mlp_grid))
print('Conv full grid size:', len(conv_grid))

In [None]:
# =========================
# Deterministic Sampling
# =========================

rng = random.Random(SEED)

def sample_trials(grid, max_trials, rng):
    if len(grid) <= max_trials:
        return list(grid)
    idx = list(range(len(grid)))
    rng.shuffle(idx)
    return [grid[i] for i in idx[:max_trials]]

mlp_trials = sample_trials(mlp_grid, MAX_TRIALS_MLP, rng)
conv_trials = sample_trials(conv_grid, MAX_TRIALS_CONV, rng)
all_trials = mlp_trials + conv_trials
rng.shuffle(all_trials)

print('Selected MLP trials :', len(mlp_trials))
print('Selected Conv trials:', len(conv_trials))
print('Total selected      :', len(all_trials))

In [None]:
# =========================
# Runner Utilities
# =========================

def trial_key(cfg):
    payload = json.dumps(cfg, sort_keys=True)
    return hashlib.md5(payload.encode('utf-8')).hexdigest()[:12]

def run_cmd_live(cmd, cwd, log_path, stage_name):
    """
    Run command with live streaming output to notebook and file.
    Returns (returncode, combined_output).
    """
    env = dict(**os.environ)
    env["PYTHONUNBUFFERED"] = "1"
    env["PYTHONIOENCODING"] = "utf-8"

    print(f"\n[{stage_name}] START")
    print(f"[{stage_name}] CWD: {cwd}")
    print(f"[{stage_name}] CMD: {' '.join(cmd)}")

    lines = []
    with open(log_path, "w", encoding="utf-8") as lf:
        proc = subprocess.Popen(
            cmd,
            cwd=str(cwd),
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            encoding="utf-8",
            errors="replace",
            bufsize=1,
            env=env,
        )

        last_heartbeat = time.time()
        while True:
            line = proc.stdout.readline() if proc.stdout is not None else ""
            if line:
                msg = line.rstrip("\n")
                lines.append(msg)
                lf.write(msg + "\n")
                print(f"[{stage_name}] {msg}")
                last_heartbeat = time.time()
            elif proc.poll() is not None:
                break
            else:
                now = time.time()
                if now - last_heartbeat > 20:
                    print(f"[{stage_name}] ... still running ...")
                    last_heartbeat = now
                time.sleep(0.2)

        returncode = proc.wait()

    print(f"[{stage_name}] END (code={returncode})")
    return returncode, "\n".join(lines)

def build_train_cmd(cfg, train_dir):
    cmd = [
        sys.executable, 'scripts/train_vae.py',
        '--parquet', str(PARQUET_PATH),
        '--model_type', cfg['model_type'],
        '--latent_dim', str(cfg['latent_dim']),
        '--epochs', str(EPOCHS),
        '--batch_size', str(cfg['batch_size']),
        '--lr', str(cfg['lr']),
        '--beta', str(cfg['beta']),
        '--weight_decay', str(cfg['weight_decay']),
        '--patience', str(PATIENCE),
        '--seed', str(SEED),
        '--device', DEVICE,
        '--output_dir', str(train_dir),
    ]

    if cfg['model_type'] == 'mlp':
        cmd += ['--hidden_dims'] + [str(x) for x in cfg['hidden_dims']]
    else:
        cmd += ['--channels'] + [str(x) for x in cfg['channels']]
        cmd += ['--fc_dim', str(cfg['fc_dim'])]
        if not cfg['batchnorm']:
            cmd += ['--no_batchnorm']

    return cmd

def build_eval_cmd(train_dir, eval_dir):
    checkpoint = train_dir / 'vae_checkpoint.pt'
    return [
        sys.executable, 'scripts/eval_vae.py',
        '--checkpoint', str(checkpoint),
        '--parquet', str(PARQUET_PATH),
        '--output_dir', str(eval_dir),
        '--device', DEVICE,
        '--n_plot_samples', '0',
    ]

def parse_metrics(eval_dir):
    p = eval_dir / 'test_metrics.json'
    if not p.exists():
        return {}
    return json.loads(p.read_text())

def run_trial(cfg, trial_idx, total_trials):
    key = trial_key(cfg)
    trial_name = f"{trial_idx:04d}_{cfg['model_type']}_{key}"
    trial_dir = OUTPUT_ROOT / trial_name
    train_dir = trial_dir / 'train'
    eval_dir = trial_dir / 'eval'
    logs_dir = trial_dir / 'logs'

    train_dir.mkdir(parents=True, exist_ok=True)
    eval_dir.mkdir(parents=True, exist_ok=True)
    logs_dir.mkdir(parents=True, exist_ok=True)

    train_cmd = build_train_cmd(cfg, train_dir)
    eval_cmd = build_eval_cmd(train_dir, eval_dir)

    print("\n" + "=" * 90)
    print(f"TRIAL {trial_idx}/{total_trials} :: {trial_name}")
    print(f"CONFIG :: {json.dumps(cfg, sort_keys=True)}")
    print("=" * 90)

    t0 = time.time()

    train_t0 = time.time()
    tr_code, tr_combined = run_cmd_live(
        train_cmd,
        PROJECT_ROOT,
        logs_dir / 'train_output.txt',
        f"{trial_name}::TRAIN",
    )
    train_sec = time.time() - train_t0

    result = {
        'trial_idx': trial_idx,
        'trial_name': trial_name,
        'trial_key': key,
        'model_type': cfg['model_type'],
        'status': 'train_failed' if tr_code != 0 else 'train_ok',
        'train_returncode': tr_code,
        'eval_returncode': None,
        'runtime_sec': None,
        'train_runtime_sec': round(train_sec, 3),
        'eval_runtime_sec': None,
        'train_dir': str(train_dir),
        'eval_dir': str(eval_dir),
        **cfg,
    }

    if tr_code != 0:
        (logs_dir / 'train_stderr.txt').write_text(tr_combined, encoding='utf-8')

    if tr_code == 0:
        eval_t0 = time.time()
        ev_code, ev_combined = run_cmd_live(
            eval_cmd,
            PROJECT_ROOT,
            logs_dir / 'eval_output.txt',
            f"{trial_name}::EVAL",
        )
        eval_sec = time.time() - eval_t0

        result['eval_returncode'] = ev_code
        result['eval_runtime_sec'] = round(eval_sec, 3)
        result['status'] = 'ok' if ev_code == 0 else 'eval_failed'

        if ev_code != 0:
            (logs_dir / 'eval_stderr.txt').write_text(ev_combined, encoding='utf-8')

        if ev_code == 0:
            metrics = parse_metrics(eval_dir)
            result.update({
                'elbo': metrics.get('elbo'),
                'recon_loss': metrics.get('recon_loss'),
                'kl_loss': metrics.get('kl_loss'),
                'mse_original': metrics.get('mse_original'),
                'mae_original': metrics.get('mae_original'),
                'rmse_original': metrics.get('rmse_original'),
                'n_samples': metrics.get('n_samples'),
            })

    result['runtime_sec'] = round(time.time() - t0, 3)

    print(f"[{trial_name}] STATUS={result['status']} | total={result['runtime_sec']:.1f}s | train={result['train_runtime_sec']:.1f}s | eval={result.get('eval_runtime_sec')}")
    if result.get('mae_original') is not None:
        print(f"[{trial_name}] MAE={result['mae_original']:.6f}, RMSE={result['rmse_original']:.6f}, ELBO={result['elbo']:.6f}")

    return result

In [None]:
# =========================
# Execute Sweep (resume-safe)
# =========================

RESULTS_CSV = OUTPUT_ROOT / 'results.csv'

if RESULTS_CSV.exists():
    existing = pd.read_csv(RESULTS_CSV)
    print(f'Loaded existing results: {len(existing)} rows')

    # Keep only latest record per trial_key if duplicates exist
    if 'trial_key' in existing.columns and not existing.empty:
        existing = existing.sort_values('trial_idx').drop_duplicates(subset=['trial_key'], keep='last')

    # Only successful trials are considered done
    done_keys = set(existing.loc[existing['status'] == 'ok', 'trial_key'].astype(str).tolist())
else:
    existing = pd.DataFrame()
    done_keys = set()
    print('No existing results found; starting fresh.')

pending = [cfg for cfg in all_trials if trial_key(cfg) not in done_keys]
print(f'Pending trials: {len(pending)} / {len(all_trials)}')

new_rows = []
for i, cfg in enumerate(pending, start=1):
    trial_idx = len(existing) + i
    row = run_trial(cfg, trial_idx, len(all_trials))
    new_rows.append(row)

    # incremental checkpoint save
    cur_df = pd.DataFrame(new_rows)
    out_df = pd.concat([existing, cur_df], ignore_index=True)

    # keep latest row per trial_key
    if 'trial_key' in out_df.columns and not out_df.empty:
        out_df = out_df.sort_values('trial_idx').drop_duplicates(subset=['trial_key'], keep='last')

    out_df.to_csv(RESULTS_CSV, index=False)

results = pd.read_csv(RESULTS_CSV) if RESULTS_CSV.exists() else pd.DataFrame(new_rows)
print('Sweep complete. Rows in results:', len(results))
print('Successful trials:', int((results['status'] == 'ok').sum()) if 'status' in results.columns else 0)

In [None]:
# =========================
# Optional: prune non-top trial folders
# =========================

results = pd.read_csv(RESULTS_CSV)
ok = results[results['status'] == 'ok'].copy()
ok = ok.sort_values(['mae_original', 'rmse_original', 'elbo'], ascending=[True, True, True])
keep_trials = set(ok.head(KEEP_FULL_ARTIFACTS_FOR_TOP_K)['trial_name'].tolist())

deleted = 0
for d in OUTPUT_ROOT.iterdir():
    if not d.is_dir():
        continue
    if d.name in keep_trials:
        continue
    if d.name == '__pycache__':
        continue
    try:
        shutil.rmtree(d)
        deleted += 1
    except Exception as e:
        print('Could not delete', d, e)

print(f'Removed {deleted} non-top trial directories. Kept top {len(keep_trials)} full artifacts.')

In [None]:
# =========================
# Analysis
# =========================

results = pd.read_csv(RESULTS_CSV)
ok = results[results['status'] == 'ok'].copy()

if ok.empty:
    raise RuntimeError('No successful trials found.')

ok['mae_vp'] = ok['mae_original'] * 100
ok['rmse_vp'] = ok['rmse_original'] * 100
ok['runtime_min'] = ok['runtime_sec'] / 60.0

display(ok[['trial_name', 'model_type', 'mae_original', 'rmse_original', 'elbo', 'runtime_sec']].sort_values('mae_original').head(20))

summary = ok.groupby('model_type').agg(
    n_trials=('trial_name', 'count'),
    best_mae=('mae_original', 'min'),
    median_mae=('mae_original', 'median'),
    best_rmse=('rmse_original', 'min'),
    median_rmse=('rmse_original', 'median'),
    mean_runtime_min=('runtime_min', 'mean'),
).reset_index()

print('Architecture summary:')
display(summary)

In [None]:
# Distribution plots
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

for m, color in [('mlp', '#1f77b4'), ('conv', '#2ca02c')]:
    vals = ok.loc[ok['model_type'] == m, 'mae_vp']
    axes[0].hist(vals, bins=20, alpha=0.6, label=m.upper(), color=color)

axes[0].set_title('MAE distribution (vol points)')
axes[0].set_xlabel('MAE (vol points)')
axes[0].set_ylabel('Count')
axes[0].legend()

for m, color in [('mlp', '#1f77b4'), ('conv', '#2ca02c')]:
    vals = ok.loc[ok['model_type'] == m, 'rmse_vp']
    axes[1].hist(vals, bins=20, alpha=0.6, label=m.upper(), color=color)

axes[1].set_title('RMSE distribution (vol points)')
axes[1].set_xlabel('RMSE (vol points)')
axes[1].set_ylabel('Count')
axes[1].legend()

plt.tight_layout()
plt.show()

In [None]:
# Shared hyperparameter sensitivity
fig, axes = plt.subplots(2, 2, figsize=(12, 9))

colors = ok['model_type'].map({'mlp':'#1f77b4', 'conv':'#2ca02c'})

axes[0,0].scatter(ok['lr'], ok['mae_vp'], c=colors, alpha=0.7)
axes[0,0].set_xscale('log')
axes[0,0].set_title('MAE vs Learning Rate')
axes[0,0].set_xlabel('lr')
axes[0,0].set_ylabel('MAE (vp)')

axes[0,1].scatter(ok['beta'], ok['mae_vp'], c=colors, alpha=0.7)
axes[0,1].set_title('MAE vs Beta')
axes[0,1].set_xlabel('beta')
axes[0,1].set_ylabel('MAE (vp)')

ok.boxplot(column='mae_vp', by='batch_size', ax=axes[1,0])
axes[1,0].set_title('MAE by Batch Size')
axes[1,0].set_xlabel('batch_size')
axes[1,0].set_ylabel('MAE (vp)')

ok.boxplot(column='mae_vp', by='weight_decay', ax=axes[1,1])
axes[1,1].set_title('MAE by Weight Decay')
axes[1,1].set_xlabel('weight_decay')
axes[1,1].set_ylabel('MAE (vp)')

plt.suptitle('')
plt.tight_layout()
plt.show()

In [None]:
# Top configurations
top_mlp = ok[ok['model_type']=='mlp'].sort_values(['mae_original','rmse_original','elbo']).head(10)
top_conv = ok[ok['model_type']=='conv'].sort_values(['mae_original','rmse_original','elbo']).head(10)

print('Top 10 MLP configs')
display(top_mlp[['trial_name','mae_original','rmse_original','elbo','latent_dim','hidden_dims','lr','beta','batch_size','weight_decay']])

print('Top 10 Conv configs')
display(top_conv[['trial_name','mae_original','rmse_original','elbo','latent_dim','channels','fc_dim','batchnorm','lr','beta','batch_size','weight_decay']])

best = ok.sort_values(['mae_original','rmse_original','elbo']).iloc[0]
print('Overall best trial')
display(best.to_frame().T)