In [None]:
!pip install pwlf #for colab

## Setup and preprocess df

In [None]:
import time

modules = [
    ('numpy', 'np'),
    ('scipy.stats', 'stats'),
    ('scipy.optimize', 'optimize'), 
    ('matplotlib.pyplot', 'plt'),
    ('pandas', 'pd'),
    ('seaborn', 'sns'),
    ('itertools', 'itertools'),
    ('copy', 'copy'),
    ('re', 're'),
    ('pdb', 'pdb'),
    ('logging', 'logging')
]

for module, alias in modules:
    start = time.time()
    exec(f"import {module} as {alias}")
    end = time.time()
    print(f"{module}: {end - start:.4f} seconds")

In [8]:
import numpy as np
from scipy import stats, optimize
import matplotlib.pyplot as plt
import pandas as pd #taking long to load here
import seaborn as sns
import itertools
import copy,re, pdb, logging

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger=logging.getLogger(__name__)


In [9]:
# df = pd.read_csv("https://epochai.org/data/epochdb/notable_systems.csv")
url = 'https://drive.google.com/file/d/1RLLKPU3bEYK65wlQlU0p20u9M8cHkLMl/view?usp=sharing'
url = 'https://drive.google.com/uc?id=' + url.split('/')[-2]

df = pd.read_csv(url)

df = df[~df["Notability criteria"].isna()]

df["compute"] = df["Training compute (FLOP)"]
df["date"] = df["Publication date"]
df["model"] = df["System"]
df["poss1e23"] = df["Possibly over 1e23 FLOP"]
df["poss1e25"] = df["Estimated over 1e25 FLOP"]
df["cost"] = df["Training compute cost (2023 USD)"]
df["cost"] = df["cost"].str.replace(",", "").str.replace("$", "").astype(float)

df = df[["model", "compute", "date", "cost", "poss1e23", "poss1e25"]]

In [10]:
to_remove = ['AlphaGo Zero','AlphaZero'] #outliers
df = df[~df["model"].isin(to_remove)]

In [11]:
to_append = [
  ["Claude 3.5 Sonnet", 4.3e25, "2024-06-21", np.nan, np.nan, np.nan],
  ["GPT-4o Mini", 1.2e25, "2024-07-18", np.nan, np.nan, np.nan],
]

for row in to_append:
  if row[0] not in df["model"].values:
    df.loc[len(df)] = row

In [None]:
to_add_compute = {
    "Claude 3 Opus": 2.5e25,
    "Claude 3 Sonnet": 1.1e25,
    "GPT-4o": 2.9e25,
    "Gemini 1.0 Pro": 2.8e24,
    "Gemini 1.5 Pro": 1.9e25,
    "Reka Core": 8.4e24,
    "GPT-4 Turbo": 2.1e25,  # rough guess
    "GPT-4V": 2.1e25,  # rough guess
    "Claude 2.1": df[df["model"]=="Claude 2"]["compute"].values,  # rough guess
}

logger.info('Can add more recent models here')


for k, v in to_add_compute.items():
  if df.loc[df["model"] == k, "compute"].isna().values:
    df.loc[df["model"] == k, "compute"] = v
  else:
    print(f"{k} already has a compute value")

In [13]:
# Reset the ones we've set
df.loc[~df["compute"].isna(), "poss1e23"] = np.nan
df.loc[~df["compute"].isna(), "poss1e25"] = np.nan

# Set some temporary placeholder values
# TODO: revisit
# df.loc[(df["poss1e25"] == "checked"), "compute"] = 1.01e25  # placeholder
# df.loc[((df["poss1e23"] =="checked") & (df["poss1e25"] != "checked")), "compute"] = 1.01e23  # placeholder

# We want to handle these leading models manually via the above compute estimates.
assert df[(df["poss1e25"] == "checked") & (df["compute"].isna())].size == 0

# We sample 1e23-1e25 models with unknown compute from the existing empirical distribution.
# TODO: revisit
poss1e23 = ((df["poss1e23"] == "checked") & (df["poss1e25"] != "checked"))
df.loc[poss1e23, "compute"] = df[(df["compute"] >= 1e23) & (df["compute"] < 1e25)]["compute"].sample(poss1e23.sum(), random_state=0).values

df["date"] = pd.to_datetime(df["date"])
df["log_compute"] = np.log10(df["compute"])

df["date_float"] = df["date"].dt.year + df["date"].dt.month/12

df['year'] = df['date'].dt.year

df = df.sort_values("date")
df.dropna(subset="compute", inplace=True)

In [None]:
#generate basic scatterplot
if 1:
    fig = sns.scatterplot(data=df[df['date']>'2010-01-01'], x='date',y='compute')
    fig.set(yscale='log')
    plt.grid(alpha=0.5)

    # Add line of best fit for historical data
    historical_data = df[df['date']>'2010-01-01']
    x = historical_data['date'].astype(np.int64) // 10**9  # Convert to unix timestamp
    y = historical_data['compute']
    z = np.polyfit(x, np.log(y), 1)
    p = np.poly1d(z)
    plt.plot(historical_data['date'], np.exp(p(x)), 'b--', alpha=0.8)

    future_dates = pd.date_range(start='2025-01-01', end='2029-12-31', periods=200)
    base = 1e25  # Starting point based on 2024 level
    noise = np.random.normal(0, 10, len(future_dates))
    years_from_2025 = (future_dates.year - 2025)

    growth_rate = 3.0  # Exponential growth rate
    future_compute = base * np.exp(growth_rate * years_from_2025) * (1 + noise)
    plt.scatter(future_dates, future_compute, alpha=0.3, color='red', label='Projected - business as usual')

    growth_rate = 0.4
    future_compute = base * np.exp(growth_rate * years_from_2025) * (1 + noise)
    plt.scatter(future_dates, future_compute, alpha=0.3, color='green', label='Projected - inference scaling')

    plt.legend()
    plt.xlim([pd.Timestamp('2020-01-01'),pd.Timestamp('2030-01-01')])

    for exp in range(25,31):
        plt.axhline(y=10**exp,color='gray',linestyle='--',alpha=0.6)



