## Setup and preprocess df

In [8]:
import numpy as np
from scipy import stats, optimize
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import itertools
import copy,re, pdb
%matplotlib inline

In [9]:
# df = pd.read_csv("https://epochai.org/data/epochdb/notable_systems.csv")
url = 'https://drive.google.com/file/d/1RLLKPU3bEYK65wlQlU0p20u9M8cHkLMl/view?usp=sharing'
url = 'https://drive.google.com/uc?id=' + url.split('/')[-2]

df = pd.read_csv(url)

df = df[~df["Notability criteria"].isna()]

df["compute"] = df["Training compute (FLOP)"]
df["date"] = df["Publication date"]
df["model"] = df["System"]
df["poss1e23"] = df["Possibly over 1e23 FLOP"]
df["poss1e25"] = df["Estimated over 1e25 FLOP"]
df["cost"] = df["Training compute cost (2023 USD)"]
df["cost"] = df["cost"].str.replace(",", "").str.replace("$", "").astype(float)

df = df[["model", "compute", "date", "cost", "poss1e23", "poss1e25"]]

In [10]:
to_remove = ['AlphaGo Zero','AlphaZero']
df = df[~df["model"].isin(to_remove)]

In [11]:
to_append = [
  ["Claude 3.5 Sonnet", 4.3e25, "2024-06-21", np.nan, np.nan, np.nan],
  ["GPT-4o Mini", 1.2e25, "2024-07-18", np.nan, np.nan, np.nan],
]

for row in to_append:
  if row[0] not in df["model"].values:
    df.loc[len(df)] = row

In [None]:
to_add_compute = {
    "Claude 3 Opus": 2.5e25,
    "Claude 3 Sonnet": 1.1e25,
    "GPT-4o": 2.9e25,
    "Gemini 1.0 Pro": 2.8e24,
    "Gemini 1.5 Pro": 1.9e25,
    "Reka Core": 8.4e24,
    "GPT-4 Turbo": 2.1e25,  # rough guess
    "GPT-4V": 2.1e25,  # rough guess
    "Claude 2.1": df[df["model"]=="Claude 2"]["compute"].values,  # rough guess
}

for k, v in to_add_compute.items():
  if df.loc[df["model"] == k, "compute"].isna().values:
    df.loc[df["model"] == k, "compute"] = v
  else:
    print(f"{k} already has a compute value")

In [13]:
# Reset the ones we've set
df.loc[~df["compute"].isna(), "poss1e23"] = np.nan
df.loc[~df["compute"].isna(), "poss1e25"] = np.nan

# Set some temporary placeholder values
# TODO: revisit
# df.loc[(df["poss1e25"] == "checked"), "compute"] = 1.01e25  # placeholder
# df.loc[((df["poss1e23"] =="checked") & (df["poss1e25"] != "checked")), "compute"] = 1.01e23  # placeholder

# We want to handle these leading models manually via the above compute estimates.
assert df[(df["poss1e25"] == "checked") & (df["compute"].isna())].size == 0

# We sample 1e23-1e25 models with unknown compute from the existing empirical distribution.
# TODO: revisit
poss1e23 = ((df["poss1e23"] == "checked") & (df["poss1e25"] != "checked"))
df.loc[poss1e23, "compute"] = df[(df["compute"] >= 1e23) & (df["compute"] < 1e25)]["compute"].sample(poss1e23.sum(), random_state=0).values

df["date"] = pd.to_datetime(df["date"])
df["log_compute"] = np.log10(df["compute"])

df["date_float"] = df["date"].dt.year + df["date"].dt.month/12

df['year'] = df['date'].dt.year

df = df.sort_values("date")
df.dropna(subset="compute", inplace=True)

In [None]:
fig = sns.scatterplot(data=df[df['date']>'2010-01-01'], x='date',y='compute')
fig.set(yscale='log')
plt.grid(alpha=0.5)
%matplotlib inline

## Analysis class

In [None]:
## utils
def exponential_fit(x,a,b):
    return a*np.exp(b*x)

