# Stage 3 Full Run: Baseline vs Elementwise (full length, lr=0.00468, 3 seeds)

In [None]:
%%bash
source /workspace/ese-3060-project/.venv/bin/activate
python -m ipykernel install --user --name ese3060-venv --display-name "ese3060 venv"
echo "Kernel installed."

In [None]:
from pathlib import Path
import os, subprocess, json, glob, re
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('ggplot')
%matplotlib inline
%cd /workspace/ese-3060-project

# Detect project root: prefer PROJ_ROOT env, else repo root (parent if running from notebooks/)
cwd = Path.cwd().expanduser().resolve()
default_root = cwd.parent if cwd.name == 'notebooks' else cwd
PROJ_ROOT = Path(os.environ.get('PROJ_ROOT', default_root)).expanduser().resolve()
SCRIPT_PATH = PROJ_ROOT / 'train_gpt.py'
RESULTS = PROJ_ROOT / 'experiments' / 'results_stage3.csv'
RESULTS_ALL = PROJ_ROOT / 'experiments' / 'results.csv'
CURVES_CSV = PROJ_ROOT / 'experiments' / 'log_curves_stage3.csv'
LOG_DIR = PROJ_ROOT / 'logs'
SPLITTER = PROJ_ROOT / 'scripts' / 'split_results.py'

# If stage-specific results not found, fall back to aggregated
if not RESULTS.exists() and RESULTS_ALL.exists():
    RESULTS = RESULTS_ALL

print('Project root:', PROJ_ROOT)
print('Results:', RESULTS)
print('Log dir:', LOG_DIR)
print('Curves CSV:', CURVES_CSV)

In [None]:
# Runtime knobs
NPROC = None                # auto-detect GPU count if None
LR = 0.00468
SEEDS = [1337, 2337, 3337]  # 3 seeds per config
NUM_ITER = 5100             # full length
WARMUP_ITERS = 0
WARMDOWN_ITERS = 1450
VAL_EVERY = 125
CONFIGS = {
    "baseline":    {"ATTNGATE": "none",        "GATEPOS": "sdpa", "GATEACT": "sigmoid"},
    "elementwise": {"ATTNGATE": "elementwise", "GATEPOS": "sdpa", "GATEACT": "sigmoid"},
}
TORCHRUN = "torchrun"
LAUNCH = False             # set True to actually run

# torchrun helpers
if NPROC is None:
    try:
        gpu_count = int(subprocess.check_output("nvidia-smi --list-gpus | wc -l", shell=True).decode().strip())
    except Exception:
        gpu_count = 0
    NPROC = max(gpu_count, 1)

assert SCRIPT_PATH.exists(), f"Missing train script: {SCRIPT_PATH}"

def run_job(name, cfg, seed):
    env = os.environ.copy()
    env.update({
        "ATTNGATE": cfg["ATTNGATE"],
        "GATEPOS": cfg["GATEPOS"],
        "GATEACT": cfg["GATEACT"],
        "LR": str(LR),
        "SEED": str(seed),
        "NUM_ITERATIONS": str(NUM_ITER),
        "WARMUP_ITERS": str(WARMUP_ITERS),
        "WARMDOWN_ITERS": str(WARMDOWN_ITERS),
        "VAL_EVERY": str(VAL_EVERY),
    })
    cmd = [TORCHRUN, "--standalone", f"--nproc_per_node={NPROC}", str(SCRIPT_PATH)]
    print(f"\n>>> Launching {name} seed={seed} lr={LR:.5f} nproc={NPROC}")
    if not LAUNCH:
        return 0
    proc = subprocess.run(cmd, env=env)
    if proc.returncode != 0:
        raise RuntimeError(f"Run failed: {name} seed {seed} rc={proc.returncode}")

for cfg_name, cfg_env in CONFIGS.items():
    for seed in SEEDS:
        run_job(cfg_name, cfg_env, seed)

if LAUNCH:
    # Refresh splits (stage1=1500, stage2=800; stage3 remains in main results unless you add a stage3 filter later)
    subprocess.run(["python", str(SPLITTER), "--stage1-iters", "1500", "--stage2-iters", "800"], check=False)

print("Done (LAUNCH=" + str(LAUNCH) + ")")

In [None]:
# Load results
if RESULTS.exists():
    df = pd.read_csv(RESULTS)
elif RESULTS_ALL.exists():
    df = pd.read_csv(RESULTS_ALL)
    # filter best-effort for stage3 params
    df = df[(df.get('num_iterations',0)==NUM_ITER) & (df.get('learning_rate',0).round(5)==LR) & (df.get('warmdown_iters',0)==WARMDOWN_ITERS)]