In [None]:
FLOP_dollar=2e25/100e6 #FLOP per dollar conversion ~2023 (GPT-4 was ~2e25 FLOP for estimated $1e8)


fig = sns.scatterplot(data=df[df['date']>'2010-01-01'], x='date',y=(1/FLOP_dollar)*df['compute'])
fig.set(yscale='log')
plt.grid(alpha=0.5)
plt.axhline(y=1e14,label='World GDP',color='red',linestyle='--',alpha=0.8)
plt.axhline(y=30e12,label='US GDP',color='orange',linestyle='--',alpha=0.8)
plt.axhline(y=40e9,label='Meta R&D budget 2023',color='green',linestyle='--',alpha=0.8)
plt.axhline(y=100e6,label='GPT-4 training cost (est)',color='purple',linestyle='--',alpha=0.8)

# Add future projections
future_dates = pd.date_range(start='2024-01-01', end='2029-12-31', periods=500)
base = (1/FLOP_dollar) * 2e25  # Starting point based on 2024 level
noise = np.random.normal(0, 10, len(future_dates))
years_from_2024 = (future_dates.year - 2024)

growth_rate = 3.0  # Exponential growth rate
future_costs = base * np.exp(growth_rate * years_from_2024) * (1 + noise)
plt.scatter(future_dates, future_costs, alpha=0.3, color='red', label='Projected - business as usual')

#growth_rate = 0.4
#future_costs = base * np.exp(growth_rate * years_from_2024) * (1 + noise)
#plt.scatter(future_dates, future_costs, alpha=0.3, color='green', label='Projected - inference scaling')



plt.legend()
plt.xlim([pd.Timestamp('2020-01-01'),pd.Timestamp('2030-01-01')])
plt.ylabel("Training compute cost ($)")

In [None]:


year_filter=[2020,2021,2022,2023]
group_param=5
table=pd.DataFrame(index=[f'Group {i}' for i in range(group_param)],columns=year_filter)


for year in df['date'].dt.year.unique():
    if year not in year_filter: continue
    year_data = df[df['date'].dt.year == year]
    print(f"\nYear: {year}")
    sorted_year_data=year_data.sort_values(by='compute',ascending=False)['compute']
    grouped_data=pd.qcut(sorted_year_data,q=group_param,labels=False)
    for group in range(group_param):
        group_data = sorted_year_data[grouped_data == group]
        group_share = group_data.sum() / year_data['compute'].sum() * 100
        table.loc[f'Group {group}',year] = group_share
        print(f"Group {group}: {group_share:.1f}% of total compute")



# Plot pie chart of latest year's data
latest_year = max(year_filter)
latest_data = table[latest_year]
plt.figure(figsize=(8,8))
plt.pie(latest_data, labels=[f'Group {i}' for i in range(group_param)], autopct='%1.1f%%')
plt.title(f'Share of Total Compute by Group ({latest_year})')



        
    



    


In [None]:
from sklearn.linear_model import LinearRegression

FLOP_dollar_2024 = 2e25/100e6
dollar_FLOP_2024 = 1/FLOP_dollar_2024
year_grouped_df=df.groupby(df['date'][df['date']>'2010-01-01'].dt.year)
aggregate_compute=year_grouped_df['compute'].sum()
aggregate_compute_cost=aggregate_compute*dollar_FLOP_2024
log_aggregate_compute=np.log10(aggregate_compute)
log_aggregate_compute_cost=np.log10(aggregate_compute_cost)
#plot
# Plot historical data
plt.figure(figsize=(10,6))
plt.scatter(year_grouped_df.groups.keys(), log_aggregate_compute_cost, label='Historical data')

# Fit exponential for extrapolation
# Linear regression
x = np.array(list(year_grouped_df.groups.keys())).reshape(-1, 1)
y = log_aggregate_compute_cost.values
reg = LinearRegression().fit(x, y)

# Generate future years for extrapolation
future_years = np.arange(max(x), 2030).reshape(-1, 1)

# Get predictions
future_predictions = reg.predict(future_years)


# Plot extrapolation
plt.plot(future_years, future_predictions, '--', label='Extrapolation')
plt.xlabel('Year')
plt.ylabel('Log10(Total Compute)')
plt.legend()
plt.grid(True)

In [None]:
total_compute_2028 = 1e30
cost_2024 = total_compute_2028 * dollar_FLOP_2024
print(f"With 2024 FLOP/dollar costs, the cost of {total_compute_2028} FLOP is approx {cost_2024/1e12:,.2f} trillion USD")

#case 1 - ~ 9 models with 1e29, 100 models with 1e27 
#case 1 - ~ 9 models with 

#case 2 - ~10000 models with 1e26, 0 models above that

#case 3 - 1 model 1e29, 10 models 1e28, 100 models 1e27, 1000 models 1e26 etc. 

years_to_iter=[2020,2021,2022,2023]
fig,axs=plt.subplots(nrows=2,ncols=2,figsize=(8,6)); axs_ravel=axs.ravel()
kde_fig,kde_axs=plt.subplots(nrows=2,ncols=2,figsize=(8,6)); kde_axs_ravel=kde_axs.ravel()

def percentage_formatter(x,pos):
        return f'{x:.6f}%'



