Skip to content

Commit

Permalink
Move the batching logic from Model to BatchedModel.
Browse files Browse the repository at this point in the history
The `Model` class will keep an abstract `predict` method, and both `BatchedModel` and `BatchedRemoteModel` will implement `predict` but require their subclasses to implement `predict_minibatch`.

PiperOrigin-RevId: 555984644
  • Loading branch information
bdu91 authored and LIT team committed Aug 11, 2023
1 parent d25392b commit 6fdcbfe
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 88 deletions.
121 changes: 61 additions & 60 deletions lit_nlp/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,6 @@ def description(self) -> str:
"""
return inspect.getdoc(self) or ''

def max_minibatch_size(self) -> int:
"""Maximum minibatch size for this model."""
return 1

@classmethod
def init_spec(cls) -> Optional[Spec]:
"""Attempts to infer a Spec describing a Model's constructor parameters.
Expand Down Expand Up @@ -137,22 +133,10 @@ def supports_concurrent_predictions(self):
Returns:
(bool) True if the model can handle multiple concurrent calls to its
`predict_minibatch` method.
`predict` method.
"""
return False

@abc.abstractmethod
def predict_minibatch(self, inputs: list[JsonDict]) -> list[JsonDict]:
"""Run prediction on a batch of inputs.
Args:
inputs: sequence of inputs, following model.input_spec()
Returns:
list of outputs, following model.output_spec()
"""
return

def load(self, path: str):
"""Load and return a new instance of this model loaded from a new path.
Expand Down Expand Up @@ -194,41 +178,10 @@ def get_embedding_table(self) -> tuple[list[str], np.ndarray]:
raise NotImplementedError('get_embedding_table() not implemented for ' +
self.__class__.__name__)

##
# Concrete implementations of common functions.
@abc.abstractmethod
def predict(self, inputs: Iterable[JsonDict], **kw) -> Iterable[JsonDict]:
"""Run prediction on a dataset.
This uses minibatch inference for efficiency, but yields per-example output.
This will also copy some NumPy arrays if they look like slices of a larger
tensor. This adds some overhead, but reduces memory leaks by allowing the
source tensor (which may be a large padded matrix) to be garbage collected.
Args:
inputs: iterable of input dicts
**kw: additional kwargs passed to predict_minibatch()
Returns:
model outputs, for each input
"""
results = self._batched_predict(inputs, **kw)
results = (scrub_numpy_refs(res) for res in results)
return results

def _batched_predict(self, inputs: Iterable[JsonDict],
**kw) -> Iterator[JsonDict]:
"""Internal helper to predict using minibatches."""
minibatch_size = self.max_minibatch_size(**kw)
minibatch = []
for ex in inputs:
if len(minibatch) < minibatch_size:
minibatch.append(ex)
if len(minibatch) >= minibatch_size:
yield from self.predict_minibatch(minibatch, **kw)
minibatch = []
if len(minibatch) > 0: # pylint: disable=g-explicit-length-test
yield from self.predict_minibatch(minibatch, **kw)
"""Run prediction on a list of inputs and return the outputs."""
pass


class ModelWrapper(Model):
Expand All @@ -250,16 +203,10 @@ def wrapped(self):
def description(self) -> str:
return self.wrapped.description()

def max_minibatch_size(self) -> int:
return self.wrapped.max_minibatch_size()

@property
def supports_concurrent_predictions(self):
return self.wrapped.supports_concurrent_predictions

def predict_minibatch(self, inputs: list[JsonDict], **kw) -> list[JsonDict]:
return self.wrapped.predict_minibatch(inputs, **kw)

def predict(
self, inputs: Iterable[JsonDict], *args, **kw
) -> Iterable[JsonDict]:
Expand All @@ -285,10 +232,64 @@ def get_embedding_table(self) -> tuple[list[str], np.ndarray]:
class BatchedModel(Model):
"""Generic base class for the batched model.
Currently this is a no-op pass-through of Model class and will be updated
after moving users of Model class over.
Subclass needs to implement predict_minibatch() and optionally
max_minibatch_size().
"""
pass

def max_minibatch_size(self) -> int:
"""Maximum minibatch size for this model."""
return 1

@property
def supports_concurrent_predictions(self):
return False

@abc.abstractmethod
def predict_minibatch(self, inputs: list[JsonDict]) -> list[JsonDict]:
"""Run prediction on a batch of inputs.
Args:
inputs: sequence of inputs, following model.input_spec()
Returns:
list of outputs, following model.output_spec()
"""
pass

