In [None]:
from itertools import product
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from typing import List, Tuple, Optional, Dict, Any, Union, Callable, Iterable

In [None]:
from scaling.utils import (
    get_pareto_frontier, 
    get_final_points_from_curve_set, 
    fit_linear_model,
    functional_form_chin3,
    fit_parametric_form,
)
from scaling.visualize import visualize_train_curves, plot_line_fit

In [None]:
warmstarting_df = pd.read_parquet(
    "../data/warmstarting_results.parquet",
)
display(warmstarting_df)

In [None]:
# retaining only warmstarting results

warmstarting_df = warmstarting_df.loc[warmstarting_df.method != "mup"]
display(warmstarting_df)

### Visualize raw data

Visualize all learning curves across N, D available in the *training set* to see scaling patterns.

Optionally, consider visualizing for different hyperparameters.

In [None]:
unique_col_list = ["base_N", "target_N", "tkpm", "shrink"]
y_col = "Validation Loss"
x_col = "flops"  # "tokens"

In [None]:
plt.clf()
fig, ax = plt.subplots(figsize=(6, 4))

visualize_train_curves(
    ax, 
    warmstarting_df,
    unique_col_list,
    x_col=x_col,
    y_col=y_col,
    plot_all_curves=True,
    plot_final=True,
    plot_pareto_final=True,
    ylims=(1.5, 2.5),
    xlims=(1e16, 2e19),
    xlog=True
)

### Approach 1

Fit for `C vs N` and `C vs D`, assuming *best* training run for each `(N, D)`.

Steps:

1. Visualize for `C vs N` and `C vs D`
2. Fit linear model for each of them

In [None]:
x_col = "flops"
y_col = "Validation Loss"

_pareto_df = get_final_points_from_curve_set(
    warmstarting_df,
    unique_col_list,
    x_col=x_col,
    y_col=y_col,
    get_pareto=True
)

display(_pareto_df)

plt.clf()
fig, ax = plt.subplots(1, 2, figsize=(9, 4))

visualize_train_curves(
    ax[0],
    _pareto_df,
    unique_col_list,
    x_col=x_col,
    y_col="target_N",
    plot_all_curves=False,
    plot_final=True,
    plot_pareto_final=False,
    xlog=True,
    ylog=True,
)

visualize_train_curves(
    ax[1],
    _pareto_df,
    unique_col_list,
    x_col=x_col,
    y_col="tokens",
    plot_all_curves=False,
    plot_final=True,
    plot_pareto_final=False,
    xlog=True,
    ylog=True,
)


##### Using the largest model scale as the held out

In [None]:
# Creating held out over pareto df HERE

train_df = _pareto_df.loc[_pareto_df.target_N != sorted(_pareto_df.target_N.unique())[-1]]
held_out_df = _pareto_df.loc[_pareto_df.target_N == sorted(_pareto_df.target_N.unique())[-1]]
held_out_df

In [None]:
# Visualizing scaling law linear fit for C vs N, D, L

plt.clf();
fig, ax = plt.subplots(1, 3, figsize=(15, 4));

# C vs N

X = train_df[x_col].values
Y = train_df["target_N"].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[0],
    X,
    Y,
    slope,
    intercept,
    x_extrapolate=[held_out_df.flops.values[0], 1e20, 1e21, 1e22, 1e23],
    y_extrapolate=[held_out_df["target_N"].values[0]],
)
ax[0].set_xlabel("FLOPs")
ax[0].set_ylabel("Parameters (N)")


# C vs D

X = train_df[x_col].values
Y = train_df["tokens"].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[1],
    X,
    Y,
    slope,
    intercept,
    x_extrapolate=[held_out_df.flops.values[0], 1e20, 1e21, 1e22, 1e23],
    y_extrapolate=[held_out_df.tokens.values[0]],
)
ax[1].set_xlabel("FLOPs")
ax[1].set_ylabel("Tokens (D)")

# C vs Loss

X = train_df[x_col].values
Y = train_df[y_col].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[2],
    X,
    Y,
    slope,
    intercept,
    x_extrapolate=[held_out_df.flops.values[0], 1e20, 1e21, 1e22, 1e23],
    y_extrapolate=[held_out_df[y_col].values[0]],
    ylog=False,
)
ax[2].set_xlabel("FLOPs")
ax[2].set_ylabel(y_col)

##### Using the top-2 model scales as the held out

In [None]:
# Creating held out over pareto df HERE

train_df = _pareto_df.loc[_pareto_df.target_N.isin(sorted(_pareto_df.target_N.unique())[:-2])]
held_out_df = _pareto_df.loc[_pareto_df.target_N.isin(sorted(_pareto_df.target_N.unique())[-2:])]
held_out_df = held_out_df.loc[held_out_df.tkpm.isin(train_df.tkpm.unique())]
held_out_df

In [None]:
# Visualizing scaling law linear fit for C vs N, D, L

