In [None]:
from itertools import product
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from typing import List, Tuple, Optional, Dict, Any, Union, Callable, Iterable

In [None]:
from scaling.utils import (
    get_pareto_frontier, 
    get_final_points_from_curve_set, 
    fit_linear_model,
    functional_form_chin3,
    fit_parametric_form,
    fit_parametric_form_stable,
    functional_form_chin3_stable,
)
from scaling.visualize import visualize_train_curves, plot_line_fit, plot_isoflops

In [None]:
unique_col_list = ["base_N", "target_N", "tkpm", "shrink"]

def preprocess_warmstarting(df, y_col_to_smooth=None, smoothing_window=100):
    __df = pd.DataFrame()
    for i, x in enumerate(df.groupby(unique_col_list)):
        _df = x[1].sort_values(by="flops")
        # smooth it
        if y_col_to_smooth is not None:
            # +"_smoothed"
            _df[y_col_to_smooth] = _df[y_col_to_smooth].rolling(smoothing_window, win_type='gaussian', min_periods=1).mean(std=smoothing_window / 10)
        
        # scaling tokens and flops to the max
        max_intended_tokens = (_df.iloc[-1]["target_N"] * _df.iloc[-1]["tkpm"])
        if abs((max_intended_tokens -  _df["tokens"].max()) / _df["tokens"].max()) > 0.01:
            print("Wrong tkpm: ", x[0])
            continue
        _df["tokens"] = np.round(max_intended_tokens / _df["tokens"].max() * _df["tokens"])
        
        max_intended_flops = 6. * max_intended_tokens * _df["target_N"]
        _df["flops"] = np.round(max_intended_flops / _df["flops"].max() * _df["flops"])
        
        __df = pd.concat([__df, _df])
    return __df

In [None]:
warmstarting_df = pd.read_parquet(
    "../data/warmstarting_results.parquet",
)
warmstarting_df = preprocess_warmstarting(warmstarting_df)
display(warmstarting_df)


In [None]:
# retaining only warmstarting results

warmstarting_df = warmstarting_df.loc[warmstarting_df.method != "mup"]
display(warmstarting_df)

### Visualize raw data

Visualize all learning curves across N, D available in the *training set* to see scaling patterns.

Optionally, consider visualizing for different hyperparameters.

In [None]:

y_col = "Validation Loss"
x_col = "flops"  # "tokens"

final_points = get_final_points_from_curve_set(
    warmstarting_df,
    unique_col_list,
    x_col=x_col,
    y_col=y_col,
    get_pareto=False,
)

In [None]:
plt.clf()
fig, ax = plt.subplots(figsize=(6, 4))

pareto_df = visualize_train_curves(
    ax, 
    warmstarting_df,
    unique_col_list,
    x_col=x_col,
    y_col=y_col,
    plot_all_curves=True,
    plot_final=True,
    plot_pareto_final=True,
    ylims=(1.5, 2.5),
    xlims=(1e16, 2e19),
    xlog=True,
)

In [None]:
display(pareto_df)

### Approach 1

Fit for `C vs N` and `C vs D`, assuming *best* training run for each `(N, D)`.

Steps:

1. Visualize for `C vs N` and `C vs D`
2. Fit linear model for each of them

In [None]:
x_col = "flops"
y_col = "Validation Loss"

_pareto_df = get_final_points_from_curve_set(
    warmstarting_df,
    unique_col_list,
    x_col=x_col,
    y_col=y_col,
    get_pareto=True,
)

display(_pareto_df.head())

plt.clf()
fig, ax = plt.subplots(1, 3, figsize=(18, 5))

plot_line_fit(
    ax[0],
    X = _pareto_df[x_col],
    Y = _pareto_df["target_N"]
)
ax[0].set_ylabel("Target N")

plot_line_fit(
    ax[1],
    X = _pareto_df[x_col],
    Y = _pareto_df["tokens"],
)
ax[1].set_ylabel("Tokens")

plot_line_fit(
    ax[2],
    X = _pareto_df[x_col],
    Y = _pareto_df[y_col],
)
ax[2].set_ylabel(y_col)

