In [1]:
from symbolic_pursuit.models import SymbolicRegressor  # our symbolic model class
from sklearn.metrics import mean_squared_error # we are going to assess the quality of the model based on the generalization MSE
from sympy import init_printing # We use sympy to display mathematical expresssions 
import numpy as np # we use numpy to deal with arrays
import lime 
import lime.lime_tabular
init_printing()

In [2]:
def f(X):
    return X[:, 0]+2*X[:,1]+3*X[:,2]

dim_X = 3

In [3]:
n_pts = 100
X = np.random.uniform(0, 1, (n_pts, dim_X))

In [4]:
n_test = 10
X_test = np.random.uniform(0, 1, (n_test, dim_X))

In [5]:
def order_weights(exp_list):
    ordered_weights = [0 for _ in range(dim_X)]
    for tup in exp_list:
        feature_id = int(tup[0].split('x_')[1][0])
        ordered_weights[feature_id-1] = tup[1]    
    return ordered_weights    

In [6]:
lime_weight_list = []
explainer = lime.lime_tabular.LimeTabularExplainer(X, 
                                                   feature_names=["x_"+str(k) for k in range(1,dim_X+1)], 
                                                   class_names=['f'], 
                                                   verbose=True,
                                                   mode='regression')

for i in range(n_test):
    exp = explainer.explain_instance(X_test[i], f, num_features=dim_X)
    lime_weight_list.append(order_weights(exp.as_list()))  
                            
print(lime_weight_list)    

Intercept 3.2798211806407935
Prediction_local [1.65404743]
Right: 1.8251336731013608
Intercept 2.614071907176627
Prediction_local [3.73197779]
Right: 3.915166083674069
Intercept 3.531281560606753
Prediction_local [0.96230586]
Right: 1.0680580888815143
Intercept 3.0625264408852044
Prediction_local [2.37679817]
Right: 2.5153466906319384
Intercept 3.2086059357743575
Prediction_local [1.84577565]
Right: 1.9418805655150169
Intercept 2.9638014720699735
Prediction_local [2.61361529]
Right: 2.345125005682301
Intercept 2.4382167227734373
Prediction_local [4.3320286]
Right: 3.862974480778787
Intercept 2.5028096225875585
Prediction_local [4.03790463]
Right: 4.384884034895277
Intercept 2.5931288423211623
Prediction_local [3.74344057]
Right: 4.008337975873024
Intercept 2.6118316446665713
Prediction_local [3.71640377]
Right: 3.675022879439817
[[-0.4649202728998083, 0.3027631241814751, -1.4636166047972394], [-0.4745913065092781, 1.0894478160279923, 0.5030493759701495], [-0.16628404660065857, -0.95340

In [7]:
symbolic_model = SymbolicRegressor()
symbolic_model.fit(f, X)

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Now working on term number  1 .
Now working on hyperparameter tree number  1 .
         Current function value: 1.124145
         Iterations: 13
         Function evaluations: 1172
         Gradient evaluations: 145
Now working on hyperparameter tree number  2 .
Optimization terminated successfully.
         Current function value: 0.000000
         Iterations: 79
         Function evaluations: 1800
         Gradient evaluations: 180
The algorithm stopped because the desired precision was achieved.
The tree number  2  was selected as the best.
Backfitting complete.
The current model has the following expression:  6.48080192147146*[ReLU(P1)]**1.00001818129391*hyper((5.22588880302877e-6,), (1.46065927793992, 1.46065927793992), 1.0/[ReLU(P1)])
The current value of the loss is:  2.0027648368941944e-10 .
----------------------------------------------------------------------------------------

In [8]:
symbolic_weight_list = [] 
for k in range(n_test):
    symbolic_weight_list.append(symbolic_model.get_feature_importance(X_test[k]))
    

In [9]:
print(symbolic_weight_list)

[[0.999982498536494, 1.99998695088047, 3.00001826682352], [0.999999496915246, 2.00002094801116, 3.00006926316279], [0.999961904584880, 1.99994576252512, 2.99995648351119], [0.999990329610961, 2.00000261320133, 3.00004176060114], [0.999984146340451, 1.99999024652456, 3.00002321035201], [0.999988743092238, 1.99999944012905, 3.00003700093270], [0.999999233084327, 2.00002042034353, 3.00006847165136], [1.00000170250490, 2.00002535923889, 3.00007588008785], [0.999999957797835, 2.00002186978646, 3.00007064584318], [0.999998246853151, 2.00001844785953, 3.00006551288804]]
