In [8]:
# %pip install scikit-learn
# %pip install pandas 
# %pip install statsmodels
# %pip install xgboost
# %pip install rpy2
# %pip install shap
# %pip install matplotlib
# %pip install keras

In [9]:
from sklearn.ensemble import RandomForestRegressor
from shaprpy.datasets import load_california_housing

import pandas as pd
import xgboost as xgb
from statsmodels.datasets import get_rdataset
from shaprpy import explain

In [10]:
def xgboost_model(x_train, y_train, from_file, filename=None): 
    if (from_file): 
        if ( not filename):
            filename = '../xgboost_model'
        model = xgb.Booster()
        model.load_model(filename)
        model.feature_names = x_train.columns.tolist()
    else: 
        model = xgb.XGBRegressor(n_estimators=20, verbosity=0)
        model.fit(x_train, y_train)
    return model

### Verifying against R results

In [11]:
airquality = get_rdataset("airquality", "datasets").data
data = airquality.dropna()

x_var = ["Solar.R", "Wind", "Temp", "Month"]
y_var = "Ozone"

x_train = data[x_var]
y_train = data[y_var]

In [12]:
model = xgboost_model(x_train=x_train, y_train=y_train, from_file=True)

In [15]:
explanation_SAGE = explain(
    model = model,
    x_train = x_train,
    x_explain = x_train,
    approach = 'independence',
    phi0 = y_train.mean().item(),
    seed=1,
    sage = True, 
    response = y_train, 
    verbose = None
)

print(explanation_SAGE["shapley_values_est"])

   explain_id         none    Solar.R        Wind        Temp     Month
1         NaN -1097.314474  146.53804  442.288023  498.620822  3.502053


In [14]:
explanation_SHAP = explain(   
    model = model,
    x_train = x_train,
    x_explain = x_train,
    approach = 'independence',
    phi0 = float(y_train.mean().item()),
    seed= 1,
    verbose = None
)

print(explanation_SHAP["shapley_values_est"])

     explain_id       none   Solar.R       Wind       Temp     Month
1             1  42.099098  3.030332   4.310215 -12.861986  3.110408
2             2  42.099098 -3.358880   3.691014 -11.697075  2.388401
3             3  42.099099 -2.793629  -9.692032 -18.139897  1.190093
4             4  42.099099 -0.513363  -9.576654 -14.383259  0.620829
5             5  42.099101 -1.579514  -2.804843 -15.939206  1.133936
..          ...        ...       ...        ...        ...       ...
107         107  42.099097 -7.574907  -9.163595 -11.039575 -0.891674
108         108  42.099099  1.318196   2.753162 -14.009171 -2.367950
109         109  42.099099 -0.217088 -10.508347 -15.175035 -1.497503
110         110  42.099100 -2.635532   1.077513 -18.545079 -2.513320
111         111  42.099100  4.344556 -11.707045 -13.445780 -1.158783

[111 rows x 6 columns]
