Skip to content

Commit

Permalink
Adds infer_spec_for_func() to lit_nlp.api.types, with tests.
Browse files Browse the repository at this point in the history
This function attempts to infer a Spec describing the arguments of a given function for a class. It will be used in subsequent enhancements in an attempt to automatically infer an init_spec() for lit_nlp.api.dataset.Dataset and lit_nlp.api.model.Model sub-classes.

PiperOrigin-RevId: 499348089
  • Loading branch information
RyanMullins authored and LIT team committed Jan 4, 2023
1 parent 1590a19 commit d28eec3
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 2 deletions.
117 changes: 116 additions & 1 deletion lit_nlp/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
"""
import abc
import enum
import inspect
import math
import numbers
from typing import Any, NewType, Optional, Sequence, Type, TypedDict, Union
from typing import Any, Callable, NewType, Optional, Sequence, Type, TypedDict, Union

import attr
from lit_nlp.api import dtypes
Expand Down Expand Up @@ -894,6 +895,11 @@ class MetricResult(LitType):
best_value: MetricBestValue = MetricBestValue.NONE


@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class Integer(Scalar):
step: int = 1


# LINT.ThenChange(../client/lib/lit_types.ts)

# Type aliases for backend use.
Expand All @@ -908,3 +914,112 @@ def get_type_by_name(typename: str) -> Type[LitType]:
cls = globals()[typename]
assert issubclass(cls, LitType)
return cls


# A map from Python's native type annotations to their LitType corollary for use
# by infer_spec_for_func().
_INFERENCE_TYPES_TO_LIT_TYPES: dict[Type[Any], Callable[..., LitType]] = {
bool: Boolean,
Optional[bool]: Boolean,
float: Scalar,
Optional[float]: Scalar,
Union[float, int]: Scalar,
Optional[Union[float, int]]: Scalar,
int: Integer,
Optional[int]: Integer,
str: String,
Optional[str]: String,
}


def infer_spec_for_func(func: Callable[..., Any]) -> Spec:
"""Infers a Spec from the arguments of a Callable's signature.
LIT uses
[Specs](https://github.com/PAIR-code/lit/blob/main/documentation/api.md#type-system)
as a mechanism to communicate how the web app should construct user interface
elements to enable user input for certain tasks, such as parameterizing an
Interpreter or loading a Model or Dataset at runtime. This includes
information about the type, default value (if any), required status, etc. of
the arguments to enable robust construction of HTML input elements.
As many LIT components are essentially Python functions that can be
parameterized and run via the LIT web app, this function exists to automate
the creation of Specs for some use cases, e.g., the `init_spec()` API of
`lit_nlp.api.dataset.Dataset` and `lit_nlp.api.model.Model` classes. It
attempts to infer a Spec for the Callable passed in as the value of `func` by:
1. Using `inspect.signature()` to retreive the Callable's signature info;
2. Processing `signature.parameters` to transform them into a corollary
`LitType` object that is consumable by the web app, either using
`Parameter.annotation` or by inferring a type from `Parameter.default`;
3. Adding an entry to a `Spec` dictionary where the key is `Paramater.name`
and the value is the `LitType`; and then
4. Returning the `Spec` dictionary after all arguments are processed.
Due to limitations of LIT's typing system and front-end support for these
types, this function is only able to infer Specs for Callables with arguments
of the following types (or `Optional` variants thereof) at this time. Support
for additional types may be added in the future. A `TypeError` will be raised
if this function encounters a type aside from those listed below.
* `bool` ==> `Boolean()`
* `float` ==> `Scalar()`
* `int` ==> `Integer()`
* `Union[float, int]` ==> `Scalar()`
* `str` ==> `String()`
Specs inferred by this function will not include entries for the `self`
parameter of instance methods of classes as this is unnecessary/implied, or
for `*args`- or `**kwargs`-like parameters of any funciton as we cannot safely
infer how variable arguments will be mutated, passed, or used.
Args:
func: The Callable for which a spec will be inferred.
Returns:
A Spec object where the keys are the parameter names and the values are the
`LitType` representation of that parameter (its type, default value, and
whether or not it is required).
Raises:
TypeError: If unable to infer a type, the type is not supported, or `func`
is not a `Callable`.
"""
if not callable(func):
raise TypeError("Attempted to infer a spec for a non-'Callable', "
f"'{type(func)}'.")

signature = inspect.signature(func)
spec: Spec = {}

for param in signature.parameters.values():
if (param.name == "self" or
param.kind is param.VAR_KEYWORD or
param.kind is param.VAR_POSITIONAL):
continue # self, *args, and **kwargs are not returned in inferred Specs.

# Otherwise, attempt to infer a type from the Paramater object.
if param.annotation is param.empty and param.default is param.empty:
raise TypeError(f"Unable to infer a type for parameter '{param.name}' "
f"of '{func.__name__}'. Please add a type hint or "
"default value, or implement a Spec literal.")

if param.annotation is param.empty:
param_type = type(param.default)
else:
param_type = param.annotation

if param_type in _INFERENCE_TYPES_TO_LIT_TYPES:
lit_type_cstr = _INFERENCE_TYPES_TO_LIT_TYPES[param_type]
lit_type_params = {"required": param.default is param.empty,}
if param.default is not param.empty:
lit_type_params["default"] = param.default
spec[param.name] = lit_type_cstr(**lit_type_params)
else:
raise TypeError(f"Unsupported type '{param_type}' for parameter "
f"'{param.name}' of '{func.__name__}'. If possible "
"(e.g., this parameter is Optional), please implement a "
"spec literal instead of using inferencing.")

return spec
219 changes: 218 additions & 1 deletion lit_nlp/api/types_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for types."""

from typing import Any
from typing import Any, Callable, Optional, Union
from absl.testing import absltest
from absl.testing import parameterized
from lit_nlp.api import dtypes
Expand Down Expand Up @@ -575,5 +575,222 @@ def test_attention(self):
out_spec, output, ds_spec, ds_spec, example)


