In [3]:
from z3 import *
from sklearn.datasets import load_iris
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier

In [4]:
from logic_explain_ml.xgboost import XGBoostExplainer

In [5]:
iris = load_iris()
X = pd.DataFrame(iris.data, columns=iris.feature_names)
y = iris.target

# y = np.where(y == 0, 0, 1)  # converte em binario
y[y == 2] = 0
# X = X.iloc[:, :2] # corta colunas do df

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=101
)

xgbc = XGBClassifier(n_estimators=100, max_depth=3, learning_rate=0.1)
xgbc.fit(X_train, y_train)

preds = xgbc.predict(X_test)
preds

array([0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0,
       1, 1, 1, 1, 1, 0, 0, 0])

In [6]:
explainer = XGBoostExplainer(xgbc, X)
explainer.fit()

In [7]:
xgbc.feature_importances_, xgbc.feature_names_in_

(array([0.03358656, 0.02979117, 0.3463788 , 0.59024346], dtype=float32),
 array(['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)',
        'petal width (cm)'], dtype='<U17'))

In [8]:
sample = [5.5, 4.2, 1.4, 0.2]
exp = explainer.explain(sample, reorder="asc")
print(exp)

[petal length (cm) == 1.4]


In [9]:
for i in range(X_test.shape[0]):
    print(explainer.explain(X_test.values[i], reorder="asc"))

[petal length (cm) == 1.4]
[petal length (cm) == 1.3]
[petal length (cm) == 1.6]
[sepal width (cm) == 3, sepal length (cm) == 7.2, petal length (cm) == 5.8]
[petal length (cm) == 4.7, petal width (cm) == 1.4]
[petal width (cm) == 1.8]
[petal length (cm) == 4.5, petal width (cm) == 1.5]
[petal length (cm) == 4, petal width (cm) == 1.3]
[petal width (cm) == 1.9]
[petal length (cm) == 1.4]
[petal width (cm) == 2]
[petal length (cm) == 1.5]
[petal length (cm) == 1.5]
[petal width (cm) == 2]
[petal width (cm) == 1.8]
[petal length (cm) == 4.3, petal width (cm) == 1.3]
[petal length (cm) == 4.1, petal width (cm) == 1.3]
[petal length (cm) == 4.2, petal width (cm) == 1.3]
[petal length (cm) == 1.3]
[sepal length (cm) == 6.1, petal length (cm) == 5.6]
[petal length (cm) == 3.8, petal width (cm) == 1.1]
[petal length (cm) == 1.5]
[petal length (cm) == 4.4, petal width (cm) == 1.2]
[petal length (cm) == 4.4, petal width (cm) == 1.4]
[petal length (cm) == 4.3, petal width (cm) == 1.3]
[petal leng