diff --git a/notebooks/image_examples/image_captioning/Image Captioning using Azure Cognitive Services.ipynb b/notebooks/image_examples/image_captioning/Image Captioning using Azure Cognitive Services.ipynb index 17764d32d..a67ee64db 100644 --- a/notebooks/image_examples/image_captioning/Image Captioning using Azure Cognitive Services.ipynb +++ b/notebooks/image_examples/image_captioning/Image Captioning using Azure Cognitive Services.ipynb @@ -86,8 +86,7 @@ "outputs": [], "source": [ "def get_caption(path_to_image):\n", - " \"\"\"\n", - " Function to get image caption when path to image file is given.\n", + " \"\"\"Function to get image caption when path to image file is given.\n", " Note: API_KEY and ANALYZE_URL need to be defined before calling this function.\n", "\n", " Parameters\n", @@ -97,8 +96,8 @@ " Output\n", " -------\n", " image caption\n", - " \"\"\"\n", "\n", + " \"\"\"\n", " headers = {\n", " \"Ocp-Apim-Subscription-Key\": API_KEY,\n", " \"Content-Type\": \"application/octet-stream\",\n", @@ -314,8 +313,7 @@ " show_grid_plot=False,\n", " limit_grid=20,\n", "):\n", - " \"\"\"\n", - " Function to take a list of images and parameters such max evals etc. and return shap explanations (shap_values) for test images(X).\n", + " \"\"\"Function to take a list of images and parameters such max evals etc. and return shap explanations (shap_values) for test images(X).\n", " Paramaters\n", " ----------\n", " X : list of images which need to be explained\n", @@ -329,7 +327,6 @@ " ------\n", " shap_values_list: list of shap_values objects generated for the images\n", " \"\"\"\n", - "\n", " global image_counter\n", " global mask_counter\n", " shap_values_list = []\n", diff --git a/notebooks/image_examples/image_captioning/Image Captioning using Open Source.ipynb b/notebooks/image_examples/image_captioning/Image Captioning using Open Source.ipynb index 7a3652688..55318658c 100644 --- a/notebooks/image_examples/image_captioning/Image Captioning using Open Source.ipynb +++ b/notebooks/image_examples/image_captioning/Image Captioning using Open Source.ipynb @@ -154,24 +154,23 @@ "outputs": [], "source": [ "class ImageCaptioningPyTorchModel:\n", - " \"\"\"\n", - " Wrapper class to get image captions using Resnet model from setup above.\n", + " \"\"\"Wrapper class to get image captions using Resnet model from setup above.\n", " Note: This class is being used instead of tools/eval.py to get predictions (captions).\n", " To get more context for this class, please refer to tools/eval.py file.\n", " \"\"\"\n", "\n", " def __init__(self, model_path, infos_path, cnn_model=\"resnet101\", device=\"cuda\"):\n", - " \"\"\"\n", - " Initializing the class by loading torch model and vocabulary at path given and using Resnet weights stored in data/imagenet_weights.\n", + " \"\"\"Initializing the class by loading torch model and vocabulary at path given and using Resnet weights stored in data/imagenet_weights.\n", " This is done to speeden the process of getting image captions and avoid loading the model every time captions are needed.\n", + "\n", " Parameters\n", " ----------\n", " model_path : pre-trained model path\n", " infos_path : pre-trained infos (vocab) path\n", " cnn_model : resnet model weights to use; options: \"resnet101\" (default), \"resnet152\"\n", " device : \"cpu\" or \"cuda\" (default)\n", - " \"\"\"\n", "\n", + " \"\"\"\n", " # load infos\n", " with open(infos_path, \"rb\") as f:\n", " infos = utils.pickle_load(f)\n", @@ -201,8 +200,8 @@ " gc.collect()\n", "\n", " def __call__(self, image_folder, batch_size):\n", - " \"\"\"\n", - " Function to get captions for images placed in image_folder.\n", + " \"\"\"Function to get captions for images placed in image_folder.\n", + "\n", " Parameters\n", " ----------\n", " image_folder: folder of images for which captions are needed\n", @@ -210,8 +209,8 @@ " Output\n", " -------\n", " captions : list of captions for images in image_folder (will return a string if there is only one image in folder)\n", - " \"\"\"\n", "\n", + " \"\"\"\n", " # setting eval options\n", " opt = self.opt\n", " opt.batch_size = batch_size\n", @@ -414,8 +413,7 @@ "def run_masker(\n", " X, mask_value=\"inpaint_ns\", max_evals=300, batch_size=50, fixed_context=None\n", "):\n", - " \"\"\"\n", - " Function to take a list of images and parameters such max evals etc. and return shap explanations (shap_values) for test images(X).\n", + " \"\"\"Function to take a list of images and parameters such max evals etc. and return shap explanations (shap_values) for test images(X).\n", " Paramaters\n", " ----------\n", " X : list of images which need to be explained\n", diff --git a/pyproject.toml b/pyproject.toml index 9a0026000..8ca4115cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,16 +136,29 @@ select = [ "UP", # pyupgrade "E", # pycodestyle "W", # warning + "D", # pydocstyle + # D417 # undocumented parameter. FIXME: get this passing ] ignore = [ # Recommended rules to disable when using ruff formatter: "E117", # Over-indented "E501", # Line too long + + # pydocstyle issues not yet fixed + "D100", # Missing docstring in public module + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic method + "D205", # 1 blank line required between summary line and description + "D400", # First line should end with a period + "D401", # First line of docstring should be in imperative mood: "A basic partial dependence plot function." + "D404", # First word of the docstring should not be "This" ] -[tool.ruff.format] -# For now, only Jupyter notebooks are autoformatted. -exclude = ["**.py"] +[tool.ruff.lint.pydocstyle] +convention = "numpy" [tool.ruff.lint.per-file-ignores] # Don't apply linting/formatting to vendored code @@ -159,7 +172,14 @@ exclude = ["**.py"] "notebooks/tabular_examples/tree_based_models/tree_shap_paper/*" = ["ALL"] # Disable some unwanted rules on Jupyter notebooks -"*.ipynb" = ["E703", "E402"] # Allow trailing semicolons, allow imports not at top +"*.ipynb" = ["D", "E703", "E402"] # Allow trailing semicolons, allow imports not at top + +# Ignore pycodestyle in tests +"tests/*py" = ["D"] + +[tool.ruff.format] +# For now, only Jupyter notebooks are autoformatted. +exclude = ["**.py"] [tool.coverage.run] source_pkgs = ["shap"] diff --git a/setup.py b/setup.py index b11e01ae1..1be4b6946 100644 --- a/setup.py +++ b/setup.py @@ -135,8 +135,7 @@ def run_setup(*, with_binary, with_cuda): def try_run_setup(*, with_binary, with_cuda): - """ Fails gracefully when various install steps don't work. - """ + """Fails gracefully when various install steps don't work.""" global _BUILD_ATTEMPTS _BUILD_ATTEMPTS += 1 diff --git a/shap/_explanation.py b/shap/_explanation.py index 0a67fe4c1..0693f07ab 100644 --- a/shap/_explanation.py +++ b/shap/_explanation.py @@ -15,70 +15,60 @@ op_chain_root = OpChain("shap.Explanation") class MetaExplanation(type): - """ This metaclass exposes the Explanation object's methods for creating template op chains. - """ + """This metaclass exposes the Explanation object's methods for creating template op chains.""" def __getitem__(cls, item): return op_chain_root.__getitem__(item) @property def abs(cls): - """ Element-wise absolute value op. - """ + """Element-wise absolute value op.""" return op_chain_root.abs @property def identity(cls): - """ A no-op. - """ + """A no-op.""" return op_chain_root.identity @property def argsort(cls): - """ Numpy style argsort. - """ + """Numpy style argsort.""" return op_chain_root.argsort @property def sum(cls): - """ Numpy style sum. - """ + """Numpy style sum.""" return op_chain_root.sum @property def max(cls): - """ Numpy style max. - """ + """Numpy style max.""" return op_chain_root.max @property def min(cls): - """ Numpy style min. - """ + """Numpy style min.""" return op_chain_root.min @property def mean(cls): - """ Numpy style mean. - """ + """Numpy style mean.""" return op_chain_root.mean @property def sample(cls): - """ Numpy style sample. - """ + """Numpy style sample.""" return op_chain_root.sample @property def hclust(cls): - """ Hierarchical clustering op. - """ + """Hierarchical clustering op.""" return op_chain_root.hclust class Explanation(metaclass=MetaExplanation): - """ A sliceable set of parallel arrays representing a SHAP explanation. - """ + """A sliceable set of parallel arrays representing a SHAP explanation.""" + def __init__( self, values, @@ -164,14 +154,12 @@ def __init__( @property def shape(self): - """ Compute the shape over potentially complex data nesting. - """ + """Compute the shape over potentially complex data nesting.""" return _compute_shape(self._s.values) @property def values(self): - """ Pass-through from the underlying slicer object. - """ + """Pass-through from the underlying slicer object.""" return self._s.values @values.setter def values(self, new_values): @@ -179,8 +167,7 @@ def values(self, new_values): @property def base_values(self): - """ Pass-through from the underlying slicer object. - """ + """Pass-through from the underlying slicer object.""" return self._s.base_values @base_values.setter def base_values(self, new_base_values): @@ -188,8 +175,7 @@ def base_values(self, new_base_values): @property def data(self): - """ Pass-through from the underlying slicer object. - """ + """Pass-through from the underlying slicer object.""" return self._s.data @data.setter def data(self, new_data): @@ -197,8 +183,7 @@ def data(self, new_data): @property def display_data(self): - """ Pass-through from the underlying slicer object. - """ + """Pass-through from the underlying slicer object.""" return self._s.display_data @display_data.setter def display_data(self, new_display_data): @@ -208,14 +193,12 @@ def display_data(self, new_display_data): @property def instance_names(self): - """ Pass-through from the underlying slicer object. - """ + """Pass-through from the underlying slicer object.""" return self._s.instance_names @property def output_names(self): - """ Pass-through from the underlying slicer object. - """ + """Pass-through from the underlying slicer object.""" return self._s.output_names @output_names.setter def output_names(self, new_output_names): @@ -223,14 +206,12 @@ def output_names(self, new_output_names): @property def output_indexes(self): - """ Pass-through from the underlying slicer object. - """ + """Pass-through from the underlying slicer object.""" return self._s.output_indexes @property def feature_names(self): - """ Pass-through from the underlying slicer object. - """ + """Pass-through from the underlying slicer object.""" return self._s.feature_names @feature_names.setter def feature_names(self, new_feature_names): @@ -238,26 +219,22 @@ def feature_names(self, new_feature_names): @property def lower_bounds(self): - """ Pass-through from the underlying slicer object. - """ + """Pass-through from the underlying slicer object.""" return self._s.lower_bounds @property def upper_bounds(self): - """ Pass-through from the underlying slicer object. - """ + """Pass-through from the underlying slicer object.""" return self._s.upper_bounds @property def error_std(self): - """ Pass-through from the underlying slicer object. - """ + """Pass-through from the underlying slicer object.""" return self._s.error_std @property def main_effects(self): - """ Pass-through from the underlying slicer object. - """ + """Pass-through from the underlying slicer object.""" return self._s.main_effects @main_effects.setter def main_effects(self, new_main_effects): @@ -265,8 +242,7 @@ def main_effects(self, new_main_effects): @property def hierarchical_values(self): - """ Pass-through from the underlying slicer object. - """ + """Pass-through from the underlying slicer object.""" return self._s.hierarchical_values @hierarchical_values.setter def hierarchical_values(self, new_hierarchical_values): @@ -274,23 +250,22 @@ def hierarchical_values(self, new_hierarchical_values): @property def clustering(self): - """ Pass-through from the underlying slicer object. - """ + """Pass-through from the underlying slicer object.""" return self._s.clustering @clustering.setter def clustering(self, new_clustering): self._s.clustering = new_clustering def cohorts(self, cohorts): - """ Split this explanation into several cohorts. + """Split this explanation into several cohorts. Parameters ---------- cohorts : int or array If this is an integer then we auto build that many cohorts using a decision tree. If this is an array then we treat that as an array of cohort names/ids for each instance. - """ + """ if isinstance(cohorts, int): return _auto_cohorts(self, max_cohorts=cohorts) if isinstance(cohorts, (list, tuple, np.ndarray)): @@ -299,8 +274,7 @@ def cohorts(self, cohorts): raise TypeError("The given set of cohort indicators is not recognized! Please give an array or int.") def __repr__(self): - """ Display some basic printable info, but not everything. - """ + """Display some basic printable info, but not everything.""" out = ".values =\n"+self.values.__repr__() if self.base_values is not None: out += "\n\n.base_values =\n"+self.base_values.__repr__() @@ -309,8 +283,7 @@ def __repr__(self): return out def __getitem__(self, item): - """ This adds support for OpChain indexing. - """ + """This adds support for OpChain indexing.""" new_self = None if not isinstance(item, tuple): item = (item,) @@ -503,8 +476,7 @@ def __truediv__(self, other): # return new_self def _numpy_func(self, fname, **kwargs): - """ Apply a numpy-style function to this Explanation. - """ + """Apply a numpy-style function to this Explanation.""" new_self = copy.copy(self) axis = kwargs.get("axis", None) @@ -551,23 +523,19 @@ def _numpy_func(self, fname, **kwargs): return new_self def mean(self, axis): - """ Numpy-style mean function. - """ + """Numpy-style mean function.""" return self._numpy_func("mean", axis=axis) def max(self, axis): - """ Numpy-style mean function. - """ + """Numpy-style mean function.""" return self._numpy_func("max", axis=axis) def min(self, axis): - """ Numpy-style mean function. - """ + """Numpy-style mean function.""" return self._numpy_func("min", axis=axis) def sum(self, axis=None, grouping=None): - """ Numpy-style mean function. - """ + """Numpy-style mean function.""" if grouping is None: return self._numpy_func("sum", axis=axis) elif axis == 1 or len(self.shape) == 1: @@ -576,8 +544,7 @@ def sum(self, axis=None, grouping=None): raise DimensionError("Only axis = 1 is supported for grouping right now...") def hstack(self, other): - """ Stack two explanations column-wise. - """ + """Stack two explanations column-wise.""" assert self.shape[0] == other.shape[0], "Can't hstack explanations with different numbers of rows!" assert np.max(np.abs(self.base_values - other.base_values)) < 1e-6, "Can't hstack explanations with different base values!" @@ -620,7 +587,7 @@ def flip(self): def hclust(self, metric="sqeuclidean", axis=0): - """ Computes an optimal leaf ordering sort order using hclustering. + """Computes an optimal leaf ordering sort order using hclustering. hclust(metric="sqeuclidean") @@ -631,6 +598,7 @@ def hclust(self, metric="sqeuclidean", axis=0): axis : int The axis to cluster along. + """ values = self.values @@ -647,7 +615,7 @@ def hclust(self, metric="sqeuclidean", axis=0): return inds def sample(self, max_samples, replace=False, random_state=0): - """ Randomly samples the instances (rows) of the Explanation object. + """Randomly samples the instances (rows) of the Explanation object. Parameters ---------- @@ -657,6 +625,7 @@ def sample(self, max_samples, replace=False, random_state=0): replace : bool Sample with or without replacement. + """ prev_seed = np.random.seed(random_state) inds = np.random.choice(self.shape[0], min(max_samples, self.shape[0]), replace=replace) @@ -748,8 +717,7 @@ def group_features(shap_values, feature_map): ) def compute_output_dims(values, base_values, data, output_names): - """ Uses the passed data to infer which dimensions correspond to the model's output. - """ + """Uses the passed data to infer which dimensions correspond to the model's output.""" values_shape = _compute_shape(values) # input shape matches the data shape @@ -856,9 +824,7 @@ def __repr__(self): def _auto_cohorts(shap_values, max_cohorts): - """ This uses a DecisionTreeRegressor to build a group of cohorts with similar SHAP values. - """ - + """This uses a DecisionTreeRegressor to build a group of cohorts with similar SHAP values.""" # fit a decision tree that well separates the SHAP values m = sklearn.tree.DecisionTreeRegressor(max_leaf_nodes=max_cohorts) m.fit(shap_values.data, shap_values.values) @@ -893,8 +859,7 @@ def _auto_cohorts(shap_values, max_cohorts): return Cohorts(**cohorts) def list_wrap(x): - """ A helper to patch things since slicer doesn't handle arrays of arrays (it does handle lists of arrays) - """ + """A helper to patch things since slicer doesn't handle arrays of arrays (it does handle lists of arrays)""" if isinstance(x, np.ndarray) and len(x.shape) == 1 and isinstance(x[0], np.ndarray): return [v for v in x] else: diff --git a/shap/_serializable.py b/shap/_serializable.py index d9dbde388..88422bdbe 100644 --- a/shap/_serializable.py +++ b/shap/_serializable.py @@ -9,17 +9,15 @@ log = logging.getLogger('shap') class Serializable: - """ This is the superclass of all serializable objects. - """ + """This is the superclass of all serializable objects.""" def save(self, out_file): - """ Save the model to the given file stream. - """ + """Save the model to the given file stream.""" pickle.dump(type(self), out_file) @classmethod def load(cls, in_file, instantiate=True): - """ This is meant to be overridden by subclasses and called with super. + """This is meant to be overridden by subclasses and called with super. We return constructor argument values when not being instantiated. Since there are no constructor arguments for the Serializable class we just return an empty dictionary. @@ -30,7 +28,7 @@ def load(cls, in_file, instantiate=True): @classmethod def _instantiated_load(cls, in_file, **kwargs): - """ This is meant to be overridden by subclasses and called with super. + """This is meant to be overridden by subclasses and called with super. We return constructor argument values (we have no values to load in this abstract class). """ @@ -48,8 +46,8 @@ def _instantiated_load(cls, in_file, **kwargs): class Serializer: - """ Save data items to an input stream. - """ + """Save data items to an input stream.""" + def __init__(self, out_stream, block_name, version): self.out_stream = out_stream self.block_name = block_name @@ -70,8 +68,7 @@ def __exit__(self, exception_type, exception_value, traceback): pickle.dump("END_BLOCK___", self.out_stream) def save(self, name, value, encoder="auto"): - """ Dump a data item to the current input stream. - """ + """Dump a data item to the current input stream.""" log.debug("name = %s", name) pickle.dump(name, self.out_stream) if encoder is None or encoder is False: @@ -102,8 +99,7 @@ def save(self, name, value, encoder="auto"): log.debug("value = %s", str(value)) class Deserializer: - """ Load data items from an input stream. - """ + """Load data items from an input stream.""" def __init__(self, in_stream, block_name, min_version, max_version): self.in_stream = in_stream @@ -168,8 +164,7 @@ def __exit__(self, exception_type, exception_value, traceback): ) def load(self, name, decoder=None): - """ Load a data item from the current input stream. - """ + """Load a data item from the current input stream.""" # confirm the block name loaded_name = pickle.load(self.in_stream) log.debug("loaded_name = %s", loaded_name) diff --git a/shap/actions/_action.py b/shap/actions/_action.py index 6339e0c10..fba51ada1 100644 --- a/shap/actions/_action.py +++ b/shap/actions/_action.py @@ -1,6 +1,6 @@ class Action: - """ Abstract action class. - """ + """Abstract action class.""" + def __lt__(self, other_action): return self.cost < other_action.cost diff --git a/shap/benchmark/_compute.py b/shap/benchmark/_compute.py index ab46ca5d1..81c860e8e 100644 --- a/shap/benchmark/_compute.py +++ b/shap/benchmark/_compute.py @@ -2,8 +2,7 @@ class ComputeTime: - """ Extracts a runtime benchmark result from the passed Explanation. - """ + """Extracts a runtime benchmark result from the passed Explanation.""" def __call__(self, explanation, name): return BenchmarkResult("compute time", name, value=explanation.compute_time / explanation.shape[0]) diff --git a/shap/benchmark/_explanation_error.py b/shap/benchmark/_explanation_error.py index d325adcfe..fd052bf9f 100644 --- a/shap/benchmark/_explanation_error.py +++ b/shap/benchmark/_explanation_error.py @@ -12,7 +12,7 @@ class ExplanationError: - """ A measure of the explanation error relative to a model's actual output. + """A measure of the explanation error relative to a model's actual output. This benchmark metric measures the discrepancy between the output of the model predicted by an attribution explanation vs. the actual output of the model. This discrepancy is measured over @@ -27,7 +27,7 @@ class ExplanationError: """ def __init__(self, masker, model, *model_args, batch_size=500, num_permutations=10, link=links.identity, linearize_link=True, seed=38923): - """ Build a new explanation error benchmarker with the given masker, model, and model args. + """Build a new explanation error benchmarker with the given masker, model, and model args. Parameters ---------- @@ -58,8 +58,8 @@ def __init__(self, masker, model, *model_args, batch_size=500, num_permutations= linearize_link : bool Non-linear links can destroy additive separation in generalized linear models, so by linearizing the link we can retain additive separation. See upcoming paper/doc for details. - """ + """ self.masker = masker self.model = model self.model_args = model_args @@ -80,9 +80,7 @@ def __init__(self, masker, model, *model_args, batch_size=500, num_permutations= self.data_type = "tabular" def __call__(self, explanation, name, step_fraction=0.01, indices=[], silent=False): - """ Run this benchmark on the given explanation. - """ - + """Run this benchmark on the given explanation.""" if isinstance(explanation, np.ndarray): attributions = explanation elif isinstance(explanation, Explanation): diff --git a/shap/benchmark/_result.py b/shap/benchmark/_result.py index 0c31e0f0f..f4236ac88 100644 --- a/shap/benchmark/_result.py +++ b/shap/benchmark/_result.py @@ -13,8 +13,7 @@ } class BenchmarkResult: - """ The result of a benchmark run. - """ + """The result of a benchmark run.""" def __init__(self, metric, method, value=None, curve_x=None, curve_y=None, curve_y_std=None, value_sign=None): self.metric = metric diff --git a/shap/benchmark/_sequential.py b/shap/benchmark/_sequential.py index 4c5eca223..c4f5c6045 100644 --- a/shap/benchmark/_sequential.py +++ b/shap/benchmark/_sequential.py @@ -198,9 +198,7 @@ def __call__(self, name, explanation, *model_args, percent=0.01, indices=[], y=N return mask_vals, curves, aucs def score(self, explanation, X, percent=0.01, y=None, label=None, silent=False, debug_mode=False): - ''' - Will be deprecated once MaskedModel is in complete support - ''' + """Will be deprecated once MaskedModel is in complete support""" # if explainer is already the attributions if isinstance(explanation, np.ndarray): attributions = explanation diff --git a/shap/benchmark/experiments.py b/shap/benchmark/experiments.py index 5800d8330..ab70162c2 100644 --- a/shap/benchmark/experiments.py +++ b/shap/benchmark/experiments.py @@ -320,7 +320,7 @@ def __print_status(): def run_remote_experiments(experiments, thread_hosts, rate_limit=10): - """ Use ssh to run the experiments on remote machines in parallel. + """Use ssh to run the experiments on remote machines in parallel. Parameters ---------- @@ -333,8 +333,8 @@ def run_remote_experiments(experiments, thread_hosts, rate_limit=10): rate_limit : int How many ssh connections we make per minute to each host (to avoid throttling issues). - """ + """ global ssh_conn_per_min_limit ssh_conn_per_min_limit = rate_limit diff --git a/shap/benchmark/measures.py b/shap/benchmark/measures.py index d7a0fe291..bccbcad62 100644 --- a/shap/benchmark/measures.py +++ b/shap/benchmark/measures.py @@ -7,7 +7,7 @@ _remove_cache = {} def remove_retrain(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): - """ The model is retrained for each test sample with the important features set to a constant. + """The model is retrained for each test sample with the important features set to a constant. If you want to know how important a set of features is you can ask how the model would be different if those features had never existed. To determine this we can mask those features @@ -18,7 +18,6 @@ def remove_retrain(nmask, X_train, y_train, X_test, y_test, attr_test, model_gen to get the change in model performance when a specified fraction of the most important features are withheld. """ - warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!") # see if we match the last cached call @@ -70,9 +69,7 @@ def remove_retrain(nmask, X_train, y_train, X_test, y_test, attr_test, model_gen return metric(y_test, yp_masked_test) def remove_mask(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): - """ Each test sample is masked by setting the important features to a constant. - """ - + """Each test sample is masked by setting the important features to a constant.""" X_train, X_test = to_array(X_train, X_test) # how many features to mask @@ -92,13 +89,12 @@ def remove_mask(nmask, X_train, y_train, X_test, y_test, attr_test, model_genera return metric(y_test, yp_masked_test) def remove_impute(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): - """ The model is reevaluated for each test sample with the important features set to an imputed value. + """The model is reevaluated for each test sample with the important features set to an imputed value. Note that the imputation is done using a multivariate normality assumption on the dataset. This depends on being able to estimate the full data covariance matrix (and inverse) accuractly. So X_train.shape[0] should be significantly bigger than X_train.shape[1]. """ - X_train, X_test = to_array(X_train, X_test) # how many features to mask @@ -129,9 +125,7 @@ def remove_impute(nmask, X_train, y_train, X_test, y_test, attr_test, model_gene return metric(y_test, yp_masked_test) def remove_resample(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): - """ The model is reevaluated for each test sample with the important features set to resample background values. - """ - + """The model is reevaluated for each test sample with the important features set to resample background values.""" X_train, X_test = to_array(X_train, X_test) # how many features to mask @@ -156,14 +150,13 @@ def remove_resample(nmask, X_train, y_train, X_test, y_test, attr_test, model_ge return metric(y_test, yp_masked_test) def batch_remove_retrain(nmask_train, nmask_test, X_train, y_train, X_test, y_test, attr_train, attr_test, model_generator, metric): - """ An approximation of holdout that only retraines the model once. + """An approximation of holdout that only retraines the model once. This is also called ROAR (RemOve And Retrain) in work by Google. It is much more computationally efficient that the holdout method because it masks the most important features in every sample and then retrains the model once, instead of retraining the model for every test sample like the holdout metric. """ - warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!") X_train, X_test = to_array(X_train, X_test) @@ -194,7 +187,7 @@ def batch_remove_retrain(nmask_train, nmask_test, X_train, y_train, X_test, y_te _keep_cache = {} def keep_retrain(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): - """ The model is retrained for each test sample with the non-important features set to a constant. + """The model is retrained for each test sample with the non-important features set to a constant. If you want to know how important a set of features is you can ask how the model would be different if only those features had existed. To determine this we can mask the other features @@ -205,7 +198,6 @@ def keep_retrain(nkeep, X_train, y_train, X_test, y_test, attr_test, model_gener to get the change in model performance when a specified fraction of the most important features are retained. """ - warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!") # see if we match the last cached call @@ -258,9 +250,7 @@ def keep_retrain(nkeep, X_train, y_train, X_test, y_test, attr_test, model_gener return metric(y_test, yp_masked_test) def keep_mask(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): - """ The model is reevaluated for each test sample with the non-important features set to their mean. - """ - + """The model is reevaluated for each test sample with the non-important features set to their mean.""" X_train, X_test = to_array(X_train, X_test) # how many features to mask @@ -281,13 +271,12 @@ def keep_mask(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generato return metric(y_test, yp_masked_test) def keep_impute(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): - """ The model is reevaluated for each test sample with the non-important features set to an imputed value. + """The model is reevaluated for each test sample with the non-important features set to an imputed value. Note that the imputation is done using a multivariate normality assumption on the dataset. This depends on being able to estimate the full data covariance matrix (and inverse) accuractly. So X_train.shape[0] should be significantly bigger than X_train.shape[1]. """ - X_train, X_test = to_array(X_train, X_test) # how many features to mask @@ -318,9 +307,7 @@ def keep_impute(nkeep, X_train, y_train, X_test, y_test, attr_test, model_genera return metric(y_test, yp_masked_test) def keep_resample(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): - """ The model is reevaluated for each test sample with the non-important features set to resample background values. - """ # why broken? overwriting? - + """The model is reevaluated for each test sample with the non-important features set to resample background values.""" # why broken? overwriting? X_train, X_test = to_array(X_train, X_test) # how many features to mask @@ -345,14 +332,13 @@ def keep_resample(nkeep, X_train, y_train, X_test, y_test, attr_test, model_gene return metric(y_test, yp_masked_test) def batch_keep_retrain(nkeep_train, nkeep_test, X_train, y_train, X_test, y_test, attr_train, attr_test, model_generator, metric): - """ An approximation of keep that only retraines the model once. + """An approximation of keep that only retraines the model once. This is also called KAR (Keep And Retrain) in work by Google. It is much more computationally efficient that the keep method because it masks the unimportant features in every sample and then retrains the model once, instead of retraining the model for every test sample like the keep metric. """ - warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!") X_train, X_test = to_array(X_train, X_test) @@ -382,9 +368,7 @@ def batch_keep_retrain(nkeep_train, nkeep_test, X_train, y_train, X_test, y_test return metric(y_test, yp_test_masked) def local_accuracy(X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model): - """ The how well do the features plus a constant base rate sum up to the model output. - """ - + """The how well do the features plus a constant base rate sum up to the model output.""" X_train, X_test = to_array(X_train, X_test) # how many features to mask @@ -399,8 +383,7 @@ def to_array(*args): return [a.values if isinstance(a, pd.DataFrame) else a for a in args] def const_rand(size, seed=23980): - """ Generate a random array with a fixed seed. - """ + """Generate a random array with a fixed seed.""" old_seed = np.random.seed() np.random.seed(seed) out = np.random.rand(size) @@ -408,16 +391,14 @@ def const_rand(size, seed=23980): return out def const_shuffle(arr, seed=23980): - """ Shuffle an array in-place with a fixed seed. - """ + """Shuffle an array in-place with a fixed seed.""" old_seed = np.random.seed() np.random.seed(seed) np.random.shuffle(arr) np.random.seed(old_seed) def strip_list(attrs): - """ This assumes that if you have a list of outputs you just want the second one (the second class is the '1' class). - """ + """This assumes that if you have a list of outputs you just want the second one (the second class is the '1' class).""" if isinstance(attrs, list): return attrs[1] else: diff --git a/shap/benchmark/methods.py b/shap/benchmark/methods.py index 0e7766454..795294487 100644 --- a/shap/benchmark/methods.py +++ b/shap/benchmark/methods.py @@ -15,50 +15,47 @@ def linear_shap_corr(model, data): - """ Linear SHAP (corr 1000) - """ + """Linear SHAP (corr 1000)""" return LinearExplainer(model, data, feature_perturbation="correlation_dependent", nsamples=1000).shap_values def linear_shap_ind(model, data): - """ Linear SHAP (ind) - """ + """Linear SHAP (ind)""" return LinearExplainer(model, data, feature_perturbation="interventional").shap_values def coef(model, data): - """ Coefficients - """ + """Coefficients""" return other.CoefficentExplainer(model).attributions def random(model, data): - """ Random + """Random color = #777777 linestyle = solid """ return other.RandomExplainer().attributions def kernel_shap_1000_meanref(model, data): - """ Kernel SHAP 1000 mean ref. + """Kernel SHAP 1000 mean ref. color = red_blue_circle(0.5) linestyle = solid """ return lambda X: KernelExplainer(model.predict, kmeans(data, 1)).shap_values(X, nsamples=1000, l1_reg=0) def sampling_shap_1000(model, data): - """ IME 1000 + """IME 1000 color = red_blue_circle(0.5) linestyle = dashed """ return lambda X: SamplingExplainer(model.predict, data).shap_values(X, nsamples=1000) def tree_shap_tree_path_dependent(model, data): - """ TreeExplainer + """TreeExplainer color = red_blue_circle(0) linestyle = solid """ return TreeExplainer(model, feature_perturbation="tree_path_dependent").shap_values def tree_shap_independent_200(model, data): - """ TreeExplainer (independent) + """TreeExplainer (independent) color = red_blue_circle(0) linestyle = dashed """ @@ -66,7 +63,7 @@ def tree_shap_independent_200(model, data): return TreeExplainer(model, data_subsample, feature_perturbation="interventional").shap_values def mean_abs_tree_shap(model, data): - """ mean(|TreeExplainer|) + """mean(|TreeExplainer|) color = red_blue_circle(0.25) linestyle = solid """ @@ -79,47 +76,46 @@ def f(X): return f def saabas(model, data): - """ Saabas + """Saabas color = red_blue_circle(0) linestyle = dotted """ return lambda X: TreeExplainer(model).shap_values(X, approximate=True) def tree_gain(model, data): - """ Gain/Gini Importance + """Gain/Gini Importance color = red_blue_circle(0.25) linestyle = dotted """ return other.TreeGainExplainer(model).attributions def lime_tabular_regression_1000(model, data): - """ LIME Tabular 1000 + """LIME Tabular 1000 color = red_blue_circle(0.75) """ return lambda X: other.LimeTabularExplainer(model.predict, data, mode="regression").attributions(X, nsamples=1000) def lime_tabular_classification_1000(model, data): - """ LIME Tabular 1000 + """LIME Tabular 1000 color = red_blue_circle(0.75) """ return lambda X: other.LimeTabularExplainer(model.predict_proba, data, mode="classification").attributions(X, nsamples=1000)[1] def maple(model, data): - """ MAPLE + """MAPLE color = red_blue_circle(0.6) """ return lambda X: other.MapleExplainer(model.predict, data).attributions(X, multiply_by_input=False) def tree_maple(model, data): - """ Tree MAPLE + """Tree MAPLE color = red_blue_circle(0.6) linestyle = dashed """ return lambda X: other.TreeMapleExplainer(model, data).attributions(X, multiply_by_input=False) def deep_shap(model, data): - """ Deep SHAP (DeepLIFT) - """ + """Deep SHAP (DeepLIFT)""" if isinstance(model, KerasWrap): model = model.model explainer = DeepExplainer(model, kmeans(data, 1).data) @@ -133,8 +129,7 @@ def f(X): return f def expected_gradients(model, data): - """ Expected Gradients - """ + """Expected Gradients""" if isinstance(model, KerasWrap): model = model.model explainer = GradientExplainer(model, data) diff --git a/shap/benchmark/metrics.py b/shap/benchmark/metrics.py index 1ff8db7a3..2bf7a9cec 100644 --- a/shap/benchmark/metrics.py +++ b/shap/benchmark/metrics.py @@ -20,11 +20,10 @@ def runtime(X, y, model_generator, method_name): - """ Runtime (sec / 1k samples) + """Runtime (sec / 1k samples) transform = "negate_log" sort_order = 2 """ - old_seed = np.random.seed() np.random.seed(3293) @@ -54,14 +53,13 @@ def runtime(X, y, model_generator, method_name): return None, np.mean(method_reps) def local_accuracy(X, y, model_generator, method_name): - """ Local Accuracy + """Local Accuracy transform = "identity" sort_order = 0 """ def score_map(true, pred): - """ Computes local accuracy as the normalized standard deviation of numerical scores. - """ + """Computes local accuracy as the normalized standard deviation of numerical scores.""" return np.std(pred - true) / (np.std(true) + 1e-6) def score_function(X_train, X_test, y_train, y_test, attr_function, trained_model, random_state): @@ -72,11 +70,10 @@ def score_function(X_train, X_test, y_train, y_test, attr_function, trained_mode return None, __score_method(X, y, None, model_generator, score_function, method_name) def consistency_guarantees(X, y, model_generator, method_name): - """ Consistency Guarantees + """Consistency Guarantees transform = "identity" sort_order = 1 """ - # 1.0 - perfect consistency # 0.8 - guarantees depend on sampling # 0.6 - guarantees depend on approximation @@ -104,12 +101,11 @@ def consistency_guarantees(X, y, model_generator, method_name): return None, guarantees[method_name] def __mean_pred(true, pred): - """ A trivial metric that is just is the output of the model. - """ + """A trivial metric that is just is the output of the model.""" return np.mean(pred) def keep_positive_mask(X, y, model_generator, method_name, num_fcounts=11): - """ Keep Positive (mask) + """Keep Positive (mask) xlabel = "Max fraction of features kept" ylabel = "Mean model output" transform = "identity" @@ -118,7 +114,7 @@ def keep_positive_mask(X, y, model_generator, method_name, num_fcounts=11): return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) def keep_negative_mask(X, y, model_generator, method_name, num_fcounts=11): - """ Keep Negative (mask) + """Keep Negative (mask) xlabel = "Max fraction of features kept" ylabel = "Negative mean model output" transform = "negate" @@ -127,7 +123,7 @@ def keep_negative_mask(X, y, model_generator, method_name, num_fcounts=11): return __run_measure(measures.keep_mask, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred) def keep_absolute_mask__r2(X, y, model_generator, method_name, num_fcounts=11): - """ Keep Absolute (mask) + """Keep Absolute (mask) xlabel = "Max fraction of features kept" ylabel = "R^2" transform = "identity" @@ -136,7 +132,7 @@ def keep_absolute_mask__r2(X, y, model_generator, method_name, num_fcounts=11): return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score) def keep_absolute_mask__roc_auc(X, y, model_generator, method_name, num_fcounts=11): - """ Keep Absolute (mask) + """Keep Absolute (mask) xlabel = "Max fraction of features kept" ylabel = "ROC AUC" transform = "identity" @@ -145,7 +141,7 @@ def keep_absolute_mask__roc_auc(X, y, model_generator, method_name, num_fcounts= return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score) def remove_positive_mask(X, y, model_generator, method_name, num_fcounts=11): - """ Remove Positive (mask) + """Remove Positive (mask) xlabel = "Max fraction of features removed" ylabel = "Negative mean model output" transform = "negate" @@ -154,7 +150,7 @@ def remove_positive_mask(X, y, model_generator, method_name, num_fcounts=11): return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) def remove_negative_mask(X, y, model_generator, method_name, num_fcounts=11): - """ Remove Negative (mask) + """Remove Negative (mask) xlabel = "Max fraction of features removed" ylabel = "Mean model output" transform = "identity" @@ -163,7 +159,7 @@ def remove_negative_mask(X, y, model_generator, method_name, num_fcounts=11): return __run_measure(measures.remove_mask, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred) def remove_absolute_mask__r2(X, y, model_generator, method_name, num_fcounts=11): - """ Remove Absolute (mask) + """Remove Absolute (mask) xlabel = "Max fraction of features removed" ylabel = "1 - R^2" transform = "one_minus" @@ -172,7 +168,7 @@ def remove_absolute_mask__r2(X, y, model_generator, method_name, num_fcounts=11) return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score) def remove_absolute_mask__roc_auc(X, y, model_generator, method_name, num_fcounts=11): - """ Remove Absolute (mask) + """Remove Absolute (mask) xlabel = "Max fraction of features removed" ylabel = "1 - ROC AUC" transform = "one_minus" @@ -181,7 +177,7 @@ def remove_absolute_mask__roc_auc(X, y, model_generator, method_name, num_fcount return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score) def keep_positive_resample(X, y, model_generator, method_name, num_fcounts=11): - """ Keep Positive (resample) + """Keep Positive (resample) xlabel = "Max fraction of features kept" ylabel = "Mean model output" transform = "identity" @@ -190,7 +186,7 @@ def keep_positive_resample(X, y, model_generator, method_name, num_fcounts=11): return __run_measure(measures.keep_resample, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) def keep_negative_resample(X, y, model_generator, method_name, num_fcounts=11): - """ Keep Negative (resample) + """Keep Negative (resample) xlabel = "Max fraction of features kept" ylabel = "Negative mean model output" transform = "negate" @@ -199,7 +195,7 @@ def keep_negative_resample(X, y, model_generator, method_name, num_fcounts=11): return __run_measure(measures.keep_resample, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred) def keep_absolute_resample__r2(X, y, model_generator, method_name, num_fcounts=11): - """ Keep Absolute (resample) + """Keep Absolute (resample) xlabel = "Max fraction of features kept" ylabel = "R^2" transform = "identity" @@ -208,7 +204,7 @@ def keep_absolute_resample__r2(X, y, model_generator, method_name, num_fcounts=1 return __run_measure(measures.keep_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score) def keep_absolute_resample__roc_auc(X, y, model_generator, method_name, num_fcounts=11): - """ Keep Absolute (resample) + """Keep Absolute (resample) xlabel = "Max fraction of features kept" ylabel = "ROC AUC" transform = "identity" @@ -217,7 +213,7 @@ def keep_absolute_resample__roc_auc(X, y, model_generator, method_name, num_fcou return __run_measure(measures.keep_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score) def remove_positive_resample(X, y, model_generator, method_name, num_fcounts=11): - """ Remove Positive (resample) + """Remove Positive (resample) xlabel = "Max fraction of features removed" ylabel = "Negative mean model output" transform = "negate" @@ -226,7 +222,7 @@ def remove_positive_resample(X, y, model_generator, method_name, num_fcounts=11) return __run_measure(measures.remove_resample, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) def remove_negative_resample(X, y, model_generator, method_name, num_fcounts=11): - """ Remove Negative (resample) + """Remove Negative (resample) xlabel = "Max fraction of features removed" ylabel = "Mean model output" transform = "identity" @@ -235,7 +231,7 @@ def remove_negative_resample(X, y, model_generator, method_name, num_fcounts=11) return __run_measure(measures.remove_resample, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred) def remove_absolute_resample__r2(X, y, model_generator, method_name, num_fcounts=11): - """ Remove Absolute (resample) + """Remove Absolute (resample) xlabel = "Max fraction of features removed" ylabel = "1 - R^2" transform = "one_minus" @@ -244,7 +240,7 @@ def remove_absolute_resample__r2(X, y, model_generator, method_name, num_fcounts return __run_measure(measures.remove_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score) def remove_absolute_resample__roc_auc(X, y, model_generator, method_name, num_fcounts=11): - """ Remove Absolute (resample) + """Remove Absolute (resample) xlabel = "Max fraction of features removed" ylabel = "1 - ROC AUC" transform = "one_minus" @@ -253,7 +249,7 @@ def remove_absolute_resample__roc_auc(X, y, model_generator, method_name, num_fc return __run_measure(measures.remove_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score) def keep_positive_impute(X, y, model_generator, method_name, num_fcounts=11): - """ Keep Positive (impute) + """Keep Positive (impute) xlabel = "Max fraction of features kept" ylabel = "Mean model output" transform = "identity" @@ -262,7 +258,7 @@ def keep_positive_impute(X, y, model_generator, method_name, num_fcounts=11): return __run_measure(measures.keep_impute, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) def keep_negative_impute(X, y, model_generator, method_name, num_fcounts=11): - """ Keep Negative (impute) + """Keep Negative (impute) xlabel = "Max fraction of features kept" ylabel = "Negative mean model output" transform = "negate" @@ -271,7 +267,7 @@ def keep_negative_impute(X, y, model_generator, method_name, num_fcounts=11): return __run_measure(measures.keep_impute, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred) def keep_absolute_impute__r2(X, y, model_generator, method_name, num_fcounts=11): - """ Keep Absolute (impute) + """Keep Absolute (impute) xlabel = "Max fraction of features kept" ylabel = "R^2" transform = "identity" @@ -280,7 +276,7 @@ def keep_absolute_impute__r2(X, y, model_generator, method_name, num_fcounts=11) return __run_measure(measures.keep_impute, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score) def keep_absolute_impute__roc_auc(X, y, model_generator, method_name, num_fcounts=11): - """ Keep Absolute (impute) + """Keep Absolute (impute) xlabel = "Max fraction of features kept" ylabel = "ROC AUC" transform = "identity" @@ -289,7 +285,7 @@ def keep_absolute_impute__roc_auc(X, y, model_generator, method_name, num_fcount return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score) def remove_positive_impute(X, y, model_generator, method_name, num_fcounts=11): - """ Remove Positive (impute) + """Remove Positive (impute) xlabel = "Max fraction of features removed" ylabel = "Negative mean model output" transform = "negate" @@ -298,7 +294,7 @@ def remove_positive_impute(X, y, model_generator, method_name, num_fcounts=11): return __run_measure(measures.remove_impute, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) def remove_negative_impute(X, y, model_generator, method_name, num_fcounts=11): - """ Remove Negative (impute) + """Remove Negative (impute) xlabel = "Max fraction of features removed" ylabel = "Mean model output" transform = "identity" @@ -307,7 +303,7 @@ def remove_negative_impute(X, y, model_generator, method_name, num_fcounts=11): return __run_measure(measures.remove_impute, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred) def remove_absolute_impute__r2(X, y, model_generator, method_name, num_fcounts=11): - """ Remove Absolute (impute) + """Remove Absolute (impute) xlabel = "Max fraction of features removed" ylabel = "1 - R^2" transform = "one_minus" @@ -316,7 +312,7 @@ def remove_absolute_impute__r2(X, y, model_generator, method_name, num_fcounts=1 return __run_measure(measures.remove_impute, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score) def remove_absolute_impute__roc_auc(X, y, model_generator, method_name, num_fcounts=11): - """ Remove Absolute (impute) + """Remove Absolute (impute) xlabel = "Max fraction of features removed" ylabel = "1 - ROC AUC" transform = "one_minus" @@ -325,7 +321,7 @@ def remove_absolute_impute__roc_auc(X, y, model_generator, method_name, num_fcou return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score) def keep_positive_retrain(X, y, model_generator, method_name, num_fcounts=11): - """ Keep Positive (retrain) + """Keep Positive (retrain) xlabel = "Max fraction of features kept" ylabel = "Mean model output" transform = "identity" @@ -334,7 +330,7 @@ def keep_positive_retrain(X, y, model_generator, method_name, num_fcounts=11): return __run_measure(measures.keep_retrain, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) def keep_negative_retrain(X, y, model_generator, method_name, num_fcounts=11): - """ Keep Negative (retrain) + """Keep Negative (retrain) xlabel = "Max fraction of features kept" ylabel = "Negative mean model output" transform = "negate" @@ -343,7 +339,7 @@ def keep_negative_retrain(X, y, model_generator, method_name, num_fcounts=11): return __run_measure(measures.keep_retrain, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred) def remove_positive_retrain(X, y, model_generator, method_name, num_fcounts=11): - """ Remove Positive (retrain) + """Remove Positive (retrain) xlabel = "Max fraction of features removed" ylabel = "Negative mean model output" transform = "negate" @@ -352,7 +348,7 @@ def remove_positive_retrain(X, y, model_generator, method_name, num_fcounts=11): return __run_measure(measures.remove_retrain, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) def remove_negative_retrain(X, y, model_generator, method_name, num_fcounts=11): - """ Remove Negative (retrain) + """Remove Negative (retrain) xlabel = "Max fraction of features removed" ylabel = "Mean model output" transform = "identity" @@ -377,7 +373,7 @@ def score_function(fcount, X_train, X_test, y_train, y_test, attr_function, trai return fcounts, __score_method(X, y, fcounts, model_generator, score_function, method_name) def batch_remove_absolute_retrain__r2(X, y, model_generator, method_name, num_fcounts=11): - """ Batch Remove Absolute (retrain) + """Batch Remove Absolute (retrain) xlabel = "Fraction of features removed" ylabel = "1 - R^2" transform = "one_minus" @@ -386,7 +382,7 @@ def batch_remove_absolute_retrain__r2(X, y, model_generator, method_name, num_fc return __run_batch_abs_metric(measures.batch_remove_retrain, X, y, model_generator, method_name, sklearn.metrics.r2_score, num_fcounts) def batch_keep_absolute_retrain__r2(X, y, model_generator, method_name, num_fcounts=11): - """ Batch Keep Absolute (retrain) + """Batch Keep Absolute (retrain) xlabel = "Fraction of features kept" ylabel = "R^2" transform = "identity" @@ -395,7 +391,7 @@ def batch_keep_absolute_retrain__r2(X, y, model_generator, method_name, num_fcou return __run_batch_abs_metric(measures.batch_keep_retrain, X, y, model_generator, method_name, sklearn.metrics.r2_score, num_fcounts) def batch_remove_absolute_retrain__roc_auc(X, y, model_generator, method_name, num_fcounts=11): - """ Batch Remove Absolute (retrain) + """Batch Remove Absolute (retrain) xlabel = "Fraction of features removed" ylabel = "1 - ROC AUC" transform = "one_minus" @@ -404,7 +400,7 @@ def batch_remove_absolute_retrain__roc_auc(X, y, model_generator, method_name, n return __run_batch_abs_metric(measures.batch_remove_retrain, X, y, model_generator, method_name, sklearn.metrics.roc_auc_score, num_fcounts) def batch_keep_absolute_retrain__roc_auc(X, y, model_generator, method_name, num_fcounts=11): - """ Batch Keep Absolute (retrain) + """Batch Keep Absolute (retrain) xlabel = "Fraction of features kept" ylabel = "ROC AUC" transform = "identity" @@ -429,9 +425,7 @@ def score_function(fcount, X_train, X_test, y_train, y_test, attr_function, trai _attribution_cache = {} def __score_method(X, y, fcounts, model_generator, score_function, method_name, nreps=10, test_size=100, cache_dir="/tmp"): - """ Test an explanation method. - """ - + """Test an explanation method.""" try: pickle except NameError: @@ -512,15 +506,13 @@ def __intlogspace(start, end, count): return np.unique(np.round(start + (end-start) * (np.logspace(0, 1, count, endpoint=True) - 1) / 9).astype(int)) def __toarray(X): - """ Converts DataFrames to numpy arrays. - """ + """Converts DataFrames to numpy arrays.""" if hasattr(X, "values"): X = X.values return X def __strip_list(attrs): - """ This assumes that if you have a list of outputs you just want the second one (the second class). - """ + """This assumes that if you have a list of outputs you just want the second one (the second class).""" if isinstance(attrs, list): return attrs[1] else: @@ -566,7 +558,7 @@ def _human_and(X, model_generator, method_name, fever, cough): return "human", (human_consensus, methods_attrs[0,:]) def human_and_00(X, y, model_generator, method_name): - """ AND (false/false) + """AND (false/false) This tests how well a feature attribution method agrees with human intuition for an AND operation combined with linear effects. This metric deals @@ -582,7 +574,7 @@ def human_and_00(X, y, model_generator, method_name): return _human_and(X, model_generator, method_name, False, False) def human_and_01(X, y, model_generator, method_name): - """ AND (false/true) + """AND (false/true) This tests how well a feature attribution method agrees with human intuition for an AND operation combined with linear effects. This metric deals @@ -598,7 +590,7 @@ def human_and_01(X, y, model_generator, method_name): return _human_and(X, model_generator, method_name, False, True) def human_and_11(X, y, model_generator, method_name): - """ AND (true/true) + """AND (true/true) This tests how well a feature attribution method agrees with human intuition for an AND operation combined with linear effects. This metric deals @@ -637,7 +629,7 @@ def _human_or(X, model_generator, method_name, fever, cough): return "human", (human_consensus, methods_attrs[0,:]) def human_or_00(X, y, model_generator, method_name): - """ OR (false/false) + """OR (false/false) This tests how well a feature attribution method agrees with human intuition for an OR operation combined with linear effects. This metric deals @@ -653,7 +645,7 @@ def human_or_00(X, y, model_generator, method_name): return _human_or(X, model_generator, method_name, False, False) def human_or_01(X, y, model_generator, method_name): - """ OR (false/true) + """OR (false/true) This tests how well a feature attribution method agrees with human intuition for an OR operation combined with linear effects. This metric deals @@ -669,7 +661,7 @@ def human_or_01(X, y, model_generator, method_name): return _human_or(X, model_generator, method_name, False, True) def human_or_11(X, y, model_generator, method_name): - """ OR (true/true) + """OR (true/true) This tests how well a feature attribution method agrees with human intuition for an OR operation combined with linear effects. This metric deals @@ -708,7 +700,7 @@ def _human_xor(X, model_generator, method_name, fever, cough): return "human", (human_consensus, methods_attrs[0,:]) def human_xor_00(X, y, model_generator, method_name): - """ XOR (false/false) + """XOR (false/false) This tests how well a feature attribution method agrees with human intuition for an eXclusive OR operation combined with linear effects. This metric deals @@ -724,7 +716,7 @@ def human_xor_00(X, y, model_generator, method_name): return _human_xor(X, model_generator, method_name, False, False) def human_xor_01(X, y, model_generator, method_name): - """ XOR (false/true) + """XOR (false/true) This tests how well a feature attribution method agrees with human intuition for an eXclusive OR operation combined with linear effects. This metric deals @@ -740,7 +732,7 @@ def human_xor_01(X, y, model_generator, method_name): return _human_xor(X, model_generator, method_name, False, True) def human_xor_11(X, y, model_generator, method_name): - """ XOR (true/true) + """XOR (true/true) This tests how well a feature attribution method agrees with human intuition for an eXclusive OR operation combined with linear effects. This metric deals @@ -779,7 +771,7 @@ def _human_sum(X, model_generator, method_name, fever, cough): return "human", (human_consensus, methods_attrs[0,:]) def human_sum_00(X, y, model_generator, method_name): - """ SUM (false/false) + """SUM (false/false) This tests how well a feature attribution method agrees with human intuition for a SUM operation. This metric deals @@ -794,7 +786,7 @@ def human_sum_00(X, y, model_generator, method_name): return _human_sum(X, model_generator, method_name, False, False) def human_sum_01(X, y, model_generator, method_name): - """ SUM (false/true) + """SUM (false/true) This tests how well a feature attribution method agrees with human intuition for a SUM operation. This metric deals @@ -809,7 +801,7 @@ def human_sum_01(X, y, model_generator, method_name): return _human_sum(X, model_generator, method_name, False, True) def human_sum_11(X, y, model_generator, method_name): - """ SUM (true/true) + """SUM (true/true) This tests how well a feature attribution method agrees with human intuition for a SUM operation. This metric deals diff --git a/shap/benchmark/models.py b/shap/benchmark/models.py index 2c0a886f4..b0c045c60 100644 --- a/shap/benchmark/models.py +++ b/shap/benchmark/models.py @@ -5,8 +5,8 @@ class KerasWrap: - """ A wrapper that allows us to set parameters in the constructor and do a reset before fitting. - """ + """A wrapper that allows us to set parameters in the constructor and do a reset before fitting.""" + def __init__(self, model, epochs, flatten_output=False): self.model = model self.epochs = epochs @@ -33,38 +33,31 @@ def predict(self, X): # This models are all tuned for the corrgroups60 dataset def corrgroups60__lasso(): - """ Lasso Regression - """ + """Lasso Regression""" return sklearn.linear_model.Lasso(alpha=0.1) def corrgroups60__ridge(): - """ Ridge Regression - """ + """Ridge Regression""" return sklearn.linear_model.Ridge(alpha=1.0) def corrgroups60__decision_tree(): - """ Decision Tree - """ - + """Decision Tree""" # max_depth was chosen to minimise test error return sklearn.tree.DecisionTreeRegressor(random_state=0, max_depth=6) def corrgroups60__random_forest(): - """ Random Forest - """ + """Random Forest""" return sklearn.ensemble.RandomForestRegressor(100, random_state=0) def corrgroups60__gbm(): - """ Gradient Boosted Trees - """ + """Gradient Boosted Trees""" import xgboost # max_depth and learning_rate were fixed then n_estimators was chosen using a train/test split return xgboost.XGBRegressor(max_depth=6, n_estimators=50, learning_rate=0.1, n_jobs=8, random_state=0) def corrgroups60__ffnn(): - """ 4-Layer Neural Network - """ + """4-Layer Neural Network""" from tensorflow.keras.layers import Dense from tensorflow.keras.models import Sequential @@ -82,38 +75,31 @@ def corrgroups60__ffnn(): def independentlinear60__lasso(): - """ Lasso Regression - """ + """Lasso Regression""" return sklearn.linear_model.Lasso(alpha=0.1) def independentlinear60__ridge(): - """ Ridge Regression - """ + """Ridge Regression""" return sklearn.linear_model.Ridge(alpha=1.0) def independentlinear60__decision_tree(): - """ Decision Tree - """ - + """Decision Tree""" # max_depth was chosen to minimise test error return sklearn.tree.DecisionTreeRegressor(random_state=0, max_depth=4) def independentlinear60__random_forest(): - """ Random Forest - """ + """Random Forest""" return sklearn.ensemble.RandomForestRegressor(100, random_state=0) def independentlinear60__gbm(): - """ Gradient Boosted Trees - """ + """Gradient Boosted Trees""" import xgboost # max_depth and learning_rate were fixed then n_estimators was chosen using a train/test split return xgboost.XGBRegressor(max_depth=6, n_estimators=100, learning_rate=0.1, n_jobs=8, random_state=0) def independentlinear60__ffnn(): - """ 4-Layer Neural Network - """ + """4-Layer Neural Network""" from tensorflow.keras.layers import Dense from tensorflow.keras.models import Sequential @@ -131,8 +117,7 @@ def independentlinear60__ffnn(): def cric__lasso(): - """ Lasso Regression - """ + """Lasso Regression""" model = sklearn.linear_model.LogisticRegression(penalty="l1", C=0.002) # we want to explain the raw probability outputs of the trees @@ -141,8 +126,7 @@ def cric__lasso(): return model def cric__ridge(): - """ Ridge Regression - """ + """Ridge Regression""" model = sklearn.linear_model.LogisticRegression(penalty="l2") # we want to explain the raw probability outputs of the trees @@ -151,8 +135,7 @@ def cric__ridge(): return model def cric__decision_tree(): - """ Decision Tree - """ + """Decision Tree""" model = sklearn.tree.DecisionTreeClassifier(random_state=0, max_depth=4) # we want to explain the raw probability outputs of the trees @@ -161,8 +144,7 @@ def cric__decision_tree(): return model def cric__random_forest(): - """ Random Forest - """ + """Random Forest""" model = sklearn.ensemble.RandomForestClassifier(100, random_state=0) # we want to explain the raw probability outputs of the trees @@ -171,8 +153,7 @@ def cric__random_forest(): return model def cric__gbm(): - """ Gradient Boosted Trees - """ + """Gradient Boosted Trees""" import xgboost # max_depth and subsample match the params used for the full cric data in the paper @@ -187,8 +168,7 @@ def cric__gbm(): return model def cric__ffnn(): - """ 4-Layer Neural Network - """ + """4-Layer Neural Network""" from tensorflow.keras.layers import Dense, Dropout from tensorflow.keras.models import Sequential @@ -207,9 +187,7 @@ def cric__ffnn(): def human__decision_tree(): - """ Decision Tree - """ - + """Decision Tree""" # build data N = 1000000 M = 3 diff --git a/shap/benchmark/plots.py b/shap/benchmark/plots.py index 56bb204b7..b47560266 100644 --- a/shap/benchmark/plots.py +++ b/shap/benchmark/plots.py @@ -371,9 +371,7 @@ def plot_human(dataset, model, metric, cmap=benchmark_color_map): return pl.gcf() def _human_score_map(human_consensus, methods_attrs): - """ Converts human agreement differences to numerical scores for coloring. - """ - + """Converts human agreement differences to numerical scores for coloring.""" v = 1 - min(np.sum(np.abs(methods_attrs - human_consensus)) / (np.abs(human_consensus).sum() + 1), 1.0) return v diff --git a/shap/datasets.py b/shap/datasets.py index e33bce1f2..a5bce4aa1 100644 --- a/shap/datasets.py +++ b/shap/datasets.py @@ -11,7 +11,7 @@ def imagenet50(display=False, resolution=224, n_points=None): - """ This is a set of 50 images representative of ImageNet images. + """This is a set of 50 images representative of ImageNet images. This dataset was collected by randomly finding a working ImageNet link and then pasting the original ImageNet image into Google image search restricted to images licensed for reuse. A @@ -21,7 +21,6 @@ def imagenet50(display=False, resolution=224, n_points=None): Note that because the images are only rough replacements the labels might no longer be correct. """ - prefix = github_data_url + "imagenet50_" X = np.load(cache(f"{prefix}{resolution}x{resolution}.npy")).astype(np.float32) y = np.loadtxt(cache(f"{prefix}labels.csv")) @@ -34,8 +33,7 @@ def imagenet50(display=False, resolution=224, n_points=None): def california(display=False, n_points=None): - """ Return the california housing data in a nice package. """ - + """Return the california housing data in a nice package.""" d = sklearn.datasets.fetch_california_housing() df = pd.DataFrame(data=d.data, columns=d.feature_names) target = d.target @@ -48,8 +46,7 @@ def california(display=False, n_points=None): def linnerud(display=False, n_points=None): - """ Return the linnerud data in a nice package (multi-target regression). """ - + """Return the linnerud data in a nice package (multi-target regression).""" d = sklearn.datasets.load_linnerud() X = pd.DataFrame(d.data, columns=d.feature_names) y = pd.DataFrame(d.target, columns=d.target_names) @@ -62,12 +59,11 @@ def linnerud(display=False, n_points=None): def imdb(display=False, n_points=None): - """ Return the classic IMDB sentiment analysis training data in a nice package. + """Return the classic IMDB sentiment analysis training data in a nice package. Full data is at: http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz Paper to cite when using the data is: http://www.aclweb.org/anthology/P11-1015 """ - with open(cache(github_data_url + "imdb_train.txt"), encoding="utf-8") as f: data = f.readlines() y = np.ones(25000, dtype=bool) @@ -81,12 +77,11 @@ def imdb(display=False, n_points=None): def communitiesandcrime(display=False, n_points=None): - """ Predict total number of non-violent crimes per 100K popuation. + """Predict total number of non-violent crimes per 100K popuation. This dataset is from the classic UCI Machine Learning repository: https://archive.ics.uci.edu/ml/datasets/Communities+and+Crime+Unnormalized """ - raw_data = pd.read_csv( cache(github_data_url + "CommViolPredUnnormalizedData.txt"), na_values="?" @@ -109,8 +104,7 @@ def communitiesandcrime(display=False, n_points=None): def diabetes(display=False, n_points=None): - """ Return the diabetes data in a nice package. """ - + """Return the diabetes data in a nice package.""" d = sklearn.datasets.load_diabetes() df = pd.DataFrame(data=d.data, columns=d.feature_names) target = d.target @@ -123,8 +117,7 @@ def diabetes(display=False, n_points=None): def iris(display=False, n_points=None): - """ Return the classic iris data in a nice package. """ - + """Return the classic iris data in a nice package.""" d = sklearn.datasets.load_iris() df = pd.DataFrame(data=d.data, columns=d.feature_names) target = d.target @@ -139,7 +132,7 @@ def iris(display=False, n_points=None): def adult(display=False, n_points=None): - """ Return the Adult census data in a nice package. """ + """Return the Adult census data in a nice package.""" dtypes = [ ("Age", "float32"), ("Workclass", "category"), ("fnlwgt", "float32"), ("Education", "category"), ("Education-Num", "float32"), ("Marital Status", "category"), @@ -181,8 +174,7 @@ def adult(display=False, n_points=None): def nhanesi(display=False, n_points=None): - """ A nicely packaged version of NHANES I data with surivival times as labels. - """ + """A nicely packaged version of NHANES I data with surivival times as labels.""" X = pd.read_csv(cache(github_data_url + "NHANESI_X.csv"), index_col=0) y = pd.read_csv(cache(github_data_url + "NHANESI_y.csv"), index_col=0)["y"] @@ -198,11 +190,10 @@ def nhanesi(display=False, n_points=None): def corrgroups60(display=False, n_points=1_000): - """ Correlated Groups 60 + """Correlated Groups 60 A simulated dataset with tight correlations among distinct groups of features. """ - # set a constant seed old_seed = np.random.seed() np.random.seed(0) @@ -243,9 +234,7 @@ def f(X): def independentlinear60(display=False, n_points=1_000): - """ A simulated dataset with tight correlations among distinct groups of features. - """ - + """A simulated dataset with tight correlations among distinct groups of features.""" # set a constant seed old_seed = np.random.seed() np.random.seed(0) @@ -271,8 +260,7 @@ def f(X): def a1a(n_points=None): - """ A sparse dataset in scipy csr matrix format. - """ + """A sparse dataset in scipy csr matrix format.""" data, target = sklearn.datasets.load_svmlight_file(cache(github_data_url + 'a1a.svmlight')) if n_points is not None: @@ -283,8 +271,7 @@ def a1a(n_points=None): def rank(): - """ Ranking datasets from lightgbm repository. - """ + """Ranking datasets from lightgbm repository.""" rank_data_url = 'https://raw.githubusercontent.com/Microsoft/LightGBM/master/examples/lambdarank/' x_train, y_train = sklearn.datasets.load_svmlight_file(cache(rank_data_url + 'rank.train')) x_test, y_test = sklearn.datasets.load_svmlight_file(cache(rank_data_url + 'rank.test')) @@ -295,8 +282,7 @@ def rank(): def cache(url, file_name=None): - """ Loads a file from the URL and caches it locally. - """ + """Loads a file from the URL and caches it locally.""" if file_name is None: file_name = os.path.basename(url) data_dir = os.path.join(os.path.dirname(__file__), "cached_data") diff --git a/shap/explainers/_additive.py b/shap/explainers/_additive.py index 3eb5dc47e..a8d8c4782 100644 --- a/shap/explainers/_additive.py +++ b/shap/explainers/_additive.py @@ -5,7 +5,7 @@ class AdditiveExplainer(Explainer): - """ Computes SHAP values for generalized additive models. + """Computes SHAP values for generalized additive models. This assumes that the model only has first-order effects. Extending this to second- and third-order effects is future work (if you apply this to those models right now @@ -13,7 +13,7 @@ class AdditiveExplainer(Explainer): """ def __init__(self, model, masker, link=None, feature_names=None, linearize_link=True): - """ Build an Additive explainer for the given model using the given masker object. + """Build an Additive explainer for the given model using the given masker object. Parameters ---------- @@ -21,13 +21,14 @@ def __init__(self, model, masker, link=None, feature_names=None, linearize_link= A callable python object that executes the model given a set of input data samples. masker : function or numpy.array or pandas.DataFrame - A callable python object used to "mask" out hidden features of the form `masker(mask, *fargs)`. + A callable python object used to "mask" out hidden features of the form ``masker(mask, *fargs)``. It takes a single a binary mask and an input sample and returns a matrix of masked samples. These masked samples are evaluated using the model function and the outputs are then averaged. As a shortcut for the standard masking used by SHAP you can pass a background data matrix instead of a function and that matrix will be used for masking. To use a clustering - game structure you can pass a shap.maskers.Tabular(data, hclustering=\"correlation\") object, but + game structure you can pass a ``shap.maskers.Tabular(data, hclustering="correlation")`` object, but note that this structure information has no effect on the explanations of additive models. + """ super().__init__(model, masker, feature_names=feature_names, linearize_link=linearize_link) @@ -64,16 +65,14 @@ def __init__(self, model, masker, link=None, feature_names=None, linearize_link= self._expected_value = self._input_offsets.sum() + self._zero_offset def __call__(self, *args, max_evals=None, silent=False): - """ Explains the output of model(*args), where args represents one or more parallel iterable args. - """ - + """Explains the output of model(*args), where args represents one or more parallel iterable args.""" # we entirely rely on the general call implementation, we override just to remove **kwargs # from the function signature return super().__call__(*args, max_evals=max_evals, silent=silent) @staticmethod def supports_model_with_masker(model, masker): - """ Determines if this explainer can handle the given model. + """Determines if this explainer can handle the given model. This is an abstract static method meant to be implemented by each subclass. """ @@ -85,9 +84,7 @@ def supports_model_with_masker(model, masker): return False def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent): - """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes). - """ - + """Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes).""" x = row_args[0] inputs = np.zeros((len(x), len(x))) for i in range(len(x)): diff --git a/shap/explainers/_deep/__init__.py b/shap/explainers/_deep/__init__.py index 94b159dfa..c7b2f5eb7 100644 --- a/shap/explainers/_deep/__init__.py +++ b/shap/explainers/_deep/__init__.py @@ -21,7 +21,7 @@ class DeepExplainer(Explainer): """ def __init__(self, model, data, session=None, learning_phase_flags=None): - """ An explainer object for a differentiable model using a given background dataset. + """An explainer object for a differentiable model using a given background dataset. Note that the complexity of the method scales linearly with the number of background data samples. Passing the entire training dataset as `data` will give very accurate expected @@ -89,7 +89,7 @@ def __init__(self, model, data, session=None, learning_phase_flags=None): self.explainer.framework = framework def shap_values(self, X, ranked_outputs=None, output_rank_order='max', check_additivity=True): - """ Return approximate SHAP values for the model applied to the data given by X. + """Return approximate SHAP values for the model applied to the data given by X. Parameters ---------- @@ -130,5 +130,6 @@ def shap_values(self, X, ranked_outputs=None, output_rank_order='max', check_add .. versionchanged:: 0.45.0 Return type for models with multiple outputs and one input changed from list to np.ndarray. + """ return self.explainer.shap_values(X, ranked_outputs, output_rank_order, check_additivity=check_additivity) diff --git a/shap/explainers/_deep/deep_pytorch.py b/shap/explainers/_deep/deep_pytorch.py index b97384d18..3045e09d9 100644 --- a/shap/explainers/_deep/deep_pytorch.py +++ b/shap/explainers/_deep/deep_pytorch.py @@ -64,8 +64,7 @@ def add_target_handle(self, layer): self.target_handle = input_handle def add_handles(self, model, forward_handle, backward_handle): - """ - Add handles to all non-container layers in the model. + """Add handles to all non-container layers in the model. Recursively for non-container layers """ handles_list = [] @@ -79,8 +78,7 @@ def add_handles(self, model, forward_handle, backward_handle): return handles_list def remove_attributes(self, model): - """ - Removes the x and y attributes which were added by the forward handles + """Removes the x and y attributes which were added by the forward handles Recursively searches for non-container layers """ for child in model.children(): diff --git a/shap/explainers/_deep/deep_tf.py b/shap/explainers/_deep/deep_tf.py index b128836b0..af218bf32 100644 --- a/shap/explainers/_deep/deep_tf.py +++ b/shap/explainers/_deep/deep_tf.py @@ -15,7 +15,7 @@ tf_gradients_impl = None def custom_record_gradient(op_name, inputs, attrs, results): - """ This overrides tensorflow.python.eager.backprop._record_gradient. + """This overrides tensorflow.python.eager.backprop._record_gradient. We need to override _record_gradient in order to get gradient backprop to get called for ResourceGather operations. In order to make this work we @@ -38,14 +38,13 @@ def custom_record_gradient(op_name, inputs, attrs, results): return out class TFDeep(Explainer): - """ - Using tf.gradients to implement the backpropagation was + """Using tf.gradients to implement the backpropagation was inspired by the gradient-based implementation approach proposed by Ancona et al, ICLR 2018. Note that this package does not currently use the reveal-cancel rule for ReLu units proposed in DeepLIFT. """ def __init__(self, model, data, session=None, learning_phase_flags=None): - """ An explainer object for a deep model using a given background dataset. + """An explainer object for a deep model using a given background dataset. Note that the complexity of the method scales linearly with the number of background data samples. Passing the entire training dataset as `data` will give very accurate expected @@ -216,8 +215,7 @@ def _init_between_tensors(self, out_op, model_inputs): self.used_types[op.type] = True def _variable_inputs(self, op): - """ Return which inputs of this operation are variable (i.e. depend on the model inputs). - """ + """Return which inputs of this operation are variable (i.e. depend on the model inputs).""" if op not in self._vinputs: out = np.zeros(len(op.inputs), dtype=bool) for i,t in enumerate(op.inputs): @@ -226,8 +224,7 @@ def _variable_inputs(self, op): return self._vinputs[op] def phi_symbolic(self, i): - """ Get the SHAP value computation graph for a given model output. - """ + """Get the SHAP value computation graph for a given model output.""" if self.phi_symbolics[i] is None: if not tf.executing_eagerly(): @@ -341,8 +338,7 @@ def shap_values(self, X, ranked_outputs=None, output_rank_order="max", check_add return output_phis def run(self, out, model_inputs, X): - """ Runs the model while also setting the learning phase flags to False. - """ + """Runs the model while also setting the learning phase flags to False.""" if not tf.executing_eagerly(): feed_dict = dict(zip(model_inputs, X)) for t in self.learning_phase_flags: @@ -370,8 +366,7 @@ def anon(): return self.execute_with_overridden_gradients(anon) def custom_grad(self, op, *grads): - """ Passes a gradient op creation request to the correct handler. - """ + """Passes a gradient op creation request to the correct handler.""" type_name = op.type[5:] if op.type.startswith("shap_") else op.type out = op_handlers[type_name](self, op, *grads) # we cut off the shap_ prefix before the lookup return out @@ -422,7 +417,7 @@ def execute_with_overridden_gradients(self, f): return [v.numpy() for v in out] def tensors_blocked_by_false(ops): - """ Follows a set of ops assuming their value is False and find blocked Switch paths. + """Follows a set of ops assuming their value is False and find blocked Switch paths. This is used to prune away parts of the model graph that are only used during the training phase (like dropout, batch norm, etc.). @@ -467,7 +462,7 @@ def forward_walk_ops(start_ops, tensor_blacklist, op_type_blacklist, within_ops) def softmax(explainer, op, *grads): - """ Just decompose softmax into its components and recurse, we can handle all of them :) + """Just decompose softmax into its components and recurse, we can handle all of them :) We assume the 'axis' is the last dimension because the TF codebase swaps the 'axis' to the last dimension before the softmax op if 'axis' is not already the last dimension. @@ -683,7 +678,7 @@ def passthrough(explainer, op, *grads): return explainer.orig_grads[op.type](op, *grads) def break_dependence(explainer, op, *grads): - """ This function name is used to break attribution dependence in the graph traversal. + """This function name is used to break attribution dependence in the graph traversal. These operation types may be connected above input data values in the graph but their outputs don't depend on the input values (for example they just depend on the shape). diff --git a/shap/explainers/_exact.py b/shap/explainers/_exact.py index 3ade9a821..c6dc095a7 100644 --- a/shap/explainers/_exact.py +++ b/shap/explainers/_exact.py @@ -17,7 +17,7 @@ class ExactExplainer(Explainer): - """ Computes SHAP values via an optimized exact enumeration. + """Computes SHAP values via an optimized exact enumeration. This works well for standard Shapley value maskers for models with less than ~15 features that vary from the background per sample. It also works well for Owen values from hclustering structured @@ -28,7 +28,7 @@ class ExactExplainer(Explainer): """ def __init__(self, model, masker, link=links.identity, linearize_link=True, feature_names=None): - """ Build an explainers.Exact object for the given model using the given masker object. + """Build an explainers.Exact object for the given model using the given masker object. Parameters ---------- @@ -56,6 +56,7 @@ def __init__(self, model, masker, link=links.identity, linearize_link=True, feat many samples. This for example means that a linear logistic regression model would have interaction effects that arise from the non-linear changes in expectation averaging. To retain the additively of the model with still respecting the link function we linearize the link function by default. + """ # TODO link to the link linearization paper when done super().__init__(model, masker, link=link, linearize_link=linearize_link, feature_names=feature_names) @@ -68,9 +69,7 @@ def __init__(self, model, masker, link=links.identity, linearize_link=True, feat self._gray_code_cache = {} # used to avoid regenerating the same gray code patterns def __call__(self, *args, max_evals=100000, main_effects=False, error_bounds=False, batch_size="auto", interactions=1, silent=False): - """ Explains the output of model(*args), where args represents one or more parallel iterators. - """ - + """Explains the output of model(*args), where args represents one or more parallel iterators.""" # we entirely rely on the general call implementation, we override just to remove **kwargs # from the function signature return super().__call__( @@ -84,9 +83,7 @@ def _cached_gray_codes(self, n): return self._gray_code_cache[n] def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, interactions, silent): - """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes). - """ - + """Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes).""" # build a masked version of the model for the current input sample fm = MaskedModel(self.model, self.masker, self.link, self.linearize_link, *row_args) @@ -238,9 +235,7 @@ def _compute_grey_code_row_values_st(row_values, mask, inds, outputs, shapley_co row_values[k,j] += delta def partition_delta_indexes(partition_tree, all_masks): - """ Return an delta index encoded array of all the masks possible while following the given partition tree. - """ - + """Return an delta index encoded array of all the masks possible while following the given partition tree.""" # convert the masks to delta index format mask = np.zeros(all_masks.shape[1], dtype=bool) delta_inds = [] @@ -258,9 +253,7 @@ def partition_delta_indexes(partition_tree, all_masks): return np.array(delta_inds) def partition_masks(partition_tree): - """ Return an array of all the masks possible while following the given partition tree. - """ - + """Return an array of all the masks possible while following the given partition tree.""" M = partition_tree.shape[0] + 1 mask_matrix = make_masks(partition_tree) all_masks = [] @@ -326,7 +319,7 @@ def _partition_masks_recurse(index, m00, ind00, ind11, inds_lists, mask_matrix, def gray_code_masks(nbits): - """ Produces an array of all binary patterns of size nbits in gray code order. + """Produces an array of all binary patterns of size nbits in gray code order. This is based on code from: http://code.activestate.com/recipes/576592-gray-code-generatoriterator/ """ @@ -346,7 +339,7 @@ def gray_code_masks(nbits): return out def gray_code_indexes(nbits): - """ Produces an array of which bits flip at which position. + """Produces an array of which bits flip at which position. We assume the masks start at all zero and -1 means don't do a flip. This is a more efficient representation of the gray_code_masks version. diff --git a/shap/explainers/_explainer.py b/shap/explainers/_explainer.py index 77a9d626e..dadb7ccf7 100644 --- a/shap/explainers/_explainer.py +++ b/shap/explainers/_explainer.py @@ -17,7 +17,7 @@ class Explainer(Serializable): - """ Uses Shapley values to explain any machine learning model or python function. + """Uses Shapley values to explain any machine learning model or python function. This is the primary explainer interface for the SHAP library. It takes any combination of a model and masker and returns a callable subclass object that implements @@ -26,7 +26,7 @@ class Explainer(Serializable): def __init__(self, model, masker=None, link=links.identity, algorithm="auto", output_names=None, feature_names=None, linearize_link=True, seed=None, **kwargs): - """ Build a new explainer for the passed model. + """Build a new explainer for the passed model. Parameters ---------- @@ -74,7 +74,6 @@ def __init__(self, model, masker=None, link=links.identity, algorithm="auto", ou seed for reproducibility """ - self.model = model self.output_names = output_names self.feature_names = feature_names @@ -202,13 +201,12 @@ def __init__(self, model, masker=None, link=links.identity, algorithm="auto", ou def __call__(self, *args, max_evals="auto", main_effects=False, error_bounds=False, batch_size="auto", outputs=None, silent=False, **kwargs): - """ Explains the output of model(*args), where args is a list of parallel iterable datasets. + """Explains the output of model(*args), where args is a list of parallel iterable datasets. Note this default version could be an abstract method that is implemented by each algorithm-specific subclass of Explainer. Descriptions of each subclasses' __call__ arguments are available in their respective doc-strings. """ - # if max_evals == "auto": # self._brute_force_fallback @@ -365,7 +363,7 @@ def __call__(self, *args, max_evals="auto", main_effects=False, error_bounds=Fal return out[0] if len(out) == 1 else out def explain_row(self, *row_args, max_evals, main_effects, error_bounds, outputs, silent, **kwargs): - """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes, main_effects). + """Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes, main_effects). This is an abstract method meant to be implemented by each subclass. @@ -377,13 +375,13 @@ def explain_row(self, *row_args, max_evals, main_effects, error_bounds, outputs, the expected value of the model for each sample (which is the same for all samples unless there are fixed inputs present, like labels when explaining the loss), and row_mask_shapes is a list of all the input shapes (since the row_values is always flattened), - """ + """ return {} @staticmethod def supports_model_with_masker(model, masker): - """ Determines if this explainer can handle the given model. + """Determines if this explainer can handle the given model. This is an abstract static method meant to be implemented by each subclass. """ @@ -391,8 +389,7 @@ def supports_model_with_masker(model, masker): @staticmethod def _compute_main_effects(fm, expected_value, inds): - """ A utility method to compute the main effects from a MaskedModel. - """ + """A utility method to compute the main effects from a MaskedModel.""" warnings.warn( "This function is not used within the shap library and will therefore be removed in an upcoming release. " "If you rely on this function, please open an issue: https://github.com/shap/shap/issues.", @@ -419,8 +416,7 @@ def _compute_main_effects(fm, expected_value, inds): return expanded_main_effects def save(self, out_file, model_saver=".save", masker_saver=".save"): - """ Write the explainer to the given file stream. - """ + """Write the explainer to the given file stream.""" super().save(out_file) with Serializer(out_file, "shap.Explainer", version=0) as s: s.save("model", self.model, model_saver) @@ -429,11 +425,12 @@ def save(self, out_file, model_saver=".save", masker_saver=".save"): @classmethod def load(cls, in_file, model_loader=Model.load, masker_loader=Masker.load, instantiate=True): - """ Load an Explainer from the given file stream. + """Load an Explainer from the given file stream. Parameters ---------- in_file : The file stream to load objects from. + """ if instantiate: return cls._instantiated_load(in_file, model_loader=model_loader, masker_loader=masker_loader) @@ -446,9 +443,7 @@ def load(cls, in_file, model_loader=Model.load, masker_loader=Masker.load, insta return kwargs def pack_values(values): - """ Used the clean up arrays before putting them into an Explanation object. - """ - + """Used the clean up arrays before putting them into an Explanation object.""" if not hasattr(values, "__len__"): return values diff --git a/shap/explainers/_gpu_tree.py b/shap/explainers/_gpu_tree.py index f9ed61dac..8097798fd 100644 --- a/shap/explainers/_gpu_tree.py +++ b/shap/explainers/_gpu_tree.py @@ -16,18 +16,18 @@ class GPUTreeExplainer(TreeExplainer): - """ - Experimental GPU accelerated version of TreeExplainer. Currently requires source build with + """Experimental GPU accelerated version of TreeExplainer. Currently requires source build with cuda available and 'CUDA_PATH' environment variable defined. Examples -------- See `GPUTree explainer examples `_ + """ def shap_values(self, X, y=None, tree_limit=None, approximate=False, check_additivity=True, from_call=False): - """ Estimate the SHAP values for a set of samples. + """Estimate the SHAP values for a set of samples. Parameters ---------- @@ -61,6 +61,7 @@ def shap_values(self, X, y=None, tree_limit=None, approximate=False, check_addit attribute of the explainer when it is constant). For models with vector outputs this returns a list of such matrices, one for each output. + """ assert not approximate, "approximate not supported" @@ -90,7 +91,7 @@ def shap_values(self, X, y=None, tree_limit=None, approximate=False, check_addit return out def shap_interaction_values(self, X, y=None, tree_limit=None): - """ Estimate the SHAP interaction values for a set of samples. + """Estimate the SHAP interaction values for a set of samples. Parameters ---------- @@ -123,8 +124,8 @@ def shap_interaction_values(self, X, y=None, tree_limit=None): interaction effects between all pairs of features for that sample. For models with vector outputs this returns a list of tensors, one for each output. - """ + """ assert self.model.model_output == "raw", "Only model_output = \"raw\" is supported for " \ "SHAP interaction values right now!" assert self.feature_perturbation != "interventional", 'feature_perturbation="interventional" is not yet supported for ' + \ diff --git a/shap/explainers/_gradient.py b/shap/explainers/_gradient.py index ed54b98a9..b0385ae23 100644 --- a/shap/explainers/_gradient.py +++ b/shap/explainers/_gradient.py @@ -18,7 +18,7 @@ class GradientExplainer(Explainer): - """ Explains a model using expected gradients (an extension of integrated gradients). + """Explains a model using expected gradients (an extension of integrated gradients). Expected gradients an extension of the integrated gradients method (Sundararajan et al. 2017), a feature attribution method designed for differentiable models based on an extension of Shapley @@ -32,10 +32,11 @@ class GradientExplainer(Explainer): Examples -------- See :ref:`Gradient Explainer Examples ` + """ def __init__(self, model, data, session=None, batch_size=50, local_smoothing=0): - """ An explainer object for a differentiable model using a given background dataset. + """An explainer object for a differentiable model using a given background dataset. Parameters ---------- @@ -56,8 +57,8 @@ def __init__(self, model, data, session=None, batch_size=50, local_smoothing=0): The background dataset to use for integrating out features. Gradient explainer integrates over these samples. The data passed here must match the input tensors given in the first argument. Single element lists can be passed unwrapped. - """ + """ # first, we need to find the framework if type(model) is tuple: a, b = model @@ -84,7 +85,7 @@ def __init__(self, model, data, session=None, batch_size=50, local_smoothing=0): self.explainer = _PyTorchGradient(model, data, batch_size, local_smoothing) def __call__(self, X, nsamples=200): - """ Return an explanation object for the model applied to X. + """Return an explanation object for the model applied to X. Parameters ---------- @@ -95,15 +96,17 @@ def __call__(self, X, nsamples=200): explain the model's output. nsamples : int number of background samples + Returns ------- shap.Explanation: + """ shap_values = self.shap_values(X, nsamples) return Explanation(values=shap_values, data=X, feature_names=self.features) def shap_values(self, X, nsamples=200, ranked_outputs=None, output_rank_order="max", rseed=None, return_variances=False): - """ Return the values for the model applied to X. + """Return the values for the model applied to X. Parameters ---------- diff --git a/shap/explainers/_kernel.py b/shap/explainers/_kernel.py index 2098c052c..f7b6101fc 100644 --- a/shap/explainers/_kernel.py +++ b/shap/explainers/_kernel.py @@ -78,6 +78,7 @@ class KernelExplainer(Explainer): Examples -------- See :ref:`Kernel Explainer Examples `. + """ def __init__(self, model, data, feature_names=None, link="identity", **kwargs): @@ -178,7 +179,7 @@ def __call__(self, X, l1_reg="auto", silent=False): ) def shap_values(self, X, **kwargs): - """ Estimate the SHAP values for a set of samples. + """Estimate the SHAP values for a set of samples. Parameters ---------- @@ -228,6 +229,7 @@ def shap_values(self, X, **kwargs): .. versionchanged:: 0.45.0 Return type for models with multiple outputs and one input changed from list to np.ndarray. + """ # convert dataframes if isinstance(X, pd.Series): diff --git a/shap/explainers/_linear.py b/shap/explainers/_linear.py index 3e81879f9..d3e052a31 100644 --- a/shap/explainers/_linear.py +++ b/shap/explainers/_linear.py @@ -79,6 +79,7 @@ class LinearExplainer(Explainer): Examples -------- See `Linear explainer examples `_ + """ def __init__(self, model, masker, link=links.identity, nsamples=1000, feature_perturbation=None, **kwargs): @@ -211,7 +212,7 @@ def __init__(self, model, masker, link=links.identity, nsamples=1000, feature_pe raise InvalidFeaturePerturbationError("Unknown type of feature_perturbation provided: " + self.feature_perturbation) def _estimate_transforms(self, nsamples): - """ Uses block matrix inversion identities to quickly estimate transforms. + """Uses block matrix inversion identities to quickly estimate transforms. After a bit of matrix math we can isolate a transform matrix (# features x # features) that is independent of any sample we are explaining. It is the result of averaging over @@ -276,8 +277,7 @@ def _estimate_transforms(self, nsamples): @staticmethod def _parse_model(model): - """ Attempt to pull out the coefficients and intercept from the given model object. - """ + """Attempt to pull out the coefficients and intercept from the given model object.""" # raw coefficients if type(model) == tuple and len(model) == 2: coef = model[0] @@ -302,9 +302,7 @@ def _parse_model(model): @staticmethod def supports_model_with_masker(model, masker): - """ Determines if we can parse the given model. - """ - + """Determines if we can parse the given model.""" if not isinstance(masker, (maskers.Independent, maskers.Partition, maskers.Impute)): return False @@ -315,9 +313,7 @@ def supports_model_with_masker(model, masker): return True def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent): - """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes). - """ - + """Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes).""" assert len(row_args) == 1, "Only single-argument functions are supported by the Linear explainer!" X = row_args[0] @@ -367,7 +363,7 @@ def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_si def shap_values(self, X): - """ Estimate the SHAP values for a set of samples. + """Estimate the SHAP values for a set of samples. Parameters ---------- @@ -391,8 +387,8 @@ def shap_values(self, X): .. versionchanged:: 0.45.0 Return type for models with multiple outputs changed from list to np.ndarray. - """ + """ # convert dataframes if isinstance(X, (pd.Series, pd.DataFrame)): X = X.values diff --git a/shap/explainers/_partition.py b/shap/explainers/_partition.py index 79fbbc92c..5f760be8d 100644 --- a/shap/explainers/_partition.py +++ b/shap/explainers/_partition.py @@ -62,8 +62,8 @@ def __init__(self, model, masker, *, output_names=None, link=links.identity, lin Examples -------- See `Partition explainer examples `_ - """ + """ super().__init__(model, masker, link=link, linearize_link=linearize_link, algorithm="partition", \ output_names = output_names, feature_names=feature_names) @@ -124,17 +124,14 @@ def __call__(self, *args, max_evals=500, fixed_context=None, main_effects=False, # note that changes to this function signature should be copied to the default call argument wrapper above def __call__(self, *args, max_evals=500, fixed_context=None, main_effects=False, error_bounds=False, batch_size="auto", outputs=None, silent=False): - """ Explain the output of the model on the given arguments. - """ + """Explain the output of the model on the given arguments.""" return super().__call__( *args, max_evals=max_evals, fixed_context=fixed_context, main_effects=main_effects, error_bounds=error_bounds, batch_size=batch_size, outputs=outputs, silent=silent ) def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent, fixed_context = "auto"): - """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes). - """ - + """Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes).""" if fixed_context == "auto": # if isinstance(self.masker, maskers.Text): # fixed_context = 1 # we err on the side of speed for text models @@ -202,9 +199,7 @@ def __str__(self): return "shap.explainers.PartitionExplainer()" def owen(self, fm, f00, f11, max_evals, output_indexes, fixed_context, batch_size, silent): - """ Compute a nested set of recursive Owen values based on an ordering recursion. - """ - + """Compute a nested set of recursive Owen values based on an ordering recursion.""" #f = self._reshaped_model #r = self.masker #masks = np.zeros(2*len(inds)+1, dtype=int) @@ -341,9 +336,7 @@ def owen(self, fm, f00, f11, max_evals, output_indexes, fixed_context, batch_siz return output_indexes, base_value def owen3(self, fm, f00, f11, max_evals, output_indexes, fixed_context, batch_size, silent): - """ Compute a nested set of recursive Owen values based on an ordering recursion. - """ - + """Compute a nested set of recursive Owen values based on an ordering recursion.""" #f = self._reshaped_model #r = self.masker #masks = np.zeros(2*len(inds)+1, dtype=int) diff --git a/shap/explainers/_permutation.py b/shap/explainers/_permutation.py index 7eabddafa..0f7748c58 100644 --- a/shap/explainers/_permutation.py +++ b/shap/explainers/_permutation.py @@ -8,7 +8,7 @@ class PermutationExplainer(Explainer): - """ This method approximates the Shapley values by iterating through permutations of the inputs. + """This method approximates the Shapley values by iterating through permutations of the inputs. This is a model agnostic explainer that guarantees local accuracy (additivity) by iterating completely through an entire permutation of the features in both forward and reverse directions (antithetic sampling). @@ -21,7 +21,7 @@ class PermutationExplainer(Explainer): """ def __init__(self, model, masker, link=links.identity, feature_names=None, linearize_link=True, seed=None, **call_args): - """ Build an explainers.Permutation object for the given model using the given masker object. + """Build an explainers.Permutation object for the given model using the given masker object. Parameters ---------- @@ -29,20 +29,20 @@ def __init__(self, model, masker, link=links.identity, feature_names=None, linea A callable python object that executes the model given a set of input data samples. masker : function or numpy.array or pandas.DataFrame - A callable python object used to "mask" out hidden features of the form `masker(binary_mask, x)`. + A callable python object used to "mask" out hidden features of the form ``masker(binary_mask, x)``. It takes a single input sample and a binary mask and returns a matrix of masked samples. These masked samples are evaluated using the model function and the outputs are then averaged. As a shortcut for the standard masking using by SHAP you can pass a background data matrix instead of a function and that matrix will be used for masking. To use a clustering - game structure you can pass a shap.maskers.Tabular(data, clustering=\"correlation\") object. + game structure you can pass a ``shap.maskers.Tabular(data, clustering="correlation")`` object. seed: None or int Seed for reproducibility **call_args : valid argument to the __call__ method These arguments are saved and passed to the __call__ method as the new default values for these arguments. - """ + """ # setting seed for random generation: if seed is not None, then shap values computation should be reproducible np.random.seed(seed) @@ -73,17 +73,14 @@ def __call__(self, *args, max_evals=500, main_effects=False, error_bounds=False, # note that changes to this function signature should be copied to the default call argument wrapper above def __call__(self, *args, max_evals=500, main_effects=False, error_bounds=False, batch_size="auto", outputs=None, silent=False): - """ Explain the output of the model on the given arguments. - """ + """Explain the output of the model on the given arguments.""" return super().__call__( *args, max_evals=max_evals, main_effects=main_effects, error_bounds=error_bounds, batch_size=batch_size, outputs=outputs, silent=silent ) def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent): - """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes). - """ - + """Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes).""" # build a masked version of the model for the current input sample fm = MaskedModel(self.model, self.masker, self.link, self.linearize_link, *row_args) @@ -184,7 +181,7 @@ def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_si def shap_values(self, X, npermutations=10, main_effects=False, error_bounds=False, batch_evals=True, silent=False): - """ Legacy interface to estimate the SHAP values for a set of samples. + """Legacy interface to estimate the SHAP values for a set of samples. Parameters ---------- @@ -206,8 +203,8 @@ def shap_values(self, X, npermutations=10, main_effects=False, error_bounds=Fals sample and the expected value of the model output (which is stored as expected_value attribute of the explainer). For models with vector outputs this returns a list of such matrices, one for each output. - """ + """ explanation = self(X, max_evals=npermutations * X.shape[1], main_effects=main_effects) return explanation.values diff --git a/shap/explainers/_sampling.py b/shap/explainers/_sampling.py index bffa70a89..0dd6f3c50 100644 --- a/shap/explainers/_sampling.py +++ b/shap/explainers/_sampling.py @@ -37,6 +37,7 @@ class SamplingExplainer(KernelExplainer): we would approximate a feature being missing by setting it to zero. Unlike the KernelExplainer, this data can be the whole training set, even if that is a large set. This is because SamplingExplainer only samples from this background dataset. + """ def __init__(self, model, data, **kwargs): diff --git a/shap/explainers/_tree.py b/shap/explainers/_tree.py index b4e23aa89..2ea0ce780 100644 --- a/shap/explainers/_tree.py +++ b/shap/explainers/_tree.py @@ -84,6 +84,7 @@ class TreeExplainer(Explainer): Examples -------- See `Tree explainer examples `_ + """ def __init__( @@ -98,7 +99,7 @@ def __init__( link=None, linearize_link=None, ): - """ Build a new Tree explainer for the passed model. + """Build a new Tree explainer for the passed model. Parameters ---------- @@ -245,9 +246,7 @@ def __init__( self.expected_value = [1 - self.expected_value, self.expected_value] def __dynamic_expected_value(self, y): - """ This computes the expected value conditioned on the given label value. - """ - + """This computes the expected value conditioned on the given label value.""" return self.model.predict(self.data, np.ones(self.data.shape[0]) * y).mean(0) def __call__(self, X, y=None, interactions=False, check_additivity=True): @@ -368,7 +367,7 @@ def _validate_inputs(self, X, y, tree_limit, check_additivity): return X, y, X_missing, flat_output, tree_limit, check_additivity def shap_values(self, X, y=None, tree_limit=None, approximate=False, check_additivity=True, from_call=False): - """ Estimate the SHAP values for a set of samples. + """Estimate the SHAP values for a set of samples. Parameters ---------- @@ -410,6 +409,7 @@ def shap_values(self, X, y=None, tree_limit=None, approximate=False, check_addit .. versionchanged:: 0.45.0 Return type for models with multiple outputs changed from list to np.ndarray. + """ # see if we have a default tree_limit in place. if tree_limit is None: @@ -539,7 +539,7 @@ def _get_shap_output(self, phi, flat_output): return out def shap_interaction_values(self, X, y=None, tree_limit=None): - """ Estimate the SHAP interaction values for a set of samples. + """Estimate the SHAP interaction values for a set of samples. Parameters ---------- @@ -571,8 +571,8 @@ def shap_interaction_values(self, X, y=None, tree_limit=None): .. versionchanged:: 0.45.0 Return type for models with multiple outputs changed from list to np.ndarray. - """ + """ assert self.model.model_output == "raw", "Only model_output = \"raw\" is supported for SHAP interaction values right now!" #assert self.feature_perturbation == "tree_path_dependent", "Only feature_perturbation = \"tree_path_dependent\" is supported for SHAP interaction values right now!" transform = "identity" @@ -681,11 +681,10 @@ def check_sum(sum_val, model_output): @staticmethod def supports_model_with_masker(model, masker): - """ Determines if this explainer can handle the given model. + """Determines if this explainer can handle the given model. This is an abstract static method meant to be implemented by each subclass. """ - if not isinstance(masker, (maskers.Independent)) and masker is not None: return False @@ -697,7 +696,7 @@ def supports_model_with_masker(model, masker): class TreeEnsemble: - """ An ensemble of decision trees. + """An ensemble of decision trees. This object provides a common interface to many different types of models. """ @@ -1352,8 +1351,7 @@ def num_outputs(self) -> int: return self.trees[0].values.shape[1] def get_transform(self): - """ A consistent interface to make predictions from this model. - """ + """A consistent interface to make predictions from this model.""" if self.model_output == "raw": transform = "identity" elif self.model_output in ("probability", "probability_doubled"): @@ -1396,8 +1394,8 @@ def predict(self, X, y=None, output=None, tree_limit=None): tree_limit : None (default) or int Limit the number of trees used by the model. By default None means no use the limit of the original model, and -1 means no limit. - """ + """ if output is None: output = self.model_output @@ -1500,7 +1498,9 @@ class SingleTree: max_depth : int The max depth of the tree. + """ + def __init__(self, tree, normalize=False, scaling=1.0, data=None, data_missing=None): assert_import("cext") @@ -1770,9 +1770,8 @@ def extract_data(node, tree): class IsoTree(SingleTree): - """ - In sklearn the tree of the Isolation Forest does not calculated in a good way. - """ + """In sklearn the tree of the Isolation Forest does not calculated in a good way.""" + def __init__(self, tree, tree_features, normalize=False, scaling=1.0, data=None, data_missing=None): super().__init__(tree, normalize, scaling, data, data_missing) if safe_isinstance(tree, "sklearn.tree._tree.Tree"): @@ -1909,7 +1908,7 @@ def __init__(self, xgb_model) -> None: self.cat_feature_indices = None def to_integers(data: list[int]) -> np.ndarray: - "Handle u8 array from UBJSON." + """Handle u8 array from UBJSON.""" assert isinstance(data, list) return np.asanyarray(data, dtype=np.uint8) diff --git a/shap/explainers/other/_coefficient.py b/shap/explainers/other/_coefficient.py index 6384710b9..9644e0fd6 100644 --- a/shap/explainers/other/_coefficient.py +++ b/shap/explainers/other/_coefficient.py @@ -4,11 +4,12 @@ class Coefficient(Explainer): - """ Simply returns the model coefficients as the feature attributions. + """Simply returns the model coefficients as the feature attributions. This is only for benchmark comparisons and does not approximate SHAP values in a meaningful way. """ + def __init__(self, model): assert hasattr(model, "coef_"), "The passed model does not have a coef_ attribute!" self.model = model diff --git a/shap/explainers/other/_lime.py b/shap/explainers/other/_lime.py index 2757c8537..f096a4d64 100644 --- a/shap/explainers/other/_lime.py +++ b/shap/explainers/other/_lime.py @@ -10,7 +10,7 @@ pass class LimeTabular(Explainer): - """ Simply wrap of lime.lime_tabular.LimeTabularExplainer into the common shap interface. + """Simply wrap of lime.lime_tabular.LimeTabularExplainer into the common shap interface. Parameters ---------- @@ -24,6 +24,7 @@ class LimeTabular(Explainer): mode : "classification" or "regression" Control the mode of LIME tabular. + """ def __init__(self, model, data, mode="classification"): diff --git a/shap/explainers/other/_random.py b/shap/explainers/other/_random.py index 414b3de54..1605fd0aa 100644 --- a/shap/explainers/other/_random.py +++ b/shap/explainers/other/_random.py @@ -8,11 +8,12 @@ class Random(Explainer): - """ Simply returns random (normally distributed) feature attributions. + """Simply returns random (normally distributed) feature attributions. This is only for benchmark comparisons. It supports both fully random attributions and random attributions that are constant across all explanations. """ + def __init__(self, model, masker, link=links.identity, feature_names=None, linearize_link=True, constant=False, **call_args): super().__init__(model, masker, link=link, linearize_link=linearize_link, feature_names=feature_names) @@ -26,9 +27,7 @@ def __init__(self, model, masker, link=links.identity, feature_names=None, linea self.constant_attributions = None def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent): - """ Explains a single row. - """ - + """Explains a single row.""" # build a masked version of the model for the current input sample fm = MaskedModel(self.model, self.masker, self.link, self.linearize_link, *row_args) diff --git a/shap/explainers/other/_treegain.py b/shap/explainers/other/_treegain.py index edcd54bb5..82ce00ee9 100644 --- a/shap/explainers/other/_treegain.py +++ b/shap/explainers/other/_treegain.py @@ -4,10 +4,11 @@ class TreeGain(Explainer): - """ Simply returns the global gain/gini feature importances for tree models. + """Simply returns the global gain/gini feature importances for tree models. This is only for benchmark comparisons and is not meant to approximate SHAP values. """ + def __init__(self, model): if str(type(model)).endswith("sklearn.tree.tree.DecisionTreeRegressor'>"): pass diff --git a/shap/explainers/pytree.py b/shap/explainers/pytree.py index 2802849b7..9a2063316 100644 --- a/shap/explainers/pytree.py +++ b/shap/explainers/pytree.py @@ -1,5 +1,4 @@ -""" -This module is a pure python implementation of Tree SHAP. +"""This module is a pure python implementation of Tree SHAP. It is primarily for illustration since it is slower than the 'tree' module which uses a compiled C++ implementation. """ @@ -136,8 +135,7 @@ class TreeExplainer: - """ A pure Python (slow) implementation of Tree SHAP. - """ + """A pure Python (slow) implementation of Tree SHAP.""" def __init__(self, model, **kwargs): self.model_type = "internal" diff --git a/shap/explainers/tf_utils.py b/shap/explainers/tf_utils.py index 6c66b1f11..50be69421 100644 --- a/shap/explainers/tf_utils.py +++ b/shap/explainers/tf_utils.py @@ -4,14 +4,13 @@ def _import_tf(): - """ Tries to import tensorflow. - """ + """Tries to import tensorflow.""" global tf if tf is None: import tensorflow as tf def _get_session(session): - """ Common utility to get the session for the tensorflow-based explainer. + """Common utility to get the session for the tensorflow-based explainer. Parameters ---------- @@ -22,6 +21,7 @@ def _get_session(session): session : tf.compat.v1.Session An optional existing session. + """ _import_tf() # if we are not given a session find a default session @@ -33,13 +33,14 @@ def _get_session(session): return tf.get_default_session() if session is None else session def _get_graph(explainer): - """ Common utility to get the graph for the tensorflow-based explainer. + """Common utility to get the graph for the tensorflow-based explainer. Parameters ---------- explainer : Explainer One of the tensorflow-based explainers. + """ _import_tf() if not tf.executing_eagerly(): @@ -50,13 +51,14 @@ def _get_graph(explainer): return graph def _get_model_inputs(model): - """ Common utility to determine the model inputs. + """Common utility to determine the model inputs. Parameters ---------- model : Tensorflow Keras model or tuple The tensorflow model or tuple. + """ _import_tf() if str(type(model)).endswith("keras.engine.sequential.Sequential'>") or \ @@ -72,13 +74,14 @@ def _get_model_inputs(model): def _get_model_output(model): - """ Common utility to determine the model output. + """Common utility to determine the model output. Parameters ---------- model : Tensorflow Keras model or tuple The tensorflow model or tuple. + """ _import_tf() if str(type(model)).endswith("keras.engine.sequential.Sequential'>") or \ diff --git a/shap/links.py b/shap/links.py index 3325f6988..2b2fd7295 100644 --- a/shap/links.py +++ b/shap/links.py @@ -4,8 +4,7 @@ @numba.njit def identity(x): - """ A no-op link function. - """ + """A no-op link function.""" return x @numba.njit def _identity_inverse(x): @@ -14,8 +13,7 @@ def _identity_inverse(x): @numba.njit def logit(x): - """ A logit link function useful for going from probability units to log-odds units. - """ + """A logit link function useful for going from probability units to log-odds units.""" return np.log(x/(1-x)) @numba.njit def _logit_inverse(x): diff --git a/shap/maskers/_composite.py b/shap/maskers/_composite.py index 1f9b132e1..a73a6bda8 100644 --- a/shap/maskers/_composite.py +++ b/shap/maskers/_composite.py @@ -5,7 +5,7 @@ class Composite(Masker): - """ This merges several maskers for different inputs together into a single composite masker. + """This merges several maskers for different inputs together into a single composite masker. This is not yet implemented. """ @@ -40,8 +40,7 @@ def __init__(self, *maskers): self.clustering = types.MethodType(joint_clustering, self) def shape(self, *args): - """ Compute the shape of this masker as the sum of all the sub masker shapes. - """ + """Compute the shape of this masker as the sum of all the sub masker shapes.""" assert len(args) == self.total_args, "The number of passed args is incorrect!" rows = None @@ -61,8 +60,7 @@ def shape(self, *args): return rows, cols def mask_shapes(self, *args): - """ The shape of the masks we expect. - """ + """The shape of the masks we expect.""" out = [] pos = 0 for i, masker in enumerate(self.maskers): @@ -70,8 +68,7 @@ def mask_shapes(self, *args): return out def data_transform(self, *args): - """ Transform the argument - """ + """Transform the argument""" arg_pos = 0 out = [] for i, masker in enumerate(self.maskers): @@ -125,9 +122,7 @@ def __call__(self, mask, *args): return tuple(masked) def joint_clustering(self, *args): - """ Return a joint clustering that merges the clusterings of all the submaskers. - """ - + """Return a joint clustering that merges the clusterings of all the submaskers.""" single_clustering = [] arg_pos = 0 for i, masker in enumerate(self.maskers): diff --git a/shap/maskers/_fixed.py b/shap/maskers/_fixed.py index e00bc3075..5fd50e7e2 100644 --- a/shap/maskers/_fixed.py +++ b/shap/maskers/_fixed.py @@ -4,7 +4,7 @@ class Fixed(Masker): - """ This leaves the input unchanged during masking, and is used for things like scoring labels. + """This leaves the input unchanged during masking, and is used for things like scoring labels. Sometimes there are inputs to the model that we do not want to explain, but rather we want to consider them fixed. The primary example of this is when we explain the loss of the model using @@ -13,6 +13,7 @@ class Fixed(Masker): the model's input features. This is where the Fixed masker can help, since we can apply it to the label inputs. """ + def __init__(self): self.shape = (None, 0) self.clustering = np.zeros((0, 4)) @@ -21,6 +22,5 @@ def __call__(self, mask, x): return ([x],) def mask_shapes(self, x): - """ The shape of the masks we expect. - """ + """The shape of the masks we expect.""" return [(0,)] diff --git a/shap/maskers/_fixed_composite.py b/shap/maskers/_fixed_composite.py index 861a945eb..8272f7b71 100644 --- a/shap/maskers/_fixed_composite.py +++ b/shap/maskers/_fixed_composite.py @@ -5,11 +5,10 @@ class FixedComposite(Masker): - """ A masker that outputs both the masked data and the original data as a pair. - """ + """A masker that outputs both the masked data and the original data as a pair.""" def __init__(self, masker): - """ Creates a Composite masker from an underlying masker and returns the original args along with the masked output. + """Creates a Composite masker from an underlying masker and returns the original args along with the masked output. Parameters ---------- @@ -20,6 +19,7 @@ def __init__(self, masker): ------- tuple A tuple consisting of the masked input using the underlying masker appended with the original args in a list. + """ self.masker = masker @@ -30,8 +30,7 @@ def __init__(self, masker): setattr(self, masker_attribute, getattr(self.masker, masker_attribute)) def __call__(self, mask, *args): - """ Computes mask on the args using the masker data attribute and returns tuple containing masked input with args. - """ + """Computes mask on the args using the masker data attribute and returns tuple containing masked input with args.""" masked_X = self.masker(mask, *args) wrapped_args = [] for item in args: @@ -42,8 +41,7 @@ def __call__(self, mask, *args): return masked_X + wrapped_args def save(self, out_file): - """ Write a FixedComposite masker to a file stream. - """ + """Write a FixedComposite masker to a file stream.""" super().save(out_file) # Increment the version number when the encoding changes! @@ -52,8 +50,7 @@ def save(self, out_file): @classmethod def load(cls, in_file, instantiate=True): - """ Load a FixedComposite masker from a file stream. - """ + """Load a FixedComposite masker from a file stream.""" if instantiate: return cls._instantiated_load(in_file) diff --git a/shap/maskers/_image.py b/shap/maskers/_image.py index 12e7b462a..1875ef97e 100644 --- a/shap/maskers/_image.py +++ b/shap/maskers/_image.py @@ -109,8 +109,7 @@ def __call__(self, mask, x): return (out.reshape(1, *in_shape),) def inpaint(self, x, mask, method): - """ Fill in the masked parts of the image through inpainting. - """ + """Fill in the masked parts of the image through inpainting.""" reshaped_mask = mask.reshape(self.input_shape).astype(np.uint8).max(2) if reshaped_mask.sum() == np.prod(self.input_shape[:-1]): out = x.reshape(self.input_shape).copy() @@ -125,9 +124,7 @@ def inpaint(self, x, mask, method): ).astype(x.dtype).ravel() def build_partition_tree(self): - """ This partitions an image into a herarchical clustering based on axis-aligned splits. - """ - + """This partitions an image into a herarchical clustering based on axis-aligned splits.""" xmin = 0 xmax = self.input_shape[0] ymin = 0 @@ -144,8 +141,7 @@ def build_partition_tree(self): self.clustering = clustering def save(self, out_file): - """ Write a Image masker to a file stream. - """ + """Write a Image masker to a file stream.""" super().save(out_file) # Increment the version number when the encoding changes! @@ -155,8 +151,7 @@ def save(self, out_file): @classmethod def load(cls, in_file, instantiate=True): - """ Load a Image masker from a file stream. - """ + """Load a Image masker from a file stream.""" if instantiate: return cls._instantiated_load(in_file) @@ -168,9 +163,7 @@ def load(cls, in_file, instantiate=True): @njit def _jit_build_partition_tree(xmin, xmax, ymin, ymax, zmin, zmax, total_ywidth, total_zwidth, M, clustering, q): - """ This partitions an image into a herarchical clustering based on axis-aligned splits. - """ - + """This partitions an image into a herarchical clustering based on axis-aligned splits.""" # heapq.heappush(q, (0, xmin, xmax, ymin, ymax, zmin, zmax, -1, False)) # q.put((0, xmin, xmax, ymin, ymax, zmin, zmax, -1, False)) diff --git a/shap/maskers/_masker.py b/shap/maskers/_masker.py index 895a7f1a0..f83f37cab 100644 --- a/shap/maskers/_masker.py +++ b/shap/maskers/_masker.py @@ -4,16 +4,13 @@ class Masker(Serializable): - """ This is the superclass of all maskers. - """ + """This is the superclass of all maskers.""" def __call__(self, mask, *args): - """ Maskers are callable objects that accept the same inputs as the model plus a binary mask. - """ + """Maskers are callable objects that accept the same inputs as the model plus a binary mask.""" def _standardize_mask(self, mask, *args): - """ This allows users to pass True/False as short hand masks. - """ + """This allows users to pass True/False as short hand masks.""" if mask is True or mask is False: if callable(self.shape): shape = self.shape(*args) diff --git a/shap/maskers/_output_composite.py b/shap/maskers/_output_composite.py index 1217aa9ab..ae2e94bcd 100644 --- a/shap/maskers/_output_composite.py +++ b/shap/maskers/_output_composite.py @@ -3,11 +3,10 @@ class OutputComposite(Masker): - """ A masker that is a combination of a masker and a model and outputs both masked args and the model's output. - """ + """A masker that is a combination of a masker and a model and outputs both masked args and the model's output.""" def __init__(self, masker, model): - """ Creates a masker from an underlying masker and and model. + """Creates a masker from an underlying masker and and model. This masker returns the masked input along with the model output for the passed args. @@ -23,6 +22,7 @@ def __init__(self, masker, model): ------- tuple A tuple consisting of the masked input using the underlying masker appended with the model output for passed args. + """ self.masker = masker self.model = model @@ -34,8 +34,7 @@ def __init__(self, masker, model): setattr(self, masker_attribute, getattr(self.masker, masker_attribute)) def __call__(self, mask, *args): - """ Mask the args using the masker and return a tuple containing the masked input and the model output on the args. - """ + """Mask the args using the masker and return a tuple containing the masked input and the model output on the args.""" masked_X = self.masker(mask, *args) y = self.model(*args) # wrap model output @@ -47,8 +46,7 @@ def __call__(self, mask, *args): return masked_X + y def save(self, out_file): - """ Write a OutputComposite masker to a file stream. - """ + """Write a OutputComposite masker to a file stream.""" super().save(out_file) # Increment the version number when the encoding changes! @@ -58,8 +56,7 @@ def save(self, out_file): @classmethod def load(cls, in_file, instantiate=True): - """ Load a OutputComposite masker from a file stream. - """ + """Load a OutputComposite masker from a file stream.""" if instantiate: return cls._instantiated_load(in_file) diff --git a/shap/maskers/_tabular.py b/shap/maskers/_tabular.py index 2cb971b79..5842eb40b 100644 --- a/shap/maskers/_tabular.py +++ b/shap/maskers/_tabular.py @@ -14,11 +14,10 @@ class Tabular(Masker): - """ A common base class for Independent and Partition. - """ + """A common base class for Independent and Partition.""" def __init__(self, data, max_samples=100, clustering=None): - """ This masks out tabular features by integrating over the given background dataset. + """This masks out tabular features by integrating over the given background dataset. Parameters ---------- @@ -41,8 +40,8 @@ def __init__(self, data, max_samples=100, clustering=None): `matching`, `minkowski`, `rogerstanimoto`, `russellrao`, `seuclidean`, `sokalmichener`, `sokalsneath`, `sqeuclidean`, `yule`. These are all the options from scipy.spatial.distance.pdist's metric argument. - """ + """ self.output_dataframe = False if isinstance(data, pd.DataFrame): self.feature_names = data.columns @@ -136,12 +135,11 @@ def __call__(self, mask, x): def invariants(self, x): - """ This returns a mask of which features change when we mask them. + """This returns a mask of which features change when we mask them. This optional masking method allows explainers to avoid re-evaluating the model when the features that would have been masked are all invariant. """ - # make sure we got valid data if x.shape != self.data.shape[1:]: raise DimensionError( @@ -152,8 +150,7 @@ def invariants(self, x): return np.isclose(x, self.data) def save(self, out_file): - """ Write a Tabular masker to a file stream. - """ + """Write a Tabular masker to a file stream.""" super().save(out_file) # Increment the version number when the encoding changes! @@ -172,8 +169,7 @@ def save(self, out_file): @classmethod def load(cls, in_file, instantiate=True): - """ Load a Tabular masker from a file stream. - """ + """Load a Tabular masker from a file stream.""" if instantiate: return cls._instantiated_load(in_file) @@ -198,11 +194,10 @@ def _single_delta_mask(dind, masked_inputs, last_mask, data, x, noop_code): @njit def _delta_masking(masks, x, curr_delta_inds, varying_rows_out, masked_inputs_tmp, last_mask, data, variants, masked_inputs_out, noop_code): - """ Implements the special (high speed) delta masking API that only flips the positions we need to. + """Implements the special (high speed) delta masking API that only flips the positions we need to. Note that we attempt to avoid doing any allocation inside this function for speed reasons. """ - dpos = 0 i = -1 masks_pos = 0 @@ -243,11 +238,10 @@ def _delta_masking(masks, x, curr_delta_inds, varying_rows_out, class Independent(Tabular): - """ This masks out tabular features by integrating over the given background dataset. - """ + """This masks out tabular features by integrating over the given background dataset.""" def __init__(self, data, max_samples=100): - """ Build a Independent masker with the given background data. + """Build a Independent masker with the given background data. Parameters ---------- @@ -260,18 +254,19 @@ def __init__(self, data, max_samples=100): samples coming out of the masker (to be integrated over) matches the number of samples in the background dataset. This means larger background dataset cause longer runtimes. Normally about 1, 10, 100, or 1000 background samples are reasonable choices. + """ super().__init__(data, max_samples=max_samples, clustering=None) class Partition(Tabular): - """ This masks out tabular features by integrating over the given background dataset. + """This masks out tabular features by integrating over the given background dataset. Unlike Independent, Partition respects a hierarchical structure of the data. """ def __init__(self, data, max_samples=100, clustering="correlation"): - """ Build a Partition masker with the given background data and clustering. + """Build a Partition masker with the given background data and clustering. Parameters ---------- @@ -295,23 +290,25 @@ def __init__(self, data, max_samples=100, clustering="correlation"): `sokalmichener`, `sokalsneath`, `sqeuclidean`, `yule`. These are all the options from scipy.spatial.distance.pdist's metric argument. If an array, then this is assumed to be the clustering of the features. + """ super().__init__(data, max_samples=max_samples, clustering=clustering) class Impute(Masker): # we should inherit from Tabular once we add support for arbitrary masking - """ This imputes the values of missing features using the values of the observed features. + """This imputes the values of missing features using the values of the observed features. Unlike Independent, Gaussian imputes missing values based on correlations with observed data points. """ def __init__(self, data, method="linear"): - """ Build a Partition masker with the given background data and clustering. + """Build a Partition masker with the given background data and clustering. Parameters ---------- data : numpy.ndarray, pandas.DataFrame or {"mean: numpy.ndarray, "cov": numpy.ndarray} dictionary The background dataset that is used for masking. + """ if data is dict and "mean" in data: self.mean = data.get("mean", None) diff --git a/shap/maskers/_text.py b/shap/maskers/_text.py index 1f4ebec1c..6e3fc9bab 100644 --- a/shap/maskers/_text.py +++ b/shap/maskers/_text.py @@ -14,15 +14,16 @@ class Text(Masker): - """ This masks out tokens according to the given tokenizer. + """This masks out tokens according to the given tokenizer. The masked variables are output_type : "string" (default) or "token_ids" """ + def __init__(self, tokenizer=None, mask_token=None, collapse_mask_token="auto", output_type="string"): - """ Build a new Text masker given an optional passed tokenizer. + """Build a new Text masker given an optional passed tokenizer. Parameters ---------- @@ -40,8 +41,8 @@ def __init__(self, tokenizer=None, mask_token=None, collapse_mask_token="auto", collapse_mask_token : True, False, or "auto" If True, when several consecutive tokens are masked only one mask token is used to replace the entire series of original tokens. - """ + """ if tokenizer is None: self.tokenizer = SimpleTokenizer() elif callable(tokenizer): @@ -166,14 +167,11 @@ def __call__(self, mask, s): return (np.array([out]),) def data_transform(self, s): - """ Called by explainers to allow us to convert data to better match masking (here this means tokenizing). - """ + """Called by explainers to allow us to convert data to better match masking (here this means tokenizing).""" return (self.token_segments(s)[0],) def token_segments(self, s): - """ Returns the substrings associated with each token in the given string. - """ - + """Returns the substrings associated with each token in the given string.""" try: token_data = self.tokenizer(s, return_offsets_mapping=True) offsets = token_data["offset_mapping"] @@ -210,8 +208,7 @@ def token_segments(self, s): return tokens, token_ids def clustering(self, s): - """ Compute the clustering of tokens for the given string. - """ + """Compute the clustering of tokens for the given string.""" self._update_s_cache(s) special_tokens = [] sep_token = getattr_silent(self.tokenizer, "sep_token") @@ -284,7 +281,7 @@ def _update_s_cache(self, s): self._segments_s = np.array(tokens) def shape(self, s): - """ The shape of what we return as a masker. + """The shape of what we return as a masker. Note we only return a single sample, so there is no expectation averaging. """ @@ -292,14 +289,12 @@ def shape(self, s): return (1, len(self._tokenized_s)) def mask_shapes(self, s): - """ The shape of the masks we expect. - """ + """The shape of the masks we expect.""" self._update_s_cache(s) return [(len(self._tokenized_s),)] def invariants(self, s): - """ The names of the features for each mask position for the given input string. - """ + """The names of the features for each mask position for the given input string.""" self._update_s_cache(s) invariants = np.zeros(len(self._tokenized_s), dtype=bool) @@ -314,14 +309,12 @@ def invariants(self, s): return invariants.reshape(1, -1) def feature_names(self, s): - """ The names of the features for each mask position for the given input string. - """ + """The names of the features for each mask position for the given input string.""" self._update_s_cache(s) return [[v.strip() for v in self._segments_s]] def save(self, out_file): - """ Save a Text masker to a file stream. - """ + """Save a Text masker to a file stream.""" super().save(out_file) with Serializer(out_file, "shap.maskers.Text", version=0) as s: s.save("tokenizer", self.tokenizer) @@ -331,8 +324,7 @@ def save(self, out_file): @classmethod def load(cls, in_file, instantiate=True): - """ Load a Text masker from a file stream. - """ + """Load a Text masker from a file stream.""" if instantiate: return cls._instantiated_load(in_file) @@ -346,16 +338,14 @@ def load(cls, in_file, instantiate=True): class SimpleTokenizer: - """ A basic model agnostic tokenizer. - """ + """A basic model agnostic tokenizer.""" + def __init__(self, split_pattern=r"\W+"): - """ Create a tokenizer based on a simple splitting pattern. - """ + """Create a tokenizer based on a simple splitting pattern.""" self.split_pattern = re.compile(split_pattern) def __call__(self, s, return_offsets_mapping=True): - """ Tokenize the passed string, optionally returning the offsets of each token in the original string. - """ + """Tokenize the passed string, optionally returning the offsets of each token in the original string.""" pos = 0 offset_ranges = [] input_ids = [] @@ -376,8 +366,7 @@ def __call__(self, s, return_offsets_mapping=True): def post_process_sentencepiece_tokenizer_output(s): - """ replaces whitespace encoded as '_' with ' ' for sentencepiece tokenizers. - """ + """Replaces whitespace encoded as '_' with ' ' for sentencepiece tokenizers.""" s = s.replace('▁', ' ') return s @@ -391,8 +380,8 @@ def post_process_sentencepiece_tokenizer_output(s): connectors = ["but", "and", "or"] class Token: - """ A token representation used for token clustering. - """ + """A token representation used for token clustering.""" + def __init__(self, value): self.s = value if value in openers or value in closers: @@ -409,8 +398,8 @@ def __repr__(self): return self.s class TokenGroup: - """ A token group (substring) representation used for token clustering. - """ + """A token group (substring) representation used for token clustering.""" + def __init__(self, group, index=None): self.g = group self.index = index @@ -428,7 +417,7 @@ def __len__(self): return len(self.g) def merge_score(group1, group2, special_tokens): - """ Compute the score of merging two token groups. + """Compute the score of merging two token groups. special_tokens: tokens (such as separator tokens) that should be grouped last """ @@ -489,8 +478,7 @@ def merge_score(group1, group2, special_tokens): return score def merge_closest_groups(groups, special_tokens): - """ Finds the two token groups with the best merge score and merges them. - """ + """Finds the two token groups with the best merge score and merges them.""" scores = [merge_score(groups[i], groups[i+1], special_tokens) for i in range(len(groups)-1)] #print(scores) ind = np.argmax(scores) @@ -504,7 +492,7 @@ def merge_closest_groups(groups, special_tokens): groups.pop(ind+1) def partition_tree(decoded_tokens, special_tokens): - """ Build a heriarchial clustering of tokens that align with sentence structure. + """Build a heriarchial clustering of tokens that align with sentence structure. Note that this is fast and heuristic right now. TODO: Build this using a real constituency parser. diff --git a/shap/models/_model.py b/shap/models/_model.py index 2b1380c8c..d1fe94d2b 100644 --- a/shap/models/_model.py +++ b/shap/models/_model.py @@ -5,12 +5,10 @@ class Model(Serializable): - """ This is the superclass of all models. - """ + """This is the superclass of all models.""" def __init__(self, model=None): - """ Wrap a callable model as a SHAP Model object. - """ + """Wrap a callable model as a SHAP Model object.""" if isinstance(model, Model): self.inner_model = model.inner_model else: @@ -26,8 +24,7 @@ def __call__(self, *args): return out def save(self, out_file): - """ Save the model to the given file stream. - """ + """Save the model to the given file stream.""" super().save(out_file) with Serializer(out_file, "shap.Model", version=0) as s: s.save("model", self.inner_model) diff --git a/shap/models/_teacher_forcing.py b/shap/models/_teacher_forcing.py index 790ca4674..83315497e 100644 --- a/shap/models/_teacher_forcing.py +++ b/shap/models/_teacher_forcing.py @@ -9,7 +9,7 @@ class TeacherForcing(Model): - """ Generates scores (log odds) for output text explanation algorithms using Teacher Forcing technique. + """Generates scores (log odds) for output text explanation algorithms using Teacher Forcing technique. This class supports generation of log odds for transformer models as well as functions. In model agnostic cases (model is function) it expects a similarity_model and similarity_tokenizer to approximate log odd scores @@ -17,7 +17,7 @@ class TeacherForcing(Model): """ def __init__(self, model, tokenizer=None, similarity_model=None, similarity_tokenizer=None, batch_size=128, device=None): - """ Build a teacher forcing model from the given text generation model. + """Build a teacher forcing model from the given text generation model. Parameters ---------- @@ -43,6 +43,7 @@ def __init__(self, model, tokenizer=None, similarity_model=None, similarity_toke ------- numpy.ndarray The scores (log odds) of generating target sentence ids using the model. + """ super().__init__(model) @@ -89,7 +90,7 @@ def __init__(self, model, tokenizer=None, similarity_model=None, similarity_toke self.similarity_model_type = "tf" def __call__(self, X, Y): - """ Computes log odds scores of generating output(text) for a given batch of input(text/image) . + """Computes log odds scores of generating output(text) for a given batch of input(text/image) . Parameters ---------- @@ -103,6 +104,7 @@ def __call__(self, X, Y): ------- numpy.ndarray A numpy array of log odds scores for every input pair (masked_X, X) + """ output_batch = None # caching updates output names and target sentence ids @@ -121,7 +123,7 @@ def __call__(self, X, Y): return output_batch def update_output_names(self, output): - """ The function updates output tokens. + """The function updates output tokens. It mimics the caching mechanism to update the output tokens for every new row of explanation that are to be explained. @@ -130,6 +132,7 @@ def update_output_names(self, output): ---------- output: numpy.ndarray Output(sentence/sentence ids) for an explanation row. + """ # check if the target sentence has been updated (occurs when explaining a new row) if (self.output is None) or (not np.array_equal(self.output, output)): @@ -137,7 +140,7 @@ def update_output_names(self, output): self.output_names = self.get_output_names(output) def get_output_names(self, output): - """ Gets the output tokens by computing the output sentence ids and output names using the similarity_tokenizer. + """Gets the output tokens by computing the output sentence ids and output names using the similarity_tokenizer. Parameters ---------- @@ -148,13 +151,14 @@ def get_output_names(self, output): ------- list A list of output tokens. + """ output_ids = self.get_outputs(output) output_names = [self.similarity_tokenizer.decode([x]).strip() for x in output_ids[0, :]] return output_names def get_outputs(self, X): - """ The function tokenizes output sentences and returns ids. + """The function tokenizes output sentences and returns ids. Parameters ---------- @@ -165,6 +169,7 @@ def get_outputs(self, X): ------- numpy.ndarray An array of output(target sentence) ids. + """ # check if output is a sentence or already parsed target ids if X.dtype.type is np.str_: @@ -179,7 +184,7 @@ def get_outputs(self, X): return output_ids def get_inputs(self, X, padding_side='right'): - """ The function tokenizes source sentences. + """The function tokenizes source sentences. In model agnostic case, the function calls model(X) which is expected to return a batch of output sentences which is tokenized to compute inputs. @@ -193,6 +198,7 @@ def get_inputs(self, X, padding_side='right'): ------- dict Dictionary of padded source sentence ids and attention mask as tensors("pt" or "tf" based on similarity_model_type). + """ if self.model_agnostic: # In model agnostic case, we first pass the input through the model and then tokenize output sentence @@ -208,7 +214,7 @@ def get_inputs(self, X, padding_side='right'): return inputs def get_logodds(self, logits): - """ Calculates log odds from logits. + """Calculates log odds from logits. This function passes the logits through softmax and then computes log odds for the output(target sentence) ids. @@ -221,6 +227,7 @@ def get_logodds(self, logits): ------- numpy.ndarray Computes log odds for corresponding output ids. + """ # set output ids for which scores are to be extracted if self.output.dtype.type is np.str_: @@ -239,7 +246,7 @@ def calc_logodds(arr): return logodds_for_output_ids def model_inference(self, inputs, output_ids): - """ This function performs model inference for tensorflow and pytorch models. + """This function performs model inference for tensorflow and pytorch models. Parameters ---------- @@ -253,6 +260,7 @@ def model_inference(self, inputs, output_ids): ------- numpy.ndarray Returns output logits from the model. + """ if self.similarity_model_type == "pt": import torch @@ -307,7 +315,7 @@ def model_inference(self, inputs, output_ids): return logits def get_teacher_forced_logits(self, X, Y): - """ The function generates logits for transformer models. + """The function generates logits for transformer models. It generates logits for encoder-decoder models as well as decoder only models by using the teacher forcing technique. @@ -323,6 +331,7 @@ def get_teacher_forced_logits(self, X, Y): ------- numpy.ndarray Decoder output logits for output(target sentence) ids. + """ # check if type of model architecture assigned in model config if (hasattr(self.similarity_model.config, "is_encoder_decoder") and not self.similarity_model.config.is_encoder_decoder) \ diff --git a/shap/models/_text_generation.py b/shap/models/_text_generation.py index 5da41f1a0..6960becae 100644 --- a/shap/models/_text_generation.py +++ b/shap/models/_text_generation.py @@ -6,13 +6,13 @@ class TextGeneration(Model): - """ Generates target sentence/ids using a base model. + """Generates target sentence/ids using a base model. It generates target sentence/ids for a model (a pretrained transformer model or a function). """ def __init__(self, model=None, tokenizer=None, target_sentences=None, device=None): - """ Create a text generator model from a pretrained transformer model or a function. + """Create a text generator model from a pretrained transformer model or a function. For a pretrained transformer model, a tokenizer should be passed. @@ -34,6 +34,7 @@ def __init__(self, model=None, tokenizer=None, target_sentences=None, device=Non ------- numpy.ndarray Array of target sentence/ids. + """ super().__init__(model) @@ -60,7 +61,7 @@ def __init__(self, model=None, tokenizer=None, target_sentences=None, device=Non self.target_X = None def __call__(self, X): - """ Generates target sentence/ids from X. + """Generates target sentence/ids from X. Parameters ---------- @@ -71,6 +72,7 @@ def __call__(self, X): ------- numpy.ndarray Array of target sentence/ids. + """ if (self.X is None) or (isinstance(self.X, np.ndarray) and not np.array_equal(self.X, X)) or \ (isinstance(self.X, str) and (self.X != X)): @@ -88,7 +90,7 @@ def __call__(self, X): return np.array(self.target_X) def get_inputs(self, X, padding_side='right'): - """ The function tokenizes source sentences. + """The function tokenizes source sentences. In model agnostic case, the function calls model(X) which is expected to return a batch of output sentences which is tokenized to compute inputs. @@ -102,6 +104,7 @@ def get_inputs(self, X, padding_side='right'): ------- dict Dictionary of padded source sentence ids and attention mask as tensors("pt" or "tf" based on model_type). + """ # set tokenizer padding to prepare inputs for batch inferencing # padding_side="left" for only decoder models text generation eg. GPT2 @@ -112,7 +115,7 @@ def get_inputs(self, X, padding_side='right'): return inputs def model_generate(self, X): - """ This function performs text generation for tensorflow and pytorch models. + """This function performs text generation for tensorflow and pytorch models. Parameters ---------- @@ -123,6 +126,7 @@ def model_generate(self, X): ------- numpy.ndarray Returns target sentence ids. + """ if (hasattr(self.inner_model.config, "is_encoder_decoder") and not self.inner_model.config.is_encoder_decoder) \ and (hasattr(self.inner_model.config, "is_decoder") and not self.inner_model.config.is_decoder): @@ -189,7 +193,7 @@ def model_generate(self, X): return target_X def parse_prefix_suffix_for_model_generate_output(self, output): - """ Calculates if special tokens are present in the beginning/end of the model generated output. + """Calculates if special tokens are present in the beginning/end of the model generated output. Parameters ---------- @@ -200,6 +204,7 @@ def parse_prefix_suffix_for_model_generate_output(self, output): ------- dict Dictionary of prefix and suffix lengths concerning special tokens in output ids. + """ keep_prefix, keep_suffix = 0, 0 if self.tokenizer.convert_ids_to_tokens(output[0]) in self.tokenizer.special_tokens_map.values(): diff --git a/shap/models/_topk_lm.py b/shap/models/_topk_lm.py index 12281bf6b..da59cabf6 100644 --- a/shap/models/_topk_lm.py +++ b/shap/models/_topk_lm.py @@ -8,11 +8,10 @@ class TopKLM(Model): - """ Generates scores (log odds) for the top-k tokens for Causal/Masked LM. - """ + """Generates scores (log odds) for the top-k tokens for Causal/Masked LM.""" def __init__(self, model, tokenizer, k=10, generate_topk_token_ids=None, batch_size=128, device=None): - """ Take Causal/Masked LM model and tokenizer and build a log odds output model for the top-k tokens. + """Take Causal/Masked LM model and tokenizer and build a log odds output model for the top-k tokens. Parameters ---------- @@ -35,6 +34,7 @@ def __init__(self, model, tokenizer, k=10, generate_topk_token_ids=None, batch_s ------- numpy.ndarray The scores (log odds) of generating top-k token ids using the model. + """ super().__init__(model) @@ -62,7 +62,7 @@ def __init__(self, model, tokenizer, k=10, generate_topk_token_ids=None, batch_s def __call__(self, masked_X, X): - """ Computes log odds scores for a given batch of masked inputs for the top-k tokens for Causal/Masked LM. + """Computes log odds scores for a given batch of masked inputs for the top-k tokens for Causal/Masked LM. Parameters ---------- @@ -76,6 +76,7 @@ def __call__(self, masked_X, X): ------- numpy.ndarray A numpy array of log odds scores for top-k tokens for every input pair (masked_X, X) + """ output_batch = None self.update_cache_X(X[:1]) @@ -91,7 +92,7 @@ def __call__(self, masked_X, X): return output_batch def update_cache_X(self, X): - """ The function updates original input(X) and top-k token ids for the Causal/Masked LM. + """The function updates original input(X) and top-k token ids for the Causal/Masked LM. It mimics the caching mechanism to update the original input and topk token ids that are to be explained and which updates for every new row of explanation. @@ -100,6 +101,7 @@ def update_cache_X(self, X): ---------- X: np.ndarray Input(Text) for an explanation row. + """ # check if the source sentence has been updated (occurs when explaining a new row) if (self.X is None) or (not np.array_equal(self.X, X)): @@ -107,7 +109,7 @@ def update_cache_X(self, X): self.output_names = self.get_output_names_and_update_topk_token_ids(self.X) def get_output_names_and_update_topk_token_ids(self, X): - """ Gets the token names for top-k token ids for Causal/Masked LM. + """Gets the token names for top-k token ids for Causal/Masked LM. Parameters ---------- @@ -118,8 +120,8 @@ def get_output_names_and_update_topk_token_ids(self, X): ------- list A list of output tokens. - """ + """ # see if the user gave a custom token generator if self._custom_generate_topk_token_ids is not None: return self._custom_generate_topk_token_ids(X) @@ -130,7 +132,7 @@ def get_output_names_and_update_topk_token_ids(self, X): return output_names def get_logodds(self, logits): - """ Calculates log odds from logits. + """Calculates log odds from logits. This function passes the logits through softmax and then computes log odds for the top-k token ids. @@ -143,6 +145,7 @@ def get_logodds(self, logits): ------- numpy.ndarray Computes log odds for corresponding top-k token ids. + """ # pass logits through softmax, get the token corresponding score and convert back to log odds (as one vs all) def calc_logodds(arr): @@ -156,7 +159,7 @@ def calc_logodds(arr): return logodds_for_topk_token_ids def get_inputs(self, X, padding_side='right'): - """ The function tokenizes source sentence. + """The function tokenizes source sentence. Parameters ---------- @@ -167,6 +170,7 @@ def get_inputs(self, X, padding_side='right'): ------- dict Dictionary of padded source sentence ids and attention mask as tensors("pt" or "tf" based on similarity_model_type). + """ self.tokenizer.padding_side = padding_side inputs = self.tokenizer(X.tolist(), return_tensors=self.model_type, padding=True) @@ -175,7 +179,7 @@ def get_inputs(self, X, padding_side='right'): return inputs def generate_topk_token_ids(self, X): - """ Generates top-k token ids for Causal/Masked LM. + """Generates top-k token ids for Causal/Masked LM. Parameters ---------- @@ -186,13 +190,14 @@ def generate_topk_token_ids(self, X): ------- np.ndarray An array of top-k token ids. + """ logits = self.get_lm_logits(X) topk_tokens_ids = (-logits).argsort()[0, :self.k] return topk_tokens_ids def get_lm_logits(self, X): - """ Evaluates a Causal/Masked LM model and returns logits corresponding to next word/masked word. + """Evaluates a Causal/Masked LM model and returns logits corresponding to next word/masked word. Parameters ---------- @@ -203,6 +208,7 @@ def get_lm_logits(self, X): ------- numpy.ndarray Logits corresponding to next word/masked word. + """ if safe_isinstance(self.inner_model, MODELS_FOR_CAUSAL_LM): inputs = self.get_inputs(X, padding_side="left") diff --git a/shap/models/_transformers_pipeline.py b/shap/models/_transformers_pipeline.py index 5b0c85979..87d80dff0 100644 --- a/shap/models/_transformers_pipeline.py +++ b/shap/models/_transformers_pipeline.py @@ -5,7 +5,7 @@ class TransformersPipeline(Model): - """ This wraps a transformers pipeline object for easy explanations. + """This wraps a transformers pipeline object for easy explanations. By default transformers pipeline object output lists of dictionaries, not standard tensors as expected by SHAP. This class wraps pipelines to make them output nice @@ -13,8 +13,7 @@ class TransformersPipeline(Model): """ def __init__(self, pipeline, rescale_to_logits=False): - """ Build a new model by wrapping the given pipeline object. - """ + """Build a new model by wrapping the given pipeline object.""" super().__init__(pipeline) # the pipeline becomes our inner_model self.rescale_to_logits = rescale_to_logits diff --git a/shap/plots/_bar.py b/shap/plots/_bar.py index d35da5813..36e77f2c1 100644 --- a/shap/plots/_bar.py +++ b/shap/plots/_bar.py @@ -66,7 +66,6 @@ def bar(shap_values, max_display=10, order=Explanation.abs, clustering=None, clu See `bar plot examples `_. """ - # assert str(type(shap_values)).endswith("Explanation'>"), "The shap_values parameter must be a shap.Explanation object!" # convert Explanation objects to dictionaries diff --git a/shap/plots/_beeswarm.py b/shap/plots/_beeswarm.py index 559cc115b..64f812bad 100644 --- a/shap/plots/_beeswarm.py +++ b/shap/plots/_beeswarm.py @@ -1,5 +1,4 @@ -""" Summary plots of SHAP values across a whole dataset. -""" +"""Summary plots of SHAP values across a whole dataset.""" import warnings @@ -62,11 +61,9 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0), Examples -------- - See `beeswarm plot examples `_. """ - if not isinstance(shap_values, Explanation): emsg = ( "The beeswarm plot requires an `Explanation` object as the " @@ -490,8 +487,8 @@ def summary_legacy(shap_values, features=None, feature_names=None, max_display=N show_values_in_legend: bool Flag to print the mean of the SHAP values in the multi-output bar plot. Set to False by default. - """ + """ # support passing an explanation object if str(type(shap_values)).endswith("Explanation'>"): shap_exp = shap_values diff --git a/shap/plots/_benchmark.py b/shap/plots/_benchmark.py index 7f7b16bd9..aba715482 100644 --- a/shap/plots/_benchmark.py +++ b/shap/plots/_benchmark.py @@ -16,9 +16,7 @@ } def benchmark(benchmark, show=True): - """ Plot a BenchmarkResult or list of such results. - """ - + """Plot a BenchmarkResult or list of such results.""" if hasattr(benchmark, "__iter__"): benchmark = list(benchmark) diff --git a/shap/plots/_decision.py b/shap/plots/_decision.py index 213942421..7ac78da81 100644 --- a/shap/plots/_decision.py +++ b/shap/plots/_decision.py @@ -1,4 +1,4 @@ -""" Visualize cumulative SHAP values.""" +"""Visualize cumulative SHAP values.""" from typing import Union @@ -52,8 +52,7 @@ def __decision_plot_matplotlib( legend_labels, legend_location, ): - """matplotlib rendering for decision_plot()""" - + """Matplotlib rendering for decision_plot()""" # image size row_height = 0.4 if auto_size_plot: @@ -170,8 +169,7 @@ class DecisionPlotResult: """ def __init__(self, base_value, shap_values, feature_names, feature_idx, xlim): - """ - Example + """Example ------- Plot two decision plots using the same feature order and x-axis. >>> range1, range2 = range(20), range(20, 40) @@ -342,7 +340,6 @@ def decision( Examples -------- - Plot two decision plots using the same feature order and x-axis. >>> range1, range2 = range(20), range(20, 40) @@ -352,7 +349,6 @@ def decision( See more `decision plot examples here `_. """ - # code taken from force_plot. auto unwrap the base_value if type(base_value) == np.ndarray and len(base_value) == 1: base_value = base_value[0] @@ -584,8 +580,8 @@ def multioutput_decision(base_values, shap_values, row_index, **kwargs) -> Union ------- DecisionPlotResult or None Returns a DecisionPlotResult object if `return_objects=True`. Returns `None` otherwise (the default). - """ + """ if not (isinstance(base_values, list) and isinstance(shap_values, list)): raise ValueError("The base_values and shap_values args expect lists.") diff --git a/shap/plots/_embedding.py b/shap/plots/_embedding.py index d1f4fc52a..fb1b79709 100644 --- a/shap/plots/_embedding.py +++ b/shap/plots/_embedding.py @@ -7,7 +7,7 @@ def embedding(ind, shap_values, feature_names=None, method="pca", alpha=1.0, show=True): - """ Use the SHAP values as an embedding which we project to 2D for visualization. + """Use the SHAP values as an embedding which we project to 2D for visualization. Parameters ---------- @@ -32,8 +32,8 @@ def embedding(ind, shap_values, feature_names=None, method="pca", alpha=1.0, sho alpha : float The transparency of the data points (between 0 and 1). This can be useful to the show density of the data points when using a large dataset. - """ + """ if feature_names is None: feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1])] diff --git a/shap/plots/_force.py b/shap/plots/_force.py index 7c0bf6b57..a0b0b2432 100644 --- a/shap/plots/_force.py +++ b/shap/plots/_force.py @@ -1,5 +1,4 @@ -""" Visualize the SHAP values with additive force style layouts. -""" +"""Visualize the SHAP values with additive force style layouts.""" import base64 import json @@ -93,8 +92,8 @@ def force( Controls the feature names/values that are displayed on force plot. Only features that the magnitude of their shap value is larger than min_perc * (sum of all abs shap values) will be displayed. - """ + """ # support passing an explanation object if str(type(base_value)).endswith("Explanation'>"): shap_exp = base_value @@ -248,9 +247,7 @@ class AdditiveExplanation(Explanation): """Data structure for AdditiveForceVisualizer / AdditiveForceArrayVisualizer.""" def __init__(self, base_value, out_value, effects, effects_var, instance, link, model, data): - """ - - Parameters + """Parameters ---------- base_value : float This is the reference value that the feature contributions start from. @@ -259,6 +256,7 @@ def __init__(self, base_value, out_value, effects, effects_var, instance, link, out_value : float The model prediction value, taken as the sum of the SHAP values across all features and the ``base_value``. + """ self.base_value = base_value self.out_value = out_value @@ -308,7 +306,7 @@ def initjs(): def save_html(out_file, plot, full_html=True): - """ Save html plots to an output file. + """Save html plots to an output file. Parameters ---------- @@ -322,8 +320,8 @@ def save_html(out_file, plot, full_html=True): If ``True``, writes a complete HTML document starting with an ```` tag. If ``False``, only script and div tags are included. - """ + """ if not isinstance(plot, BaseVisualizer): raise TypeError("`save_html` requires a Visualizer returned by `shap.plots.force()`.") @@ -399,6 +397,7 @@ def visualize( ---------- e : AdditiveExplanation Contains the data necessary for additive force plots. + """ plot_cmap = verify_valid_cmap(plot_cmap) @@ -474,15 +473,14 @@ class AdditiveForceVisualizer(BaseVisualizer): """Visualizer for a single Additive Force plot.""" def __init__(self, e, plot_cmap="RdBu"): - """ - - Parameters + """Parameters ---------- e : AdditiveExplanation Contains the data necessary for additive force plots. plot_cmap : str or list[str] Color map to use. It can be a string (defaults to ``RdBu``) or a list of hex color strings. + """ if not isinstance(e, AdditiveExplanation): emsg = "AdditiveForceVisualizer can only visualize AdditiveExplanation objects!" diff --git a/shap/plots/_group_difference.py b/shap/plots/_group_difference.py index b983dc74d..7498cf162 100644 --- a/shap/plots/_group_difference.py +++ b/shap/plots/_group_difference.py @@ -6,7 +6,7 @@ def group_difference(shap_values, group_mask, feature_names=None, xlabel=None, xmin=None, xmax=None, max_display=None, sort=True, show=True, ax=None): - """ This plots the difference in mean SHAP values between two groups. + """This plots the difference in mean SHAP values between two groups. It is useful to decompose many group level metrics about the model output among the input features. Quantitative fairness metrics for machine learning models are @@ -22,8 +22,8 @@ def group_difference(shap_values, group_mask, feature_names=None, xlabel=None, x feature_names : list A list of feature names. - """ + """ # Compute confidence bounds for the group difference value vs = [] gmean = group_mask.mean() diff --git a/shap/plots/_heatmap.py b/shap/plots/_heatmap.py index 7a42b34f6..9e1a7f239 100644 --- a/shap/plots/_heatmap.py +++ b/shap/plots/_heatmap.py @@ -51,11 +51,9 @@ def heatmap(shap_values, instance_order=Explanation.hclust(), feature_values=Exp Examples -------- - See `heatmap plot examples `_. """ - # sort the SHAP values matrix by rows and columns values = shap_values.values if issubclass(type(feature_values), OpChain): diff --git a/shap/plots/_image.py b/shap/plots/_image.py index 841f659f4..c67648e55 100644 --- a/shap/plots/_image.py +++ b/shap/plots/_image.py @@ -64,11 +64,9 @@ def image(shap_values: Explanation or np.ndarray, Examples -------- - See `image plot examples `_. """ - # support passing an explanation object if str(type(shap_values)).endswith("Explanation'>"): shap_exp = shap_values @@ -178,7 +176,7 @@ def image(shap_values: Explanation or np.ndarray, def image_to_text(shap_values): - """ Plots SHAP values for image inputs with test outputs. + """Plots SHAP values for image inputs with test outputs. Parameters ---------- diff --git a/shap/plots/_monitoring.py b/shap/plots/_monitoring.py index 5a7f8d05e..c3951ee30 100644 --- a/shap/plots/_monitoring.py +++ b/shap/plots/_monitoring.py @@ -14,7 +14,7 @@ def truncate_text(text, max_len): return text def monitoring(ind, shap_values, features, feature_names=None, show=True): - """ Create a SHAP monitoring plot. + """Create a SHAP monitoring plot. (Note this function is preliminary and subject to change!!) A SHAP monitoring plot is meant to display the behavior of a model @@ -35,8 +35,8 @@ def monitoring(ind, shap_values, features, feature_names=None, show=True): feature_names : list Names of the features (length # features) - """ + """ if isinstance(features, pd.DataFrame): if feature_names is None: feature_names = features.columns diff --git a/shap/plots/_partial_dependence.py b/shap/plots/_partial_dependence.py index 4935079a6..629b8567e 100644 --- a/shap/plots/_partial_dependence.py +++ b/shap/plots/_partial_dependence.py @@ -8,11 +8,10 @@ def compute_bounds(xmin, xmax, xv): - """ Handles any setting of xmax and xmin. + """Handles any setting of xmax and xmin. Note that we handle None, float, or "percentile(float)" formats. """ - if xmin is not None or xmax is not None: if isinstance(xmin, str) and xmin.startswith("percentile"): xmin = np.nanpercentile(xv, float(xmin[11:-1])) @@ -31,9 +30,7 @@ def partial_dependence(ind, model, data, xmin="percentile(0)", xmax="percentile( feature_expected_value=False, shap_values=None, ylabel=None, ice=True, ace_opacity=1, pd_opacity=1, pd_linewidth=2, ace_linewidth='auto', ax=None, show=True): - """ A basic partial dependence plot function. - """ - + """A basic partial dependence plot function.""" if isinstance(data, Explanation): features = data.data shap_values = data diff --git a/shap/plots/_scatter.py b/shap/plots/_scatter.py index 5b23dc4e5..9d1d24b61 100644 --- a/shap/plots/_scatter.py +++ b/shap/plots/_scatter.py @@ -77,11 +77,9 @@ def scatter(shap_values, color="#1E88E5", hist=True, axis_color="#333333", cmap= Examples -------- - See `scatter plot examples `_. """ - assert str(type(shap_values)).endswith("Explanation'>"), "The shap_values parameter must be a shap.Explanation object!" # see if we are plotting multiple columns @@ -477,7 +475,7 @@ def dependence_legacy(ind, shap_values=None, features=None, feature_names=None, color="#1E88E5", axis_color="#333333", cmap=None, dot_size=16, x_jitter=0, alpha=1, title=None, xmin=None, xmax=None, ax=None, show=True, ymin=None, ymax=None): - """ Create a SHAP dependence plot, colored by an interaction feature. + """Create a SHAP dependence plot, colored by an interaction feature. Plots the value of the feature on the x-axis and the SHAP value of the same feature on the y-axis. This shows how the model depends on the given feature, and is like a @@ -538,7 +536,6 @@ def dependence_legacy(ind, shap_values=None, features=None, feature_names=None, Represents the upper bound of the plot's y-axis. """ - if cmap is None: cmap = colors.red_blue diff --git a/shap/plots/_text.py b/shap/plots/_text.py index 7309436e5..c89ca9aa6 100644 --- a/shap/plots/_text.py +++ b/shap/plots/_text.py @@ -59,14 +59,12 @@ def text(shap_values, num_starting_labels=0, grouping_threshold=0.01, separator= Examples -------- - See `text plot examples `_. """ def values_min_max(values, base_values): - """ Used to pick our axis limits. - """ + """Used to pick our axis limits.""" fx = base_values + values.sum() xmin = fx - values[values > 0].sum() xmax = fx - values[values < 0].sum() @@ -702,12 +700,11 @@ def draw_tick_mark(xval, label=None, bold=False, backing=False): def text_old(shap_values, tokens, partition_tree=None, num_starting_labels=0, grouping_threshold=1, separator=''): - """ Plots an explanation of a string of text using coloring and interactive labels. + """Plots an explanation of a string of text using coloring and interactive labels. The output is interactive HTML and you can click on any token to toggle the display of the SHAP value assigned to that token. """ - # See if we got hierarchical input data. If we did then we need to reprocess the # shap_values and tokens to get the groups we want to display warnings.warn( diff --git a/shap/plots/_utils.py b/shap/plots/_utils.py index dda4dd112..fdc549565 100644 --- a/shap/plots/_utils.py +++ b/shap/plots/_utils.py @@ -30,9 +30,7 @@ def convert_ordering(ordering, shap_values): def get_sort_order(dist, clust_order, cluster_threshold, feature_order): - """ Returns a sorted order of the values where we respect the clustering order when dist[i,j] < cluster_threshold - """ - + """Returns a sorted order of the values where we respect the clustering order when dist[i,j] < cluster_threshold""" #feature_imp = np.abs(values) # if partition_tree is not None: @@ -77,8 +75,7 @@ def get_sort_order(dist, clust_order, cluster_threshold, feature_order): return feature_order def merge_nodes(values, partition_tree): - """ This merges the two clustered leaf nodes with the smallest total value. - """ + """This merges the two clustered leaf nodes with the smallest total value.""" M = partition_tree.shape[0] + 1 ptind = 0 @@ -130,12 +127,11 @@ def merge_nodes(values, partition_tree): return partition_tree_new, ind1, ind2 def dendrogram_coords(leaf_positions, partition_tree): - """ Returns the x and y coords of the lines of a dendrogram where the leaf order is given. + """Returns the x and y coords of the lines of a dendrogram where the leaf order is given. Note that scipy can compute these coords as well, but it does not allow you to easily specify a specific leaf order, hence this reimplementation. """ - xout = [] yout = [] _dendrogram_coords_rec(partition_tree.shape[0]-1, leaf_positions, partition_tree, xout, yout) @@ -161,8 +157,7 @@ def _dendrogram_coords_rec(pos, leaf_positions, partition_tree, xout, yout): return (x_left + x_right) / 2, y_curr def fill_internal_max_values(partition_tree, leaf_values): - """ This fills the forth column of the partition tree matrix with the max leaf value in that cluster. - """ + """This fills the forth column of the partition tree matrix with the max leaf value in that cluster.""" M = partition_tree.shape[0] + 1 new_tree = partition_tree.copy() for i in range(new_tree.shape[0]): @@ -183,8 +178,7 @@ def fill_internal_max_values(partition_tree, leaf_values): return new_tree def fill_counts(partition_tree): - """ This updates the - """ + """This updates the""" M = partition_tree.shape[0] + 1 for i in range(partition_tree.shape[0]): val = 0 diff --git a/shap/plots/_violin.py b/shap/plots/_violin.py index 6c8c8fbc0..97bcf241d 100644 --- a/shap/plots/_violin.py +++ b/shap/plots/_violin.py @@ -1,5 +1,4 @@ -""" Summary plots of SHAP values (violin plot) across a whole dataset. -""" +"""Summary plots of SHAP values (violin plot) across a whole dataset.""" import warnings @@ -64,11 +63,9 @@ def violin(shap_values, features=None, feature_names=None, max_display=None, plo Examples -------- - See `violin plot examples `_. """ - # support passing an explanation object if str(type(shap_values)).endswith("Explanation'>"): shap_exp = shap_values diff --git a/shap/plots/_waterfall.py b/shap/plots/_waterfall.py index 7bf8fd4d6..8a73646ba 100644 --- a/shap/plots/_waterfall.py +++ b/shap/plots/_waterfall.py @@ -40,11 +40,9 @@ def waterfall(shap_values, max_display=10, show=True): Examples -------- - See `waterfall plot examples `_. """ - # Turn off interactive plot if show is False: plt.ioff() @@ -320,7 +318,7 @@ def waterfall(shap_values, max_display=10, show=True): def waterfall_legacy(expected_value, shap_values=None, features=None, feature_names=None, max_display=10, show=True): - """ Plots an explanation of a single prediction as a waterfall plot. + """Plots an explanation of a single prediction as a waterfall plot. The SHAP value of a feature represents the impact of the evidence provided by that feature on the model's output. The waterfall plot is designed to visually display how the SHAP values (evidence) of each feature @@ -351,8 +349,8 @@ def waterfall_legacy(expected_value, shap_values=None, features=None, feature_na show : bool Whether matplotlib.pyplot.show() is called before returning. Setting this to False allows the plot to be customized further after it has been created. - """ + """ # Turn off interactive plot when not calling plt.show if show is False: plt.ioff() diff --git a/shap/plots/colors/_colors.py b/shap/plots/colors/_colors.py index 248f66f5b..3b5ab627b 100644 --- a/shap/plots/colors/_colors.py +++ b/shap/plots/colors/_colors.py @@ -1,5 +1,4 @@ -""" This defines some common colors. -""" +"""This defines some common colors.""" import numpy as np diff --git a/shap/utils/_clustering.py b/shap/utils/_clustering.py index 0d9e65c10..c8ddab081 100644 --- a/shap/utils/_clustering.py +++ b/shap/utils/_clustering.py @@ -17,7 +17,7 @@ def partition_tree(X, metric="correlation"): def partition_tree_shuffle(indexes, index_mask, partition_tree): - """ Randomly shuffle the indexes in a way that is consistent with the given partition tree. + """Randomly shuffle the indexes in a way that is consistent with the given partition tree. Parameters ---------- @@ -27,6 +27,7 @@ def partition_tree_shuffle(indexes, index_mask, partition_tree): A bool mask of which indexes we want to include in the shuffled list. partition_tree: np.array The partition tree we should follow. + """ M = len(index_mask) #switch = np.random.randn(M) < 0 @@ -79,16 +80,14 @@ def _mask_delta_score(m1, m2): def hclust_ordering(X, metric="sqeuclidean", anchor_first=False): - """ A leaf ordering is under-defined, this picks the ordering that keeps nearby samples similar. - """ - + """A leaf ordering is under-defined, this picks the ordering that keeps nearby samples similar.""" # compute a hierarchical clustering and return the optimal leaf ordering D = scipy.spatial.distance.pdist(X, metric) cluster_matrix = scipy.cluster.hierarchy.complete(D) return scipy.cluster.hierarchy.leaves_list(scipy.cluster.hierarchy.optimal_leaf_ordering(cluster_matrix, D)) def xgboost_distances_r2(X, y, learning_rate=0.6, early_stopping_rounds=2, subsample=1, max_estimators=10000, random_state=0): - """ Compute reducancy distances scaled from 0-1 among all the feature in X relative to the label y. + """Compute reducancy distances scaled from 0-1 among all the feature in X relative to the label y. Distances are measured by training univariate XGBoost models of y for all the features, and then predicting the output of these models using univariate XGBoost models of other features. If one @@ -97,7 +96,6 @@ def xgboost_distances_r2(X, y, learning_rate=0.6, early_stopping_rounds=2, subsa to no redundancy while a distance of 0 corresponds to perfect redundancy (measured using the proportion of variance explained). Note these distances are not symmetric. """ - import xgboost # pick our train/text split @@ -169,6 +167,7 @@ def hclust(X, y=None, linkage="single", metric="auto", random_state=0): ------- clustering: np.array The hierarchical clustering encoded as a linkage matrix. + """ if isinstance(X, pd.DataFrame): X = X.values diff --git a/shap/utils/_exceptions.py b/shap/utils/_exceptions.py index 0e8b3d0e4..87e114c46 100644 --- a/shap/utils/_exceptions.py +++ b/shap/utils/_exceptions.py @@ -1,6 +1,5 @@ class DimensionError(Exception): - """ - Used for instances where dimensions are either + """Used for instances where dimensions are either not supported or cause errors. """ @@ -17,9 +16,8 @@ class InvalidMaskerError(ValueError): pass class ExplainerError(Exception): - """ - Generic errors related to Explainers - """ + """Generic errors related to Explainers""" + pass class InvalidAlgorithmError(ValueError): diff --git a/shap/utils/_general.py b/shap/utils/_general.py index 3de0eaf98..43a5018af 100644 --- a/shap/utils/_general.py +++ b/shap/utils/_general.py @@ -50,12 +50,11 @@ def convert_name(ind, shap_values, input_names): return ind def potential_interactions(shap_values_column, shap_values_matrix): - """ Order other features by how much interaction they seem to have with the feature at the given index. + """Order other features by how much interaction they seem to have with the feature at the given index. This just bins the SHAP values for a feature along that feature's value. For true Shapley interaction index values for SHAP see the interaction_contribs option implemented in XGBoost. """ - # ignore inds that are identical to the column ignore_inds = np.where((shap_values_matrix.values.T - shap_values_column.values).T.std(0) < 1e-8) @@ -99,12 +98,11 @@ def potential_interactions(shap_values_column, shap_values_matrix): def approximate_interactions(index, shap_values, X, feature_names=None): - """ Order other features by how much interaction they seem to have with the feature at the given index. + """Order other features by how much interaction they seem to have with the feature at the given index. This just bins the SHAP values for a feature along that feature's value. For true Shapley interaction index values for SHAP see the interaction_contribs option implemented in XGBoost. """ - # convert from DataFrames if we got any if isinstance(X, pd.DataFrame): if feature_names is None: @@ -184,6 +182,7 @@ def sample(X, nsamples=100, random_state=0): random_state : Determines random number generation for shuffling the data. Use this to ensure reproducibility across multiple function calls. + """ if hasattr(X, "shape"): over_count = nsamples >= X.shape[0] @@ -196,8 +195,7 @@ def sample(X, nsamples=100, random_state=0): def safe_isinstance(obj, class_path_str): - """ - Acts as a safe version of isinstance without having to explicitly + """Acts as a safe version of isinstance without having to explicitly import packages which may not exist in the users environment. Checks if obj is an instance of type specified by class_path_str. @@ -211,8 +209,9 @@ def safe_isinstance(obj, class_path_str): Example: `sklearn.ensemble.RandomForestRegressor` Returns - -------- + ------- bool: True if isinstance is true and the package exists, False otherwise + """ if isinstance(class_path_str, str): class_path_strs = [class_path_str] @@ -251,9 +250,7 @@ def safe_isinstance(obj, class_path_str): def format_value(s, format_str): - """ Strips trailing zeros and uses a unicode minus sign. - """ - + """Strips trailing zeros and uses a unicode minus sign.""" if not issubclass(type(s), str): s = format_str % s s = re.sub(r'\.?0+$', '', s) @@ -263,21 +260,18 @@ def format_value(s, format_str): # From: https://groups.google.com/forum/m/#!topic/openrefine/G7_PSdUeno0 def ordinal_str(n): - """ Converts a number to and ordinal string. - """ + """Converts a number to and ordinal string.""" return str(n) + {1: 'st', 2: 'nd', 3: 'rd'}.get(4 if 10 <= n % 100 < 20 else n % 10, "th") class OpChain: - """ A way to represent a set of dot chained operations on an object without actually running them. - """ + """A way to represent a set of dot chained operations on an object without actually running them.""" def __init__(self, root_name=""): self._ops = [] self._root_name = root_name def apply(self, obj): - """ Applies all our ops to the given object. - """ + """Applies all our ops to the given object.""" for o in self._ops: op,args,kwargs = o if args is not None: @@ -287,8 +281,7 @@ def apply(self, obj): return obj def __call__(self, *args, **kwargs): - """ Update the args for the previous operation. - """ + """Update the args for the previous operation.""" new_self = OpChain(self._root_name) new_self._ops = copy.copy(self._ops) new_self._ops[-1][1] = args diff --git a/shap/utils/_keras.py b/shap/utils/_keras.py index ab705c549..7aaef9ca7 100644 --- a/shap/utils/_keras.py +++ b/shap/utils/_keras.py @@ -1,11 +1,8 @@ -""" This file contains various utility functions that are useful but not core to SHAP. -""" +"""This file contains various utility functions that are useful but not core to SHAP.""" def clone_keras_layers(model, start_layer, stop_layer): - """ Clones the keras layers between the start and stop layer as a new model. - """ - + """Clones the keras layers between the start and stop layer as a new model.""" import tensorflow as tf if isinstance(start_layer, int): @@ -46,12 +43,11 @@ def clone_keras_layers(model, start_layer, stop_layer): return tf.keras.Model(layer_input, new_layers[stop_layer.output.name]) def split_keras_model(model, layer): - """ Splits the keras model around layer into two models. + """Splits the keras model around layer into two models. This is done such that model2(model1(X)) = model(X) and mode11(X) == layer(X) """ - if isinstance(layer, str): layer = model.get_layer(layer) elif isinstance(layer, int): diff --git a/shap/utils/_legacy.py b/shap/utils/_legacy.py index bc4364913..93af3f090 100644 --- a/shap/utils/_legacy.py +++ b/shap/utils/_legacy.py @@ -8,7 +8,7 @@ def kmeans(X, k, round_values=True): - """ Summarize a dataset with k mean samples weighted by the number of data points they + """Summarize a dataset with k mean samples weighted by the number of data points they each represent. Parameters @@ -26,8 +26,8 @@ def kmeans(X, k, round_values=True): Returns ------- DenseData object. - """ + """ group_names = [str(i) for i in range(X.shape[1])] if isinstance(X, pd.DataFrame): group_names = X.columns @@ -99,7 +99,7 @@ def __init__(self, f, out_names): def convert_to_model(val, keep_index=False): - """ Convert a model to a Model object. + """Convert a model to a Model object. Parameters ---------- @@ -109,6 +109,7 @@ def convert_to_model(val, keep_index=False): keep_index : bool If True then the index values will be passed to the model function as the first argument. When this is False the feature names will be removed from the model object to avoid unnecessary warnings. + """ if isinstance(val, Model): out = val diff --git a/shap/utils/_masked_model.py b/shap/utils/_masked_model.py index 595b8e5df..3ddef8059 100644 --- a/shap/utils/_masked_model.py +++ b/shap/utils/_masked_model.py @@ -9,7 +9,7 @@ class MaskedModel: - """ This is a utility class that combines a model, a masker object, and a current input. + """This is a utility class that combines a model, a masker object, and a current input. The combination of a model, a masker object, and a current input produces a binary set function that can be called to mask out any set of inputs. This class attempts to be smart @@ -225,7 +225,7 @@ def mask_shapes(self): return [a.shape for a in self.args] # TODO: this will need to get more flexible def __len__(self): - """ How many binary inputs there are to toggle. + """How many binary inputs there are to toggle. By default we just match what the masker tells us. But if the masker doesn't help us out by giving a length then we assume is the number of data inputs. @@ -239,9 +239,7 @@ def varying_inputs(self): return np.where(np.any(self._variants, axis=0))[0] def main_effects(self, inds=None, batch_size=None): - """ Compute the main effects for this model. - """ - + """Compute the main effects for this model.""" # if no indexes are given then we assume all indexes could be non-zero if inds is None: inds = np.arange(len(self)) @@ -272,9 +270,7 @@ def _assert_output_input_match(inputs, outputs): f"The model produced {len(outputs)} output rows when given {len(inputs[0])} input rows! Check the implementation of the model you provided for errors." def _convert_delta_mask_to_full(masks, full_masks): - """ This converts a delta masking array to a full bool masking array. - """ - + """This converts a delta masking array to a full bool masking array.""" i = -1 masks_pos = 0 while masks_pos < len(masks): @@ -412,11 +408,10 @@ def _build_fixed_multi_output(averaged_outs, last_outs, outputs, batch_positions def make_masks(cluster_matrix): - """ Builds a sparse CSR mask matrix from the given clustering. + """Builds a sparse CSR mask matrix from the given clustering. This function is optimized since trees for images can be very large. """ - M = cluster_matrix.shape[0] + 1 indices_row_pos = np.zeros(2 * M - 1, dtype=int) indptr = np.zeros(2 * M, dtype=int) @@ -466,14 +461,13 @@ def _rec_fill_masks(cluster_matrix, indices_row_pos, indptr, indices, M, ind): indices[pos + lind_size:pos + lind_size + rind_size] = indices[rpos:rpos + rind_size] def link_reweighting(p, link): - """ Returns a weighting that makes mean(weights*link(p)) == link(mean(p)). + """Returns a weighting that makes mean(weights*link(p)) == link(mean(p)). This is based on a linearization of the link function. When the link function is monotonic then we can find a set of positive weights that adjust for the non-linear influence changes on the expected value. Note that there are many possible reweightings that can satisfy the above property. This function returns the one that has the lowest L2 norm. """ - # the linearized link function is a first order Taylor expansion of the link function # centered at the expected value expected_value = np.mean(p, axis=0) diff --git a/shap/utils/_show_progress.py b/shap/utils/_show_progress.py index ce69a39ce..3f2391d32 100644 --- a/shap/utils/_show_progress.py +++ b/shap/utils/_show_progress.py @@ -4,8 +4,8 @@ class ShowProgress: - """ This is a simple wrapper around tqdm that includes a starting delay before printing. - """ + """This is a simple wrapper around tqdm that includes a starting delay before printing.""" + def __init__(self, iterable, total, desc, silent, start_delay): self.iter = iter(iterable) self.start_time = time.time() diff --git a/shap/utils/image.py b/shap/utils/image.py index 1107e0f09..8cae1be47 100644 --- a/shap/utils/image.py +++ b/shap/utils/image.py @@ -9,7 +9,7 @@ def is_empty(path): - """ Function to check if folder at given path exists and is not empty. + """Function to check if folder at given path exists and is not empty. Returns True if folder is empty or does not exist. """ @@ -25,9 +25,7 @@ def is_empty(path): return empty def make_dir(path): - """ - Function to create a new directory with given path or empty if it already exists. - """ + """Function to create a new directory with given path or empty if it already exists.""" if not os.path.exists(path): if not os.path.isfile(path): # make directory if it does not exist @@ -42,9 +40,7 @@ def make_dir(path): os.remove(path+file) def add_sample_images(path): - """ - Function to add sample images from imagenet50 SHAP data in the given folder. - """ + """Function to add sample images from imagenet50 SHAP data in the given folder.""" X, _ = shap.datasets.imagenet50() counter = 1 indexes_list = [25, 26, 30, 44] @@ -55,16 +51,13 @@ def add_sample_images(path): counter += 1 def load_image(path_to_image): - """ - Function to load image at given path and return numpy array of RGB float values. - """ + """Function to load image at given path and return numpy array of RGB float values.""" image = cv2.imread(path_to_image) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return np.array(image).astype('float') def check_valid_image(path_to_image): - """ - Function to check if a file has valid image extensions and return True if it does. + """Function to check if a file has valid image extensions and return True if it does. Note: Azure Cognitive Services only accepts below file formats. """ valid_extensions = (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".jfif") @@ -72,24 +65,24 @@ def check_valid_image(path_to_image): return True def save_image(array, path_to_image): - """ - Function to save image(RGB values array) at given path (filename and location). - """ + """Function to save image(RGB values array) at given path (filename and location).""" # saving array of RGB values as an image image = np.array(array)/255.0 plt.imsave(path_to_image, image) def resize_image(path_to_image, reshaped_dir): - """ - Function to resize given image retaining original aspect ratio and save in given directory 'reshaped_dir'. + """Function to resize given image retaining original aspect ratio and save in given directory 'reshaped_dir'. Returns numpy array of resized image and path where resized file is saved. - Note: + + Note + ---- Azure COGS CV has size limit of < 4MB and min size of 50x50 for images. Hence, large image files are being reshaped in code below to increase speed of SHAP explanations and run Azure COGS for image captions. If image (pixel_size, pixel_size) is greater than 500 for either of the dimensions: 1 - image is resized to have max. 500 pixel size for the dimension > 500 2 - other dimension is resized retaining the original aspect ratio + """ image = load_image(path_to_image) @@ -122,10 +115,7 @@ def resize_image(path_to_image, reshaped_dir): def display_grid_plot(list_of_captions, list_of_images, max_columns=4, figsize=(20,20)): - """ - Function to display grid of images and their titles/captions. - """ - + """Function to display grid of images and their titles/captions.""" # load list of images masked_images = [] for filename in list_of_images: diff --git a/shap/utils/transformers.py b/shap/utils/transformers.py index dd0dfe294..f19afc1d1 100644 --- a/shap/utils/transformers.py +++ b/shap/utils/transformers.py @@ -83,13 +83,12 @@ ] def is_transformers_lm(model): - """ Check if the given model object is a huggingface transformers language model. - """ + """Check if the given model object is a huggingface transformers language model.""" return (safe_isinstance(model, "transformers.PreTrainedModel") or safe_isinstance(model, "transformers.TFPreTrainedModel")) and \ safe_isinstance(model, MODELS_FOR_SEQ_TO_SEQ_CAUSAL_LM + MODELS_FOR_CAUSAL_LM) def parse_prefix_suffix_for_tokenizer(tokenizer): - """ Set prefix and suffix tokens based on null tokens. + """Set prefix and suffix tokens based on null tokens. Example for distillgpt2: null_tokens=[], for BART: null_tokens = [0,2] and for MarianMT: null_tokens=[0] used to slice tokens belonging to sentence after passing through tokenizer.encode(). @@ -132,12 +131,11 @@ def parse_prefix_suffix_for_tokenizer(tokenizer): } def getattr_silent(obj, attr): - """ This turns of verbose logging of missing attributes for huggingface transformers. + """This turns of verbose logging of missing attributes for huggingface transformers. This is motivated by huggingface transformers objects that print error warnings when we access unset properties. """ - reset_verbose = False if getattr(obj, 'verbose', False): reset_verbose = True diff --git a/tests/actions/_action.py b/tests/actions/_action.py index 687d585c8..827376313 100644 --- a/tests/actions/_action.py +++ b/tests/actions/_action.py @@ -1,5 +1,4 @@ -""" Unit tests for the Exact explainer. -""" +"""Unit tests for the Exact explainer.""" import numpy as np import pandas as pd @@ -11,8 +10,7 @@ def test_create_and_run(): X = pd.DataFrame({"feature1": np.ones(5), "feature2": np.ones(5)}) class IncreaseFeature1(shap.actions.Action): - """ Sample action. - """ + """Sample action.""" def __init__(self, amount): self.amount = amount diff --git a/tests/actions/_optimizer.py b/tests/actions/_optimizer.py index 629896dd3..497a5501e 100644 --- a/tests/actions/_optimizer.py +++ b/tests/actions/_optimizer.py @@ -1,5 +1,4 @@ -""" Unit tests for the Exact explainer. -""" +"""Unit tests for the Exact explainer.""" import numpy as np import pandas as pd @@ -13,8 +12,7 @@ def create_basic_scenario(): X = pd.DataFrame({"feature1": np.ones(5), "feature2": np.ones(5), "feature3": np.ones(5)}) class IncreaseFeature1(shap.actions.Action): - """ Sample action. - """ + """Sample action.""" def __init__(self, amount): self.amount = amount @@ -27,8 +25,7 @@ def __str__(self): return f"Improve feature1 by {self.amount}." class IncreaseFeature2(shap.actions.Action): - """ Sample action. - """ + """Sample action.""" def __init__(self, amount): self.amount = amount @@ -41,8 +38,7 @@ def __str__(self): return f"Improve feature2 by {self.amount}." class IncreaseFeature3(shap.actions.Action): - """ Sample action. - """ + """Sample action.""" def __init__(self, amount): self.amount = amount diff --git a/tests/benchmark/framework.py b/tests/benchmark/framework.py index 4b0f65b68..ced156f5b 100644 --- a/tests/benchmark/framework.py +++ b/tests/benchmark/framework.py @@ -12,8 +12,7 @@ def model(x): masker = X def test_update(): - """ This is to test the update function within benchmark/framework - """ + """This is to test the update function within benchmark/framework""" sort_order = 'positive' def score_function(true, pred): return np.mean(pred) @@ -28,8 +27,7 @@ def score_function(true, pred): assert len(scores['values'][metric]) == 3 def test_get_benchmark(): - """ This is to test the get benchmark function within benchmark/framework - """ + """This is to test the get benchmark function within benchmark/framework""" metrics = {'sort_order': ['positive', 'negative'], 'perturbation': ['keep']} scores = shap.benchmark.get_benchmark(model, X, y, explainer, masker, metrics) @@ -39,8 +37,7 @@ def test_get_benchmark(): assert len(scores['values']) == 2 def test_get_metrics(): - """ This is to test the get metrics function with respect to different selection method - """ + """This is to test the get metrics function with respect to different selection method""" scores1 = {'name': 'test1', 'metrics': ['keep positive', 'keep absolute'], 'values': dict()} scores2 = {'name': 'test2', 'metrics': ['keep positive', 'keep negative'], 'values': dict()} benchmarks = {'test1': scores1, 'test2': scores2} diff --git a/tests/explainers/__init__.py b/tests/explainers/__init__.py index c688c5ca3..1e98b551f 100644 --- a/tests/explainers/__init__.py +++ b/tests/explainers/__init__.py @@ -1,5 +1,4 @@ -""" This modules tests all the explainer types. -""" +"""This modules tests all the explainer types.""" import matplotlib diff --git a/tests/explainers/common.py b/tests/explainers/common.py index 54a939a95..d65672979 100644 --- a/tests/explainers/common.py +++ b/tests/explainers/common.py @@ -7,8 +7,7 @@ def basic_xgboost_scenario(max_samples=None, dataset=shap.datasets.adult): - """ Create a basic XGBoost model on a data set. - """ + """Create a basic XGBoost model on a data set.""" xgboost = pytest.importorskip('xgboost') # get a dataset on income prediction @@ -27,8 +26,7 @@ def basic_xgboost_scenario(max_samples=None, dataset=shap.datasets.adult): def test_additivity(explainer_type, model, masker, data, **kwargs): - """ Test explainer and masker for additivity on a single output prediction problem. - """ + """Test explainer and masker for additivity on a single output prediction problem.""" explainer = explainer_type(model, masker, **kwargs) shap_values = explainer(data) @@ -50,8 +48,7 @@ def test_additivity(explainer_type, model, masker, data, **kwargs): assert np.max(np.abs(shap_values.base_values + shap_values.values.sum(1) - model(data)) < 1e6) def test_interactions_additivity(explainer_type, model, masker, data, **kwargs): - """ Test explainer and masker for additivity on a single output prediction problem. - """ + """Test explainer and masker for additivity on a single output prediction problem.""" explainer = explainer_type(model, masker, **kwargs) shap_values = explainer(data, interactions=True) @@ -78,9 +75,7 @@ def test_interactions_additivity(explainer_type, model, masker, data, **kwargs): # assert np.max(np.abs(shap_values.base_values + shap_values.values.sum((1, 2)) - model.predict(X[:100])) < 1e6) def test_serialization(explainer_type, model, masker, data, rtol=1e-05, atol=1e-8, **kwargs): - """ Test serialization with a given explainer algorithm. - """ - + """Test serialization with a given explainer algorithm.""" explainer_kwargs = {k: v for k,v in kwargs.items() if k in ["algorithm"]} explainer_original = explainer_type(model, masker, **explainer_kwargs) shap_values_original = explainer_original(data[:1]) diff --git a/tests/explainers/conftest.py b/tests/explainers/conftest.py index 963c7e0d3..e7a507d0c 100644 --- a/tests/explainers/conftest.py +++ b/tests/explainers/conftest.py @@ -6,8 +6,7 @@ @pytest.mark.skipif(sys.platform == 'win32', reason="Integer division bug in HuggingFace on Windows") @pytest.fixture(scope="session") def basic_translation_scenario(): - """ Create a basic transformers translation model and tokenizer. - """ + """Create a basic transformers translation model and tokenizer.""" AutoTokenizer = pytest.importorskip("transformers").AutoTokenizer AutoModelForSeq2SeqLM = pytest.importorskip("transformers").AutoModelForSeq2SeqLM diff --git a/tests/explainers/test_deep.py b/tests/explainers/test_deep.py index 327a9963b..8d17f9136 100644 --- a/tests/explainers/test_deep.py +++ b/tests/explainers/test_deep.py @@ -1,5 +1,4 @@ -""" Tests for the Deep explainer. -""" +"""Tests for the Deep explainer.""" import numpy as np @@ -16,8 +15,7 @@ ############################ def test_tf_eager(random_seed): - """ This is a basic eager example from keras. - """ + """This is a basic eager example from keras.""" tf = pytest.importorskip('tensorflow') tf.compat.v1.random.set_random_seed(random_seed) @@ -43,8 +41,7 @@ def test_tf_eager(random_seed): def test_tf_keras_mnist_cnn(random_seed): - """ This is the basic mnist cnn example from keras. - """ + """This is the basic mnist cnn example from keras.""" tf = pytest.importorskip('tensorflow') rs = np.random.RandomState(random_seed) tf.compat.v1.random.set_random_seed(random_seed) @@ -136,9 +133,7 @@ def test_tf_keras_mnist_cnn(random_seed): @pytest.mark.parametrize("activation", ["relu", "elu", "selu"]) def test_tf_keras_activations(activation): - """Test verifying that a linear model with linear data gives the correct result. - """ - + """Test verifying that a linear model with linear data gives the correct result.""" # FIXME: this test should ideally pass with any random seed. See #2960 random_seed = 0 @@ -178,9 +173,7 @@ def test_tf_keras_activations(activation): def test_tf_keras_linear(): - """Test verifying that a linear model with linear data gives the correct result. - """ - + """Test verifying that a linear model with linear data gives the correct result.""" # FIXME: this test should ideally pass with any random seed. See #2960 random_seed = 0 @@ -224,8 +217,7 @@ def test_tf_keras_linear(): def test_tf_keras_imdb_lstm(random_seed): - """ Basic LSTM example using the keras API defined in tensorflow - """ + """Basic LSTM example using the keras API defined in tensorflow""" tf = pytest.importorskip('tensorflow') rs = np.random.RandomState(random_seed) tf.compat.v1.random.set_random_seed(random_seed) @@ -310,8 +302,7 @@ def test_tf_deep_multi_inputs_multi_outputs(): ####################### def _torch_cuda_available(): - """ Checks whether cuda is available. If so, torch-related tests are also tested on gpu. - """ + """Checks whether cuda is available. If so, torch-related tests are also tested on gpu.""" try: import torch @@ -334,16 +325,15 @@ def _torch_cuda_available(): @pytest.mark.parametrize("torch_device", TORCH_DEVICES) @pytest.mark.parametrize("interim", [True, False]) def test_pytorch_mnist_cnn(torch_device, interim): - """The same test as above, but for pytorch - """ + """The same test as above, but for pytorch""" torch = pytest.importorskip('torch') from torch import nn from torch.nn import functional as F class RandData: - """ Random test data. - """ + """Random test data.""" + def __init__(self, batch_size): self.current = 0 self.batch_size = batch_size @@ -359,8 +349,7 @@ def __next__(self): class Net(nn.Module): - """ Basic conv net. - """ + """Basic conv net.""" def __init__(self): super().__init__() @@ -384,8 +373,7 @@ def __init__(self): ) def forward(self, x): - """ Run the model. - """ + """Run the model.""" x = self.conv_layers(x) x = x.view(-1, 320) x = self.fc_layers(x) @@ -454,8 +442,7 @@ def train(model, device, train_loader, optimizer, _, cutoff=20): @pytest.mark.parametrize("torch_device", TORCH_DEVICES) def test_pytorch_custom_nested_models(torch_device): - """Testing single outputs - """ + """Testing single outputs""" torch = pytest.importorskip('torch') from sklearn.datasets import fetch_california_housing @@ -464,8 +451,8 @@ def test_pytorch_custom_nested_models(torch_device): from torch.utils.data import DataLoader, TensorDataset class CustomNet1(nn.Module): - """ Model 1. - """ + """Model 1.""" + def __init__(self, num_features): super().__init__() self.net = nn.Sequential( @@ -477,13 +464,12 @@ def __init__(self, num_features): ) def forward(self, X): - """ Run the model. - """ + """Run the model.""" return self.net(X.unsqueeze(1)).squeeze(1) class CustomNet2(nn.Module): - """ Model 2. - """ + """Model 2.""" + def __init__(self, num_features): super().__init__() self.net = nn.Sequential( @@ -492,13 +478,12 @@ def __init__(self, num_features): ) def forward(self, X): - """ Run the model. - """ + """Run the model.""" return self.net(X).unsqueeze(1) class CustomNet(nn.Module): - """ Model 3. - """ + """Model 3.""" + def __init__(self, num_features): super().__init__() self.net1 = CustomNet1(num_features) @@ -506,8 +491,7 @@ def __init__(self, num_features): self.maxpool2 = nn.MaxPool1d(kernel_size=2) def forward(self, X): - """ Run the model. - """ + """Run the model.""" x = self.net1(X) return self.maxpool2(self.net2(x)).squeeze(1) @@ -580,8 +564,7 @@ def train(model, device, train_loader, optimizer, epoch): @pytest.mark.parametrize("torch_device", TORCH_DEVICES) def test_pytorch_single_output(torch_device): - """Testing single outputs - """ + """Testing single outputs""" torch = pytest.importorskip('torch') from sklearn.datasets import fetch_california_housing @@ -590,8 +573,8 @@ def test_pytorch_single_output(torch_device): from torch.utils.data import DataLoader, TensorDataset class Net(nn.Module): - """ Test model. - """ + """Test model.""" + def __init__(self, num_features): super().__init__() self.linear = nn.Linear(num_features // 2, 2) @@ -602,8 +585,7 @@ def __init__(self, num_features): self.maxpool2 = nn.MaxPool1d(kernel_size=2) def forward(self, X): - """ Run the model. - """ + """Run the model.""" x = self.aapool1d(self.convt1d(self.conv1d(X.unsqueeze(1)))).squeeze(1) return self.maxpool2(self.linear(self.leaky_relu(x)).unsqueeze(1)).squeeze(1) @@ -676,8 +658,7 @@ def train(model, device, train_loader, optimizer, epoch): @pytest.mark.parametrize("torch_device", TORCH_DEVICES) @pytest.mark.parametrize("disconnected", [True, False]) def test_pytorch_multiple_inputs(torch_device, disconnected): - """ Check a multi-input scenario. - """ + """Check a multi-input scenario.""" torch = pytest.importorskip('torch') from sklearn.datasets import fetch_california_housing @@ -687,8 +668,8 @@ def test_pytorch_multiple_inputs(torch_device, disconnected): class Net(nn.Module): - """ Testing model. - """ + """Testing model.""" + def __init__(self, num_features, disconnected): super().__init__() self.disconnected = disconnected @@ -701,8 +682,7 @@ def __init__(self, num_features, disconnected): ) def forward(self, x1, x2): - """ Run the model. - """ + """Run the model.""" if self.disconnected: x = self.linear(x1).unsqueeze(1) else: diff --git a/tests/explainers/test_exact.py b/tests/explainers/test_exact.py index 1b183ed42..1a4246947 100644 --- a/tests/explainers/test_exact.py +++ b/tests/explainers/test_exact.py @@ -1,5 +1,4 @@ -""" Unit tests for the Exact explainer. -""" +"""Unit tests for the Exact explainer.""" import pickle diff --git a/tests/explainers/test_explainer.py b/tests/explainers/test_explainer.py index 53a86a46e..60b222542 100644 --- a/tests/explainers/test_explainer.py +++ b/tests/explainers/test_explainer.py @@ -1,5 +1,4 @@ -""" Tests for Explainer class. -""" +"""Tests for Explainer class.""" import pytest import sklearn @@ -9,7 +8,6 @@ def test_explainer_to_permutationexplainer(): """Checks that Explainer maps to PermutationExplainer as expected.""" - X_train, X_test, y_train, _ = sklearn.model_selection.train_test_split(*shap.datasets.adult(), test_size=0.1, random_state=0) lr = sklearn.linear_model.LogisticRegression(solver="liblinear") lr.fit(X_train, y_train) @@ -27,9 +25,7 @@ def test_explainer_to_permutationexplainer(): def test_wrapping_for_text_to_text_teacher_forcing_model(): - """ This tests using the Explainer class to auto wrap a masker in a text to text scenario. - """ - + """This tests using the Explainer class to auto wrap a masker in a text to text scenario.""" transformers = pytest.importorskip("transformers") def f(x): @@ -46,9 +42,7 @@ def f(x): assert shap.utils.safe_isinstance(explainer.masker, "shap.maskers.OutputComposite") def test_wrapping_for_topk_lm_model(): - """ This tests using the Explainer class to auto wrap a masker in a language modelling scenario. - """ - + """This tests using the Explainer class to auto wrap a masker in a language modelling scenario.""" transformers = pytest.importorskip("transformers") name = "hf-internal-testing/tiny-random-BartForCausalLM" diff --git a/tests/explainers/test_gpu_tree.py b/tests/explainers/test_gpu_tree.py index da4ae09b2..1cda9e567 100644 --- a/tests/explainers/test_gpu_tree.py +++ b/tests/explainers/test_gpu_tree.py @@ -1,5 +1,4 @@ -""" Test gpu accelerated tree functions. -""" +"""Test gpu accelerated tree functions.""" import numpy as np import pytest import sklearn diff --git a/tests/explainers/test_gradient.py b/tests/explainers/test_gradient.py index 6151c0949..90d5cc074 100644 --- a/tests/explainers/test_gradient.py +++ b/tests/explainers/test_gradient.py @@ -8,9 +8,7 @@ def test_tf_keras_mnist_cnn(random_seed): - """ This is the basic mnist cnn example from keras. - """ - + """This is the basic mnist cnn example from keras.""" tf = pytest.importorskip('tensorflow') rs = np.random.RandomState(random_seed) @@ -135,8 +133,7 @@ def test_tf_multi_inputs_multi_outputs(): def test_pytorch_mnist_cnn(): - """The same test as above, but for pytorch - """ + """The same test as above, but for pytorch""" # FIXME: this test should ideally pass with any random seed. See #2960 random_seed = 0 @@ -151,8 +148,8 @@ def test_pytorch_mnist_cnn(): batch_size = 128 class RandData: - """ Ranomd data for testing. - """ + """Ranomd data for testing.""" + def __init__(self, batch_size): self.current = 0 self.batch_size = batch_size @@ -189,8 +186,8 @@ def __next__(self): def run_test(train_loader, test_loader, interim): class Net(nn.Module): - """ A test model. - """ + """A test model.""" + def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 5, kernel_size=5) @@ -200,8 +197,7 @@ def __init__(self): self.fc2 = nn.Linear(20, 10) def forward(self, x): - """ Run the model. - """ + """Run the model.""" x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 160) @@ -261,8 +257,7 @@ def train(model, device, train_loader, optimizer, _, cutoff=20): def test_pytorch_multiple_inputs(random_seed): - """ Test multi-input scenarios.""" - + """Test multi-input scenarios.""" torch = pytest.importorskip('torch') from torch import nn @@ -274,15 +269,14 @@ def test_pytorch_multiple_inputs(random_seed): background = [torch.zeros(batch_size, 3), torch.zeros(batch_size, 4)] class Net(nn.Module): - """ A test model. - """ + """A test model.""" + def __init__(self): super().__init__() self.linear = nn.Linear(7, 1) def forward(self, x1, x2): - """ Run the model. - """ + """Run the model.""" return self.linear(torch.cat((x1, x2), dim=-1)) model = Net() @@ -301,8 +295,7 @@ def forward(self, x1, x2): def test_pytorch_multiple_inputs_multiple_outputs(random_seed): - """ Test multi-input scenarios.""" - + """Test multi-input scenarios.""" torch = pytest.importorskip('torch') from torch import nn @@ -342,7 +335,7 @@ def forward(self, input1, input2): @pytest.mark.parametrize("input_type", ["numpy", "dataframe"]) def test_tf_input(random_seed, input_type): - """ Test tabular (batch_size, features) pd.DataFrame and numpy input. """ + """Test tabular (batch_size, features) pd.DataFrame and numpy input.""" tf = pytest.importorskip('tensorflow') tf.random.set_seed(random_seed) diff --git a/tests/explainers/test_kernel.py b/tests/explainers/test_kernel.py index ff41a1e6e..c99f04e53 100644 --- a/tests/explainers/test_kernel.py +++ b/tests/explainers/test_kernel.py @@ -8,23 +8,19 @@ def test_null_model_small(): - """ Test a small null model. - """ + """Test a small null model.""" explainer = shap.KernelExplainer(lambda x: np.zeros(x.shape[0]), np.ones((2, 4)), nsamples=100) e = explainer.explain(np.ones((1, 4))) assert np.sum(np.abs(e)) < 1e-8 def test_null_model(): - """ Test a larger null model. - """ + """Test a larger null model.""" explainer = shap.KernelExplainer(lambda x: np.zeros(x.shape[0]), np.ones((2, 10)), nsamples=100) e = explainer.explain(np.ones((1, 10))) assert np.sum(np.abs(e)) < 1e-8 def test_front_page_model_agnostic(): - """ Test the ReadMe kernel expainer example. - """ - + """Test the ReadMe kernel expainer example.""" # print the JS visualization code to the notebook shap.initjs() @@ -42,9 +38,7 @@ def test_front_page_model_agnostic(): shap.force_plot(explainer.expected_value[0], shap_values[0, :, 0], X_test.iloc[0, :], link="logit") def test_front_page_model_agnostic_rank(): - """ Test the rank regularized explanation of the ReadMe example. - """ - + """Test the rank regularized explanation of the ReadMe example.""" # print the JS visualization code to the notebook shap.initjs() @@ -61,9 +55,7 @@ def test_front_page_model_agnostic_rank(): shap.force_plot(explainer.expected_value[0], shap_values[0, :, 0], X_test.iloc[0, :], link="logit") def test_kernel_shap_with_call_method(): - """ Test the __call__ method of the Kernel class - """ - + """Test the __call__ method of the Kernel class""" # print the JS visualization code to the notebook shap.initjs() @@ -88,8 +80,7 @@ def test_kernel_shap_with_call_method(): np.testing.assert_allclose(sigm(shap_values.sum(1) + explainer.expected_value), outputs) def test_kernel_shap_with_dataframe(random_seed): - """ Test with a Pandas DataFrame. - """ + """Test with a Pandas DataFrame.""" rs = np.random.RandomState(random_seed) df_X = pd.DataFrame(rs.random((10, 3)), columns=list('abc')) @@ -128,9 +119,7 @@ def test_kernel_shap_with_dataframe_explanation(random_seed): shap.plots.scatter(explanation[:, "a"], show=False) def test_kernel_shap_with_a1a_sparse_zero_background(): - """ Test with a sparse matrix for the background. - """ - + """Test with a sparse matrix for the background.""" X, y = shap.datasets.a1a() x_train, x_test, y_train, _ = sklearn.model_selection.train_test_split(X, y, test_size=0.01, random_state=0) linear_model = sklearn.linear_model.LinearRegression() @@ -143,8 +132,7 @@ def test_kernel_shap_with_a1a_sparse_zero_background(): explainer.shap_values(x_test) def test_kernel_shap_with_a1a_sparse_nonzero_background(): - """ Check with a sparse non zero background matrix. - """ + """Check with a sparse non zero background matrix.""" np.set_printoptions(threshold=100000) X, y = shap.datasets.a1a() @@ -168,9 +156,7 @@ def dense_to_sparse_predict(data): assert np.allclose(shap_values, shap_values_dense, rtol=1e-02, atol=1e-01) def test_kernel_shap_with_high_dim_sparse(): - """ Verifies we can run on very sparse data produced from feature hashing. - """ - + """Verifies we can run on very sparse data produced from feature hashing.""" remove = ('headers', 'footers', 'quotes') categories = [ 'alt.atheism', @@ -193,9 +179,7 @@ def test_kernel_shap_with_high_dim_sparse(): _ = explainer.shap_values(x_test) def test_kernel_sparse_vs_dense_multirow_background(): - """ Mix sparse and dense matrix values. - """ - + """Mix sparse and dense matrix values.""" # train a logistic regression classifier X_train, X_test, Y_train, _ = sklearn.model_selection.train_test_split(*shap.datasets.iris(), test_size=0.1, random_state=0) lr = sklearn.linear_model.LogisticRegression(solver='lbfgs') @@ -223,7 +207,7 @@ def test_kernel_sparse_vs_dense_multirow_background(): def test_linear(random_seed): - """ Tests that KernelExplainer returns the correct result when the model is linear. + """Tests that KernelExplainer returns the correct result when the model is linear. (as per corollary 1 of https://arxiv.org/abs/1705.07874) """ @@ -246,9 +230,7 @@ def f(x): def test_non_numeric(): - """ Test using non-numeric data. - """ - + """Test using non-numeric data.""" # create dummy data X = np.array([['A', '0', '0'], ['A', '1', '0'], ['B', '0', '0'], ['B', '1', '0'], ['A', '1', '0']]) y = np.array([0, 1, 2, 3, 4]) @@ -297,8 +279,7 @@ def test_kernel_explainer_with_tensors(): explainer.shap_values(X[:1]) def test_kernel_multiclass_single_row(): - """ Check a multi-input scenario. - """ + """Check a multi-input scenario.""" X, y = shap.datasets.iris() lr = sklearn.linear_model.LogisticRegression(solver='lbfgs') @@ -311,8 +292,7 @@ def test_kernel_multiclass_single_row(): def test_kernel_multiclass_multiple_rows(): - """ Check a multi-input scenario. - """ + """Check a multi-input scenario.""" X, y = shap.datasets.iris() lr = sklearn.linear_model.LogisticRegression(solver='lbfgs') diff --git a/tests/explainers/test_linear.py b/tests/explainers/test_linear.py index efc65d8fb..bd028db76 100644 --- a/tests/explainers/test_linear.py +++ b/tests/explainers/test_linear.py @@ -1,5 +1,4 @@ -""" Unit tests for the Linear explainer. -""" +"""Unit tests for the Linear explainer.""" import numpy as np import pytest import scipy.special @@ -143,8 +142,7 @@ def test_shape_values_linear_many_features(): np.testing.assert_allclose(expected - values, 0, atol=0.01) def test_single_feature(random_seed): - """ Make sure things work with a univariate linear regression. - """ + """Make sure things work with a univariate linear regression.""" Ridge = pytest.importorskip('sklearn.linear_model').Ridge # generate linear data @@ -163,8 +161,7 @@ def test_single_feature(random_seed): assert np.max(np.abs(explainer.expected_value + shap_values.sum(1) - model.predict(X))) < 1e-6 def test_sparse(): - """ Validate running LinearExplainer on scipy sparse data - """ + """Validate running LinearExplainer on scipy sparse data""" make_multilabel_classification = pytest.importorskip('sklearn.datasets').make_multilabel_classification LogisticRegression = pytest.importorskip('sklearn.linear_model').LogisticRegression @@ -187,8 +184,7 @@ def test_sparse(): @pytest.mark.xfail(reason="This should pass but it doesn't.") def test_sparse_multi_class(): - """ Validate running LinearExplainer on scipy sparse data - """ + """Validate running LinearExplainer on scipy sparse data""" make_multilabel_classification = pytest.importorskip('sklearn.datasets').make_multilabel_classification LogisticRegression = pytest.importorskip('sklearn.linear_model').LogisticRegression diff --git a/tests/explainers/test_partition.py b/tests/explainers/test_partition.py index e98d720f3..ed4ac9b1b 100644 --- a/tests/explainers/test_partition.py +++ b/tests/explainers/test_partition.py @@ -1,5 +1,4 @@ -""" This file contains tests for partition explainer. -""" +"""This file contains tests for partition explainer.""" import pickle diff --git a/tests/explainers/test_permutation.py b/tests/explainers/test_permutation.py index e15dfb2b0..1689272af 100644 --- a/tests/explainers/test_permutation.py +++ b/tests/explainers/test_permutation.py @@ -1,5 +1,4 @@ -""" Unit tests for the Permutation explainer. -""" +"""Unit tests for the Permutation explainer.""" import pickle @@ -11,8 +10,7 @@ def test_exact_second_order(random_seed): - """ This tests that the Perumtation explain gives exact answers for second order functions. - """ + """This tests that the Perumtation explain gives exact answers for second order functions.""" rs = np.random.RandomState(random_seed) data = rs.randint(0, 2, size=(100,5)) def model(data): diff --git a/tests/explainers/test_sampling.py b/tests/explainers/test_sampling.py index b139fb1cd..14fcce445 100644 --- a/tests/explainers/test_sampling.py +++ b/tests/explainers/test_sampling.py @@ -1,5 +1,4 @@ -""" Unit tests for the Sampling explainer. -""" +"""Unit tests for the Sampling explainer.""" import numpy as np import pytest diff --git a/tests/explainers/test_tree.py b/tests/explainers/test_tree.py index 5b6ce4aa2..924e2eda3 100644 --- a/tests/explainers/test_tree.py +++ b/tests/explainers/test_tree.py @@ -419,8 +419,7 @@ def test_catboost_interactions(): def _average_path_length(n_samples_leaf): - """ - Vendored from: https://github.com/scikit-learn/scikit-learn/blob/399131c8545cd525724e4bacf553416c512ac82c/sklearn/ensemble/_iforest.py#L531 + """Vendored from: https://github.com/scikit-learn/scikit-learn/blob/399131c8545cd525724e4bacf553416c512ac82c/sklearn/ensemble/_iforest.py#L531 For use in isolation forest tests. """ @@ -574,7 +573,7 @@ def test_provided_background_independent_prob_output(): def test_single_tree_compare_with_kernel_shap(): - """ Compare with Kernel SHAP, which makes the same independence assumptions + """Compare with Kernel SHAP, which makes the same independence assumptions as Independent Tree SHAP. Namely, they both assume independence between the set being conditioned on, and the remainder set. """ @@ -616,7 +615,7 @@ def f(inp): def test_several_trees(): - """ Make sure Independent Tree SHAP sums up to the correct value for + """Make sure Independent Tree SHAP sums up to the correct value for larger models (20 trees). """ # FIXME: this test should ideally pass with any random seed. See #2960 @@ -650,7 +649,7 @@ def test_several_trees(): def test_single_tree_nonlinear_transformations(): - """ Make sure Independent Tree SHAP single trees with non-linear + """Make sure Independent Tree SHAP single trees with non-linear transformations. """ # Supported non-linear transforms @@ -757,7 +756,6 @@ def test_singletree_lightgbm_basic(self): """A basic test for checking that a LightGBM `dump_model()["tree_info"]` dictionary is parsed properly into a `SingleTree` object. """ - # Stump (only root node) tree sample_tree = { "tree_index": 256, @@ -922,9 +920,7 @@ def transform(self, X, y=None, **fit_params): ) def test_sklearn_random_forest_newsgroups(self): - """ - note: this test used to fail in native TreeExplainer code due to memory corruption - """ + """note: this test used to fail in native TreeExplainer code due to memory corruption""" newsgroups_train, newsgroups_test, _ = create_binary_newsgroups_data() pipeline = self._create_vectorizer_for_randomforestclassifier() pipeline.fit(newsgroups_train.data, newsgroups_train.target) @@ -1200,6 +1196,7 @@ class TestExplainerXGBoost: * XGBRFClassifier * XGBRanker """ + xgboost = pytest.importorskip("xgboost") regressors = [xgboost.XGBRegressor, xgboost.XGBRFRegressor] @@ -1226,11 +1223,9 @@ def test_xgboost_regression(self, Reg): @pytest.mark.parametrize("Clf", classifiers) def test_xgboost_dmatrix_propagation(self, Clf): - """ - Test that xgboost sklearn attributues are properly passed to the DMatrix + """Test that xgboost sklearn attributues are properly passed to the DMatrix initiated during shap value calculation. see GH #3313 """ - X, y = shap.datasets.adult(n_points=100) # Randomly add missing data to the input where missing data is encoded as 1e-8 @@ -1430,7 +1425,8 @@ def test_explanation_data_not_dmatrix(self, random_seed): """Checks that DMatrix is not stored in Explanation.data after TreeExplainer.__call__, since it is not supported by our plotting functions. - See GH #3357 for more information.""" + See GH #3357 for more information. + """ xgboost = pytest.importorskip("xgboost") rs = np.random.RandomState(random_seed) diff --git a/tests/maskers/test_custom.py b/tests/maskers/test_custom.py index 40ca868c1..50def7451 100644 --- a/tests/maskers/test_custom.py +++ b/tests/maskers/test_custom.py @@ -1,5 +1,4 @@ -""" This file contains tests for custom (user supplied) maskers. -""" +"""This file contains tests for custom (user supplied) maskers.""" import numpy as np @@ -7,9 +6,7 @@ def test_raw_function(): - """ Make sure passing a simple masking function works. - """ - + """Make sure passing a simple masking function works.""" X, _ = shap.datasets.california(n_points=500) def test(X): diff --git a/tests/maskers/test_fixed_composite.py b/tests/maskers/test_fixed_composite.py index 980ff65f2..8160e8544 100644 --- a/tests/maskers/test_fixed_composite.py +++ b/tests/maskers/test_fixed_composite.py @@ -1,5 +1,4 @@ -""" This file contains tests for the FixedComposite masker. -""" +"""This file contains tests for the FixedComposite masker.""" import tempfile @@ -11,9 +10,7 @@ @pytest.mark.skip(reason="fails on travis and I don't know why yet...Ryan might need to take a look since this API will change soon anyway") def test_fixed_composite_masker_call(): - """ Test to make sure the FixedComposite masker works when masking everything. - """ - + """Test to make sure the FixedComposite masker works when masking everything.""" AutoTokenizer = pytest.importorskip("transformers").AutoTokenizer args = ("This is a test statement for fixed composite masker",) @@ -30,9 +27,7 @@ def test_fixed_composite_masker_call(): assert fixed_composite_masked_output == expected_fixed_composite_masked_output def test_serialization_fixedcomposite_masker(): - """ Make sure fixedcomposite serialization works. - """ - + """Make sure fixedcomposite serialization works.""" AutoTokenizer = pytest.importorskip("transformers").AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased", use_fast=False) diff --git a/tests/maskers/test_image.py b/tests/maskers/test_image.py index cff56571e..8e0e56f70 100644 --- a/tests/maskers/test_image.py +++ b/tests/maskers/test_image.py @@ -1,5 +1,4 @@ -""" This file contains tests for the Image masker. -""" +"""This file contains tests for the Image masker.""" import tempfile @@ -15,9 +14,7 @@ pytestmark = pytest.mark.skip("opencv not installed") def test_serialization_image_masker_inpaint_telea(): - """ Make sure image serialization works with inpaint telea mask. - """ - + """Make sure image serialization works with inpaint telea mask.""" test_image_height = 500 test_image_width = 500 test_data = np.ones((test_image_height, test_image_width, 3)) * 50 @@ -44,9 +41,7 @@ def test_serialization_image_masker_inpaint_telea(): assert np.array_equal(original_image_masker(mask, test_data), new_image_masker(mask, test_data)) def test_serialization_image_masker_inpaint_ns(): - """ Make sure image serialization works with inpaint ns mask. - """ - + """Make sure image serialization works with inpaint ns mask.""" test_image_height = 500 test_image_width = 500 test_data = np.ones((test_image_height, test_image_width, 3)) * 50 @@ -73,9 +68,7 @@ def test_serialization_image_masker_inpaint_ns(): assert np.array_equal(original_image_masker(mask, test_data), new_image_masker(mask, test_data)) def test_serialization_image_masker_blur(): - """ Make sure image serialization works with blur mask. - """ - + """Make sure image serialization works with blur mask.""" test_image_height = 500 test_image_width = 500 test_data = np.ones((test_image_height, test_image_width, 3)) * 50 @@ -102,9 +95,7 @@ def test_serialization_image_masker_blur(): assert np.array_equal(original_image_masker(mask, test_data), new_image_masker(mask, test_data)) def test_serialization_image_masker_mask(): - """ Make sure image serialization works. - """ - + """Make sure image serialization works.""" test_image_height = 500 test_image_width = 500 test_data = np.ones((test_image_height, test_image_width, 3)) * 50 diff --git a/tests/maskers/test_tabular.py b/tests/maskers/test_tabular.py index a951576cb..c1e118f9e 100644 --- a/tests/maskers/test_tabular.py +++ b/tests/maskers/test_tabular.py @@ -1,5 +1,4 @@ -""" This file contains tests for the Tabular maskers. -""" +"""This file contains tests for the Tabular maskers.""" import tempfile @@ -9,9 +8,7 @@ def test_serialization_independent_masker_dataframe(): - """ Test the serialization of an Independent masker based on a data frame. - """ - + """Test the serialization of an Independent masker based on a data frame.""" X, _ = shap.datasets.california(n_points=500) # initialize independent masker @@ -35,10 +32,7 @@ def test_serialization_independent_masker_dataframe(): assert np.array_equal(original_independent_masker(mask, X[:1].values[0])[1], new_independent_masker(mask, X[:1].values[0])[1]) def test_serialization_independent_masker_numpy(): - """ Test the serialization of an Independent masker based on a numpy array. - """ - - + """Test the serialization of an Independent masker based on a numpy array.""" X, _ = shap.datasets.california(n_points=500) X = X.values @@ -64,9 +58,7 @@ def test_serialization_independent_masker_numpy(): assert np.array_equal(original_independent_masker(mask, X[0])[0], new_independent_masker(mask, X[0])[0]) def test_serialization_partion_masker_dataframe(): - """ Test the serialization of a Partition masker based on a DataFrame. - """ - + """Test the serialization of a Partition masker based on a DataFrame.""" X, _ = shap.datasets.california(n_points=500) # initialize partition masker @@ -90,9 +82,7 @@ def test_serialization_partion_masker_dataframe(): assert np.array_equal(original_partition_masker(mask, X[:1].values[0])[1], new_partition_masker(mask, X[:1].values[0])[1]) def test_serialization_partion_masker_numpy(): - """ Test the serialization of a Partition masker based on a numpy array. - """ - + """Test the serialization of a Partition masker based on a numpy array.""" X, _ = shap.datasets.california(n_points=500) X = X.values diff --git a/tests/maskers/test_text.py b/tests/maskers/test_text.py index 4be087778..547f35770 100644 --- a/tests/maskers/test_text.py +++ b/tests/maskers/test_text.py @@ -1,5 +1,4 @@ -""" This file contains tests for the Text masker. -""" +"""This file contains tests for the Text masker.""" import tempfile @@ -10,9 +9,7 @@ def test_method_token_segments_pretrained_tokenizer(): - """ Check that the Text masker produces the same segments as its non-fast pretrained tokenizer. - """ - + """Check that the Text masker produces the same segments as its non-fast pretrained tokenizer.""" AutoTokenizer = pytest.importorskip("transformers").AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased", use_fast=False) @@ -26,9 +23,7 @@ def test_method_token_segments_pretrained_tokenizer(): def test_method_token_segments_pretrained_tokenizer_fast(): - """ Check that the Text masker produces the same segments as its fast pretrained tokenizer. - """ - + """Check that the Text masker produces the same segments as its fast pretrained tokenizer.""" AutoTokenizer = pytest.importorskip("transformers").AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=True) @@ -42,9 +37,7 @@ def test_method_token_segments_pretrained_tokenizer_fast(): def test_masker_call_pretrained_tokenizer(): - """ Check that the Text masker with a non-fast pretrained tokenizer masks correctly. - """ - + """Check that the Text masker with a non-fast pretrained tokenizer masks correctly.""" AutoTokenizer = pytest.importorskip("transformers").AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=False) @@ -59,9 +52,7 @@ def test_masker_call_pretrained_tokenizer(): assert output_masked_text[0] == correct_masked_text def test_masker_call_pretrained_tokenizer_fast(): - """ Check that the Text masker with a fast pretrained tokenizer masks correctly. - """ - + """Check that the Text masker with a fast pretrained tokenizer masks correctly.""" AutoTokenizer = pytest.importorskip("transformers").AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=True) @@ -78,9 +69,7 @@ def test_masker_call_pretrained_tokenizer_fast(): @pytest.mark.filterwarnings(r"ignore:Recommended. pip install sacremoses") def test_sentencepiece_tokenizer_output(): - """ Tests for output for sentencepiece tokenizers to not have '_' in output of masker when passed a mask of ones. - """ - + """Tests for output for sentencepiece tokenizers to not have '_' in output of masker when passed a mask of ones.""" AutoTokenizer = pytest.importorskip("transformers").AutoTokenizer pytest.importorskip("sentencepiece") @@ -96,9 +85,7 @@ def test_sentencepiece_tokenizer_output(): assert sentencepiece_tokenizer_output_processed[0][0] == expected_sentencepiece_tokenizer_output_processed def test_keep_prefix_suffix_tokenizer_parsing(): - """ Checks parsed keep prefix and keep suffix for different tokenizers. - """ - + """Checks parsed keep prefix and keep suffix for different tokenizers.""" AutoTokenizer = pytest.importorskip("transformers").AutoTokenizer tokenizer_mt = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es") @@ -123,9 +110,7 @@ def test_keep_prefix_suffix_tokenizer_parsing(): def test_text_infill_with_collapse_mask_token(): - """ Tests for different text infilling output combinations with collapsing mask token. - """ - + """Tests for different text infilling output combinations with collapsing mask token.""" AutoTokenizer = pytest.importorskip("transformers").AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("gpt2") @@ -167,9 +152,7 @@ def test_text_infill_with_collapse_mask_token(): text_infilled_ex3_mist== expected_text_infilled_ex3 and text_infilled_ex4_mist == expected_text_infilled_ex4 def test_serialization_text_masker(): - """ Make sure text serialization works. - """ - + """Make sure text serialization works.""" AutoTokenizer = pytest.importorskip("transformers").AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased", use_fast=False) @@ -195,9 +178,7 @@ def test_serialization_text_masker(): assert original_masked_output == new_masked_output def test_serialization_text_masker_custom_mask(): - """ Make sure text serialization works with custom mask. - """ - + """Make sure text serialization works with custom mask.""" AutoTokenizer = pytest.importorskip("transformers").AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased", use_fast=True) @@ -222,9 +203,7 @@ def test_serialization_text_masker_custom_mask(): assert original_masked_output == new_masked_output def test_serialization_text_masker_collapse_mask_token(): - """ Make sure text serialization works with collapse mask token. - """ - + """Make sure text serialization works with collapse mask token.""" AutoTokenizer = pytest.importorskip("transformers").AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased", use_fast=True) diff --git a/tests/models/test_teacher_forcing_logits.py b/tests/models/test_teacher_forcing_logits.py index 1c2e55f31..8314bea77 100644 --- a/tests/models/test_teacher_forcing_logits.py +++ b/tests/models/test_teacher_forcing_logits.py @@ -1,5 +1,4 @@ -""" This file contains tests for the TeacherForcingLogits class. -""" +"""This file contains tests for the TeacherForcingLogits class.""" import numpy as np import pytest @@ -8,9 +7,7 @@ def test_method_get_teacher_forced_logits_for_encoder_decoder_model(): - """ Tests if get_teacher_forced_logits() works for encoder-decoder models. - """ - + """Tests if get_teacher_forced_logits() works for encoder-decoder models.""" transformers = pytest.importorskip("transformers") requests = pytest.importorskip("requests") @@ -32,9 +29,7 @@ def test_method_get_teacher_forced_logits_for_encoder_decoder_model(): assert not np.isnan(np.sum(logits)) def test_method_get_teacher_forced_logits_for_decoder_model(): - """ Tests if get_teacher_forced_logits() works for decoder only models. - """ - + """Tests if get_teacher_forced_logits() works for decoder only models.""" transformers = pytest.importorskip("transformers") requests = pytest.importorskip("requests") diff --git a/tests/models/test_text_generation.py b/tests/models/test_text_generation.py index 07dca4e9d..1524d99e8 100644 --- a/tests/models/test_text_generation.py +++ b/tests/models/test_text_generation.py @@ -1,5 +1,4 @@ -""" This file contains tests for the TextGeneration class. -""" +"""This file contains tests for the TextGeneration class.""" import sys @@ -10,10 +9,9 @@ @pytest.mark.skipif(sys.platform == 'win32', reason="Integer division bug in HuggingFace on Windows") def test_call_function_text_generation(): - """ Tests if target sentence from model and model wrapped in a function (mimics model agnostic scenario) - produces the same ids. + """Tests if target sentence from model and model wrapped in a function (mimics model agnostic scenario) + produces the same ids. """ - torch = pytest.importorskip("torch") transformers = pytest.importorskip("transformers") diff --git a/tests/plots/__init__.py b/tests/plots/__init__.py index 63062f7e2..5d3cbba2e 100644 --- a/tests/plots/__init__.py +++ b/tests/plots/__init__.py @@ -1,5 +1,4 @@ -""" -The plotting baseline folder is generated using the pytest-mpl plugin. +"""The plotting baseline folder is generated using the pytest-mpl plugin. If you have made changes to the plots, the baseline folder will need rebuilding before the tests can run successfully. Run the following in the root directory: diff --git a/tests/plots/conftest.py b/tests/plots/conftest.py index 5b513ea40..cef433202 100644 --- a/tests/plots/conftest.py +++ b/tests/plots/conftest.py @@ -13,8 +13,7 @@ def close_matplotlib_plots_after_tests(): @pytest.fixture() def explainer(): - """ A simple explainer to be used as a test fixture. - """ + """A simple explainer to be used as a test fixture.""" xgboost = pytest.importorskip('xgboost') # get a dataset on income prediction X, y = shap.datasets.adult() diff --git a/tests/plots/test_bar.py b/tests/plots/test_bar.py index 418a4bcd3..d93f480fe 100644 --- a/tests/plots/test_bar.py +++ b/tests/plots/test_bar.py @@ -1,5 +1,4 @@ -"""This file contains tests for the bar plot. -""" +"""This file contains tests for the bar plot.""" import matplotlib.pyplot as plt import numpy as np import pytest diff --git a/tests/plots/test_beeswarm.py b/tests/plots/test_beeswarm.py index dd297aa74..ece2693e8 100644 --- a/tests/plots/test_beeswarm.py +++ b/tests/plots/test_beeswarm.py @@ -7,8 +7,7 @@ def test_beeswarm_input_is_explanation(): - """Checks an error is raised if a non-Explanation object is passed as input. - """ + """Checks an error is raised if a non-Explanation object is passed as input.""" with pytest.raises( TypeError, match="beeswarm plot requires an `Explanation` object", diff --git a/tests/plots/test_decision.py b/tests/plots/test_decision.py index 237f648ee..c68d780f1 100644 --- a/tests/plots/test_decision.py +++ b/tests/plots/test_decision.py @@ -7,8 +7,7 @@ matplotlib.use('Agg') def test_random_decision(random_seed): - """ Make sure the decision plot does not crash on random data. - """ + """Make sure the decision plot does not crash on random data.""" rs = np.random.RandomState(random_seed) shap.decision_plot( 0, diff --git a/tests/plots/test_dependence.py b/tests/plots/test_dependence.py index 73210c87a..b578123bd 100644 --- a/tests/plots/test_dependence.py +++ b/tests/plots/test_dependence.py @@ -7,18 +7,15 @@ matplotlib.use('Agg') def test_random_dependence(): - """ Make sure a dependence plot does not crash. - """ + """Make sure a dependence plot does not crash.""" shap.dependence_plot(0, np.random.randn(20, 5), np.random.randn(20, 5), show=False) def test_random_dependence_no_interaction(): - """ Make sure a dependence plot does not crash when we are not showing interactions. - """ + """Make sure a dependence plot does not crash when we are not showing interactions.""" shap.dependence_plot(0, np.random.randn(20, 5), np.random.randn(20, 5), show=False, interaction_index=None) def test_dependence_use_line_collection_bug(): - """ Make sure a dependence plot does not crash. - """ + """Make sure a dependence plot does not crash.""" # GH 3368 sklearn = pytest.importorskip("sklearn") diff --git a/tests/plots/test_dependence_string_features.py b/tests/plots/test_dependence_string_features.py index f60d0ff05..3f2cf5b6c 100644 --- a/tests/plots/test_dependence_string_features.py +++ b/tests/plots/test_dependence_string_features.py @@ -8,8 +8,7 @@ def test_dependence_one_string_feature(): - """ Test the dependence plot with a string feature. - """ + """Test the dependence plot with a string feature.""" X = _create_sample_dataset(string_features={"Sex"}) shap.dependence_plot( @@ -22,8 +21,7 @@ def test_dependence_one_string_feature(): def test_dependence_two_string_features(): - """ Test the dependence plot with two string features. - """ + """Test the dependence plot with two string features.""" X = _create_sample_dataset(string_features={"Sex", "Blood group"}) shap.dependence_plot( @@ -36,8 +34,7 @@ def test_dependence_two_string_features(): def test_dependence_one_string_feature_no_interaction(): - """ Test the dependence plot with no interactions. - """ + """Test the dependence plot with no interactions.""" X = _create_sample_dataset(string_features={"Sex"}) shap.dependence_plot( @@ -50,8 +47,7 @@ def test_dependence_one_string_feature_no_interaction(): def test_dependence_one_string_feature_auto_interaction(): - """ Test the dependence plot with auto interaction detection. - """ + """Test the dependence plot with auto interaction detection.""" X = _create_sample_dataset(string_features={"Sex"}) shap.dependence_plot( @@ -64,8 +60,7 @@ def test_dependence_one_string_feature_auto_interaction(): def test_approximate_interactions(): - """ Test the approximate interaction detector. - """ + """Test the approximate interaction detector.""" X_no_string_features = _create_sample_dataset(string_features={}) X_one_string_feature = _create_sample_dataset(string_features={"Sex"}) X_two_string_features = _create_sample_dataset(string_features={"Sex", "Blood group"}) diff --git a/tests/plots/test_force.py b/tests/plots/test_force.py index 2ee643a15..5f2823c03 100644 --- a/tests/plots/test_force.py +++ b/tests/plots/test_force.py @@ -52,9 +52,7 @@ def test_verify_valid_cmap(cmap, exp_ctx): def test_random_force_plot_mpl_with_data(): - """ Test if force plot with matplotlib works. - """ - + """Test if force plot with matplotlib works.""" RandomForestRegressor = pytest.importorskip('sklearn.ensemble').RandomForestRegressor # train model @@ -72,9 +70,7 @@ def test_random_force_plot_mpl_with_data(): shap.force_plot([1, 1], shap_values, X.iloc[0, :], show=False) def test_random_force_plot_mpl_text_rotation_with_data(): - """ Test if force plot with matplotlib works when supplied with text_rotation. - """ - + """Test if force plot with matplotlib works when supplied with text_rotation.""" RandomForestRegressor = pytest.importorskip('sklearn.ensemble').RandomForestRegressor # train model diff --git a/tests/plots/test_heatmap.py b/tests/plots/test_heatmap.py index 43ae88e1d..0c3a41495 100644 --- a/tests/plots/test_heatmap.py +++ b/tests/plots/test_heatmap.py @@ -7,8 +7,7 @@ @pytest.mark.mpl_image_compare def test_heatmap(explainer): - """ Make sure the heatmap plot is unchanged. - """ + """Make sure the heatmap plot is unchanged.""" fig = plt.figure() shap_values = explainer(explainer.data) shap.plots.heatmap(shap_values, show=False) @@ -18,8 +17,7 @@ def test_heatmap(explainer): @pytest.mark.mpl_image_compare def test_heatmap_feature_order(explainer): - """ Make sure the heatmap plot is unchanged when we apply a feature ordering. - """ + """Make sure the heatmap plot is unchanged when we apply a feature ordering.""" fig = plt.figure() shap_values = explainer(explainer.data) shap.plots.heatmap(shap_values, max_display=5, diff --git a/tests/plots/test_image.py b/tests/plots/test_image.py index 4138abaab..ef3802a3d 100644 --- a/tests/plots/test_image.py +++ b/tests/plots/test_image.py @@ -7,24 +7,19 @@ def test_random_single_image(): - """ Just make sure the image_plot function doesn't crash. - """ - + """Just make sure the image_plot function doesn't crash.""" shap.image_plot(np.random.randn(3, 20, 20), np.random.randn(3, 20, 20), show=False) def test_random_multi_image(): - """ Just make sure the image_plot function doesn't crash. - """ - + """Just make sure the image_plot function doesn't crash.""" shap.image_plot([np.random.randn(3, 20, 20) for i in range(3)], np.random.randn(3, 20, 20), show=False) def test_image_to_text_single(): - """ Just make sure the image_to_text function doesn't crash. - """ + """Just make sure the image_to_text function doesn't crash.""" class MockImageExplanation: - """ Fake explanation object. - """ + """Fake explanation object.""" + def __init__(self, data, values, output_names): self.data = data self.values = values diff --git a/tests/plots/test_summary.py b/tests/plots/test_summary.py index d9d546a8f..b5fb7150d 100644 --- a/tests/plots/test_summary.py +++ b/tests/plots/test_summary.py @@ -7,8 +7,7 @@ @pytest.mark.mpl_image_compare def test_random_summary(): - """ Just make sure the summary_plot function doesn't crash. - """ + """Just make sure the summary_plot function doesn't crash.""" np.random.seed(0) fig = plt.figure() shap.summary_plot(np.random.randn(20, 5), show=False) @@ -18,8 +17,7 @@ def test_random_summary(): @pytest.mark.mpl_image_compare def test_random_summary_with_data(): - """ Just make sure the summary_plot function doesn't crash with data. - """ + """Just make sure the summary_plot function doesn't crash with data.""" np.random.seed(0) fig = plt.figure() shap.summary_plot(np.random.randn(20, 5), np.random.randn(20, 5), show=False) @@ -29,8 +27,7 @@ def test_random_summary_with_data(): @pytest.mark.mpl_image_compare def test_random_multi_class_summary(): - """ Check a multiclass run. - """ + """Check a multiclass run.""" np.random.seed(0) fig = plt.figure() shap.summary_plot([np.random.randn(20, 5) for i in range(3)], np.random.randn(20, 5), show=False) @@ -40,8 +37,8 @@ def test_random_multi_class_summary(): @pytest.mark.mpl_image_compare def test_random_multi_class_summary_legend_decimals(): - """ Check the functionality of printing the legend in the plot of a multiclass run when - all the SHAP values are smaller than 1. + """Check the functionality of printing the legend in the plot of a multiclass run when + all the SHAP values are smaller than 1. """ np.random.seed(0) fig = plt.figure() @@ -53,8 +50,8 @@ def test_random_multi_class_summary_legend_decimals(): @pytest.mark.mpl_image_compare def test_random_multi_class_summary_legend(): - """ Check the functionality of printing the legend in the plot of a multiclass run when - SHAP values are bigger than 1. + """Check the functionality of printing the legend in the plot of a multiclass run when + SHAP values are bigger than 1. """ np.random.seed(0) fig = plt.figure() @@ -66,8 +63,7 @@ def test_random_multi_class_summary_legend(): @pytest.mark.mpl_image_compare def test_random_summary_bar_with_data(): - """ Check a bar chart. - """ + """Check a bar chart.""" np.random.seed(0) fig = plt.figure() shap.summary_plot(np.random.randn(20, 5), np.random.randn(20, 5), plot_type="bar", show=False) @@ -77,8 +73,7 @@ def test_random_summary_bar_with_data(): @pytest.mark.mpl_image_compare def test_random_summary_dot_with_data(): - """ Check a dot chart. - """ + """Check a dot chart.""" np.random.seed(0) fig = plt.figure() shap.summary_plot(np.random.randn(20, 5), np.random.randn(20, 5), plot_type="dot", show=False) @@ -88,8 +83,7 @@ def test_random_summary_dot_with_data(): @pytest.mark.mpl_image_compare def test_random_summary_violin_with_data(): - """ Check a violin chart. - """ + """Check a violin chart.""" np.random.seed(0) fig = plt.figure() shap.summary_plot(np.random.randn(20, 5), np.random.randn(20, 5), plot_type="violin", show=False) @@ -99,8 +93,7 @@ def test_random_summary_violin_with_data(): @pytest.mark.mpl_image_compare def test_random_summary_layered_violin_with_data(): - """ Check a layered violin chart. - """ + """Check a layered violin chart.""" rs = np.random.RandomState(0) fig = plt.figure() shap_values = rs.randn(200, 5) @@ -117,8 +110,7 @@ def test_random_summary_layered_violin_with_data(): @pytest.mark.mpl_image_compare(tolerance=6) def test_random_summary_with_log_scale(): - """ Check a with a log scale. - """ + """Check a with a log scale.""" np.random.seed(0) fig = plt.figure() shap.summary_plot(np.random.randn(20, 5), use_log_scale=True, show=False) diff --git a/tests/plots/test_text.py b/tests/plots/test_text.py index e2c96f043..ef81d65b6 100644 --- a/tests/plots/test_text.py +++ b/tests/plots/test_text.py @@ -4,9 +4,7 @@ def test_single_text_to_text(): - """ Just make sure the test_plot function doesn't crash. - """ - + """Just make sure the test_plot function doesn't crash.""" test_values = np.array([ [10.61284012, 3.28389317], [-3.77245945, 10.76889759], diff --git a/tests/plots/test_waterfall.py b/tests/plots/test_waterfall.py index 312eb808a..faebae530 100644 --- a/tests/plots/test_waterfall.py +++ b/tests/plots/test_waterfall.py @@ -8,8 +8,7 @@ def test_waterfall_input_is_explanation(): - """Checks an error is raised if a non-Explanation object is passed as input. - """ + """Checks an error is raised if a non-Explanation object is passed as input.""" with pytest.raises( TypeError, match="waterfall plot requires an `Explanation` object", @@ -27,8 +26,7 @@ def test_waterfall_wrong_explanation_shape(explainer): @pytest.mark.mpl_image_compare(tolerance=3) def test_waterfall(explainer): - """ Test the new waterfall plot. - """ + """Test the new waterfall plot.""" fig = plt.figure() shap_values = explainer(explainer.data) shap.plots.waterfall(shap_values[0]) @@ -38,8 +36,7 @@ def test_waterfall(explainer): @pytest.mark.mpl_image_compare(tolerance=3) def test_waterfall_legacy(explainer): - """ Test the old waterfall plot. - """ + """Test the old waterfall plot.""" shap_values = explainer.shap_values(explainer.data) fig = plt.figure() shap.plots._waterfall.waterfall_legacy(explainer.expected_value, shap_values[0]) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index af5f8b2fa..623bb282b 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,5 +1,4 @@ -""" This file contains tests for the `shap.datasets` module. -""" +"""This file contains tests for the `shap.datasets` module.""" import pytest diff --git a/tests/test_explanation.py b/tests/test_explanation.py index 1e3ba0e99..8a3360c76 100644 --- a/tests/test_explanation.py +++ b/tests/test_explanation.py @@ -1,5 +1,4 @@ -"""This file contains tests for the `shap._explanation` module. -""" +"""This file contains tests for the `shap._explanation` module.""" import numpy as np import pytest @@ -29,8 +28,7 @@ def test_explanation_hstack(random_seed): def test_explanation_hstack_errors(random_seed): - """Checks that `hstack` throws errors on invalid input. - """ + """Checks that `hstack` throws errors on invalid input.""" # generate 2 Explanation objects for stacking rs = np.random.RandomState(random_seed) base_vals = np.ones(20) * 0.123