In [1]:
from rf_counterfactuals import RandomForestExplainer, visualize

from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

In [2]:
### Load iris dataset as Pandas Dataframe (other formats aren't supported yet) and split. Then train RF classifier on it
X, y = load_iris(return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.67, random_state=420, stratify=y)
rf = RandomForestClassifier()
rf.fit(X_train, y_train)

In [3]:
### Make an RandomForestExplainer object with input of: RandomForest model, training data
rfe = RandomForestExplainer(rf, X_train, y_train)
### Look for counterfactual examples (max/limit 1) in test data, which lead to change label's value from '0'('setosa') to '2'('virginica')
### Counterfactual examples are selected based on 'hoem' metric's value
X_test_label_0 = X_test[y_test==0]
counterfactuals = rfe.explain_with_single_metric(X_test_label_0, 2, metric='hoem', limit=1)

[1/3] Extracting positive paths.
[2/3] Generating counterfactual examples for each tree. Total number of tasks: 100


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done   5 tasks      | elapsed:    7.2s
[Parallel(n_jobs=-1)]: Done  10 tasks      | elapsed:    7.4s
[Parallel(n_jobs=-1)]: Done  17 tasks      | elapsed:    7.9s
[Parallel(n_jobs=-1)]: Done  24 tasks      | elapsed:    8.4s
[Parallel(n_jobs=-1)]: Done  33 tasks      | elapsed:    9.0s
[Parallel(n_jobs=-1)]: Done  42 tasks      | elapsed:    9.7s
[Parallel(n_jobs=-1)]: Done  53 tasks      | elapsed:   10.8s
[Parallel(n_jobs=-1)]: Done  64 tasks      | elapsed:   11.8s
[Parallel(n_jobs=-1)]: Done  77 tasks      | elapsed:   12.9s
[Parallel(n_jobs=-1)]: Done  90 tasks      | elapsed:   13.9s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:   14.6s finished


[3/3] Calculating loss function. Total number of tasks: 17


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done   5 tasks      | elapsed:    0.5s
[Parallel(n_jobs=-1)]: Done  10 tasks      | elapsed:    0.7s
[Parallel(n_jobs=-1)]: Done  12 out of  17 | elapsed:    0.7s remaining:    0.2s
[Parallel(n_jobs=-1)]: Done  14 out of  17 | elapsed:    1.1s remaining:    0.2s
[Parallel(n_jobs=-1)]: Done  17 out of  17 | elapsed:    1.3s finished


In [4]:
### Visualize an example row (row_no = 0) of data with its counterfactual
row_index_to_visualize = 0

row = X_test_label_0.iloc[row_index_to_visualize]

# First counterfactual found for row 0th
cf = counterfactuals[row_index_to_visualize].iloc[0]

print(f"row label: {rf.predict(row.to_frame(0).T)[0]} |\t cf label: {rf.predict(cf.to_frame(0).T)[0]}")
print(visualize(rfe, row, cf))

row label: 0 |	 cf label: 2
                     X     X'  difference constraints
sepal length (cm)  5.8  6.283       0.483            
sepal width (cm)   4.0  4.000       0.000            
petal length (cm)  1.2  2.626       1.426            
petal width (cm)   0.2  1.876       1.676            
