# Stage 1: Runs + EDA (Gating Screening)

This notebook lets you:
1. Launch the short Stage 1 screening runs (gating variants) with `torchrun`.
2. Parse `experiments/results.csv` and `logs/*.txt`.
3. Plot loss/time and step-time summaries.

> **Safety:** Runs start only if you set `DRY_RUN = False` in the launch cell. Default is safe.

## 0a) (Once per machine) Register and verify venv kernel inside the notebook
If JupyterLab is already running but not using the venv, run the next cell to install the kernel spec, then switch to it via Kernel → Change Kernel → `ese3060 venv`.

In [None]:
%%bash
source /workspace/ese-3060-project/.venv/bin/activate
python -m ipykernel install --user --name ese3060-venv --display-name "ese3060 venv"
echo "Kernel installed. Switch via Kernel -> Change Kernel -> 'ese3060 venv'"


## 0) Setup
- Assumes repo at `/workspace/ese-3060-project` and data symlinked at `data/fineweb10B`.
- Activate venv before starting JupyterLab: `source .venv/bin/activate`.
- Change `PROJECT_ROOT` below if different.

In [None]:
import os, subprocess, shlex, json, glob, re
from datetime import datetime
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('ggplot')
%matplotlib inline

PROJECT_ROOT = os.environ.get("PROJ_ROOT", "/workspace/ese-3060-project")
SCRIPT = os.path.join(PROJECT_ROOT, "train_gpt.py")
RESULTS = os.path.join(PROJECT_ROOT, "experiments", "results.csv")
LOG_DIR = os.path.join(PROJECT_ROOT, "logs")
PROJECT_ROOT, SCRIPT

## 1) Launch runs (Stage 1 screening)
Configs: baseline, headwise sigmoid, elementwise sigmoid, headwise ns_sigmoid, const sigmoid (all at SDPA position).

In [None]:
# Runtime knobs
NPROC = None           # set to int to override; otherwise auto-detect GPU count
NUM_ITER = 1500        # screening iterations
VAL_EVERY = 125        # eval cadence
SEEDS = [1337, 1338]   # two seeds per config
DRY_RUN = True         # flip to False to actually launch

CONFIGS = {
    "baseline":  {"ATTNGATE": "none",      "GATEPOS": "sdpa", "GATEACT": "sigmoid"},
    "head_sig":  {"ATTNGATE": "headwise",  "GATEPOS": "sdpa", "GATEACT": "sigmoid"},
    "elem_sig":  {"ATTNGATE": "elementwise","GATEPOS": "sdpa", "GATEACT": "sigmoid"},
    "head_ns":   {"ATTNGATE": "headwise",  "GATEPOS": "sdpa", "GATEACT": "ns_sigmoid"},
    "const_sig": {"ATTNGATE": "const",     "GATEPOS": "sdpa", "GATEACT": "sigmoid"},
}

def detect_gpu_count():
    try:
        import torch
        return torch.cuda.device_count()
    except Exception:
        pass
    try:
        out = subprocess.check_output(["nvidia-smi", "--list-gpus"], text=True)
        return len([l for l in out.splitlines() if l.strip()])
    except Exception:
        return 0

gpu_count = detect_gpu_count()
effective_nproc = NPROC if NPROC is not None else gpu_count
print(f"Detected GPUs: {gpu_count}; using NPROC={effective_nproc}")
if effective_nproc is None or effective_nproc < 1:
    raise SystemExit("No GPUs detected; set NPROC manually if using a custom setup.")
if effective_nproc > gpu_count:
    raise SystemExit(f"Requested NPROC={effective_nproc} exceeds visible GPUs={gpu_count}.")

def run_config(cfg_name, cfg_env, seed, nproc, num_iter, val_every, dry_run=True):
    env = os.environ.copy()
    env.update({
        "SEED": str(seed),
        "NUM_ITER": str(num_iter),
        "VAL_EVERY": str(val_every),
    })
    env.update(cfg_env)
    cmd = ["torchrun", "--standalone", f"--nproc_per_node={nproc}", SCRIPT]
    print(f"[{datetime.utcnow().isoformat()}Z] {cfg_name} seed={seed}")
    print("    env overrides:", json.dumps({k: env[k] for k in ['ATTNGATE','GATEPOS','GATEACT','SEED','NUM_ITER','VAL_EVERY']}, indent=2))
    if dry_run:
        return
    result = subprocess.run(cmd, env=env)
    if result.returncode != 0:
        raise RuntimeError(f"Run failed: {cfg_name}, seed {seed}, rc={result.returncode}")

