-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtest_limeexplainer.py
194 lines (142 loc) · 6.77 KB
/
test_limeexplainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# pylint: disable=import-error, wrong-import-position, wrong-import-order, duplicate-code
"""LIME explainer test suite"""
from common import *
import pytest
from trustyai.explainers import LimeExplainer
from trustyai.utils import TestModels
from trustyai.model import feature, Model, simple_prediction
from trustyai.metrics import ExplainabilityMetrics
from trustyai.visualizations import plot
from org.kie.trustyai.explainability.local import (
LocalExplanationException,
)
def mock_features(n_features: int):
return [mock_feature(i, f"f-num{i}") for i in range(n_features)]
def test_empty_prediction():
"""Check if the explanation returned is not null"""
lime_explainer = LimeExplainer(seed=0, samples=10, perturbations=1)
inputs = []
model = TestModels.getSumSkipModel(0)
outputs = model.predict([inputs])[0].outputs
with pytest.raises(LocalExplanationException):
lime_explainer.explain(inputs=inputs, outputs=outputs, model=model)
def test_non_empty_input():
"""Test for non-empty input"""
lime_explainer = LimeExplainer(seed=0, samples=10, perturbations=1)
features = [feature(name=f"f-num{i}", value=i, dtype="number") for i in range(4)]
model = TestModels.getSumSkipModel(0)
outputs = model.predict([features])[0].outputs
saliency_map = lime_explainer.explain(inputs=features, outputs=outputs, model=model)
assert saliency_map is not None
def test_sparse_balance(): # pylint: disable=too-many-locals
"""Test sparse balance"""
for n_features in range(1, 4):
lime_explainer_no_penalty = LimeExplainer(samples=100, penalise_sparse_balance=False)
features = mock_features(n_features)
model = TestModels.getSumSkipModel(0)
outputs = model.predict([features])[0].outputs
saliency_map_no_penalty = lime_explainer_no_penalty.explain(
inputs=features, outputs=outputs, model=model
).saliency_map()
assert saliency_map_no_penalty is not None
decision_name = "sum-but0"
saliency_no_penalty = saliency_map_no_penalty.get(decision_name)
lime_explainer = LimeExplainer(samples=100, penalise_sparse_balance=True)
saliency_map = lime_explainer.explain(inputs=features, outputs=outputs, model=model).saliency_map()
assert saliency_map is not None
saliency = saliency_map.get(decision_name)
for i in range(len(features)):
score = saliency.getPerFeatureImportance().get(i).getScore()
score_no_penalty = (
saliency_no_penalty.getPerFeatureImportance().get(i).getScore()
)
assert abs(score) <= abs(score_no_penalty)
def test_normalized_weights():
"""Test normalized weights"""
lime_explainer = LimeExplainer(normalise_weights=True, perturbations=2, samples=10)
n_features = 4
features = mock_features(n_features)
model = TestModels.getSumSkipModel(0)
outputs = model.predict([features])[0].outputs
saliency_map = lime_explainer.explain(inputs=features, outputs=outputs, model=model).saliency_map()
assert saliency_map is not None
decision_name = "sum-but0"
saliency = saliency_map.get(decision_name)
per_feature_importance = saliency.getPerFeatureImportance()
for feature_importance in per_feature_importance:
assert -3.0 < feature_importance.getScore() < 3.0
def lime_plots(block):
"""Test normalized weights"""
lime_explainer = LimeExplainer(normalise_weights=False, perturbations=2, samples=10)
n_features = 15
features = mock_features(n_features)
model = TestModels.getSumSkipModel(0)
outputs = model.predict([features])[0].outputs
explanation = lime_explainer.explain(inputs=features, outputs=outputs, model=model)
plot(explanation, block=block)
plot(explanation, block=block, render_bokeh=True)
plot(explanation, block=block, output_name="sum-but0")
plot(explanation, block=block, output_name="sum-but0", render_bokeh=True)
@pytest.mark.block_plots
def test_lime_plots_blocking():
lime_plots(True)
def test_lime_plots():
lime_plots(False)
def test_lime_v2():
np.random.seed(0)
data = pd.DataFrame(np.random.rand(1, 5)).values
model_weights = np.random.rand(5)
predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1)
model = Model(predict_function)
explainer = LimeExplainer(samples=100, perturbations=2, seed=23, normalise_weights=False)
explanation = explainer.explain(inputs=data, outputs=model(data), model=model)
for score in explanation.as_dataframe()["output-0"]['Saliency']:
assert score != 0
for out_name, df in explanation.as_dataframe().items():
assert "Feature" in df
assert "output" in out_name
assert all([x in str(df) for x in "01234"])
def test_impact_score():
np.random.seed(0)
data = pd.DataFrame(np.random.rand(1, 5))
model_weights = np.random.rand(5)
predict_function = lambda x: np.dot(x.values, model_weights)
model = Model(predict_function, dataframe_input=True)
output = model(data)
pred = simple_prediction(data, output)
explainer = LimeExplainer(samples=100, perturbations=2, seed=23, normalise_weights=False)
explanation = explainer.explain(inputs=data, outputs=output, model=model)
saliency = list(explanation.saliency_map().values())[0]
top_features_t = saliency.getTopFeatures(2)
impact = ExplainabilityMetrics.impactScore(model, pred, top_features_t)
assert impact > 0
return impact
def test_lime_as_html():
np.random.seed(0)
data = np.random.rand(1, 5)
model_weights = np.random.rand(5)
predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1)
model = Model(predict_function, disable_arrow=True)
explainer = LimeExplainer()
explainer.explain(inputs=data, outputs=model(data), model=model)
assert True
explanation = explainer.explain(inputs=data, outputs=model(data), model=model)
for score in explanation.as_dataframe()["output-0"]['Saliency']:
assert score != 0
def test_lime_numpy():
np.random.seed(0)
data = np.random.rand(101, 5)
model_weights = np.random.rand(5)
predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1)
fnames = ['f{}'.format(x) for x in "abcde"]
onames = ['o{}'.format(x) for x in "12"]
model = Model(predict_function,
feature_names=fnames,
output_names=onames
)
explainer = LimeExplainer()
explanation = explainer.explain(inputs=data[0], outputs=model(data[0]), model=model)
for oname in onames:
assert oname in explanation.as_dataframe().keys()
for fname in fnames:
assert fname in explanation.as_dataframe()[oname]['Feature'].values