In [1]:
#This script takes every simulation with all points and calculates rmse and mae in each case

In [1]:
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import seaborn as sns
import matplotlib.gridspec as gridspec
import ast
import sys
sys.path.append('machine-scientist/')
sys.path.append('machine-scientist/Prior/')
from mcmc import *
from parallel import *
from fit_prior import read_prior_par
from sklearn.metrics import mean_squared_error
from sklearn.metrics import root_mean_squared_error
from sklearn.metrics import mean_absolute_error

In [2]:
def clean_index(dataframe):
    dataframe.set_index('Unnamed: 0', inplace=True)
    dataframe.index.name = None
    dataframe= dataframe.reset_index(drop=True)
    return dataframe

def add_bms_pred(dataframe, bms_trace, number_param):
    VARS = ['x1',]
    x = dn[[c for c in VARS]].copy()
    y=dataframe.noise

    if number_param==10:
        prior_par = read_prior_par('machine-scientist/Prior/final_prior_param_sq.named_equations.nv1.np10.2017-10-18 18:07:35.089658.dat')
    elif number_param==20:
        prior_par = read_prior_par('machine-scientist/Prior/final_prior_param_sq.named_equations.nv1.np20.maxs200.2024-05-10 162907.551306.dat')

    #mdl model
    minrow = bms_trace[bms_trace.H == min(bms_trace.H)].iloc[0]
    minH, minexpr, minparvals = minrow.H, minrow.expr, ast.literal_eval(minrow.parvals)

    t = Tree(
        variables=list(x.columns),
        parameters=['a%d' % i for i in range(number_param)],
        x=x, y=y,
        prior_par=prior_par,
        max_size=200,
        from_string=minexpr,
    )

    t.set_par_values(deepcopy(minparvals))

    dplot = deepcopy(dn)
    dplot['ybms'] = t.predict(x)

    return dplot
    

In [3]:
#Read NN and BMS data
functions=['leaky_ReLU', 'tanh'] #tanh, leaky_ReLU
realizations=2
N=9

sigmas=[0.0, 0.02, 0.04,0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18, 0.20]
resolution='4e-3x' #0.5x, 1x, 2x, 4e-3x
resolutions={'0.5x':'0.1', '1x':'0.05' , '2x': '0.025' , '4e-3x':'0.004' }

runid=0
NPAR=10 #10, 20
steps=50000


rmse_nn_train=[];rmse_nn_test=[]
rmse_mdl_train=[];rmse_mdl_test=[]

mae_nn_train=[];mae_nn_test=[]
mae_mdl_train=[];mae_mdl_test=[]

n_index=[];r_index=[];sigma_index=[];function_index=[]

#Put mae and rmse of each simulation (on nn and bms) in a dataframe
for function in functions:

    for sigma in sigmas:

        for r in range(realizations+1):
            
            file_model='NN_no_overfit_' + function + '_sigma_' + str(sigma) + '_r_' + str(r) + '.csv'
            model_d='../data/' + resolution + '_resolution/trained_nns/' + file_model
            d=pd.read_csv(model_d)

            n_points=int(len(d.index)/10)
            train_fraction=3/4;train_size=int(n_points*train_fraction)
            print(train_size)

            for n in range(N+1):
                n_index.append(n);r_index.append(r);sigma_index.append(sigma);function_index.append(function)
            
                dn=d[d['rep']==n]
                dn=clean_index(dn)

                #Read BMS data
                filename='BMS_'+function+'_n_'+str(n)+'_sigma_'+str(sigma)+ '_r_' + str(r) + '_res_' + resolutions[resolution] + '_trace_'\
                +str(steps)+'_prior_'+str(NPAR)+ '.csv'

                print(function, sigma, n, r)
                
                try:
                    trace=pd.read_csv('../data/MSTraces/' + resolution + '_resolution/' + filename, sep=';', header=None, names=['t', 'H', 'expr', 'parvals', 'kk1', 'kk2','kk3'])
                    dplot=add_bms_pred(dn, trace, NPAR)
                except FileNotFoundError:
                    dplot = deepcopy(dn) #If no bms errors available, fill the dataframe with zeros
                    dplot['ybms'] = [0] * len(dplot)
                    #display(dplot)
                    
                

                #Errors
            
                #nns
                rmse_nn_train_i=root_mean_squared_error(dplot.loc[:train_size-1]['ymodel'],dplot.loc[:train_size -1]['y'])
                rmse_nn_train.append(rmse_nn_train_i)
            
                rmse_nn_test_i=root_mean_squared_error(dplot.loc[train_size-1:]['ymodel'],dplot.loc[train_size -1:]['y'])
                rmse_nn_test.append(rmse_nn_test_i)

                mae_nn_train_i=mean_absolute_error(dplot.loc[:train_size-1]['ymodel'],dplot.loc[:train_size -1]['y'])
                mae_nn_train.append(mae_nn_train_i)
            
                mae_nn_test_i=mean_absolute_error(dplot.loc[train_size-1:]['ymodel'],dplot.loc[train_size -1:]['y'])
                mae_nn_test.append(mae_nn_test_i)


                try:
                    rmse_mdl_train_i=root_mean_squared_error(dplot.loc[:train_size-1]['ybms'],dn.loc[:train_size-1]['y'])
                except ValueError:
                    rmse_mdl_train_i=np.inf
                rmse_mdl_train.append(rmse_mdl_train_i)

                try:
                    rmse_mdl_test_i=root_mean_squared_error(dplot.loc[train_size-1:]['ybms'],dn.loc[train_size-1:]['y'])
                except (ValueError, RuntimeWarning) as e:
                    rmse_mdl_test_i=np.inf
                
                rmse_mdl_test.append(rmse_mdl_test_i)

                try:
                    mae_mdl_train_i=mean_absolute_error(dplot.loc[:train_size-1]['ybms'],dplot.loc[:train_size -1]['y'])
                except ValueError:
                    mae_mdl_train_i=np.inf
                mae_mdl_train.append(mae_mdl_train_i)

                try:
                    mae_mdl_test_i=mean_absolute_error(dplot.loc[train_size-1:]['ybms'],dplot.loc[train_size -1:]['y'])
                except ValueError:
                    mae_mdl_test_i=np.inf
                
                mae_mdl_test.append(mae_mdl_test_i)