class InferSpecTarget(object):
"""A dummy class for testing infer_spec_for_func against a class."""

def __init__(self, arg: bool = True):
self._arg = arg


# The following is a series of identity functions used to test
# types.infer_spec_for_func(Callable[..., Any]). These need to exist in this way
# because other options (e.g., lambdas) do not support annotation, which is
# required to test infer_spec_for_func() success cases.


def _bool_param(param: bool = True) -> bool:
return param


def _float_param(param: float = 1.2345) -> float:
return param


def _int_param(param: int = 1) -> int:
return param


def _many_params(
param_1: bool,
param_2: Optional[float],
param_3: int = 1,
param_4: Optional[str] = None
) -> tuple[bool, Optional[float], int, Optional[str]]:
return (param_1, param_2, param_3, param_4)


def _no_annotation(param="param") -> str:
return param


def _no_annotation_or_default(param) -> Any:
return param


def _no_args() -> Any:
return {}


def _no_default(param: str) -> str:
return param


def _object_param(param: object) -> object:
return param


def _optional_bool_param(param: Optional[bool]) -> Optional[bool]:
return param


def _optional_float_param(param: Optional[float]) -> Optional[float]:
return param


def _optional_int_param(param: Optional[int]) -> Optional[int]:
return param


def _optional_scalar_param(
param: Optional[Union[float, int]] = 1.2345) -> Optional[Union[float, int]]:
return param


def _optional_str_param(param: Optional[str]) -> Optional[str]:
return param


def _scalar_param(param: Union[float, int]) -> Union[float, int]:
return param


def _str_param(param: str = "str") -> str:
return param


class InferSpecTests(parameterized.TestCase):

@parameterized.named_parameters(
dict(
testcase_name="class",
func=InferSpecTarget,
expected_spec={"arg": types.Boolean(default=True, required=False)},
),
dict(
testcase_name="class_init",
func=InferSpecTarget.__init__,
expected_spec={"arg": types.Boolean(default=True, required=False)},
),
dict(
testcase_name="class_instance_init",
func=InferSpecTarget().__init__,
expected_spec={"arg": types.Boolean(default=True, required=False)},
),
dict(
testcase_name="empty_spec",
func=_no_args,
expected_spec={},
),
dict(
testcase_name="many_params",
func=_many_params,
expected_spec={
"param_1": types.Boolean(required=True),
"param_2": types.Scalar(required=True),
"param_3": types.Integer(default=1, required=False, step=1),
"param_4": types.String(default=None, required=False),
},
),
dict(
testcase_name="no_annotation",
func=_no_annotation,
expected_spec={
"param": types.String(default="param", required=False),
},
),
dict(
testcase_name="no_default",
func=_no_default,
expected_spec={
"param": types.String(default="", required=True),
},
),
dict(
testcase_name="optional_bool",
func=_optional_bool_param,
expected_spec={
"param": types.Boolean(required=True),
},
),
dict(
testcase_name="optional_float",
func=_optional_float_param,
expected_spec={
"param": types.Scalar(required=True),
},
),
dict(
testcase_name="optional_int",
func=_optional_int_param,
expected_spec={
"param": types.Integer(required=True),
},
),
dict(
testcase_name="optional_scalar",
func=_optional_scalar_param,
expected_spec={
"param": types.Scalar(default=1.2345, required=False),
},
),
dict(
testcase_name="optional_str",
func=_optional_str_param,
expected_spec={
"param": types.String(required=True),
},
),
dict(
testcase_name="single_bool",
func=_bool_param,
expected_spec={
"param": types.Boolean(default=True, required=False),
},
),
dict(
testcase_name="single_float",
func=_float_param,
expected_spec={
"param": types.Scalar(default=1.2345, required=False),
},
),
dict(
testcase_name="single_int",
func=_int_param,
expected_spec={
"param": types.Integer(default=1, required=False),
},
),
dict(
testcase_name="single_scalar",
func=_scalar_param,
expected_spec={
"param": types.Scalar(required=True),
},
),
dict(
testcase_name="single_str",
func=_str_param,
expected_spec={
"param": types.String(default="str", required=False),
},
),
)
def test_infer_spec_for_func(self, func: Callable[..., Any],
expected_spec: types.Spec):
spec = types.infer_spec_for_func(func)
self.assertEqual(spec, expected_spec)

@parameterized.named_parameters(
("class_instance", InferSpecTarget()),
("lambda", lambda x: x),
("no_annotation_or_default", _no_annotation_or_default),
("not_a_callable", "not_a_callable"),
("unsupported_type", _object_param),
)
def test_infer_spec_for_func_errors(self, func: Any):
self.assertRaises(TypeError, types.infer_spec_for_func, func)


if __name__ == "__main__":
absltest.main()
8 changes: 8 additions & 0 deletions lit_nlp/client/lib/lit_types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,14 @@ export class Scalar extends LitType {
step: number = .01;
}

/**
* An integer value
*/
@registered
export class Integer extends Scalar {
override step: number = 1;
}

/**
* Regression score, a single float.
*/
Expand Down

0 comments on commit d28eec3

Please sign in to comment.