In [None]:
from itertools import product
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

from scaling.utils import (
    get_pareto_frontier, 
    get_final_points_from_curve_set, 
    fit_linear_model,
    functional_form_chin3,
    fit_parametric_form,
)
from scaling.visualize import visualize_train_curves, plot_line_fit

In [None]:
porian_df = pd.read_pickle(
    "../data/porian_results.pickle.xz",
    compression='xz'
)
porian_df.head()

In [None]:
from scipy.ndimage import gaussian_filter1d


def gaussian_smoothing(
    data: list | np.ndarray | pd.Series,
    window_size: int=1
) -> list | np.ndarray | pd.Series:
    sigma = window_size / 2.0
    smoothed_data = gaussian_filter1d(data, sigma)

    return smoothed_data

In [None]:
# adding final loss columns

porian_df['val/loss_final'] = porian_df['val/loss'].apply(lambda x: x.iloc[-1])
porian_df['train/loss_final'] = porian_df['train/loss'].apply(lambda x: x.iloc[-1])
porian_df['val/loss_final_std'] = porian_df['val/loss_std'].apply(lambda x: x.iloc[-1])

porian_df.head(n=3)

In [None]:
porian_df["D"] = (porian_df["seq_len"] * porian_df["bs"] * porian_df['max_step'])

In [None]:
# adding total parameters

# Create the lookup dictionary from the Table 2 in Appendix B (https://arxiv.org/abs/2406.19146)
param_lookup = {
    (3, 96): {'N_exact': 5.176, 'N_eff': 5.763, 'N_Kaplan': 0.331},
    (4, 128): {'N_exact': 7.508, 'N_eff': 8.552, 'N_Kaplan': 1.049},
    (5, 160): {'N_exact': 9.817, 'N_eff': 11.45, 'N_Kaplan': 1.741},
    (6, 224): {'N_exact': 15.61, 'N_eff': 18.35, 'N_Kaplan': 4.301},
    (8, 288): {'N_exact': 22.51, 'N_eff': 27.21, 'N_Kaplan': 7.963},
    (9, 320): {'N_exact': 28.70, 'N_eff': 34.57, 'N_Kaplan': 12.53},
    (10, 384): {'N_exact': 37.09, 'N_eff': 44.92, 'N_Kaplan': 17.69},
    (12, 480): {'N_exact': 57.43, 'N_eff': 69.18, 'N_Kaplan': 33.18},
    (14, 576): {'N_exact': 84.85, 'N_eff': 101.3, 'N_Kaplan': 55.74},
    (15, 640): {'N_exact': 108.5, 'N_eff': 128.1, 'N_Kaplan': 76.19},
    (18, 704): {'N_exact': 149.1, 'N_eff': 175.0, 'N_Kaplan': 113.5},
    (21, 832): {'N_exact': 221.0, 'N_eff': 256.7, 'N_Kaplan': 178.9},
    (23, 1024): {'N_exact': 347.3, 'N_eff': 395.3, 'N_Kaplan': 295.4},
    (26, 1120): {'N_exact': 455.5, 'N_eff': 514.9, 'N_Kaplan': 398.8},
    (26, 1312): {'N_exact': 612.2, 'N_eff': 681.8, 'N_Kaplan': 545.8},
    (30, 1504): {'N_exact': 902.1, 'N_eff': 994.1, 'N_Kaplan': 825.9}
}

scale_to_millions_factor = 1e6

# Add the three new columns to your dataframe
porian_df['N_exact'] = porian_df.apply(
    lambda row: scale_to_millions_factor * param_lookup.get((row['depth'], row['width']), {}).get('N_exact', None), axis=1
)
porian_df['N_eff'] = porian_df.apply(
    lambda row: scale_to_millions_factor * param_lookup.get((row['depth'], row['width']), {}).get('N_eff', None), axis=1
)
porian_df['N_Kaplan'] = porian_df.apply(
    lambda row: scale_to_millions_factor * param_lookup.get((row['depth'], row['width']), {}).get('N_Kaplan', None), axis=1
)


porian_df.head(3)

In [None]:
# main params to consider
# DATASET = "rw"
DATASET = "owt2"
DECAY = "chinchilla"
porian_df = porian_df[porian_df['dataset'] == DATASET]
porian_df = porian_df[porian_df['decay'] == DECAY]

In [None]:
porian_df["N"] = porian_df["N_eff"]
porian_df["C"] = 6 * porian_df["N"] * porian_df["D"]
porian_df = porian_df.reset_index()

In [None]:
fig, ax = plt.subplots();
ax.scatter(
    porian_df["N"],
    porian_df["D"],
)
ax.set_xlabel("N")
ax.set_ylabel("D")
ax.loglog()