def x_transform_for_exp_fit(ref_x,x,inverse=False):
    
    '''
    Transform timestamps to ~ interval [0,50] for stable exp fit.
    Can also do inverse
    '''

    norm_const = ref_x.min()

    if not inverse:
        transformed_x = (x-norm_const)/1e7 #normalising x values
    else:
        transformed_x = 1e7*x + norm_const 

    return transformed_x


def sample_from_gmm(n_samples,params):
    
    mus,vars,ws = params
    mu_l,mu_h=mus
    var_l,var_h=vars
    w_l,w_h=ws
    std_l,std_h=np.sqrt(var_l),np.sqrt(var_h)

    components = np.random.choice([0,1], size=n_samples,p=[w_l.item(),w_h.item()])

    samples = np.where(
        components == 0,
        np.random.normal(loc=mu_l,scale=std_l,size=(1,n_samples)),
        np.random.normal(loc=mu_h,scale=std_h,size=(1,n_samples))
    )

    return samples

In [None]:
random_seed=42

from sklearn.mixture import GaussianMixture; random_seed=42
from scipy.stats import norm
from sklearn.linear_model import LinearRegression
from scipy.optimize import curve_fit
import pwlf

from dataclasses import dataclass

@dataclass
class AnalysisConfig:
    fit_start_date: str = '2017-01-01'
    fit_stop_date: str = '2024-01-01'
    predict_start_date: str = '2024-01-01'
    predict_stop_date: str = '2030-01-01'

