### Loss vs. norm plot

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from utils.fitting import Config, build_minima_df
from utils.plotting import plot_minima, plot_parabola_grid, plot_interactive_horizon_scatter

In [None]:
# for paths
BASE = "lr-bs-scan"
SCALING_PREFIX = 'base' # 'base' // 'x4-width' // 'x8-depth'
POSTFIX = ""
MOMENTUM = 1.0
DECAYED = False
SEEDS = [30] # None for all // [30] for one 

# filtering
HORIZONS = np.logspace(31, 37, 4, base=2)
# HORIZONS = [2_868_903_936, 11_475_615_744, 45_902_462_976, 183_609_851_904]
MAX_LOSS = 11.765

# fit cfg
FROM_FIT = True
OPTIMUM_FROM_CLOSEST = False
C_FIXED = MAX_LOSS
FIT_K = 7
FIT_K_BY = "x"
AVG_REL_FROM = -1
AVG_REL_TO = 1
AVERAGE_H = [2**33, 2**34, 2**35, 2**36, 2**37] 
AVERAGE_BS = [2**3, 2**4, 2**5, 2**6, 2**7, ] # 2**8, 2**9, 2**10
SKIP_FIT = [] # (2**33, 2**9), (2**33, 2**10)

In [None]:
if SEEDS is None:
    seed_part = "-seeds"
elif len(SEEDS) == 1:
    seed_part = f"-seed-{SEEDS[0]}"
else:
    raise ValueError("Only one seed supported")
CSV_PATH = f'data/{BASE}-{SCALING_PREFIX}-momentum-{MOMENTUM}{'-decayed' if DECAYED else ''}-preprocessed{seed_part}{POSTFIX}.csv'

cfg = Config(
    csv_path=CSV_PATH,
    horizons=HORIZONS,
    max_loss=MAX_LOSS,
    from_fit=FROM_FIT,
    c_fixed=C_FIXED,
    optimum_from_closest=OPTIMUM_FROM_CLOSEST,
    fit_k=FIT_K, 
    fit_k_by=FIT_K_BY,
    avg_rel_from=AVG_REL_FROM,
    avg_rel_to=AVG_REL_TO,
    average_h=AVERAGE_H,
    average_bs=AVERAGE_BS,
    skip_fit=SKIP_FIT,
    strict_avg=True,
    bs_size_base=50,
    bs_size_factor=1.85,
    figsize=(9, 8),
    # legend_models_loc="upper center",
    legend_models_bbox=(0.4, .98),
    # legend_bs_loc="lower left",
    # legend_bs_bbox=(0.545, .78),
    legend_bs_bbox=(0.01, .25),
    use_constrained_layout=True,
    line_width=4.,
    legend_fontsize=16,
    axis_label_fontsize=23,
    tick_label_fontsize=23,
)
df = pd.read_csv(cfg.csv_path)
minima = build_minima_df(df, cfg)

minima = minima.query('bs >= 32')
# minima.to_csv(f'data/minima-{SCALING_PREFIX}-from-fit-{FROM_FIT}-c-{C_FIXED}-aver-{AVG_REL_FROM}.csv', index=False)

fig, ax = plot_minima(minima, cfg)
ax.set_xlim(4.8, 8.5)
ax.set_ylim(3.7, 5.5)

plt.show()
# fig.savefig(f'plots/loss-vs-norm-{SCALING_PREFIX}-from-fit-{FROM_FIT}-momentum-{MOMENTUM}-decayed-{DECAYED}{POSTFIX}.pdf')

### Individuals fits for diagnostics 

In [None]:
fig, axes = plot_parabola_grid(df, cfg)

for ax in axes.flat:
    ax.set_ylim(3.75, 6.)  # Set your desired min and max values

plt.show()
# fig.savefig(f'plots/fits-{SCALING_PREFIX}-from-fit-{FROM_FIT}-momentum-{MOMENTUM}-decayed-{DECAYED}{POSTFIX}.pdf')

### 3D plot

In [None]:
fig = plot_interactive_horizon_scatter(
    df,
    horizon=2**33,
    loss_range=(3.5, 8.0),
    fig_height_px=800,
    fig_width_px=950,
    norm_col="output_norm",
)

### Norm variation with LR


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
df = pd.read_csv("data/lr-bs-scan-base-momentum-1.0-preprocessed-seed-30.csv")
df_ = df.query('(bs == 128) and (horizon == 2**33)')

