Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] model_diagnostics and predict_surrogate #290

Merged
merged 11 commits into from
Aug 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ install.packages("DALEX")
The **Python** version of dalex is available on [pip](https://pypi.org/project/dalex/)

```console
pip install dalex
pip install dalex -U
```

## Learn more
Expand Down
13 changes: 12 additions & 1 deletion python/dalex/NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
dalex (development)
----------------------------------------------------------------
* ...

### bug fixes

* `ModelPerformance.plot` now uses a drwhy color palette

### features

* added the `ResidualDiagnostics` object with a `plot` method
* added `model_diagnostics` method to the `Explainer`, which performs residual diagnostics
* added `predict_surrogate` method to the `Explainer`, which is a wrapper for the `lime`
tabular explanation from the [lime](https://github.com/marcotcr/lime) package
* added a `__str__` method to all of the explanation objects (it prints the `result` attribute)

dalex 0.2.0
----------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion python/dalex/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ The `dalex` package is a part of [DrWhy.AI](http://DrWhy.AI) universe.
## Installation

```console
pip install dalex==0.1.9
pip install dalex -U
```

## Resources
Expand Down
28 changes: 28 additions & 0 deletions python/dalex/dalex/_explainer/checks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
from copy import deepcopy

from .helper import verbose_cat, is_y_in_data
from .yhat import *
Expand Down Expand Up @@ -290,3 +291,30 @@ def check_loss_function(explainer, loss_function):

def check_model_type(model_type, model_type_):
return model_type_ if model_type is None else model_type


def check_new_observation_lime(new_observation):
# lime accepts only np.array as data_row

new_observation_ = deepcopy(new_observation)
if isinstance(new_observation_, pd.Series):
new_observation_ = new_observation_.to_numpy()
elif isinstance(new_observation_, np.ndarray):
if new_observation_.ndim == 2:
if new_observation.shape[0] != 1:
raise ValueError("Wrong new_observation dimension")
# make 2D array 1D
new_observation_ = new_observation_.flatten()
elif new_observation_.ndim > 2:
raise ValueError("Wrong new_observation dimension")
elif isinstance(new_observation_, list):
new_observation_ = np.array(new_observation_)
elif isinstance(new_observation_, pd.DataFrame):
if new_observation.shape[0] != 1:
raise ValueError("Wrong new_observation dimension")
else:
new_observation_ = new_observation.to_numpy().flatten()
else:
raise TypeError("new_observation must be a list or numpy.ndarray or pandas.Series or pandas.DataFrame")

return new_observation_
30 changes: 30 additions & 0 deletions python/dalex/dalex/_explainer/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,33 @@ def is_y_in_data(data, y):
def get_model_info(model):
model_package = re.search("(?<=<class ').*?(?=\.)", str(type(model)))[0]
return {'model_package': model_package}


def unpack_kwargs_lime(explainer, new_observation, **kwargs):
# helper function for predict_surrogate(type='lime')
# use https://stackoverflow.com/a/58543357 to unpack the **kwargs into multiple functions
from lime.lime_tabular import LimeTabularExplainer
import inspect

explainer_args = [k for k, v in inspect.signature(LimeTabularExplainer).parameters.items()]
explainer_dict = {k: kwargs.pop(k) for k in dict(kwargs) if k in explainer_args}
explanation_args = [k for k, v in inspect.signature(
LimeTabularExplainer.explain_instance).parameters.items()]
explanation_dict = {k: kwargs.pop(k) for k in dict(kwargs) if k in explanation_args}

if 'training_data' not in explainer_dict:
explainer_dict['training_data'] = explainer.data.to_numpy()
if 'mode' not in explainer_dict:
explainer_dict['mode'] = explainer.model_type
if 'data_row' not in explanation_dict:
explanation_dict['data_row'] = new_observation
if 'predict_fn' not in explanation_dict:
if hasattr(explainer.model, 'predict_proba'):
explanation_dict['predict_fn'] = explainer.model.predict_proba
elif hasattr(explainer.model, 'predict'):
explanation_dict['predict_fn'] = explainer.model.predict
else:
raise ValueError("Pass a `predict_fn` parameter to the `predict_surrogate` method. "
"See https://lime-ml.readthedocs.io/en/latest/lime.html#lime.lime_tabular.LimeTabularExplainer.explain_instance")

return explainer_dict, explanation_dict
94 changes: 92 additions & 2 deletions python/dalex/dalex/_explainer/object.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dalex.dataset_level import ModelPerformance, VariableImportance, AggregatedProfiles
from dalex.dataset_level import ModelPerformance, VariableImportance,\
AggregatedProfiles, ResidualDiagnostics
from dalex.instance_level import BreakDown, Shap, CeterisParibus
from .checks import *
from .helper import get_model_info
from .helper import get_model_info, unpack_kwargs_lime


class Explainer:
Expand Down Expand Up @@ -259,6 +260,12 @@ def predict_parts(self,
BreakDown or Shap class object
Explanation object containing the main result attribute and the plot method.
Object class, its attributes, and the plot method depend on the `type` parameter.

Notes
--------
https://pbiecek.github.io/ema/breakDown.html
https://pbiecek.github.io/ema/iBreakDown.html
https://pbiecek.github.io/ema/shapley.html
"""

types = ('break_down_interactions', 'break_down', 'shap')
Expand Down Expand Up @@ -333,6 +340,10 @@ def predict_profile(self,
-----------
CeterisParibus class object
Explanation object containing the main result attribute and the plot method.

Notes
--------
https://pbiecek.github.io/ema/ceterisParibus.html
"""

types = ('ceteris_paribus', )
Expand All @@ -352,6 +363,45 @@ def predict_profile(self,

return predict_profile_

def predict_surrogate(self, new_observation, type='lime', **kwargs):
"""Wrapper for surrogate model explanations

This function uses the lime package to create model explanation.
See https://lime-ml.readthedocs.io/en/latest/lime.html#module-lime.lime_tabular

Parameters
-----------
new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
An observation for which a prediction needs to be explained.
type : {'lime'}
Type of explanation method
(default is 'lime', which uses the lime package to create an explanation).
kwargs :
Keyword arguments passed to the lime.lime_tabular.LimeTabularExplainer object
and the LimeTabularExplainer.explain_instance method. Exceptions are:
`training_data`, `mode`, `data_row` and `predict_fn`. Other parameters:
https://lime-ml.readthedocs.io/en/latest/lime.html#module-lime.lime_tabular

Returns
-----------
lime.explanation.Explanation
Explanation object.

Notes
-----------
https://github.com/marcotcr/lime
"""

if type == 'lime':
from lime.lime_tabular import LimeTabularExplainer
new_observation = check_new_observation_lime(new_observation)

explainer_dict, explanation_dict = unpack_kwargs_lime(self, new_observation, **kwargs)
lime_tabular_explainer = LimeTabularExplainer(**explainer_dict)
explanation = lime_tabular_explainer.explain_instance(**explanation_dict)

return explanation

def model_performance(self,
model_type=None,
cutoff=0.5):
Expand All @@ -370,6 +420,10 @@ def model_performance(self,
-----------
ModelPerformance class object
Explanation object containing the main result attribute and the plot method.

Notes
--------
https://pbiecek.github.io/ema/modelPerformance.html
"""

if model_type is None and self.model_type is None:
Expand Down Expand Up @@ -429,6 +483,10 @@ def model_parts(self,
-----------
VariableImportance class object
Explanation object containing the main result attribute and the plot method.

Notes
--------
https://pbiecek.github.io/ema/featureImportance.html
"""

types = ('variable_importance', 'ratio', 'difference')
Expand Down Expand Up @@ -508,6 +566,11 @@ def model_profile(self,
-----------
AggregatedProfiles class object
Explanation object containing the main result attribute and the plot method.

Notes
--------
https://pbiecek.github.io/ema/partialDependenceProfiles.html
https://pbiecek.github.io/ema/accumulatedLocalProfiles.html
"""

types = ('partial', 'accumulated', 'conditional')
Expand Down Expand Up @@ -544,6 +607,33 @@ def model_profile(self,

return model_profile_

def model_diagnostics(self,
variables=None):
"""Calculate dataset level residuals diagnostics

Parameters
-----------
variables : str or array_like of str, optional
Variables for which the data will be calculated
(default is None, which means all of the variables).

Returns
-----------
ResidualDiagnostics class object
Explanation object containing the main result attribute and the plot method.

Notes
--------
https://pbiecek.github.io/ema/residualDiagnostic.html
"""

residual_diagnostics_ = ResidualDiagnostics(
variables=variables
)
residual_diagnostics_.fit(self)

return residual_diagnostics_

def dumps(self, *args, **kwargs):
"""Return the pickled representation (bytes object) of the Explainer

Expand Down
4 changes: 3 additions & 1 deletion python/dalex/dalex/dataset_level/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from ._aggregated_profiles.object import AggregatedProfiles
from ._model_performance.object import ModelPerformance
from ._variable_importance.object import VariableImportance
from ._residual_diagnostics import ResidualDiagnostics

__all__ = [
"ModelPerformance",
"VariableImportance",
"AggregatedProfiles"
"AggregatedProfiles",
"ResidualDiagnostics"
]
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def __init__(self,
self.raw_profiles = None
self.random_state = random_state

def __str__(self):
from IPython.display import display
display(self.result)

def fit(self,
ceteris_paribus,
verbose=True):
Expand Down
27 changes: 15 additions & 12 deletions python/dalex/dalex/dataset_level/_model_performance/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dalex.dataset_level._model_performance.plot import ecdf
from .utils import *
from ..._explainer.theme import get_default_colors


class ModelPerformance:
Expand Down Expand Up @@ -39,6 +40,10 @@ def __init__(self,
self.result = None
self.residuals = None

def __str__(self):
from IPython.display import display
display(self.result)

def fit(self, explainer):
"""Calculate the result of explanation

Expand Down Expand Up @@ -121,7 +126,7 @@ def plot(self,
Parameters
-----------
objects : ModelPerformance object or array_like of ModelPerformance objects
Additional objects to plot in subplots (default is None).
Additional objects to plot (default is None).
title : str, optional
Title of the plot (default depends on the `type` attribute).
show : bool, optional
Expand All @@ -136,31 +141,29 @@ def plot(self,

# are there any other objects to plot?
if objects is None:
n = 1
_residuals_df_list = [self.residuals.copy()]
_df_list = [self.residuals.copy()]
elif isinstance(objects, self.__class__): # allow for objects to be a single element
n = 2
_residuals_df_list = [self.residuals.copy(), objects.residuals.copy()]
_df_list = [self.residuals.copy(), objects.residuals.copy()]
else: # objects as tuple or array
n = len(objects) + 1
_residuals_df_list = [self.residuals.copy()]
_df_list = [self.residuals.copy()]
for ob in objects:
if not isinstance(ob, self.__class__):
raise TypeError("Some explanations aren't of ModelPerformance class")
_residuals_df_list += [ob.residuals.copy()]
_df_list += [ob.residuals.copy()]

colors = get_default_colors(len(_df_list), 'line')
fig = go.Figure()

for i in range(n):
_residuals_df = _residuals_df_list[i]
_abs_residuals = np.abs(_residuals_df['residuals'])
for i, _df in enumerate(_df_list):
_abs_residuals = np.abs(_df['residuals'])
_unique_abs_residuals = np.unique(_abs_residuals)

fig.add_scatter(
x=_unique_abs_residuals,
y=1 - ecdf(_abs_residuals)(_unique_abs_residuals),
line_shape='hv',
name=_residuals_df.iloc[0, _residuals_df.columns.get_loc('label')]
name=_df.iloc[0, _df.columns.get_loc('label')],
marker=dict(color=colors[i])
)

fig.update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .object import ResidualDiagnostics

__all__ = [
"ResidualDiagnostics"
]
17 changes: 17 additions & 0 deletions python/dalex/dalex/dataset_level/_residual_diagnostics/checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import numpy as np
import pandas as pd


def check_variables(variables):
# treating variables as list simplifies code
if variables is not None and not isinstance(variables, (str, list, np.ndarray, pd.Series)):
raise TypeError("variables must be None or str or list or np.ndarray or pd.Series")

if variables is None:
variables_ = None
elif isinstance(variables, str):
variables_ = [variables]
else:
variables_ = list(variables)

return variables_
Loading