class DataAnalysis():

    def __init__(self,df,window_freq='year'):

        self.df = df
        self.df['date'] = pd.to_datetime(self.df['date'])
        self.working_df = None

        self.start_time = '2017-01-01'
        self.stop_time = '2024-01-01'
        self.predict_start_time = '2024-01-01'
        self.predict_stop_time = '2030-01-01'
        self.window_freq = window_freq
        self.window_size = 'year'

        if self.window_freq=='quarter':
            times = pd.date_range(self.start_time,self.stop_time,freq='QS')
            predict_times = pd.date_range(self.predict_start_time,self.predict_stop_time,freq='QS') 
        elif self.window_freq=='biannual':
            times = pd.date_range(start=self.start_time,end=self.stop_time,freq='6MS')[1:-1] #indexing filters out startyear-01-01, endyear-01-01
            predict_times = pd.date_range(self.predict_start_time,self.predict_stop_time,freq='6M')[1:-1] 
        elif self.window_freq=='year':
            times = pd.date_range(start=self.start_time,end=self.stop_time,freq='AS-JUL')
            predict_times = pd.date_range(self.predict_start_time,self.predict_stop_time,freq='AS-JUL')
        else:
            raise ValueError('')
        
        if self.window_size=='year':
            times_lb = times - pd.DateOffset(months=6)
            times_ub = times + pd.DateOffset(months=6)

        #can use these to quickly filter df and get 
        self.window_times = times
        self.window_times_lb = times_lb
        self.window_times_ub = times_ub

        self.predict_times = predict_times

    def time_truncate_df(self,start='2017-01-01',end='2024-01-01'):

        self.working_df = self.df[(self.df['date']>'2017-01-01') & (self.df['date']<'2024-01-01')]

    def window_filter(self,window_time):

        '''
            Filter df based on window time and window size
            NOTE: We're hard coding in year long dataframe window size now

        
        '''

        filtering_condition = (self.working_df['date'] >= (window_time-pd.DateOffset(months=6))) & (self.working_df['date'] < (window_time+pd.DateOffset(months=6)))

        return filtering_condition 

    def fit_distributions(self,fit_type,plot=False):
        '''
        May want to look at doing a fit to the rolling windows
        
        '''

        FIT_TYPES = ['gaussian','gaussian mixture']
        if fit_type not in FIT_TYPES:
            raise ValueError(f'Invalid fit_type. Types: {FIT_TYPES}') 
        self.fit_type=fit_type

        params = {t:None for t in self.window_times}


        if fit_type=='gaussian':
            
            for t,t_lb,t_ub in list(zip(self.window_times,self.window_times_lb,self.window_times_ub)):
                date_filt_condition = (self.working_df['date']>=t_lb) & (self.working_df['date'] < t_ub)
                date_filt_df = self.working_df[date_filt_condition]
                log_compute_data = date_filt_df['log_compute']

                mean = log_compute_data.mean()
                std = log_compute_data.mean()
                params[t] = {'mean':mean,'std':std}

                if plot: 
                    fig,ax=plt.subplots()
                    plus_minus = "\u00B1"
                    sns.kdeplot(log_compute_data,label=f'timestamp: {t.date()} {plus_minus} 6mo ',linewidth=2,ax=ax)
  
                    mean = log_compute_data.mean()
                    std = np.sqrt(log_compute_data.var()) #simple for now
                    x=np.linspace(10,30,1000)
                    ax.plot(x,norm.pdf(x,loc=mean,scale=std))
                    ax.grid(); ax.legend(loc='upper left')

        
        if fit_type == 'gaussian mixture':

            for t in self.window_times:
                date_filt_condition = self.window_filter(window_time=t)
                date_filt_df = self.working_df[date_filt_condition]
                log_compute_data = date_filt_df['log_compute']

                gmm = GaussianMixture(n_components=2,random_state=2)
                gmm.fit(log_compute_data.to_numpy().reshape(-1,1))
                means,covariances,weights = gmm.means_,gmm.covariances_,gmm.weights_
                params[t] = {'means':means,
                            'covars':covariances,
                            'weights':weights}

        
        self.fitted_params = params

        return params
    
    def extrapolate_distributions(self):

        '''
        NOTE: Gaussian mixture covar and weight is very hacky right now
        '''
        
        if self.fit_type=='gaussian':
            
            #linear extrap means
            fit_dates = [t for t in self.fitted_params.keys()]
            fit_dates_float = np.array([t.timestamp() for t in fit_dates])
            means = np.array([self.fitted_params[t]['mean'] for t in self.fitted_params.keys()])
            predicted_dates_float = np.array([t.timestamp() for t in self.predict_times])


            model=LinearRegression()
            model.fit(fit_dates_float.reshape(-1,1),means)
            predicted_means = model.predict(predicted_dates_float.reshape(-1,1))
            retr_means = model.predict(fit_dates_float.reshape(-1,1))

            #sample std
            std_bounds = (1.1,1.6)
            predicted_stds = np.random.uniform(low=std_bounds[0],high=std_bounds[1],size=(predicted_means.shape))
            retr_stds = np.empty_like(retr_means); retr_stds.fill(np.nan) #not retrodicting params for now

            predicted_params = {t:(mu,std) for t,mu,std in list(zip(self.predict_times,predicted_means,predicted_stds))}
            retr_params = {t:(mu,std) for t,mu,std in list(zip(self.window_times,retr_means,retr_stds))}

            self.distribution_parameter_model = model #not good atm code atm 


        elif self.fit_type=='gaussian mixture':

            ##hacky parameterisation
            lower_var_bounds = (0.5,1.4) #heuristic set (0.5,1.4)
            upper_var_bounds = (0.5,1.4) #exluding ~3.5 var for upper dist in 2023. Heuristic set (0.5,1.2)
            lower_weights_bound = (0.15,0.35) #heuristically set (0.15,0.35)

            ##get data to fit
            lower_idx = [self.fitted_params[t]['means'].argmin() for t in self.fitted_params.keys()]
            higher_idx = [self.fitted_params[t]['means'].argmax() for t in self.fitted_params.keys()]

            lower_dist_means = np.array([self.fitted_params[t]['means'][idx] for t,idx in list(zip(self.fitted_params.keys(),lower_idx))])
            higher_dist_means = np.array([self.fitted_params[t]['means'][idx] for t,idx in list(zip(self.fitted_params.keys(),higher_idx))])

            lower_dist_covars = np.array([self.fitted_params[t]['covars'][idx] for t,idx in list(zip(self.fitted_params.keys(),lower_idx))])
            higher_dist_covars = np.array([self.fitted_params[t]['covars'][idx] for t,idx in list(zip(self.fitted_params.keys(),higher_idx))])
            
            lower_dist_weights = np.array([self.fitted_params[t]['weights'][idx] for t,idx in list(zip(self.fitted_params.keys(),lower_idx))])
            higher_dist_weights = np.array([self.fitted_params[t]['weights'][idx] for t,idx in list(zip(self.fitted_params.keys(),higher_idx))])


            fit_times_float = np.array([t.timestamp() for t in self.window_times])
            predict_times_float = np.array([t.timestamp() for t in self.predict_times])



            ##predict/retrodict means
            lower_dist_mean_model = pwlf.PiecewiseLinFit(fit_times_float,lower_dist_means)
            lower_dist_mean_model.fit(n_segments=2)
            pred_lower_means = lower_dist_mean_model.predict(predict_times_float)
            retr_lower_means = lower_dist_mean_model.predict(fit_times_float)

            higher_dist_mean_model = pwlf.PiecewiseLinFit(fit_times_float,higher_dist_means)
            higher_dist_mean_model.fit(n_segments=2)
            pred_higher_means = higher_dist_mean_model.predict(predict_times_float)
            retr_higher_means = higher_dist_mean_model.predict(fit_times_float)

            ##extrapolate vars
            pred_lower_vars = np.random.uniform(low=lower_var_bounds[0],high=lower_var_bounds[1],size=(pred_lower_means.shape))
            pred_higher_vars = np.random.uniform(low=upper_var_bounds[0],high=upper_var_bounds[1],size=(pred_higher_means.shape))
            retr_lower_vars = np.empty_like(retr_lower_means); retr_lower_vars.fill(np.nan)
            retr_higher_vars = np.empty_like(retr_higher_means); retr_higher_vars.fill(np.nan)

            ##extraplate weights 
            pred_lower_weights = np.random.uniform(low=lower_weights_bound[0],high=lower_weights_bound[1],size=(pred_lower_means.shape))
            retr_lower_weights = np.empty_like(retr_lower_means); retr_lower_weights.fill(np.nan)

            ##set predicted params state var
            predicted_params = {t:((mu1,mu2),(var1,var2),(w1,w2)) for
                                     t,mu1,mu2,var1,var2,w1,w2 in
                                     list(zip(self.predict_times,
                                              pred_lower_means,pred_higher_means,
                                              pred_lower_vars,pred_higher_vars,
                                              pred_lower_weights,1-pred_lower_weights))
                                    }
            
            retr_params = {t:((mu1,mu2),(var1,var2),(w1,w2)) for
                                        t,mu1,mu2,var1,var2,w1,w2 in
                                        list(zip(self.window_times,
                                                retr_lower_means,retr_higher_means,
                                                retr_lower_vars,retr_higher_vars,
                                                retr_lower_weights,1-retr_lower_weights))
                                        }
            

        else:
            pass

        self.predicted_params = predicted_params
        self.retrodicted_params = retr_params

        return predicted_params
    
    def model_counts(self,counts_fit_type):


        ## COULD add a kinked exponential
        COUNT_FIT_TYPES=['linear','exponential','kinked linear']

        if counts_fit_type not in COUNT_FIT_TYPES: raise ValueError(f'Expected fit in {COUNT_FIT_TYPES}')
    
        
        #time_data is bad var name but leftover from old code
        time_data = {t:{'size':None,
                   } 
                   for t in self.window_times}
        
        for t,t_lb,t_ub in list(zip(self.window_times,self.window_times_lb,self.window_times_ub)):

            if t_lb < pd.Timestamp(self.start_time) or t_ub > pd.Timestamp(self.stop_time): 
                print(f'Skipping {t} - window not in range')
                time_data[t]['size']=None

                continue
            else:
                date_filt_condition = (self.working_df['date']>=t_lb) & (self.working_df['date'] < t_ub)
                date_tmp_df = self.working_df[date_filt_condition] #filtered df
                time_data[t]['size']=date_tmp_df.shape[0]

        #perform fitting
    
        fit_counts = np.array([t['size'] for t in time_data.values()])
        fit_times_float = np.array([t.timestamp() for t in self.window_times])
        predict_times_float = np.array([t.timestamp() for t in self.predict_times])
 
        if counts_fit_type=='linear':
            model = LinearRegression()
            model.fit(fit_times_float.reshape(-1,1),fit_counts)
            predicted_counts = model.predict(predict_times_float.reshape(-1,1))
            retr_counts = model.predict(fit_times_float.reshape(-1,1)).astype('int')

        elif counts_fit_type=='exponential':
            transformed_fit_x = x_transform_for_exp_fit(ref_x=fit_times_float,x=fit_times_float)
            popt,pcov = curve_fit(exponential_fit,transformed_fit_x,fit_counts)    
            a,b = popt; model = popt
            transformed_pred_x = x_transform_for_exp_fit(ref_x=fit_times_float,x=predict_times_float)
            predicted_counts = exponential_fit(transformed_pred_x,a=a,b=b)
            retr_counts = exponential_fit(transformed_fit_x,a=a,b=b)

        elif counts_fit_type=='kinked linear':
            model = pwlf.PiecewiseLinFit(fit_times_float,fit_counts)
            breakpoints = model.fit(2)
            predicted_counts = model.predict(predict_times_float)
            retr_counts = model.predict(fit_times_float)

        else:
            pass

        
        #set state vars
        self.fit_counts = fit_counts
        self.predicted_counts = predicted_counts.astype('int')
        self.count_fit_type = counts_fit_type
        self.retr_counts = retr_counts
        self.counts_model = model

        return predicted_counts
    
    def count_threshold_models(self):

        thresholds = np.arange(23,30+1)
        threshold_counts_df = pd.DataFrame(columns=thresholds,index=self.predict_times)

        #not doing rollouts yet
        for pred_t,params,counts in list(zip(self.predict_times,self.predicted_params.values(),self.predicted_counts)):
            
            if self.fit_type == 'gaussian':
                mu,sigma = params
                log_compute_samples = norm.rvs(loc=mu,scale=sigma,size=counts)
                for thr in thresholds:
                    n_exceed = log_compute_samples[log_compute_samples>=thr].size
                    threshold_counts_df.at[pred_t,thr] = n_exceed

            if self.fit_type== 'gaussian mixture':
                mus,vars,ws = params
                mu_l,mu_h = mus
                var_l,var_h = vars
                w_l,w_h = ws 
                std_l,std_h = np.sqrt(var_l),np.sqrt(var_h)

                log_compute_samples = sample_from_gmm(n_samples=counts,params=params)
                for thr in thresholds:
                    n_exceed = log_compute_samples[log_compute_samples>=thr].size
                    threshold_counts_df.at[pred_t,thr] = n_exceed
        

        self.threshold_counts = threshold_counts_df
        return threshold_counts_df
    
    def verify_with_retrodiction(self):

        '''
        Params:
            n_years_retr: Retrodict n years back

        Return:

        Notes:
            - Don't think this is adapted for rolling windows yet (?)
            - A lot of this isn't 'true' retrodiction - we use some fitted/observed values
        '''

        retrodict_times = self.window_times #retrodict for fitted time stamps
        retrodict_times_float = np.array([t.timestamp() for t in retrodict_times])
        retr_counts = self.retr_counts


        thresholds = [23,24]
        predicted_past_counts = pd.DataFrame(index=retrodict_times,columns=thresholds)
        observed_past_counts = pd.DataFrame(index=retrodict_times,columns=thresholds)
        percent_error_df = pd.DataFrame(np.nan,index=retrodict_times,columns=thresholds)



        #pretty inefficient way to do it right now
        for idx,t in enumerate(retrodict_times):

            ##generate distributions and counts
            count = int(retr_counts[idx])

            if self.fit_type=='gaussian':
                mean = self.retrodicted_params[t][0]
                std = self.retrodicted_params[t][1]

                if np.isnan(std): std = self.working_df[self.window_filter(t)]['log_compute'].std() #USE EMPIRICAL STDS

                pred_log_compute_data = norm.rvs(loc=mean,scale=std,size=count)

            elif self.fit_type=='gaussian mixture':
            
                #unpack retrodicted params
                mus,vars,ws = self.retrodicted_params[t]
                mu_l,mu_h = mus
                var_l,var_h = vars
                w_l,w_h = ws
 
                if np.isnan(var_l): var_l = self.fitted_params[t]['covars'][0] #use fitted vars
                if np.isnan(var_h): var_h = self.fitted_params[t]['covars'][1] #use fitted vars
                if np.isnan(w_l): w_l = self.fitted_params[t]['weights'][0] #use fitted gmm weights
                if np.isnan(w_h): w_h = 1-w_l #use fitted gmm weights

                w_l,w_h = np.array([w_l]), np.array([w_h]) #for type compatibility with other parts of workflow

                params_retr = ((mu_l,mu_h),(var_l,var_h),(w_l,w_h))
                pred_log_compute_data = sample_from_gmm(n_samples=count,params=params_retr)

            ##get obs log compute data
            obs_log_compute_data = self.working_df[self.window_filter(t)]['log_compute']

            ##do threshold counts
            for thr in thresholds:
                #pred
                thr_count_pr = pred_log_compute_data[pred_log_compute_data>=thr].size
                predicted_past_counts.at[t,thr] = thr_count_pr

                #obs
                thr_count_ob = obs_log_compute_data[obs_log_compute_data>=thr].size
                observed_past_counts.at[t,thr] = thr_count_ob

        abs_diff = np.abs(observed_past_counts-predicted_past_counts)
        obs_df_safe = observed_past_counts.replace(0,np.nan) #for safe division
        percent_error_df = (abs_diff/obs_df_safe)*100
             
             
        self.predicted_past_counts = predicted_past_counts
        self.observed_past_counts = observed_past_counts

        return predicted_past_counts,observed_past_counts,percent_error_df

    def verify_with_training_compute(self):
         
        return None
    

