# Stage 2.5: Baseline vs Elementwise (1500 iters, wd=1450, lr=0.00468)

In [None]:
%%bash
source /workspace/ese-3060-project/.venv/bin/activate
python -m ipykernel install --user --name ese3060-venv --display-name "ese3060 venv"
echo "Kernel installed."

In [None]:
import os, subprocess, json, glob, re
from datetime import datetime
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('ggplot')
%matplotlib inline
%cd /workspace/ese-3060-project

PROJECT_ROOT = os.environ.get('PROJ_ROOT', '/workspace/ese-3060-project')
SCRIPT_PATH = os.path.join(PROJECT_ROOT, 'train_gpt.py')
RESULTS_ALL = os.path.join(PROJECT_ROOT, 'experiments', 'results.csv')
RESULTS = os.path.join(PROJECT_ROOT, 'experiments', 'results_stage2_5.csv')
LOG_DIR = os.path.join(PROJECT_ROOT, 'logs')
SPLITTER = os.path.join(PROJECT_ROOT, 'scripts', 'split_results.py')
PROJECT_ROOT, SCRIPT_PATH

In [None]:
# Runtime knobs
NPROC = None                # auto-detect if None
LR = 0.00468
SEEDS = [1337, 2337]        # two runs per config
NUM_ITER = 1500
WARMUP_ITERS = 0
WARMDOWN_ITERS = 1450
VAL_EVERY = 125
CONFIGS = {
    "baseline":    {"ATTNGATE": "none",        "GATEPOS": "sdpa", "GATEACT": "sigmoid"},
    "elementwise": {"ATTNGATE": "elementwise", "GATEPOS": "sdpa", "GATEACT": "sigmoid"},
}
TORCHRUN = "torchrun"
LAUNCH = False             # set True to actually run

# torchrun helpers
if NPROC is None:
    try:
        gpu_count = int(subprocess.check_output("nvidia-smi --list-gpus | wc -l", shell=True).decode().strip())
    except Exception:
        gpu_count = 0
    NPROC = max(gpu_count, 1)

assert os.path.exists(SCRIPT_PATH), f"Missing train script: {SCRIPT_PATH}"

def run_job(name, cfg, seed):
    env = os.environ.copy()
    env.update({
        "ATTNGATE": cfg["ATTNGATE"],
        "GATEPOS": cfg["GATEPOS"],
        "GATEACT": cfg["GATEACT"],
        "LR": str(LR),
        "SEED": str(seed),
        "NUM_ITERATIONS": str(NUM_ITER),
        "WARMUP_ITERS": str(WARMUP_ITERS),
        "WARMDOWN_ITERS": str(WARMDOWN_ITERS),
        "VAL_EVERY": str(VAL_EVERY),
    })
    cmd = [TORCHRUN, "--standalone", f"--nproc_per_node={NPROC}", SCRIPT_PATH]
    print(f"\n>>> Launching {name} seed={seed} lr={LR:.5f} nproc={NPROC}")
    if not LAUNCH:
        return 0
    proc = subprocess.run(cmd, env=env)
    if proc.returncode != 0:
        raise RuntimeError(f"Run failed: {name} seed {seed} rc={proc.returncode}")

for cfg_name, cfg_env in CONFIGS.items():
    for seed in SEEDS:
        run_job(cfg_name, cfg_env, seed)

if LAUNCH:
    # after runs, refresh splits
    subprocess.run(["python", SPLITTER, "--stage1-iters", "1500", "--stage2-iters", "800"], check=False)

print("Done (LAUNCH=" + str(LAUNCH) + ")")

In [None]:
# Load results (prefer stage2_5 split if present)
if os.path.exists(RESULTS):
    df = pd.read_csv(RESULTS)
elif os.path.exists(RESULTS_ALL):
    df = pd.read_csv(RESULTS_ALL)
else:
    df = pd.DataFrame()
df.head() if not df.empty else df

In [None]:
# Simple grouping
stage = df.copy()
if not stage.empty:
    agg = stage.groupby(["attn_gate", "learning_rate"]).agg(
        runs=("run_id", "count"),
        mean_best_val=("best_val_loss", "mean"),
        std_best_val=("best_val_loss", "std"),
        mean_ms_step=("ms_per_step", "mean"),
    ).reset_index()
    display(agg)
    if not agg.empty:
        fig, ax = plt.subplots(figsize=(6,4))
        for gate, sub in agg.groupby('attn_gate'):
            ax.errorbar(sub['learning_rate'], sub['mean_best_val'], yerr=sub['std_best_val'], marker='o', capsize=4, label=gate)
        ax.set_xlabel('learning_rate')
        ax.set_ylabel('best_val_loss')
        ax.legend()
        ax.grid(True, alpha=0.3)
        plt.show()
else:
    print("No stage 2.5 data yet; run jobs first.")

In [None]:
# Parse logs and plot curves
VAL_RE = re.compile(r"step:(\d+)/(\d+).*val_loss:([0-9.]+).*train_time:(\d+)ms")
rows = []
for path in Path(LOG_DIR).glob('*.txt') if os.path.exists(LOG_DIR) else []:
    run_id = path.stem
    cfg = stage[stage['run_id']==run_id]
    if cfg.empty:
        continue
    cfg = cfg.iloc[0]
    with open(path) as f:
        for line in f:
            m = VAL_RE.search(line)
            if m:
                step = int(m.group(1))
                loss = float(m.group(3))
                t_ms = int(m.group(4))
                rows.append({'run_id': run_id, 'step': step, 'val_loss': loss, 'train_time_ms': t_ms, 'attn_gate': cfg['attn_gate'], 'learning_rate': cfg['learning_rate']})
curves = pd.DataFrame(rows)
if not curves.empty:
    fig, axes = plt.subplots(1, len(curves['attn_gate'].unique()), figsize=(12,4), sharey=True)
    if len(curves['attn_gate'].unique())==1:
        axes=[axes]
    for ax, (gate, sub) in zip(axes, curves.groupby('attn_gate')):
        for lr, sublr in sub.groupby('learning_rate'):
            sublr = sublr.sort_values('train_time_ms')
            ax.plot(sublr['train_time_ms']/1000.0, sublr['val_loss'], label=f"lr={lr:.5f}")
        ax.set_title(f"attn_gate={gate}")
        ax.set_xlabel('train_time (s)')
        ax.set_ylabel('val_loss')
        ax.legend()
        ax.grid(True, alpha=0.3)
    plt.show()
else:
    print('No matching logs parsed; run jobs first.')

In [None]:
# Export curves
if 'curves' in locals() and not curves.empty:
    export_path = Path(PROJ_ROOT) / 'experiments' / 'log_curves_stage2_5.csv'
    export_path.parent.mkdir(parents=True, exist_ok=True)
    curves.to_csv(export_path, index=False)
    print(f'Saved curves to {export_path}')
else:
    print('curves is empty; nothing to export.')