def predict(self, inputs: Iterable[JsonDict], **kw) -> Iterable[JsonDict]:
"""Run prediction on a dataset.
This uses minibatch inference for efficiency, but yields per-example output.
This will also copy some NumPy arrays if they look like slices of a larger
tensor. This adds some overhead, but reduces memory leaks by allowing the
source tensor (which may be a large padded matrix) to be garbage collected.
Args:
inputs: iterable of input dicts
**kw: additional kwargs passed to predict_minibatch()
Returns:
model outputs, for each input
"""
results = self.batched_predict(inputs, **kw)
results = (scrub_numpy_refs(res) for res in results)
return results

def batched_predict(
self, inputs: Iterable[JsonDict], **kw
) -> Iterator[JsonDict]:
"""Internal helper to predict using minibatches."""
minibatch_size = self.max_minibatch_size(**kw)
minibatch = []
for ex in inputs:
if len(minibatch) < minibatch_size:
minibatch.append(ex)
if len(minibatch) >= minibatch_size:
yield from self.predict_minibatch(minibatch, **kw)
minibatch = []
if len(minibatch) > 0: # pylint: disable=g-explicit-length-test
yield from self.predict_minibatch(minibatch, **kw)


class BatchedRemoteModel(Model):
Expand Down
5 changes: 2 additions & 3 deletions lit_nlp/api/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def input_spec(self) -> types.Spec:
def output_spec(self) -> types.Spec:
return {}

def predict_minibatch(self,
inputs: list[model.JsonDict]) -> list[model.JsonDict]:
def predict(self, inputs: list[model.JsonDict]) -> list[model.JsonDict]:
return []


Expand Down Expand Up @@ -77,7 +76,7 @@ def input_spec(self) -> types.Spec:
def output_spec(self) -> types.Spec:
return {}

def predict_minibatch(self, *args, **kwargs) -> list[types.JsonDict]:
def predict(self, *args, **kwargs) -> list[types.JsonDict]:
return []


Expand Down
10 changes: 6 additions & 4 deletions lit_nlp/components/hotflip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def output_spec(self) -> dict[str, lit_types.LitType]:
def get_embedding_table(self):
return ([], np.ndarray([]))

def predict_minibatch(
self, inputs: list[lit_model.JsonDict]) -> list[lit_model.JsonDict]:
def predict(
self, inputs: list[lit_model.JsonDict]
) -> list[lit_model.JsonDict]:
pass


Expand Down Expand Up @@ -108,8 +109,9 @@ def output_spec(self) -> dict[str, lit_types.LitType]:
def get_embedding_table(self):
return ([], np.ndarray([]))

def predict_minibatch(
self, inputs: list[lit_model.JsonDict]) -> list[lit_model.JsonDict]:
def predict(
self, inputs: list[lit_model.JsonDict]
) -> list[lit_model.JsonDict]:
pass


Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/components/shap_explainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def input_spec(self) -> lit_types.Spec:
def output_spec(self) -> lit_types.Spec:
return {}

def predict_minibatch(self, inputs, **kw):
def predict(self, inputs, **kw):
return None


Expand Down
14 changes: 0 additions & 14 deletions lit_nlp/examples/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,6 @@ def preprocess(self, ex: JsonDict) -> JsonDict:
def description(self) -> str:
return "T5 for machine translation\n" + self.wrapped.description()

# TODO(b/170662608): remove these after batching API is cleaned up.
def max_minibatch_size(self) -> int:
raise NotImplementedError("Use predict() instead.")

def predict_minibatch(self, inputs):
raise NotImplementedError("Use predict() instead.")

def predict(self, inputs):
"""Predict on a single minibatch of examples."""
model_inputs = (self.preprocess(ex) for ex in inputs)
Expand Down Expand Up @@ -479,13 +472,6 @@ def preprocess(self, ex: JsonDict) -> JsonDict:
def description(self) -> str:
return "T5 for summarization\n" + self.wrapped.description()

# TODO(b/170662608): remove these after batching API is cleaned up.
def max_minibatch_size(self) -> int:
raise NotImplementedError("Use predict() instead.")

def predict_minibatch(self, inputs):
raise NotImplementedError("Use predict() instead.")

def predict(self, inputs):
"""Predict on a single minibatch of examples."""
inputs = list(inputs) # needs to be referenced below, so keep full list
Expand Down
6 changes: 0 additions & 6 deletions lit_nlp/lib/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,6 @@ def fit_transform(self, inputs: Iterable[JsonDict]):
self._cache.put(output, cache_key)
return outputs

# TODO(b/170662608) Remove once batching logic changes are done.
def predict_minibatch(self, *args, **kw):
raise RuntimeError(
"This method should be inaccessible as it bypasses the cache. Please"
" use CachingModelWrapper.predict().")

def predict(self,
inputs: Iterable[JsonDict],
progress_indicator: Optional[ProgressIndicator] = lambda x: x,
Expand Down

0 comments on commit 6fdcbfe

Please sign in to comment.