In [None]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from typing import List, Tuple, Optional, Dict, Any, Union, Callable, Iterable

In [None]:
from scaling.utils import (
    get_pareto_frontier, get_final_points_from_curve_set, fit_linear_model
)
from scaling.visualize import visualize_train_curves, plot_line_fit

In [None]:
warmstarting_df = pd.read_parquet(
    "../data/warmstarting_results.parquet",
)
display(warmstarting_df)

### Create held-out set

Use the top-1 or top-2 model/data scales as held-out.

In [None]:
bases = warmstarting_df.base_N.unique()
targets = sorted(warmstarting_df.target_N.unique())
tkpms = sorted(warmstarting_df.tkpm.unique())

held_out_target = targets[-1]

held_out_target

In [None]:
train_df = warmstarting_df.loc[(warmstarting_df.target_N != held_out_target) & (warmstarting_df.method == "mup")]
held_out_df = warmstarting_df.loc[warmstarting_df.target_N == held_out_target]

display(held_out_df)

In [None]:
held_out_df["Validation Loss"].min()

### Visualize raw data

Visualize all learning curves across N, D available in the *training set* to see scaling patterns.

Optionally, consider visualizing for different hyperparameters.

In [None]:
unique_col_list = ["base_N", "target_N", "tkpm", "shrink"]
y_col = "Validation Loss"
x_col = "flops"  # "tokens"

In [None]:
plt.clf()
fig, ax = plt.subplots(figsize=(6, 4))

visualize_train_curves(
    ax, 
    train_df,
    unique_col_list,
    x_col=x_col,
    y_col=y_col,
    plot_all_curves=True,
    plot_final=True,
    plot_pareto_final=True,
    ylims=(1.5, 2.5),
    xlims=(1e16, 2e19),
    xlog=True
)

### Approach 1

Fit for `C vs N` and `C vs D`, assuming *best* training run for each `(N, D)`.

Steps:

1. Visualize for `C vs N` and `C vs D`
2. Fit linear model for each of them

In [None]:
x_col = "flops"

_pareto_df = get_final_points_from_curve_set(
    train_df,
    unique_col_list,
    x_col=x_col,
    y_col="Validation Loss",
    get_pareto=True
)

plt.clf()
fig, ax = plt.subplots(1, 2, figsize=(9, 4))

visualize_train_curves(
    ax[0],
    _pareto_df,
    unique_col_list,
    x_col=x_col,
    y_col="target_N",
    plot_all_curves=False,
    plot_final=True,
    plot_pareto_final=False,
    xlog=True,
    ylog=True,
)

visualize_train_curves(
    ax[1],
    _pareto_df,
    unique_col_list,
    x_col=x_col,
    y_col="tokens",
    plot_all_curves=False,
    plot_final=True,
    plot_pareto_final=False,
    xlog=True,
    ylog=True,
)


In [None]:
# Visualizing scaling law linear fit for C vs N, D, L

plt.clf();
fig, ax = plt.subplots(1, 3, figsize=(15, 4));

# C vs N

X = _pareto_df[x_col].values
Y = _pareto_df["target_N"].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[0],
    X,
    Y,
    slope,
    intercept,
    x_extrapolate=[4.471914e+19, 1e20, 1e21, 1e22, 1e23],
    y_extrapolate=[held_out_target],
)
ax[0].set_xlabel("FLOPs")
ax[0].set_ylabel("Target N")


# C vs D

X = _pareto_df[x_col].values
Y = _pareto_df["tokens"].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[1],
    X,
    Y,
    slope,
    intercept,
    x_extrapolate=[4.471914e+19, 1e20, 1e21, 1e22, 1e23],
    y_extrapolate=[held_out_df.tokens.max()],
)
ax[1].set_xlabel("FLOPs")
ax[1].set_ylabel("Tokens")

# C vs Loss

X = _pareto_df[x_col].values
Y = _pareto_df[y_col].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[2],
    X,
    Y,
    slope,
    intercept,
    x_extrapolate=[4.471914e+19, 1e20, 1e21, 1e22, 1e23],
    y_extrapolate=[held_out_df[y_col].min()],
)
ax[2].set_xlabel("FLOPs")
ax[2].set_ylabel("Validation Loss")

### Approach 3
