In [1]:
import sys
import pandas as pd
import numpy as np
import time

from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.model_selection import train_test_split

from generalizedtrees.recipes import trepan
from generalizedtrees.vis.vis import explanation_to_html
from generalizedtrees.features import FeatureSpec

from sklearn.metrics import classification_report

np_rng = np.random.default_rng(8372234)
sk_rng = np.random.RandomState(3957458)

In [2]:
#load data
data_train = pd.read_csv('iris.csv')

In [3]:
#prepair data
data_df = data_train.drop(['species'], axis=1)

encoder = OneHotEncoder(drop = 'if_binary')
lencoder = LabelEncoder()

numeric_features = data_df.select_dtypes(include = 'number')
categorical_features_df = data_df.select_dtypes(exclude = 'number')
categorical_features = encoder.fit_transform(categorical_features_df).toarray()
feature_names = np.append(numeric_features.columns, encoder.get_feature_names(categorical_features_df.columns))

x = np.append(
    numeric_features,
    categorical_features,
    axis = 1)

y = lencoder.fit_transform(data_train['species'])


#splitt edata
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state = 42)

In [4]:
pd.DataFrame({'Feature Name': feature_names})

Unnamed: 0,Feature Name
0,sepal length
1,sepal width
2,petal length
3,petal width


In [5]:
#create a neural network as oracle
from sklearn.model_selection import GridSearchCV
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

clf = GridSearchCV(
    make_pipeline(StandardScaler(), MLPClassifier(solver='lbfgs', alpha=1e-5, random_state=sk_rng)),
    param_grid = {'mlpclassifier__hidden_layer_sizes': [(5,), (10,), (20,), (40,)]},
    refit = True
)
clf.fit(x_train, y_train)
clf.best_estimator_

Pipeline(steps=[('standardscaler', StandardScaler()),
                ('mlpclassifier',
                 MLPClassifier(alpha=1e-05, hidden_layer_sizes=(5,),
                               random_state=RandomState(MT19937) at 0x26F87471340,
                               solver='lbfgs'))])

In [6]:
#prepair the TREPAN algorithm
explanation = trepan(
    m_of_n=True,
    max_tree_size=10,
    impurity='entropy',
    rng = np_rng)

In [7]:
#running TREPAN algorithm
t0 = time.time()

explanation.fit(x_train, clf)

t1 = time.time()

print(f'Time taken to learn explanation: {t1-t0} seconds')

print(explanation.show_tree())

Assuming continuous features in the absence of feature specifications


Time taken to learn explanation: 92.53777551651001 seconds
Test 2 of ['x[2] ≤ 2.45', 'x[3] ≤ 0.8', 'x[1] > 3.6500000000000004', 'x[2] ≤ 5.15', 'x[3] ≤ 1.55']
+--Test 2 of ['x[2] ≤ 2.45', 'x[1] > 3.25', 'x[0] ≤ 5.75']
|  +--[0.48  0.327 0.193]
|  +--Test x[2] ≤ 4.95
|     +--[0.238 0.619 0.143]
|     +--Test x[1] ≤ 2.5
|        +--[0.049 0.326 0.625]
|        +--[0.094 0.458 0.448]
+--Test 2 of ['x[2] ≤ 4.6', 'x[3] ≤ 1.5', 'x[1] > 3.3499999999999996']
   +--[0.403 0.344 0.253]
   +--Test 1 of ['x[2] ≤ 4.6', 'x[3] ≤ 1.5']


In [8]:
#save resulting tree as html
explanation_to_html(explanation, 'explanation.html')

In [9]:
#compaire performances
y_test_trepan = explanation.predict(x_test)
y_test_model = clf.predict(x_test)

In [10]:
#performances of TREPAN and neural network
print('Trepan:')
print(classification_report(y_test, y_test_trepan, target_names=lencoder.classes_))

print('Black Box:')
print(classification_report(y_test, y_test_model, target_names=lencoder.classes_))

Trepan:
                 precision    recall  f1-score   support

    Iris-setosa       1.00      1.00      1.00         6
Iris-versicolor       1.00      1.00      1.00         6
 Iris-virginica       1.00      1.00      1.00         3

       accuracy                           1.00        15
      macro avg       1.00      1.00      1.00        15
   weighted avg       1.00      1.00      1.00        15

Black Box:
                 precision    recall  f1-score   support

    Iris-setosa       1.00      1.00      1.00         6
Iris-versicolor       1.00      1.00      1.00         6
 Iris-virginica       1.00      1.00      1.00         3

       accuracy                           1.00        15
      macro avg       1.00      1.00      1.00        15
   weighted avg       1.00      1.00      1.00        15



In [11]:
#compairing TREPAN and neural network by calculating fidelity
print('Training set fidelity')
print(classification_report(clf.predict(x_train), explanation.predict(x_train), target_names=lencoder.classes_))

print('Test set fidelity')
print(classification_report(y_test_model, y_test_trepan, target_names=lencoder.classes_))

Training set fidelity
                 precision    recall  f1-score   support

    Iris-setosa       0.98      1.00      0.99        44
Iris-versicolor       1.00      0.91      0.95        44
 Iris-virginica       0.94      1.00      0.97        47

       accuracy                           0.97       135
      macro avg       0.97      0.97      0.97       135
   weighted avg       0.97      0.97      0.97       135

Test set fidelity
                 precision    recall  f1-score   support

    Iris-setosa       1.00      1.00      1.00         6
Iris-versicolor       1.00      1.00      1.00         6
 Iris-virginica       1.00      1.00      1.00         3

       accuracy                           1.00        15
      macro avg       1.00      1.00      1.00        15
   weighted avg       1.00      1.00      1.00        15



In [12]:
from sklearn.tree import DecisionTreeClassifier

In [13]:
tree_clf = DecisionTreeClassifier()
tree_model = tree_clf.fit(x_train, y_train)

In [14]:
y_test_tree_model = tree_model.predict(x_test)

In [15]:
print('Tree:')
print(classification_report(y_test, y_test_tree_model, target_names=lencoder.classes_))

Tree:
                 precision    recall  f1-score   support

    Iris-setosa       1.00      1.00      1.00         6
Iris-versicolor       1.00      1.00      1.00         6
 Iris-virginica       1.00      1.00      1.00         3

       accuracy                           1.00        15
      macro avg       1.00      1.00      1.00        15
   weighted avg       1.00      1.00      1.00        15

