# Critical Batch Size in RLVR

Analysis notebook for the CBS experiments. Pulls wandb logs, computes S(B) and E(B),
fits the McCandlish model, and produces publication-quality plots.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from scipy.optimize import curve_fit
from collections import defaultdict
import wandb

matplotlib.rcParams.update({
    "font.size": 13,
    "axes.labelsize": 14,
    "axes.titlesize": 15,
    "legend.fontsize": 11,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "figure.figsize": (7, 5),
    "figure.dpi": 150,
})

## 1. Load Experiment Data from Wandb

In [None]:
WANDB_ENTITY = "harvardml"
WANDB_PROJECT = "cbs_rlvr"

# Phase tags to load
PHASE_TAGS = ["p1", "p2a", "p2b", "p2c", "p4_math", "p4_rpp"]

api = wandb.Api()
runs = api.runs(f"{WANDB_ENTITY}/{WANDB_PROJECT}")
print(f"Found {len(runs)} runs in {WANDB_ENTITY}/{WANDB_PROJECT}")

In [None]:
def parse_experiment_name(name: str) -> dict:
    """Parse cbs_{phase}_np{n_prompts}_nr{n_rollouts} into components."""
    parts = name.split("_")
    result = {"name": name, "phase": None, "n_prompts": None, "n_rollouts": None}
    for p in parts:
        if p.startswith("np"):
            result["n_prompts"] = int(p[2:])
        elif p.startswith("nr"):
            result["n_rollouts"] = int(p[2:])
    # Phase is everything between 'cbs_' and '_np'
    if "_np" in name and name.startswith("cbs_"):
        result["phase"] = name.split("_np")[0].replace("cbs_", "")
    if result["n_prompts"] and result["n_rollouts"]:
        result["total_batch"] = result["n_prompts"] * result["n_rollouts"]
    return result


def load_run_history(run, keys=None):
    """Load history DataFrame for a wandb run."""
    if keys is None:
        keys = [
            "training/global_step",
            "critic/score/mean",
            "actor/entropy",
            "perf/time_per_step",
            "perf/throughput",
        ]
        # Also grab all val-core metrics
        keys.append("_step")
    hist = run.scan_history(keys=keys, page_size=10000)
    return pd.DataFrame(list(hist))

In [None]:
# Load all runs and their histories
all_runs = {}
for run in runs:
    meta = parse_experiment_name(run.name)
    if meta["phase"] is None:
        continue
    meta["run_id"] = run.id
    meta["state"] = run.state

    # Load full history including val metrics
    hist = pd.DataFrame(list(run.scan_history(page_size=10000)))
    meta["history"] = hist
    all_runs[run.name] = meta
    print(f"  Loaded {run.name}: {len(hist)} rows, state={run.state}")

print(f"\nTotal runs loaded: {len(all_runs)}")

## 2. Extract S(B) and E(B)

For each run, find the step at which a target accuracy threshold is reached.

In [None]:
def find_val_accuracy_columns(hist: pd.DataFrame) -> list[str]:
    """Find val-core accuracy columns in the history DataFrame."""
    return [c for c in hist.columns if "val-core" in c and "mean@" in c]


def get_accuracy_at_step(hist: pd.DataFrame, acc_col: str) -> pd.Series:
    """Extract (step, accuracy) series from history, forward-filling NaN."""
    step_col = "training/global_step"
    if step_col not in hist.columns or acc_col not in hist.columns:
        return pd.Series(dtype=float)
    df = hist[[step_col, acc_col]].dropna(subset=[acc_col]).copy()
    df = df.rename(columns={step_col: "step", acc_col: "accuracy"})
    return df.set_index("step")["accuracy"]


def steps_to_threshold(acc_series: pd.Series, threshold: float) -> int | None:
    """Find the first step where accuracy >= threshold."""
    above = acc_series[acc_series >= threshold]
    if len(above) == 0:
        return None
    return int(above.index[0])

