Skip to content

Commit

Permalink
Fix issue that prevented clustering of LIME to work.
Browse files Browse the repository at this point in the history
The LIME interpreter's result is structurally different from other LIT gradient interpreters. This commit aligns both types of structures.

PiperOrigin-RevId: 426395387
  • Loading branch information
eberts-google authored and LIT team committed Feb 4, 2022
1 parent 49faa00 commit e35d8d8
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 119 deletions.
58 changes: 20 additions & 38 deletions lit_nlp/components/salience_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
# Lint as: python3
"""kmeans clustering of salience weights."""

from typing import Dict, List, Optional, Sequence, Text, Tuple
from typing import Dict, List, Optional, Sequence, Tuple

from lit_nlp.api import components as lit_components
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import model as lit_model
from lit_nlp.api import types
from lit_nlp.lib import utils
import numpy as np
from sklearn import cluster

Expand All @@ -45,17 +44,6 @@ def __init__(self, salience_mappers: Dict[str, lit_components.Interpreter]):
self.salience_mappers = salience_mappers
self.kmeans = {}

def find_fields(self, output_spec: Spec) -> List[Text]:
# Find TokenGradients fields
grad_fields = utils.find_spec_keys(output_spec, types.TokenGradients)

# Check that these are aligned to Tokens fields
for f in grad_fields:
tokens_field = output_spec[f].align # pytype: disable=attribute-error
assert tokens_field in output_spec
assert isinstance(output_spec[tokens_field], types.Tokens)
return grad_fields