fig.supxlabel("FLOPs")

for ax in ax.flat:
    leg = ax.get_legend()
    if leg:
        leg.remove()


##### Using the largest model scale as the held out

In [None]:
# Creating held out over pareto df HERE

train_df = _pareto_df.loc[_pareto_df.target_N != sorted(_pareto_df.target_N.unique())[-1]]
held_out_df = _pareto_df.loc[_pareto_df.target_N == sorted(_pareto_df.target_N.unique())[-1]]
held_out_df

In [None]:
# Visualizing scaling law linear fit for C vs N, D, L

plt.clf();
fig, ax = plt.subplots(1, 3, figsize=(15, 4.5));

# C vs N

X = train_df[x_col].values
Y = train_df["target_N"].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[0],
    X,
    Y,
    slope,
    intercept,
    x_extrapolate=[held_out_df.flops.values[0], 1e20, 1e21, 1e22, 1e23],
    y_extrapolate=[held_out_df["target_N"].values[0]],
)
ax[0].set_ylabel("Target N")


# C vs D

X = train_df[x_col].values
Y = train_df["tokens"].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[1],
    X,
    Y,
    slope,
    intercept,
    x_extrapolate=[held_out_df.flops.values[0], 1e20, 1e21, 1e22, 1e23],
    y_extrapolate=[held_out_df.tokens.values[0]],
)
ax[1].set_ylabel("Tokens")

# C vs Loss

X = train_df[x_col].values
Y = train_df[y_col].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[2],
    X,
    Y,
    slope,
    intercept,
    x_extrapolate=[held_out_df.flops.values[0], 1e20, 1e21, 1e22, 1e23],
    y_extrapolate=[held_out_df[y_col].values[0]],
)
ax[2].set_ylabel(y_col)

fig.supxlabel("FLOPs")

handles, labels = ax[0].get_legend_handles_labels()
for ax in ax.flat:
    leg = ax.get_legend()
    if leg:
        leg.remove()
fig.legend(handles, labels, loc="upper center", ncol=4)

##### Using the top-2 model scales as the held out

In [None]:
# Creating held out over pareto df HERE

train_df = _pareto_df.loc[_pareto_df.target_N.isin(sorted(_pareto_df.target_N.unique())[:-2])]
held_out_df = _pareto_df.loc[_pareto_df.target_N.isin(sorted(_pareto_df.target_N.unique())[-2:])]
held_out_df = held_out_df.loc[held_out_df.tkpm.isin(train_df.tkpm.unique())]
held_out_df

In [None]:
# Visualizing scaling law linear fit for C vs N, D, L

plt.clf();
fig, ax = plt.subplots(1, 3, figsize=(15, 4.5));

# C vs N

X = train_df[x_col].values
Y = train_df["target_N"].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[0],
    X,
    Y,
    slope,
    intercept,
    x_extrapolate=held_out_df.flops.values.tolist() + [1e20, 1e21, 1e22, 1e23],
    y_extrapolate=held_out_df["target_N"].values,
)
ax[0].set_ylabel("Target N")


# C vs D

X = train_df[x_col].values
Y = train_df["tokens"].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[1],
    X,
    Y,
    slope,
    intercept,
    x_extrapolate=held_out_df.flops.values.tolist() + [1e20, 1e21, 1e22, 1e23],
    y_extrapolate=held_out_df.tokens.values,
)
ax[1].set_ylabel("Tokens")

# C vs Loss

X = train_df[x_col].values
Y = train_df[y_col].values
slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")

plot_line_fit(
    ax[2],
    X,
    Y,
    slope,
    intercept,
    x_extrapolate=held_out_df.flops.values.tolist() + [1e20, 1e21, 1e22, 1e23],
    y_extrapolate=held_out_df[y_col].values,
)
ax[2].set_ylabel(y_col)

fig.supxlabel("FLOPs")

handles, labels = ax[0].get_legend_handles_labels()
for ax in ax.flat:
    leg = ax.get_legend()
    if leg:
        leg.remove()
fig.legend(handles, labels, loc="upper center", ncol=4)

#### Fixed Base Model