for idx,year in enumerate(years_to_iter):
        ax,kde_ax=axs_ravel[idx], kde_axs_ravel[idx]
        total_compute=aggregate_compute[aggregate_compute.index==year].values
        cost_2023=total_compute*dollar_FLOP_2024
        datapoints_year=df[df['date'].dt.year==year]['compute']
        mean_log_compute=np.log10(datapoints_year).mean()

        #prep data
        sorted_computes=np.sort(datapoints_year)
        norm_factor=total_compute[0]
        norm_sorted_computes=sorted_computes/norm_factor
        cumsum=np.cumsum(sorted_computes)
        norm_cumsum=cumsum/norm_factor



        #T-m plot
        ax.plot(norm_sorted_computes,norm_cumsum)
        ax.scatter(norm_sorted_computes, norm_cumsum, alpha=0.5, color='blue', s=30,marker='x')

        ax.grid(True,alpha=0.3)
        ax.set_xscale('log'); ax.set_yscale('log')
        #ax.set_xlim([1e18,1e27])
        ax.set_xlabel('individual model size'); ax.set_ylabel('Total training compute')
        ax.set_title(f'Year: {year}')
        ax.text(0.05, 0.95, f'Total compute: {total_compute[0]:.2e} FLOP', 
                transform=ax.transAxes, verticalalignment='top')
        ax.axhline(y=norm_cumsum[-1],color='r',linestyle='--')
        ax.axvline(x=1,color='g',linestyle='--',alpha=0.5)
        ax.text(1,ax.get_ylim()[0],f'{norm_factor:.2e}',
                rotation=90,fontsize=8,verticalalignment='top')
        ax.yaxis.set_major_formatter(percentage_formatter)

        #KDE plot 
        kde=stats.gaussian_kde(np.log10(norm_sorted_computes))
        x_range=np.logspace(np.log10(norm_sorted_computes).min(),np.log10(1))
        kde_ax.plot(x_range,kde(np.log10(x_range)))
        kde_ax.set_xscale('log')
        kde_ax.set_title(f'Year: {year}')
        kde_ax.grid(alpha=0.5)

        kde_ax.axvline(x=1,color='g',linestyle='--',alpha=0.5)
        kde_ax.text(1,ax.get_ylim()[0],f'{norm_factor:.2e}',
                rotation=90,fontsize=8,verticalalignment='top')
        if idx>=2: kde_ax.set_xlabel('Model compute (normalised by total)')

fig.tight_layout(pad=2.0)
kde_fig.tight_layout(pad=2.0)

In [None]:
T=245
N=10
a,b=23,26

# Generate all possible integer combinations between log_a and log_b
possible_values = np.arange(a, b+1).astype(float)
all_combinations = list(itertools.combinations_with_replacement(possible_values, N))

# Filter combinations that sum to target
valid_combinations = []
for combo in all_combinations:
    if np.sum(combo)==T:
        valid_combinations.append(combo)

valid_distributions = np.array(valid_combinations)

print(valid_distributions)

## Analysis class

In [None]:
# utils functions

 
def exponential_fit(x,a,b):
    return a*np.exp(b*x)

def x_transform_for_exp_fit(ref_x,x,inverse=False):
    
    '''
    Transform timestamps to ~ interval [0,50] for stable exp fit.
    Can also do inverse
    '''

    norm_const = ref_x.min()

    if not inverse:
        transformed_x = (x-norm_const)/1e7 #normalising x values
    else:
        transformed_x = 1e7*x + norm_const 

    return transformed_x


def sample_from_gmm(n_samples,params):
    
    mus,vars,ws = params
    mu_l,mu_h=mus
    var_l,var_h=vars
    w_l,w_h=ws
    std_l,std_h=np.sqrt(var_l),np.sqrt(var_h)

    components = np.random.choice([0,1], size=n_samples,p=[w_l.item(),w_h.item()])

    samples = np.where(
        components == 0,
        np.random.normal(loc=mu_l,scale=std_l,size=(1,n_samples)),
        np.random.normal(loc=mu_h,scale=std_h,size=(1,n_samples))
    )

    return samples

def gmm_density_plot(params,x_min,x_max):
    
        mus,vars,ws = params
        mu_l,mu_h=mus
        var_l,var_h=vars[0],vars[1]
        w_l,w_h=ws[0],ws[1]
        std_l,std_h=np.sqrt(var_l),np.sqrt(var_h)

        x = np.linspace(x_min, x_max, 1000).reshape(-1,1)
        pdf = np.zeros_like(x)

        for mu, var, w in zip([mu_l,mu_h],[var_l,var_h],[w_l,w_h]):
            pdf += w*stats.norm.pdf(x, mu, np.sqrt(var))

        return x,pdf

x,fitted_pdf = gmm_density_plot([np.array([1.5,2.5]),np.array([0.5,0.5]),np.array([0.5,0.5])],0,5)

In [None]:
from sklearn.mixture import GaussianMixture; 
from scipy.stats import norm
from sklearn.linear_model import LinearRegression
from scipy.optimize import curve_fit
import pwlf, random

random_seed=42
random.seed(random_seed)
np.random.seed(random_seed)

from dataclasses import dataclass