In [None]:
tmp_df = df[(df['date']>'2017-01-01') & (df['date']<'2024-01-01')]
tmp_df['date'] = pd.to_datetime(tmp_df['date'])

window_freq = 'year' 
dist_fit='gaussian mixture'
count_fit='linear'

analysis = DataAnalysis(df=tmp_df,window_freq=window_freq)
analysis.time_truncate_df()
params = analysis.fit_distributions(fit_type=dist_fit,plot=False)
predicted_params = analysis.extrapolate_distributions()
predicted_counts = analysis.model_counts(counts_fit_type=count_fit)
threshold_counts = analysis.count_threshold_models()
pred_past_counts,obs_past_counts,percent_error_df = analysis.verify_with_retrodiction()

In [None]:
import itertools, warnings

distribution_fits = ['gaussian','gaussian mixture']
model_count_fits = ['linear','exponential','kinked linear']
window_freq = 'year'
dist_fit = 'gaussian mixture'
count_fit = 'linear'



for (dist_fit,count_fit) in list(itertools.product(distribution_fits,model_count_fits)):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        analysis = DataAnalysis(df=tmp_df,window_freq=window_freq)
        analysis.time_truncate_df()
        _ = analysis.fit_distributions(fit_type=dist_fit,plot=False)
        _ = analysis.extrapolate_distributions() 
        _ = analysis.model_counts(counts_fit_type=count_fit)
        threshold_counts = analysis.count_threshold_models()
        pred_past_counts,obs_past_counts,percent_error_df = analysis.verify_with_retrodiction()