In [None]:
def compute_sb_eb(runs_dict: dict, phase: str, acc_col: str | None = None,
                  thresholds: list[float] | None = None) -> pd.DataFrame:
    """
    Compute S(B) and E(B) for runs in a given phase.

    Returns DataFrame with columns:
      n_prompts, n_rollouts, total_batch, threshold, S_B, E_B
    """
    if thresholds is None:
        thresholds = [0.3, 0.4, 0.5, 0.6, 0.65, 0.7]

    rows = []
    for name, meta in runs_dict.items():
        if meta["phase"] != phase:
            continue
        hist = meta["history"]
        if hist.empty:
            continue

        # Auto-detect accuracy column if not specified
        if acc_col is None:
            val_cols = find_val_accuracy_columns(hist)
            if not val_cols:
                continue
            # Pick the GSM8K accuracy column (prefer it over MATH)
            gsm_cols = [c for c in val_cols if "gsm" in c.lower()]
            col = gsm_cols[0] if gsm_cols else val_cols[0]
        else:
            col = acc_col

        acc = get_accuracy_at_step(hist, col)
        if acc.empty:
            continue

        for thresh in thresholds:
            s_b = steps_to_threshold(acc, thresh)
            rows.append({
                "n_prompts": meta["n_prompts"],
                "n_rollouts": meta["n_rollouts"],
                "total_batch": meta["total_batch"],
                "threshold": thresh,
                "S_B": s_b,
                "E_B": s_b * meta["total_batch"] if s_b is not None else None,
                "acc_col": col,
            })

    return pd.DataFrame(rows)

In [None]:
# Compute S(B), E(B) for Phase 1
sb_eb_p1 = compute_sb_eb(all_runs, phase="p1")
print("Phase 1 S(B)/E(B):")
sb_eb_p1

## 3. Fit the McCandlish Model

$$S(B) = S_{\min} \cdot \left(1 + \frac{B_{\text{noise}}}{B}\right)$$

$$E(B) = E_{\min} \cdot \left(1 + \frac{B}{B_{\text{noise}}}\right)$$

In [None]:
def mccandlish_S(B, S_min, B_noise):
    return S_min * (1 + B_noise / B)


def mccandlish_E(B, E_min, B_noise):
    return E_min * (1 + B / B_noise)


