# Please refer to https://github.com/ethancaballero/broken_neural_scaling_laws

In [None]:
import numpy as np

from scipy.signal import savgol_filter
from scipy.optimize import curve_fit


# for n = 0
def bnsl_with_0_break(_x, a, b, c0):
    y = a + b * _x**(-c0)
    return y

def bnsl_with_0_break__log(_x, a, b, c0):
    y = bnsl_with_0_break(_x, a, b, c0)
    return np.log(y+1)

def bnsl_with_0_break__msle_optim(p, _x, _y):
    a, b, c0 = p
    y = bnsl_with_0_break(_x, a, b, c0)
    return np.mean((np.log(y+1)-np.log(_y+1))**2)


# for n = 1
def bnsl_with_1_break(_x, a, b, c0, c1, d1, f1):
    y = a + b * _x**(-c0) * (1 + (_x/d1)**(1/f1))**(-c1 * f1)
    return y

def bnsl_with_1_break__log(_x, a, b, c0, c1, d1, f1):
    y = bnsl_with_1_break(_x, a, b, c0, c1, d1, f1)
    return np.log(y+1)

def bnsl_with_1_break__msle_optim(p, _x, _y):
    a, b, c0, c1, d1, f1 = p
    b = 1.25**b - 1 + 1e-8
    d1 = 1.25**d1 - 1 + 1e-8
    y = bnsl_with_1_break(_x, a, b, c0, c1, d1, f1)
    return np.mean((np.log(y+1)-np.log(_y+1))**2)


# for n = 2
def bnsl_with_2_break(_x, a, b, c0, c1, d1, f1, c2, d2, f2):
    y = a + b * _x**(-c0) * (1 + (_x/d1)**(1/f1))**(-c1 * f1) * (1 + (_x/d2)**(1/f2))**(-c2 * f2)
    return y

def bnsl_with_2_break__log(_x, a, b, c0, c1, d1, f1, c2, d2, f2):
    y = bnsl_with_2_break(_x, a, b, c0, c1, d1, f1, c2, d2, f2)
    return np.log(y+1)

def bnsl_with_2_break__msle_optim(p, _x, _y):
    a, b, c0, c1, d1, f1, c2, d2, f2 = p
    b = 1.25**b - 1 + 1e-8
    d1 = 1.25**d1 - 1 + 1e-8
    d2 = 1.25**d2 - 1 + 1e-8
    y = bnsl_with_2_break(_x, a, b, c0, c1, d1, f1, c2, d2, f2)
    return np.mean((np.log(y+1)-np.log(_y+1))**2)

In [None]:
split_ratio = [0.6, 0.65, 0.7, 0.75, 0.8, 0.85]

def fit_compute_scaling(budget_arr, error_arr, n):
    
    for split_point in split_ratio:
        x_fit, y_fit = budget_arr[:int(len(budget_arr)* split_point)], error_arr[:int(len(error_arr)* error_arr)]
        y_fit_smooth = savgol_filter(error_arr[:split_point], window_length=11, polyorder=3)

        x_test, y_test = budget_arr[int(len(budget_arr)* split_point):], error_arr[int(len(error_arr)* error_arr):]
        
        if n == 0:
            # break point = 0, equivalent to a standard power-law
            
            p_grid = (slice(0.0, 1., 0.1), slice(0, 1, 0.1), slice(0, 1, 0.1))
            res = scipy.optimize.brute(bnsl_with_0_break__msle_optim, p_grid, args=(x_fit, y_fit_smooth), full_output=False, finish=None, Ns=1, workers=-1)
            a, b, c0  = res
            
            y_log = np.log(y1+1)
            popt, _ = scipy.optimize.curve_fit(bnsl_with_0_break__log, x1, y_log, p0=[a, b, c0], maxfev=1000000000, 
                                               bounds=([0.0, -np.inf, -np.inf], [1, np.inf, np.inf]))

            a, b, c0 = popt
            y_pred = bnsl_with_0_break(x_test, a, b, c0)

        elif n == 1:
            # break point = 1
            
            p_grid = (slice(0.0, 1., 0.1), slice(0, 1, 0.1), slice(0, 1, 0.1), slice(0, 1, 0.1), slice(5, 15, 1.0), slice(0, 1, 0.1))
            res = scipy.optimize.brute(bnsl_with_1_break__msle_optim, p_grid, args=(x1, y1), full_output=False, finish=None, Ns=1, workers=-1)
            a, b, c0, c1, d1, f1 = res
            
            y_log = np.log(y1+1)
            popt, _ = scipy.optimize.curve_fit(bnsl_with_1_break__log, x1, y_log, p0=[a, b, c0, c1, d1, f1], maxfev=1000000000, 
                                               bounds=([0.0, -np.inf, -np.inf, -np.inf, 5, -np.inf], [1, np.inf, np.inf, np.inf, 15, np.inf]))
    
            a, b, c0, c1, d1, f1 = popt
            
            x_tile = np.linspace(x1[0], x2[-1], 1000)    
            y_pred = bnsl_with_1_break(x2, a, b, c0, c1, d1, f1)        

        elif n == 2:
            # break point = 2, use bnsl_with_2_break__msle_optim instaed
            continue

    return y_pred

def fit_model_scaling(params_arr, error_arr):
    x_fit, y_fit = params_arr[:-1], error_arr[:-1]
    x_test, y_test = params_arr[-1:], error_arr[-1:]

    p_grid = (slice(0.0, 1., 0.1), slice(0, 1, 0.1), slice(0, 1, 0.1))
    res = scipy.optimize.brute(bnsl_with_0_break__msle_optim, p_grid, args=(x_fit, y_fit), full_output=False, finish=None, Ns=1, workers=-1)
    a, b, c0  = res
    
    y_log = np.log(y1+1)
    popt, _ = scipy.optimize.curve_fit(bnsl_with_0_break__log, x1, y_log, p0=[a, b, c0], maxfev=1000000000, 
                                       bounds=([0.0, -np.inf, -np.inf], [1, np.inf, np.inf]))

    a, b, c0 = popt
    y_pred = bnsl_with_0_break(x_test, a, b, c0)

    return y_pred