Skip to content

Commit

Permalink
[python] model_diagnostics and predict_surrogate (#290)
Browse files Browse the repository at this point in the history
* update readme

* add ResidualDiagnostics

* add predict_surrogate

* upgrade the predict_fn selection

* add plot to ResidualDiagnostics

* add print method

* add tests to model_diagnostics

* optimize tests

* add tests to predict_surrogate

* fix test

* fix ci
  • Loading branch information
hbaniecki committed Aug 20, 2020
1 parent 0f57acc commit 49edcce
Show file tree
Hide file tree
Showing 25 changed files with 555 additions and 40 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -38,7 +38,7 @@ install.packages("DALEX")
The **Python** version of dalex is available on [pip](https://pypi.org/project/dalex/)

```console
pip install dalex
pip install dalex -U
```

## Learn more
Expand Down
13 changes: 12 additions & 1 deletion python/dalex/NEWS.md
@@ -1,6 +1,17 @@
dalex (development)
----------------------------------------------------------------
* ...

### bug fixes

* `ModelPerformance.plot` now uses a drwhy color palette

### features

* added the `ResidualDiagnostics` object with a `plot` method
* added `model_diagnostics` method to the `Explainer`, which performs residual diagnostics
* added `predict_surrogate` method to the `Explainer`, which is a wrapper for the `lime`
tabular explanation from the [lime](https://github.com/marcotcr/lime) package
* added a `__str__` method to all of the explanation objects (it prints the `result` attribute)

dalex 0.2.0
----------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion python/dalex/README.md
Expand Up @@ -27,7 +27,7 @@ The `dalex` package is a part of [DrWhy.AI](http://DrWhy.AI) universe.
## Installation

```console
pip install dalex==0.1.9
pip install dalex -U
```

## Resources
Expand Down
28 changes: 28 additions & 0 deletions python/dalex/dalex/_explainer/checks.py
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
from copy import deepcopy

from .helper import verbose_cat, is_y_in_data
from .yhat import *
Expand Down Expand Up @@ -290,3 +291,30 @@ def check_loss_function(explainer, loss_function):

def check_model_type(model_type, model_type_):
return model_type_ if model_type is None else model_type


def check_new_observation_lime(new_observation):
# lime accepts only np.array as data_row

new_observation_ = deepcopy(new_observation)
if isinstance(new_observation_, pd.Series):
new_observation_ = new_observation_.to_numpy()
elif isinstance(new_observation_, np.ndarray):
if new_observation_.ndim == 2:
if new_observation.shape[0] != 1:
raise ValueError("Wrong new_observation dimension")
# make 2D array 1D
new_observation_ = new_observation_.flatten()
elif new_observation_.ndim > 2:
raise ValueError("Wrong new_observation dimension")
elif isinstance(new_observation_, list):
new_observation_ = np.array(new_observation_)
elif isinstance(new_observation_, pd.DataFrame):
if new_observation.shape[0] != 1:
raise ValueError("Wrong new_observation dimension")
else:
new_observation_ = new_observation.to_numpy().flatten()
else:
raise TypeError("new_observation must be a list or numpy.ndarray or pandas.Series or pandas.DataFrame")

return new_observation_
30 changes: 30 additions & 0 deletions python/dalex/dalex/_explainer/helper.py
Expand Up @@ -15,3 +15,33 @@ def is_y_in_data(data, y):
def get_model_info(model):
model_package = re.search("(?<=<class ').*?(?=\.)", str(type(model)))[0]
return {'model_package': model_package}


def unpack_kwargs_lime(explainer, new_observation, **kwargs):
# helper function for predict_surrogate(type='lime')
# use https://stackoverflow.com/a/58543357 to unpack the **kwargs into multiple functions
from lime.lime_tabular import LimeTabularExplainer
import inspect

explainer_args = [k for k, v in inspect.signature(LimeTabularExplainer).parameters.items()]
explainer_dict = {k: kwargs.pop(k) for k in dict(kwargs) if k in explainer_args}
explanation_args = [k for k, v in inspect.signature(
LimeTabularExplainer.explain_instance).parameters.items()]
explanation_dict = {k: kwargs.pop(k) for k in dict(kwargs) if k in explanation_args}

if 'training_data' not in explainer_dict:
explainer_dict['training_data'] = explainer.data.to_numpy()
if 'mode' not in explainer_dict:
explainer_dict['mode'] = explainer.model_type
if 'data_row' not in explanation_dict:
explanation_dict['data_row'] = new_observation
if 'predict_fn' not in explanation_dict:
if hasattr(explainer.model, 'predict_proba'):
explanation_dict['predict_fn'] = explainer.model.predict_proba
elif hasattr(explainer.model, 'predict'):
explanation_dict['predict_fn'] = explainer.model.predict
else:
raise ValueError("Pass a `predict_fn` parameter to the `predict_surrogate` method. "
"See https://lime-ml.readthedocs.io/en/latest/lime.html#lime.lime_tabular.LimeTabularExplainer.explain_instance")

return explainer_dict, explanation_dict
94 changes: 92 additions & 2 deletions python/dalex/dalex/_explainer/object.py
@@ -1,7 +1,8 @@
from dalex.dataset_level import ModelPerformance, VariableImportance, AggregatedProfiles
from dalex.dataset_level import ModelPerformance, VariableImportance,\
AggregatedProfiles, ResidualDiagnostics
from dalex.instance_level import BreakDown, Shap, CeterisParibus
from .checks import *
from .helper import get_model_info
from .helper import get_model_info, unpack_kwargs_lime


class Explainer:
Expand Down Expand Up @@ -259,6 +260,12 @@ def predict_parts(self,
BreakDown or Shap class object
Explanation object containing the main result attribute and the plot method.
Object class, its attributes, and the plot method depend on the `type` parameter.
Notes
--------
https://pbiecek.github.io/ema/breakDown.html
https://pbiecek.github.io/ema/iBreakDown.html
https://pbiecek.github.io/ema/shapley.html
"""

types = ('break_down_interactions', 'break_down', 'shap')
Expand Down Expand Up @@ -333,6 +340,10 @@ def predict_profile(self,
-----------
CeterisParibus class object
Explanation object containing the main result attribute and the plot method.
Notes
--------
https://pbiecek.github.io/ema/ceterisParibus.html
"""

types = ('ceteris_paribus', )
Expand All @@ -352,6 +363,45 @@ def predict_profile(self,

return predict_profile_

def predict_surrogate(self, new_observation, type='lime', **kwargs):
"""Wrapper for surrogate model explanations
This function uses the lime package to create model explanation.
See https://lime-ml.readthedocs.io/en/latest/lime.html#module-lime.lime_tabular
Parameters
-----------
new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
An observation for which a prediction needs to be explained.
type : {'lime'}
Type of explanation method
(default is 'lime', which uses the lime package to create an explanation).
kwargs :
Keyword arguments passed to the lime.lime_tabular.LimeTabularExplainer object
and the LimeTabularExplainer.explain_instance method. Exceptions are:
`training_data`, `mode`, `data_row` and `predict_fn`. Other parameters:
https://lime-ml.readthedocs.io/en/latest/lime.html#module-lime.lime_tabular
Returns
-----------
lime.explanation.Explanation
Explanation object.
Notes
-----------
https://github.com/marcotcr/lime
"""

if type == 'lime':
from lime.lime_tabular import LimeTabularExplainer
new_observation = check_new_observation_lime(new_observation)

explainer_dict, explanation_dict = unpack_kwargs_lime(self, new_observation, **kwargs)
lime_tabular_explainer = LimeTabularExplainer(**explainer_dict)
explanation = lime_tabular_explainer.explain_instance(**explanation_dict)

return explanation

def model_performance(self,
model_type=None,
cutoff=0.5):
Expand All @@ -370,6 +420,10 @@ def model_performance(self,
-----------
ModelPerformance class object
Explanation object containing the main result attribute and the plot method.
Notes
--------
https://pbiecek.github.io/ema/modelPerformance.html
"""

if model_type is None and self.model_type is None:
Expand Down Expand Up @@ -429,6 +483,10 @@ def model_parts(self,
-----------
VariableImportance class object
Explanation object containing the main result attribute and the plot method.
Notes
--------
https://pbiecek.github.io/ema/featureImportance.html
"""

types = ('variable_importance', 'ratio', 'difference')
Expand Down Expand Up @@ -508,6 +566,11 @@ def model_profile(self,
-----------
AggregatedProfiles class object
Explanation object containing the main result attribute and the plot method.
Notes
--------
https://pbiecek.github.io/ema/partialDependenceProfiles.html
https://pbiecek.github.io/ema/accumulatedLocalProfiles.html
"""

types = ('partial', 'accumulated', 'conditional')
Expand Down Expand Up @@ -544,6 +607,33 @@ def model_profile(self,

return model_profile_

def model_diagnostics(self,
variables=None):
"""Calculate dataset level residuals diagnostics
Parameters
-----------
variables : str or array_like of str, optional
Variables for which the data will be calculated
(default is None, which means all of the variables).
Returns
-----------
ResidualDiagnostics class object
Explanation object containing the main result attribute and the plot method.
Notes
--------
https://pbiecek.github.io/ema/residualDiagnostic.html
"""

residual_diagnostics_ = ResidualDiagnostics(
variables=variables
)
residual_diagnostics_.fit(self)

return residual_diagnostics_

def dumps(self, *args, **kwargs):
"""Return the pickled representation (bytes object) of the Explainer
Expand Down
4 changes: 3 additions & 1 deletion python/dalex/dalex/dataset_level/__init__.py
@@ -1,9 +1,11 @@
from ._aggregated_profiles.object import AggregatedProfiles
from ._model_performance.object import ModelPerformance
from ._variable_importance.object import VariableImportance
from ._residual_diagnostics import ResidualDiagnostics

__all__ = [
"ModelPerformance",
"VariableImportance",
"AggregatedProfiles"
"AggregatedProfiles",
"ResidualDiagnostics"
]
Expand Up @@ -89,6 +89,10 @@ def __init__(self,
self.raw_profiles = None
self.random_state = random_state

def __str__(self):
from IPython.display import display
display(self.result)

def fit(self,
ceteris_paribus,
verbose=True):
Expand Down
27 changes: 15 additions & 12 deletions python/dalex/dalex/dataset_level/_model_performance/object.py
Expand Up @@ -2,6 +2,7 @@

from dalex.dataset_level._model_performance.plot import ecdf
from .utils import *
from ..._explainer.theme import get_default_colors


class ModelPerformance:
Expand Down Expand Up @@ -39,6 +40,10 @@ def __init__(self,
self.result = None
self.residuals = None

def __str__(self):
from IPython.display import display
display(self.result)

def fit(self, explainer):
"""Calculate the result of explanation
Expand Down Expand Up @@ -121,7 +126,7 @@ def plot(self,
Parameters
-----------
objects : ModelPerformance object or array_like of ModelPerformance objects
Additional objects to plot in subplots (default is None).
Additional objects to plot (default is None).
title : str, optional
Title of the plot (default depends on the `type` attribute).
show : bool, optional
Expand All @@ -136,31 +141,29 @@ def plot(self,

# are there any other objects to plot?
if objects is None:
n = 1
_residuals_df_list = [self.residuals.copy()]
_df_list = [self.residuals.copy()]
elif isinstance(objects, self.__class__): # allow for objects to be a single element
n = 2
_residuals_df_list = [self.residuals.copy(), objects.residuals.copy()]
_df_list = [self.residuals.copy(), objects.residuals.copy()]
else: # objects as tuple or array
n = len(objects) + 1
_residuals_df_list = [self.residuals.copy()]
_df_list = [self.residuals.copy()]
for ob in objects:
if not isinstance(ob, self.__class__):
raise TypeError("Some explanations aren't of ModelPerformance class")
_residuals_df_list += [ob.residuals.copy()]
_df_list += [ob.residuals.copy()]

colors = get_default_colors(len(_df_list), 'line')
fig = go.Figure()

for i in range(n):
_residuals_df = _residuals_df_list[i]
_abs_residuals = np.abs(_residuals_df['residuals'])
for i, _df in enumerate(_df_list):
_abs_residuals = np.abs(_df['residuals'])
_unique_abs_residuals = np.unique(_abs_residuals)

fig.add_scatter(
x=_unique_abs_residuals,
y=1 - ecdf(_abs_residuals)(_unique_abs_residuals),
line_shape='hv',
name=_residuals_df.iloc[0, _residuals_df.columns.get_loc('label')]
name=_df.iloc[0, _df.columns.get_loc('label')],
marker=dict(color=colors[i])
)

fig.update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside',
Expand Down
@@ -0,0 +1,5 @@
from .object import ResidualDiagnostics

__all__ = [
"ResidualDiagnostics"
]
17 changes: 17 additions & 0 deletions python/dalex/dalex/dataset_level/_residual_diagnostics/checks.py
@@ -0,0 +1,17 @@
import numpy as np
import pandas as pd


def check_variables(variables):
# treating variables as list simplifies code
if variables is not None and not isinstance(variables, (str, list, np.ndarray, pd.Series)):
raise TypeError("variables must be None or str or list or np.ndarray or pd.Series")

if variables is None:
variables_ = None
elif isinstance(variables, str):
variables_ = [variables]
else:
variables_ = list(variables)

return variables_

0 comments on commit 49edcce

Please sign in to comment.