diff --git a/.travis.yml b/.travis.yml index da3aa5d3..06ad0eb2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -45,6 +45,9 @@ install: # Install other deps for the testing script below. - pip install pytest + # Dependencies only needed for testing. + - pip install lime + script: # Build front-end. - pushd lit_nlp && yarn && yarn build && popd diff --git a/environment.yml b/environment.yml index 599a249d..53db0dcd 100644 --- a/environment.yml +++ b/environment.yml @@ -25,7 +25,6 @@ dependencies: - pip: - tensorflow - tensorflow-datasets - - lime - rouge-score - sacrebleu - umap-learn diff --git a/lit_nlp/components/citrus/lime.py b/lit_nlp/components/citrus/lime.py new file mode 100644 index 00000000..4b389e20 --- /dev/null +++ b/lit_nlp/components/citrus/lime.py @@ -0,0 +1,176 @@ +# Lint as: python3 +"""Local Interpretable Model-agnostic Explanations (LIME). + +LIME was proposed in the following paper: + +> "Why Should I Trust You?": Explaining the Predictions of Any Classifier +> Marco Tulio Ribeiro, Sameer Singh, Carlos Guestrin +> https://arxiv.org/abs/1602.04938 + +LIME explains classifiers by returning a feature attribution score +for each input feature. It works as follows: + +1) Sample perturbation masks. First the number of masked features is sampled + (uniform, at least 1), and then that number of features are randomly chosen + to be masked out (without replacement). +2) Get predictions from the model for those perturbations. Use these as labels. +3) Fit a linear model to associate the input positions indicated by the binary + mask with the resulting predicted label. + +The resulting feature importance scores are the linear model coefficients for +the requested output class or (in case of regression) the output score. + +This is a reimplementation of the original https://github.com/marcotcr/lime +and is tested for compatibility. This version supports applying LIME to text +input, also in case of regression and binary-classification where the +prediction function only outputs a scalar for each input sentence. +""" +import functools +from typing import Any, Callable, Iterable, Optional, Sequence +from lit_nlp.components.citrus import helpers +import numpy as np +from sklearn import linear_model +from sklearn import metrics + +DEFAULT_KERNEL_WIDTH = 25 +DEFAULT_MASK_TOKEN = '' +DEFAULT_NUM_SAMPLES = 3000 +DEFAULT_SOLVER = 'cholesky' + + +def sample_masks(num_samples: int, num_features: int, seed: int = None): + """Samples LIME masks with at least 1 position disabled per sampled mask. + + The number of disabled features is sampled from a uniform distribution. + + Args: + num_samples: The number of samples. + num_features: The number of features to sample a mask for. Typically this is + the number of tokens in the sentence. + seed: Set this to an integer to make the sampling deterministic. + + Returns: + Masks [num_samples, num_features] indicating which features are + enabled (True) and which ones are disabled (False). + """ + rng = np.random.RandomState(seed) + positions = np.tile(np.arange(num_features), (num_samples, 1)) + permutation_fn = np.vectorize(rng.permutation, signature='(n)->(n)') + permutations = permutation_fn(positions) # A shuffled range of positions. + num_disabled_features = rng.randint(1, num_features + 1, (num_samples, 1)) + # For num_disabled_features[i] == 2, this will set indices 0 and 1 to False. + return permutations >= num_disabled_features + + +def get_perturbations(tokens: Sequence[str], + masks: np.ndarray, + mask_token: str = '') -> Iterable[str]: + """Returns strings with the masked tokens replaced with `mask_token`.""" + for mask in masks: + parts = [t if mask[i] else mask_token for i, t in enumerate(tokens)] + yield ' '.join(parts) + + +def exponential_kernel( + distance: float, kernel_width: float = DEFAULT_KERNEL_WIDTH) -> np.ndarray: + """The exponential kernel.""" + return np.sqrt(np.exp(-(distance**2) / kernel_width**2)) + + +def explain( + sentence: str, + predict_fn: Callable[[Iterable[str]], np.ndarray], + class_to_explain: Optional[int] = None, + num_samples: int = DEFAULT_NUM_SAMPLES, + tokenizer: Any = str.split, + mask_token: str = DEFAULT_MASK_TOKEN, + alpha: float = 1.0, + solver: str = DEFAULT_SOLVER, + kernel: Callable[..., np.ndarray] = exponential_kernel, + distance_fn: Callable[..., np.ndarray] = functools.partial( + metrics.pairwise.pairwise_distances, metric='cosine'), + distance_scale: float = 100., + return_model: bool = False, + return_score: bool = False, + return_prediction: bool = False, + seed: Optional[int] = None, +) -> helpers.PosthocExplanation: + """Returns the LIME explanation for a given sentence. + + By default, this function returns an explanation object containing feature + importance scores and the intercept. Optionally, more information can be + returned, such as the linear model, the score of the model on the perturbation + set, and the prediction that the linear model makes on the original sentence. + + Args: + sentence: An input to be explained. + predict_fn: A prediction function that returns an array of outputs given a + list of inputs. The output shape is [len(inputs)] for regression and + binary classification (with scalar output), and [len(inputs), num_classes] + for multi-class classification. + class_to_explain: The class ID to explain in case of multi-class + classification, where `predict_fn` returns outputs with multiple + dimensions for each input. For example, use 2 to explain the third class + in 3-class classification. For regression and binary classification, where + `predict_fn` returns a scalar for each input, this does not need to be + set. + num_samples: The number of n-grams to sample. + tokenizer: A function that splits the input sentence into tokens. + mask_token: The token that is used for masking tokens, e.g., ''. + alpha: Regularization strength of the linear approximation model. See + `sklearn.linear_model.Ridge` for details. + solver: Solver to use in the linear approximation model. See + `sklearn.linear_model.Ridge` for details. + kernel: A kernel function to be used on the distance function. By default, + the exponential kernel (with kernel width DEFAULT_KERNEL_WIDTH) is used. + distance_fn: A distance function to use in range [0, 1]. Default: cosine. + distance_scale: A scalar factor multiplied with the distances before the + kernel is applied. + return_model: Returns the fitted linear model. + return_score: Returns the score of the linear model on the perturbations. + This is the R^2 of the linear model predictions w.r.t. their targets. + return_prediction: Returns the prediction of the linear model on the full + original sentence. + seed: Optional random seed to make the explanation deterministic. + + Returns: + The explanation for the requested class. + """ + # TODO(bastings): Provide sentence already tokenized to reduce split/join ops. + tokens = tokenizer(sentence) + + masks = sample_masks(num_samples + 1, len(tokens), seed=seed) + assert masks.shape[0] == num_samples + 1, 'Expected num_samples + 1 masks.' + all_true_mask = np.ones_like(masks[0], dtype=np.bool) + masks[0] = all_true_mask # First mask is the full sentence. + + perturbations = list(get_perturbations(tokens, masks, mask_token)) + outputs = predict_fn(perturbations) + + if len(outputs.shape) > 1: + assert class_to_explain is not None, \ + 'class_to_explain needs to be set when `predict_fn` returns a 2D tensor' + outputs = outputs[:, class_to_explain] # We are only interested in 1 class. + + distances = distance_fn(all_true_mask.reshape(1, -1), masks).flatten() + distances = distance_scale * distances + distances = kernel(distances) + + # Fit a linear model for the requested output class. + model = linear_model.Ridge( + alpha=alpha, solver=solver, random_state=seed).fit( + masks, outputs, sample_weight=distances) + + explanation = helpers.PosthocExplanation( + feature_importance=model.coef_, intercept=model.intercept_) + + if return_model: + explanation.model = model + + if return_score: + explanation.score = model.score(masks, outputs) + + if return_prediction: + explanation.prediction = model.predict(all_true_mask.reshape(1, -1)) + + return explanation diff --git a/lit_nlp/components/citrus/lime_test.py b/lit_nlp/components/citrus/lime_test.py new file mode 100644 index 00000000..a27bb936 --- /dev/null +++ b/lit_nlp/components/citrus/lime_test.py @@ -0,0 +1,312 @@ +# Lint as: python3 +import collections +import functools +from absl.testing import absltest +from absl.testing import parameterized +import lime as original_lime +from lit_nlp.components.citrus import lime +from lit_nlp.components.citrus import utils +import numpy as np +from scipy import special +from scipy import stats + + +class LimeTest(parameterized.TestCase): + + def test_sample_masks_returns_correct_shape_and_type(self): + """Tests if LIME mask samples have the right shape and type.""" + num_samples = 2 + num_features = 3 + masks = lime.sample_masks(num_samples, num_features, seed=0) + self.assertEqual(np.dtype('bool'), masks.dtype) + self.assertEqual((num_samples, num_features), masks.shape) + + def test_sample_masks_contains_extreme_samples(self): + """Tests if the masks contain extreme samples (1 or all features).""" + num_samples = 1000 + num_features = 10 + masks = lime.sample_masks(num_samples, num_features, seed=0) + num_disabled = (~masks).sum(axis=-1) + self.assertEqual(1, min(num_disabled)) + self.assertEqual(num_features, max(num_disabled)) + + def test_sample_masks_returns_uniformly_distributed_masks(self): + """Tests if the masked positions are uniformly distributed.""" + num_samples = 10000 + num_features = 100 + masks = lime.sample_masks(num_samples, num_features, seed=0) + # The mean should be ~0.5, but this is also true when normally distributed. + np.testing.assert_almost_equal(masks.mean(), 0.5, decimal=2) + # We should see each possible masked count approx. the same number of times. + # We check this by looking at the entropy which should be around 1.0. + counter = collections.Counter(masks.sum(axis=-1)) + entropy = stats.entropy(list(counter.values()), base=num_features) + np.testing.assert_almost_equal(entropy, 1.0, decimal=2) + + def test_get_perturbations_returns_correctly_masked_string(self): + """Tests obtaining perturbations from tokens and a mask.""" + sentence = 'It is a great movie but also somewhat bad .' + tokens = sentence.split() + # We create a mock mask with False for tokens with an 'a', True otherwise. + masks = np.array([[False if 'a' in token else True for token in tokens]]) + perturbations = list(lime.get_perturbations(tokens, masks, mask_token='_')) + expected = 'It is _ _ movie but _ _ _ .' + self.assertEqual(expected, perturbations[0]) + + @parameterized.named_parameters( + { + 'testcase_name': 'is_one_for_zero_distance', + 'distance': 0., + 'kernel_width': 10, + 'expected': 1., + }, { + 'testcase_name': 'is_zero_for_exp_kernel_width_distance', + 'distance': np.exp(10), + 'kernel_width': 10, + 'expected': 0., + }) + def test_exponential_kernel(self, distance, kernel_width, expected): + """Tests a few known exponential kernel results.""" + result = lime.exponential_kernel(distance, kernel_width) + np.testing.assert_almost_equal(expected, result) + + @parameterized.named_parameters( + { + 'testcase_name': 'correctly_identifies_important_tokens_for_1d_input', + 'sentence': 'It is a great movie but also somewhat bad .', + 'num_samples': 1000, + 'positive_token': 'great', + 'negative_token': 'bad', + 'num_classes': 1, + 'class_to_explain': None, + }, { + 'testcase_name': 'correctly_identifies_important_tokens_for_2d_input', + 'sentence': 'It is a great movie but also somewhat bad .', + 'num_samples': 1000, + 'positive_token': 'great', + 'negative_token': 'bad', + 'num_classes': 2, + 'class_to_explain': 1, + }, { + 'testcase_name': 'correctly_identifies_important_tokens_for_3d_input', + 'sentence': 'It is a great movie but also somewhat bad .', + 'num_samples': 1000, + 'positive_token': 'great', + 'negative_token': 'bad', + 'num_classes': 3, + 'class_to_explain': 2, + }) + def test_explain(self, sentence, num_samples, positive_token, negative_token, + num_classes, class_to_explain): + """Tests explaining text classifiers with various output dimensions.""" + + def _predict_fn(sentences): + """Mock prediction function.""" + rs = np.random.RandomState(seed=0) + predictions = [] + for sentence in sentences: + probs = rs.uniform(0., 1., num_classes) + # To check if LIME finds the right positive/negative correlations. + if negative_token in sentence: + probs[class_to_explain] = probs[class_to_explain] - 1. + if positive_token in sentence: + probs[class_to_explain] = probs[class_to_explain] + 1. + predictions.append(probs) + + predictions = np.stack(predictions, axis=0) + if num_classes == 1: + return np.squeeze(special.expit(predictions), -1) + else: + return special.softmax(predictions, axis=-1) + + explanation = lime.explain( + sentence, + _predict_fn, + class_to_explain=class_to_explain, + num_samples=num_samples, + tokenizer=str.split) + + self.assertLen(explanation.feature_importance, len(sentence.split())) + + # The positive word should have the highest attribution score. + positive_token_idx = sentence.split().index(positive_token) + self.assertEqual(positive_token_idx, + np.argmax(explanation.feature_importance)) + + # The negative word should have the lowest attribution score. + negative_token_idx = sentence.split().index(negative_token) + self.assertEqual(negative_token_idx, + np.argmin(explanation.feature_importance)) + + @parameterized.named_parameters({ + 'testcase_name': 'correctly_identifies_important_tokens_for_regression', + 'sentence': 'It is a great movie but also somewhat bad .', + 'num_samples': 1000, + 'positive_token': 'great', + 'negative_token': 'bad', + }) + def test_explain_regression(self, sentence, num_samples, positive_token, + negative_token): + """Tests explaining text classifiers with various output dimensions.""" + + def _predict_fn(sentences): + """Mock prediction function.""" + rs = np.random.RandomState(seed=0) + predictions = [] + for sentence in sentences: + output = rs.uniform(-2., 2.) + # To check if LIME finds the right positive/negative correlations. + if negative_token in sentence: + output -= rs.uniform(0., 2.) + if positive_token in sentence: + output += rs.uniform(0., 2.) + predictions.append(output) + + predictions = np.stack(predictions, axis=0) + return predictions + + explanation = lime.explain( + sentence, _predict_fn, num_samples=num_samples, tokenizer=str.split) + + self.assertLen(explanation.feature_importance, len(sentence.split())) + + # The positive word should have the highest attribution score. + positive_token_idx = sentence.split().index(positive_token) + self.assertEqual(positive_token_idx, + np.argmax(explanation.feature_importance)) + + # The negative word should have the lowest attribution score. + negative_token_idx = sentence.split().index(negative_token) + self.assertEqual(negative_token_idx, + np.argmin(explanation.feature_importance)) + + def test_explain_returns_explanation_with_intercept(self): + """Tests if the explanation contains an intercept value.""" + + def _predict_fn(sentences): + return np.random.uniform(0., 1., [len(list(sentences)), 2]) + + explanation = lime.explain('Test sentence', _predict_fn, 1, num_samples=5) + self.assertNotEqual(explanation.intercept, 0.) + + def test_explain_returns_explanation_with_model(self): + """Tests if the explanation contains the model.""" + + def _predict_fn(sentences): + return np.random.uniform(0., 1., [len(list(sentences)), 2]) + + explanation = lime.explain( + 'Test sentence', + _predict_fn, + class_to_explain=1, + num_samples=5, + return_model=True) + self.assertIsNotNone(explanation.model) + + def test_explain_returns_explanation_with_score(self): + """Tests if the explanation contains a linear model score.""" + + def _predict_fn(sentences): + return np.random.uniform(0., 1., [len(list(sentences)), 2]) + + explanation = lime.explain( + 'Test sentence', + _predict_fn, + class_to_explain=1, + num_samples=5, + return_score=True) + self.assertIsNotNone(explanation.score) + + def test_explain_returns_explanation_with_prediction(self): + """Tests if the explanation contains a prediction.""" + + def _predict_fn(sentences): + return np.random.uniform(0., 1., [len(list(sentences)), 2]) + + explanation = lime.explain( + 'Test sentence', + _predict_fn, + class_to_explain=1, + num_samples=5, + return_prediction=True) + self.assertIsNotNone(explanation.prediction) + + @parameterized.named_parameters( + { + 'testcase_name': 'for_2d_input', + 'sentence': ' '.join(list('abcdefghijklmnopqrstuvwxyz')), + 'num_samples': 5000, + 'num_classes': 2, + 'class_to_explain': 1, + }, { + 'testcase_name': 'for_3d_input', + 'sentence': ' '.join(list('abcdefghijklmnopqrstuvwxyz')), + 'num_samples': 5000, + 'num_classes': 3, + 'class_to_explain': 2, + }) + def test_explain_matches_original_lime(self, sentence, num_samples, + num_classes, class_to_explain): + """Tests if Citrus LIME matches the original implementation.""" + list('abcdefghijklmnopqrstuvwxyz') + # Assign some weight to each token a-z. + # Each token contributes positively/negatively to the prediction. + rs = np.random.RandomState(seed=0) + token_weights = {token: rs.normal() for token in sentence.split()} + token_weights[lime.DEFAULT_MASK_TOKEN] = 0. + + def _predict_fn(sentences): + """Mock prediction function.""" + rs = np.random.RandomState(seed=0) + predictions = [] + for sentence in sentences: + probs = rs.normal(0., 0.1, size=num_classes) + # To check if LIME finds the right positive/negative correlations. + for token in sentence.split(): + probs[class_to_explain] += token_weights[token] + predictions.append(probs) + return np.stack(predictions, axis=0) + + # Explain the prediction using Citrus LIME. + explanation = lime.explain( + sentence, + _predict_fn, + class_to_explain=class_to_explain, + num_samples=num_samples, + tokenizer=str.split, + mask_token=lime.DEFAULT_MASK_TOKEN, + kernel=functools.partial( + lime.exponential_kernel, kernel_width=lime.DEFAULT_KERNEL_WIDTH)) + scores = explanation.feature_importance # [seq_len] + scores = utils.normalize_scores(scores, make_positive=False) + + # Explain the prediction using original LIME. + original_lime_explainer = original_lime.lime_text.LimeTextExplainer( + class_names=map(str, np.arange(num_classes)), + mask_string=lime.DEFAULT_MASK_TOKEN, + kernel_width=lime.DEFAULT_KERNEL_WIDTH, + split_expression=str.split, + bow=False) + num_features = len(sentence.split()) + original_explanation = original_lime_explainer.explain_instance( + sentence, + _predict_fn, + labels=(class_to_explain,), + num_features=num_features, + num_samples=num_samples) + + # original_explanation.local_exp is a dict that has a key class_to_explain, + # which gives a sequence of (index, score) pairs. + # We convert it to an array [seq_len] with a score per position. + original_scores = np.zeros(num_features) + for index, score in original_explanation.local_exp[class_to_explain]: + original_scores[index] = score + original_scores = utils.normalize_scores( + original_scores, make_positive=False) + + # Test that Citrus LIME and original LIME match. + np.testing.assert_allclose(scores, original_scores, atol=0.01) + + +if __name__ == '__main__': + absltest.main() diff --git a/lit_nlp/components/lime_explainer.py b/lit_nlp/components/lime_explainer.py index 0ef32092..97094539 100644 --- a/lit_nlp/components/lime_explainer.py +++ b/lit_nlp/components/lime_explainer.py @@ -16,36 +16,46 @@ """Gradient-based attribution.""" import copy -from typing import cast, Any, List, Text, Optional +import functools +from typing import Any, Iterable, List, Optional from absl import logging -from lime import lime_text from lit_nlp.api import components as lit_components from lit_nlp.api import dataset as lit_dataset from lit_nlp.api import dtypes from lit_nlp.api import model as lit_model from lit_nlp.api import types +from lit_nlp.components.citrus import lime +from lit_nlp.components.citrus import utils as citrus_util from lit_nlp.lib import utils + import numpy as np JsonDict = types.JsonDict Spec = types.Spec -def new_example(original_example: JsonDict, field: Text, new_value: Any): +def new_example(original_example: JsonDict, field: str, new_value: Any): """Deep copies the example and replaces `field` with `new_value`.""" example = copy.deepcopy(original_example) example[field] = new_value return example -def explanation_to_array(explanation: Any): - """Given a LIME explanation object, return a numpy array with scores.""" - # local_exp is a List[(word_position, score)]. We need to sort it. - scores = sorted(explanation.local_exp[1]) # Puts it back in word order. - scores = np.array([v for k, v in scores]) - scores = scores / np.abs(scores).sum() - return scores +def _predict_fn(strings: Iterable[str], model: Any, original_example: JsonDict, + text_key: str, pred_key: str): + """Given raw strings, return probabilities. Used by `lime.explain`.""" + # Prepare example objects to be fed to the model for each sentence/string. + input_examples = [new_example(original_example, text_key, s) for s in strings] + + # Get model predictions for the examples. + model_outputs = model.predict(input_examples) + outputs = np.array([output[pred_key] for output in model_outputs]) + # Make outputs 1D in case of regression or binary classification. + if outputs.ndim == 2 and outputs.shape[1] == 1: + outputs = np.squeeze(outputs, axis=1) + # [len(strings)] or [len(strings), num_labels]. + return outputs class LIME(lit_components.Interpreter): @@ -64,6 +74,8 @@ def run( kernel_width: int = 25, # TODO(lit-dev): make configurable in UI. mask_string: str = '[MASK]', # TODO(lit-dev): make configurable in UI. num_samples: int = 256, # TODO(lit-dev): make configurable in UI. + class_to_explain: Optional[int] = 1, # TODO(lit-dev): b/173469699. + seed: Optional[int] = None, # TODO(lit-dev): make configurable in UI. ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" @@ -77,54 +89,44 @@ def run( logging.info('Found text fields for LIME attribution: %s', str(text_keys)) # Find the key of output probabilities field(s). - pred_keys = utils.find_spec_keys(model.output_spec(), types.MulticlassPreds) + pred_keys = utils.find_spec_keys( + model.output_spec(), (types.MulticlassPreds, types.RegressionScore)) if not pred_keys: - logging.warning('LIME did not find a multi-class predictions field.') + logging.warning('LIME did not find any supported output fields.') return None pred_key = pred_keys[0] # TODO(lit-dev): configure which prob field to use. - pred_spec = cast(types.MulticlassPreds, model.output_spec()[pred_key]) - label_names = pred_spec.vocab - - # Create a LIME text explainer instance. - explainer = lime_text.LimeTextExplainer( - class_names=label_names, - split_expression=str.split, - kernel_width=kernel_width, - mask_string=mask_string, # This is the string used to mask words. - bow=False) # bow=False masks inputs, instead of deleting them entirely. - all_results = [] # Explain each input. for input_ in inputs: # Dict[field name -> interpretations] result = {} + predict_fn = functools.partial( + _predict_fn, model=model, original_example=input_, pred_key=pred_key) # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] logging.info('Explaining: %s', input_string) - # Use the number of words as the number of features. - num_features = len(input_string.split()) - - def _predict_proba(strings: List[Text]): - """Given raw strings, return probabilities. Used by `explainer`.""" - input_examples = [new_example(input_, text_key, s) for s in strings] - model_outputs = model.predict(input_examples) - probs = np.array([output[pred_key] for output in model_outputs]) - return probs # [len(strings), num_labels] - # Perturbs the input string, gets model predictions, fits linear model. - explanation = explainer.explain_instance( - input_string, - _predict_proba, - num_features=num_features, - num_samples=num_samples) + explanation = lime.explain( + sentence=input_string, + predict_fn=functools.partial(predict_fn, text_key=text_key), + # `class_to_explain` is ignored when predict_fn output is a scalar. + class_to_explain=class_to_explain, # Index of the class to explain. + num_samples=num_samples, + tokenizer=str.split, + mask_token=mask_string, + kernel=functools.partial( + lime.exponential_kernel, kernel_width=kernel_width), + seed=seed) # Turn the LIME explanation into a list following original word order. - scores = explanation_to_array(explanation) + scores = explanation.feature_importance + # TODO(lit-dev): Move score normalization to the UI. + scores = citrus_util.normalize_scores(scores) result[text_key] = dtypes.SalienceMap(input_string.split(), scores) all_results.append(result) diff --git a/pip_package/setup.py b/pip_package/setup.py index eac8c911..e6286a07 100644 --- a/pip_package/setup.py +++ b/pip_package/setup.py @@ -27,7 +27,6 @@ "scipy", "pandas", "scikit-learn", - "lime", "sacrebleu", "umap-learn", "Werkzeug",