In [1]:
from numpy.random import RandomState
seed = RandomState(1994)

In [2]:
from sklearn import datasets, model_selection, ensemble, metrics

In [3]:
import pprint

def print_sample(feature_names, sample):
    print('\n'.join(f'{name}: {value}' for name, value in zip(feature_names, sample)))

## Classification
Contrastive explanation for an instance of the [Iris](https://archive.ics.uci.edu/ml/datasets/iris) data set

---

**1. Train a (black-box) model on the Iris data**

In [4]:
data = datasets.load_iris()
x_train, x_test, y_train, y_test = model_selection.train_test_split(data.data, 
                                                                data.target, 
                                                                train_size=0.80, 
                                                                random_state=seed)
model = ensemble.RandomForestClassifier(random_state=seed)
model.fit(x_train, y_train)

print('Classifier performance (F1):', metrics.f1_score(y_test, model.predict(x_test), average='weighted'))

Classifier performance (F1): 0.9333333333333333




**2. Perform contrastive explanation**

In [6]:
# Import
import contrastive_explanation as ce

# Select a sample to explain ('questioned data point') why it predicted the fact instead of the foil 
sample = x_test[0]
print_sample(data.feature_names, sample)

# Create a domain mapper (map the explanation to meaningful labels for explanation)
dm = ce.domain_mappers.DomainMapperTabular(x_train,
                                           feature_names=data.feature_names,
                                           contrast_names=data.target_names)

# Create the contrastive explanation object (default is a Foil Tree explanator)
exp = ce.ContrastiveExplanation(dm)

# Explain the instance (sample) for the given model
exp.explain_instance_domain(model.predict_proba, sample)

sepal length (cm): 5.1
sepal width (cm): 3.3
petal length (cm): 1.7
petal width (cm): 0.5


"The model predicted 'setosa' instead of 'versicolor' because 'petal length (cm) <= 2.528 and sepal width (cm) <= 3.569'"

## Regression
Explain an instance of the [Diabetes](http://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_diabetes.html#sklearn.datasets.load_diabetes) data set

**1. Train a (black-box) model on the Diabetes data**

In [7]:
r_data = datasets.load_diabetes()

rx_train, rx_test, ry_train, ry_test = model_selection.train_test_split(r_data.data, 
                                                                        r_data.target, 
                                                                        train_size=0.80, 
                                                                        random_state=seed)
m_cv = ensemble.RandomForestRegressor(random_state=seed)
r_model = model_selection.GridSearchCV(m_cv, param_grid={'n_estimators': [50, 100, 500]})

r_model.fit(rx_train, ry_train)

print('Regressor performance (R-squared):', metrics.r2_score(ry_test, r_model.predict(rx_test)))

NameError: name 'model_reg' is not defined

In [None]:
print(r_data['DESCR'])

**2. Perform contrastive explanation**

In [None]:
import contrastive_explanation as ce

# Select a sample to explain
r_sample = test[1]
print_sample(r_data.feature_names, r_sample)
print('\n')

# Create a domain mapper (still tabular data, but for regression we do not have named labels for the outcome),
# ensure that 'sex' is a categorical feature
r_dm = ce.domain_mappers.DomainMapperTabular(rx_train, 
                                             feature_names=data_reg.feature_names,
                                             categorical_features=[1])

# Create the CE objects, ensure that 'regression' is set to True
# again, we use the Foil Tree explanator, but now we print out intermediary outcomes and steps (verbose)
r_exp = ce.ContrastiveExplanation(r_dm,
                                  regression=True,
                                  explanator=ce.explanators.TreeExplanator(verbose=True),
                                  verbose=False)

# Explain using the model, also include a 'factual' (non-contrastive 'why fact?') explanation
r_exp.explain_instance_domain(r_model.predict, r_sample, include_factual=True)