In [8]:
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import seaborn as sns
import matplotlib.gridspec as gridspec
import ast
import sys
sys.path.append('machine-scientist/')
sys.path.append('machine-scientist/Prior/')
from mcmc import *
from parallel import *
from fit_prior import read_prior_par
from sklearn.metrics import mean_squared_error
from sklearn.metrics import root_mean_squared_error
from sklearn.metrics import mean_absolute_error

In [9]:
def clean_index(dataframe):
    dataframe.set_index('Unnamed: 0', inplace=True)
    dataframe.index.name = None
    dataframe= dataframe.reset_index(drop=True)
    return dataframe

def add_bms_pred(dataframe, bms_trace, number_param):
    VARS = ['x1',]
    x = dn[[c for c in VARS]].copy()
    y=dataframe.noise

    if number_param==10:
        prior_par = read_prior_par('machine-scientist/Prior/final_prior_param_sq.named_equations.nv1.np10.2017-10-18 18:07:35.089658.dat')
    elif number_param==20:
        prior_par = read_prior_par('machine-scientist/Prior/final_prior_param_sq.named_equations.nv1.np20.maxs200.2024-05-10 162907.551306.dat')

    #mdl model
    minrow = bms_trace[bms_trace.H == min(bms_trace.H)].iloc[0]
    minH, minexpr, minparvals = minrow.H, minrow.expr, ast.literal_eval(minrow.parvals)

    t = Tree(
        variables=list(x.columns),
        parameters=['a%d' % i for i in range(number_param)],
        x=x, y=y,
        prior_par=prior_par,
        max_size=200,
        from_string=minexpr,
    )

    t.set_par_values(deepcopy(minparvals))

    dplot = deepcopy(dn)
    dplot['ybms'] = t.predict(x)

    return dplot
    
    

In [30]:
#Read NN and BMS data
#n=0;
function='leaky_ReLU' #tanh, leaky_ReLU
realizations=2
N=9
#sigmas=[sigma_y for sigma_y in np.arange(0,0.2,0.02)]
sigmas=[0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18, 0.20]


runid=0
NPAR=10 #10, 20
steps=50000
train_size=60

rmse_nn_train=[];rmse_nn_test=[]
rmse_mdl_train=[];rmse_mdl_test=[]

mae_nn_train=[];mae_nn_test=[]
mae_mdl_train=[];mae_mdl_test=[]

n_index=[]
r_index=[]
sigma_index=[]

for sigma in sigmas:
    for r in range(realizations+1):
        #Read NN data
        if sigma==0.1:
            file_model='NN_no_overfit_' + function + '_sigma_' + '0.10' + '_r_' + str(r) + '.csv'
        else:
        
            file_model='NN_no_overfit_' + function + '_sigma_' + str(sigma) + '_r_' + str(r) + '.csv'
        model_d='../data/trained_nns/' + file_model
        d=pd.read_csv(model_d)

        for n in range(N+1):
            n_index.append(n)
            r_index.append(r)
            sigma_index.append(sigma)
            dn=d[d['rep']==n]
            dn=clean_index(dn)

            #Read BMS data
            if sigma==0.1:
                filename='BMS_'+function+'_n_'+str(n)+'_sigma_'+'0.10'+ '_r_' + str(r) + '_trace_'+str(steps)+'_prior_'+str(NPAR)+ '.csv'
            else:
                filename='BMS_'+function+'_n_'+str(n)+'_sigma_'+str(sigma)+ '_r_' + str(r) + '_trace_'+str(steps)+'_prior_'+str(NPAR)+ '.csv'
        
        #filename='BMS_'+function+'_n_'+str(n)+'_sigma_'+str(sigma)+'_trace_'+str(steps)+'_prior_'+str(NPAR)+'_id_'+str(runid)+'.csv'
            trace=pd.read_csv('../data/MSTraces/' + filename, sep=';', header=None, names=['t', 'H', 'expr', 'parvals', 'kk1', 'kk2','kk3'])
            dplot=add_bms_pred(dn, trace, NPAR)

            #Errors
            rmse_nn_train_i=root_mean_squared_error(dplot.loc[:train_size-1]['ymodel'],dplot.loc[:train_size -1]['y'])
            rmse_nn_train.append(rmse_nn_train_i)
            rmse_nn_test_i=root_mean_squared_error(dplot.loc[train_size-1:]['ymodel'],dplot.loc[train_size -1:]['y'])
            rmse_nn_test.append(rmse_nn_test_i)

            mae_nn_train_i=mean_absolute_error(dplot.loc[:train_size-1]['ymodel'],dplot.loc[:train_size -1]['y'])
            mae_nn_train.append(mae_nn_train_i)
            mae_nn_test_i=mean_absolute_error(dplot.loc[train_size-1:]['ymodel'],dplot.loc[train_size -1:]['y'])
            mae_nn_test.append(mae_nn_test_i)
    
            rmse_mdl_i=mean_squared_error(dplot.ybms,dn.y)
            rmse_mdl_train_i=root_mean_squared_error(dplot.loc[:train_size-1]['ybms'],dn.loc[:train_size-1]['y'])
            rmse_mdl_train.append(rmse_mdl_train_i)
            rmse_mdl_test_i=root_mean_squared_error(dplot.loc[train_size-1:]['ybms'],dn.loc[train_size-1:]['y'])
            rmse_mdl_test.append(rmse_mdl_test_i)

            mae_mdl_train_i=mean_absolute_error(dplot.loc[:train_size-1]['ybms'],dplot.loc[:train_size -1]['y'])
            mae_mdl_train.append(mae_mdl_train_i)
            mae_mdl_test_i=mean_absolute_error(dplot.loc[train_size-1:]['ybms'],dplot.loc[train_size -1:]['y'])
            mae_mdl_test.append(mae_mdl_test_i)