## Explore gmms

In [None]:
tmp_df = df[(df['date']>'2017-01-01') & (df['date']<'2024-01-01')]
tmp_df['date'] = pd.to_datetime(tmp_df['date'])

In [None]:
from sklearn.mixture import GaussianMixture; random_seed=42

years = tmp_df['year'].unique()

years_choice = [2018,2019,2020,2021,2022,2023]
nrows,ncols = (3,2)

INDIVIDUAL_PLOTS = True
GMM_FIT = True
GMM_params = {year:{} for year in years_choice}

if INDIVIDUAL_PLOTS:
    n_plots = len(years_choice)
    fig,axs = plt.subplots(ncols=ncols,nrows=nrows,figsize=(8,6))
    axs = axs.ravel()
    for ax in axs: ax.grid();ax.legend()
    for idx,year in enumerate(years_choice):
        ax = axs[idx]
        log_compute_data = tmp_df[tmp_df['year']==year]['log_compute']
        sns.kdeplot(log_compute_data,ax=ax,alpha=0.5,label=f'{year} KDE')

        if GMM_FIT:
            gmm = GaussianMixture(n_components=2,random_state=random_seed)
            gmm.fit(log_compute_data.to_numpy().reshape(-1,1))
            means,covariances,weights = gmm.means_,gmm.covariances_,gmm.weights_
            x = np.linspace(14,28,1000)
            ax.plot(x,np.exp(gmm.score_samples(x.reshape(-1,1))),label=f'gmm fit, n = {len(log_compute_data)}') #gmm calculates log probs
            GMM_params[year]['means'] = means
            GMM_params[year]['covariances'] = covariances
            GMM_params[year]['weights'] = weights

        ax.set_ylim([0,0.4])
        ax.legend()