# Find the optimum point (minimum loss)
opt_idx = df_['train_loss'].idxmin()
opt_norm = df_.loc[opt_idx, 'output_norm']
opt_loss = df_.loc[opt_idx, 'train_loss']

fig = plt.figure(figsize=(8, 6))
scatter = plt.scatter(df_['output_norm'], df_['train_loss'], c=np.log2(df_['lr']), cmap='viridis')
cbar = plt.colorbar(scatter, label=r'$\log_{2}(\eta)$')

# Add dashed lines to emphasize the optimum
plt.axvline(x=opt_norm, color='black', linestyle='--', linewidth=2, alpha=0.7, label='Optimum')
plt.axhline(y=opt_loss, color='black', linestyle='--', linewidth=2, alpha=0.7)

# # Highlight the optimum point
# plt.scatter(opt_norm, opt_loss, color='red', s=100, marker='*', edgecolor='black', linewidth=1, zorder=5)

# Set axis labels with font size
plt.xlabel(r'||$W_\mathrm{out}$||$_{\mathrm{RMS} \to \infty}$', fontsize=16)
plt.ylabel('Train Loss', fontsize=16)

# Set tick label sizes
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

# Set colorbar label and tick sizes
cbar.ax.tick_params(labelsize=14)
cbar.set_label(r'$\log_{2}(\eta)$', fontsize=16)

plt.xscale('log', base=2)
plt.title(r'$B=128$ samples, $D=2^{33}$ tokens', fontsize=18)
# plt.legend(loc='upper right')
plt.tight_layout()
plt.show()

# fig.savefig('plots/lr-norm-variation.pdf')

### Norm growth for different (LR, B)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from utils.plotting import plot_norm_vs_horizon_by_lr_bs

In [None]:
df = pd.read_csv("data/lr-bs-scan-base-momentum-1.0-preprocessed-seed-30.csv")
df = df.query('step > 1')

In [None]:
lr_values = [2**-7, 2**-5, 2**-3, 2**-1]  # Adjust based on your data
bs_values = [32, 128, 512, 2048]     # Adjust based on your data
x_col = 'horizon'

fig, ax = plot_norm_vs_horizon_by_lr_bs(df, x_col=x_col, 
                                        lr_values=lr_values, bs_values=bs_values)
plt.show()

# fig.savefig(f'plots/norm-evolution-vs-{x_col}-base.pdf')

### Optimal LR vs. B fit

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
### This is from presaved minima files
### Ignore this cell if you want to plot for minima dataframe derived above

import pandas as pd
import glob

# Find all CSV files starting with "data/minima"
minima_files = glob.glob("data/minima*.csv")
print(f"Found {len(minima_files)} minima files:")
for file in minima_files:
    print(f"  {file}")

# Read all dataframes
dataframes = []
for file in minima_files:
    df = pd.read_csv(file)
    dataframes.append(df)

# Combine all dataframes and group by horizon and bs to average lr
combined_df = pd.concat(dataframes, ignore_index=True)

# Group by horizon and bs, then calculate mean and std of lr and keep other columns
# We'll take the first value for non-lr columns (assuming they're consistent)
minima = combined_df.groupby(['horizon', 'bs']).agg({
    'lr': ['mean', 'std'],
    **{col: 'first' for col in combined_df.columns if col not in ['lr', 'horizon', 'bs']}
}).reset_index()

# Flatten the multi-level column names
minima.columns = [f'{col[0]}_{col[1]}' if col[1] else col[0] for col in minima.columns]

# Rename the lr columns for clarity
minima = minima.rename(columns={'lr_mean': 'lr', 'train_loss_first': 'train_loss'})
print(f"\nCombined dataframe shape: {minima.shape}")
print(f"Columns: {list(minima.columns)}")

In [None]:
import matplotlib.pyplot as plt
from utils.plotting import plot_lr_bs_fit

minima = minima.query('(bs >= 32) and (bs <= 2048)')
minima = minima.query('horizon >= 2**31')
res = plot_lr_bs_fit(minima, 
                     lr_col='lr', bs_col='bs', horizon_col='horizon', loss_col='train_loss',
                     marker_size=15, legend_marker_size=160, best_marker_size=1200,
                     legend_fontsize=16, axis_label_fontsize=23, tick_label_fontsize=23,
                     figsize=(9, 8), min_alpha=0.2, 
                     )
