In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats
import scipy.optimize as optimize

## Toy example for exponential

In [None]:
def NLL_from_data(lambda_,data):
    log_pdf = np.log(lambda_) - lambda_*data
    log_likelihood = np.sum(log_pdf)
    return -log_likelihood


def NLL_from_pdf(pdf):
    log_pdf = np.log(pdf)
    log_likelihood = np.sum(log_pdf)
    return -log_likelihood

In [None]:
np.random.seed(42)

lambda_ = 3.0
size=100
data = stats.expon.rvs(scale=1/lambda_,size=size) #draw 'size' datapoints from expon dist with given parameter
pdf = stats.expon.pdf(data,scale=1/lambda_) #calculate probs of the datapoints

print(NLL_from_data(lambda_=lambda_,data=data),NLL_from_pdf(pdf))

In [None]:

#data gives the arguments to be fixed
#minimise call expects that the first argument is one to be optimised!
p0=[1.0]
bounds = [(1e-6,None)]
options = {'maxiter':1000}
est_1 = (optimize.minimize(NLL_from_data,x0=p0,args=(data,),method='L-BFGS-B',bounds=bounds,options=options)).x[0]

## Truncated exponential fit

In [None]:

#taken from 'normed fit data' in other file
DATA = [np.array([8.94813851e-01, 1.05365097e-01, 0, 0,
        0, 0]),
 np.array([7.14377891e-01, 2.14385034e-01, 7.15299320e-02, 0,
        0, 0]),
 np.array([6.55267360e-01, 2.06996136e-01, 1.38031332e-01, 0,
        0, 0])]

FLOP_bins = ['23-24', '24-25', '25-26', '26-27', '27-28', '28-29']
mapped_FLOP_bins = np.arange(1,len(FLOP_bins)+1) #this x range is quite arbitrary
assert len(FLOP_bins)==len(mapped_FLOP_bins)

In [None]:
def discret_exp_dist(lambda_,start=0,end=1,bin_width=0.1):
    n_bins = int(((end-start)/bin_width)) #number of values in pmf
    res=n_bins*100 #100 pdf values for each bin
    dx = (end-start)/res
    data = np.linspace(start,end,res)
    pdf = stats.expon.pdf(data,scale=1/lambda_)
    pmf = np.zeros(n_bins)

    for i in range(n_bins):
        stt_idx = int((i*res)/n_bins)
        end_idx = int((((i+1)/n_bins)*res))
        prob_mass = np.sum(pdf[stt_idx:end_idx]*dx) #integrate
        pmf[i]=prob_mass

    return pmf


def truncated_exp_dist_pdf(lambda_,t,start=0,stop=1,size=100):
    '''
    Generate a truncated exponential distribution (continous)
    '''
    data = np.linspace(start=start,stop=stop,num=size)
    threshold_idx = int((t/(stop-start))*size)
    data_t = data[:threshold_idx]

    expon_pdf_t = stats.expon.pdf(x=data_t,scale=1/lambda_) #pdf from start -> threshold
    norm_factor = stats.expon.cdf(t,scale=1/lambda_) - stats.expon.cdf(start,scale=1/lambda_)
    truncated_pdf = expon_pdf_t/norm_factor
    
    rem_pdf = np.zeros(len(data)-len(data_t))

    full_pdf = np.concatenate([truncated_pdf,rem_pdf])

    return data,full_pdf

def truncated_exp_dist_pmf(pdf,data,t,bin_width=0.1):

    first_zero_idx = np.where(pdf==0)[0][0]
    start,stop=data[0],data[-1]

    #ensure that bins do not straddle non-zero part of pmf and zero part
    div = (t-start)/bin_width
    assert div%1==0 

    bin_lbs = np.arange(start,stop,bin_width)
    bin_ubs = bin_lbs+bin_width
    bin_mps = (bin_lbs+bin_ubs)/2

    n_bins = int((stop-start)/bin_width)
    res = len(pdf) #number of pdf values
    dx = (stop-start)/res
    pmf = np.zeros(n_bins)

    for i in range(n_bins):
        stt_idx = int((i*res)/n_bins)
        end_idx = int((((i+1)/n_bins)*res))
        prob_mass = np.sum(pdf[stt_idx:end_idx]*dx)
        pmf[i]=prob_mass
    
    return bin_mps,pmf

def compute_predicted_pmf(lambda_,bin_centers,threshold_bin):

    '''
    Determines a truncated exponential pmf given a scale parameter and bins to integrate over.
    Arguments:
        lambda_: scale parameter
        bin_centers: used to determine how to integrate pdf
        threshold_bin: used to determine where truncated exp pdf should go to 0

    
    '''

    bin_width = (np.diff(bin_centers))[0] #should be const.

    #determine useful bin properties
    bin_lbs = bin_centers-(bin_width/2)
    bin_ubs = bin_centers+(bin_width/2)
    bin_bounds = np.array(list(zip(bin_lbs,bin_ubs)))

    #determine where pdf should go to 0
    threshold = threshold_bin-(bin_width/2) #states where pdf should go to 0

    #producing pdf over most of input space
    start=0
    stop=2
    lambda_=lambda_
    t=threshold
    size=1000; assert np.log10(size)%1==0 #in this implementation we require size to be power of 10
    data,pdf=data,pdf = truncated_exp_dist_pdf(lambda_=lambda_,t=t,start=start,stop=stop,size=size)
    dx=np.round(np.mean(np.diff(data)),int(np.log10(len(data)))) #assuming data is of for 10^n
    assert(np.round(sum(pdf*dx),1)==1),print(np.sum(pdf*dx)) #check that prob mass sums to 0

    pred_pmf = np.zeros(len(bin_centers))
    for idx,bin_bound in enumerate(bin_bounds):
        lb,ub=bin_bound
        idx_a=np.argmin(np.abs(data-lb))
        idx_b=np.argmin(np.abs(data-ub))
        sliced_pdf = pdf[idx_a:idx_b]
        prob_mass = np.sum(pdf[idx_a:idx_b]*dx)
        pred_pmf[idx]=prob_mass

    return pred_pmf

