Skip to content

Commit

Permalink
Making init_spec() a @classmethod
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 506286474
  • Loading branch information
RyanMullins authored and LIT team committed Feb 1, 2023
1 parent c5f0216 commit db51d9d
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 30 deletions.
10 changes: 6 additions & 4 deletions lit_nlp/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def description(self) -> str:
"""
return self._description or inspect.getdoc(self) or '' # pytype: disable=bad-return-type

def init_spec(self) -> Optional[types.Spec]:
@classmethod
def init_spec(cls) -> Optional[types.Spec]:
"""Attempts to infer a Spec describing a Dataset's constructor parameters.
The Dataset base class attempts to infer a Spec for the constructor using
Expand All @@ -116,11 +117,12 @@ def init_spec(self) -> Optional[types.Spec]:
could not be inferred.
"""
try:
spec = types.infer_spec_for_func(self.__init__)
spec = types.infer_spec_for_func(cls.__init__)
except TypeError as e:
spec = None
logging.warning("Unable to infer init spec for dataset '%s'. %s",
self.__class__.__name__, str(e))
logging.warning(
"Unable to infer init spec for dataset '%s'. %s", cls.__name__, str(e)
)
return spec

def load(self, path: str):
Expand Down
24 changes: 13 additions & 11 deletions lit_nlp/api/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,28 @@ def __init__(self, path: str, max_examples: int = 200, max_qps: float = 1.0):
class DatasetTest(parameterized.TestCase):

@parameterized.named_parameters(
("empty_init", TestDatasetEmptyInit()),
("pass_thru_init", TestDatasetPassThroughInit()),
("empty_init", TestDatasetEmptyInit),
("pass_thru_init", TestDatasetPassThroughInit),
)
def test_init_spec_empty(self, dataset: lit_dataset.Dataset):
self.assertEmpty(dataset.init_spec())

def test_init_spec_populated(self):
dataset = TestDatasetInitWithArgs("test/path")
self.assertEqual(dataset.init_spec(), {
"path": types.String(),
"max_examples": types.Integer(default=200, required=False),
"max_qps": types.Scalar(default=1.0, required=False),
})
self.assertEqual(
TestDatasetInitWithArgs.init_spec(),
{
"path": types.String(),
"max_examples": types.Integer(default=200, required=False),
"max_qps": types.Scalar(default=1.0, required=False),
},
)

@parameterized.named_parameters(
# All base Dataset classes are incompatible with automated spec inference
# due to the complexity of their arguments, thus return None.
("dataset", lit_dataset.Dataset()),
("indexed_dataset", lit_dataset.IndexedDataset(id_fn=lambda x: x)),
("none_dataset", lit_dataset.NoneDataset(models={})),
("dataset", lit_dataset.Dataset),
("indexed_dataset", lit_dataset.IndexedDataset),
("none_dataset", lit_dataset.NoneDataset),
)
def test_init_spec_none(self, dataset: lit_dataset.Dataset):
self.assertIsNone(dataset.init_spec())
Expand Down
10 changes: 6 additions & 4 deletions lit_nlp/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def max_minibatch_size(self) -> int:
"""Maximum minibatch size for this model."""
return 1

def init_spec(self) -> Optional[Spec]:
@classmethod
def init_spec(cls) -> Optional[Spec]:
"""Attempts to infer a Spec describing a Model's constructor parameters.
The Model base class attempts to infer a Spec for the constructor using
Expand All @@ -111,11 +112,12 @@ def init_spec(self) -> Optional[Spec]:
not be inferred.
"""
try:
spec = types.infer_spec_for_func(self.__init__)
spec = types.infer_spec_for_func(cls.__init__)
except TypeError as e:
spec = None
logging.warning("Unable to infer init spec for model '%s'. %s",
self.__class__.__name__, str(e))
logging.warning(
"Unable to infer init spec for model '%s'. %s", cls.__name__, str(e)
)
return spec

def is_compatible_with_dataset(self, dataset: lit_dataset.Dataset) -> bool:
Expand Down
21 changes: 10 additions & 11 deletions lit_nlp/api/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,23 +198,22 @@ def test_batched_predict(self, inputs: list[model.JsonDict],
self.assertEqual(test_model.count, expected_run_count)

def test_init_spec_empty(self):
mdl = TestBatchingModel()
self.assertEmpty(mdl.init_spec())
self.assertEmpty(TestBatchingModel.init_spec())

def test_init_spec_populated(self):
mdl = TestSavedModel("test/path")
self.assertEqual(mdl.init_spec(), {
"path": types.String(),
"compute_embs": types.Boolean(default=False, required=False),
})
self.assertEqual(
TestSavedModel.init_spec(),
{
"path": types.String(),
"compute_embs": types.Boolean(default=False, required=False),
},
)

@parameterized.named_parameters(
("bad_args", CompatibilityTestModel({})),
("bad_args", CompatibilityTestModel),
# All ModelWrapper instances should return None, regardless of the model
# the instance is wrapping.
("wrap_bad_args", model.ModelWrapper(CompatibilityTestModel({}))),
("wrap_good_args", model.ModelWrapper(TestSavedModel("test/path"))),
("wrap_no_args", model.ModelWrapper(TestBatchingModel())),
("wrapper", model.ModelWrapper),
)
def test_init_spec_none(self, mdl: model.Model):
self.assertIsNone(mdl.init_spec())
Expand Down

0 comments on commit db51d9d

Please sign in to comment.