errors_df=pd.DataFrame({'sigma':sigma_index, 'mae_nn_train':mae_nn_train, 'mae_nn_test':mae_nn_test, 'mae_mdl_train':mae_mdl_train, 
                        'mae_mdl_test':mae_mdl_test, 'rmse_nn_train':rmse_nn_train, 'rmse_nn_test': rmse_nn_test, 
                        'rmse_mdl_train':rmse_mdl_train, 'rmse_mdl_test': rmse_mdl_test, 'n':n_index, 'r': r_index})

In [31]:
display(errors_df)

display(errors_df[errors_df['sigma']==0.12])

mean_errors_df=errors_df.groupby(['sigma'],as_index=False)[['mae_nn_train', 'mae_nn_test', 'mae_mdl_train','mae_mdl_test',
                                                            'rmse_nn_train', 'rmse_nn_test', 'rmse_mdl_train','rmse_mdl_test']].mean()
display(mean_errors_df)

std_errors_df=errors_df.groupby(['sigma'],as_index=False)[['mae_nn_train', 'mae_nn_test', 'mae_mdl_train','mae_mdl_test',
                                                           'rmse_nn_train', 'rmse_nn_test', 'rmse_mdl_train','rmse_mdl_test']].std()

std_errors_df=std_errors_df.rename(columns={'mae_nn_train':'mae_nn_train_std', 'mae_nn_test':'mae_nn_test_std', 
                                            'mae_mdl_train':'mae_mdl_train_std','mae_mdl_test':'mae_mdl_test_std',
                                            'rmse_nn_train': 'rmse_nn_train_std', 'rmse_nn_test':'rmse_nn_test_std' , 
                                            'rmse_mdl_train':'rmse_mdl_train_std','rmse_mdl_test':'rmse_mdl_test_std'})

std_errors_df[['mae_nn_train_sdm', 'mae_nn_test_sdm', 'mae_mdl_train_sdm', 'mae_mdl_test_sdm',
               'rmse_nn_train_sdm', 'rmse_nn_test_sdm', 'rmse_mdl_train_sdm', 'rmse_mdl_test_sdm']]=\
std_errors_df[['mae_nn_train_std', 'mae_nn_test_std', 'mae_mdl_train_std', 'mae_mdl_test_std',
               'rmse_nn_train_std', 'rmse_nn_test_std', 'rmse_mdl_train_std', 'rmse_mdl_test_std']]/np.sqrt(N*(realizations+1))
display(std_errors_df)


errors_statistics_df=mean_errors_df.join(std_errors_df.set_index('sigma'), on='sigma')

display(errors_statistics_df)

#save error dataframes
errors_statistics_df.to_csv('../data/'+ 'errors_statistics' + str(function) + '.csv')

