-
Notifications
You must be signed in to change notification settings - Fork 3.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix for shap.maskers.Impute class #3379
base: master
Are you sure you want to change the base?
Changes from 5 commits
53835a6
7db61b6
dd13787
bf8a8ac
5f6e58a
43d47f6
75c88e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -3,6 +3,7 @@ | |||
import numpy as np | ||||
import pandas as pd | ||||
from numba import njit | ||||
from sklearn.impute import KNNImputer, SimpleImputer | ||||
|
||||
from .. import utils | ||||
from .._serializable import Deserializer, Serializer | ||||
|
@@ -17,7 +18,7 @@ class Tabular(Masker): | |||
""" A common base class for Independent and Partition. | ||||
""" | ||||
|
||||
def __init__(self, data, max_samples=100, clustering=None): | ||||
def __init__(self, data, max_samples=100, clustering=None, impute=None): | ||||
""" This masks out tabular features by integrating over the given background dataset. | ||||
|
||||
Parameters | ||||
|
@@ -61,6 +62,11 @@ def __init__(self, data, max_samples=100, clustering=None): | |||
self.clustering = clustering | ||||
self.max_samples = max_samples | ||||
|
||||
# prepare by fitting sklearn imputer | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would rather see this in the impute than in the generic tabular class There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I understand. Do you suggest overwriting the |
||||
self.impute = impute | ||||
if self.impute is not None: | ||||
self.impute.fit(self.data) | ||||
|
||||
# # warn users about large background data sets | ||||
# if self.data.shape[0] > 100: | ||||
# log.warning("Using " + str(self.data.shape[0]) + " background data samples could cause slower " + | ||||
|
@@ -117,6 +123,9 @@ def __call__(self, mask, x): | |||
|
||||
return (masked_inputs_out,), varying_rows_out | ||||
|
||||
if self.impute is not None: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should then also just be initialized in impute |
||||
self.data = self.impute.transform(x) | ||||
|
||||
# otherwise we update the whole set of masked data for a single sample | ||||
self._masked_data[:] = x * mask + self.data * np.invert(mask) | ||||
self._last_mask[:] = mask | ||||
|
@@ -299,24 +308,82 @@ def __init__(self, data, max_samples=100, clustering="correlation"): | |||
super().__init__(data, max_samples=max_samples, clustering=clustering) | ||||
|
||||
|
||||
class Impute(Masker): # we should inherit from Tabular once we add support for arbitrary masking | ||||
class LinearImpute: | ||||
""" Simple class for imputing missing values using pandas.Series.interpolate. | ||||
""" | ||||
|
||||
def __init__(self, missing_value=0): | ||||
""" Build a linear imputer for Impute classes when method=linear | ||||
|
||||
Parameters | ||||
---------- | ||||
missing_value : int, numpy.NaN | ||||
The missing value to impute. | ||||
""" | ||||
self.missing_value = missing_value | ||||
|
||||
def fit(self, data): | ||||
""" Linearly impute missing values in the data and return as array. | ||||
|
||||
Parameters | ||||
---------- | ||||
data : numpy.ndarray | ||||
Array to impute missing values for, should be masked | ||||
using missing_value. | ||||
""" | ||||
self.data = pd.DataFrame(data) | ||||
stompsjo marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
self.data = self.data.replace(self.missing_value, np.NaN) | ||||
interpolated = self.data.interpolate( | ||||
method="linear", | ||||
limit_direction="both" | ||||
) | ||||
|
||||
return interpolated.values() | ||||
|
||||
|
||||
class Impute(Tabular): | ||||
""" This imputes the values of missing features using the values of the observed features. | ||||
|
||||
Unlike Independent, Gaussian imputes missing values based on correlations with observed data points. | ||||
""" | ||||
|
||||
def __init__(self, data, method="linear"): | ||||
def __init__(self, data, max_samples=100, method="mean"): | ||||
""" Build a Partition masker with the given background data and clustering. | ||||
|
||||
Parameters | ||||
---------- | ||||
data : numpy.ndarray, pandas.DataFrame or {"mean: numpy.ndarray, "cov": numpy.ndarray} dictionary | ||||
data : numpy.ndarray, pandas.DataFrame | ||||
The background dataset that is used for masking. | ||||
|
||||
max_samples : int | ||||
The maximum number of samples to use from the passed background data. If data has more | ||||
than max_samples then shap.utils.sample is used to subsample the dataset. The number of | ||||
samples coming out of the masker (to be integrated over) matches the number of samples in | ||||
the background dataset. This means larger background dataset cause longer runtimes. Normally | ||||
about 1, 10, 100, or 1000 background samples are reasonable choices. | ||||
|
||||
method : string or sklearn.impute object | ||||
If a string, then this is shorthand for the type of sklearn.impute object to generate. | ||||
Either a SimpleImputer or KNNImputer is used with default settings. | ||||
Mean, median, and mode refer to the supported SimpleImputer strategies. | ||||
For more finetuning, pass an already initialized sklearn.impute object. | ||||
Supported methods are: | ||||
mean - SimpleImputer with elements replaced by mean of feature in data. | ||||
median - SimpleImputer with elements replaced by median of feature in data. | ||||
mode - SimpleImputer with elements replaced by most frequent of feature in data. | ||||
knn - KNNImputer with elements replaced by mean value within 5 NNs of feature in data. | ||||
""" | ||||
if data is dict and "mean" in data: | ||||
self.mean = data.get("mean", None) | ||||
self.cov = data.get("cov", None) | ||||
data = np.expand_dims(data["mean"], 0) | ||||
methods = ["linear", "mean", "median", "mode", "knn"] | ||||
if isinstance(method, str): | ||||
if method not in methods: | ||||
raise NotImplementedError("Given imputation method is not supported.") | ||||
stompsjo marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
elif method == "knn": | ||||
impute = KNNImputer(missing_values=0) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO the missing values should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think line 142 may cause issues: Line 141 in 75c88e8
If the missing value is np.nan , and that is present in x , then x*mask will still leave the np.nan in place. I think the current behavior is to have masked values set to 0 via the mask.
(That said, I recognize that as is, every time |
||||
elif method == "linear": | ||||
impute = LinearImpute(missing_value=0) | ||||
else: | ||||
impute = SimpleImputer(missing_values=0, strategy=method) | ||||
else: | ||||
impute = method | ||||
stompsjo marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
||||
self.data = data | ||||
self.method = method | ||||
super().__init__(data, max_samples=max_samples, impute=impute) |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -54,7 +54,7 @@ def test_serialization_independent_masker_numpy(): | |||||||||||||||||||||||||||||||||||
temp_serialization_file.seek(0) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# deserialize masker | ||||||||||||||||||||||||||||||||||||
new_independent_masker = shap.maskers.Masker.load(temp_serialization_file) | ||||||||||||||||||||||||||||||||||||
new_independent_masker = shap.maskers.Independent.load(temp_serialization_file) | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I misinterpreted what this test is doing. I assumed that we want to use the load method specifically of the masker being tested, rather than the parent shap/tests/maskers/test_tabular.py Line 110 in 43d47f6
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
mask = np.ones(X.shape[1]).astype(int) | ||||||||||||||||||||||||||||||||||||
mask[0] = 0 | ||||||||||||||||||||||||||||||||||||
|
@@ -107,7 +107,60 @@ def test_serialization_partion_masker_numpy(): | |||||||||||||||||||||||||||||||||||
temp_serialization_file.seek(0) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# deserialize masker | ||||||||||||||||||||||||||||||||||||
new_partition_masker = shap.maskers.Masker.load(temp_serialization_file) | ||||||||||||||||||||||||||||||||||||
new_partition_masker = shap.maskers.Partition.load(temp_serialization_file) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
mask = np.ones(X.shape[1]).astype(int) | ||||||||||||||||||||||||||||||||||||
mask[0] = 0 | ||||||||||||||||||||||||||||||||||||
mask[4] = 0 | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# comparing masked values | ||||||||||||||||||||||||||||||||||||
assert np.array_equal(original_partition_masker(mask, X[0])[0], new_partition_masker(mask, X[0])[0]) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def test_serialization_impute_masker_dataframe(): | ||||||||||||||||||||||||||||||||||||
""" Test the serialization of a Partition masker based on a DataFrame. | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
X, _ = shap.datasets.california(n_points=500) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# initialize partition masker | ||||||||||||||||||||||||||||||||||||
original_partition_masker = shap.maskers.Impute(X) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
with tempfile.TemporaryFile() as temp_serialization_file: | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# serialize partition masker | ||||||||||||||||||||||||||||||||||||
original_partition_masker.save(temp_serialization_file) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
temp_serialization_file.seek(0) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# deserialize masker | ||||||||||||||||||||||||||||||||||||
new_partition_masker = shap.maskers.Impute.load(temp_serialization_file) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
mask = np.ones(X.shape[1]).astype(int) | ||||||||||||||||||||||||||||||||||||
mask[0] = 0 | ||||||||||||||||||||||||||||||||||||
mask[4] = 0 | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# comparing masked values | ||||||||||||||||||||||||||||||||||||
assert np.array_equal(original_partition_masker(mask, X[:1].values[0])[1], new_partition_masker(mask, X[:1].values[0])[1]) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def test_serialization_impute_masker_numpy(): | ||||||||||||||||||||||||||||||||||||
""" Test the serialization of a Partition masker based on a numpy array. | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
X, _ = shap.datasets.california(n_points=500) | ||||||||||||||||||||||||||||||||||||
X = X.values | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# initialize partition masker | ||||||||||||||||||||||||||||||||||||
original_partition_masker = shap.maskers.Impute(X) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
with tempfile.TemporaryFile() as temp_serialization_file: | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# serialize partition masker | ||||||||||||||||||||||||||||||||||||
original_partition_masker.save(temp_serialization_file) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
temp_serialization_file.seek(0) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# deserialize masker | ||||||||||||||||||||||||||||||||||||
new_partition_masker = shap.maskers.Impute.load(temp_serialization_file) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add some tests where you actually impute something, the housing dataset does not have missing data AFAIK. Parameterizing the tests for each method would be great aswell. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's a start: shap/tests/maskers/test_tabular.py Lines 175 to 191 in 75c88e8
I tried to come up with something reproducible for all methods. Let me know if you have feedback. |
||||||||||||||||||||||||||||||||||||
mask = np.ones(X.shape[1]).astype(int) | ||||||||||||||||||||||||||||||||||||
mask[0] = 0 | ||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add this keyword only for the impute class ant not for tabular. See my comment below.