In [None]:
# Visualize the Pareto Front and display to what algorithm it belongs (I think we just have muP and Paws...)

In [None]:
warmstarting_df = pd.read_parquet(
    "../data/warmstarting_results.parquet",
)
warmstarting_df = preprocess_warmstarting(warmstarting_df)
base_model_sizes = warmstarting_df['base_N'].unique()

# create a plot with len(basemodel sizes) many axes
fig, axes = plt.subplots(1, len(base_model_sizes), figsize=(5 * len(base_model_sizes), 4.5));
fig2, axes2 = plt.subplots(1, len(base_model_sizes), figsize=(5 * len(base_model_sizes), 4.5));

for i, base_model_size in enumerate(base_model_sizes):
    base_size_df = warmstarting_df[warmstarting_df['base_N'] == base_model_size]
    
    axes[i].set_title(f"Base model size: {(base_model_size/1_000_000):3.1f}M")
    # plot for the different scaling factors 
    # get a color gradient

    shrink_factors = sorted(base_size_df['shrink'].unique())
    colors = plt.cm.Blues(np.linspace(0, 1, len(shrink_factors)))
    for j, shrink in enumerate(shrink_factors):
        shrink_df = base_size_df[base_size_df['shrink']==shrink]
        if shrink in [0.0, 0.2]:
            visualize_train_curves(
                axes[i], 
                shrink_df,
                unique_col_list,
                x_col=x_col,
                y_col=y_col,
                plot_all_curves=True,
                plot_final=True,
                plot_pareto_final=True,
                ylims=(1.5, 2.5),
                xlims=(1e16, 2e19),
                xlog=True,
                style={"color": plt.get_cmap("tab10").colors[j], "label": f"shrink={shrink}"})

        pareto_shrink_df = get_final_points_from_curve_set(
            shrink_df,
            unique_col_list,
            x_col=x_col,
            y_col=y_col,
            get_pareto=True,
        )
            
        X = pareto_shrink_df[x_col].values
        Y = pareto_shrink_df[y_col].values

        slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
        # print(f"Slope: {slope}, Intercept: {intercept}, R^2: {r_value**2}")
        if shrink == 0.:
            plot_line_fit(
                axes2[i],
                X,
                Y,
                slope,
                intercept,
                style={"color": "red", "label": f"shrink={shrink}"}
            )
        else:
            plot_line_fit(
                axes2[i],
                X,
                Y,
                slope,
                intercept,
                style={"color": colors[j], "label": f"shrink={shrink}"})
    # axes2[-1].legend()
    # list all scaling factors
    # Add the line plot in there
    # try to somehow add labels and colors
    
    # Possibly assign colors based of identifier columns
    # Fit the loss for each shrinking?
    # Fit the loss for all approaches 
for ax in axes2.flat:
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys())
    
    leg = ax.get_legend()
    if leg:
        pass # leg.remove()

for ax in axes.flat:
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys())
    
    leg = ax.get_legend()
    if leg:
        pass # leg.remove()

## Fixed Growth Factor

In [None]:
def filter_pairs(df, jump_size=1):
    distinct_values = sorted(
        pd.unique(df[['base_N', 'target_N']].values.ravel())
    )
    pairs = list(zip(distinct_values[:-jump_size], distinct_values[jump_size:]))
    df_filtered = df[df[['base_N', 'target_N']].apply(tuple, axis=1).isin(pairs)]
    return df_filtered

y_col = "Validation Loss"

num_jump_sizes = 1
fig, axes = plt.subplots(1, num_jump_sizes, figsize=((5 * num_jump_sizes), 4.5), layout='constrained');
axes = np.atleast_1d(axes)
fig2, axes2 = plt.subplots(1, num_jump_sizes, figsize=((5 * num_jump_sizes), 4.5), layout='constrained');
axes2 = np.atleast_1d(axes2)
"""
mup_df = warmstarting_df[warmstarting_df['method']=='mup']
pareto_shrink_df = get_final_points_from_curve_set(
    mup_df,
    unique_col_list,
    x_col=x_col,
    y_col=y_col,
    get_pareto=True,
)
X = pareto_shrink_df[x_col].values
Y = pareto_shrink_df[y_col].values
for i in range(num_jump_sizes):
    plot_line_fit(
        axes[i],
        X,
        Y,
        slope,
        intercept,
        style={"color": "red", "label": f"mup"}
    )
"""

