# Stage 2: LR Sweep (Baseline vs Elementwise)

Runs short LR sweep with warmup/early-stop. Baseline (none) vs elementwise gating.

## 0a) (Once per machine) Register venv kernel
If Jupyter isn't using the venv, run this once, then switch kernel to `ese3060 venv`.

In [None]:
%%bash
source /workspace/ese-3060-project/.venv/bin/activate
python -m ipykernel install --user --name ese3060-venv --display-name "ese3060 venv"
echo "Kernel installed."



## 0) Setup
Assumes repo at `/workspace/ese-3060-project` and data symlinked at `data/fineweb10B`. Activate venv before starting JupyterLab.

In [None]:
import os, subprocess, json, glob, re
from datetime import datetime
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('ggplot')
%cd /workspace/ese-3060-project
RESULTS = os.path.join('experiments','results.csv')


## 1) Launch runs (Stage 2 LR sweep)
Baseline (attn_gate=none) vs elementwise, SDPA position, sigmoid gate. Warmup 150, total iters default 800, warmdown set to fit short run.

In [None]:

# Runtime knobs
NPROC = None                # set to int to override; otherwise auto-detect GPU count
BASE_LR = 0.0036
LR_MULTS = [1.0, 1.1, 1.2, 1.3]
SEEDS = [1337, 2337]
NUM_ITER = 800
WARMUP_ITERS = 150
WARMDOWN_ITERS = 600       # set <= NUM_ITER
EARLY_STOP_PATIENCE = 2    # 0 disables
EARLY_STOP_MIN_DELTA = 0.0
ATTN_GATES = ["none", "elementwise"]
SCRIPT_PATH = "/workspace/ese-3060-project/train_gpt.py"
TORCHRUN = "torchrun"
LAUNCH = False             # set True to launch runs

# torchrun helpers
if NPROC is None:
    try:
        gpu_count = int(subprocess.check_output("nvidia-smi --list-gpus | wc -l", shell=True).decode().strip())
    except Exception:
        gpu_count = 0
    NPROC = max(gpu_count, 1)

assert os.path.exists(SCRIPT_PATH), f"Missing train script: {SCRIPT_PATH}"

import shlex, signal

def run_job(attn_gate, lr, seed):
    env = os.environ.copy()
    env.update({
        "ATTNGATE": attn_gate,
        "GATEPOS": "sdpa",
        "GATEACT": "sigmoid",
        "LR": str(lr),
        "SEED": str(seed),
        "WARMUP_ITERS": str(WARMUP_ITERS),
        "NUM_ITERATIONS": str(NUM_ITER),
        "WARMDOWN_ITERS": str(WARMDOWN_ITERS),
        "EARLY_STOP_PATIENCE": str(EARLY_STOP_PATIENCE),
        "EARLY_STOP_MIN_DELTA": str(EARLY_STOP_MIN_DELTA),
    })
    cmd = [TORCHRUN, "--standalone", f"--nproc_per_node={NPROC}", SCRIPT_PATH]
    print(f"
>>> Launching attn_gate={attn_gate} lr={lr:.5f} seed={seed} nproc={NPROC}")
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, env=env, bufsize=1, preexec_fn=os.setsid)
    for line in proc.stdout:
        print(line.rstrip())
    ret = proc.wait()
    print(f"<exit {ret}> attn_gate={attn_gate} lr={lr:.5f} seed={seed}")
    return ret

if LAUNCH:
    for attn in ATTN_GATES:
        for mult in LR_MULTS:
            lr = BASE_LR * mult
            for seed in SEEDS:
                run_job(attn, lr, seed)
else:
    print("Launch skipped (set LAUNCH=True to run)")



## 2) Load results

In [None]:

if os.path.exists(RESULTS):
    df = pd.read_csv(RESULTS)
else:
    df = pd.DataFrame()
df.head() if not df.empty else df



### Filter to Stage 2 runs (short LR sweep)

In [None]:

stage2 = df.copy()
stage2 = stage2[(stage2.get('num_iterations',0)==NUM_ITER) & (stage2.get('warmup_iters',0)==WARMUP_ITERS)]
if 'warmdown_iters' in stage2.columns:
    stage2 = stage2[stage2['warmdown_iters']==WARMDOWN_ITERS]
stage2.head()



### Aggregate metrics by config

In [None]:

if not stage2.empty:
    group_cols = ['attn_gate','learning_rate']
    agg = stage2.groupby(group_cols).agg(
        runs=('run_id','count'),
        mean_best_val=('best_val_loss','mean'),
        std_best_val=('best_val_loss','std'),
        mean_ms_step=('ms_per_step','mean'),
    ).reset_index()
    display(agg)
    if not agg.empty:
        fig, ax = plt.subplots(figsize=(6,4))
        for gate, sub in agg.groupby('attn_gate'):
            ax.plot(sub['learning_rate'], sub['mean_best_val'], marker='o', label=gate)
        ax.set_xlabel('learning_rate')
        ax.set_ylabel('best_val_loss')
        ax.legend()
        ax.grid(True, alpha=0.3)
        plt.show()
else:
    print('No stage2 rows found; run jobs first.')



### Parse val_loss curves from logs (first run per config)

In [None]:

VAL_RE = re.compile(r"step:(\d+)/(\d+).*val_loss:([0-9.]+).*train_time:(\d+)ms")
rows = []
for path in glob.glob('logs/*.txt'):
    run_id = os.path.splitext(os.path.basename(path))[0]
    # map run_id to config
    cfg = stage2[stage2['run_id']==run_id]
    if cfg.empty:
        continue
    cfg = cfg.iloc[0]
    with open(path) as f:
        for line in f:
            m = VAL_RE.search(line)
            if m:
                step = int(m.group(1))
                loss = float(m.group(3))
                t_ms = int(m.group(4))
                rows.append({'run_id': run_id, 'step': step, 'val_loss': loss, 'train_time_ms': t_ms, 'attn_gate': cfg['attn_gate'], 'learning_rate': cfg['learning_rate']})
curves = pd.DataFrame(rows)
if not curves.empty:
    fig, axes = plt.subplots(1, len(curves['attn_gate'].unique()), figsize=(12,4), sharey=True)
    if len(curves['attn_gate'].unique())==1:
        axes=[axes]
    for ax, (gate, sub) in zip(axes, curves.groupby('attn_gate')):
        for lr, sublr in sub.groupby('learning_rate'):
            ax.plot(sublr['train_time_ms']/1000.0, sublr['val_loss'], label=f"lr={lr:.5f}")
        ax.set_title(f"attn_gate={gate}")
        ax.set_xlabel('train_time (s)')
        ax.set_ylabel('val_loss')
        ax.legend()
        ax.grid(True, alpha=0.3)
    plt.show()
else:
    print('No matching logs parsed; run jobs first.')