Unnamed: 0,sigma,mae_nn_train,mae_nn_test,mae_mdl_train,mae_mdl_test,rmse_nn_train,rmse_nn_test,rmse_mdl_train,rmse_mdl_test,n,r
0,0.06,0.018732,0.083284,0.022398,0.067949,0.024027,0.090971,0.028203,0.072539,0,0
1,0.06,0.011968,0.244179,0.024748,0.308734,0.015620,0.262822,0.030642,0.327105,1,0
2,0.06,0.026609,0.082431,0.040904,0.279303,0.033940,0.099486,0.045787,0.314320,2,0
3,0.06,0.029613,0.328843,0.010539,0.087249,0.053144,0.337817,0.015333,0.113934,3,0
4,0.06,0.014908,0.054975,0.017585,0.053622,0.017927,0.065063,0.021159,0.062952,4,0
...,...,...,...,...,...,...,...,...,...,...,...
235,0.20,0.074124,0.187386,0.022778,0.125844,0.111458,0.211543,0.024829,0.152622,5,2
236,0.20,0.117989,0.063681,0.126591,0.238867,0.142708,0.068140,0.142359,0.242984,6,2
237,0.20,0.088039,0.255640,0.023182,0.032335,0.124684,0.268635,0.029091,0.038541,7,2
238,0.20,0.127539,0.209643,0.098171,0.143706,0.168219,0.219296,0.114703,0.160794,8,2


Unnamed: 0,sigma,mae_nn_train,mae_nn_test,mae_mdl_train,mae_mdl_test,rmse_nn_train,rmse_nn_test,rmse_mdl_train,rmse_mdl_test,n,r
90,0.12,0.024105,0.064049,0.026392,0.059776,0.029364,0.069824,0.033791,0.063594,0,0
91,0.12,0.027819,0.215542,0.025033,0.318061,0.03447,0.235767,0.033035,0.335922,1,0
92,0.12,0.0333,0.129659,0.043621,0.264625,0.03809,0.159067,0.04918,0.301438,2,0
93,0.12,0.069332,0.468655,0.058642,0.410887,0.100397,0.47824,0.080193,0.420592,3,0
94,0.12,0.045996,0.170471,0.028619,0.119142,0.058408,0.185553,0.038973,0.121501,4,0
95,0.12,0.034485,0.179752,0.019287,0.167781,0.043346,0.202527,0.021804,0.188699,5,0
96,0.12,0.032384,0.060878,0.04783,0.052848,0.041411,0.069149,0.055138,0.05314,6,0
97,0.12,0.047963,0.224947,0.031236,0.05925,0.063136,0.23798,0.04044,0.078333,7,0
98,0.12,0.053189,0.266779,0.101127,0.162117,0.067625,0.278649,0.11359,0.177441,8,0
99,0.12,0.030545,0.066875,0.041367,0.038006,0.04134,0.079204,0.051557,0.038415,9,0


Unnamed: 0,sigma,mae_nn_train,mae_nn_test,mae_mdl_train,mae_mdl_test,rmse_nn_train,rmse_nn_test,rmse_mdl_train,rmse_mdl_test
0,0.06,0.022926,0.138846,0.021438,0.393304,0.030396,0.151697,0.026201,0.656841
1,0.08,0.030155,0.133946,0.028633,0.161897,0.03915,0.14599,0.036684,0.180181
2,0.1,0.043011,0.170464,0.039026,0.197208,0.055253,0.185085,0.048659,0.21471
3,0.12,0.090469,0.672278,0.041327,0.174468,0.16904,0.685238,0.050646,0.191976
4,0.14,0.060481,0.250801,0.053748,0.26882,0.082092,0.2637,0.064954,0.284991
5,0.16,0.074424,0.38387,0.056738,0.195993,0.109797,0.40031,0.068208,0.208397
6,0.18,0.072528,0.362814,0.054151,0.267464,0.101848,0.377574,0.067815,0.286291
7,0.2,0.077405,0.174527,0.070661,0.243512,0.102722,0.188062,0.083693,0.258529