class DataAnalysis():

    def __init__(self,df,window_freq='year'):

        self.df = df
        self.df['date'] = pd.to_datetime(self.df['date'])
        self.working_df = None

        self.start_time = '2017-01-01'
        self.stop_time = '2024-01-01'
        self.predict_start_time = '2024-01-01'
        self.predict_stop_time = '2030-01-01'
        self.window_freq = window_freq
        self.window_size = 'year'

        if self.window_freq=='quarter':
            times = pd.date_range(self.start_time,self.stop_time,freq='QS')
            predict_times = pd.date_range(self.predict_start_time,self.predict_stop_time,freq='QS') 
        elif self.window_freq=='biannual':
            times = pd.date_range(start=self.start_time,end=self.stop_time,freq='6MS')[1:-1] #indexing filters out startyear-01-01, endyear-01-01
            predict_times = pd.date_range(self.predict_start_time,self.predict_stop_time,freq='6M')[1:-1] 
        elif self.window_freq=='year':
            times = pd.date_range(start=self.start_time,end=self.stop_time,freq='AS-JUL')
            predict_times = pd.date_range(self.predict_start_time,self.predict_stop_time,freq='AS-JUL')
        else:
            raise ValueError('')
        
        if self.window_size=='year':
            times_lb = times - pd.DateOffset(months=6)
            times_ub = times + pd.DateOffset(months=6)

        #can use these to quickly filter df and get 
        self.fit_times = times

        self.predict_times = predict_times

    def truncate_df_dates(self,start='2017-01-01',end='2024-01-01'):

        self.working_df = self.df[(self.df['date']>'2017-01-01') & (self.df['date']<'2024-01-01')]

    def window_filter(self,window_time):

        '''
            Take window time and return filter for dates d such that d in [window_time-0.5*window_size,window_time+0.5*window_size]
            e.g: Given time t, find all df entries within [t-6mo,t+6mo] for year window time
        '''

        if self.window_size=='year':
            filtering_condition = (self.working_df['date'] >= (window_time-pd.DateOffset(months=6))) & (self.working_df['date'] < (window_time+pd.DateOffset(months=6)))

        return filtering_condition 

    def fit_distributions(self,fit_type,plot=False):
        '''

        Help: Fit distributions based on statistics from 
        
        NOTE: May want to look at doing a fit to the rolling windows
        
        '''

        FIT_TYPES = ['gaussian','gaussian mixture']
        if fit_type not in FIT_TYPES:
            raise ValueError(f'Invalid fit_type. Types: {FIT_TYPES}') 
        self.fit_type=fit_type

        params = {t:None for t in self.fit_times}


        if fit_type=='gaussian':
            
            for t in self.fit_times:
                date_filt_condition = self.window_filter(window_time=t)
                date_filt_df = self.working_df[date_filt_condition]
                log_compute_data = date_filt_df['log_compute']

                mean = log_compute_data.mean()
                std = log_compute_data.mean()
                params[t] = {'mean':mean,'std':std}

                if plot: 
                    fig,ax=plt.subplots()
                    plus_minus = "\u00B1"
                    sns.kdeplot(log_compute_data,label=f'timestamp: {t.date()} {plus_minus} 6mo ',linewidth=2,ax=ax)
  
                    mean = log_compute_data.mean()
                    std = np.sqrt(log_compute_data.var()) #simple for now
                    x=np.linspace(10,30,1000)
                    ax.plot(x,norm.pdf(x,loc=mean,scale=std))
                    ax.grid(); ax.legend(loc='upper left')

        
        if fit_type == 'gaussian mixture':

            for t in self.fit_times:
                date_filt_condition = self.window_filter(window_time=t)
                date_filt_df = self.working_df[date_filt_condition]
                log_compute_data = date_filt_df['log_compute']

                gmm = GaussianMixture(n_components=2,
                                    random_state=random_seed,
                                    max_iter=100,
                                    n_init=10)
                gmm.fit(log_compute_data.to_numpy().reshape(-1,1))
                means,covariances,weights = gmm.means_,gmm.covariances_,gmm.weights_
                params[t] = {'means':means,
                            'covars':covariances,
                            'weights':weights}

        
        self.fitted_params = params

        return params
    
    def extrapolate_distributions(self):

        '''

        Help: Extrapolate distribution parameters to self.predict_times 


        NOTE: gmm extrapolation of covars and weights is very hacky right now
        covars: sample from uniform distribution [0.5,1.4]
        weights: sample from uniform distribution [0.15,0.35]
        '''
        
        if self.fit_type=='gaussian':
            
            #linear extrap means
            fit_dates = [t for t in self.fitted_params.keys()]
            fit_dates_float = np.array([t.timestamp() for t in fit_dates])
            means = np.array([self.fitted_params[t]['mean'] for t in self.fitted_params.keys()])
            predicted_dates_float = np.array([t.timestamp() for t in self.predict_times])


            model=LinearRegression()
            model.fit(fit_dates_float.reshape(-1,1),means)
            predicted_means = model.predict(predicted_dates_float.reshape(-1,1))
            retr_means = model.predict(fit_dates_float.reshape(-1,1))

            #sample std
            std_bounds = (1.1,1.6)
            predicted_stds = np.random.uniform(low=std_bounds[0],high=std_bounds[1],size=(predicted_means.shape))
            retr_stds = np.empty_like(retr_means); retr_stds.fill(np.nan) #not retrodicting params for now

            predicted_params = {t:(mu,std) for t,mu,std in list(zip(self.predict_times,predicted_means,predicted_stds))}
            retr_params = {t:(mu,std) for t,mu,std in list(zip(self.fit_times
                                                               ,retr_means,retr_stds))}

            self.distribution_parameter_model = model #not good atm code atm 


        elif self.fit_type=='gaussian mixture':

            lower_dist_mean_truncate = True
            lower_var_bounds = (0.5,1.4) #heuristically set (0.5,1.4)
            upper_var_bounds = (0.5,1.4) #exluding ~3.5 var for upper dist in 2023. Heuristic set (0.5,1.2)
            lower_weights_bound = (0.15,0.40) #heuristically set (0.15,0.35)
        

            retr_var_bounds = ()
            retr_var_bounds = ()
            retr_lower_weights_bound = ()

            ##get data to fit
            lower_idx = [self.fitted_params[t]['means'].argmin() for t in self.fitted_params.keys()]
            higher_idx = [self.fitted_params[t]['means'].argmax() for t in self.fitted_params.keys()]

            lower_dist_means = np.array([self.fitted_params[t]['means'][idx] for t,idx in list(zip(self.fitted_params.keys(),lower_idx))])
            higher_dist_means = np.array([self.fitted_params[t]['means'][idx] for t,idx in list(zip(self.fitted_params.keys(),higher_idx))])

            lower_dist_covars = np.array([self.fitted_params[t]['covars'][idx] for t,idx in list(zip(self.fitted_params.keys(),lower_idx))])
            higher_dist_covars = np.array([self.fitted_params[t]['covars'][idx] for t,idx in list(zip(self.fitted_params.keys(),higher_idx))])
            
            lower_dist_weights = np.array([self.fitted_params[t]['weights'][idx] for t,idx in list(zip(self.fitted_params.keys(),lower_idx))])
            higher_dist_weights = np.array([self.fitted_params[t]['weights'][idx] for t,idx in list(zip(self.fitted_params.keys(),higher_idx))])

            fit_times_float = np.array([t.timestamp() for t in self.fit_times])
            predict_times_float = np.array([t.timestamp() for t in self.predict_times])


            ##predict/retrodict means
            if not lower_dist_mean_truncate:
                lower_dist_mean_model = pwlf.PiecewiseLinFit(fit_times_float,lower_dist_means)
            else: 
                lower_dist_mean_model = pwlf.PiecewiseLinFit(fit_times_float[:-1],lower_dist_means[:-1])

            lower_dist_mean_model.fit(n_segments=2)
            pred_lower_means = lower_dist_mean_model.predict(predict_times_float)
            retr_lower_means = lower_dist_mean_model.predict(fit_times_float)

            higher_dist_mean_model = pwlf.PiecewiseLinFit(fit_times_float,higher_dist_means)
            higher_dist_mean_model.fit(n_segments=2)
            pred_higher_means = higher_dist_mean_model.predict(predict_times_float)
            retr_higher_means = higher_dist_mean_model.predict(fit_times_float)

            ##extrapolate vars
            pred_lower_vars = np.random.uniform(low=lower_var_bounds[0],high=lower_var_bounds[1],size=(pred_lower_means.shape))
            pred_higher_vars = np.random.uniform(low=upper_var_bounds[0],high=upper_var_bounds[1],size=(pred_higher_means.shape))
            retr_lower_vars = np.empty_like(retr_lower_means); retr_lower_vars.fill(np.nan) #emtpy as we're just taking fitted vars 
            retr_higher_vars = np.empty_like(retr_higher_means); retr_higher_vars.fill(np.nan) #empty as we're just taking fitted vars

            ##extraplate weights 
            pred_lower_weights = np.random.uniform(low=lower_weights_bound[0],high=lower_weights_bound[1],size=(pred_lower_means.shape))
            retr_lower_weights = np.empty_like(retr_lower_means); retr_lower_weights.fill(np.nan) #empty as we're just taking fitted vars

            ##set predicted params state var
            predicted_params = {t:((mu1,mu2),(var1,var2),(w1,w2)) for
                                     t,mu1,mu2,var1,var2,w1,w2 in
                                     list(zip(self.predict_times,
                                              pred_lower_means,pred_higher_means,
                                              pred_lower_vars,pred_higher_vars,
                                              pred_lower_weights,1-pred_lower_weights))
                                    }
            
            retr_params = {t:((mu1,mu2),(var1,var2),(w1,w2)) for
                                        t,mu1,mu2,var1,var2,w1,w2 in
                                        list(zip(self.fit_times,
                                                retr_lower_means,retr_higher_means,
                                                retr_lower_vars,retr_higher_vars,
                                                retr_lower_weights,1-retr_lower_weights))
                                        }
            

        else:
            pass

        self.predicted_params = predicted_params
        self.retrodicted_params = retr_params

        return predicted_params
    
    def model_counts(self,counts_fit_type):


        ## COULD add a kinked exponential
        COUNT_FIT_TYPES=['linear','exponential','kinked linear']

        if counts_fit_type not in COUNT_FIT_TYPES: raise ValueError(f'Expected fit in {COUNT_FIT_TYPES}')
    
        
        #time_data is bad var name but leftover from old code
        counts_data = {}

        for t in self.fit_times:

            if (t-pd.DateOffset(months=6)) < pd.Timestamp(self.start_time) or (t+pd.DateOffset(months=6)) >= pd.Timestamp(self.stop_time): 
                print(f'Skip {t} for model counts - cannot form a window of size "{self.window_size}" around this time')
                continue
            else:
                date_filt_condition = self.window_filter(window_time=t)
                date_tmp_df = self.working_df[date_filt_condition] #filtered df
                counts_data[t]=date_tmp_df.shape[0]

        #perform fitting
        fit_counts = np.array([v for v in counts_data.values()])
        fit_times_float_c = np.array([t.timestamp() for t in counts_data.keys()]) #fit to times that have counts value. _c appended as this is JUST for fitting counts model
        fit_times_float = np.array([t.timestamp() for t in self.fit_times])
        predict_times_float = np.array([t.timestamp() for t in self.predict_times])

        #import ipdb; ipdb.set_trace()

        if counts_fit_type=='linear':
            model = LinearRegression()
            model.fit(fit_times_float_c.reshape(-1,1),fit_counts)
            predicted_counts = model.predict(predict_times_float.reshape(-1,1))
            retr_counts = model.predict(fit_times_float.reshape(-1,1)).astype('int')

        elif counts_fit_type=='exponential':
            transformed_fit_x = x_transform_for_exp_fit(ref_x=fit_times_float,x=fit_times_float_c)
            popt,pcov = curve_fit(exponential_fit,transformed_fit_x,fit_counts)    
            a,b = popt; model = popt
            transformed_pred_x = x_transform_for_exp_fit(ref_x=fit_times_float,x=predict_times_float)
            predicted_counts = exponential_fit(transformed_pred_x,a=a,b=b)
            retr_counts = exponential_fit(transformed_fit_x,a=a,b=b)

        elif counts_fit_type=='kinked linear':
            model = pwlf.PiecewiseLinFit(fit_times_float,fit_counts)
            breakpoints = model.fit(2)
            predicted_counts = model.predict(predict_times_float)
            retr_counts = model.predict(fit_times_float)

        else:
            pass

        
        #set state vars
        self.fit_counts = fit_counts
        self.predicted_counts = predicted_counts.astype('int')
        self.count_fit_type = counts_fit_type
        self.retr_counts = retr_counts; assert len(retr_counts) == len(self.fit_times_float), print(retr_counts,self.fit_times) #sanity check
        self.counts_model = model

        return predicted_counts

    def generate_simulated_data(self):
        """
        Generated simulated log_compute samples based on choice of compute distribution and count model
        """

        self.simulated_data={}

        for pred_t,params,counts in list(zip(self.predict_times,self.predicted_params.values(),self.predicted_counts)):
                if self.fit_type == 'gaussian':
                    mu,sigma = params
                    log_compute_samples = norm.rvs(loc=mu,scale=sigma,size=counts)
                    self.simulated_data[pred_t] = log_compute_samples
    
                if self.fit_type== 'gaussian mixture':
                    mus,vars,ws = params
                    mu_l,mu_h = mus
                    var_l,var_h = vars
                    w_l,w_h = ws 
                    std_l,std_h = np.sqrt(var_l),np.sqrt(var_h)
    
                    log_compute_samples = sample_from_gmm(n_samples=counts,params=params)
                    self.simulated_data[pred_t] = log_compute_samples

    def count_threshold_models(self):

        thresholds = np.arange(23,30+1)
        threshold_counts_df = pd.DataFrame(columns=thresholds,index=self.predict_times)

        for t in self.predict_times:
            log_compute_samples = self.simulated_data[t]
            for thr in thresholds:
                n_exceed = log_compute_samples[log_compute_samples>=thr].size
                threshold_counts_df.at[t,thr] = n_exceed


        self.threshold_counts = threshold_counts_df
        return threshold_counts_df
    
    def count_frontier_models(self,frontier_thresholds=[0.5,1,1.5],past_counts=True):
        ''' 
        Do counting for frontier connected threshold.        

        Careful with interpretation of this given window size = year, but window freq != year sometimes
        '''
    
        frontier_counts_df = pd.DataFrame(columns=frontier_thresholds,index=self.predict_times)

        running_max=0
        for t in self.predict_times:
            for thr in frontier_thresholds:
                #we want number of samples within thr of max
                log_compute_samples = self.simulated_data[t]
                max_val = log_compute_samples.max()
                if max_val>running_max: running_max=max_val
                n_exceed = log_compute_samples[(running_max-log_compute_samples)<=thr].size
                frontier_counts_df.at[t,thr] = n_exceed

        self.frontier_counts_df = frontier_counts_df

        #do frontier connected counts for past counts 
        if past_counts:
            frontier_counts_df_past = pd.DataFrame(columns=frontier_thresholds,index=self.fit_times)

            running_max=0
            for t in self.fit_times:
                for thr in frontier_thresholds:
                    #we want number of samples within thr of max
                    log_compute_data = self.working_df[self.window_filter(t)]['log_compute']
                    max_val = log_compute_data.max()
                    if max_val>running_max: running_max=max_val
                    n_exceed = log_compute_data[(running_max-log_compute_data)<=thr].size
                    frontier_counts_df_past.at[t,thr] = n_exceed
            
            self.past_frontier_counts_df = frontier_counts_df_past
            

        return frontier_counts_df,frontier_counts_df_past
                  
    def verify_with_retrodiction(self):

        '''
        Params:

        Return:

        NOTE:
            - A lot of this isn't 'true' retrodiction - we're using some fitted parameters rather than retrodicted parameters
        '''

        retrodict_times = self.fit_times #retrodict for fitted time stamps
        retrodict_times_float = np.array([t.timestamp() for t in retrodict_times])
        retr_counts = self.retr_counts

        assert len(retr_counts) == len(retrodict_times), print(retr_counts,retrodict_times)


        thresholds = [23,24]
        predicted_past_counts = pd.DataFrame(index=retrodict_times,columns=thresholds)
        observed_past_counts = pd.DataFrame(index=retrodict_times,columns=thresholds)
        percent_error_df = pd.DataFrame(np.nan,index=retrodict_times,columns=thresholds)



        #pretty inefficient way to do it right now
        for idx,t in enumerate(retrodict_times):

            ##generate distributions and counts
            count = int(retr_counts[idx])

            if self.fit_type=='gaussian':
                mean = self.retrodicted_params[t][0]
                std = self.retrodicted_params[t][1]

                if np.isnan(std): std = self.working_df[self.window_filter(t)]['log_compute'].std() #USE EMPIRICAL STDS

                pred_log_compute_data = norm.rvs(loc=mean,scale=std,size=count)

            elif self.fit_type=='gaussian mixture':
            
                #unpack retrodicted params
                mus,vars,ws = self.retrodicted_params[t]
                mu_l,mu_h = mus
                var_l,var_h = vars
                w_l,w_h = ws
 
                if np.isnan(var_l): var_l = self.fitted_params[t]['covars'][0] #use fitted vars
                if np.isnan(var_h): var_h = self.fitted_params[t]['covars'][1] #use fitted vars
                if np.isnan(w_l): w_l = self.fitted_params[t]['weights'][0] #use fitted gmm weights
                if np.isnan(w_h): w_h = 1-w_l #use fitted gmm weights

                w_l,w_h = np.array([w_l]), np.array([w_h]) #for type compatibility with other parts of workflow

                params_retr = ((mu_l,mu_h),(var_l,var_h),(w_l,w_h))
                pred_log_compute_data = sample_from_gmm(n_samples=count,params=params_retr)

            ##get obs log compute data
            obs_log_compute_data = self.working_df[self.window_filter(t)]['log_compute']

            ##do threshold counts
            for thr in thresholds:
                #pred
                thr_count_pr = pred_log_compute_data[pred_log_compute_data>=thr].size
                predicted_past_counts.at[t,thr] = thr_count_pr

                #obs
                thr_count_ob = obs_log_compute_data[obs_log_compute_data>=thr].size
                observed_past_counts.at[t,thr] = thr_count_ob

        diff = predicted_past_counts - observed_past_counts
        obs_df_safe = observed_past_counts.replace(0,np.nan) #for safe division
        percent_error_df = (np.abs(diff)/obs_df_safe)*100
             
             
        self.predicted_past_counts = predicted_past_counts
        self.observed_past_counts = observed_past_counts

        return predicted_past_counts,observed_past_counts,diff,percent_error_df

    def verify_with_training_compute(self):
         
        '''
        Help: Compare predictions with extrapolations of training compute 
        NOTE: Not implemented for DataAnalysis class yet

        '''
        
        return None

