Skip to content

Commit

Permalink
Adds init_spec() to lit_nlp.api.dataset classes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 499348503
  • Loading branch information
RyanMullins authored and LIT team committed Jan 4, 2023
1 parent d28eec3 commit d624562
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
25 changes: 25 additions & 0 deletions lit_nlp/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,31 @@ def description(self) -> str:
"""
return self._description or inspect.getdoc(self) or '' # pytype: disable=bad-return-type

def init_spec(self) -> Optional[Spec]:
"""Attempts to infer a Spec describing a Dataset's constructor parameters.
The Dataset base class attempts to infer a Spec for the constructor using
`lit_nlp.api.types.infer_spec_for_func()`.
If successful, this function will return a `dict[str, LitType]`. If
unsucessful (i.e., the inferencer raises a `TypeError` because it encounters
a parameter that it not supported by `infer_spec_for_func()`), this function
will return None, log a warning describing where and how the inferencing
failed, and LIT users **will not** be able to load new instances of this
Dataset from the UI.
Returns:
A Spec representation of the Dataset's constructor, or None if a Spec
could not be inferred.
"""
try:
spec = types.infer_spec_for_func(self.__init__)
except TypeError as e:
spec = None
logging.warning("Unable to infer init spec for model '%s'. %s",
self.__class__.__name__, str(e), exc_info=True)
return spec

def load(self, path: str):
"""Load and return additional previously-saved datapoints for this dataset.
Expand Down
46 changes: 45 additions & 1 deletion lit_nlp/api/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,56 @@
"""Tests for lit_nlp.lib.model."""

from absl.testing import absltest
from absl.testing import parameterized

from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types


class DatasetTest(absltest.TestCase):
class TestDatasetEmptyInit(lit_dataset.Dataset):

def __init__(self):
pass


class TestDatasetPassThroughInit(lit_dataset.Dataset):

def __init__(self, *args, **kwargs):
pass


class TestDatasetInitWithArgs(lit_dataset.Dataset):

def __init__(self, path: str, max_examples: int = 200, max_qps: float = 1.0):
pass


class DatasetTest(parameterized.TestCase):

@parameterized.named_parameters(
("empty_init", TestDatasetEmptyInit()),
("pass_thru_init", TestDatasetPassThroughInit()),
)
def test_init_spec_empty(self, dataset: lit_dataset.Dataset):
self.assertEmpty(dataset.init_spec())

def test_init_spec_populated(self):
dataset = TestDatasetInitWithArgs("test/path")
self.assertEqual(dataset.init_spec(), {
"path": types.String(),
"max_examples": types.Integer(default=200, required=False),
"max_qps": types.Scalar(default=1.0, required=False),
})

@parameterized.named_parameters(
# All base Dataset classes are incompatible with automated spec inference
# due to the complexity of their arguments, thus return None.
("dataset", lit_dataset.Dataset()),
("indexed_dataset", lit_dataset.IndexedDataset(id_fn=lambda x: x)),
("none_dataset", lit_dataset.NoneDataset(models={})),
)
def test_init_spec_none(self, dataset: lit_dataset.Dataset):
self.assertIsNone(dataset.init_spec())

def test_remap(self):
"""Test remap method."""
Expand Down

0 comments on commit d624562

Please sign in to comment.