else: 
    fig,ax=plt.subplots()
    for idx,year in enumerate(years):
        if year not in years_choice:continue
        sns.kdeplot(tmp_df[tmp_df['year']==year]['log_compute'],ax=ax,alpha=0.5,label=year)

fig.tight_layout()
fig.set_size_inches(w=8,h=8)

In [None]:
##visualise fitted params

fig,(ax1,ax2,ax3) = plt.subplots(ncols=3)

low_idx = [GMM_params[year]['means'].argmin() for year in years_choice] #get the idx for the lower dist (by mean)
high_idx = [GMM_params[year]['means'].argmax() for year in years_choice] #get the idx for the upper dist (by mean)

lower_dist_means = [GMM_params[year]['means'][idx] for year,idx in list(zip(years_choice,low_idx))]
upper_dist_means = [GMM_params[year]['means'][idx] for year,idx in  list(zip(years_choice,high_idx))]

lower_dist_vars = [GMM_params[year]['covariances'][idx] for year,idx in list(zip(years_choice,low_idx))]
upper_dist_vars = [GMM_params[year]['covariances'][idx] for year,idx in  list(zip(years_choice,high_idx))]

lower_dist_weight = [GMM_params[year]['weights'][idx] for year,idx in list(zip(years_choice,low_idx))]
upper_dist_weight = [GMM_params[year]['weights'][idx] for year,idx in list(zip(years_choice,high_idx))]

