Skip to content

Commit

Permalink
Moves Model-Dataset compatibility checks to Model class instead of Mo…
Browse files Browse the repository at this point in the history
…delSpec

PiperOrigin-RevId: 495310542
  • Loading branch information
RyanMullins authored and LIT team committed Dec 14, 2022
1 parent 6e0df41 commit c268ce4
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 121 deletions.
48 changes: 25 additions & 23 deletions lit_nlp/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import inspect
import itertools
import multiprocessing # for ThreadPool
from typing import List, Tuple, Iterable, Iterator, Text, Union
from typing import Iterable, Iterator, Union

import attr
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types
from lit_nlp.lib import utils
import numpy as np
Expand Down Expand Up @@ -71,21 +72,6 @@ class ModelSpec(object):
input: Spec
output: Spec

def is_compatible_with_dataset(self, dataset_spec: Spec) -> bool:
"""Return true if this model is compatible with the dataset spec."""
for key, field_spec in self.input.items():
if key in dataset_spec:
# If the field is in the dataset, make sure it's compatible.
if not dataset_spec[key].is_compatible(field_spec):
return False
else:
# If the field isn't in the dataset, only allow if the model marks as
# optional.
if field_spec.required:
return False

return True


class Model(metaclass=abc.ABCMeta):
"""Base class for LIT models."""
Expand All @@ -106,6 +92,22 @@ def max_minibatch_size(self) -> int:
"""Maximum minibatch size for this model."""
return 1

def is_compatible_with_dataset(self, dataset: lit_dataset.Dataset) -> bool:
"""Return true if this model is compatible with the dataset spec."""
dataset_spec = dataset.spec()
for key, field_spec in self.input_spec().items():
if key in dataset_spec:
# If the field is in the dataset, make sure it's compatible.
if not dataset_spec[key].is_compatible(field_spec):
return False
else:
# If the field isn't in the dataset, only allow if the model marks as
# optional.
if field_spec.required:
return False

return True

@property
def supports_concurrent_predictions(self):
"""Indcates support for multiple concurrent predict calls across threads.
Expand All @@ -119,7 +121,7 @@ def supports_concurrent_predictions(self):
return False

@abc.abstractmethod
def predict_minibatch(self, inputs: List[JsonDict]) -> List[JsonDict]:
def predict_minibatch(self, inputs: list[JsonDict]) -> list[JsonDict]:
"""Run prediction on a batch of inputs.
Args:
Expand Down Expand Up @@ -163,7 +165,7 @@ def output_spec(self) -> types.Spec:
def spec(self) -> ModelSpec:
return ModelSpec(input=self.input_spec(), output=self.output_spec())

def get_embedding_table(self) -> Tuple[List[Text], np.ndarray]:
def get_embedding_table(self) -> tuple[list[str], np.ndarray]:
"""Return the full vocabulary and embedding table.
Implementing this is optional, but needed for some techniques such as
Expand All @@ -175,7 +177,7 @@ def get_embedding_table(self) -> Tuple[List[Text], np.ndarray]:
raise NotImplementedError('get_embedding_table() not implemented for ' +
self.__class__.__name__)

def fit_transform_with_metadata(self, indexed_inputs: List[JsonDict]):
def fit_transform_with_metadata(self, indexed_inputs: list[JsonDict]):
"""For internal use by UMAP and other sklearn-based models."""
raise NotImplementedError(
'fit_transform_with_metadata() not implemented for ' +
Expand Down Expand Up @@ -255,7 +257,7 @@ def max_minibatch_size(self) -> int:
def supports_concurrent_predictions(self):
return self.wrapped.supports_concurrent_predictions

def predict_minibatch(self, inputs: List[JsonDict], **kw) -> List[JsonDict]:
def predict_minibatch(self, inputs: list[JsonDict], **kw) -> list[JsonDict]:
return self.wrapped.predict_minibatch(inputs, **kw)

def predict(self, inputs: Iterable[JsonDict], *args,
Expand Down Expand Up @@ -288,10 +290,10 @@ def spec(self) -> ModelSpec:

##
# Special methods
def get_embedding_table(self) -> Tuple[List[Text], np.ndarray]:
def get_embedding_table(self) -> tuple[list[str], np.ndarray]:
return self.wrapped.get_embedding_table()

def fit_transform_with_metadata(self, indexed_inputs: List[JsonDict]):
def fit_transform_with_metadata(self, indexed_inputs: list[JsonDict]):
return self.wrapped.fit_transform_with_metadata(indexed_inputs)


Expand Down Expand Up @@ -331,7 +333,7 @@ def supports_concurrent_predictions(self):
return True

@abc.abstractmethod
def predict_minibatch(self, inputs: List[JsonDict]) -> List[JsonDict]:
def predict_minibatch(self, inputs: list[JsonDict]) -> list[JsonDict]:
"""Run prediction on a batch of inputs.
Subclass should implement this.
Expand Down
165 changes: 93 additions & 72 deletions lit_nlp/api/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,84 +15,105 @@
"""Tests for lit_nlp.lib.model."""

