Skip to content

Commit

Permalink
Change all uses of run_with_metadata() in lit_nlp/components to call …
Browse files Browse the repository at this point in the history
…run() directly in preparation for removing all *_with_metadata() methods across LIT.

PiperOrigin-RevId: 551209869
  • Loading branch information
nadah09 authored and LIT team committed Jul 26, 2023
1 parent 21d523d commit 5f1a971
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 128 deletions.
20 changes: 10 additions & 10 deletions lit_nlp/components/curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@
class CurvesInterpreter(lit_components.Interpreter):
"""Returns data for rendering ROC and Precision-Recall curves."""

def run_with_metadata(self,
indexed_inputs: Sequence[IndexedInput],
model: lit_model.Model,
dataset: lit_dataset.IndexedDataset,
model_outputs: Optional[Sequence[JsonDict]] = None,
config: Optional[JsonDict] = None):
def run(self,
inputs: Sequence[JsonDict],
model: lit_model.Model,
dataset: lit_dataset.Dataset,
model_outputs: Optional[Sequence[JsonDict]] = None,
config: Optional[JsonDict] = None):
if not config:
raise ValueError('Curves required config parameters but received none.')

Expand All @@ -69,12 +69,12 @@ def run_with_metadata(self,
' model spec to output a single MulticlassPreds field.'
)

if not indexed_inputs:
if not inputs:
return {ROC_DATA: [], PR_DATA: []}

# Run prediction if needed:
if model_outputs is None:
model_outputs = list(model.predict_with_metadata(indexed_inputs))
model_outputs = list(model.predict(inputs))