for i in range(num_jump_sizes):

    axes[i].set_title(f"{(2)**(i+1)}x growth factor")
    jump_df = filter_pairs(warmstarting_df, jump_size=i+1)

    shrink_factors = sorted(jump_df['shrink'].unique())
    colors = plt.cm.Blues(np.linspace(0, 1, len(shrink_factors)))
    for j, shrink in enumerate(shrink_factors):
        if shrink not in [0.0, 1.0, 0.4]:
            continue
        shrink_df = jump_df[jump_df['shrink']==shrink]
        pareto_shrink_df = get_final_points_from_curve_set(
            shrink_df,
            unique_col_list,
            x_col=x_col,
            y_col=y_col,
            get_pareto=True,
        )
            
        X = pareto_shrink_df[x_col].values
        Y = pareto_shrink_df[y_col].values

        if shrink in [0.0, 1.0, 0.4]:
            visualize_train_curves(
                axes2[i], 
                shrink_df,
                unique_col_list,
                x_col=x_col,
                y_col=y_col,
                plot_all_curves=True,
                plot_final=True,
                plot_pareto_final=True,
                ylims=(1.5, 2.5),
                xlims=(1e16, 2e19),
                xlog=True,
                style={"color": plt.get_cmap("tab10").colors[j], "label": f"shrink={shrink}"})

        slope, intercept, r_value, p_value, std_err = fit_linear_model(X, Y)
        
        if shrink != 0.0:
            style={"color": colors[j], "label": r"$\lambda_{shrink}" + rf"={shrink}$ ($L={intercept:.2f}\cdot C^" + r"{" f"{slope:.3f}" + r"}$)"}
        else:
            style={"color": "red", "label": rf"$\mu P$ ($L={intercept:.2f}\cdot C^" + r"{" f"{slope:.3f}" + r"}$)"}
        
        plot_line_fit(
            axes[i],
            X,
            Y,
            slope,
            intercept,
            style=style,
            x_max_plot=1e19
        )

fig.supxlabel("FLOPs")
fig.supylabel(y_col)
for ax in axes.flat:
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys())
    
    leg = ax.get_legend()
    if leg:
        pass # leg.remove()

for ax in axes2.flat:
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys())
    
    leg = ax.get_legend()
    if leg:
        pass # leg.remove()

# set axis to linear scale
for ax in axes.flat:
    ax.set_xscale('linear') 
    ax.set_yscale('linear') 
    

### Approach 3


In [None]:
mup_df = warmstarting_df[warmstarting_df['method']=='mup']
df = get_final_points_from_curve_set(
    mup_df,
    unique_col_list,
    x_col="flops",
    y_col="Validation Loss",
    get_pareto=False,
)

N, _ = df["target_N"].values, df["tokens"].values
_D = df["target_N"] * df["tkpm"]
y = df["Validation Loss"].values

_df = pd.DataFrame.from_dict({
    "N": N,
    "D": _D,
    "Loss": y
}).groupby(by=["N", "D"]).min().reset_index()
_df.sort_values(by=["N", "D"], inplace=True)

data_X = _df[["N", "D"]].values
data_y = _df["Loss"].values

In [None]:
initialization = list(product(
    np.linspace(0, 15, 5),  # a
    np.linspace(0., 1., 10),  # alpha
    np.linspace(0, 15, 5),  # b
    np.linspace(0., 1., 10),  # beta
    np.linspace(-1., 1., 5),  # e
))

# initialization = list(product(
#     [0, 5,],  # a
#     [0, 0.5,],  # alpha
#     [0, 5, ],  # b
#     [0, 0.5,],  # beta
#     [-1, -0.5,]  # e
# ))


In [None]:
best_params, best_loss = fit_parametric_form_stable(
    functional_form_chin3_stable,
    data_X, 
    data_y, 
    initialization
)