def fit_cbs(df: pd.DataFrame, threshold: float, batch_col: str = "total_batch"):
    """
    Fit B_noise from S(B) data at a given accuracy threshold.

    Returns dict with S_min, B_noise, and goodness-of-fit info.
    """
    sub = df[(df["threshold"] == threshold) & df["S_B"].notna()].copy()
    if len(sub) < 3:
        return None

    B_vals = sub[batch_col].values.astype(float)
    S_vals = sub["S_B"].values.astype(float)

    try:
        popt, pcov = curve_fit(
            mccandlish_S, B_vals, S_vals,
            p0=[S_vals.min(), B_vals[len(B_vals) // 2]],
            bounds=([0, 0], [np.inf, np.inf]),
            maxfev=10000,
        )
        S_min_fit, B_noise_fit = popt
        S_pred = mccandlish_S(B_vals, *popt)
        ss_res = np.sum((S_vals - S_pred) ** 2)
        ss_tot = np.sum((S_vals - S_vals.mean()) ** 2)
        r_squared = 1 - ss_res / ss_tot if ss_tot > 0 else 0

        return {
            "threshold": threshold,
            "S_min": S_min_fit,
            "B_noise": B_noise_fit,
            "R2": r_squared,
            "B_vals": B_vals,
            "S_vals": S_vals,
        }
    except RuntimeError:
        return None

In [None]:
# Fit CBS for each threshold
thresholds = [0.3, 0.4, 0.5, 0.6, 0.65, 0.7]
fits = {}
for t in thresholds:
    result = fit_cbs(sb_eb_p1, t)
    if result is not None:
        fits[t] = result
        print(f"Threshold {t:.0%}: B_noise={result['B_noise']:.0f}, "
              f"S_min={result['S_min']:.1f}, R²={result['R2']:.4f}")
    else:
        print(f"Threshold {t:.0%}: insufficient data to fit")

## 4. Plots

In [None]:
def plot_sb_eb(sb_eb_df: pd.DataFrame, fits: dict, title_suffix: str = ""):
    """
    Plot S(B) and E(B) with fitted curves.
    Produces the classic CBS visualization.
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))

    cmap = plt.cm.viridis
    thresholds_with_data = sorted(fits.keys())
    colors = {t: cmap(i / max(len(thresholds_with_data) - 1, 1))
              for i, t in enumerate(thresholds_with_data)}

    for thresh, fit in fits.items():
        c = colors[thresh]
        B = fit["B_vals"]
        S = fit["S_vals"]

        # S(B) plot
        axes[0].scatter(B, S, color=c, s=60, zorder=5)
        B_smooth = np.geomspace(B.min() * 0.5, B.max() * 2, 200)
        S_smooth = mccandlish_S(B_smooth, fit["S_min"], fit["B_noise"])
        axes[0].plot(B_smooth, S_smooth, color=c, alpha=0.7,
                     label=f"acc={thresh:.0%} (B*={fit['B_noise']:.0f})")
        axes[0].axvline(fit["B_noise"], color=c, ls=":", alpha=0.4)

        # E(B) plot
        E = S * B
        axes[1].scatter(B, E, color=c, s=60, zorder=5)
        E_min = fit["S_min"] * fit["B_noise"]
        E_smooth = mccandlish_E(B_smooth, E_min, fit["B_noise"])
        axes[1].plot(B_smooth, E_smooth, color=c, alpha=0.7,
                     label=f"acc={thresh:.0%}")
        axes[1].axvline(fit["B_noise"], color=c, ls=":", alpha=0.4)

    for ax in axes:
        ax.set_xscale("log", base=2)
        ax.set_yscale("log", base=2)
        ax.legend()
        ax.grid(True, alpha=0.3)

    axes[0].set_xlabel("Batch size B (total samples/step)")
    axes[0].set_ylabel("Steps to target S(B)")
    axes[0].set_title(f"S(B): Steps to target{title_suffix}")

    axes[1].set_xlabel("Batch size B (total samples/step)")
    axes[1].set_ylabel("Total samples E(B) = B·S(B)")
    axes[1].set_title(f"E(B): Total samples to target{title_suffix}")

    plt.tight_layout()
    return fig

In [None]:
# Phase 1 CBS plot
if fits:
    fig = plot_sb_eb(sb_eb_p1, fits, title_suffix=" (Phase 1: vary n_prompts)")
    fig.savefig("cbs_phase1_sb_eb.pdf", bbox_inches="tight")
    plt.show()
else:
    print("No fits available yet. Run Phase 1 experiments first.")

## 5. Training Curves: Accuracy vs Steps and vs Total Samples

In [None]:
def plot_training_curves(runs_dict: dict, phase: str, metric_col: str | None = None):
    """
    Plot training curves for all runs in a phase.
    Left: accuracy vs steps. Right: accuracy vs total samples processed.
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))

    phase_runs = {k: v for k, v in runs_dict.items() if v["phase"] == phase}
    if not phase_runs:
        print(f"No runs found for phase {phase}")
        return

    sorted_runs = sorted(phase_runs.values(), key=lambda x: x["total_batch"])
    cmap = plt.cm.plasma
    n = len(sorted_runs)

    for i, meta in enumerate(sorted_runs):
        hist = meta["history"]
        if hist.empty:
            continue

        if metric_col is None:
            val_cols = find_val_accuracy_columns(hist)
            gsm_cols = [c for c in val_cols if "gsm" in c.lower()]
            col = gsm_cols[0] if gsm_cols else (val_cols[0] if val_cols else None)
        else:
            col = metric_col

        if col is None or col not in hist.columns:
            continue

        step_col = "training/global_step"
        df = hist[[step_col, col]].dropna(subset=[col])
        if df.empty:
            continue

        steps = df[step_col].values
        acc = df[col].values
        total_samples = steps * meta["total_batch"]
        color = cmap(i / max(n - 1, 1))
        label = f"B={meta['total_batch']} (np={meta['n_prompts']},nr={meta['n_rollouts']})"

        axes[0].plot(steps, acc, color=color, label=label, marker="o", ms=3)
        axes[1].plot(total_samples, acc, color=color, label=label, marker="o", ms=3)

    axes[0].set_xlabel("Training steps")
    axes[0].set_ylabel("Accuracy")
    axes[0].set_title(f"Accuracy vs Steps ({phase})")
    axes[0].legend(fontsize=8)
    axes[0].grid(True, alpha=0.3)

    axes[1].set_xlabel("Total samples processed")
    axes[1].set_ylabel("Accuracy")
    axes[1].set_title(f"Accuracy vs Total Samples ({phase})")
    axes[1].legend(fontsize=8)
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    return fig