from absl.testing import absltest
from absl.testing import parameterized

from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import model
from lit_nlp.api import types
from lit_nlp.lib import testing_utils


class SpecTest(absltest.TestCase):

def test_compatibility_fullmatch(self):
"""Test with an exact match."""
mspec = model.ModelSpec(
input={
"text_a": types.TextSegment(),
"text_b": types.TextSegment(),
},
output={})
dspec = mspec.input
self.assertTrue(mspec.is_compatible_with_dataset(dspec))

def test_compatibility_mismatch(self):
"""Test with specs that don't match."""
mspec = model.ModelSpec(
input={
"text_a": types.TextSegment(),
"text_b": types.TextSegment(),
},
output={})
dspec = {"premise": types.TextSegment(), "hypothesis": types.TextSegment()}
self.assertFalse(mspec.is_compatible_with_dataset(dspec))

def test_compatibility_extrafield(self):
"""Test with an extra field in the dataset."""
mspec = model.ModelSpec(
input={
"text_a": types.TextSegment(),
"text_b": types.TextSegment(),
},
output={})
dspec = {
"text_a": types.TextSegment(),
"text_b": types.TextSegment(),
"label": types.CategoryLabel(vocab=["0", "1"]),
}
self.assertTrue(mspec.is_compatible_with_dataset(dspec))

def test_compatibility_optionals(self):
"""Test with optionals in the model spec."""
mspec = model.ModelSpec(
input={
"text": types.TextSegment(),
"tokens": types.Tokens(parent="text", required=False),
"label": types.CategoryLabel(vocab=["0", "1"], required=False),
},
output={})
dspec = {
"text": types.TextSegment(),
"label": types.CategoryLabel(vocab=["0", "1"]),
}
self.assertTrue(mspec.is_compatible_with_dataset(dspec))

def test_compatibility_optionals_mismatch(self):
"""Test with optionals that don't match metadata."""
mspec = model.ModelSpec(
input={
"text": types.TextSegment(),
"tokens": types.Tokens(parent="text", required=False),
"label": types.CategoryLabel(vocab=["0", "1"], required=False),
},
output={})
dspec = {
"text": types.TextSegment(),
# This label field doesn't match the one the model expects.
"label": types.CategoryLabel(vocab=["foo", "bar"]),
}
self.assertFalse(mspec.is_compatible_with_dataset(dspec))


class ModelTest(absltest.TestCase):
class CompatibilityTestModel(model.Model):
"""Dummy model for testing Model.is_compatible_with_dataset()."""

def __init__(self, input_spec: types.Spec):
self._input_spec = input_spec

def input_spec(self) -> types.Spec:
return self._input_spec

def output_spec(self) -> types.Spec:
return {}

def predict_minibatch(self,
inputs: list[model.JsonDict]) -> list[model.JsonDict]:
return []


class ModelTest(parameterized.TestCase):