def MSE_LOSS(lambda_,bin_centers,threshold_bin,true_pmf):
    pred_pmf = compute_predicted_pmf(lambda_,bin_centers,threshold_bin)
    loss=np.mean((pred_pmf-true_pmf)**2)
    #print(f"lambda: {lambda_}, loss: {loss}")
    return loss

In [None]:
#MSE FIT for optimal lambdas
import scipy.optimize as optimize

OPT_LAMBDA = []

for pmf in DATA:

    true_pmf=pmf
    bin_centers = np.linspace(0.1,1.1,len(FLOP_bins)) #mapped FLOP_bins
    bin_width = np.diff(bin_centers)[0]
    threshold_bin = bin_centers[(np.where(true_pmf==0)[0][0])]

    initial_lambda = [2.0]
    args = (bin_centers,threshold_bin,true_pmf) #args to minimize that are fixed

    result = optimize.minimize(MSE_LOSS,initial_lambda,args=args)

    opt_lambda = result.x
    OPT_LAMBDA.append(opt_lambda)


In [None]:
fig,axs=plt.subplots(nrows=3,ncols=2,figsize=(8,8))

bin_centers = np.linspace(0.1,1.1,len(FLOP_bins)) #mapped FLOP_bins
width = np.diff(bin_centers)[0]

#plot params
alpha=0.6
y_lim = [0,1]

for row_idx,ax_row in enumerate(axs):
    year = 2022+row_idx

    true_pmf = DATA[row_idx]
    opt_lambda = OPT_LAMBDA[row_idx]
    threshold_bin = bin_centers[(np.where(true_pmf==0)[0][0])]
    best_fit_pred_pmf = compute_predicted_pmf(opt_lambda,bin_centers,threshold_bin)

    for col_idx,ax in enumerate(ax_row):
        if col_idx==0: ax.bar(bin_centers,true_pmf,edgecolor='black',width=width,label=f'true pmf {year}',alpha=alpha,color='tab:blue')
        if col_idx==1: ax.bar(bin_centers,best_fit_pred_pmf,edgecolor='black',width=width,label=f'pred pmf {year}, $\lambda={np.round(opt_lambda[0],1)}$',alpha=alpha,color='red')

        ax.set_ylim(y_lim)
        ax.legend()
        ax.grid(alpha=0.6)

        if row_idx==2:
            ax.set_xticks(bin_centers)
            ax.set_xticklabels(FLOP_bins,rotation=45)
            ax.set_xlabel('FLOP bins',fontsize=12)
        else:
            ax.set_xticklabels([])



fig.tight_layout()

In [None]:
PLOT_LAMBDA=False
PLOT_PREDICTED_DIST=False


FLOP_bins = ['23-24', '24-25', '25-26', '26-27', '27-28', '28-29']
start_year = 2025
end_year = 2029
future_years = np.arange(start_year,end_year+1)
years = np.array([2022,2023,2024])

mapped_years = years - 2022 #0 indexing
mapped_future_years = future_years - 2022
OPT_LAMBDA_arr = np.array([elem[0] for elem in OPT_LAMBDA])


def exp_func(x,a,b):
    return a*np.exp(-b*x)

popt,pcov = optimize.curve_fit(exp_func,mapped_years,OPT_LAMBDA_arr,p0=(1,1))
PRED_LAMBDAS = exp_func(mapped_future_years,*popt)


#opt lambda extrapolation
if PLOT_LAMBDA:
    fig,ax=plt.subplots()

    all_years = np.concatenate([years,future_years])
    lambdas = np.concatenate([OPT_LAMBDA_arr,PRED_LAMBDAS])
    
    ax.plot(all_years,lambdas)

    #ax.set_xlim([0,10])
    ax.set_ylim([0,11])

#threshold bin extrapolation
threshold_bins = ['25-26','26-27','26-27']
future_threshold_bins = ['26-27','27-28','27-28','28-29','28-29'] #manually extrapolating right now - that's enough. Assuming that frontier training compute growing by an OOM every two years

#model number extrapolation
future_model_numbers = [75,110,180,280,400] #just vibing it right now - more interested in the distributions


fig,axs = plt.subplots(nrows=len(future_years),figsize=(6,10))
bin_centers = np.linspace(0.1,1.1,len(FLOP_bins))
width = np.diff(bin_centers)[0]

for idx,year in enumerate(future_years):
    pred_lambda = PRED_LAMBDAS[0]
    threshold_bin = future_threshold_bins[idx]
    ax = axs[idx]

    threshold_bin_idx = np.where(np.array(FLOP_bins)==threshold_bin)[0][0]; print(threshold_bin_idx)
    mapped_threshold_bin = bin_centers[threshold_bin_idx]
    predicted_pmf = compute_predicted_pmf(pred_lambda,bin_centers,mapped_threshold_bin)

    ax.bar(bin_centers,predicted_pmf,width=width,edgecolor='black')
    ax.set_title(f'{year} predicted pmf')
    ax.set_ylim([0,1])
    ax.grid(alpha=0.5)
    ax.set_xticks(bin_centers)
    ax.set_xticklabels(FLOP_bins)

fig.tight_layout()

PREDICTED_HISTOGRAM_DATA_DF = pd.DataFrame(index=FLOP_bins,columns=future_years)