In [None]:
fig = plot_training_curves(all_runs, "p1")
if fig:
    fig.savefig("cbs_phase1_training_curves.pdf", bbox_inches="tight")
    plt.show()

## 6. Two-Axis Decomposition (Phase 2)

In [None]:
# Phase 2B: Fix n_prompts, vary n_rollouts
sb_eb_p2b = compute_sb_eb(all_runs, phase="p2b")

fits_p2b = {}
for t in thresholds:
    result = fit_cbs(sb_eb_p2b, t, batch_col="total_batch")
    if result is not None:
        fits_p2b[t] = result

if fits_p2b:
    fig = plot_sb_eb(sb_eb_p2b, fits_p2b, title_suffix=" (Phase 2B: vary n_rollouts)")
    fig.savefig("cbs_phase2b_sb_eb.pdf", bbox_inches="tight")
    plt.show()

In [None]:
# Phase 2C: Iso-batch decomposition
# All runs have total_batch=2048 but different (n_prompts, n_rollouts) splits

def plot_iso_batch(runs_dict: dict, phase: str = "p2c"):
    """Plot accuracy vs steps for iso-batch experiments, colored by n_rollouts."""
    phase_runs = {k: v for k, v in runs_dict.items() if v["phase"] == phase}
    if not phase_runs:
        print(f"No runs for phase {phase}")
        return

    sorted_runs = sorted(phase_runs.values(), key=lambda x: x["n_rollouts"])
    cmap = plt.cm.coolwarm
    n = len(sorted_runs)

    fig, ax = plt.subplots(figsize=(8, 5.5))

    for i, meta in enumerate(sorted_runs):
        hist = meta["history"]
        if hist.empty:
            continue
        val_cols = find_val_accuracy_columns(hist)
        gsm_cols = [c for c in val_cols if "gsm" in c.lower()]
        col = gsm_cols[0] if gsm_cols else (val_cols[0] if val_cols else None)
        if col is None:
            continue

        step_col = "training/global_step"
        df = hist[[step_col, col]].dropna(subset=[col])
        if df.empty:
            continue

        color = cmap(i / max(n - 1, 1))
        label = f"np={meta['n_prompts']}, nr={meta['n_rollouts']}"
        ax.plot(df[step_col], df[col], color=color, label=label, marker="o", ms=4)

    ax.set_xlabel("Training steps")
    ax.set_ylabel("Accuracy")
    ax.set_title(f"Iso-batch (B=2048): Prompt diversity vs Rollout diversity")
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    return fig

fig = plot_iso_batch(all_runs, "p2c")
if fig:
    fig.savefig("cbs_phase2c_iso_batch.pdf", bbox_inches="tight")
    plt.show()

## 7. Gradient Noise Scale Analysis (Phase 3)

If gradient noise metrics were logged during training, analyze them here.

