# Introduction to `explainy`
In this notebook, we will go over the main functionalities of the library

In [1]:
import pandas as pd
from sklearn.datasets import load_diabetes
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split

In [2]:
pip install explainy

Note: you may need to restart the kernel to use updated packages.


`explainy` allows you to create machine learning model explanations based on four different explanation characteristics:

-   **global**: explanation of system functionality
-   **local**: explanation of decision rationale
-   **contrastive**: tracing of decision path
-   **non-contrastive**: parameter weighting

The explanations algorithms in `explainy` can be categorized as follows:

| | non-contrastive				|contrastive | 
| --- 			| :---: 				| :---: | 
|global|Permutation Feature Importance	| Surrogate Model | 
|local|Shap Values	| Counterfactual Example|


In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

diabetes = load_diabetes()

X_train, X_test, y_train, y_test = train_test_split(
    diabetes.data, diabetes.target, random_state=0
)
X_test = pd.DataFrame(X_test, columns=diabetes.feature_names)
y_test = pd.DataFrame(y_test)

model = RandomForestRegressor(random_state=0).fit(X_train, y_train)

In [4]:
from explainy.explanations.permutation_explanation import PermutationExplanation

number_of_features = 4
sample_index = 1

explainer = PermutationExplanation(X_test, y_test, model, number_of_features)

explanation = explainer.explain(sample_index)
print(explanation)
explainer.plot(kind='bar')

ModuleNotFoundError: No module named 'explainy.explanations'

In [None]:
explainer.plot(kind='box')

Generate explanations with multiple numbers of features to explain the outcoume.
Since the `PermuationExplanation` method is a global explaination method, all samples will have the same feature importance explanation.

In [None]:
# Global, Non-contrastive
samples = [0]
list_number_of_features = [2, 4, 6, 8, 10]

sample_index = 0
for number_of_features in [2, 4, 6, 8]:
    explainer = PermutationExplanation(
        X_test, y_test, model, number_of_features
    )
    explanation = explainer.explain(sample_index)
    explainer.plot(kind='box')
    print(explanation)
    print('\n')

Let's use the `ShapExplanation` to create local explantions for each sample individually.

In [None]:
from explainy.explanations.shap_explanation import ShapExplanation

# Local, Non-contrastive
number_of_features = 4
for sample_index in [0, 1, 2, 3]:

    explainer = ShapExplanation(
        X_test, y_test, model, number_of_features
    )
    explanation = explainer.explain(sample_index)
    explainer.plot(sample_index)
    print(explanation)
    print('\n')


In [None]:
from explainy.explanations.surrogate_model_explanation import SurrogateModelExplanation

# Global, Contrastive
list_number_of_features = [2, 4]
for number_of_features in list_number_of_features:
    for sample_index in [0]:

        explainer = SurrogateModelExplanation(
            X_test, y_test, model, number_of_features
        )
        explanation = explainer.explain(sample_index)
        explainer.plot(sample_index)
        print(explanation)
        print('\n')


In [None]:
from explainy.explanations.counterfactual_explanation import CounterfactualExplanation

# Local, Contrastive
number_of_features = 6
for number_of_features in list_number_of_features:
    for sample_index in [1, 2, 3]:
        explainer = CounterfactualExplanation(
            X_test, y_test, model, number_of_features
        )
        explanation = explainer.explain(sample_index)
        explainer.plot(sample_index)
        print(explanation)
        print('\n')


Finally, we can also compare the explanations using the four different algorithms:

In [None]:
number_of_features = 6
sample_index = 1

for ExplanationObject in [PermutationExplanation, SurrogateModelExplanation, ShapExplanation, CounterfactualExplanation]:
    explainer = ExplanationObject(
        X_test, y_test, model, number_of_features
    )
    explanation = explainer.explain(sample_index)
    explainer.plot(sample_index)
    print(explanation)
    print('\n')