#### Model predictions

In [None]:
##study one config
random_seed=42
random.seed(random_seed)
np.random.seed(random_seed)

tmp_df = df[(df['date']>'2017-01-01') & (df['date']<'2024-01-01')]
tmp_df['date'] = pd.to_datetime(tmp_df['date'])

window_freq = 'year' 
dist_fit='gaussian'
count_fit='linear'

analysis = DataAnalysis(df=tmp_df,window_freq=window_freq)
analysis.truncate_df_dates()
params = analysis.fit_distributions(fit_type=dist_fit,plot=False)
predicted_params = analysis.extrapolate_distributions()
predicted_counts = analysis.model_counts(counts_fit_type=count_fit)
simulated_data = analysis.generate_simulated_data()
threshold_counts = analysis.count_threshold_models()
frontier_connected_counts,past_frontier_counts = analysis.count_frontier_models(frontier_thresholds=[0.5,1,1.5])

#diff = pred - obs
pred_past_counts,obs_past_counts,diff,percent_error_df = analysis.verify_with_retrodiction()

In [None]:
# all model predictions

import itertools, warnings, tabulate

random_seed=42
random.seed(random_seed)
np.random.seed(random_seed)

distribution_fits = ['gaussian']
model_count_fits = ['linear','exponential','kinked linear']
window_freq = 'year'

