In [1]:
import pandas as pd
from collections import defaultdict
import time

from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

from rf_counterfactuals import RandomForestExplainer, visualize

### 'miniadult.csv' is a 5% subset of original adult data

In [2]:
adult_dataset = pd.read_csv("datasets/miniadult.csv")
adult_dataset

Unnamed: 0,age,workclass,fnlwgt,education,educational-num,marital-status,occupation,relationship,race,gender,capital-gain,capital-loss,hours-per-week,native-country,income
0,17,Private,209949,11th,7,Never-married,Sales,Own-child,White,Female,0,1602,12,United-States,<=50K
1,19,Local-gov,169853,HS-grad,9,Never-married,Craft-repair,Own-child,White,Male,0,0,40,United-States,<=50K
2,35,Federal-gov,179262,HS-grad,9,Never-married,Adm-clerical,Unmarried,White,Female,0,0,40,United-States,<=50K
3,59,Private,126677,Some-college,10,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,50,United-States,>50K
4,30,Self-emp-not-inc,247328,Assoc-voc,11,Separated,Sales,Not-in-family,White,Male,0,0,40,Mexico,<=50K
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2437,19,Private,140985,Some-college,10,Never-married,Adm-clerical,Other-relative,White,Male,0,0,25,United-States,<=50K
2438,46,Private,99385,Bachelors,13,Separated,Exec-managerial,Unmarried,White,Female,0,0,40,United-States,<=50K
2439,31,Private,213750,Some-college,10,Never-married,Sales,Not-in-family,White,Male,0,0,45,United-States,<=50K
2440,51,Private,135388,12th,8,Widowed,Machine-op-inspct,Not-in-family,White,Male,0,1564,40,United-States,>50K


### Encode features to allow scikit-learn Random Forest work on

In [3]:
d = defaultdict(LabelEncoder)

adult_dataset = adult_dataset.apply(lambda x: d[x.name].fit_transform(x))
adult_dataset

Unnamed: 0,age,workclass,fnlwgt,education,educational-num,marital-status,occupation,relationship,race,gender,capital-gain,capital-loss,hours-per-week,native-country,income
0,0,3,1477,1,6,4,11,3,4,0,0,8,10,34,0
1,2,2,988,11,8,4,2,3,4,1,0,0,36,34,0
2,18,1,1114,11,8,4,1,4,4,0,0,0,36,34,0
3,42,3,609,15,9,2,3,0,4,1,0,0,46,34,1
4,13,5,1730,8,10,5,11,1,4,1,0,0,36,23,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2437,2,3,712,15,9,4,1,2,4,1,0,0,21,34,0
2438,29,3,369,9,12,5,3,4,4,0,0,0,36,34,0
2439,14,3,1503,15,9,4,11,1,4,1,0,0,41,34,0
2440,34,3,667,2,7,6,6,1,4,1,0,5,36,34,1


### Split dataset to test and train and train RF classifer on training data

In [4]:
train, test = train_test_split(adult_dataset, train_size=0.67, random_state=420, stratify=adult_dataset['income'])

X_train = train.loc[:, train.columns!="income"]
y_train = train["income"]

X_test = test.loc[:, test.columns!="income"]
y_test = test["income"]

rf = RandomForestClassifier()
rf.fit(X_train, y_train)

### Run a RandomForestExplainer on pretrained RF and look for counterfactual examples which leads to change class label '0' (<=50K) -> '1' (>50K)

Calculation may take a couple of minutes

In [5]:
categorical_features = [1, 3, 5, 6, 7, 8, 9, 13] # workclass, education, marital-status, occupation, relationship, race, gender, native-country
frozen_features = [8, 9] # race, gender
left_frozen_features = [0] # age

### Make an RandomForestExplainer object with input of: RandomForest model, training data and constraints
rfe = RandomForestExplainer(rf, X_train, y_train, categorical_features=categorical_features, 
                            left_frozen_features=left_frozen_features, frozen_features=frozen_features)

### Look for counterfactual examples in test data, which lead to change label's value from '0'('<=50K') to '1'('>50K')
### Counterfactual examples are selected from first Pareto front based on 'hoem' and 'unmatched_components_ratio' metrics value
X_test_label_0 = X_test[y_test==0]

