# 0. Preamble

This is a demo notebook for this library.

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import json
import os
import subprocess
import re

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from category_encoders.one_hot import OneHotEncoder
from timeit import default_timer as timer
from copy import deepcopy
from psutil import cpu_count

from tree_asp.rule_extractor import RFGlobalRuleExtractor, RFLocalRuleExtractor
from tree_asp.classifier import RuleClassifier
from tree_asp.clasp_parser import generate_answers
from tree_asp.rule import Rule
from tree_asp.utils import time_print
from hyperparameter import optuna_random_forest
from utils import load_data

SEED = 2020
NUM_CPU = cpu_count(logical=False)

# 1. Loading data

Here we are using one of the datasets in the dataset directory ('adult' dataset)

In [3]:
def preprocess(dataset_name):
    X, y = load_data(dataset_name)
    categorical_features = list(X.columns[X.dtypes == 'category'])
    if len(categorical_features) > 0:
        oh = OneHotEncoder(cols=categorical_features, use_cat_names=True)
        X = oh.fit_transform(X)
        # avoid special character error
        operators = [('>=', '_ge_'),
                     ('<=', '_le_'),
                     ('>',  '_gt_'),
                     ('<',  '_lt_'),
                     ('!=', '_nq_'),
                     ('=',  '_eq_')]
        for op_s, op_r in operators:
            X = X.rename(columns=lambda x: re.sub(op_s, op_r, x))
        X = X.rename(columns=lambda x: re.sub('[^A-Za-z0-9_]+', '_', x))
    feat = X.columns
    return (X, y, feat)

X, y, feat = preprocess('adult')

display(X.head())
display(feat)
display(X.dtypes)

Unnamed: 0,age,workclass_State_gov,workclass_Self_emp_not_inc,workclass_Private,workclass_Federal_gov,workclass_Local_gov,workclass_Self_emp_inc,workclass_Without_pay,workclass_Never_worked,education_Bachelors,...,native_country_Outlying_US_Guam_USVI_etc_,native_country_Scotland,native_country_Trinadad_Tobago,native_country_Greece,native_country_Nicaragua,native_country_Vietnam,native_country_Hong,native_country_Ireland,native_country_Hungary,native_country_Holand_Netherlands
0,39,1,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0
1,50,0,1,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0
2,38,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,53,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,28,0,0,1,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0


Index(['age', 'workclass_State_gov', 'workclass_Self_emp_not_inc',
       'workclass_Private', 'workclass_Federal_gov', 'workclass_Local_gov',
       'workclass_Self_emp_inc', 'workclass_Without_pay',
       'workclass_Never_worked', 'education_Bachelors',
       ...
       'native_country_Outlying_US_Guam_USVI_etc_', 'native_country_Scotland',
       'native_country_Trinadad_Tobago', 'native_country_Greece',
       'native_country_Nicaragua', 'native_country_Vietnam',
       'native_country_Hong', 'native_country_Ireland',
       'native_country_Hungary', 'native_country_Holand_Netherlands'],
      dtype='object', length=103)

age                                  int64
workclass_State_gov                  int64
workclass_Self_emp_not_inc           int64
workclass_Private                    int64
workclass_Federal_gov                int64
                                     ...  
native_country_Vietnam               int64
native_country_Hong                  int64
native_country_Ireland               int64
native_country_Hungary               int64
native_country_Holand_Netherlands    int64
Length: 103, dtype: object

# 2. Train your Tree Ensemble

We are using Random Forest for this demo. Other supported models are Decision Tree and LightGBM.

For simplicity we will not perform hyperparameter optimization and cross validation.  
If you are interested in those, please take a look at the experiment runners (`wrapper_tree_*.py`).

In [4]:
x_train, y_train = X, y

rf = RandomForestClassifier(n_estimators=10, max_depth=5, random_state=SEED, n_jobs=1)
rf.fit(x_train, y_train)

RandomForestClassifier(max_depth=5, n_estimators=10, n_jobs=1,
                       random_state=2020)

# 3. Fit Rule Extractor for your Trees & Explanation type

For Random Forests, we will be using `RFGlobalRuleExtractor` and `RFLocalRuleExtractor`.