In [None]:
unique_col_list = [
    "index" # every entry is a single training run
]
x_col = 'C'
y_col = 'val/loss_final'

In [None]:
_pareto_df = get_final_points_from_curve_set(
    porian_df,
    unique_col_list,
    x_col=x_col,
    y_col=y_col,
    get_pareto=True
)
plt.clf()
fig, ax = plt.subplots(1, 2, figsize=(9, 4))

visualize_train_curves(
    ax[0],
    _pareto_df,
    unique_col_list,
    x_col=x_col,
    y_col="N",
    plot_all_curves=False,
    plot_final=True,
    plot_pareto_final=False,
    xlog=True,
    ylog=True,
)

visualize_train_curves(
    ax[1],
    _pareto_df,
    unique_col_list,
    x_col=x_col,
    y_col="D",
    plot_all_curves=False,
    plot_final=True,
    plot_pareto_final=False,
    xlog=True,
    ylog=True,
)

In [None]:
# _pareto_df = _pareto_df[_pareto_df["C"] > 4e15]
# porian_df = porian_df[porian_df["C"] > 4e15]

# Approach 1

In [None]:
# Visualizing scaling law linear fit for C vs N, D, L

plt.clf();
fig, ax = plt.subplots(1, 3, figsize=(15, 4));

# C vs N

X = _pareto_df[x_col].values
Y = _pareto_df["N"].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[0],
    X,
    Y,
    slope,
    intercept,
)
ax[0].set_xlabel("FLOPs")
ax[0].set_ylabel("Parameters (N)")


# C vs D

X = _pareto_df[x_col].values
Y = _pareto_df["D"].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[1],
    X,
    Y,
    slope,
    intercept,
)
ax[1].set_xlabel("FLOPs")
ax[1].set_ylabel("Tokens (D)")

# C vs Loss

X = _pareto_df[x_col].values
Y = _pareto_df[y_col].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[2],
    X,
    Y,
    slope,
    intercept,
)
ax[2].set_xlabel("FLOPs")
ax[2].set_ylabel(y_col)

# Approach 3

In [None]:
df = get_final_points_from_curve_set(
    porian_df,
    unique_col_list,
    x_col="C",
    y_col="loss",
    get_pareto=False,
)

N = df["N"].values
D = df["D"].values
y = df[y_col].values

_df = pd.DataFrame.from_dict({
    "N": N,
    "D": D,
    "Loss": y
}).groupby(by=["N", "D"]).min().reset_index()
_df.sort_values(by=["N", "D"], inplace=True)

data_X = _df[["N", "D"]].values
data_y = _df["Loss"].values

In [None]:
initialization = list(product(
    np.linspace(0, 10, 5),  # a
    np.linspace(0., 1., 5),  # alpha
    np.linspace(0, 10, 5),  # b
    np.linspace(0., 1., 5),  # beta
    np.linspace(0., 1., 5)  # e
))

In [None]:
best_params, best_loss = fit_parametric_form(
    functional_form_chin3,
    data_X, 
    data_y, 
    initialization
)

print(f"Best Loss: {best_loss}")
print(f"a: {best_params[0]}, alpha={best_params[1]}\nb: {best_params[2]}, beta={best_params[3]}\ne={best_params[4]}")

In [None]:
_a, alpha, _b, beta, _e = best_params

A = np.exp(_a)
B = np.exp(_b)
E = np.exp(_e)

a = beta / (alpha + beta)
b = alpha / (alpha + beta)

print(a, b)

G = ((alpha*A) / (beta*B)) ** (1 / (alpha + beta))
print(G)

In [None]:
train_df = _pareto_df
# held_out_df = _pareto_df.loc[_pareto_df.target_N == sorted(_pareto_df.target_N.unique())[-1]]
# held_out_df
plt.clf();
fig, ax = plt.subplots(1, 2, figsize=(10, 4));

# C vs N

X = train_df[x_col].values
Y = train_df["N"].values
# slope, intercept = b, G / 6**a

ax[0].scatter(X, Y, label="raw data")
_x_plot = np.linspace(X.min(), X.max(), 100)
ax[0].plot(
    _x_plot, 
    # np.exp(intercept + slope * np.log(_x_plot)), 
    G * (_x_plot / 6) ** a,
    color="red", 
    label="fitted line"
    )
ax[0].loglog()


# C vs D

X = train_df[x_col].values
Y = train_df["D"].values
# slope, intercept = b, G / 6**a

ax[1].scatter(X, Y, label="raw data")
_x_plot = np.linspace(X.min(), X.max(), 100)
ax[1].plot(
    _x_plot, 
    # np.exp(intercept + slope * np.log(_x_plot)), 
    G**-1 * (_x_plot / 6) ** b,
    color="red", 
    label="fitted line"
    )
ax[1].loglog()