def _build_vocab(
self,
token_saliencies: List[JsonDict]) -> Tuple[Dict[str, int], List[str]]:
Expand All @@ -66,7 +54,7 @@ def _build_vocab(
depends on the order of the tokens in the input.
Args:
token_saliencies: List of mappings from gradient field to TokenSaliency
token_saliencies: List of mappings from salience field to TokenSaliency
objects. This is the result of a post-hoc explanation method, such as
gradient l2.
Expand Down Expand Up @@ -119,20 +107,20 @@ def _compute_fixed_length_representation(
absolute value is largest.
Args:
token_saliencies: List of mappings from gradient field to TokenSaliency
token_saliencies: List of mappings from salience field to TokenSaliency
objects. This is the result of a post-hoc explanation method, such as
gradient l2.
vocab_lookup: Mapping from word type to its index in the vocabulary.
Returns:
List of one mapping per example. Every element maps a gradient field to
List of one mapping per example. Every element maps a salience field to
its fixed-length representation.
"""
representations = []
for instance in token_saliencies:
per_field_results = {}

for grad_field, token_salience in instance.items():
for salience_field, token_salience in instance.items():
token_weights = {}

for token, score in zip(token_salience.tokens, token_salience.salience):
Expand All @@ -144,7 +132,7 @@ def _compute_fixed_length_representation(
# Normalize to unit length.
representation = np.asarray(representation) / np.linalg.norm(
representation)
per_field_results[grad_field] = representation
per_field_results[salience_field] = representation
representations.append(per_field_results)
return representations

Expand Down Expand Up @@ -187,9 +175,6 @@ def run_with_metadata(
config = config or {}
# If no specific inputs provided, use the entire dataset.
inputs_to_use = indexed_inputs or dataset.examples

# Find gradient fields to interpret
grad_fields = self.find_fields(model.output_spec())
token_saliencies = self.salience_mappers[
config[SALIENCE_MAPPER_KEY]].run_with_metadata(inputs_to_use, model,
dataset, model_outputs,
Expand All @@ -198,47 +183,44 @@ def run_with_metadata(
if not token_saliencies:
return None

salience_fields = list(token_saliencies[0].keys())
vocab_lookup, vocab = self._build_vocab(token_saliencies)
representations = self._compute_fixed_length_representation(
token_saliencies, vocab_lookup)

cluster_ids = {}
grad_field_to_representations = {}
grad_field_to_top_tokens = {}
salience_field_to_representations = {}
salience_field_to_top_tokens = {}

for grad_field in grad_fields:
for salience_field in salience_fields:
weight_matrix = np.vstack(
representation[grad_field] for representation in representations)
representation[salience_field] for representation in representations)
n_clusters = int(
config.get(N_CLUSTERS_KEY,
self.config_spec()[N_CLUSTERS_KEY].default))
self.kmeans[grad_field] = cluster.KMeans(n_clusters=n_clusters)
cluster_ids[grad_field] = self.kmeans[grad_field].fit_predict(
self.kmeans[salience_field] = cluster.KMeans(n_clusters=n_clusters)
cluster_ids[salience_field] = self.kmeans[salience_field].fit_predict(
weight_matrix).tolist()
grad_field_to_representations[grad_field] = weight_matrix
grad_field_to_top_tokens[grad_field] = []
salience_field_to_representations[salience_field] = weight_matrix
salience_field_to_top_tokens[salience_field] = []

for cluster_id in range(n_clusters):
# <float32>[vocab size]
mean_weight_matrix = weight_matrix[np.asarray(cluster_ids[grad_field])
== cluster_id].mean(axis=0)
mean_weight_matrix = weight_matrix[np.asarray(
cluster_ids[salience_field]) == cluster_id].mean(axis=0)
top_indices = (
mean_weight_matrix.argsort()[::-1][:int(
config.get(N_TOP_TOKENS_KEY,
self.config_spec()[N_TOP_TOKENS_KEY].default))])
top_tokens = [(vocab[i], mean_weight_matrix[i]) for i in top_indices]
grad_field_to_top_tokens[grad_field].append(top_tokens)
salience_field_to_top_tokens[salience_field].append(top_tokens)

return {
CLUSTER_ID_KEY: cluster_ids,
REPRESENTATION_KEY: grad_field_to_representations,
TOP_TOKEN_KEY: grad_field_to_top_tokens,
REPRESENTATION_KEY: salience_field_to_representations,
TOP_TOKEN_KEY: salience_field_to_top_tokens,
}

def is_compatible(self, model: lit_model.Model):
compatible_fields = self.find_fields(model.output_spec())
return len(compatible_fields)

def config_spec(self) -> types.Spec:
return {
SALIENCE_MAPPER_KEY:
Expand Down
195 changes: 114 additions & 81 deletions lit_nlp/components/salience_clustering_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,95 @@
"""Tests for lit_nlp.components.gradient_maps."""

from absl.testing import absltest
from absl.testing import parameterized
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import dtypes
from lit_nlp.components import gradient_maps
from lit_nlp.components import lime_explainer
from lit_nlp.components import salience_clustering
from lit_nlp.lib import testing_utils
import numpy as np


class SalienceClusteringTest(absltest.TestCase):
class SalienceClusteringTest(parameterized.TestCase):

def setUp(self):
super(SalienceClusteringTest, self).setUp()
self.salience_mappers = {
'Grad L2 Norm': gradient_maps.GradientNorm(),
'Grad ⋅ Input': gradient_maps.GradientDotInput()
'Grad ⋅ Input': gradient_maps.GradientDotInput(),
'LIME': lime_explainer.LIME()
}

def _call_classification_model_on_standard_input(self, config, grad_key):
inputs = [
{
'data': {
'segment': 'a b c d'
}
},
{
'data': {
'segment': 'a b c d'
}
},
{
'data': {
'segment': 'e f e f'
}
},
{
'data': {
'segment': 'e f e f'
}
},
{
'data': {
'segment': 'e f e f'
}
},
]
model = testing_utils.TestModelClassification()
dataset = lit_dataset.Dataset(None, None)

model_outputs = [{
grad_key:
np.array([[0, 0, 1, 1], [0, 1, 0, 0], [1, 1, 1, 1], [1, 0, 1, 1]]),
'tokens': ['a', 'b', 'c', 'd'],
'grad_class':
'1'
}, {
grad_key:
np.array([[0, 0, 1, 1], [0, 1, 0, 0], [1, 1, 1, 1], [1, 0, 1, 1]]),
'tokens': ['a', 'b', 'c', 'd'],
'grad_class':
'1'
}, {
grad_key:
np.array([[1, 1, 1, 1], [1, 0, 1, 1], [1, 1, 1, 1], [0, 0, 1, 1]]),
'tokens': ['e', 'f', 'e', 'g'],
'grad_class':
'1'
}, {
grad_key:
np.array([[1, 1, 1, 1], [1, 0, 1, 1], [1, 1, 1, 1], [0, 0, 1, 1]]),
'tokens': ['e', 'f', 'e', 'g'],
'grad_class':
'1'
}, {
grad_key:
np.array([[1, 1, 1, 1], [1, 0, 1, 1], [1, 1, 1, 1], [0, 0, 1, 1]]),
'tokens': ['e', 'f', 'e', 'g'],
'grad_class':
'1'
}]

clustering_component = salience_clustering.SalienceClustering(
self.salience_mappers)
result = clustering_component.run_with_metadata(inputs, model, dataset,
model_outputs, config)
return result, clustering_component

def test_build_vocab(self):
token_saliencies = [
{
Expand Down Expand Up @@ -96,101 +168,51 @@ def test_convert_to_bow_vector(self):
]
np.testing.assert_equal(expected, representations)

def test_clustering(self):
inputs = [
{
'data': {
'segment': 'a b c d'
}
},
{
'data': {
'segment': 'a b c d'
}
},
{
'data': {
'segment': 'e f e f'
}
},
{
'data': {
'segment': 'e f e f'
}
},
{
'data': {
'segment': 'e f e f'
}
},
]
model = testing_utils.TestModelClassification()
dataset = lit_dataset.Dataset(None, None)
@parameterized.named_parameters(
('lit_internal_salience', 'Grad L2 Norm', 'input_embs_grad'),
('lime', 'LIME', 'segment'),
)
def test_clustering(self, salience_mapper, grad_key):
"""Tests clustering on LIT-internal gradient methods."""
config = {
salience_clustering.SALIENCE_MAPPER_KEY: 'Grad L2 Norm',
salience_clustering.SALIENCE_MAPPER_KEY: salience_mapper,
salience_clustering.N_CLUSTERS_KEY: 2,
salience_clustering.N_TOP_TOKENS_KEY: 2
}

model_outputs = [{
'input_embs_grad':
np.array([[0, 0, 1, 1], [0, 1, 0, 0], [1, 1, 1, 1], [1, 0, 1, 1]]),
'tokens': ['a', 'b', 'c', 'd'],
'grad_class':
'1'
}, {
'input_embs_grad':
np.array([[0, 0, 1, 1], [0, 1, 0, 0], [1, 1, 1, 1], [1, 0, 1, 1]]),
'tokens': ['a', 'b', 'c', 'd'],
'grad_class':
'1'
}, {
'input_embs_grad':
np.array([[1, 1, 1, 1], [1, 0, 1, 1], [1, 1, 1, 1], [0, 0, 1, 1]]),
'tokens': ['e', 'f', 'e', 'g'],
'grad_class':
'1'
}, {
'input_embs_grad':
np.array([[1, 1, 1, 1], [1, 0, 1, 1], [1, 1, 1, 1], [0, 0, 1, 1]]),
'tokens': ['e', 'f', 'e', 'g'],
'grad_class':
'1'
}, {
'input_embs_grad':
np.array([[1, 1, 1, 1], [1, 0, 1, 1], [1, 1, 1, 1], [0, 0, 1, 1]]),
'tokens': ['e', 'f', 'e', 'g'],
'grad_class':
'1'
}]

clustering_component = salience_clustering.SalienceClustering(
self.salience_mappers)
result = clustering_component.run_with_metadata(inputs, model, dataset,
model_outputs, config)
result, clustering_component = (
self._call_classification_model_on_standard_input(config, grad_key))
# Cluster id assignment is random, so in one run the first 2 examples may
# be cluster 0, in the next run they may be in cluster 1.
cluster_id_of_first = result[
salience_clustering.CLUSTER_ID_KEY]['input_embs_grad'][0]
salience_clustering.CLUSTER_ID_KEY][grad_key][0]
cluster_id_of_last = result[
salience_clustering.CLUSTER_ID_KEY]['input_embs_grad'][-1]
salience_clustering.CLUSTER_ID_KEY][grad_key][-1]
np.testing.assert_equal(
result[salience_clustering.CLUSTER_ID_KEY]['input_embs_grad'], [
result[salience_clustering.CLUSTER_ID_KEY][grad_key], [
cluster_id_of_first, cluster_id_of_first, cluster_id_of_last,
cluster_id_of_last, cluster_id_of_last
])
np.testing.assert_allclose(
result[salience_clustering.REPRESENTATION_KEY]['input_embs_grad'][0],
result[salience_clustering.REPRESENTATION_KEY]['input_embs_grad'][1])
result[salience_clustering.REPRESENTATION_KEY][grad_key][0],
result[salience_clustering.REPRESENTATION_KEY][grad_key][1])
np.testing.assert_allclose(
result[salience_clustering.REPRESENTATION_KEY]['input_embs_grad'][2],
result[salience_clustering.REPRESENTATION_KEY]['input_embs_grad'][3])
result[salience_clustering.REPRESENTATION_KEY][grad_key][2],
result[salience_clustering.REPRESENTATION_KEY][grad_key][3])
np.testing.assert_allclose(
result[salience_clustering.REPRESENTATION_KEY]['input_embs_grad'][2],
result[salience_clustering.REPRESENTATION_KEY]['input_embs_grad'][4])
self.assertIn('input_embs_grad', clustering_component.kmeans)
self.assertIsNotNone(clustering_component.kmeans['input_embs_grad'])
result[salience_clustering.REPRESENTATION_KEY][grad_key][2],
result[salience_clustering.REPRESENTATION_KEY][grad_key][4])
self.assertIn(grad_key, clustering_component.kmeans)
self.assertIsNotNone(clustering_component.kmeans[grad_key])

def test_top_tokens(self):
"""Tests top token results (doesn't apply for LIME with a test model)."""
config = {
salience_clustering.SALIENCE_MAPPER_KEY: 'Grad L2 Norm',
salience_clustering.N_CLUSTERS_KEY: 2,
salience_clustering.N_TOP_TOKENS_KEY: 2
}
result, _ = self._call_classification_model_on_standard_input(
config, 'input_embs_grad')
# Clustering isn't deterministic so we don't know if examples 1 and 2 are
# in cluster 0 or 1.
for cluster_id in range(config[salience_clustering.N_CLUSTERS_KEY]):
Expand All @@ -204,6 +226,17 @@ def test_clustering(self):
top_tokens_are_set_ef = subset_ef == top_tokens
self.assertTrue(top_tokens_are_set_cd or top_tokens_are_set_ef)

def test_string_config_item(self):
"""Tests clustering when config contains a string value."""
config = {
salience_clustering.SALIENCE_MAPPER_KEY: 'Grad L2 Norm',
salience_clustering.N_CLUSTERS_KEY: '2',
salience_clustering.N_TOP_TOKENS_KEY: 2
}
result, _ = self._call_classification_model_on_standard_input(
config, 'input_embs_grad')
self.assertIsNotNone(result)


if __name__ == '__main__':
absltest.main()

0 comments on commit e35d8d8

Please sign in to comment.