In [None]:
from itertools import product
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from scaling.utils import (
    get_pareto_frontier, 
    get_final_points_from_curve_set, 
    fit_linear_model,
    functional_form_chin3,
    fit_parametric_form,
    fit_parametric_form_stable,
    functional_form_chin3_stable,
)
from scaling.visualize import visualize_train_curves, plot_line_fit, plot_isoflops

In [None]:
unique_col_list = ["base_N", "target_N", "tkpm", "shrink"]

def preprocess_warmstarting(df, y_col_to_smooth=None, smoothing_window=100):
    __df = pd.DataFrame()
    for i, x in enumerate(df.groupby(unique_col_list)):
        _df = x[1].sort_values(by="flops")
        # smooth it
        if y_col_to_smooth is not None:
            # +"_smoothed"
            _df[y_col_to_smooth] = _df[y_col_to_smooth].rolling(smoothing_window, win_type='gaussian', min_periods=1).mean(std=smoothing_window / 10)
        
        # scaling tokens and flops to the max
        max_intended_tokens = (_df.iloc[-1]["target_N"] * _df.iloc[-1]["tkpm"])
        if abs((max_intended_tokens -  _df["tokens"].max()) / _df["tokens"].max()) > 0.01:
            print("Wrong tkpm: ", x[0])
            continue
        _df["tokens"] = np.round(max_intended_tokens / _df["tokens"].max() * _df["tokens"])
        
        max_intended_flops = 6. * max_intended_tokens * _df["target_N"]
        _df["flops"] = np.round(max_intended_flops / _df["flops"].max() * _df["flops"])
        
        __df = pd.concat([__df, _df])
    return __df

In [None]:
mlp_df = pd.read_parquet(
    "../data/mlp_results.parquet",
)
# warmstarting_df = preprocess_warmstarting(warmstarting_df)
display(mlp_df)

In [None]:
def get_loss_at_flops(df: pd.DataFrame, flop_intervals: list[float], y_col: str, unique_col_list = list[str], add_base_compute=False) -> pd.Series:
    """Get the loss at a specific flop value by interpolation."""
    x_col = "flops"
    best_learning_curve = None
    best_final_loss = float('inf')

    for i, x in enumerate(df.groupby(unique_col_list)):
        _df = x[1].dropna(subset=[y_col]).sort_values(by=x_col)
        if add_base_compute:
            base_flops = 6. * 20. * _df.iloc[0]['base_N']**2
            _df[x_col] += base_flops
        final_loss = _df.iloc[-1][y_col]
        if final_loss < best_final_loss:
            best_final_loss = final_loss
            best_learning_curve = pd.Series(
                data=_df[y_col].values,
                index=_df[x_col].values
            )
    
    # add the flops into the Series if not present
    for flop in flop_intervals:
        if flop not in best_learning_curve.index:
            best_learning_curve.loc[flop] = np.nan
    best_learning_curve = best_learning_curve.sort_index()
    # interpolate nans
    best_learning_curve = best_learning_curve.interpolate(method='linear')
    return best_learning_curve.loc[flop_intervals]

In [None]:
SHRINK = 0.4
ADD_BASE_COMPUTE = False
TKPM = 20.
MIN_FLOPS_SCALE_FACTOR = 10

warmstarting_tkpms_df = warmstarting_df[warmstarting_df['tkpm']==TKPM]

target_models = sorted(warmstarting_tkpms_df['target_N'].unique())
target_models = target_models[1:-1]  # skip the smallest model
fig, axes = plt.subplots(1, len(target_models), figsize=(5 * len(target_models), 5), layout='constrained');
for i, target_model in enumerate(target_models):
    target_model_df = warmstarting_tkpms_df[warmstarting_tkpms_df['target_N']==target_model]
    no_growth_df = target_model_df[target_model_df['method']=='mup']
    
    # calculate flop intervals
    max_flops = no_growth_df['flops'].max()
    min_flops = max_flops / MIN_FLOPS_SCALE_FACTOR
    flop_intervals = np.linspace(min_flops, max_flops, 7)
    flops_df = pd.DataFrame()
    
    # add growth factor == 1
    no_growth_df = no_growth_df[no_growth_df["base_N"]==no_growth_df["base_N"].max()]
    flops_df[1.] = get_loss_at_flops(no_growth_df, flop_intervals, y_col, unique_col_list)
    
    # add growth factor > 1
    shrink_target_model_df = target_model_df[target_model_df['shrink'] == SHRINK]
    shrink_target_model_df = target_model_df[target_model_df['method'] == 'warmstart']
    # check if it is in a shrink list
    
    base_models = sorted(shrink_target_model_df['base_N'].unique(), reverse=True)
    for base_model in base_models:
        base_model_df = shrink_target_model_df[shrink_target_model_df['base_N']==base_model]
        
        growth_factor = target_model / base_model
        growth_df = base_model_df[base_model_df['target_N']==target_model]
        flops_df[growth_factor] = get_loss_at_flops(growth_df, flop_intervals, y_col, unique_col_list, add_base_compute=ADD_BASE_COMPUTE)
        # select only the shrink factor we want
    axes[i].set_title(f"Target N: {(target_model/1_000_000):3.1f}M")
    plot_isoflops(
        axes[i],
        flops_df,
        disable_y_label=(i == len(target_models) - 1),
    )
    
# add figure wide xlabel
fig.supxlabel("Growth Factor", fontsize=15)
fig.supylabel(y_col, fontsize=15)