else:
    df = pd.DataFrame()
df.head() if not df.empty else df

In [None]:
# Aggregate comparison with deltas (elementwise vs baseline)
if not df.empty:
    agg = df.groupby('attn_gate').agg(
        runs=('run_id','count'),
        mean_best_val=('best_val_loss','mean'),
        std_best_val=('best_val_loss','std'),
        mean_final_val=('final_val_loss','mean'),
        mean_ms_step=('ms_per_step','mean'),
        std_ms_step=('ms_per_step','std'),
    ).reset_index()
    display(agg)
    if 'none' in agg['attn_gate'].values and 'elementwise' in agg['attn_gate'].values:
        base = agg.set_index('attn_gate').loc['none']
        elem = agg.set_index('attn_gate').loc['elementwise']
        delta_loss = elem['mean_best_val'] - base['mean_best_val']
        std_loss = base['std_best_val'] if pd.notna(base['std_best_val']) and base['std_best_val']!=0 else float('nan')
        z_loss = delta_loss / std_loss if std_loss==std_loss and std_loss!=0 else float('nan')
        delta_ms = elem['mean_ms_step'] - base['mean_ms_step']
        std_ms = base['std_ms_step'] if pd.notna(base['std_ms_step']) and base['std_ms_step']!=0 else float('nan')
        z_ms = delta_ms / std_ms if std_ms==std_ms and std_ms!=0 else float('nan')

        fig, ax = plt.subplots(figsize=(6,4))
        ax.bar([0], [delta_loss])
        ax.set_xticks([0]); ax.set_xticklabels(['elementwise - baseline'])
        ax.set_ylabel('Δ best val loss (lower better)')
        ax.bar_label(ax.containers[0], labels=[f'z={z_loss:.2f}' if z_loss==z_loss else 'z=NA'])
        ax.grid(True, alpha=0.3)
        plt.tight_layout(); plt.show()

        fig, ax = plt.subplots(figsize=(6,4))
        ax.bar([0], [delta_ms])
        ax.set_xticks([0]); ax.set_xticklabels(['elementwise - baseline'])
        ax.set_ylabel('Δ ms/step (lower better)')
        ax.bar_label(ax.containers[0], labels=[f'z={z_ms:.2f}' if z_ms==z_ms else 'z=NA'])
        ax.grid(True, alpha=0.3)
        plt.tight_layout(); plt.show()
    else:
        print('Baseline or elementwise missing; cannot compute deltas.')
else:
    print('No stage3 results found; ensure runs are logged.')

In [None]:
# Parse logs and plot mean curves per attn_gate
VAL_RE = re.compile(r"step:(\d+)/(\d+).*val_loss:([0-9.]+).*train_time:(\d+)ms")
rows = []
for path in LOG_DIR.glob('*.txt') if LOG_DIR.exists() else []:
    rid = path.stem
    cfg = df[df['run_id']==rid]
    if cfg.empty:
        continue
    cfg = cfg.iloc[0]
    with open(path) as f:
        for line in f:
            m = VAL_RE.search(line)
            if m:
                step = int(m.group(1))
                vloss = float(m.group(3))
                t_ms = int(m.group(4))
                rows.append({
                    'run_id': rid,
                    'step': step,
                    'val_loss': vloss,
                    'train_time_ms': t_ms,
                    'attn_gate': cfg['attn_gate'],
                    'learning_rate': cfg['learning_rate'],
                })
curves = pd.DataFrame(rows)

if not curves.empty:
    fig, ax = plt.subplots(figsize=(7,5))
    for gate, sub in curves.groupby('attn_gate'):
        sub = sub.sort_values('train_time_ms')
        mean_curve = sub.groupby('train_time_ms')['val_loss'].mean().reset_index()
        ax.plot(mean_curve['train_time_ms']/1000.0, mean_curve['val_loss'], label=gate)
    ax.set_xlabel('train_time (s)')
    ax.set_ylabel('val_loss')
    ax.set_title('Mean val_loss vs train time (full run)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.show()
else:
    print('No curves to plot.')

In [None]:
# Export curves
if not curves.empty:
    CURVES_CSV.parent.mkdir(parents=True, exist_ok=True)
    curves.to_csv(CURVES_CSV, index=False)
    print(f"Saved curves to {CURVES_CSV}")
else:
    print('curves is empty; nothing to export.')