Skip to content

Commit

Permalink
addressing feedback from PR
Browse files Browse the repository at this point in the history
  • Loading branch information
Jordan Stomps committed Dec 11, 2023
1 parent 43d47f6 commit 75c88e8
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 10 deletions.
43 changes: 33 additions & 10 deletions shap/maskers/_tabular.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Callable

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -65,7 +66,12 @@ def __init__(self, data, max_samples=100, clustering=None, impute=None):
# prepare by fitting sklearn imputer
self.impute = impute
if self.impute is not None:
self.impute.fit(self.data)
if len(self.data.shape) == 1:
self.impute.fit(self.data.reshape(1, -1))
elif len(self.data.shape) == 2:
self.impute.fit(self.data)
elif len(self.data.shape) >= 2:
raise NotImplementedError(f"Currently only 1 and 2 dimensional data can by processed with the LinearImpute class. You provided {len(self.data.shape)}. If this is crucial to you, feel free to open an issue: https://github.com/shap/shap/issues.")

# # warn users about large background data sets
# if self.data.shape[0] > 100:
Expand Down Expand Up @@ -124,7 +130,12 @@ def __call__(self, mask, x):
return (masked_inputs_out,), varying_rows_out

if self.impute is not None:
self.data = self.impute.transform(x)
if len(x.shape) == 1:
self.data = self.impute.transform(x.reshape(1, -1))
elif len(x.shape) == 2:
self.data = self.impute.transform(x)
elif len(x.shape) >= 2:
raise NotImplementedError(f"Currently only 1 and 2 dimensional data can by processed with the LinearImpute class. You provided {len(x.shape)}. If this is crucial to you, feel free to open an issue: https://github.com/shap/shap/issues.")

# otherwise we update the whole set of masked data for a single sample
self._masked_data[:] = x * mask + self.data * np.invert(mask)
Expand Down Expand Up @@ -323,6 +334,9 @@ def __init__(self, missing_value=0):
self.missing_value = missing_value

def fit(self, data):
self.data = pd.DataFrame(data)

def transform(self, x):
""" Linearly impute missing values in the data and return as array.
Parameters
Expand All @@ -331,16 +345,23 @@ def fit(self, data):
Array to impute missing values for, should be masked
using missing_value.
"""
if len(data.shape) != 2:
raise NotImplementedError(f"Currently only 2 dimensional data can by processed with the LinearImpute class. You provided {len(data.shape)}. If this is crucial to you, feel free to open an issue: https://github.com/shap/shap/issues.")
self.data = pd.DataFrame(data)
self.data = self.data.replace(self.missing_value, np.NaN)
interpolated = self.data.interpolate(
self.x = x
if len(x.shape) == 1:
self.x = x.reshape(1, -1)
elif len(x.shape) > 2:
raise NotImplementedError(f"Currently only 1 and 2 dimensional data can by processed with the LinearImpute class. You provided {len(x.shape)}. If this is crucial to you, feel free to open an issue: https://github.com/shap/shap/issues.")
self.x = pd.DataFrame(self.x)
# Pandas interpolate uses NaN as missing value
self.x = self.x.replace(self.missing_value, np.NaN)
# number of imputed samples for indexing later
n_samples = self.x.shape[0]
# combine with background data for interpolation
interpolated = pd.concat([self.data, self.x], axis=0).interpolate(
method="linear",
limit_direction="both"
)

return interpolated.values()
return interpolated.values[-n_samples]


class Impute(Tabular):
Expand Down Expand Up @@ -375,7 +396,7 @@ def __init__(self, data, max_samples=100, method="mean"):
mode - SimpleImputer with elements replaced by most frequent of feature in data.
knn - KNNImputer with elements replaced by mean value within 5 NNs of feature in data.
"""
methods = ["linear", "mean", "median", "mode", "knn"]
methods = ["linear", "mean", "median", "most_frequent", "knn"]
if isinstance(method, str):
if method not in methods:
raise NotImplementedError(f"Given imputation method is not supported. Please provide one of the following methods: {', '.join(methods)}")
Expand All @@ -385,7 +406,9 @@ def __init__(self, data, max_samples=100, method="mean"):
impute = LinearImpute(missing_value=0)
else:
impute = SimpleImputer(missing_values=0, strategy=method)
else:
elif isinstance(method, Callable):
impute = method
else:
raise NotImplementedError(f"Given imputation method is not supported. Please provide one of the following methods: {', '.join(methods)}")

super().__init__(data, max_samples=max_samples, impute=impute)
42 changes: 42 additions & 0 deletions tests/maskers/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import tempfile

import numpy as np
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPRegressor

import shap

Expand Down Expand Up @@ -168,3 +171,42 @@ def test_serialization_impute_masker_numpy():

# comparing masked values
assert np.array_equal(original_partition_masker(mask, X[0])[0], new_partition_masker(mask, X[0])[0])

def test_imputation():
# toy data
x = np.full((5, 5), np.arange(1,6)).T

methods = ["linear", "mean", "median", "most_frequent", "knn"]
# toy background data
bckg = np.full((5, 5), np.arange(1,6)).T
for method in methods:
# toy sample to impute
x = np.arange(1, 6)
masker = shap.maskers.Impute(np.full((1,5), 1), method=method)
# only mask the second value
mask = np.ones_like(bckg[0])
mask[1] = 0
# masker should impute the original value (toy data is predictable)
imputed = masker(mask.astype(bool), x)
assert np.all(x == imputed)

def test_imputation_workflow():
# toy data
X, y = make_regression(n_samples=100)
X_train, X_test, y_train, y_test = train_test_split(X,
y,
train_size = 0.75)

# train toy model
model = MLPRegressor()
model.fit(X_train, y_train)
model.score(X_test, y_test)

background = shap.maskers.Impute(X_train)
# TypeError here prior to PR #3379
explainer = shap.Explainer(model.predict, masker=background)

shap_values = explainer(X_test)
shap.Explanation(shap_values.values,
shap_values.base_values,
shap_values.data)

0 comments on commit 75c88e8

Please sign in to comment.