plt.rcParams['scatter.marker'] = 'x'

if 1:
    ax1.set_title('means')
    ax1.scatter(years_choice,lower_dist_means,c='b',label='lower dist')
    ax1.scatter(years_choice,upper_dist_means,c='r',label='upper dist')
    ax1.grid(); ax1.legend()

    ax2.set_title('vars')
    ax2.scatter(years_choice,lower_dist_vars,c='b',label='low dist')
    ax2.scatter(years_choice,upper_dist_vars,c='r',label='upper dist')
    ax2.grid(); ax2.legend()
    ax2.set_ylim([0.5,1.5])

    ax3.set_title('weights')
    ax3.scatter(years_choice,lower_dist_weight,c='b',label='low dist')
    #ax3.scatter(years_choice,upper_dist_weight,c='r',label='upper dist')
    ax3.grid(); ax3.legend()
    ax3.set_ylim([0,1.0])
    #lower std bound: [0.5,1.4]
    #upper std bound: [0.5,1.2]

    fig.set_size_inches(w=10,h=6)
    fig.tight_layout()

In [None]:
## extrapolate gmms and visualise

np.random.seed()

import pwlf

PLOT_EXTRAPOLATED_PARAMS = True

fit_years = years_choice
predict_years = np.arange(2024,2029+1) #extrapolate

lower_var_bounds = (0.5,1.4) #heuristic set (0.5,1.4)
upper_var_bounds = (0.5,1.4) #exluding ~3.5 var for upper dist in 2023. Heuristic set (0.5,1.2)
lower_weights_bound = (0.15,0.35) # vibes based
lower_weights_bound_2 = (0.5,0.5) #heuristic set (0.25,0.15)

