In [1]:
import numpy as np

import fatf
import fatf.utils.data.datasets as fatf_datasets
from sklearn.neural_network import MLPClassifier as NN
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier

import itertools as it
from supp import sample
# Fix random seed
fatf.setup_random_seed(42)

21-Nov-02 21:51:54 fatf         INFO     Seeding RNGs using the input parameter.
21-Nov-02 21:51:54 fatf         INFO     Seeding RNGs with 42.


In [2]:

# Load data
iris_data_dict = fatf_datasets.load_iris()
iris_X = iris_data_dict['data']
iris_y = iris_data_dict['target']
iris_feature_names = iris_data_dict['feature_names'].tolist()
iris_class_names = iris_data_dict['target_names'].tolist()

# Train a model
# clf = fatf_models.KNN()
clf = NN(max_iter=5000)
clf.fit(iris_X, iris_y)

MLPClassifier(max_iter=5000)

In [3]:
def get_explainers(x, predict_fn, surrogate_model, kwargs_list, attributes):
    radii = np.linspace(0.1, 2, 10)
    surrogates = {}
    for r in radii:
        sample_x = sample(x, r, 1000)
        sample_y = predict_fn(sample_x)
        
        for kwargs in kwargs_list:
            surrogate = surrogate_model(**kwargs).fit(sample_x, sample_y)
            out = {
                'surrogate': surrogate,
                'score': surrogate.score(sample_x, sample_y),
            }
            
            for attr_name, attr in attributes.items():
                out[attr_name] = attr(surrogate)
                
            surrogates[tuple([r]+list(kwargs.values()))] = out
    return surrogates

In [4]:
complexity_attributes = {
    'depth': lambda x: x.get_depth(),
    'n_leaves': lambda x: x.get_n_leaves(),
    'fis': lambda x: x.feature_importances_,
}

parameters = {
    'max_depth': [3, 4, 5],
}

kwargs_list = [
    dict(zip(parameters.keys(), _tuple))
    for _tuple in it.product(*parameters.values())
]

subject = iris_X[[0], :]

In [5]:
out = get_explainers(
    subject,
    clf.predict,
    DecisionTreeClassifier,
    kwargs_list,
    complexity_attributes
)

In [6]:
out

{(0.1, 3): {'surrogate': DecisionTreeClassifier(max_depth=3),
  'score': 1.0,
  'depth': 0,
  'n_leaves': 1,
  'fis': array([0., 0., 0., 0.])},
 (0.1, 4): {'surrogate': DecisionTreeClassifier(max_depth=4),
  'score': 1.0,
  'depth': 0,
  'n_leaves': 1,
  'fis': array([0., 0., 0., 0.])},
 (0.1, 5): {'surrogate': DecisionTreeClassifier(max_depth=5),
  'score': 1.0,
  'depth': 0,
  'n_leaves': 1,
  'fis': array([0., 0., 0., 0.])},
 (0.3111111111111111, 3): {'surrogate': DecisionTreeClassifier(max_depth=3),
  'score': 1.0,
  'depth': 0,
  'n_leaves': 1,
  'fis': array([0., 0., 0., 0.])},
 (0.3111111111111111, 4): {'surrogate': DecisionTreeClassifier(max_depth=4),
  'score': 1.0,
  'depth': 0,
  'n_leaves': 1,
  'fis': array([0., 0., 0., 0.])},
 (0.3111111111111111, 5): {'surrogate': DecisionTreeClassifier(max_depth=5),
  'score': 1.0,
  'depth': 0,
  'n_leaves': 1,
  'fis': array([0., 0., 0., 0.])},
 (0.5222222222222223, 3): {'surrogate': DecisionTreeClassifier(max_depth=3),
  'score': 1.0