Skip to content

Commit

Permalink
Change run_with_metadata() calls to run() in preparation for removal …
Browse files Browse the repository at this point in the history
…of run_with_metadata() method across LIT.

PiperOrigin-RevId: 552536081
  • Loading branch information
nadah09 authored and LIT team committed Jul 31, 2023
1 parent df14c00 commit 9767670
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 170 deletions.
12 changes: 6 additions & 6 deletions lit_nlp/components/curves_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def setUp(self):
def test_label_not_in_config(self):
"""The interpreter throws an error if the config doesn't have Label."""
with self.assertRaises(ValueError):
self.ci.run_with_metadata(
indexed_inputs=self.dataset.indexed_examples,
self.ci.run(
inputs=self.dataset.examples,
model=self.model,
dataset=self.dataset,
)
Expand All @@ -140,8 +140,8 @@ def test_model_output_is_missing_in_config(self):
The interpreter throws an error if the name of the output is absent.
"""
with self.assertRaises(ValueError):
self.ci.run_with_metadata(
indexed_inputs=self.dataset.indexed_examples,
self.ci.run(
inputs=self.dataset.examples,
model=self.model,
dataset=self.dataset,
config={'Label': 'red'},
Expand All @@ -165,8 +165,8 @@ def test_interpreter_honors_user_selected_label(
self, label: str, exp_roc: _Curve, exp_pr: _Curve
):
"""Tests a happy scenario when a user doesn't specify the class label."""
curves_data = self.ci.run_with_metadata(
indexed_inputs=self.dataset.indexed_examples,
curves_data = self.ci.run(
inputs=self.dataset.examples,
model=self.model,
dataset=self.dataset,
config={
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/components/nearest_neighbors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def test_run_nn(self):
'embedding_name': 'input_embs',
'num_neighbors': 2,
}
result = self.nearest_neighbors.run_with_metadata(
dataset.indexed_examples[1:2], model, dataset, config=config
result = self.nearest_neighbors.run(
dataset.examples[1:2], model, dataset, config=config
)
expected = {'nearest_neighbors': [
{'id': '1', 'nn_distance': 0.0},
Expand Down
30 changes: 12 additions & 18 deletions lit_nlp/components/pdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,8 @@ def test_regression_num(self):
config = {
'feature': 'num',
}
result = self.pdp.run_with_metadata([self.dataset.indexed_examples[0]],
self.reg_model, self.dataset,
config=config)
result = self.pdp.run([self.dataset.examples[0]], self.reg_model,
self.dataset, config=config)
expected = {1.0: 2.0, 2.0: 3.0, 3.0: 4.0, 4.0: 5.0, 5.0: 6.0, 6.0: 7.0,
7.0: 8.0, 8.0: 9.0, 9.0: 10.0, 10.0: 11.0}
testing_utils.assert_deep_almost_equal(self, result['score'], expected)
Expand All @@ -101,9 +100,8 @@ def test_provided_range(self):
'feature': 'num',
'range': [0, 9]
}
result = self.pdp.run_with_metadata([self.dataset.indexed_examples[0]],
self.reg_model, self.dataset,
config=config)
result = self.pdp.run([self.dataset.examples[0]], self.reg_model,
self.dataset, config=config)
expected = {0.0: 1.0, 1.0: 2.0, 2.0: 3.0, 3.0: 4.0, 4.0: 5.0, 5.0: 6.0,
6.0: 7.0, 7.0: 8.0, 8.0: 9.0, 9.0: 10.0}
testing_utils.assert_deep_almost_equal(self, result['score'], expected)
Expand All @@ -112,19 +110,17 @@ def test_regression_cat(self):
config = {
'feature': 'cats',
}
result = self.pdp.run_with_metadata([self.dataset.indexed_examples[0]],
self.reg_model, self.dataset,
config=config)
result = self.pdp.run([self.dataset.examples[0]], self.reg_model,
self.dataset, config=config)
expected = {'One': 2.0, 'None': 1.0}
testing_utils.assert_deep_almost_equal(self, result['score'], expected)

def test_class_num(self):
config = {
'feature': 'num',
}
result = self.pdp.run_with_metadata([self.dataset.indexed_examples[0]],
self.class_model, self.dataset,
config=config)
result = self.pdp.run([self.dataset.examples[0]], self.class_model,
self.dataset, config=config)

expected = {1.0: [0.49, 0.51], 2.0: [0.48, 0.52], 3.0: [0.47, 0.53],
4.0: [0.46, 0.54], 5.0: [0.45, 0.55], 6.0: [0.44, 0.56],
Expand All @@ -136,19 +132,17 @@ def test_classification_cat(self):
config = {
'feature': 'cats',
}
result = self.pdp.run_with_metadata([self.dataset.indexed_examples[0]],
self.class_model, self.dataset,
config=config)
result = self.pdp.run([self.dataset.examples[0]], self.class_model,
self.dataset, config=config)
expected = {'One': [0.49, 0.51], 'None': [0.99, 0.01]}
testing_utils.assert_deep_almost_equal(self, result['probas'], expected)

def test_multiple_inputs(self):
config = {
'feature': 'num',
}
result = self.pdp.run_with_metadata(self.dataset.indexed_examples[0:2],
self.reg_model, self.dataset,
config=config)
result = self.pdp.run(self.dataset.examples[0:2], self.reg_model,
self.dataset, config=config)
expected = {1.0: 1.5, 2.0: 2.5, 3.0: 3.5, 4.0: 4.5, 5.0: 5.5, 6.0: 6.5,
7.0: 7.5, 8.0: 8.5, 9.0: 9.5, 10.0: 10.5}
testing_utils.assert_deep_almost_equal(self, result['score'], expected)
Expand Down
42 changes: 11 additions & 31 deletions lit_nlp/components/salience_clustering_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,31 +37,11 @@ def setUp(self):

def _call_classification_model_on_standard_input(self, config, grad_key):
inputs = [
{
'data': {
'segment': 'a b c d'
}
},
{
'data': {
'segment': 'a b c d'
}
},
{
'data': {
'segment': 'e f e f'
}
},
{
'data': {
'segment': 'e f e f'
}
},
{
'data': {
'segment': 'e f e f'
}
},
{'segment': 'a b c d'},
{'segment': 'a b c d'},
{'segment': 'e f e f'},
{'segment': 'e f e f'},
{'segment': 'e f e f'},
]
model = testing_utils.ClassificationModelForTesting()
dataset = lit_dataset.Dataset(None, None)
Expand Down Expand Up @@ -100,8 +80,8 @@ def _call_classification_model_on_standard_input(self, config, grad_key):

clustering_component = salience_clustering.SalienceClustering(
self.salience_mappers)
result = clustering_component.run_with_metadata(inputs, model, dataset,
model_outputs, config)
result = clustering_component.run(inputs, model, dataset, model_outputs,
config)
return result, clustering_component, inputs, model, dataset, model_outputs

def test_build_vocab(self):
Expand Down Expand Up @@ -215,8 +195,8 @@ def test_clustering_create_new_kmeans(self):
_, clustering_component, inputs, model, dataset, model_outputs = (
self._call_classification_model_on_standard_input(config, grad_key))
kmeans_call_1 = clustering_component.kmeans[grad_key]
clustering_component.run_with_metadata(inputs, model, dataset,
model_outputs, config)
clustering_component.run(inputs, model, dataset, model_outputs,
config)
kmeans_call_2 = clustering_component.kmeans[grad_key]
self.assertIsNot(kmeans_call_1, kmeans_call_2)

Expand All @@ -233,8 +213,8 @@ def test_clustering_reuse_kmeans(self):
kmeans_call_1 = clustering_component.kmeans[grad_key]

config[salience_clustering.REUSE_CLUSTERING] = True
clustering_component.run_with_metadata(inputs, model, dataset,
model_outputs, config)
clustering_component.run(inputs, model, dataset, model_outputs,
config)
kmeans_call_2 = clustering_component.kmeans[grad_key]
self.assertIs(kmeans_call_1, kmeans_call_2)

Expand Down

0 comments on commit 9767670

Please sign in to comment.