index = [f"{dist_fit}-{count_fit}" for (dist_fit,count_fit) in list(itertools.product(distribution_fits,model_count_fits))]
all_retrodictions_df = pd.DataFrame(index=index,columns=analysis.fit_times.year)
all_predictions_df = pd.DataFrame(index=index,columns=analysis.predict_times.year)

for (dist_fit,count_fit) in list(itertools.product(distribution_fits,model_count_fits)):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        analysis = DataAnalysis(df=tmp_df,window_freq=window_freq)
        analysis.truncate_df_dates()
        _ = analysis.fit_distributions(fit_type=dist_fit,plot=False)
        _ = analysis.extrapolate_distributions() 
        _ = analysis.model_counts(counts_fit_type=count_fit)
        _ = analysis.generate_simulated_data()
        threshold_counts = analysis.count_threshold_models()
        pred_past_counts,obs_past_counts,diff,percent_error_df = analysis.verify_with_retrodiction()

        diff_23,diff_24 = diff.loc[:,23],diff.loc[:,24]
        diffs = list(zip(diff_23,diff_24))
        all_retrodictions_df.loc[f'{dist_fit}-{count_fit}'] = diffs

        pred_25,pred_26 = threshold_counts.loc[:,25],threshold_counts.loc[:,26]
        all_predictions_df.loc[f'{dist_fit}-{count_fit}'] = list(zip(pred_25,pred_26))

