Skip to content

Commit

Permalink
Removes predict_with_metadata() overrides in Model subclasses.
Browse files Browse the repository at this point in the history
This change leaves lit_nlp.api.model.Model as the only implementer of predict_with_metadata(), which sets up the removal of this method in the future.

PiperOrigin-RevId: 551569578
  • Loading branch information
nadah09 authored and LIT team committed Jul 27, 2023
1 parent ad65fd9 commit bc6f82b
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 34 deletions.
11 changes: 0 additions & 11 deletions lit_nlp/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,17 +285,6 @@ def predict(
) -> Iterable[JsonDict]:
return self.wrapped.predict(inputs, *args, **kw)

# NOTE: if a subclass modifies predict(), it should also override this to
# call the custom predict() method - otherwise this will delegate to the
# wrapped class and call /that class's/ predict() method, likely leading to
# incorrect results.
# b/171513556 will solve this problem by removing the need for any
# *_with_metadata() methods.
def predict_with_metadata(
self, indexed_inputs: Iterable[JsonDict], **kw
) -> Iterable[JsonDict]:
return self.wrapped.predict_with_metadata(indexed_inputs, **kw)

def load(self, path: str):
"""Load a new model and wrap it with this class."""
new_model = self.wrapped.load(path)
Expand Down
8 changes: 0 additions & 8 deletions lit_nlp/examples/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,10 +436,6 @@ def predict(self, inputs):
outputs = self.wrapped.predict(model_inputs)
return (utils.remap_dict(mo, self.FIELD_RENAMES) for mo in outputs)

def predict_with_metadata(self, indexed_inputs):
"""As predict(), but inputs are IndexedInput."""
return self.predict((ex["data"] for ex in indexed_inputs))

def input_spec(self):
spec = lit_types.remap_spec(self.wrapped.input_spec(), self.FIELD_RENAMES)
spec["source_language"] = lit_types.CategoryLabel()
Expand Down Expand Up @@ -505,10 +501,6 @@ def predict(self, inputs):
mo["rougeL"] = float(score["rougeL"].fmeasure)
yield mo

def predict_with_metadata(self, indexed_inputs):
"""As predict(), but inputs are IndexedInput."""
return self.predict((ex["data"] for ex in indexed_inputs))

def input_spec(self):
return lit_types.remap_spec(self.wrapped.input_spec(), self.FIELD_RENAMES)

Expand Down
7 changes: 0 additions & 7 deletions lit_nlp/lib/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,13 +285,6 @@ def predict(self,

return cached_results

# TODO(b/171513556): remove this method once we no longer need to override
# ModelWrapper.predict_with_metadata()
def predict_with_metadata(self, indexed_inputs: Iterable[JsonDict], **kw):
"""As predict(), but inputs are IndexedInput."""
results = self.predict((ex["data"] for ex in indexed_inputs), **kw)
return results

def _get_results_from_cache(self, input_keys: list[CacheKey]):
with self._cache.lock:
return [self._cache.get(input_key) for input_key in input_keys]
16 changes: 8 additions & 8 deletions lit_nlp/lib/caching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,21 @@ def test_caching_model_wrapper_no_dataset_skip_cache(self):
model = testing_utils.IdentityRegressionModelForTesting()
wrapper = caching.CachingModelWrapper(model, "test")
examples = [{"data": {"val": 1}, "id": "my_id"}]
results = wrapper.predict_with_metadata(examples)
results = list(wrapper.predict_with_metadata(examples))
self.assertEqual(1, model.count)
self.assertEqual({"score": 1}, results[0])
results = wrapper.predict_with_metadata(examples)
results = list(wrapper.predict_with_metadata(examples))
self.assertEqual(2, model.count)
self.assertEqual({"score": 1}, results[0])

def test_caching_model_wrapper_use_cache(self):
model = testing_utils.IdentityRegressionModelForTesting()
wrapper = caching.CachingModelWrapper(model, "test")
examples = [{"data": {"val": 1, "_id": "id_to_cache"}, "id": "id_to_cache"}]
results = wrapper.predict_with_metadata(examples)
results = list(wrapper.predict_with_metadata(examples))
self.assertEqual(1, model.count)
self.assertEqual({"score": 1}, results[0])
results = wrapper.predict_with_metadata(examples)
results = list(wrapper.predict_with_metadata(examples))
self.assertEqual(1, model.count)
self.assertEqual({"score": 1}, results[0])
self.assertEmpty(wrapper._cache._pred_locks)
Expand All @@ -60,11 +60,11 @@ def test_caching_model_wrapper_not_cached(self):
model = testing_utils.IdentityRegressionModelForTesting()
wrapper = caching.CachingModelWrapper(model, "test")
examples = [{"data": {"val": 1}, "id": "my_id"}]
results = wrapper.predict_with_metadata(examples)
results = list(wrapper.predict_with_metadata(examples))
self.assertEqual(1, model.count)
self.assertEqual({"score": 1}, results[0])
examples = [{"data": {"val": 2}, "id": "other_id"}]
results = wrapper.predict_with_metadata(examples)
results = list(wrapper.predict_with_metadata(examples))
self.assertEqual(2, model.count)
self.assertEqual({"score": 2}, results[0])

Expand Down Expand Up @@ -98,14 +98,14 @@ def test_caching_model_wrapper_mixed_list(self):
subset = examples[:1]

# Run the CachingModelWrapper over a subset of examples
results = wrapper.predict_with_metadata(subset)
results = list(wrapper.predict_with_metadata(subset))
self.assertEqual(1, model.count)
self.assertEqual({"score": 0}, results[0])

# Now, run the CachingModelWrapper over all of the examples. This should
# only pass the examples that were not in subset to the wrapped model, and
# the total number of inputs processed by the wrapped model should be 3
results = wrapper.predict_with_metadata(examples)
results = list(wrapper.predict_with_metadata(examples))
self.assertEqual(3, model.count)
self.assertEqual({"score": 0}, results[0])
self.assertEqual({"score": 1}, results[1])
Expand Down

0 comments on commit bc6f82b

Please sign in to comment.