start_time = time.time()
counterfactuals = rfe.explain_with_multiple_metrics(X_test_label_0, 1, metrics=('hoem', 'unmatched_components_ratio'), limit=1)
end_time = time.time()

total_time = end_time - start_time
print(f"Finished in {total_time: 1.4f}s")

total_cfs = sum([len(c) for c in counterfactuals])
print(f"Total counterfactuals found: {total_cfs}")

[1/3] Extracting positive paths.
[2/3] Generating counterfactual examples for each tree. Total number of tasks: 100


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done   5 tasks      | elapsed:   27.4s
[Parallel(n_jobs=-1)]: Done  10 tasks      | elapsed:   40.4s
[Parallel(n_jobs=-1)]: Done  17 tasks      | elapsed:  1.0min
[Parallel(n_jobs=-1)]: Done  24 tasks      | elapsed:  1.3min
[Parallel(n_jobs=-1)]: Done  33 tasks      | elapsed:  1.8min
[Parallel(n_jobs=-1)]: Done  42 tasks      | elapsed:  2.3min
[Parallel(n_jobs=-1)]: Done  53 tasks      | elapsed:  2.8min
[Parallel(n_jobs=-1)]: Done  64 tasks      | elapsed:  3.3min
[Parallel(n_jobs=-1)]: Done  77 tasks      | elapsed:  4.0min
[Parallel(n_jobs=-1)]: Done  90 tasks      | elapsed:  4.7min
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:  5.1min finished


[3/3] Calculating loss function. Total number of tasks: 566


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done   5 tasks      | elapsed:    5.3s
[Parallel(n_jobs=-1)]: Done  10 tasks      | elapsed:   12.9s
[Parallel(n_jobs=-1)]: Done  17 tasks      | elapsed:   19.8s
[Parallel(n_jobs=-1)]: Done  24 tasks      | elapsed:   26.4s
[Parallel(n_jobs=-1)]: Done  33 tasks      | elapsed:   37.2s
[Parallel(n_jobs=-1)]: Done  42 tasks      | elapsed:   46.9s
[Parallel(n_jobs=-1)]: Done  53 tasks      | elapsed:  1.0min
[Parallel(n_jobs=-1)]: Done  64 tasks      | elapsed:  1.2min
[Parallel(n_jobs=-1)]: Done  77 tasks      | elapsed:  1.6min
[Parallel(n_jobs=-1)]: Done  90 tasks      | elapsed:  1.9min
[Parallel(n_jobs=-1)]: Done 105 tasks      | elapsed:  2.2min
[Parallel(n_jobs=-1)]: Done 120 tasks      | elapsed:  2.4min
[Parallel(n_jobs=-1)]: Done 137 tasks      | elapsed:  2.7min
[Parallel(n_jobs=-1)]: Done 154 tasks      | elapsed:  3.0min
[Parallel(n_jobs=-1)]: Done 173 tasks      | elapsed:  3

Finished in  1098.7639s
Total counterfactuals found: 566


### Visualize a sample row and its counterfactual example from dataset

In [6]:
### Visualize an example row (row_no = 0) of data with its counterfactual
row_index_to_visualize = 0

row = X_test_label_0.iloc[row_index_to_visualize]

# First counterfactual found for row 0th
cf = counterfactuals[row_index_to_visualize].iloc[0]

print(f"row label: {rf.predict(row.to_frame(0).T)[0]} |\t cf label: {rf.predict(cf.to_frame(0).T)[0]}")

# Provide 'd', a dict with encoded values of features, to decode them in visualization
print(visualize(rfe, row, cf, d))

row label: 0 |	 cf label: 1
                               X               X'  difference  \
age                           25             25.0       0.000   
workclass                Private          Private       0.000   
fnlwgt                      1894         2209.787     315.787   
education           Some-college     Some-college       0.000   
educational-num                9              9.0       0.000   
marital-status          Divorced         Divorced       0.000   
occupation       Exec-managerial  Exec-managerial       0.000   
relationship       Not-in-family    Not-in-family       0.000   
race                       White            White       0.000   
gender                    Female           Female       0.000   
capital-gain                   0              0.0       0.000   
capital-loss                   0              0.0       0.000   
hours-per-week                36           39.624       3.624   
native-country     United-States    United-States       0.000 