In [5]:
# for global explanations
rf_global_extractor = RFGlobalRuleExtractor()
rf_global_extractor.fit(x_train, y_train, model=rf, feature_names=feat)

# for local explanations
rf_local_extractor = RFLocalRuleExtractor()
rf_local_extractor.fit(x_train, y_train, model=rf, feature_names=feat)

<tree_asp.rule_extractor.RFLocalRuleExtractor at 0x7f2ce1fc8250>

# 4.1. Global Explanations

Export the rules in a pandas dataframe

In [6]:
global_exp_df = rf_global_extractor.export_rule_df()

In [7]:
global_exp_df.head()

Unnamed: 0,tree_idx,rule_idx,rule_condition_str,rule_predicted_class,rule_original_predict_class,rule_conditions,rule_conditions_idx,support,size,accuracy,precision_score,recall,f1_score
0,0,0,relationship_Husband <= 0.5 AND occupation_Oth...,1,0.0,[condition_idx=0. condition=relationship_Husba...,"[0, 1, 2, 3, 4]",41,5,40,6,10,7
1,0,1,relationship_Husband <= 0.5 AND occupation_Oth...,1,0.0,[condition_idx=0. condition=relationship_Husba...,"[0, 1, 2, 3, 5]",5,5,73,18,4,6
2,0,2,relationship_Husband <= 0.5 AND occupation_Oth...,1,0.0,[condition_idx=0. condition=relationship_Husba...,"[0, 1, 2, 6, 7]",1,5,76,26,1,1
3,0,3,relationship_Husband <= 0.5 AND occupation_Oth...,1,0.0,[condition_idx=0. condition=relationship_Husba...,"[0, 1, 2, 6, 8]",0,5,76,6,0,0
4,0,4,relationship_Husband <= 0.5 AND occupation_Oth...,1,1.0,[condition_idx=0. condition=relationship_Husba...,"[0, 1, 9, 10, 11]",3,5,76,53,6,10


In [8]:
global_res_str = rf_global_extractor.transform(x_train, y_train)
print(global_res_str)

rule(0). condition(0,0). condition(0,1). condition(0,2). condition(0,3). condition(0,4). support(0,41). size(0,5). accuracy(0,40). precision(0,6). recall(0,10). f1_score(0,7). error_rate(0,60). predict_class(0,1).
rule(1). condition(1,0). condition(1,1). condition(1,2). condition(1,3). condition(1,5). support(1,5). size(1,5). accuracy(1,73). precision(1,18). recall(1,4). f1_score(1,6). error_rate(1,27). predict_class(1,1).
rule(2). condition(2,0). condition(2,1). condition(2,2). condition(2,6). condition(2,7). support(2,1). size(2,5). accuracy(2,76). precision(2,26). recall(2,1). f1_score(2,1). error_rate(2,24). predict_class(2,1).
rule(3). condition(3,0). condition(3,1). condition(3,2). condition(3,6). condition(3,8). support(3,0). size(3,5). accuracy(3,76). precision(3,6). recall(3,0). f1_score(3,0). error_rate(3,24). predict_class(3,1).
rule(4). condition(4,0). condition(4,1). condition(4,9). condition(4,10). condition(4,11). support(4,3). size(4,5). accuracy(4,76). precision(4,53).

In [9]:
# write rules into temporary files

with open('./tree_asp/tmp/scratch/rules.lp', 'w', encoding='utf-8') as outfile:
    outfile.write(global_res_str)

with open('./tree_asp/tmp/scratch/class.lp', 'w', encoding='utf-8') as outfile:
    outfile.write('class(1).'.format(int(y_train.nunique() - 1)))

In [10]:
# run clingo in a subprocess

o = subprocess.run(['clingo', 
                    './tree_asp/asp_encoding/global_accuracy_coverage.lp', 
                    './tree_asp/tmp/scratch/rules.lp',
                    './tree_asp/tmp/scratch/class.lp', 
                    '0', '--parallel-mode=8,split'], 
                   capture_output=True, timeout=600)

In [11]:
# verify that clingo finds at least one answer set

answers, clasp_info = generate_answers(o.stdout.decode())
print(clasp_info)

clingo version 5.6.2
Reading from ...p_encoding/global_accuracy_coverage.lp ...
Solving...
Models: 3
Optimum: yes
Optimization: -15 -10 1
Calls: 1
Time: 0.042s (Solving: 0.00s 1st Model: 0.00s Unsat: 0.00s)
CPU Time: 0.042s
Threads: 8        (Winner: 3)


In [12]:
# we will only check the first optimal answer in this demo

for ans_set in answers:
    if not ans_set.is_optimal:
        continue
    else:
        for ans in ans_set.answer:   # list(tuple(str, tuple(int)))
            pat_idx = ans[-1][0]
            pat = rf_global_extractor.rules_[pat_idx]  # type: Rule
            print(f'class {pat.predict_class} IF {pat.rule_str}')
        break

class 1 IF occupation_Handlers_cleaners <= 0.5 AND sex_Female <= 0.5 AND occupation_Machine_op_inspct <= 0.5 AND marital_status_Married_civ_spouse > 0.5 AND hours_per_week > 33.5
class 1 IF marital_status_Married_civ_spouse > 0.5 AND occupation_Exec_managerial > 0.5 AND education_HS_grad <= 0.5 AND workclass_Self_emp_not_inc <= 0.5 AND education_9th <= 0.5
class 1 IF sex_Male > 0.5 AND native_country_Mexico <= 0.5 AND relationship_Not_in_family <= 0.5 AND relationship_Husband > 0.5 AND hours_per_week > 41.5


# 4.2 Local Explanation

In [13]:
# NOTE: local explanation is for single instances, 
# so if you need more than 1 instance explained, you need to loop over this.

In [14]:
local_res_str = rf_local_extractor.transform(x_train.loc[[10]], y_train.loc[10], model=rf)

In [15]:
# write rules into temporary files

with open('./tree_asp/tmp/scratch/rules.lp', 'w', encoding='utf-8') as outfile:
    outfile.write(local_res_str[0])

with open('./tree_asp/tmp/scratch/class.lp', 'w', encoding='utf-8') as outfile:
    outfile.write('class(1).'.format(int(y_train.nunique() - 1)))

In [16]:
# run clingo in a subprocess

o = subprocess.run(['clingo', 
                    './tree_asp/asp_encoding/local_accuracy_coverage.lp', 
                    './tree_asp/tmp/scratch/rules.lp',
                    './tree_asp/tmp/scratch/class.lp', 
                    '0', '--parallel-mode=8,split'], 
                   capture_output=True, timeout=600)

In [17]:
# verify that clingo finds at least one answer set

answers, clasp_info = generate_answers(o.stdout.decode())
print(clasp_info)

clingo version 5.6.2
Reading from ...sp_encoding/local_accuracy_coverage.lp ...
Solving...
Models: 4
Optimum: yes
Optimization: -15 -10 1
Calls: 1
Time: 0.002s (Solving: 0.00s 1st Model: 0.00s Unsat: 0.00s)
CPU Time: 0.003s
Threads: 8        (Winner: 3)


In [18]:
# we will only check the first optimal answer in this demo

for ans_set in answers:
    if not ans_set.is_optimal:
        continue
    else:
        for ans in ans_set.answer:   # list(tuple(str, tuple(int)))
            pat_idx = ans[-1][0]
            pat = rf_local_extractor.rules_[pat_idx]  # type: Rule
            print(f'class {pat.predict_class} IF {pat.rule_str}')
            break
        break

class 1 IF 1 IF occupation_Handlers_cleaners <= 0.5 AND sex_Female <= 0.5 AND occupation_Machine_op_inspct <= 0.5 AND marital_status_Married_civ_spouse > 0.5 AND hours_per_week > 33.5


# 5. Changing the selection criteria

You can edit and/or change the selection criteria by writing them in the ASP language of clingo.

In [19]:
# For example, this is the default encoding which maximizes accuracy and coverage
with open('./tree_asp/asp_encoding/global_accuracy_coverage.lp', 'r') as infile:
    lines = infile.readlines()
    for l in lines:
        print(l.strip())

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% List of Atoms
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% rule(I)            % I is index, I-th rule
%% condition(I,C)     % C is index, C-th condition
%% size(I,S)          % S is the number of conditions (items) in the rule
%% accuracy(I,A)      % A is the accuracy of this rule
%% error_rate(I,E)    % E is the error rate (1-accuracy) of this rule
%% precision(I,P)     % P is the precision of this rule
%% recall(I,R)        % R is the recall of this rule
%% f1_score(I,F)      % F is the F1 score of this rule
%% predict_class(I,X) % X is the predicted class (head) of this rule
%% class(K)           % K is the target class
%% selected(I)        % I is the selected rule
%% valid(I)           % I is a valid rule (not invalid)
%% invalid(I)         % True when I is invalid
%% rule_overlap(I,J,Cn) % Cn is the number of conditions shared between rules I and J

%%%%%%%%%%

In [20]:
from tempfile import NamedTemporaryFile

In [21]:
changed_encoding="""
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Settings
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% class(0..2).
% You need to add class(0..K). by yourself if you're running this script
% manually. If this is ran automatically there's a separate file with just
% class(1). in it.

% we would like to pick at least 1 pattern for each predict_class
1 { selected(I) :  predict_class(I, K), valid(I) } 2 :- class(K).

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% User Defined Local Constraints and Selection Criteria
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% rule is not invalid
valid(I) :- rule(I), not invalid(I).

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Dominance Relation Definitions
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% I is dominated by J
gt_acc_geq_cov(J) :- selected(I), valid(J),
accuracy(I,Ai), accuracy(J,Aj), support(I,Spi), support(J,Spj),
Ai < Aj, Spi <= Spj.

geq_acc_gt_cov(J) :- selected(I), valid(J),
accuracy(I,Ai), accuracy(J,Aj), support(I,Spi), support(J,Spj),
Ai <= Aj, Spi < Spj.

dominated :- valid(J), gt_acc_geq_cov(J).
dominated :- valid(J), geq_acc_gt_cov(J).

% cannot be dominated
:- dominated.

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Optimization Over Answer Sets
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% number of shared conditions between rules
rule_overlap(I,J,Cn) :- selected(I), selected(J), I!=J,
Cn = #count{Ci : Ci=Cj, condition(I,Ci), condition(J,Cj)}.

selected_rules(SR) :- SR = #count { I : selected(I) }, SR != 0.

#maximize { Ai/(S*SR)@2,I : selected(I), size(I,S), accuracy(I,Ai), selected_rules(SR) }.
#maximize { Sp/S@3,I : selected(I), size(I,S), support(I,Sp) }.
#minimize { Cn : selected(I), selected(J), rule_overlap(I,J,Cn) }.

#show selected/1.
"""

In [22]:
tmp_file = './tree_asp/tmp/scratch/ec1'

with open(tmp_file, 'w') as ec:
    ec.write(changed_encoding)
    
o = subprocess.run(['clingo', 
                    tmp_file, 
                    './tree_asp/tmp/scratch/rules.lp',
                    './tree_asp/tmp/scratch/class.lp', 
                    '0', '--parallel-mode=8,split'], 
                   capture_output=True, timeout=600)

In [23]:
answers, clasp_info = generate_answers(o.stdout.decode())
# we will only check the first optimal answer in this demo

for ans_set in answers:
    if not ans_set.is_optimal:
        continue
    else:
        for ans in ans_set.answer:   # list(tuple(str, tuple(int)))
            pat_idx = ans[-1][0]
            pat = rf_global_extractor.rules_[pat_idx]  # type: Rule
            print(f'class {pat.predict_class} IF {pat.rule_str}')
        break

class 1 IF occupation_Handlers_cleaners <= 0.5 AND sex_Female <= 0.5 AND occupation_Machine_op_inspct <= 0.5 AND marital_status_Married_civ_spouse > 0.5 AND hours_per_week > 33.5
class 1 IF sex_Female <= 0.5 AND marital_status_Never_married <= 0.5 AND occupation_Other_service <= 0.5 AND workclass_Federal_gov <= 0.5 AND occupation_Handlers_cleaners <= 0.5
