# Stage 2 LR Sweep (baseline vs elementwise gate)

Sweeps learning rate multipliers for baseline and elementwise gating (sigmoid, SDPA output) per Stage 2 of the plan. Uses 150 warmup steps and 1,000 total iterations. Runs 2 seeds per LR (default) and stops early if divergence is detected (NaN/inf or loss above a threshold).

**Important**: This notebook launches `torchrun` subprocesses. Adjust `gpus_per_run` and `cuda_visible_devices` to match your node. Ensure data is at `data/fineweb10B/` (or set a symlink).

In [None]:
%cd /workspace/ese-3060-project

# Config
base_lr = 0.0036
lr_multipliers = [1.0, 1.1, 1.2, 1.3]
seeds = [1337, 2337]
warmup_iters = 150
num_iterations = 1000
early_stop_patience = 2      # 0 disables; patience counted on val checks
early_stop_min_delta = 0.0    # require this much improvement to reset patience

gpus_per_run = 8            # set to your available GPU count (e.g., 1 or 8)
cuda_visible_devices = None # e.g., "0,1,2,3" or leave None

script_path = "/workspace/ese-3060-project/train_gpt.py"  # absolute path to avoid cwd issues
torchrun_cmd = "torchrun"



In [None]:
import os, subprocess, shlex, time, re, json, pathlib, signal

def build_env(attn_gate, lr, seed):
    env = os.environ.copy()
    env.update({
        "ATTNGATE": attn_gate,
        "GATEPOS": "sdpa",
        "GATEACT": "sigmoid",
        "LR": str(lr),
        "SEED": str(seed),
        "WARMUP_ITERS": str(warmup_iters),
        "NUM_ITERATIONS": str(num_iterations),
        "EARLY_STOP_PATIENCE": str(early_stop_patience),
        "EARLY_STOP_MIN_DELTA": str(early_stop_min_delta),
    })
    if cuda_visible_devices is not None:
        env["CUDA_VISIBLE_DEVICES"] = str(cuda_visible_devices)
    return env

def run_experiment(name, attn_gate, lr, seed):
    cmd = [
        torchrun_cmd,
        "--standalone",
        f"--nproc_per_node={gpus_per_run}",
        script_path,
    ]
    env = build_env(attn_gate, lr, seed)
    print(f"
>>> Launching {name}
    attn_gate={attn_gate} lr={lr:.6f} seed={seed} nproc={gpus_per_run}")
    proc = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        env=env,
        preexec_fn=os.setsid,
        bufsize=1,
    )
    lines = []
    for line in proc.stdout:
        line = line.rstrip()
        print(line)
        lines.append(line)
    ret = proc.wait()
    tail = lines[-10:]
    return {"name": name, "attn_gate": attn_gate, "lr": lr, "seed": seed, "returncode": ret, "tail": tail}


In [None]:
# Build experiment list: baseline (none) and elementwise
experiments = []
for attn_gate in ["none", "elementwise"]:
    for mult in lr_multipliers:
        lr = base_lr * mult
        for seed in seeds:
            name = f"{attn_gate}-lr{mult:.2f}-seed{seed}"
            experiments.append({
                "name": name,
                "attn_gate": attn_gate,
                "lr": lr,
                "seed": seed,
            })
print(f"Prepared {len(experiments)} experiments:")
for exp in experiments:
    print(" -", exp["name"])


In [None]:
# Run (commented out by default; remove the guard to execute)
results = []
if False:  # set to True to launch runs (expensive!)
    for exp in experiments:
        res = run_experiment(**exp)
        results.append(res)
    print(json.dumps(results, indent=2))


In [None]:
# Plot averaged val loss curves per configuration (baseline vs elementwise), LR variants on same graph
import glob
import matplotlib.pyplot as plt

log_dir = "logs"

def parse_hparams(lines):
    try:
        start = lines.index("hyperparameters:\n")
    except ValueError:
        return None
    blob = []
    for ln in lines[start+1:]:
        if ln.startswith("Running pytorch"):
            break
        blob.append(ln)
    try:
        return json.loads("".join(blob))
    except Exception:
        return None

val_re = re.compile(r"step:(\d+)/(\d+).*val_loss:([\d\.eE+-]+)")

def parse_log(path):
    with open(path) as f:
        lines = f.readlines()
    hparams = parse_hparams(lines)
    vals = []
    for ln in lines:
        m = val_re.search(ln)
        if m:
            step = int(m.group(1))
            loss = float(m.group(3))
            vals.append((step, loss))
    return hparams, vals

# collect logs matching our Stage 2 sweep filters
grouped = {}
for path in glob.glob(os.path.join(log_dir, "*.txt")):
    hparams, vals = parse_log(path)
    if not hparams or not vals:
        continue
    if hparams.get("gate_pos") != "sdpa" or hparams.get("gate_act") != "sigmoid":
        continue
    if hparams.get("warmup_iters") != warmup_iters or hparams.get("num_iterations") != num_iterations:
        continue
    if hparams.get("attn_gate") not in ("none", "elementwise"):
        continue
    key = (hparams["attn_gate"], float(hparams["learning_rate"]))
    grouped.setdefault(key, []).append(vals)

if not grouped:
    print("No matching logs found. Ensure runs completed and warmup/num_iterations match notebook settings.")
else:
    attn_groups = sorted(set(k[0] for k in grouped.keys()))
    for attn in attn_groups:
        plt.figure(figsize=(8,5))
        lrs = sorted(k[1] for k in grouped.keys() if k[0] == attn)
        for lr in lrs:
            runs = grouped[(attn, lr)]
            # average per step across seeds
            agg = {}
            for run in runs:
                for step, loss in run:
                    agg.setdefault(step, []).append(loss)
            steps = sorted(agg.keys())
            mean_loss = [sum(agg[s])/len(agg[s]) for s in steps]
            plt.plot(steps, mean_loss, label=f"lr={lr:.5f} (n={len(runs)})")
        plt.title(f"Val loss vs step — attn_gate={attn}")
        plt.xlabel("step")
        plt.ylabel("val_loss")
        plt.legend()
        plt.grid(True, alpha=0.3)
    plt.show()