print(f"Best Loss: {best_loss}")
print(f"a: {best_params[0]}, alpha={best_params[1]}\nb: {best_params[2]}, beta={best_params[3]}\ne={best_params[4]}")

In [None]:
_a, alpha, _b, beta, _e = best_params

A = np.exp(_a)
B = np.exp(_b)
E = np.exp(_e)

a = beta / (alpha + beta)
b = alpha / (alpha + beta)

print(a, b)

G = ((alpha*A) / (beta*B)) ** (1 / (alpha + beta))
print(G)

In [None]:
# Creating held out over pareto df HERE
x_col = "flops"
y_col = "Validation Loss"

_pareto_df = get_final_points_from_curve_set(
    mup_df,
    unique_col_list,
    x_col=x_col,
    y_col=y_col,
    get_pareto=True
)
train_df = _pareto_df
# train_df = _pareto_df.loc[_pareto_df.target_N != sorted(_pareto_df.target_N.unique())[-1]]
# held_out_df = _pareto_df.loc[_pareto_df.target_N == sorted(_pareto_df.target_N.unique())[-1]]
# held_out_df

In [None]:
# plotting for N_opt

plt.clf();
fig, ax = plt.subplots(1, 3, figsize=(15, 4));

# C vs N

X = train_df[x_col].values
Y = train_df["target_N"].values
# slope, intercept = b, G / 6**a

ax[0].scatter(X, Y, label="raw data")
_x_plot = np.linspace(X.min(), X.max(), 100)
ax[0].plot(
    _x_plot, 
    # np.exp(intercept + slope * np.log(_x_plot)), 
    G * (_x_plot / 6) ** a,
    color="red", 
    label="fitted line"
    )
ax[0].loglog()

# C vs D

X = train_df[x_col].values
Y = train_df["tokens"].values
# slope, intercept = b, G / 6**a

ax[1].scatter(X, Y, label="raw data")
_x_plot = np.linspace(X.min(), X.max(), 100)
ax[1].plot(
    _x_plot, 
    # np.exp(intercept + slope * np.log(_x_plot)), 
    G**-1 * (_x_plot / 6) ** b,
    color="red", 
    label="fitted line"
    )
ax[1].loglog()

# Isoflops

In [None]:

    
def get_loss_at_flops(df: pd.DataFrame, flop_intervals: list[float], y_col: str, unique_col_list = list[str], add_base_compute=False) -> pd.Series:
    """Get the loss at a specific flop value by interpolation."""
    x_col = "flops"
    best_learning_curve = None
    best_final_loss = float('inf')

    for i, x in enumerate(df.groupby(unique_col_list)):
        _df = x[1].dropna(subset=[y_col]).sort_values(by=x_col)
        if add_base_compute:
            base_flops = 6. * 20. * _df.iloc[0]['base_N']**2
            _df[x_col] += base_flops
        final_loss = _df.iloc[-1][y_col]
        if final_loss < best_final_loss:
            best_final_loss = final_loss
            best_learning_curve = pd.Series(
                data=_df[y_col].values,
                index=_df[x_col].values
            )
    
    # add the flops into the Series if not present
    for flop in flop_intervals:
        if flop not in best_learning_curve.index:
            best_learning_curve.loc[flop] = np.nan
    best_learning_curve = best_learning_curve.sort_index()
    # interpolate nans
    best_learning_curve = best_learning_curve.interpolate(method='linear')
    return best_learning_curve.loc[flop_intervals]
    

In [None]:
SHRINK = 0.4
ADD_BASE_COMPUTE = False
TKPM = 20.
MIN_FLOPS_SCALE_FACTOR = 10

warmstarting_tkpms_df = warmstarting_df[warmstarting_df['tkpm']==TKPM]