@parameterized.named_parameters(
dict(
testcase_name="full_match",
input_spec={
"text_a": types.TextSegment(),
"text_b": types.TextSegment(),
},
dataset_spec={
"text_a": types.TextSegment(),
"text_b": types.TextSegment(),
},
expected=True,
),
dict(
testcase_name="mismatch",
input_spec={
"text_a": types.TextSegment(),
"text_b": types.TextSegment(),
},
dataset_spec={
"premise": types.TextSegment(),
"hypothesis": types.TextSegment(),
},
expected=False,
),
dict(
testcase_name="extra_field",
input_spec={
"text_a": types.TextSegment(),
"text_b": types.TextSegment(),
},
dataset_spec={
"text_a": types.TextSegment(),
"text_b": types.TextSegment(),
"label": types.CategoryLabel(vocab=["0", "1"]),
},
expected=True,
),
dict(
testcase_name="optionals",
input_spec={
"text": types.TextSegment(),
"tokens": types.Tokens(parent="text", required=False),
"label": types.CategoryLabel(vocab=["0", "1"], required=False),
},
dataset_spec={
"text": types.TextSegment(),
"label": types.CategoryLabel(vocab=["0", "1"]),
},
expected=True,
),
dict(
testcase_name="optionals_mismatch",
input_spec={
"text": types.TextSegment(),
"tokens": types.Tokens(parent="text", required=False),
"label": types.CategoryLabel(vocab=["0", "1"], required=False),
},
dataset_spec={
"text": types.TextSegment(),
# This label field doesn't match the one the model expects.
"label": types.CategoryLabel(vocab=["foo", "bar"]),
},
expected=False,
),
)
def test_compatibility(self, input_spec: types.Spec, dataset_spec: types.Spec,
expected: bool):
"""Test spec compatibility between models and datasets."""
dataset = lit_dataset.Dataset(spec=dataset_spec)
ctm = CompatibilityTestModel(input_spec)
self.assertEqual(ctm.is_compatible_with_dataset(dataset), expected)

def test_predict(self):
"""Tests predict() for a model with max batch size of 3."""
Expand Down
15 changes: 7 additions & 8 deletions lit_nlp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,19 @@ class LitApp(object):
def _build_metadata(self):
"""Build metadata from model and dataset specs."""
model_info = {}
for name, m in self._models.items():
mspec: lit_model.ModelSpec = m.spec()
for name, model in self._models.items():
info = {
'description': m.description(),
'description': model.description(),
'spec': {
'input': mspec.input,
'output': mspec.output
'input': model.input_spec(),
'output': model.output_spec(),
}
}

# List compatible datasets.
info['datasets'] = [
name for name, dataset in self._datasets.items()
if mspec.is_compatible_with_dataset(dataset.spec())
if model.is_compatible_with_dataset(dataset)
]
if len(info['datasets']) == 0: # pylint: disable=g-explicit-length-test
logging.error("Error: model '%s' has no compatible datasets!", name)
Expand All @@ -81,11 +80,11 @@ def _build_metadata(self):
dataset: lit_dataset.Dataset = self._datasets[d]
compat_gens.update([
name for name, gen in self._generators.items()
if gen.is_compatible(model=m, dataset=dataset)
if gen.is_compatible(model=model, dataset=dataset)
])
compat_interps.update([
name for name, interp in self._interpreters.items()
if interp.is_compatible(model=m, dataset=dataset)
if interp.is_compatible(model=model, dataset=dataset)
])

info['generators'] = [name for name in self._generators.keys()
Expand Down
6 changes: 3 additions & 3 deletions lit_nlp/components/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def __init__(
self._models = models

# Create/Load indices.
for model_name, model_info in self._models.items():
for model_name, model in self._models.items():
compatible_datasets = [
dname for dname, ds in self.datasets.items()
if model_info.spec().is_compatible_with_dataset(ds.spec())
dataset_name for dataset_name, dataset in self.datasets.items()
if model.is_compatible_with_dataset(dataset)
]
for dataset in compatible_datasets:
self._create_empty_indices(model_name, dataset)
Expand Down
12 changes: 0 additions & 12 deletions lit_nlp/examples/models/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from lit_nlp.api import types as lit_types
from lit_nlp.examples.models import imagenet_labels
from lit_nlp.lib import image_utils
from lit_nlp.lib import utils as lit_utils
import numpy as np
import tensorflow as tf

Expand All @@ -17,13 +16,6 @@
class MobileNet(model.Model):
"""MobileNet model trained on ImageNet dataset."""

class MobileNetSpec(model.ModelSpec):

def is_compatible_with_dataset(self, dataset_spec: lit_types.Spec) -> bool:
image_field_names = lit_utils.find_spec_keys(dataset_spec,
lit_types.ImageBytes)
return bool(image_field_names)

def __init__(self) -> None:
# Initialize imagenet labels.
self.labels = [''] * len(imagenet_labels.IMAGENET_2012_LABELS)
Expand Down Expand Up @@ -89,7 +81,3 @@ def output_spec(self):
'grad_target':
lit_types.CategoryLabel(vocab=self.labels)
}

def spec(self) -> model.ModelSpec:
return self.MobileNetSpec(
input=self.input_spec(), output=self.output_spec())

0 comments on commit c268ce4

Please sign in to comment.