print(tabulate.tabulate(all_retrodictions_df,headers='keys',tablefmt='pretty'))
print('''
    Interpreting the table:
    - Rows - distribution fit, count fit combination
    - Columns - Retrodicted counts 
    - Values - (predicted - observed) counts for 1e23 and 1e24 thresholds
    - E.g: The gaussian mixture model with linear model count fit underpredicted by 3 for 1e23 threshold
        and underpredicted by 2 for 1e24 threshold for 2021

''')

print(tabulate.tabulate(all_predictions_df,headers='keys',tablefmt='pretty'))
print('''
    Interpreting the table:
    - Rows - distribution fit, count fit combination
    - Columns - Predicted counts
    - Values - Predicted counts for models exceeding 1e25 and 1e26 thresholds
    - E.g: The gaussian mixture model with linear model count fit predicts 16 models exceeding 1e25 FLOP and 
      2 models exceeding 1e26 FLOP in 2027 
      
''')



## Explore gmms

In [None]:
random_seed=42
random.seed(random_seed)
np.random.seed(random_seed)

tmp_df = df[(df['date']>'2017-01-01') & (df['date']<'2024-01-01')]
tmp_df['date'] = pd.to_datetime(tmp_df['date'])

window_freq = 'year' 
dist_fit='gaussian mixture'
count_fit='linear' 

#execute workflow 
analysis = DataAnalysis(df=tmp_df,window_freq=window_freq)
analysis.truncate_df_dates()
params = analysis.fit_distributions(fit_type=dist_fit,plot=False)
predicted_params = analysis.extrapolate_distributions()
predicted_counts = analysis.model_counts(counts_fit_type=count_fit)
threshold_counts = analysis.count_threshold_models()

In [None]:
fitted_gmm_params = analysis.fitted_params
predicted_gmm_params = analysis.predicted_params

#plot KDEs and fitted distributions
ncols,nrows = (2,4)
fig,axs = plt.subplots(ncols=ncols,nrows=nrows,figsize=(8,10))
axs = axs.ravel()

for idx,(time,params) in enumerate(fitted_gmm_params.items()):
    ax = axs[idx]
    means,vars,weights = params['means'],params['covars'],params['weights']
    means = (means[0].item(),means[1].item())
    
    ax.set_title(f'{time.year}')
    x,fitted_pdf = gmm_density_plot(params=(means,vars,weights),x_min=10,x_max=30)
    log_compute_data = analysis.working_df[analysis.window_filter(time)]['log_compute']

    ax.plot(x.squeeze(),fitted_pdf.squeeze(),linewidth=2,label=f'fitted gmm, n={len(log_compute_data)}')
    sns.kdeplot(log_compute_data,ax=ax,linewidth=2,alpha=0.7,label='empirical kde')

for ax in axs: ax.grid(); ax.legend()
fig.tight_layout()

ncols,nrows = (4,2)
fig,axs = plt.subplots(ncols=ncols,nrows=nrows,figsize=(8,10))
axs = axs.ravel()

for idx,(time,params) in enumerate(fitted_gmm_params.items()):
    ax = axs[idx]
    means,vars,weights = params['means'],params['covars'],params['weights']
    means = (means[0].item(),means[1].item())
    ax.set_title(f'{time.year}')

    lower_idx,upper_idx = means.index(min(means)),means.index(max(means))

    x = np.linspace(10,30,1000)
    pdf_lower = weights[lower_idx]*stats.norm.pdf(x,means[lower_idx],np.sqrt(vars[lower_idx]))
    pdf_upper = weights[upper_idx]*stats.norm.pdf(x,means[upper_idx],np.sqrt(vars[upper_idx]))
    ax.plot(x,pdf_lower.squeeze(),label='lower dist',c='blue')
    ax.plot(x,pdf_upper.squeeze(),label='upper dist',c='red')

    ax.grid();ax.legend();ax.set_ylim(0,0.5)

fig.set_size_inches(12,6)
fig.tight_layout()


#visualise fitted and extrapolated params
fig,(ax1,ax2,ax3) = plt.subplots(ncols=3)
fit_times = list(fitted_gmm_params.keys())
predict_times = list(predicted_gmm_params.keys())
low_color = 'b'; high_color = 'r'
fit_marker = 'o'; predict_marker = 'x'