target_models = sorted(warmstarting_tkpms_df['target_N'].unique())
target_models = target_models[1:-1]  # skip the smallest model
fig, axes = plt.subplots(1, len(target_models), figsize=(5 * len(target_models), 5), layout='constrained');
for i, target_model in enumerate(target_models):
    target_model_df = warmstarting_tkpms_df[warmstarting_tkpms_df['target_N']==target_model]
    no_growth_df = target_model_df[target_model_df['method']=='mup']
    
    # calculate flop intervals
    max_flops = no_growth_df['flops'].max()
    min_flops = max_flops / MIN_FLOPS_SCALE_FACTOR
    flop_intervals = np.linspace(min_flops, max_flops, 7)
    flops_df = pd.DataFrame()
    
    # add growth factor == 1
    no_growth_df = no_growth_df[no_growth_df["base_N"]==no_growth_df["base_N"].max()]
    flops_df[1.] = get_loss_at_flops(no_growth_df, flop_intervals, y_col, unique_col_list)
    
    # add growth factor > 1
    shrink_target_model_df = target_model_df[target_model_df['shrink'] == SHRINK]
    shrink_target_model_df = target_model_df[target_model_df['method'] == 'warmstart']
    # check if it is in a shrink list
    
    base_models = sorted(shrink_target_model_df['base_N'].unique(), reverse=True)
    for base_model in base_models:
        base_model_df = shrink_target_model_df[shrink_target_model_df['base_N']==base_model]
        
        growth_factor = target_model / base_model
        growth_df = base_model_df[base_model_df['target_N']==target_model]
        flops_df[growth_factor] = get_loss_at_flops(growth_df, flop_intervals, y_col, unique_col_list, add_base_compute=ADD_BASE_COMPUTE)
        # select only the shrink factor we want
    axes[i].set_title(f"Target N: {(target_model/1_000_000):3.1f}M")
    plot_isoflops(
        axes[i],
        flops_df,
        disable_y_label=(i == len(target_models) - 1),
    )
    
# add figure wide xlabel
fig.supxlabel("Growth Factor", fontsize=15)
fig.supylabel(y_col, fontsize=15)


In [None]:
SHRINK = 0.4
ADD_BASE_COMPUTE = False
TKPM = 30.
MIN_FLOPS_SCALE_FACTOR = 10

warmstarting_tkpms_df = warmstarting_df[warmstarting_df['tkpm']==TKPM]

target_models = sorted(warmstarting_tkpms_df['target_N'].unique())
target_models = target_models[1:-1]  # skip the smallest model
fig, axes = plt.subplots(1, len(target_models), figsize=(5 * len(target_models), 5), layout='constrained');
for i, target_model in enumerate(target_models):
    target_model_df = warmstarting_tkpms_df[warmstarting_tkpms_df['target_N']==target_model]
    no_growth_df = target_model_df[target_model_df['method']=='mup']
    
    # calculate flop intervals
    max_flops = no_growth_df['flops'].max()
    min_flops = max_flops / MIN_FLOPS_SCALE_FACTOR
    flop_intervals = np.linspace(min_flops, max_flops, 7)
    flops_df = pd.DataFrame()
    
    # add growth factor == 1
    no_growth_df = no_growth_df[no_growth_df["base_N"]==no_growth_df["base_N"].max()]
    flops_df[1.] = get_loss_at_flops(no_growth_df, flop_intervals, y_col, unique_col_list)
    
    # add growth factor > 1
    shrink_target_model_df = target_model_df[target_model_df['shrink'] == SHRINK]
    shrink_target_model_df = target_model_df[target_model_df['method'] == 'warmstart']
    # check if it is in a shrink list
    
    base_models = sorted(shrink_target_model_df['base_N'].unique(), reverse=True)
    for base_model in base_models:
        base_model_df = shrink_target_model_df[shrink_target_model_df['base_N']==base_model]
        
        growth_factor = target_model / base_model
        growth_df = base_model_df[base_model_df['target_N']==target_model]
        flops_df[growth_factor] = get_loss_at_flops(growth_df, flop_intervals, y_col, unique_col_list, add_base_compute=ADD_BASE_COMPUTE)
        # select only the shrink factor we want
    axes[i].set_title(f"Target N: {(target_model/1_000_000):3.1f}M")
    plot_isoflops(
        axes[i],
        flops_df,
        disable_y_label=(i == len(target_models) - 1),
    )
    
# add figure wide xlabel
fig.supxlabel("Growth Factor", fontsize=15)
fig.supylabel(y_col, fontsize=15)