fig = res["fig"]
plt.show()
# fig.savefig(f'plots/lr-vs-bs-base-from-fit-True.pdf')
# fig.savefig(f'plots/lr-vs-bs-{SCALING_PREFIX}-from-fit-{FROM_FIT}.pdf')

### Optimal B vs D fit

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Load the dataset
src_path = "data/_minima_more_horizons.csv"
df = pd.read_csv(src_path)

# Validate required columns
req = ["horizon", "train_loss", "bs"]
missing = [c for c in req if c not in df.columns]
if missing:
    raise ValueError(f"Missing required columns: {missing}. Available: {list(df.columns)}")

# Ensure numeric horizon/bs for grouping and logs
df["horizon_num"] = pd.to_numeric(df["horizon"], errors="coerce")
df["bs_num"] = pd.to_numeric(df["bs"], errors="coerce")

# Drop invalids
df_clean = df.dropna(subset=["horizon_num", "train_loss", "bs_num"]).copy()

# Select min train_loss per horizon
idx = df_clean.groupby("horizon_num")["train_loss"].idxmin()
best = df_clean.loc[idx].copy()

# Keep strictly positive values for log scaling
best = best[(best["horizon_num"] > 0) & (best["bs_num"] > 0)].copy()

# Fit in log2 space: log2(bs) = b*log2(horizon) + c
x = best["horizon_num"].to_numpy(dtype=float)
y = best["bs_num"].to_numpy(dtype=float)
lx, ly = np.log2(x), np.log2(y)
n = len(lx)
xbar, ybar = lx.mean(), ly.mean()
Sxx = np.sum((lx - xbar)**2)
Sxy = np.sum((lx - xbar)*(ly - ybar))
b = Sxy / Sxx
c = ybar - b*xbar
res = ly - (b*lx + c)
RSS = np.sum(res**2)
TSS = np.sum((ly - ybar)**2)
R2 = 1 - RSS/TSS if TSS > 0 else np.nan

# Standard errors (OLS assumptions in log space)
s2 = RSS / (n - 2) if n > 2 else np.nan
se_b = np.sqrt(s2 / Sxx) if n > 2 else np.nan
se_c = np.sqrt(s2 * (1.0/n + (xbar**2)/Sxx)) if n > 2 else np.nan

# Transform to original-space scale a = 2^c, with delta-method SE
a = 2**c
se_a = np.log(2.0) * a * se_c if np.isfinite(se_c) else np.nan

# Prepare fitted curve
xs = np.logspace(np.log2(x.min()), np.log2(x.max()), num=200, base=2.0)
ys = a * (xs**b)

# Helper formatter for "estimate ± s.e."
def pm(val, se, sig=3):
    if not np.isfinite(val) or not np.isfinite(se):
        return "n/a"
    if abs(val) >= 1000 or (0 < abs(val) < 0.01):
        return f"{val:.{sig}f} ± {se:.{sig}f}"
    else:
        return f"{val:.{sig}f} ± {se:.{sig}f}"

legend_lines = [
    r"Fit: $B^\ast$ = a * Horizon^b",
    f"b (exponent) = {pm(b, se_b)}",
    f"a (scale) = {pm(a, se_a)}",
    f"R² = {R2:.3f}",
]

# Plot
fig = plt.figure(figsize=(6, 6), dpi=150)
plt.scatter(x, y)
plt.plot(xs, ys, label="\n".join(legend_lines))
plt.xscale("log", base=2)
plt.yscale("log", base=2)
plt.xlabel("Horizon [tokens]")
plt.ylabel(r"$B^\ast$ [samples]")
plt.legend(loc="upper left")
plt.tight_layout()

plt.show()
# fig.savefig('plots/bs-vs-horizon-log2-log2-fit-more-horizons.pdf')


### First horizon to reach optimal norm

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
from utils.plotting import plot_global_two_param_fit

df = pd.read_csv("data/lr-bs-scan-base-momentum-1.0-preprocessed-seed-30.csv")
res = plot_global_two_param_fit(df, 
                                band_center_log2=7, band_eps=0.2,
                                bs_min=32, bs_max=1024, 
                                target_log2_horizons=[31, 32, 33, 34, 35, 36, 37],
                                legend_fontsize=16, axis_label_fontsize=23, tick_label_fontsize=23,
                                marker_size=16, star_size=30,
                                figsize=(9, 8),
                                title="Base scaling, momentum 1.0",
                                horizon_min=2**31,
                                horizon_max=2**37,
                                A_fixed=None,
                                B_fixed=None,
                                overlay_solid_A=-1, overlay_solid_B=1.5,)