plt.clf();
fig, ax = plt.subplots(1, 3, figsize=(18, 4));

# C vs N

X = train_df[x_col].values
Y = train_df["target_N"].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[0],
    X,
    Y,
    slope,
    intercept,
    x_extrapolate=held_out_df.flops.values.tolist() + [1e20, 1e21, 1e22, 1e23],
    y_extrapolate=held_out_df["target_N"].values,
)
ax[0].set_xlabel("FLOPs")
ax[0].set_ylabel("Target N")


# C vs D

X = train_df[x_col].values
Y = train_df["tokens"].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[1],
    X,
    Y,
    slope,
    intercept,
    x_extrapolate=held_out_df.flops.values.tolist() + [1e20, 1e21, 1e22, 1e23],
    y_extrapolate=held_out_df.tokens.values,
)
ax[1].set_xlabel("FLOPs")
ax[1].set_ylabel("Tokens")

# C vs Loss

X = train_df[x_col].values
Y = train_df[y_col].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[2],
    X,
    Y,
    slope,
    intercept,
    x_extrapolate=held_out_df.flops.values.tolist() + [1e20, 1e21, 1e22, 1e23],
    y_extrapolate=held_out_df[y_col].values,
    ylog=False,
)
ax[2].set_xlabel("FLOPs")
ax[2].set_ylabel(y_col)

### Approach 3


In [None]:
df = get_final_points_from_curve_set(
    warmstarting_df,
    unique_col_list,
    x_col="flops",
    y_col="Validation Loss",
    get_pareto=False,
)

N, _ = df["target_N"].values, df["tokens"].values
_D = df["target_N"] * df["tkpm"]
y = df["Validation Loss"].values

_df = pd.DataFrame.from_dict({
    "N": N,
    "D": _D,
    "Loss": y
}).groupby(by=["N", "D"]).min().reset_index()
_df.sort_values(by=["N", "D"], inplace=True)

data_X = _df[["N", "D"]].values
data_y = _df["Loss"].values

In [None]:
initialization = list(product(
    [0, 5, 10, 15, 20, 25],  # a
    [0, 0.5, 1.0, 1.5, 2.0],  # alpha
    [0, 5, 10, 15, 20, 25],  # b
    [0, 0.5, 1.0, 1.5, 2.0],  # beta
    [-1, -0.5, 0, 0.5, 1]  # e
))

# initialization = list(product(
#     [0, 5,],  # a
#     [0, 0.5,],  # alpha
#     [0, 5, ],  # b
#     [0, 0.5,],  # beta
#     [-1, -0.5,]  # e
# ))


In [None]:
best_params, best_loss = fit_parametric_form(
    functional_form_chin3,
    data_X, 
    data_y, 
    initialization
)

print(f"Best Loss: {best_loss}")
print(f"a: {best_params[0]}, alpha={best_params[1]}\nb: {best_params[2]}, beta={best_params[3]}\ne={best_params[4]}")

In [None]:
_a, alpha, _b, beta, _e = best_params

A = np.exp(_a)
B = np.exp(_b)
E = np.exp(_e)

a = beta / (alpha + beta)
b = alpha / (alpha + beta)

print(a, b)

G = ((alpha*A) / (beta*B)) ** (1 / (alpha + beta))
print(G)

In [None]:
# Creating held out over pareto df HERE
x_col = "flops"
y_col = "Validation Loss"

_pareto_df = get_final_points_from_curve_set(
    warmstarting_df,
    unique_col_list,
    x_col=x_col,
    y_col=y_col,
    get_pareto=True
)
train_df = _pareto_df.loc[_pareto_df.target_N != sorted(_pareto_df.target_N.unique())[-1]]
# held_out_df = _pareto_df.loc[_pareto_df.target_N == sorted(_pareto_df.target_N.unique())[-1]]
# held_out_df

In [None]:
# plotting for N_opt

plt.clf();
fig, ax = plt.subplots(1, 3, figsize=(15, 4));

# C vs N

X = train_df[x_col].values
Y = train_df["target_N"].values
# slope, intercept = b, G / 6**a

ax[0].scatter(X, Y, label="raw data")
_x_plot = np.linspace(X.min(), X.max(), 100)
ax[0].plot(
    _x_plot, 
    # np.exp(intercept + slope * np.log(_x_plot)), 
    G * (_x_plot / 6) ** a,
    color="red", 
    label="fitted line"
    )
ax[0].loglog()

# C vs D

X = train_df[x_col].values
Y = train_df["tokens"].values
# slope, intercept = b, G / 6**a

ax[1].scatter(X, Y, label="raw data")
_x_plot = np.linspace(X.min(), X.max(), 100)
ax[1].plot(
    _x_plot, 
    # np.exp(intercept + slope * np.log(_x_plot)), 
    G**-1 * (_x_plot / 6) ** b,
    color="red", 
    label="fitted line"
    )
ax[1].loglog()