Skip to content

Commit

Permalink
[python] potential fix for #325
Browse files Browse the repository at this point in the history
  • Loading branch information
hbaniecki committed Sep 20, 2020
1 parent fd3c36e commit 61f1c1b
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/dalex/NEWS.md
Expand Up @@ -20,7 +20,7 @@ dalex (development)
### defaults

* wrong parameter name `title_x` changed to `y_title` in `CeterisParibus.plot` and `AggregatedProfiles.plot` ([#317](https://github.com/ModelOriented/DALEX/issues/317))

* now warning the user in `Explainer` when `predict_function` returns an error or doesn't return `numpy.ndarray (1d)` ([#325](https://github.com/ModelOriented/DALEX/issues/325))

dalex 0.2.1
----------------------------------------------------------------
Expand Down
12 changes: 9 additions & 3 deletions python/dalex/dalex/_explainer/checks.py
@@ -1,6 +1,7 @@
# check functions for Explainer.__init__
import pandas as pd
from copy import deepcopy
from warnings import warn

from .helper import *
from .yhat import *
Expand Down Expand Up @@ -189,9 +190,14 @@ def check_predict_function_and_model_type(predict_function, model_type,
np.min(y_hat), np.mean(y_hat), np.max(y_hat)), verbose=verbose)

except (Exception, ValueError, TypeError) as error:
verbose_cat(" -> predicted values : the predict_function returns an error when executed \n",
verbose=verbose)
print(error)
# verbose_cat(" -> predicted values : the predict_function returns an error when executed \n",
# verbose=verbose)

warn("\n -> predicted values : the predict_function returns an error when executed \n" +
str(error), stacklevel=2)

if not isinstance(y_hat, np.ndarray) or y_hat.shape != (data.shape[0], ):
warn("\n -> predicted values : predict_function must return numpy.ndarray (1d)", stacklevel=2)

# check if predict_function accepts arrays
try:
Expand Down
54 changes: 53 additions & 1 deletion python/dalex/test/test_explainer.py
Expand Up @@ -7,6 +7,9 @@
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder

import dalex as dx
import numpy as np

import warnings


class ExplainerTest(unittest.TestCase):
Expand Down Expand Up @@ -63,5 +66,54 @@ def test(self):
self.assertIsInstance(case5, dx.instance_level.BreakDown)
self.assertIsInstance(case6, dx.instance_level.CeterisParibus)

case5 = dx.Explainer(self.model, self.X, self.y, predict_function=1, verbose=False)
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
# Trigger a warning.
case5 = dx.Explainer(self.model, self.X, self.y, predict_function=1, verbose=False)
assert issubclass(w[-1].category, UserWarning)

self.assertIsInstance(case5, dx.Explainer)

def test_errors(self):

from sklearn.ensemble import RandomForestRegressor

data = dx.datasets.load_fifa()
X = data.drop(columns=['nationality', 'value_eur']).iloc[1:100, :]
y = data['value_eur'][1:100]

model = RandomForestRegressor()
model.fit(X, y)

def predict_function_return_2d(m, d):
n_rows = d.shape[0]
prediction = m.predict(d)
return prediction.reshape((n_rows, 1))

def predict_function_return_3d(m, d):
n_rows = d.shape[0]
prediction = m.predict(d)
return prediction.reshape((n_rows, 1, 1))

def predict_function_return_one_element_array(m, d):
return np.array(0.2)

warnings.simplefilter("always")
with warnings.catch_warnings(record=True) as w:
# Trigger a warning.
dx.Explainer(model, X, y, verbose=False, model_type='regression',
predict_function=predict_function_return_2d)
assert issubclass(w[-1].category, UserWarning)

with warnings.catch_warnings(record=True) as w:
# Trigger a warning.
dx.Explainer(model, X, y, verbose=False, model_type='regression',
predict_function=predict_function_return_3d)
assert issubclass(w[-1].category, UserWarning)

with warnings.catch_warnings(record=True) as w:
# Trigger a warning.
dx.Explainer(model, X, y, verbose=False, model_type='regression',
predict_function=predict_function_return_one_element_array)
assert issubclass(w[-1].category, UserWarning)

0 comments on commit 61f1c1b

Please sign in to comment.