Skip to content

Commit

Permalink
Remove redundant Model.spec() method, plus other minor cleanup
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 508432240
  • Loading branch information
iftenney authored and LIT team committed Feb 9, 2023
1 parent c794605 commit 16b72f7
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 54 deletions.
2 changes: 1 addition & 1 deletion lit_nlp/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def __init__(self, models): # pylint: disable=super-init-not-called
def spec(self):
combined_spec = {}
for _, model in self._models.items():
req_inputs = {k: v for (k, v) in model.spec().input.items() if v.required}
req_inputs = {k: v for (k, v) in model.input_spec().items() if v.required}
# Ensure that there are no conflicting spec keys.
assert not self.has_conflicting_keys(combined_spec, req_inputs)
combined_spec.update(req_inputs)
Expand Down
53 changes: 18 additions & 35 deletions lit_nlp/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import Iterable, Iterator, Optional, Union

from absl import logging
import attr
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types
from lit_nlp.lib import utils
Expand All @@ -30,7 +29,7 @@
Spec = types.Spec


def maybe_copy(arr):
def maybe_copy_np(arr):
"""Decide if we should make a copy of an array in order to release memory.
NumPy arrays may be views into other array objects, by which a small array can
Expand Down Expand Up @@ -63,15 +62,8 @@ def maybe_copy(arr):


def scrub_numpy_refs(output: JsonDict) -> JsonDict:
"""Scrub problematic pointers. See maybe_copy() and Model.predict()."""
return {k: maybe_copy(v) for k, v in output.items()}


@attr.s(auto_attribs=True, frozen=True)
class ModelSpec(object):
"""Model spec."""
input: Spec
output: Spec
"""Scrub numpy pointers; see maybe_copy_np() and Model.predict()."""
return {k: maybe_copy_np(v) for k, v in output.items()}


class Model(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -189,10 +181,6 @@ def output_spec(self) -> types.Spec:
"""Return a spec describing model outputs."""
return

# TODO(lit-dev): annotate as @final once we migrate to python 3.8+
def spec(self) -> ModelSpec:
return ModelSpec(input=self.input_spec(), output=self.output_spec())

def get_embedding_table(self) -> tuple[list[str], np.ndarray]:
"""Return the full vocabulary and embedding table.
Expand All @@ -213,29 +201,24 @@ def fit_transform_with_metadata(self, indexed_inputs: list[JsonDict]):

##
# Concrete implementations of common functions.
def predict(self,
inputs: Iterable[JsonDict],
scrub_arrays=True,
**kw) -> Iterator[JsonDict]:
def predict(self, inputs: Iterable[JsonDict], **kw) -> Iterable[JsonDict]:
"""Run prediction on a dataset.
This uses minibatch inference for efficiency, but yields per-example output.
This will also copy some NumPy arrays if they look like slices of a larger
tensor. This adds some overhead, but reduces memory leaks by allowing the
source tensor (which may be a large padded matrix) to be garbage collected.
Args:
inputs: iterable of input dicts
scrub_arrays: if True, will copy some returned NumPy arrays in order to
allow garbage collection of intermediate data. Strongly recommended if
results will not be immediately consumed and discarded, as otherwise the
common practice of slicing arrays returned by e.g. TensorFlow can result
in large memory leaks.
**kw: additional kwargs passed to predict_minibatch()
Returns:
model outputs, for each input
"""
results = self._batched_predict(inputs, **kw)
if scrub_arrays:
results = (scrub_numpy_refs(res) for res in results)
results = (scrub_numpy_refs(res) for res in results)
return results