if 1: #plot fitted params (o's)
    low_idx = [fitted_gmm_params[t]['means'].argmin() for t in fit_times] #get the idx for the lower dist (by mean)
    high_idx = [fitted_gmm_params[t]['means'].argmax() for t in fit_times] #get the idx for the upper dist (by mean)

    lower_dist_means = [fitted_gmm_params[t]['means'][idx] for t,idx in list(zip(fit_times,low_idx))]
    upper_dist_means = [fitted_gmm_params[t]['means'][idx] for t,idx in  list(zip(fit_times,high_idx))]

    lower_dist_vars = [fitted_gmm_params[t]['covars'][idx] for t,idx in list(zip(fit_times,low_idx))]
    upper_dist_vars = [fitted_gmm_params[t]['covars'][idx] for t,idx in  list(zip(fit_times,high_idx))]

    lower_dist_weight = [fitted_gmm_params[t]['weights'][idx] for t,idx in list(zip(fit_times,low_idx))]
    upper_dist_weight = [fitted_gmm_params[t]['weights'][idx] for t,idx in list(zip(fit_times,high_idx))]

    ax1.scatter(fit_times,lower_dist_means,c=low_color,marker=fit_marker,label='lower dist')
    ax1.scatter(fit_times,upper_dist_means,c=high_color,marker=fit_marker,label='upper dist')

    ax2.scatter(fit_times,lower_dist_vars,c=low_color,marker=fit_marker,label='fitted lower dist')
    ax2.scatter(fit_times,upper_dist_vars,c=high_color,marker=fit_marker,label='fitted upper dist')

    ax3.scatter(fit_times,lower_dist_weight,c=low_color,marker=fit_marker,label='lower dist')
    ax3.scatter(fit_times,upper_dist_weight,c=high_color,marker=fit_marker,label='upper dist')


if 1: #plot predicted params (x's)
    low_dist_means = [predicted_gmm_params[t][0][0] for t in predict_times]
    high_dist_means = [predicted_gmm_params[t][0][1] for t in predict_times]

    low_dist_vars = [predicted_gmm_params[t][1][0] for t in predict_times]
    high_dist_vars = [predicted_gmm_params[t][1][1] for t in predict_times]

    low_dist_weight = [predicted_gmm_params[t][2][0] for t in predict_times]
    high_dist_weight = [predicted_gmm_params[t][2][1] for t in predict_times]

    ax1.scatter(predict_times,low_dist_means,c=low_color,marker=predict_marker,label='lower dist')
    ax1.scatter(predict_times,high_dist_means,c=high_color,marker=predict_marker,label='upper dist')

    ax2.scatter(predict_times,low_dist_vars,c=low_color,marker=predict_marker,label='predicted lower dist')
    ax2.scatter(predict_times,high_dist_vars,c=high_color,marker=predict_marker,label='predicted upper dist')

    ax3.scatter(predict_times,low_dist_weight,c=low_color,marker=predict_marker,label='lower dist')
    ax3.scatter(predict_times,high_dist_weight,c=high_color,marker=predict_marker,label='upper dist')
    


for ax in (ax1,ax2,ax3): 
    #ax.set_xticks(np.concatenate([[t.year for t in fit_times],[t.year for t in predict_times]]))
    ax.grid()

    if ax==ax2:
        pass
        ax2.legend(loc=(0.5,0.6))



fig.set_size_inches(12,4)
fig.suptitle('Fitted and Predicted GMM Parameters')
fig.tight_layout()



In [None]:
## Explore the 2023 fit
## 2023 captures a nice skew
## also gmm interpretation is nice - two distributions evolving together 

##note - I've just manually found a better covar fit than automatic method I've been using
import copy

gmm_params_2023 = copy.deepcopy(fitted_gmm_params[pd.Timestamp('2023-07-01')])
print(gmm_params_2023)
means,covars,weights = gmm_params_2023['means'],gmm_params_2023['covars'],gmm_params_2023['weights']
covars[0],covars[1] = 0.5,3 #1.7,3
weights[0]=0.8 #0.8
weights[1] = 1-weights[0]

x,fitted_pdf = gmm_density_plot(params=(means,covars,weights),x_min=10,x_max=30)
log_compute_data = analysis.working_df[analysis.window_filter(pd.Timestamp('2023-07-01'))]['log_compute']
fig,ax=plt.subplots()
ax.plot(x.squeeze(),fitted_pdf.squeeze(),linewidth=2,label='fitted gmm',c='tab:blue')
#sns.kdeplot(log_compute_data,ax=ax,linewidth=2,alpha=0.7,label='empirical kde',c='tab:red')
ax.legend(); ax.grid();ax.set_ylim([0,0.4])




Handfitting the gmm, we see:
- we can get a better fit with hand than our automatic process did
  - So imperfect fit might not be a result of inflexibility of distribution, but suboptimal fitting process instead.
  - Intuitively, what does our fit mean?
    - We've got two model families, call them lower and upper for now.
    - In 2023, the upper distribution was dominant, with a weight of 80%. That is, we should sample 80% of the time from the upper distribution.
    - In 2023, the upper distribution was significantly more concentrated, with a variance of ~55% of the lower mean. So the lower mean was more spread over models.
    - This could fit the idea of a recent class of models that are highly 'compute specialised' -> all being trained in the ~1e21 -> 1e25 FLOP range (?)
  - Fitted parameter trends
    - Around 2021 we see the weights of the upper distribution pick up, a change in the trend from pre-2020 where the weights were similar, and favouring the lower distribution
    - Again - around 2021, the upper distribution begins to clearly dominate.
    - An increase in the lower dist variance contributes to non-bumpy skew in 2023 distribution. This means we don't see the 'two clear modes', just one instead.
      - So *whilst* pre-2023 distributions show a clear double-bumb, if we want to achieve a single-bump + minor skew, we set a low variance for the lower-weighted distribution. 
    - 2023 surprises me for the variances and weights
      - variances - both *jump* relative to pre-2023.
      - Weights - again, jumps a noticeable amount.
    - AGAIN - let's note, these may not be the best fits. E.g: see our better fit for the 2023 distribution.







- What are the implications of this on future distributions?
  - Both classes of models will see mean compute rise
  - I can see "upper class" spread out a bit, as foundation model compute range becomes more broad
  - I can see upper class take increasing weight