In [None]:
def plot_gradient_noise(runs_dict: dict, phase: str = "p1"):
    """Plot estimated B_noise over training steps."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))

    phase_runs = {k: v for k, v in runs_dict.items() if v["phase"] == phase}
    sorted_runs = sorted(phase_runs.values(), key=lambda x: x["total_batch"])
    cmap = plt.cm.plasma
    n = len(sorted_runs)

    for i, meta in enumerate(sorted_runs):
        hist = meta["history"]
        if hist.empty:
            continue

        step_col = "training/global_step"
        noise_col = "grad_noise/B_noise_estimate"
        norm_col = "grad_noise/grad_norm"

        if noise_col not in hist.columns:
            continue

        df = hist[[step_col, noise_col, norm_col]].dropna(subset=[noise_col])
        if df.empty:
            continue

        color = cmap(i / max(n - 1, 1))
        label = f"B={meta['total_batch']}"

        axes[0].plot(df[step_col], df[noise_col], color=color, label=label, alpha=0.8)
        axes[1].plot(df[step_col], df[norm_col], color=color, label=label, alpha=0.8)

    axes[0].set_xlabel("Training step")
    axes[0].set_ylabel("Estimated B_noise")
    axes[0].set_title("Gradient noise scale over training")
    axes[0].legend(fontsize=8)
    axes[0].grid(True, alpha=0.3)

    axes[1].set_xlabel("Training step")
    axes[1].set_ylabel("Gradient norm")
    axes[1].set_title("Gradient norm over training")
    axes[1].legend(fontsize=8)
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    return fig

fig = plot_gradient_noise(all_runs, "p1")
if fig:
    fig.savefig("cbs_grad_noise.pdf", bbox_inches="tight")
    plt.show()
else:
    print("No gradient noise data available. Enable grad noise measurement in training.")

## 8. CBS Summary Table

In [None]:
def make_cbs_summary(phases: dict[str, dict]) -> pd.DataFrame:
    """Create a summary table of CBS estimates across phases and thresholds."""
    rows = []
    for phase_name, phase_fits in phases.items():
        for thresh, fit in phase_fits.items():
            rows.append({
                "Phase": phase_name,
                "Target Accuracy": f"{thresh:.0%}",
                "B_noise (CBS)": f"{fit['B_noise']:.0f}",
                "S_min": f"{fit['S_min']:.1f}",
                "R²": f"{fit['R2']:.4f}",
            })
    return pd.DataFrame(rows)

all_phase_fits = {"Phase 1 (vary n_prompts)": fits}
if fits_p2b:
    all_phase_fits["Phase 2B (vary n_rollouts)"] = fits_p2b

summary = make_cbs_summary(all_phase_fits)
print(summary.to_string(index=False))

## 9. Throughput Analysis

Plot wall-clock time efficiency to complement the sample efficiency analysis.

In [None]:
def plot_wall_clock_efficiency(runs_dict: dict, phase: str, acc_col: str | None = None,
                               threshold: float = 0.5):
    """
    For each batch size, estimate wall-clock time to reach threshold.
    Plot: time-to-target vs batch size (shows if scaling is efficient).
    """
    phase_runs = {k: v for k, v in runs_dict.items() if v["phase"] == phase}
    results = []

    for name, meta in phase_runs.items():
        hist = meta["history"]
        if hist.empty:
            continue

        if acc_col is None:
            val_cols = find_val_accuracy_columns(hist)
            gsm_cols = [c for c in val_cols if "gsm" in c.lower()]
            col = gsm_cols[0] if gsm_cols else (val_cols[0] if val_cols else None)
        else:
            col = acc_col
        if col is None:
            continue

        time_col = "perf/time_per_step"
        step_col = "training/global_step"
        if time_col not in hist.columns:
            continue

        acc_series = get_accuracy_at_step(hist, col)
        s_b = steps_to_threshold(acc_series, threshold)
        if s_b is None:
            continue

        avg_time_per_step = hist[time_col].dropna().mean()
        total_wall_time = s_b * avg_time_per_step

        results.append({
            "total_batch": meta["total_batch"],
            "n_prompts": meta["n_prompts"],
            "n_rollouts": meta["n_rollouts"],
            "steps": s_b,
            "avg_step_time_s": avg_time_per_step,
            "wall_time_h": total_wall_time / 3600,
        })

    if not results:
        print("No wall-clock data available")
        return

    df = pd.DataFrame(results).sort_values("total_batch")

    fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))

    axes[0].plot(df["total_batch"], df["wall_time_h"], "o-", color="steelblue", ms=8)
    axes[0].set_xscale("log", base=2)
    axes[0].set_xlabel("Batch size B")
    axes[0].set_ylabel("Wall-clock time (hours)")
    axes[0].set_title(f"Wall-clock time to {threshold:.0%} accuracy")
    axes[0].grid(True, alpha=0.3)

    axes[1].plot(df["total_batch"], df["avg_step_time_s"], "o-", color="coral", ms=8)
    axes[1].set_xscale("log", base=2)
    axes[1].set_xlabel("Batch size B")
    axes[1].set_ylabel("Avg time per step (seconds)")
    axes[1].set_title("Step time vs batch size")
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    return fig

fig = plot_wall_clock_efficiency(all_runs, "p1", threshold=0.5)
if fig:
    fig.savefig("cbs_wall_clock.pdf", bbox_inches="tight")
    plt.show()