# Get scores for the target label.
pred_spec = output_spec.get(predictions_key)
Expand All @@ -91,8 +91,8 @@ def run_with_metadata(self,
# Get ground truth for the target label.
parent_key = pred_spec.parent
ground_truth_list = []
for indexed_input in indexed_inputs:
ground_truth_label = indexed_input['data'][parent_key]
for ex in inputs:
ground_truth_label = ex[parent_key]
ground_truth = 1.0 if ground_truth_label == target_label else 0.0
ground_truth_list.append(ground_truth)

Expand Down
30 changes: 15 additions & 15 deletions lit_nlp/components/nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,17 @@ def is_compatible(self, model: lit_model.Model,
model_out_embs = utils.spec_contains(model.output_spec(), types.Embeddings)
return dataset_embs or model_out_embs

def run_with_metadata(
def run(
self,
indexed_inputs: Sequence[IndexedInput],
inputs: Sequence[JsonDict],
model: lit_model.Model,
dataset: lit_dataset.IndexedDataset,
dataset: lit_dataset.Dataset,
model_outputs: Optional[list[JsonDict]] = None,
config: Optional[JsonDict] = None) -> Optional[list[JsonDict]]:
"""Finds the nearest neighbors of the example specified in the config.
Args:
indexed_inputs: the dataset example to find nearest neighbors for.
inputs: the dataset example to find nearest neighbors for.
model: the model being explained.
dataset: the dataset which the current examples belong to.
model_outputs: optional model outputs from calling model.predict(inputs).
Expand All @@ -88,31 +88,31 @@ def run_with_metadata(
if not config:
raise TypeError('config must be provided')

if not (isinstance(dataset, lit_dataset.IndexedDataset)):
raise TypeError('Nearest neighbors requires an IndexedDataset to track '
'uniqueness by ID.')

nnconf = NearestNeighborsConfig(**(config or {}))

# TODO(lit-dev): Add support for selecting nearest neighbors of a set.
if len(indexed_inputs) != 1:
if len(inputs) != 1:
raise ValueError('indexed_inputs must contain exactly 1 example, found '
f'{len(indexed_inputs)}.')
f'{len(inputs)}.')

if nnconf.use_input:
if not dataset.spec().get(nnconf.embedding_name):
raise KeyError('Could not find embeddings field, '
f'{nnconf.embedding_name} in dataset spec')
# If using input values, then treat inputs as outputs instead of running
# the model.
dataset_outputs = [inp['data'] for inp in dataset.indexed_examples]
example_outputs = [inp['data'] for inp in indexed_inputs]
dataset_outputs = dataset.examples
example_outputs = inputs
else:
if not model.output_spec().get(nnconf.embedding_name):
raise KeyError('Could not find embeddings field, '
f'{nnconf.embedding_name} in model output spec')
dataset_outputs = list(
model.predict_with_metadata(
dataset.indexed_examples, dataset_name=nnconf.dataset_name))
example_outputs = list(
model.predict_with_metadata(
indexed_inputs, dataset_name=nnconf.dataset_name))
dataset_outputs = list(model.predict(dataset.examples))
example_outputs = list(model.predict(inputs))

example_output = example_outputs[0]

Expand All @@ -123,7 +123,7 @@ def run_with_metadata(
sorted_indices = np.argsort(distances)
k = nnconf.num_neighbors
k_nearest_neighbors = [
{'id': dataset.indexed_examples[original_index]['id'],
{'id': dataset.examples[original_index]['_id'],
'nn_distance': distances[original_index]
} for original_index in sorted_indices[:k]]

Expand Down
19 changes: 10 additions & 9 deletions lit_nlp/components/nearest_neighbors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,28 +68,29 @@ def setUp(self):
def test_run_nn(self):
examples = [
{
'segment': 'a'
'segment': 'a',
'_id': 'a'
},
{
'segment': 'b'
'segment': 'b',
'_id': 'b'
},
{
'segment': 'c'
'segment': 'c',
'_id': 'c'
},
]
indexed_inputs = [{'id': caching.input_hash(ex), 'data': ex}
for ex in examples]

model = TestModelNearestNeighbors()
dataset = lit_dataset.IndexedDataset(id_fn=caching.input_hash,
indexed_examples=indexed_inputs)
examples=examples)
config = {
'embedding_name': 'input_embs',
'num_neighbors': 2,
}
result = self.nearest_neighbors.run_with_metadata([indexed_inputs[1]],
model, dataset,
config=config)
result = self.nearest_neighbors.run_with_metadata(
dataset.indexed_examples[1:2], model, dataset, config=config
)
expected = {'nearest_neighbors': [
{'id': '1', 'nn_distance': 0.0},
{'id': '0', 'nn_distance': 1.7320508075688772}]}
Expand Down
42 changes: 13 additions & 29 deletions lit_nlp/components/tcav.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import model as lit_model
from lit_nlp.api import types
from lit_nlp.lib import caching
from lit_nlp.lib import utils

import numpy as np
Expand Down Expand Up @@ -118,17 +117,17 @@ def is_compatible(self, model: lit_model.Model,

return False

def run_with_metadata(
def run(
self,
indexed_inputs: Sequence[IndexedInput],
inputs: Sequence[JsonDict],
model: lit_model.Model,
dataset: lit_dataset.IndexedDataset,
dataset: lit_dataset.Dataset,
model_outputs: Optional[list[JsonDict]] = None,
config: Optional[JsonDict] = None) -> Optional[list[JsonDict]]:
"""Runs the TCAV method given the params in the inputs and config.
Args:
indexed_inputs: all examples in the dataset, in the indexed input format.
inputs: all examples in the dataset.
model: the model being explained.
dataset: the dataset which the current examples belong to.
model_outputs: optional model outputs from calling model.predict(inputs).
Expand All @@ -154,7 +153,7 @@ def run_with_metadata(
tcav_config = TCAVConfig(**(config or {}))
# TODO(b/171513556): get these from the Dataset object once indices are
# available there.
dataset_examples = indexed_inputs
dataset_examples = inputs

# Get this layer's output spec keys for gradients and embeddings.
grad_layer = tcav_config.grad_layer
Expand All @@ -172,12 +171,7 @@ def run_with_metadata(

# Get outputs using model.predict().
if model_outputs is None:
pred_kw = {}
if isinstance(model, caching.CachingModelWrapper):
pred_kw['dataset_name'] = tcav_config.dataset_name
predictions = list(
model.predict_with_metadata(dataset_examples, **pred_kw)
)
predictions = list(model.predict(dataset_examples))
else:
predictions = model_outputs

Expand All @@ -191,19 +185,19 @@ def run_with_metadata(
}]

ids_set = set(tcav_config.concept_set_ids)
concept_set = [ex for ex in dataset_examples if ex['id'] in ids_set]
concept_set = [ex for ex in dataset_examples if ex['_id'] in ids_set]

if tcav_config.negative_set_ids:
negative_ids_set = set(tcav_config.negative_set_ids)
negative_set = [
ex for ex in dataset_examples if ex['id'] in negative_ids_set
ex for ex in dataset_examples if ex['_id'] in negative_ids_set
]
return self._run_relative_tcav(grad_layer, emb_layer, grad_class_key,
concept_set, negative_set, predictions,
model, tcav_config)
else:
non_concept_set = [
ex for ex in dataset_examples if ex['id'] not in ids_set
ex for ex in dataset_examples if ex['_id'] not in ids_set
]
return self._run_default_tcav(grad_layer, emb_layer, grad_class_key,
concept_set, non_concept_set, predictions,
Expand All @@ -215,13 +209,8 @@ def _subsample(self, examples, n):
def _run_default_tcav(self, grad_layer, emb_layer, grad_class_key,
concept_set, non_concept_set, dataset_outputs, model,
config):
pred_kw = {}
if isinstance(model, caching.CachingModelWrapper):
pred_kw['dataset_name'] = config.dataset_name
concept_outputs = list(model.predict_with_metadata(concept_set, **pred_kw))
non_concept_outputs = list(
model.predict_with_metadata(non_concept_set, **pred_kw)
)
concept_outputs = list(model.predict(concept_set))
non_concept_outputs = list(model.predict(non_concept_set))

concept_results = []
# If there are more concept set examples than non-concept set examples, we
Expand Down Expand Up @@ -276,13 +265,8 @@ def _run_default_tcav(self, grad_layer, emb_layer, grad_class_key,
def _run_relative_tcav(self, grad_layer, emb_layer, grad_class_key,
concept_set, negative_set, dataset_outputs, model,
config):
pred_kw = {}
if isinstance(model, caching.CachingModelWrapper):
pred_kw['dataset_name'] = config.dataset_name
positive_outputs = list(model.predict_with_metadata(concept_set, **pred_kw))
negative_outputs = list(
model.predict_with_metadata(negative_set, **pred_kw)
)
positive_outputs = list(model.predict(concept_set))
negative_outputs = list(model.predict(negative_set))

# Ideally, for relative TCAV, users would test concepts with at least ~100
# examples each so we can perform ~15 runs on unique subsets.
Expand Down
72 changes: 37 additions & 35 deletions lit_nlp/components/tcav_int_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,48 +33,50 @@
)

_ALPHABET_EXAMPLES = [
{'sentence': 'a'},
{'sentence': 'b'},
{'sentence': 'c'},
{'sentence': 'd'},
{'sentence': 'e'},
{'sentence': 'f'},
{'sentence': 'g'},
{'sentence': 'h'},
{'sentence': 'i'}
{'sentence': 'a', '_id': 'a'},
{'sentence': 'b', '_id': 'b'},
{'sentence': 'c', '_id': 'c'},
{'sentence': 'd', '_id': 'd'},
{'sentence': 'e', '_id': 'e'},
{'sentence': 'f', '_id': 'f'},
{'sentence': 'g', '_id': 'g'},
{'sentence': 'h', '_id': 'h'},
{'sentence': 'i', '_id': 'i'},
]

_ALPHABET_EXAMPLES_INDEXED = [
{'id': caching.input_hash(ex), 'data': ex} for ex in _ALPHABET_EXAMPLES
{'id': ex['_id'], 'data': ex} for ex in _ALPHABET_EXAMPLES
]

_EMOTION_EXAMPLES = [
{'sentence': 'happy'}, # 0
{'sentence': 'sad'}, # 1
{'sentence': 'good'}, # 2
{'sentence': 'bad'}, # 3
{'sentence': 'pretty'}, # 4
{'sentence': 'ugly'}, # 5
{'sentence': 'sweet'}, # 6
{'sentence': 'bitter'}, # 7
{'sentence': 'well'}, # 8
{'sentence': 'poor'}, # 9
{'sentence': 'compelling'}, # 10
{'sentence': 'boring'}, # 11
{'sentence': 'pleasing'}, # 12
{'sentence': 'gross'}, # 13
{'sentence': 'blue'}, # 14
{'sentence': 'red'}, # 15
{'sentence': 'flower'}, # 16
{'sentence': 'bee'}, # 17
{'sentence': 'snake'}, # 18
{'sentence': 'windshield'}, # 19
{'sentence': 'plant'}, # 20
{'sentence': 'scary'}, # 21
{'sentence': 'pencil'}, # 22
{'sentence': 'hello'} # 23
{'sentence': 'happy', '_id': 'happy'}, # 0
{'sentence': 'sad', '_id': 'sad'}, # 1
{'sentence': 'good', '_id': 'good'}, # 2
{'sentence': 'bad', '_id': 'bad'}, # 3
{'sentence': 'pretty', '_id': 'pretty'}, # 4
{'sentence': 'ugly', '_id': 'ugly'}, # 5
{'sentence': 'sweet', '_id': 'sweet'}, # 6
{'sentence': 'bitter', '_id': 'bitter'}, # 7
{'sentence': 'well', '_id': 'well'}, # 8
{'sentence': 'poor', '_id': 'poor'}, # 9
{'sentence': 'compelling', '_id': 'compelling'}, # 10
{'sentence': 'boring', '_id': 'boring'}, # 11
{'sentence': 'pleasing', '_id': 'pleasing'}, # 12
{'sentence': 'gross', '_id': 'gross'}, # 13
{'sentence': 'blue', '_id': 'blue'}, # 14
{'sentence': 'red', '_id': 'red'}, # 15
{'sentence': 'flower', '_id': 'flower'}, # 16
{'sentence': 'bee', '_id': 'bee'}, # 17
{'sentence': 'snake', '_id': 'snake'}, # 18
{'sentence': 'windshield', '_id': 'windshield'}, # 19
{'sentence': 'plant', '_id': 'plant'}, # 20
{'sentence': 'scary', '_id': 'scary'}, # 21
{'sentence': 'pencil', '_id': 'pencil'}, # 22
{'sentence': 'hello', '_id': 'hello'} # 23
]

_EMOTION_EXAMPLES_INDEXED = [
{'id': caching.input_hash(ex), 'data': ex} for ex in _EMOTION_EXAMPLES
{'id': ex['_id'], 'data': ex} for ex in _EMOTION_EXAMPLES
]


Expand Down

0 comments on commit 5f1a971

Please sign in to comment.