# res['fig'].savefig(f'plots/horizon-to-norm.pdf')

### Scaling up plots

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from utils.fitting import Config, build_minima_df
from utils.plotting import plot_minima_at_horizon_across_models, plot_parabola_grid

In [None]:
MODELS = {
"base": pd.read_csv("data/lr-bs-scan-base-momentum-1.0-preprocessed-seed-30.csv"),
"x4-w": pd.read_csv("data/lr-bs-scan-x4-width-momentum-1.0-preprocessed-seeds.csv"),
"x8-d": pd.read_csv("data/lr-bs-scan-x8-depth-momentum-1.0-preprocessed-seeds.csv"),
"x12-w": pd.read_csv("data/lr-bs-scan-x12-width-momentum-1.0-preprocessed-seeds.csv"),
"x32-d": pd.read_csv("data/lr-bs-scan-x32-depth-momentum-1.0-preprocessed-seeds.csv"),
}

STYLES = {
"base": {"color": "black", "marker": "o", "linestyle": "solid", "alpha": 1.0, "legend_label": "d=256, L=4 (69M)"},
"x4-w": {"color": "lightcoral", "marker": "s", "linestyle": "solid", "alpha": 1.0, "legend_label": "d=1024, L=4 (314M)"},
"x12-w": {"color": "brown", "marker": "D", "linestyle": "solid", "alpha": 1.0, "legend_label": "d=3072, L=4 (1.3B)"},
"x8-d": {"color": "lightblue", "marker": "^", "linestyle": "solid", "alpha": 1.0, "legend_label": "d=256, L=32 (91M)"},
"x32-d": {"color": "steelblue", "marker": "v", "linestyle": "solid", "alpha": 1.0, "legend_label": "d=256, L=128 (168M)"},
}

# filtering
HORIZON = 2**33
MAX_LOSS = 11.765

FROM_FIT = True
OPTIMUM_FROM_CLOSEST = False
C_FIXED = MAX_LOSS
FIT_K = 7
FIT_K_BY = "x"
AVG_REL_FROM = -1
AVG_REL_TO = 1
AVERAGE_H = [2**33] 
AVERAGE_BS = [2**5, 2**6, 2**7, ] # 2**8, 2**9, 2**10
SKIP_FIT = [] # (2**33, 2**9), (2**33, 2**10)
POSTFIX = ""

cfg = Config(
    csv_path="",
    horizons=[HORIZON],
    max_loss=MAX_LOSS,
    from_fit=FROM_FIT,
    c_fixed=C_FIXED,
    optimum_from_closest=OPTIMUM_FROM_CLOSEST,
    fit_k=FIT_K, 
    fit_k_by=FIT_K_BY,
    avg_rel_from=AVG_REL_FROM,
    avg_rel_to=AVG_REL_TO,
    average_h=AVERAGE_H,
    average_bs=AVERAGE_BS,
    skip_fit=SKIP_FIT,
    strict_avg=True,
    bs_size_base=50,
    bs_size_factor=1.85,
    figsize=(9, 8),
    model_styles=STYLES,
    # legend_models_loc="upper left",
    legend_models_bbox=(0.27, .99),
    # legend_bs_loc="lower left",
    legend_bs_bbox=(0.01, .25),
    use_constrained_layout=True,
    line_width=4.,
    legend_fontsize=16,
    axis_label_fontsize=23,
    tick_label_fontsize=23,
)

for k,v in MODELS.items():
    MODELS[k] = v.query('(bs >= 32) and (bs <= 1024)')
minima_by_model = {name: build_minima_df(df, cfg) for name, df in MODELS.items()}

In [None]:
fig, ax = plot_minima_at_horizon_across_models(minima_by_model, cfg, horizon=HORIZON)
ax.set_xlim(4.7, 8.3)
plt.show()
# fig.savefig(f'plots/loss-vs-norm-scaling-for-horizon-{HORIZON}-from-fit-{FROM_FIT}.pdf')

In [None]:
MODEL_KEY = 'x32-d'
fig, axes = plot_parabola_grid(MODELS[MODEL_KEY], cfg)

for ax in axes.flat:
    ax.set_ylim(3.4, 5.5)  # Set your desired min and max values

plt.show()

# fig.savefig(f'plots/fits-{MODEL_KEY}-from-fit-{FROM_FIT}.pdf')