for cfg_name, cfg_env in CONFIGS.items():
    for seed in SEEDS:
        run_config(cfg_name, cfg_env, seed, effective_nproc, NUM_ITER, VAL_EVERY, dry_run=DRY_RUN)

print("Done (DRY_RUN=" + str(DRY_RUN) + ")")


## 2) Load results and logs
Run after some jobs complete. If `experiments/results.csv` doesn’t exist yet, create an empty DataFrame.

In [None]:
if os.path.exists(RESULTS):
    df = pd.read_csv(RESULTS)
else:
    df = pd.DataFrame()
df.head() if not df.empty else df

### Filter to Stage 1 runs (short num_iterations)

In [None]:
stage1 = df[df.get("num_iterations", pd.Series([0]*len(df))) <= 2000].copy()
stage1.head()

### Aggregate metrics by config

In [None]:
if not stage1.empty:
    group_cols = ["attn_gate", "gate_act", "gate_pos"]
    agg = stage1.groupby(group_cols).agg(
        runs=("run_id", "count"),
        mean_final_val=("final_val_loss", "mean"),
        std_final_val=("final_val_loss", "std"),
        mean_best_val=("best_val_loss", "mean"),
        mean_ms_step=("ms_per_step", "mean"),
    ).reset_index()
    agg
else:
    agg = pd.DataFrame()
    agg

### Plots: best val loss and step time

In [None]:
if not agg.empty:
    fig, ax = plt.subplots(figsize=(6,4))
    ax.bar(agg.index, agg["mean_best_val"], yerr=agg["std_final_val"], capsize=4)
    ax.set_xticks(agg.index)
    ax.set_xticklabels(agg[["attn_gate","gate_act"]].agg(' / '.join, axis=1), rotation=30, ha='right')
    ax.set_ylabel("Best val loss (mean ± sd)")
    plt.tight_layout()
else:
    print("No data to plot yet.")

In [None]:
if not agg.empty:
    fig, ax = plt.subplots(figsize=(6,4))
    ax.bar(agg.index, agg["mean_ms_step"])
    ax.set_xticks(agg.index)
    ax.set_xticklabels(agg[["attn_gate","gate_act"]].agg(' / '.join, axis=1), rotation=30, ha='right')
    ax.set_ylabel("ms/step (mean)")
    plt.tight_layout()
else:
    print("No data to plot yet.")

### Parse val_loss curves from logs and plot val_loss vs train_time
Uses the first run per config for a quick visual.

In [None]:
VAL_RE = re.compile(r"step:(\d+)/(\d+) val_loss:([0-9.]+) train_time:(\d+)ms")

def parse_log(path):
    rows = []
    for line in open(path):
        m = VAL_RE.search(line)
        if m:
            step = int(m.group(1)); total = int(m.group(2))
            vloss = float(m.group(3)); t_ms = int(m.group(4))
            rows.append((step, total, vloss, t_ms))
    return rows

log_rows = []
for path in glob.glob(os.path.join(LOG_DIR, "*.txt")):
    rid = os.path.basename(path).replace('.txt','')
    rows = parse_log(path)
    for step, total, vloss, t_ms in rows:
        log_rows.append({"run_id": rid, "step": step, "total": total, "val_loss": vloss, "train_time_ms": t_ms})

log_df = pd.DataFrame(log_rows)
log_df.head() if not log_df.empty else log_df

In [None]:
if not log_df.empty and not df.empty:
    curves = log_df.merge(df[["run_id","attn_gate","gate_act","gate_pos"]], on="run_id", how="left")
    # pick first run per config for plotting
    first_runs = curves.groupby(["attn_gate","gate_act","gate_pos"])['run_id'].transform('min') == curves['run_id']
    plot_df = curves[first_runs]

    fig, ax = plt.subplots(figsize=(7,5))
    for (gate, act, pos), sub in plot_df.groupby(["attn_gate","gate_act","gate_pos"]):
        ax.plot(sub["train_time_ms"]/1000.0, sub["val_loss"], label=f"{gate}/{act}")
    ax.set_xlabel("Train time (s)")
    ax.set_ylabel("Val loss")
    ax.set_title("Val loss vs train time (first run per config)")
    ax.legend()
    plt.tight_layout()
else:
    print("No curves to plot yet.")