#lower dist mean model 
fit_data =  np.array(lower_dist_means)
lower_dist_mean_model = pwlf.PiecewiseLinFit(fit_years,fit_data)
lower_dist_mean_model.fit(n_segments=2)
pred_lower_means = lower_dist_mean_model.predict(predict_years)
pred_lower_vars = np.random.uniform(low=lower_var_bounds[0],high=lower_var_bounds[1],size=len(predict_years))
pred_lower_weights = np.random.uniform(low=lower_weights_bound[0],high=lower_weights_bound[1],size=len(predict_years))
pred_lower_weights_2 = np.linspace(lower_weights_bound_2[0],lower_weights_bound_2[1],num=len(predict_years))


#upper dist mean model
fit_data = np.array(upper_dist_means)
upper_dist_mean_model = pwlf.PiecewiseLinFit(fit_years,fit_data)
upper_dist_mean_model.fit(n_segments=2)
pred_upper_means = upper_dist_mean_model.predict(predict_years)
pred_upper_vars = np.random.uniform(low=upper_var_bounds[0],high=upper_var_bounds[1],size=len(predict_years))
pred_upper_weights = 1 - pred_lower_weights
pred_upper_weights_2 = 1 - pred_lower_weights_2



if PLOT_EXTRAPOLATED_PARAMS:
    extrap_marker='o'
    fit_marker='x'
    low_color='b' 
    upper_color='r'

    fig,axs = plt.subplots(nrows=1,ncols=3)
    for ax in axs: 
        ax.grid()
        ax.set_xticklabels(np.concatenate([fit_years,predict_years]),rotation=45)
        ax.set_xticks(np.concatenate([fit_years,predict_years]))
    ax1,ax2,ax3=axs
    

    #means

    # Plotting the extrapolated parameters
    ax1.scatter(fit_years, lower_dist_means, marker=fit_marker, c=low_color)
    ax1.scatter(predict_years, pred_lower_means, marker=extrap_marker, c=low_color)

    ax1.scatter(fit_years, upper_dist_means, marker=fit_marker, c=upper_color)
    ax1.scatter(predict_years, pred_upper_means, marker=extrap_marker, c=upper_color)

    # Standard deviations
    ax2.scatter(fit_years, np.sqrt(lower_dist_vars), marker=fit_marker, c=low_color)
    ax2.scatter(predict_years, np.sqrt(pred_lower_vars), marker=extrap_marker, c=low_color)

    ax2.scatter(fit_years, np.sqrt(upper_dist_vars), marker=fit_marker, c=upper_color)
    ax2.scatter(predict_years, np.sqrt(pred_upper_vars), marker=extrap_marker, c=upper_color)

    # Weights
    ax3.scatter(fit_years, lower_dist_weight, marker=fit_marker, c=low_color)
    ax3.scatter(predict_years, pred_lower_weights, marker=extrap_marker, c=low_color)

    ax3.scatter(fit_years, upper_dist_weight, marker=fit_marker, c=upper_color)
    ax3.scatter(predict_years, pred_upper_weights, marker=extrap_marker, c=upper_color)

    #std 
    fig.tight_layout()
    fig.set_size_inches(w=10,h=4)


x=np.linspace(14,35,1000)

nrows,ncols=(3,2)
fig,axs=plt.subplots(nrows=nrows,ncols=ncols)
axs = axs.ravel()

for idx,(year,mu_l,mu_u,var_l,var_u,w_l,w_u) in enumerate(list(zip(
    predict_years,
    pred_lower_means,pred_upper_means,
    pred_lower_vars,pred_upper_vars,
    pred_lower_weights_2,pred_upper_weights_2,
))):

    ax=axs[idx]
    std_l,std_u = np.sqrt(var_l),np.sqrt(var_u)
    pdf = w_l*norm.pdf(x,loc=mu_l,scale=std_l) + w_u*norm.pdf(x,loc=mu_u,scale=std_u)
    ax.plot(x,pdf,label=f'{year}')
    ax.grid();ax.legend()
        

fig.tight_layout()
fig.set_size_inches(w=6,h=8)

#lower dist std devs
