## Getting started
Upload needed libraries and setting up the dataset

In [1]:
import numpy as np
import pandas as pd
import os
import bellatrex as btrex
print(btrex.__version__)

PLOT_GUI = False

##########################################################################
root_folder = os.getcwd()
print(root_folder)

0.2.2
c:\Users\u0135479\Documents\GitHub\Bellatrex


In [2]:
from sksurv.ensemble import RandomSurvivalForest
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestRegressor

from sklearn.model_selection import train_test_split


from bellatrex.datasets import load_mtr_data, load_mlc_data
from bellatrex.datasets import load_survival_data, load_binary_data, load_regression_data
from bellatrex.utilities import get_auto_setup

X, y = load_binary_data(return_X_y=True)
# X, y = load_regression_data(return_X_y=True)
# X, y = load_survival_data(return_X_y=True)
# X, y = load_mlc_data(return_X_y=True)
# X, y = load_mtr_data(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)

Depending on the prediction task, we train a compatible Random Forest model.

In [3]:
SETUP = get_auto_setup(y) # not necessary, but comfortable while swithcing between mnay prediction tasks
print('Detected prediction task \'SETUP\':', SETUP)

from sksurv.ensemble import RandomSurvivalForest
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestRegressor

### instantiate original R(S)F estimator, works best with some pruning.
if SETUP.lower() in 'survival':
    clf = RandomSurvivalForest(n_estimators=100, min_samples_split=10,
                                n_jobs=-2, random_state=0)

elif SETUP.lower() in ['binary', 'multi-label']:
    clf = RandomForestClassifier(n_estimators=100, min_samples_split=5,
                                n_jobs=-2, random_state=0)

elif SETUP.lower() in ['regression', 'multi-target']:
    clf = RandomForestRegressor(n_estimators=100, min_samples_split=5,
                                n_jobs=-2, random_state=0)


from bellatrex import BellatrexExplain
from bellatrex.wrapper_class import EnsembleWrapper, tree_list_to_model, tree_to_dict

clf.fit(X_train, y_train)
print('Model fitting complete.')

Detected prediction task 'SETUP': binary
Model fitting complete.


In case the RF model to be explained is trained externally, it can be packed with `pack_trained_ensemble`.
Later, such packed pretained model can be loaded to Bellatrex by calling `EnsembleWrapper` as in the following lines: 

In [7]:
from bellatrex.wrapper_class import EnsembleWrapper, pack_trained_ensemble

clf_packed = pack_trained_ensemble(clf) #packs the ensemble model into a lightweight dictionary
print(type(clf_packed))

# Load the pretrained model and make it compatible with Bellatrex through EnsembleWrapper()
clf2 = EnsembleWrapper(clf_packed) # ensures compatibility of the pre-trained model with Bellatrex
print(type(clf2))

<class 'dict'>
<class 'bellatrex.wrapper_class.EnsembleWrapper'>


## Building Explanations through a LocalLy AccuraTe Rule EXtractor:
Now we can fit Bellatrex on the training data and run it on a few test samples
After fitting and tuning the explainer to a specific test isntance, you can:
- plot_overview() to get a representation of the tree learners, and of the selected rules;
    GUI is available for this plotting method (set the paramter `plot_gui = True`).
- plot_visuals() to visualise the selected rules in a more use friendly way.

In [None]:
from bellatrex import BellatrexExplain


#fit RF here. The hyperparameters are given
# compatible with trained model clf, and with a wrapped dictionary as in clf1
Btrex_fitted = BellatrexExplain(clf, set_up='auto',
                                p_grid={"n_clusters": [1, 2, 3]},
                                verbose=3).fit(X_train, y_train)

N_TEST_SAMPLES = 2
for i in range(N_TEST_SAMPLES):

    print(f"Explaining sample i={i}")

    y_train_pred = clf.predict_proba(X_train)[:,1]

    tuned_method = Btrex_fitted.explain(X_test, i)
    tuned_method.plot_overview(plot_gui=False,
                               show=True)

    tuned_method.plot_visuals(plot_max_depth=5,
                              preds_distr=y_train_pred,
                              conf_level=0.9,
                              tot_digits=4)
    # plt.show()