Unnamed: 0,sigma,mae_nn_train_std,mae_nn_test_std,mae_mdl_train_std,mae_mdl_test_std,rmse_nn_train_std,rmse_nn_test_std,rmse_mdl_train_std,rmse_mdl_test_std,mae_nn_train_sdm,mae_nn_test_sdm,mae_mdl_train_sdm,mae_mdl_test_sdm,rmse_nn_train_sdm,rmse_nn_test_sdm,rmse_mdl_train_sdm,rmse_mdl_test_sdm
0,0.06,0.014663,0.086976,0.009031,1.350356,0.020391,0.08817,0.009934,2.679587,0.002822,0.016739,0.001738,0.259876,0.003924,0.016968,0.001912,0.515687
1,0.08,0.013408,0.077703,0.014847,0.128635,0.017032,0.079828,0.019986,0.135553,0.00258,0.014954,0.002857,0.024756,0.003278,0.015363,0.003846,0.026087
2,0.1,0.023206,0.135997,0.021902,0.141163,0.02713,0.13714,0.026403,0.148304,0.004466,0.026173,0.004215,0.027167,0.005221,0.026393,0.005081,0.028541
3,0.12,0.08318,0.725744,0.02343,0.144657,0.192881,0.7205,0.026478,0.167788,0.016008,0.13967,0.004509,0.027839,0.03712,0.13866,0.005096,0.032291
4,0.14,0.036071,0.283822,0.030355,0.130661,0.056115,0.284359,0.034811,0.134604,0.006942,0.054622,0.005842,0.025146,0.010799,0.054725,0.006699,0.025904
5,0.16,0.05095,0.450818,0.03649,0.131344,0.107624,0.451061,0.042382,0.135475,0.009805,0.08676,0.007022,0.025277,0.020712,0.086807,0.008156,0.026072
6,0.18,0.049512,0.43219,0.030026,0.148638,0.088546,0.431265,0.036281,0.154245,0.009529,0.083175,0.005779,0.028605,0.017041,0.082997,0.006982,0.029685
7,0.2,0.031661,0.164236,0.036684,0.126286,0.042746,0.171907,0.041819,0.1271,0.006093,0.031607,0.00706,0.024304,0.008227,0.033084,0.008048,0.02446


Unnamed: 0,sigma,mae_nn_train,mae_nn_test,mae_mdl_train,mae_mdl_test,rmse_nn_train,rmse_nn_test,rmse_mdl_train,rmse_mdl_test,mae_nn_train_std,...,rmse_mdl_train_std,rmse_mdl_test_std,mae_nn_train_sdm,mae_nn_test_sdm,mae_mdl_train_sdm,mae_mdl_test_sdm,rmse_nn_train_sdm,rmse_nn_test_sdm,rmse_mdl_train_sdm,rmse_mdl_test_sdm
0,0.06,0.022926,0.138846,0.021438,0.393304,0.030396,0.151697,0.026201,0.656841,0.014663,...,0.009934,2.679587,0.002822,0.016739,0.001738,0.259876,0.003924,0.016968,0.001912,0.515687
1,0.08,0.030155,0.133946,0.028633,0.161897,0.03915,0.14599,0.036684,0.180181,0.013408,...,0.019986,0.135553,0.00258,0.014954,0.002857,0.024756,0.003278,0.015363,0.003846,0.026087
2,0.1,0.043011,0.170464,0.039026,0.197208,0.055253,0.185085,0.048659,0.21471,0.023206,...,0.026403,0.148304,0.004466,0.026173,0.004215,0.027167,0.005221,0.026393,0.005081,0.028541
3,0.12,0.090469,0.672278,0.041327,0.174468,0.16904,0.685238,0.050646,0.191976,0.08318,...,0.026478,0.167788,0.016008,0.13967,0.004509,0.027839,0.03712,0.13866,0.005096,0.032291
4,0.14,0.060481,0.250801,0.053748,0.26882,0.082092,0.2637,0.064954,0.284991,0.036071,...,0.034811,0.134604,0.006942,0.054622,0.005842,0.025146,0.010799,0.054725,0.006699,0.025904
5,0.16,0.074424,0.38387,0.056738,0.195993,0.109797,0.40031,0.068208,0.208397,0.05095,...,0.042382,0.135475,0.009805,0.08676,0.007022,0.025277,0.020712,0.086807,0.008156,0.026072
6,0.18,0.072528,0.362814,0.054151,0.267464,0.101848,0.377574,0.067815,0.286291,0.049512,...,0.036281,0.154245,0.009529,0.083175,0.005779,0.028605,0.017041,0.082997,0.006982,0.029685
7,0.2,0.077405,0.174527,0.070661,0.243512,0.102722,0.188062,0.083693,0.258529,0.031661,...,0.041819,0.1271,0.006093,0.031607,0.00706,0.024304,0.008227,0.033084,0.008048,0.02446