errors_df=pd.DataFrame({'sigma':sigma_index, 'function':function_index, 'mae_nn_train':mae_nn_train, 'mae_nn_test':mae_nn_test, 'mae_mdl_train':mae_mdl_train, 
                        'mae_mdl_test':mae_mdl_test, 'rmse_nn_train':rmse_nn_train, 'rmse_nn_test': rmse_nn_test, 
                        'rmse_mdl_train':rmse_mdl_train, 'rmse_mdl_test': rmse_mdl_test, 'n':n_index, 'r': r_index})
errors_df.to_csv('../data/all_errors_' + resolution + '.csv')
display(errors_df)

750
leaky_ReLU 0.0 0 0
leaky_ReLU 0.0 1 0
leaky_ReLU 0.0 2 0
leaky_ReLU 0.0 3 0
leaky_ReLU 0.0 4 0
leaky_ReLU 0.0 5 0
leaky_ReLU 0.0 6 0
leaky_ReLU 0.0 7 0
leaky_ReLU 0.0 8 0
leaky_ReLU 0.0 9 0
750
leaky_ReLU 0.0 0 1
leaky_ReLU 0.0 1 1
leaky_ReLU 0.0 2 1
leaky_ReLU 0.0 3 1
leaky_ReLU 0.0 4 1
leaky_ReLU 0.0 5 1
leaky_ReLU 0.0 6 1
leaky_ReLU 0.0 7 1
leaky_ReLU 0.0 8 1
leaky_ReLU 0.0 9 1
750
leaky_ReLU 0.0 0 2
leaky_ReLU 0.0 1 2
leaky_ReLU 0.0 2 2
leaky_ReLU 0.0 3 2
leaky_ReLU 0.0 4 2
leaky_ReLU 0.0 5 2
leaky_ReLU 0.0 6 2
leaky_ReLU 0.0 7 2
leaky_ReLU 0.0 8 2
leaky_ReLU 0.0 9 2
750
leaky_ReLU 0.02 0 0
leaky_ReLU 0.02 1 0
leaky_ReLU 0.02 2 0
leaky_ReLU 0.02 3 0
leaky_ReLU 0.02 4 0
leaky_ReLU 0.02 5 0
leaky_ReLU 0.02 6 0
leaky_ReLU 0.02 7 0
leaky_ReLU 0.02 8 0
leaky_ReLU 0.02 9 0
750
leaky_ReLU 0.02 0 1
leaky_ReLU 0.02 1 1
leaky_ReLU 0.02 2 1
leaky_ReLU 0.02 3 1
leaky_ReLU 0.02 4 1
leaky_ReLU 0.02 5 1
leaky_ReLU 0.02 6 1
leaky_ReLU 0.02 7 1
leaky_ReLU 0.02 8 1
leaky_ReLU 0.02 9 1
750
leaky_

Unnamed: 0,sigma,function,mae_nn_train,mae_nn_test,mae_mdl_train,mae_mdl_test,rmse_nn_train,rmse_nn_test,rmse_mdl_train,rmse_mdl_test,n,r
0,0.0,leaky_ReLU,0.001531,0.040837,0.000496,0.021230,0.001936,0.049012,0.000634,0.029018,0,0
1,0.0,leaky_ReLU,0.002255,0.155396,0.001845,1.389800,0.003058,0.164442,0.002424,2.676518,1,0
2,0.0,leaky_ReLU,0.003467,0.036872,0.002618,0.068700,0.005003,0.043094,0.004424,0.080606,2,0
3,0.0,leaky_ReLU,0.001850,0.056598,0.000623,0.298642,0.002887,0.060429,0.000860,0.347085,3,0
4,0.0,leaky_ReLU,0.001471,0.235353,0.001347,inf,0.002086,0.279269,0.001780,inf,4,0
...,...,...,...,...,...,...,...,...,...,...,...,...
655,0.2,tanh,0.027551,0.067124,0.024260,0.117640,0.034761,0.083256,0.037562,0.123875,5,2
656,0.2,tanh,0.028916,0.120084,0.026511,0.384823,0.037839,0.138045,0.034651,0.451462,6,2
657,0.2,tanh,0.035341,0.063014,0.012995,0.077759,0.042194,0.078681,0.019436,0.091211,7,2
658,0.2,tanh,0.026682,0.505530,0.019877,0.094024,0.035372,0.565568,0.022791,0.096903,8,2