def _batched_predict(self, inputs: Iterable[JsonDict],
Expand All @@ -253,8 +236,9 @@ def _batched_predict(self, inputs: Iterable[JsonDict],
yield from self.predict_minibatch(minibatch, **kw)

# TODO(b/171513556): remove this method.
def predict_with_metadata(self, indexed_inputs: Iterable[JsonDict],
**kw) -> Iterator[JsonDict]:
def predict_with_metadata(
self, indexed_inputs: Iterable[JsonDict], **kw
) -> Iterable[JsonDict]:
"""As predict(), but inputs are IndexedInput."""
return self.predict((ex['data'] for ex in indexed_inputs), **kw)

Expand Down Expand Up @@ -288,8 +272,9 @@ def supports_concurrent_predictions(self):
def predict_minibatch(self, inputs: list[JsonDict], **kw) -> list[JsonDict]:
return self.wrapped.predict_minibatch(inputs, **kw)

def predict(self, inputs: Iterable[JsonDict], *args,
**kw) -> Iterator[JsonDict]:
def predict(
self, inputs: Iterable[JsonDict], *args, **kw
) -> Iterable[JsonDict]:
return self.wrapped.predict(inputs, *args, **kw)

# NOTE: if a subclass modifies predict(), it should also override this to
Expand All @@ -298,8 +283,9 @@ def predict(self, inputs: Iterable[JsonDict], *args,
# incorrect results.
# b/171513556 will solve this problem by removing the need for any
# *_with_metadata() methods.
def predict_with_metadata(self, indexed_inputs: Iterable[JsonDict],
**kw) -> Iterator[JsonDict]:
def predict_with_metadata(
self, indexed_inputs: Iterable[JsonDict], **kw
) -> Iterable[JsonDict]:
return self.wrapped.predict_with_metadata(indexed_inputs, **kw)

def load(self, path: str):
Expand All @@ -313,9 +299,6 @@ def input_spec(self) -> types.Spec:
def output_spec(self) -> types.Spec:
return self.wrapped.output_spec()

def spec(self) -> ModelSpec:
return ModelSpec(input=self.input_spec(), output=self.output_spec())

##
# Special methods
def get_embedding_table(self) -> tuple[list[str], np.ndarray]:
Expand Down
34 changes: 20 additions & 14 deletions lit_nlp/components/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ def run(self,
if model_outputs is None:
model_outputs = list(model.predict(inputs))

spec = model.spec()
field_map = map_pred_keys(dataset.spec(), spec.output,
self.is_field_compatible)
output_spec = model.output_spec()
field_map = map_pred_keys(
dataset.spec(), output_spec, self.is_field_compatible
)
ret = []
for pred_key, label_key in field_map.items():
# Extract fields
Expand All @@ -90,8 +91,9 @@ def run(self,
labels,
preds,
label_spec=dataset.spec()[label_key],
pred_spec=spec.output[pred_key],
config=config.get(pred_key) if config else None)
pred_spec=output_spec[pred_key],
config=config.get(pred_key) if config else None,
)
# Format for frontend.
ret.append({
'pred_key': pred_key,
Expand All @@ -112,9 +114,10 @@ def run_with_metadata(self,
# TODO(lit-team): pre-compute this mapping in constructor?
# This would require passing a model name to this function so we can
# reference a pre-computed list.
spec = model.spec()
field_map = map_pred_keys(dataset.spec(), spec.output,
self.is_field_compatible)
output_spec = model.output_spec()
field_map = map_pred_keys(
dataset.spec(), output_spec, self.is_field_compatible
)
ret = []
for pred_key, label_key in field_map.items():
# Extract fields
Expand All @@ -127,10 +130,11 @@ def run_with_metadata(self,
labels,
preds,
label_spec=dataset.spec()[label_key],
pred_spec=spec.output[pred_key],
pred_spec=output_spec[pred_key],
indices=indices,
metas=metas,
config=config.get(pred_key) if config else None)
config=config.get(pred_key) if config else None,
)
# Format for frontend.
ret.append({
'pred_key': pred_key,
Expand Down Expand Up @@ -173,8 +177,9 @@ def run(self,
config: Optional[JsonDict] = None):
# Get margin for each input for each pred key and add them to a config dict
# to pass to the wrapped metrics.
field_map = map_pred_keys(dataset.spec(),
model.spec().output, self.is_field_compatible)
field_map = map_pred_keys(
dataset.spec(), model.output_spec(), self.is_field_compatible
)
margin_config = {}
for pred_key in field_map:
field_config = config.get(pred_key) if config else None
Expand All @@ -194,8 +199,9 @@ def run_with_metadata(self,
config: Optional[JsonDict] = None) -> list[JsonDict]:
# Get margin for each input for each pred key and add them to a config dict
# to pass to the wrapped metrics.
field_map = map_pred_keys(dataset.spec(),
model.spec().output, self.is_field_compatible)
field_map = map_pred_keys(
dataset.spec(), model.output_spec(), self.is_field_compatible
)
margin_config = {}
for pred_key in field_map:
inputs = [ex['data'] for ex in indexed_inputs]
Expand Down
9 changes: 5 additions & 4 deletions lit_nlp/components/remote_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,17 @@ def __init__(self, url: Text, name: Text, max_minibatch_size: int = 256):

# Get specs
server_info = query_lit_server(self._url, 'get_info')
self._spec = lit_model.ModelSpec(
**server_info['models'][self._name]['spec'])
model_spec = server_info['models'][self._name]['spec']
self._input_spec = model_spec['input']
self._output_spec = model_spec['output']

self._max_minibatch_size = max_minibatch_size

def input_spec(self):
return self._spec.input
return self._input_spec

def output_spec(self):
return self._spec.output
return self._output_spec

def max_minibatch_size(self):
return self._max_minibatch_size
Expand Down

0 comments on commit 16b72f7

Please sign in to comment.