From 46a6c403f9ca6d6938014ea706d96940ce09998d Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Wed, 8 Feb 2023 08:51:38 +0000 Subject: [PATCH 01/21] Switch from using apischema to pydantic --- pyproject.toml | 4 ++-- src/blueapi/service/model.py | 44 +++++++++++++++--------------------- 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b27dfef914..6ba3db4d08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,9 +18,9 @@ dependencies = [ "ophyd", "nslsii", "pyepics", - "apischema", + "pydantic", "stomp.py", - "scanspec<=0.5.5", + "scanspec", "PyYAML", "click", ] diff --git a/src/blueapi/service/model.py b/src/blueapi/service/model.py index 9a9c6087df..ee220e64f2 100644 --- a/src/blueapi/service/model.py +++ b/src/blueapi/service/model.py @@ -1,29 +1,27 @@ -from dataclasses import dataclass from typing import Iterable, List -from apischema import settings from bluesky.protocols import HasName +from pydantic import BaseModel, Field from blueapi.core import BLUESKY_PROTOCOLS, Device, Plan _UNKNOWN_NAME = "UNKNOWN" -settings.camel_case = True - -@dataclass -class DeviceModel: +class DeviceModel(BaseModel): """ Representation of a device """ - name: str - protocols: List[str] + name: str = Field(description="Name of the device") + protocols: List[str] = Field( + description="Protocols that a device conforms to, indicating its capabilities" + ) @classmethod def from_device(cls, device: Device) -> "DeviceModel": name = device.name if isinstance(device, HasName) else _UNKNOWN_NAME - return cls(name, list(_protocol_names(device))) + return cls(name=name, protocols=list(_protocol_names(device))) def _protocol_names(device: Device) -> Iterable[str]: @@ -32,8 +30,7 @@ def _protocol_names(device: Device) -> Iterable[str]: yield protocol.__name__ -@dataclass -class DeviceRequest: +class DeviceRequest(BaseModel): """ A query for devices """ @@ -41,30 +38,27 @@ class DeviceRequest: ... -@dataclass -class DeviceResponse: +class DeviceResponse(BaseModel): """ Response to a query for devices """ - devices: List[DeviceModel] + devices: List[DeviceModel] = Field(description="Devices available to use in plans") -@dataclass -class PlanModel: +class PlanModel(BaseModel): """ Representation of a plan """ - name: str + name: str = Field(description="Name of the plan") @classmethod def from_plan(cls, plan: Plan) -> "PlanModel": - return cls(plan.name) + return cls(name=plan.name) -@dataclass -class PlanRequest: +class PlanRequest(BaseModel): """ A query for plans """ @@ -72,19 +66,17 @@ class PlanRequest: ... -@dataclass -class PlanResponse: +class PlanResponse(BaseModel): """ Response to a query for plans """ - plans: List[PlanModel] + plans: List[PlanModel] = Field(description="Plans available to use by a worker") -@dataclass -class TaskResponse: +class TaskResponse(BaseModel): """ Acknowledgement that a task has started, includes its ID """ - task_name: str + task_name: str = Field(description="Unique identifier for the task") From 2e5201bb1435ca5e52f82f536f36ea7dfe90819f Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Mon, 13 Feb 2023 14:54:47 +0000 Subject: [PATCH 02/21] WIP Convert apischema custom logic to pydantic --- src/blueapi/core/bluesky_types.py | 12 +++-- src/blueapi/core/context.py | 14 ++++-- src/blueapi/utils/schema.py | 15 +++--- src/blueapi/worker/event.py | 84 ++++++++++++++++++++++--------- src/blueapi/worker/task.py | 69 ++++++++++++------------- 5 files changed, 116 insertions(+), 78 deletions(-) diff --git a/src/blueapi/core/bluesky_types.py b/src/blueapi/core/bluesky_types.py index 2a757b2c3e..b2e6182079 100644 --- a/src/blueapi/core/bluesky_types.py +++ b/src/blueapi/core/bluesky_types.py @@ -1,6 +1,6 @@ import inspect from dataclasses import dataclass -from typing import Any, Callable, Generator, Mapping, Type, Union +from typing import Any, Callable, Generator, Mapping, Optional, Type, Union from bluesky.protocols import ( Checkable, @@ -20,6 +20,7 @@ WritesExternalAssets, ) from bluesky.utils import Msg +from pydantic import BaseModel, Field try: from typing import Protocol, runtime_checkable @@ -72,14 +73,15 @@ def is_bluesky_plan_generator(func: PlanGenerator) -> bool: ) -@dataclass -class Plan: +class Plan(BaseModel): """ A plan that can be run """ - name: str - model: Type[Any] + name: str = Field(description="Referenceable name of the plan") + model: Type[BaseModel] = Field( + description="Validation model of the parameters for the plan" + ) @dataclass diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 6e484e94ff..e607929e80 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -7,6 +7,8 @@ from bluesky import RunEngine from bluesky.protocols import Flyable, Readable +from pydantic import BaseConfig, Extra, validate_arguments +from pydantic.decorator import ValidatedFunction from blueapi.utils import load_module_all, schema_for_func @@ -22,6 +24,10 @@ LOGGER = logging.getLogger(__name__) +class PlanConfig(BaseConfig): + extra = Extra.forbid + + @dataclass class BlueskyContext: """ @@ -31,9 +37,8 @@ class BlueskyContext: run_engine: RunEngine = field( default_factory=lambda: RunEngine(context_managers=[]) ) - plans: Dict[str, Plan] = field(default_factory=dict) + plans: Dict[str, ValidatedFunction] = field(default_factory=dict) devices: Dict[str, Device] = field(default_factory=dict) - plan_functions: Dict[str, PlanGenerator] = field(default_factory=dict) def find_device(self, addr: Union[str, List[str]]) -> Optional[Device]: """ @@ -107,9 +112,8 @@ def my_plan(a: int, b: str): if not is_bluesky_plan_generator(plan): raise TypeError(f"{plan} is not a valid plan generator function") - schema = schema_for_func(plan) - self.plans[plan.__name__] = Plan(plan.__name__, schema) - self.plan_functions[plan.__name__] = plan + schema = ValidatedFunction(plan, PlanConfig) + self.plans[plan.__name__] = schema return plan def device(self, device: Device, name: Optional[str] = None) -> None: diff --git a/src/blueapi/utils/schema.py b/src/blueapi/utils/schema.py index a73ce72828..88e260fe4f 100644 --- a/src/blueapi/utils/schema.py +++ b/src/blueapi/utils/schema.py @@ -5,32 +5,31 @@ from apischema import deserialize from apischema.conversions.conversions import Conversion from apischema.conversions.converters import AnyConversion, default_deserialization +from pydantic import BaseModel -def schema_for_func(func: Callable[..., Any]) -> Type: +def schema_for_func(func: Callable[..., Any]) -> BaseModel: """ - Generate a dataclass that acts as a schema for validation with apischema. + Generate a pydantic model of the set of parameters to a function. Inspect the parameters, default values and type annotations of a function and generate the schema. Example: - def foo(a: int, b: str, c: bool): + def foo(a: int, b: str, c: bool = False): ... schema = schema_for_func(foo) Schema is the runtime equivalent of: - @dataclass - class fooo_params: + class foo_params(BaseModel): a: int b: str - c: bool + c: bool = False Args: - func (Callable[..., Any]): The source function, all parameters must have type - annotations + func: The source function, all parameters must have type annotations Raises: TypeError: If a type annotation is either `Any` or not supplied diff --git a/src/blueapi/worker/event.py b/src/blueapi/worker/event.py index 75edc413dc..84faa82573 100644 --- a/src/blueapi/worker/event.py +++ b/src/blueapi/worker/event.py @@ -1,8 +1,8 @@ -from dataclasses import dataclass, field from enum import Enum from typing import List, Mapping, Optional, Union from bluesky.run_engine import RunEngineStateMachine +from pydantic import BaseModel, Field from super_state_machine.extras import PropertyMachine, ProxyString # The RunEngine can return any of these three types as its state @@ -27,42 +27,81 @@ class WorkerState(Enum): @classmethod def from_bluesky_state(cls, bluesky_state: RawRunEngineState) -> "WorkerState": + """Convert the state of a bluesky RunEngine + + Args: + bluesky_state: Bluesky RunEngine state + + Returns: + RunnerState: Mapped RunEngine state + """ + if isinstance(bluesky_state, RunEngineStateMachine.States): return cls.from_bluesky_state(bluesky_state.value) return WorkerState(str(bluesky_state).upper()) -@dataclass -class StatusView: +class WorkerEvent(BaseModel): """ - A snapshot of a Status, optionally representing progress + Event emitted by a worker when the runner state changes """ - display_name: str = "UNKNOWN" - current: Optional[float] = None - initial: Optional[float] = None - target: Optional[float] = None - unit: str = "units" - precision: int = 3 - done: bool = False - percentage: Optional[float] = None - time_elapsed: Optional[float] = None - time_remaining: Optional[float] = None + state: WorkerState = Field(description="Current state of the worker") + current_task_name: Optional[str] = Field( + description="Unique ID of the currently running task, if any", default=None + ) -@dataclass -class ProgressEvent: +class StatusView(BaseModel): + """ + A snapshot of a Status of an operation, optionally representing progress + """ + + display_name: str = Field( + description="Human-readable name indicating what this status describes", + default="Unknown", + ) + current: Optional[float] = Field( + description="Current value of operation progress, if known", default=None + ) + initial: Optional[float] = Field( + description="Initial value of operation progress, if known", default=None + ) + target: Optional[float] = Field( + description="Target value operation of progress, if known", default=None + ) + unit: str = Field(description="Units of progress", default="units") + precision: int = Field( + description="Sensible precision of progress to display", default=3 + ) + done: bool = Field( + description="Whether the operation this status describes is complete", + default=False, + ) + percentage: Optional[float] = Field( + description="Percentage of status completion, if known", default=None + ) + time_elapsed: Optional[float] = Field( + description="Time elapsed since status operation beginning, if known", + default=None, + ) + time_remaining: Optional[float] = Field( + description="Estimated time remaining until operation completion, if known", + default=None, + ) + + +class ProgressEvent(BaseModel): """ Event describing the progress of processes within a running task, such as moving motors and exposing detectors. """ task_name: str - statuses: Mapping[str, StatusView] = field(default_factory=dict) + statuses: Mapping[str, StatusView] = Field(default_factory=dict) -@dataclass -class TaskStatus: +class TaskStatus(BaseModel): """ Status of a task the worker is running. """ @@ -72,8 +111,7 @@ class TaskStatus: task_failed: bool -@dataclass -class WorkerEvent: +class WorkerEvent(BaseModel): """ Event describing the state of the worker and any tasks it's running. Includes error and warning information. @@ -81,8 +119,8 @@ class WorkerEvent: state: WorkerState task_status: Optional[TaskStatus] = None - errors: List[str] = field(default_factory=list) - warnings: List[str] = field(default_factory=list) + errors: List[str] = Field(default_factory=list) + warnings: List[str] = Field(default_factory=list) def is_error(self) -> bool: return (self.task_status is not None and self.task_status.task_failed) or bool( diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index d136b65bd2..be67067676 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -1,40 +1,21 @@ import logging from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import Any, Mapping, Union - -from apischema import deserializer, identity, serializer -from apischema.conversions import Conversion - -from blueapi.core import ( - BlueskyContext, - Device, - Plan, - create_bluesky_protocol_conversions, -) +from dataclasses import dataclass +from typing import Any, Mapping + +from pydantic import BaseModel, Field, parse_obj_as +from pydantic.decorator import ValidatedFunction + +from blueapi.core import BlueskyContext, Device, create_bluesky_protocol_conversions from blueapi.utils import nested_deserialize_with_overrides # TODO: Make a TaggedUnion -class Task(ABC): +class Task(ABC, BaseModel): """ Object that can run with a TaskContext """ - _union: Any = None - - # You can use __init_subclass__ to register new subclass automatically - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - # Deserializers stack directly as a Union - deserializer(Conversion(identity, source=cls, target=Task)) - # Only Base serializer must be registered (and updated for each subclass) as - # a Union, and not be inherited - Task._union = cls if Task._union is None else Union[Task._union, cls] - serializer( - Conversion(identity, source=Task, target=Task._union, inherited=False) - ) - @abstractmethod def do_task(self, __ctx: BlueskyContext) -> None: """ @@ -48,29 +29,43 @@ def do_task(self, __ctx: BlueskyContext) -> None: LOGGER = logging.getLogger(__name__) -@dataclass class RunPlan(Task): """ Task that will run a plan """ - name: str - params: Mapping[str, Any] = field(default_factory=dict) - # plan: Generator[Msg, None, Any] + name: str = Field(description="Name of plan to run") + params: Mapping[str, Any] = Field( + description="Values for parameters to plan, if any", default_factory=dict + ) def do_task(self, ctx: BlueskyContext) -> None: LOGGER.info(f"Asked to run plan {self.name} with {self.params}") plan = ctx.plans[self.name] - plan_function = ctx.plan_functions[self.name] - sanitized_params = lookup_params(ctx, plan, self.params) - plan_generator = plan_function(**sanitized_params) + sanitized_params = _lookup_params(ctx, plan, self.params) + plan_generator = plan.call(**sanitized_params) ctx.run_engine(plan_generator) -def lookup_params( - ctx: BlueskyContext, plan: Plan, params: Mapping[str, Any] -) -> Mapping[str, Any]: +def _lookup_params( + ctx: BlueskyContext, plan: ValidatedFunction, params: Mapping[str, Any] +) -> BaseModel: + """ + Checks plan parameters against context + + Args: + ctx: Context holding plans and devices + plan: Plan object including schema + params: Parameter values to be validated against schema + + Returns: + Mapping[str, Any]: _description_ + """ + + model = plan.model + return parse_obj_as(model, params) + def find_device(name: str) -> Device: device = ctx.find_device(name) if device is not None: From 74257208979fd3692100672ac8df3548f11c866c Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 23 Feb 2023 16:07:32 +0000 Subject: [PATCH 03/21] WIP on parameter passing --- src/blueapi/core/__init__.py | 2 + src/blueapi/core/bluesky_types.py | 15 +++-- src/blueapi/core/context.py | 95 +++++++++++++++++++++++++++---- 3 files changed, 98 insertions(+), 14 deletions(-) diff --git a/src/blueapi/core/__init__.py b/src/blueapi/core/__init__.py index 1040d86a34..afc759a5c3 100644 --- a/src/blueapi/core/__init__.py +++ b/src/blueapi/core/__init__.py @@ -7,6 +7,7 @@ PlanGenerator, WatchableStatus, is_bluesky_compatible_device, + is_bluesky_compatible_device_type, is_bluesky_plan_generator, ) from .context import BlueskyContext @@ -27,4 +28,5 @@ "WatchableStatus", "is_bluesky_compatible_device", "is_bluesky_plan_generator", + "is_bluesky_compatible_device_type", ] diff --git a/src/blueapi/core/bluesky_types.py b/src/blueapi/core/bluesky_types.py index b2e6182079..8a8f35f0b1 100644 --- a/src/blueapi/core/bluesky_types.py +++ b/src/blueapi/core/bluesky_types.py @@ -58,12 +58,19 @@ def is_bluesky_compatible_device(obj: Any) -> bool: is_object = not inspect.isclass(obj) - follows_protocols = any( - map(lambda protocol: isinstance(obj, protocol), BLUESKY_PROTOCOLS) - ) # We must separately check if Obj refers to an instance rather than a # class, as both follow the protocols but only one is a "device". - return is_object and follows_protocols + return is_object and _follows_bluesky_protocols(obj) + + +def is_bluesky_compatible_device_type(cls: Type[Any]) -> bool: + # We must separately check if Obj refers to an class rather than an + # instance, as both follow the protocols but only one is a type. + return inspect.isclass(cls) and _follows_bluesky_protocols(cls) + + +def _follows_bluesky_protocols(obj: Any) -> bool: + return any(map(lambda protocol: isinstance(obj, protocol), BLUESKY_PROTOCOLS)) def is_bluesky_plan_generator(func: PlanGenerator) -> bool: diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index e607929e80..367c6542b6 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -1,14 +1,28 @@ import logging from dataclasses import dataclass, field from importlib import import_module +from inspect import Parameter, signature from pathlib import Path from types import ModuleType -from typing import Dict, List, Optional, Union +from typing import ( + Any, + Callable, + Deque, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) from bluesky import RunEngine from bluesky.protocols import Flyable, Readable -from pydantic import BaseConfig, Extra, validate_arguments -from pydantic.decorator import ValidatedFunction +from pydantic import BaseModel, create_model, validator from blueapi.utils import load_module_all, schema_for_func @@ -17,6 +31,7 @@ Plan, PlanGenerator, is_bluesky_compatible_device, + is_bluesky_compatible_device_type, is_bluesky_plan_generator, ) from .device_lookup import find_component @@ -24,10 +39,6 @@ LOGGER = logging.getLogger(__name__) -class PlanConfig(BaseConfig): - extra = Extra.forbid - - @dataclass class BlueskyContext: """ @@ -37,8 +48,9 @@ class BlueskyContext: run_engine: RunEngine = field( default_factory=lambda: RunEngine(context_managers=[]) ) - plans: Dict[str, ValidatedFunction] = field(default_factory=dict) + plans: Dict[str, Plan] = field(default_factory=dict) devices: Dict[str, Device] = field(default_factory=dict) + plan_functions: Dict[str, PlanGenerator] = field(default_factory=dict) def find_device(self, addr: Union[str, List[str]]) -> Optional[Device]: """ @@ -112,8 +124,12 @@ def my_plan(a: int, b: str): if not is_bluesky_plan_generator(plan): raise TypeError(f"{plan} is not a valid plan generator function") - schema = ValidatedFunction(plan, PlanConfig) - self.plans[plan.__name__] = schema + def get_device(name: str) -> Device: + return self.find_device(name) + + model = generate_plan_model(plan, get_device) + self.plans[plan.__name__] = Plan(name=plan.__name__, model=model) + self.plan_functions[plan.__name__] = plan return plan def device(self, device: Device, name: Optional[str] = None) -> None: @@ -142,3 +158,62 @@ def device(self, device: Device, name: Optional[str] = None) -> None: raise KeyError(f"Must supply a name for this device: {device}") self.devices[name] = device + + +def generate_plan_model( + plan: PlanGenerator, get_device: Callable[[str], Device] +) -> Type[BaseModel]: + model_annotations: Dict[str, Tuple[Type, Any]] = {} + validators: Dict[str, Any] = {} + for name, param in signature(plan).parameters.items(): + type_annotation = param.annotation + if is_bluesky_compatible_device_type(type_annotation): + type_annotation = str + validators[name] = validator(name)(get_device) + elif is_iterable_of_devices(type_annotation): + validators[name] = validator(name, each_item=True)(get_device) + + default_value = param.default + if default_value is Parameter.empty: + default_value = ... + + anno = (type_annotation, default_value) + model_annotations[name] = anno + + name = f"{plan.__name__}_model" + from pprint import pprint + + pprint(model_annotations) + return create_model(name, **model_annotations, __validators__=validators) + + +def is_mapping_with_devices(dct: Type) -> bool: + if get_params(dct): + ... + + +def is_iterable_of_devices(lst: Type) -> bool: + if origin_is_iterable(lst): + params = list(get_params(lst)) + if params: + (inner,) = params + return is_bluesky_compatible_device_type(inner) + return False + + +def get_params(maybe_parametrised: Type) -> Iterable[Type]: + for attr in "__args__", "__parameters__": + yield from getattr(maybe_parametrised, attr, []) + + +def origin_is_iterable(to_check: Type) -> bool: + return any( + map( + lambda origin: origin_is(to_check, origin), + [List, Set, Tuple, FrozenSet, Deque], + ) + ) + + +def origin_is(to_check: Type, origin: Type) -> bool: + return hasattr(to_check, "__origin__") and to_check.__origin__ is origin From 9b0a86f0479c9ed4699652bf7a0da375cf6bd414 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 16 Mar 2023 08:35:13 +0000 Subject: [PATCH 04/21] Remove redundant event --- src/blueapi/worker/event.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/blueapi/worker/event.py b/src/blueapi/worker/event.py index 84faa82573..7d8fbde698 100644 --- a/src/blueapi/worker/event.py +++ b/src/blueapi/worker/event.py @@ -41,17 +41,6 @@ def from_bluesky_state(cls, bluesky_state: RawRunEngineState) -> "WorkerState": return WorkerState(str(bluesky_state).upper()) -class WorkerEvent(BaseModel): - """ - Event emitted by a worker when the runner state changes - """ - - state: WorkerState = Field(description="Current state of the worker") - current_task_name: Optional[str] = Field( - description="Unique ID of the currently running task, if any", default=None - ) - - class StatusView(BaseModel): """ A snapshot of a Status of an operation, optionally representing progress From b7ed10919c37cb69c30a1ef6387d2f580d5fed5a Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Wed, 12 Apr 2023 19:56:04 +0100 Subject: [PATCH 05/21] Add type validators --- src/blueapi/utils/__init__.py | 3 + src/blueapi/utils/type_validator.py | 139 +++++++++++++++++++ tests/utils/test_type_validator.py | 201 ++++++++++++++++++++++++++++ 3 files changed, 343 insertions(+) create mode 100644 src/blueapi/utils/type_validator.py create mode 100644 tests/utils/test_type_validator.py diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index ebb10c6e74..5c6f21ef02 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -2,6 +2,7 @@ from .modules import load_module_all from .schema import nested_deserialize_with_overrides, schema_for_func from .thread_exception import handle_all_exceptions +from .type_validator import TypeConverter, create_model_with_type_validators __all__ = [ "handle_all_exceptions", @@ -9,4 +10,6 @@ "schema_for_func", "load_module_all", "ConfigLoader", + "create_model_with_type_validators", + "TypeConverter", ] diff --git a/src/blueapi/utils/type_validator.py b/src/blueapi/utils/type_validator.py new file mode 100644 index 0000000000..6b7f27cde4 --- /dev/null +++ b/src/blueapi/utils/type_validator.py @@ -0,0 +1,139 @@ +import functools +from dataclasses import dataclass +from inspect import isclass +from types import FunctionType +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Deque, + Dict, + FrozenSet, + Generic, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +from pydantic import BaseConfig, BaseModel, Field, create_model, validator +from pydantic.fields import ModelField + +if TYPE_CHECKING: + from pydantic.typing import AnyCallable, AnyClassMethod +else: + AnyCallable, AnyClassMethod = Any, Any + + +_PYDANTIC_LIST_TYPES = [List, Tuple, Set] +_PYDANTIC_DICT_TYPES = [Dict, Mapping] + +T = TypeVar("T") +U = TypeVar("U") +FieldDefinition = Tuple[Type, Any] +Fields = Mapping[str, FieldDefinition] +Validator = Callable[[AnyCallable], AnyClassMethod] + + +class DefaultConfig(BaseConfig): + arbitrary_types_allowed = True + + +@dataclass +class TypeConverter(Generic[T, U]): + field_type: Type[T] + func: Callable[[T], U] + + def __str__(self) -> str: + type_name = getattr( + self.field_type, "__name__", str(hash(str(self.field_type))) + ) + return f"converter_{type_name}" + + +def create_model_with_type_validators( + name: str, + fields: Fields, + converters: Iterable[TypeConverter], + config: Type[BaseConfig] = DefaultConfig, +) -> Type[BaseModel]: + validators = type_validators(fields, converters) + return create_model(name, **fields, __validators__=validators, __config__=config) + + +def type_validators( + fields: Fields, + converters: Iterable[TypeConverter], +) -> Mapping[str, Validator]: + all_validators = {} + + for converter in converters: + # def make_validator(name: str) -> Validator: + # def validate_type(value: Any) -> Any: + # return apply_to_scalars(converter.func, value) + + # validate_type.__name__ = str(converter) + # return validator(name, allow_reuse=True, pre=True)(validate_type) + + field_names = determine_fields_of_type(fields, converter.field_type) + for name in field_names: + val = _make_type_validator(name, converter) + val_method_name = f"validate_{name}" + if val_method_name in all_validators: + raise TypeError(f"Ambiguous type validator for field: {name}") + all_validators[val_method_name] = val + + return all_validators + + +def _make_type_validator(name: str, converter: TypeConverter) -> Validator: + def validate_type(value: Any) -> Any: + return apply_to_scalars(converter.func, value) + + return validator(name, allow_reuse=True, pre=True)(validate_type) + + +def determine_fields_of_type(fields: Fields, field_type: Type) -> Iterable[str]: + for name, field in fields.items(): + annotation, _ = field + if is_type_or_container_type(annotation, field_type): + yield name + + +def is_type_or_container_type(type_to_check: Type, field_type: Type) -> bool: + return params_contains(type_to_check, field_type) + # or ( + # isclass(field_type) and issubclass(field_type, type_to_check) + # ) + + +def params_contains(type_to_check: Type, field_type: Type) -> bool: + type_params = getattr(type_to_check, "__args__", []) + getattr( + type_to_check, "__parameters__", [] + ) + return type_to_check is field_type or any( + map(lambda v: params_contains(v, field_type), type_params) + ) + + +def apply_to_scalars(func: Callable[[T], U], obj: Any) -> Any: + if is_list_type(obj): + return list(map(lambda v: apply_to_scalars(func, v), obj)) + elif is_dict_type(obj): + return {k: apply_to_scalars(func, v) for k, v in obj.items()} + else: + return func(obj) + + +def is_list_type(obj: Any) -> bool: + return any(map(lambda t: isinstance(obj, t), _PYDANTIC_LIST_TYPES)) + + +def is_dict_type(obj: Any) -> bool: + return any(map(lambda t: isinstance(obj, t), _PYDANTIC_DICT_TYPES)) diff --git a/tests/utils/test_type_validator.py b/tests/utils/test_type_validator.py new file mode 100644 index 0000000000..8555c122e1 --- /dev/null +++ b/tests/utils/test_type_validator.py @@ -0,0 +1,201 @@ +from typing import Any, Dict, List, Mapping, NamedTuple, Set, Tuple, Type + +import pytest +from pydantic import BaseConfig, BaseModel, parse_obj_as +from pydantic.fields import Undefined + +from blueapi.utils import TypeConverter, create_model_with_type_validators + +_REG: Mapping[str, int] = { + letter: number for number, letter in enumerate("abcdefghijklmnopqrstuvwxyz") +} + + +class ComplexObject: + _name: str + + def __init__(self, name: str) -> None: + self._name = name + + def name(self) -> str: + return self._name + + def __eq__(self, __value: object) -> bool: + return isinstance(__value, ComplexObject) and __value.name() == self._name + + def __str__(self) -> str: + return f"ComplexObject({self._name})" + + def __repr__(self) -> str: + return f"ComplexObject({self._name})" + + +_DB: Mapping[str, ComplexObject] = {name: ComplexObject(name) for name in _REG.keys()} + + +def lookup(letter: str) -> int: + assert type(letter) is str, f"Expteced a string, got a {type(letter)}" + return _REG[letter] + + +def has_even_length(msg: str) -> bool: + assert type(msg) is str, f"Expteced a string, got a {type(msg)}" + return len(msg) % 2 == 0 + + +def lookup_complex(name: str) -> ComplexObject: + assert type(name) is str, f"Expteced a string, got a {type(name)}" + return _DB[name] + + +def test_validates_single_type() -> None: + assert_validates_single_type(int, "c", 2) + + +def test_leaves_unvalidated_types_alone() -> None: + model = create_model_with_type_validators( + "Foo", + {"a": (int, Undefined), "b": (str, Undefined)}, + [TypeConverter(int, lookup)], + ) + parsed = parse_obj_as(model, {"a": "c", "b": "hello"}) + assert parsed.a == 2 + assert parsed.b == "hello" + + +def test_validates_multiple_types() -> None: + model = create_model_with_type_validators( + "Foo", + {"a": (int, Undefined), "b": (bool, Undefined)}, + [TypeConverter(int, lookup), TypeConverter(bool, has_even_length)], + ) + parsed = parse_obj_as(model, {"a": "c", "b": "hello"}) + assert parsed.a == 2 + assert parsed.b == False + + +def test_validates_multiple_fields() -> None: + model = create_model_with_type_validators( + "Foo", + {"a": (int, Undefined), "b": (int, Undefined)}, + [TypeConverter(int, lookup)], + ) + parsed = parse_obj_as(model, {"a": "c", "b": "d"}) + assert parsed.a == 2 + assert parsed.b == 3 + + +def test_validates_multiple_fields_and_types() -> None: + model = create_model_with_type_validators( + "Foo", + { + "a": (int, Undefined), + "b": (bool, Undefined), + "c": (int, Undefined), + "d": (bool, Undefined), + }, + [TypeConverter(int, lookup), TypeConverter(bool, has_even_length)], + ) + parsed = parse_obj_as(model, {"a": "c", "b": "hello", "c": "d", "d": "word"}) + assert parsed.a == 2 + assert parsed.b == False + assert parsed.c == 3 + assert parsed.d == True + + +def test_does_not_tolerate_multiple_converters_for_same_type() -> None: + with pytest.raises(TypeError): + create_model_with_type_validators( + "Foo", + {"a": (int, Undefined), "b": (int, Undefined)}, + [TypeConverter(int, lookup), TypeConverter(int, int)], + ) + + +def test_validates_list_type() -> None: + assert_validates_single_type(List[int], ["a", "b", "c"], [0, 1, 2]) + + +def test_validates_set_type() -> None: + assert_validates_single_type(Set[int], ["a", "b", "c"], {0, 1, 2}) + + +def test_validates_tuple_type() -> None: + assert_validates_single_type(Tuple[int, ...], ["a", "b", "c"], (0, 1, 2)) + + +def test_validates_nested_container_type() -> None: + assert_validates_single_type( + List[Set[Tuple[int, int]]], + [[["a", "b"], ["c", "d"]], [["e", "f"]]], + [{(0, 1), (2, 3)}, {(4, 5)}], + ) + + +@pytest.mark.parametrize("dict_type", [Dict, Mapping]) +def test_validates_dict_type(dict_type: Type) -> None: + assert_validates_single_type( + dict_type[str, int], + { + "a": "a", + "b": "b", + "c": "c", + }, + { + "a": 0, + "b": 1, + "c": 2, + }, + ) + + +def test_validates_nested_mapping() -> None: + assert_validates_single_type( + Dict[str, List[int]], + { + "a": ["a", "b"], + "b": ["c", "d", "e"], + "c": ["f"], + }, + { + "a": [0, 1], + "b": [2, 3, 4], + "c": [5], + }, + ) + + +def test_validates_complex_object() -> None: + assert_validates_complex_object(ComplexObject, "d", ComplexObject("d")) + + +def test_validates_complex_object_list() -> None: + assert_validates_complex_object( + List[ComplexObject], + ["a", "b", "c"], + [ + ComplexObject("a"), + ComplexObject("b"), + ComplexObject("c"), + ], + ) + + +def assert_validates_single_type( + field_type: Type, input_value: Any, expected_output: Any +) -> None: + model = create_model_with_type_validators( + "Foo", {"ch": (field_type, Undefined)}, [TypeConverter(int, lookup)] + ) + assert parse_obj_as(model, {"ch": input_value}).ch == expected_output + + +def assert_validates_complex_object( + field_type: Type, input_value: Any, expected_output: Any +) -> None: + model = create_model_with_type_validators( + "Foo", + {"obj": (field_type, Undefined)}, + [TypeConverter(ComplexObject, lookup_complex)], + ) + assert parse_obj_as(model, {"obj": input_value}).obj == expected_output From c5a209e1aa13c111b9ad84c4c8618dcdfac535d0 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 13 Apr 2023 09:40:12 +0100 Subject: [PATCH 06/21] Support basemodel and dataclass recursive type validators --- src/blueapi/utils/type_validator.py | 60 +++++++-- tests/utils/test_type_validator.py | 184 ++++++++++++++++++++++++++-- 2 files changed, 229 insertions(+), 15 deletions(-) diff --git a/src/blueapi/utils/type_validator.py b/src/blueapi/utils/type_validator.py index 6b7f27cde4..ee0b3e168c 100644 --- a/src/blueapi/utils/type_validator.py +++ b/src/blueapi/utils/type_validator.py @@ -20,10 +20,11 @@ Type, TypeVar, Union, + overload, ) from pydantic import BaseConfig, BaseModel, Field, create_model, validator -from pydantic.fields import ModelField +from pydantic.fields import ModelField, Undefined if TYPE_CHECKING: from pydantic.typing import AnyCallable, AnyClassMethod @@ -41,10 +42,6 @@ Validator = Callable[[AnyCallable], AnyClassMethod] -class DefaultConfig(BaseConfig): - arbitrary_types_allowed = True - - @dataclass class TypeConverter(Generic[T, U]): field_type: Type[T] @@ -57,14 +54,54 @@ def __str__(self) -> str: return f"converter_{type_name}" +@overload def create_model_with_type_validators( name: str, + converters: Iterable[TypeConverter], fields: Fields, + config: Optional[Type[BaseConfig]] = None, +) -> Type[BaseModel]: + ... + + +@overload +def create_model_with_type_validators( + name: str, + converters: Iterable[TypeConverter], + base: Type[BaseModel], +) -> Type[BaseModel]: + ... + + +def create_model_with_type_validators( + name: str, converters: Iterable[TypeConverter], - config: Type[BaseConfig] = DefaultConfig, + fields: Optional[Fields] = None, + base: Optional[Type[BaseModel]] = None, + config: Optional[Type[BaseConfig]] = None, ) -> Type[BaseModel]: + fields = fields or {} + if base is not None: + fields = {**fields, **_extract_fields(base)} + for name, field in fields.items(): + annotation, val = field + model_type = find_model_type(annotation) + if model_type is not None: + recursed = create_model_with_type_validators( + annotation.__name__, converters, base=model_type + ) + fields[name] = recursed, val validators = type_validators(fields, converters) - return create_model(name, **fields, __validators__=validators, __config__=config) + return create_model( + name, **fields, __base__=base, __validators__=validators, __config__=config + ) + + +def _extract_fields(model: Type[BaseModel]) -> Fields: + return { + name: (field.type_, field.field_info) + for name, field in model.__fields__.items() + } def type_validators( @@ -137,3 +174,12 @@ def is_list_type(obj: Any) -> bool: def is_dict_type(obj: Any) -> bool: return any(map(lambda t: isinstance(obj, t), _PYDANTIC_DICT_TYPES)) + + +def find_model_type(anno: Type) -> Optional[Type[BaseModel]]: + if isclass(anno): + if issubclass(anno, BaseModel): + return anno + elif hasattr(anno, "__pydantic_model__"): + return getattr(anno, "__pydantic_model__") + return None diff --git a/tests/utils/test_type_validator.py b/tests/utils/test_type_validator.py index 8555c122e1..b140cc97ed 100644 --- a/tests/utils/test_type_validator.py +++ b/tests/utils/test_type_validator.py @@ -2,10 +2,16 @@ import pytest from pydantic import BaseConfig, BaseModel, parse_obj_as +from pydantic.dataclasses import dataclass from pydantic.fields import Undefined from blueapi.utils import TypeConverter, create_model_with_type_validators + +class DefaultConfig(BaseConfig): + arbitrary_types_allowed = True + + _REG: Mapping[str, int] = { letter: number for number, letter in enumerate("abcdefghijklmnopqrstuvwxyz") } @@ -30,6 +36,37 @@ def __repr__(self) -> str: return f"ComplexObject({self._name})" +class Bar(BaseModel): + a: int + b: ComplexObject + + class Config: + arbitrary_types_allowed = True + + +class Baz(BaseModel): + obj: Bar + c: str + + +@dataclass(config=DefaultConfig) +class DataclassBar: + a: int + b: ComplexObject + + +@dataclass +class DataclassBaz: + obj: DataclassBar + c: str + + +@dataclass +class DataclassMixed: + obj: Bar + c: str + + _DB: Mapping[str, ComplexObject] = {name: ComplexObject(name) for name in _REG.keys()} @@ -55,8 +92,8 @@ def test_validates_single_type() -> None: def test_leaves_unvalidated_types_alone() -> None: model = create_model_with_type_validators( "Foo", - {"a": (int, Undefined), "b": (str, Undefined)}, [TypeConverter(int, lookup)], + fields={"a": (int, Undefined), "b": (str, Undefined)}, ) parsed = parse_obj_as(model, {"a": "c", "b": "hello"}) assert parsed.a == 2 @@ -66,8 +103,8 @@ def test_leaves_unvalidated_types_alone() -> None: def test_validates_multiple_types() -> None: model = create_model_with_type_validators( "Foo", - {"a": (int, Undefined), "b": (bool, Undefined)}, [TypeConverter(int, lookup), TypeConverter(bool, has_even_length)], + fields={"a": (int, Undefined), "b": (bool, Undefined)}, ) parsed = parse_obj_as(model, {"a": "c", "b": "hello"}) assert parsed.a == 2 @@ -77,8 +114,8 @@ def test_validates_multiple_types() -> None: def test_validates_multiple_fields() -> None: model = create_model_with_type_validators( "Foo", - {"a": (int, Undefined), "b": (int, Undefined)}, [TypeConverter(int, lookup)], + fields={"a": (int, Undefined), "b": (int, Undefined)}, ) parsed = parse_obj_as(model, {"a": "c", "b": "d"}) assert parsed.a == 2 @@ -88,13 +125,13 @@ def test_validates_multiple_fields() -> None: def test_validates_multiple_fields_and_types() -> None: model = create_model_with_type_validators( "Foo", - { + [TypeConverter(int, lookup), TypeConverter(bool, has_even_length)], + fields={ "a": (int, Undefined), "b": (bool, Undefined), "c": (int, Undefined), "d": (bool, Undefined), }, - [TypeConverter(int, lookup), TypeConverter(bool, has_even_length)], ) parsed = parse_obj_as(model, {"a": "c", "b": "hello", "c": "d", "d": "word"}) assert parsed.a == 2 @@ -107,8 +144,8 @@ def test_does_not_tolerate_multiple_converters_for_same_type() -> None: with pytest.raises(TypeError): create_model_with_type_validators( "Foo", - {"a": (int, Undefined), "b": (int, Undefined)}, [TypeConverter(int, lookup), TypeConverter(int, int)], + fields={"a": (int, Undefined), "b": (int, Undefined)}, ) @@ -181,11 +218,141 @@ def test_validates_complex_object_list() -> None: ) +def test_applies_to_base() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeConverter(ComplexObject, lookup_complex)], + base=Bar, + ) + parsed = parse_obj_as(model, {"a": 2, "b": "g"}) + assert parsed.a == 2 + assert parsed.b == ComplexObject("g") + + +def test_applies_to_nested_base() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeConverter(ComplexObject, lookup_complex)], + base=Baz, + ) + parsed = parse_obj_as(model, {"obj": {"a": 2, "b": "g"}, "c": "hello"}) + assert parsed.obj.a == 2 + assert parsed.obj.b == ComplexObject("g") + assert parsed.c == "hello" + + +def test_validates_submodel() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeConverter(ComplexObject, lookup_complex)], + fields={"obj": (Bar, Undefined)}, + ) + parsed = parse_obj_as( + model, + { + "obj": { + "a": 2, + "b": "g", + }, + }, + ) + assert parsed.obj.a == 2 + assert parsed.obj.b == ComplexObject("g") + + +def test_validates_nested_submodel() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeConverter(ComplexObject, lookup_complex)], + fields={"obj": (Baz, Undefined)}, + ) + parsed = parse_obj_as( + model, + { + "obj": { + "obj": { + "a": 2, + "b": "g", + }, + "c": "hello", + } + }, + ) + assert parsed.obj.obj.a == 2 + assert parsed.obj.obj.b == ComplexObject("g") + assert parsed.obj.c == "hello" + + +def test_validates_dataclass() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeConverter(ComplexObject, lookup_complex)], + fields={"obj": (DataclassBar, Undefined)}, + ) + parsed = parse_obj_as( + model, + { + "obj": { + "a": 2, + "b": "g", + }, + }, + ) + assert parsed.obj.a == 2 + assert parsed.obj.b == ComplexObject("g") + + +def test_validates_nested_dataclass() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeConverter(ComplexObject, lookup_complex)], + fields={"obj": (DataclassBaz, Undefined)}, + ) + parsed = parse_obj_as( + model, + { + "obj": { + "obj": { + "a": 2, + "b": "g", + }, + "c": "hello", + } + }, + ) + assert parsed.obj.obj.a == 2 + assert parsed.obj.obj.b == ComplexObject("g") + assert parsed.obj.c == "hello" + + +def test_validates_mixed_dataclass() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeConverter(ComplexObject, lookup_complex)], + fields={"obj": (DataclassMixed, Undefined)}, + ) + parsed = parse_obj_as( + model, + { + "obj": { + "obj": { + "a": 2, + "b": "g", + }, + "c": "hello", + } + }, + ) + assert parsed.obj.obj.a == 2 + assert parsed.obj.obj.b == ComplexObject("g") + assert parsed.obj.c == "hello" + + def assert_validates_single_type( field_type: Type, input_value: Any, expected_output: Any ) -> None: model = create_model_with_type_validators( - "Foo", {"ch": (field_type, Undefined)}, [TypeConverter(int, lookup)] + "Foo", [TypeConverter(int, lookup)], fields={"ch": (field_type, Undefined)} ) assert parse_obj_as(model, {"ch": input_value}).ch == expected_output @@ -195,7 +362,8 @@ def assert_validates_complex_object( ) -> None: model = create_model_with_type_validators( "Foo", - {"obj": (field_type, Undefined)}, [TypeConverter(ComplexObject, lookup_complex)], + fields={"obj": (field_type, Undefined)}, + config=DefaultConfig, ) assert parse_obj_as(model, {"obj": input_value}).obj == expected_output From 041add60cc42f6cfc982d1e125d8c459d5a40648 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 13 Apr 2023 10:29:39 +0100 Subject: [PATCH 07/21] Support scanspecs --- src/blueapi/utils/type_validator.py | 14 ++++++++++++-- tests/utils/test_type_validator.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/src/blueapi/utils/type_validator.py b/src/blueapi/utils/type_validator.py index ee0b3e168c..2be1753763 100644 --- a/src/blueapi/utils/type_validator.py +++ b/src/blueapi/utils/type_validator.py @@ -151,8 +151,18 @@ def is_type_or_container_type(type_to_check: Type, field_type: Type) -> bool: def params_contains(type_to_check: Type, field_type: Type) -> bool: - type_params = getattr(type_to_check, "__args__", []) + getattr( - type_to_check, "__parameters__", [] + type_params = list( + getattr( + type_to_check, + "__args__", + [], + ) + ) + list( + getattr( + type_to_check, + "__parameters__", + [], + ) ) return type_to_check is field_type or any( map(lambda v: params_contains(v, field_type), type_params) diff --git a/tests/utils/test_type_validator.py b/tests/utils/test_type_validator.py index b140cc97ed..3ef1ea37f5 100644 --- a/tests/utils/test_type_validator.py +++ b/tests/utils/test_type_validator.py @@ -4,6 +4,8 @@ from pydantic import BaseConfig, BaseModel, parse_obj_as from pydantic.dataclasses import dataclass from pydantic.fields import Undefined +from scanspec.regions import Circle +from scanspec.specs import Line, Product, Spec from blueapi.utils import TypeConverter, create_model_with_type_validators @@ -348,6 +350,33 @@ def test_validates_mixed_dataclass() -> None: assert parsed.obj.c == "hello" +@pytest.mark.parametrize( + "spec", + [ + Line("x", 0.0, 10.0, 10), + Line("x", 0.0, 10.0, 10) * Line("y", 0.0, 10.0, 10), + (Line("x", 0.0, 10.0, 10) * Line("y", 0.0, 10.0, 10)) + & Circle("x", "y", 1.0, 2.8, radius=0.5), + ], +) +def test_validates_scanspec(spec: Spec) -> None: + assert parse_spec(spec).spec == spec + + +def test_validates_scanspec_with_complex_axis() -> None: + spec = Line(ComplexObject("x"), 0.0, 10.0, 10) + assert parse_spec(spec).spec.axes() == [ComplexObject("x")] + + +def parse_spec(spec: Spec) -> Any: + model = create_model_with_type_validators( + "Foo", + [TypeConverter(ComplexObject, lookup_complex)], + fields={"spec": (Spec, Undefined)}, + ) + return parse_obj_as(model, {"spec": spec.serialize()}) + + def assert_validates_single_type( field_type: Type, input_value: Any, expected_output: Any ) -> None: From 398c8e39c685518b7e0a2b402b7df7d4ab7eefd6 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 13 Apr 2023 10:30:19 +0100 Subject: [PATCH 08/21] Remove comments --- src/blueapi/utils/type_validator.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/blueapi/utils/type_validator.py b/src/blueapi/utils/type_validator.py index 2be1753763..e70e55a2f1 100644 --- a/src/blueapi/utils/type_validator.py +++ b/src/blueapi/utils/type_validator.py @@ -111,13 +111,6 @@ def type_validators( all_validators = {} for converter in converters: - # def make_validator(name: str) -> Validator: - # def validate_type(value: Any) -> Any: - # return apply_to_scalars(converter.func, value) - - # validate_type.__name__ = str(converter) - # return validator(name, allow_reuse=True, pre=True)(validate_type) - field_names = determine_fields_of_type(fields, converter.field_type) for name in field_names: val = _make_type_validator(name, converter) @@ -145,9 +138,6 @@ def determine_fields_of_type(fields: Fields, field_type: Type) -> Iterable[str]: def is_type_or_container_type(type_to_check: Type, field_type: Type) -> bool: return params_contains(type_to_check, field_type) - # or ( - # isclass(field_type) and issubclass(field_type, type_to_check) - # ) def params_contains(type_to_check: Type, field_type: Type) -> bool: From 0b9d3aed2aaf744efb19c593d525ad3b5056a77b Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 13 Apr 2023 10:51:44 +0100 Subject: [PATCH 09/21] Create models for functions --- src/blueapi/utils/type_validator.py | 33 +++++++++++++++++-- tests/utils/test_type_validator.py | 51 +++++++++++++++++++++++++++-- 2 files changed, 79 insertions(+), 5 deletions(-) diff --git a/src/blueapi/utils/type_validator.py b/src/blueapi/utils/type_validator.py index e70e55a2f1..829af2ee39 100644 --- a/src/blueapi/utils/type_validator.py +++ b/src/blueapi/utils/type_validator.py @@ -1,6 +1,6 @@ import functools from dataclasses import dataclass -from inspect import isclass +from inspect import Parameter, isclass, signature from types import FunctionType from typing import ( TYPE_CHECKING, @@ -64,6 +64,16 @@ def create_model_with_type_validators( ... +@overload +def create_model_with_type_validators( + name: str, + converters: Iterable[TypeConverter], + func: Callable[..., Any], + config: Optional[Type[BaseConfig]] = None, +) -> Type[BaseModel]: + ... + + @overload def create_model_with_type_validators( name: str, @@ -78,11 +88,14 @@ def create_model_with_type_validators( converters: Iterable[TypeConverter], fields: Optional[Fields] = None, base: Optional[Type[BaseModel]] = None, + func: Optional[Callable[..., Any]] = None, config: Optional[Type[BaseConfig]] = None, ) -> Type[BaseModel]: fields = fields or {} if base is not None: - fields = {**fields, **_extract_fields(base)} + fields = {**fields, **_extract_fields_from_model(base)} + if func is not None: + fields = {**fields, **_extract_fields_from_function(func)} for name, field in fields.items(): annotation, val = field model_type = find_model_type(annotation) @@ -97,13 +110,27 @@ def create_model_with_type_validators( ) -def _extract_fields(model: Type[BaseModel]) -> Fields: +def _extract_fields_from_model(model: Type[BaseModel]) -> Fields: return { name: (field.type_, field.field_info) for name, field in model.__fields__.items() } +def _extract_fields_from_function(func: Callable[..., Any]) -> Fields: + fields: Fields = {} + for name, param in signature(func).parameters.items(): + type_annotation = param.annotation + default_value = param.default + if default_value is Parameter.empty: + default_value = Undefined + + anno = (type_annotation, default_value) + fields[name] = anno + + return fields + + def type_validators( fields: Fields, converters: Iterable[TypeConverter], diff --git a/tests/utils/test_type_validator.py b/tests/utils/test_type_validator.py index 3ef1ea37f5..d5bcddeac1 100644 --- a/tests/utils/test_type_validator.py +++ b/tests/utils/test_type_validator.py @@ -69,6 +69,18 @@ class DataclassMixed: c: str +def foo(a: int, b: str) -> None: + ... + + +def bar(obj: ComplexObject) -> None: + ... + + +def baz(bar: Bar) -> None: + ... + + _DB: Mapping[str, ComplexObject] = {name: ComplexObject(name) for name in _REG.keys()} @@ -368,6 +380,38 @@ def test_validates_scanspec_with_complex_axis() -> None: assert parse_spec(spec).spec.axes() == [ComplexObject("x")] +def test_model_from_simple_function_signature() -> None: + model = create_model_with_type_validators( + "Foo", [TypeConverter(int, lookup)], func=foo + ) + parsed = parse_obj_as(model, {"a": "g", "b": "hello"}) + assert parsed.a == 6 + assert parsed.b == "hello" + + +def test_model_from_complex_function_signature() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeConverter(ComplexObject, lookup_complex)], + func=bar, + config=DefaultConfig, + ) + parsed = parse_obj_as(model, {"obj": "f"}) + assert parsed.obj == ComplexObject("f") + + +def test_model_from_nested_function_signature() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeConverter(ComplexObject, lookup_complex)], + func=baz, + config=DefaultConfig, + ) + parsed = parse_obj_as(model, {"bar": {"a": 4, "b": "k"}}) + assert parsed.bar.a == 4 + assert parsed.bar.b == ComplexObject("k") + + def parse_spec(spec: Spec) -> Any: model = create_model_with_type_validators( "Foo", @@ -387,12 +431,15 @@ def assert_validates_single_type( def assert_validates_complex_object( - field_type: Type, input_value: Any, expected_output: Any + field_type: Type, + input_value: Any, + expected_output: Any, + default_value: Any = Undefined, ) -> None: model = create_model_with_type_validators( "Foo", [TypeConverter(ComplexObject, lookup_complex)], - fields={"obj": (field_type, Undefined)}, + fields={"obj": (field_type, default_value)}, config=DefaultConfig, ) assert parse_obj_as(model, {"obj": input_value}).obj == expected_output From 0345a0482f32a34a0f5b37e71a027fe0630fb147 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 13 Apr 2023 15:26:56 +0100 Subject: [PATCH 10/21] Test default value validation --- src/blueapi/utils/type_validator.py | 2 +- tests/utils/test_type_validator.py | 32 ++++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/blueapi/utils/type_validator.py b/src/blueapi/utils/type_validator.py index 829af2ee39..c6fdbac1cd 100644 --- a/src/blueapi/utils/type_validator.py +++ b/src/blueapi/utils/type_validator.py @@ -153,7 +153,7 @@ def _make_type_validator(name: str, converter: TypeConverter) -> Validator: def validate_type(value: Any) -> Any: return apply_to_scalars(converter.func, value) - return validator(name, allow_reuse=True, pre=True)(validate_type) + return validator(name, allow_reuse=True, pre=True, always=True)(validate_type) def determine_fields_of_type(fields: Fields, field_type: Type) -> Iterable[str]: diff --git a/tests/utils/test_type_validator.py b/tests/utils/test_type_validator.py index d5bcddeac1..04cbc04214 100644 --- a/tests/utils/test_type_validator.py +++ b/tests/utils/test_type_validator.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Mapping, NamedTuple, Set, Tuple, Type import pytest -from pydantic import BaseConfig, BaseModel, parse_obj_as +from pydantic import BaseConfig, BaseModel, Field, parse_obj_as from pydantic.dataclasses import dataclass from pydantic.fields import Undefined from scanspec.regions import Circle @@ -362,6 +362,36 @@ def test_validates_mixed_dataclass() -> None: assert parsed.obj.c == "hello" +def test_validates_default_value() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeConverter(int, lookup)], + fields={"a": (int, "e")}, + config=DefaultConfig, + ) + assert parse_obj_as(model, {}).a == 4 + + +def test_validates_complex_value() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeConverter(ComplexObject, lookup_complex)], + fields={"obj": (ComplexObject, "t")}, + config=DefaultConfig, + ) + assert parse_obj_as(model, {}).obj == ComplexObject("t") + + +def test_validates_field_info() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeConverter(int, lookup)], + fields={"a": (int, Field(default="f"))}, + config=DefaultConfig, + ) + assert parse_obj_as(model, {}).a == 5 + + @pytest.mark.parametrize( "spec", [ From b0a19298afbf36826385e1f4077ec65804b8397c Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 13 Apr 2023 18:06:40 +0100 Subject: [PATCH 11/21] Add docstrings --- src/blueapi/utils/__init__.py | 4 +- src/blueapi/utils/type_validator.py | 126 +++++++++++++++++++++++----- tests/utils/test_type_validator.py | 52 +++++++----- 3 files changed, 135 insertions(+), 47 deletions(-) diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index 5c6f21ef02..fa8b28d207 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -2,7 +2,7 @@ from .modules import load_module_all from .schema import nested_deserialize_with_overrides, schema_for_func from .thread_exception import handle_all_exceptions -from .type_validator import TypeConverter, create_model_with_type_validators +from .type_validator import TypeValidatorDefinition, create_model_with_type_validators __all__ = [ "handle_all_exceptions", @@ -11,5 +11,5 @@ "load_module_all", "ConfigLoader", "create_model_with_type_validators", - "TypeConverter", + "TypeValidatorDefinition", ] diff --git a/src/blueapi/utils/type_validator.py b/src/blueapi/utils/type_validator.py index c6fdbac1cd..9e2ea6b47b 100644 --- a/src/blueapi/utils/type_validator.py +++ b/src/blueapi/utils/type_validator.py @@ -1,30 +1,24 @@ -import functools from dataclasses import dataclass from inspect import Parameter, isclass, signature -from types import FunctionType from typing import ( TYPE_CHECKING, Any, Callable, - Deque, Dict, - FrozenSet, Generic, Iterable, List, Mapping, - NamedTuple, Optional, Set, Tuple, Type, TypeVar, - Union, overload, ) -from pydantic import BaseConfig, BaseModel, Field, create_model, validator -from pydantic.fields import ModelField, Undefined +from pydantic import BaseConfig, BaseModel, create_model, validator +from pydantic.fields import Undefined if TYPE_CHECKING: from pydantic.typing import AnyCallable, AnyClassMethod @@ -43,7 +37,16 @@ @dataclass -class TypeConverter(Generic[T, U]): +class TypeValidatorDefinition(Generic[T, U]): + """ + Definition of a validator to be applied to all + types during validation. + + Args: + field_type: Convert all fields of this type + func: Convert using this function + """ + field_type: Type[T] func: Callable[[T], U] @@ -57,40 +60,102 @@ def __str__(self) -> str: @overload def create_model_with_type_validators( name: str, - converters: Iterable[TypeConverter], + definitions: Iterable[TypeValidatorDefinition], fields: Fields, config: Optional[Type[BaseConfig]] = None, ) -> Type[BaseModel]: + """ + Create a model based on the fields supplied + + Args: + name: Name of the new model + definitions: Definitions of how to validate which types of field + fields: Definitions of fields from which to make the model. + config: Pydantic config for the model. Defaults to None. + + Returns: + Type[BaseModel]: A new pydantic model with the fields and + type validators supplied. + """ + ... @overload def create_model_with_type_validators( name: str, - converters: Iterable[TypeConverter], + definitions: Iterable[TypeValidatorDefinition], func: Callable[..., Any], config: Optional[Type[BaseConfig]] = None, ) -> Type[BaseModel]: + """ + Create a model from a function's parameters with type + validators. + + Args: + name: Name of the new model + definitions: Definitions of how to validate which types of field + func: The model is constructed from the function parameters, + which must be type-annotated. + config: Pydantic config for the model. Defaults to None. + + Returns: + Type[BaseModel]: A new pydantic model based on the + function parameters. + """ + ... @overload def create_model_with_type_validators( name: str, - converters: Iterable[TypeConverter], + definitions: Iterable[TypeValidatorDefinition], base: Type[BaseModel], ) -> Type[BaseModel]: + """ + Apply type validators to an existing model + + Args: + name: Name of the new model + definitions: Definitions of how to validate which types of field + base (Type[BaseModel]): Base class for the model + + Returns: + Type[BaseModel]: A new version of `base` with type validators + """ + ... def create_model_with_type_validators( name: str, - converters: Iterable[TypeConverter], + definitions: Iterable[TypeValidatorDefinition], fields: Optional[Fields] = None, base: Optional[Type[BaseModel]] = None, func: Optional[Callable[..., Any]] = None, config: Optional[Type[BaseConfig]] = None, ) -> Type[BaseModel]: + """ + Create a pydantic model with type validators according to + definitions given. Validators are applied to all fields + of a particular type. + + Args: + name: Name of the new model + definitions: Definitions of how to validate which types of field + fields: Definitions of fields from which to make the model. + Defaults to None. + base: Optional base class for the model. Defaults to None. + func: Function, if supplied, the model is constructed from the + function parameters, which must be type-annotated. + Defaults to None. + config: Pydantic config for the model. Defaults to None. + + Returns: + Type[BaseModel]: A new pydantic model + """ + fields = fields or {} if base is not None: fields = {**fields, **_extract_fields_from_model(base)} @@ -101,10 +166,10 @@ def create_model_with_type_validators( model_type = find_model_type(annotation) if model_type is not None: recursed = create_model_with_type_validators( - annotation.__name__, converters, base=model_type + annotation.__name__, definitions, base=model_type ) fields[name] = recursed, val - validators = type_validators(fields, converters) + validators = _type_validators(fields, definitions) return create_model( name, **fields, __base__=base, __validators__=validators, __config__=config ) @@ -131,16 +196,31 @@ def _extract_fields_from_function(func: Callable[..., Any]) -> Fields: return fields -def type_validators( +def _type_validators( fields: Fields, - converters: Iterable[TypeConverter], + definitions: Iterable[TypeValidatorDefinition], ) -> Mapping[str, Validator]: + """ + Generate type validators from fields and definitions. + + Args: + fields: fields to validate. + definitions: Definitions of how to validate which types of field + + Raises: + TypeError: If a validator can be applied to more than one field. + + Returns: + Mapping[str, Validator]: Dict-like structure mapping validator + names to pydantic validators. + """ + all_validators = {} - for converter in converters: - field_names = determine_fields_of_type(fields, converter.field_type) + for definition in definitions: + field_names = _determine_fields_of_type(fields, definition.field_type) for name in field_names: - val = _make_type_validator(name, converter) + val = _make_type_validator(name, definition) val_method_name = f"validate_{name}" if val_method_name in all_validators: raise TypeError(f"Ambiguous type validator for field: {name}") @@ -149,14 +229,14 @@ def type_validators( return all_validators -def _make_type_validator(name: str, converter: TypeConverter) -> Validator: +def _make_type_validator(name: str, definition: TypeValidatorDefinition) -> Validator: def validate_type(value: Any) -> Any: - return apply_to_scalars(converter.func, value) + return apply_to_scalars(definition.func, value) return validator(name, allow_reuse=True, pre=True, always=True)(validate_type) -def determine_fields_of_type(fields: Fields, field_type: Type) -> Iterable[str]: +def _determine_fields_of_type(fields: Fields, field_type: Type) -> Iterable[str]: for name, field in fields.items(): annotation, _ = field if is_type_or_container_type(annotation, field_type): diff --git a/tests/utils/test_type_validator.py b/tests/utils/test_type_validator.py index 04cbc04214..2afd8f874b 100644 --- a/tests/utils/test_type_validator.py +++ b/tests/utils/test_type_validator.py @@ -7,7 +7,7 @@ from scanspec.regions import Circle from scanspec.specs import Line, Product, Spec -from blueapi.utils import TypeConverter, create_model_with_type_validators +from blueapi.utils import TypeValidatorDefinition, create_model_with_type_validators class DefaultConfig(BaseConfig): @@ -106,7 +106,7 @@ def test_validates_single_type() -> None: def test_leaves_unvalidated_types_alone() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(int, lookup)], + [TypeValidatorDefinition(int, lookup)], fields={"a": (int, Undefined), "b": (str, Undefined)}, ) parsed = parse_obj_as(model, {"a": "c", "b": "hello"}) @@ -117,7 +117,10 @@ def test_leaves_unvalidated_types_alone() -> None: def test_validates_multiple_types() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(int, lookup), TypeConverter(bool, has_even_length)], + [ + TypeValidatorDefinition(int, lookup), + TypeValidatorDefinition(bool, has_even_length), + ], fields={"a": (int, Undefined), "b": (bool, Undefined)}, ) parsed = parse_obj_as(model, {"a": "c", "b": "hello"}) @@ -128,7 +131,7 @@ def test_validates_multiple_types() -> None: def test_validates_multiple_fields() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(int, lookup)], + [TypeValidatorDefinition(int, lookup)], fields={"a": (int, Undefined), "b": (int, Undefined)}, ) parsed = parse_obj_as(model, {"a": "c", "b": "d"}) @@ -139,7 +142,10 @@ def test_validates_multiple_fields() -> None: def test_validates_multiple_fields_and_types() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(int, lookup), TypeConverter(bool, has_even_length)], + [ + TypeValidatorDefinition(int, lookup), + TypeValidatorDefinition(bool, has_even_length), + ], fields={ "a": (int, Undefined), "b": (bool, Undefined), @@ -158,7 +164,7 @@ def test_does_not_tolerate_multiple_converters_for_same_type() -> None: with pytest.raises(TypeError): create_model_with_type_validators( "Foo", - [TypeConverter(int, lookup), TypeConverter(int, int)], + [TypeValidatorDefinition(int, lookup), TypeValidatorDefinition(int, int)], fields={"a": (int, Undefined), "b": (int, Undefined)}, ) @@ -235,7 +241,7 @@ def test_validates_complex_object_list() -> None: def test_applies_to_base() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(ComplexObject, lookup_complex)], + [TypeValidatorDefinition(ComplexObject, lookup_complex)], base=Bar, ) parsed = parse_obj_as(model, {"a": 2, "b": "g"}) @@ -246,7 +252,7 @@ def test_applies_to_base() -> None: def test_applies_to_nested_base() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(ComplexObject, lookup_complex)], + [TypeValidatorDefinition(ComplexObject, lookup_complex)], base=Baz, ) parsed = parse_obj_as(model, {"obj": {"a": 2, "b": "g"}, "c": "hello"}) @@ -258,7 +264,7 @@ def test_applies_to_nested_base() -> None: def test_validates_submodel() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(ComplexObject, lookup_complex)], + [TypeValidatorDefinition(ComplexObject, lookup_complex)], fields={"obj": (Bar, Undefined)}, ) parsed = parse_obj_as( @@ -277,7 +283,7 @@ def test_validates_submodel() -> None: def test_validates_nested_submodel() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(ComplexObject, lookup_complex)], + [TypeValidatorDefinition(ComplexObject, lookup_complex)], fields={"obj": (Baz, Undefined)}, ) parsed = parse_obj_as( @@ -300,7 +306,7 @@ def test_validates_nested_submodel() -> None: def test_validates_dataclass() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(ComplexObject, lookup_complex)], + [TypeValidatorDefinition(ComplexObject, lookup_complex)], fields={"obj": (DataclassBar, Undefined)}, ) parsed = parse_obj_as( @@ -319,7 +325,7 @@ def test_validates_dataclass() -> None: def test_validates_nested_dataclass() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(ComplexObject, lookup_complex)], + [TypeValidatorDefinition(ComplexObject, lookup_complex)], fields={"obj": (DataclassBaz, Undefined)}, ) parsed = parse_obj_as( @@ -342,7 +348,7 @@ def test_validates_nested_dataclass() -> None: def test_validates_mixed_dataclass() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(ComplexObject, lookup_complex)], + [TypeValidatorDefinition(ComplexObject, lookup_complex)], fields={"obj": (DataclassMixed, Undefined)}, ) parsed = parse_obj_as( @@ -365,7 +371,7 @@ def test_validates_mixed_dataclass() -> None: def test_validates_default_value() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(int, lookup)], + [TypeValidatorDefinition(int, lookup)], fields={"a": (int, "e")}, config=DefaultConfig, ) @@ -375,7 +381,7 @@ def test_validates_default_value() -> None: def test_validates_complex_value() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(ComplexObject, lookup_complex)], + [TypeValidatorDefinition(ComplexObject, lookup_complex)], fields={"obj": (ComplexObject, "t")}, config=DefaultConfig, ) @@ -385,7 +391,7 @@ def test_validates_complex_value() -> None: def test_validates_field_info() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(int, lookup)], + [TypeValidatorDefinition(int, lookup)], fields={"a": (int, Field(default="f"))}, config=DefaultConfig, ) @@ -412,7 +418,7 @@ def test_validates_scanspec_with_complex_axis() -> None: def test_model_from_simple_function_signature() -> None: model = create_model_with_type_validators( - "Foo", [TypeConverter(int, lookup)], func=foo + "Foo", [TypeValidatorDefinition(int, lookup)], func=foo ) parsed = parse_obj_as(model, {"a": "g", "b": "hello"}) assert parsed.a == 6 @@ -422,7 +428,7 @@ def test_model_from_simple_function_signature() -> None: def test_model_from_complex_function_signature() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(ComplexObject, lookup_complex)], + [TypeValidatorDefinition(ComplexObject, lookup_complex)], func=bar, config=DefaultConfig, ) @@ -433,7 +439,7 @@ def test_model_from_complex_function_signature() -> None: def test_model_from_nested_function_signature() -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(ComplexObject, lookup_complex)], + [TypeValidatorDefinition(ComplexObject, lookup_complex)], func=baz, config=DefaultConfig, ) @@ -445,7 +451,7 @@ def test_model_from_nested_function_signature() -> None: def parse_spec(spec: Spec) -> Any: model = create_model_with_type_validators( "Foo", - [TypeConverter(ComplexObject, lookup_complex)], + [TypeValidatorDefinition(ComplexObject, lookup_complex)], fields={"spec": (Spec, Undefined)}, ) return parse_obj_as(model, {"spec": spec.serialize()}) @@ -455,7 +461,9 @@ def assert_validates_single_type( field_type: Type, input_value: Any, expected_output: Any ) -> None: model = create_model_with_type_validators( - "Foo", [TypeConverter(int, lookup)], fields={"ch": (field_type, Undefined)} + "Foo", + [TypeValidatorDefinition(int, lookup)], + fields={"ch": (field_type, Undefined)}, ) assert parse_obj_as(model, {"ch": input_value}).ch == expected_output @@ -468,7 +476,7 @@ def assert_validates_complex_object( ) -> None: model = create_model_with_type_validators( "Foo", - [TypeConverter(ComplexObject, lookup_complex)], + [TypeValidatorDefinition(ComplexObject, lookup_complex)], fields={"obj": (field_type, default_value)}, config=DefaultConfig, ) From 3c68ffc552896e91af5ee66f15279d148580e600 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 13 Apr 2023 18:10:15 +0100 Subject: [PATCH 12/21] Fix flake8 errors --- src/blueapi/core/bluesky_types.py | 2 +- src/blueapi/core/context.py | 3 +-- tests/utils/test_type_validator.py | 10 +++++----- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/blueapi/core/bluesky_types.py b/src/blueapi/core/bluesky_types.py index 8a8f35f0b1..3895a3cfbf 100644 --- a/src/blueapi/core/bluesky_types.py +++ b/src/blueapi/core/bluesky_types.py @@ -1,6 +1,6 @@ import inspect from dataclasses import dataclass -from typing import Any, Callable, Generator, Mapping, Optional, Type, Union +from typing import Any, Callable, Generator, Mapping, Type, Union from bluesky.protocols import ( Checkable, diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 367c6542b6..12dda3d4d9 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -13,7 +13,6 @@ Iterable, List, Optional, - Sequence, Set, Tuple, Type, @@ -24,7 +23,7 @@ from bluesky.protocols import Flyable, Readable from pydantic import BaseModel, create_model, validator -from blueapi.utils import load_module_all, schema_for_func +from blueapi.utils import load_module_all from .bluesky_types import ( Device, diff --git a/tests/utils/test_type_validator.py b/tests/utils/test_type_validator.py index 2afd8f874b..f7626ec558 100644 --- a/tests/utils/test_type_validator.py +++ b/tests/utils/test_type_validator.py @@ -1,11 +1,11 @@ -from typing import Any, Dict, List, Mapping, NamedTuple, Set, Tuple, Type +from typing import Any, Dict, List, Mapping, Set, Tuple, Type import pytest from pydantic import BaseConfig, BaseModel, Field, parse_obj_as from pydantic.dataclasses import dataclass from pydantic.fields import Undefined from scanspec.regions import Circle -from scanspec.specs import Line, Product, Spec +from scanspec.specs import Line, Spec from blueapi.utils import TypeValidatorDefinition, create_model_with_type_validators @@ -125,7 +125,7 @@ def test_validates_multiple_types() -> None: ) parsed = parse_obj_as(model, {"a": "c", "b": "hello"}) assert parsed.a == 2 - assert parsed.b == False + assert parsed.b is False def test_validates_multiple_fields() -> None: @@ -155,9 +155,9 @@ def test_validates_multiple_fields_and_types() -> None: ) parsed = parse_obj_as(model, {"a": "c", "b": "hello", "c": "d", "d": "word"}) assert parsed.a == 2 - assert parsed.b == False + assert parsed.b is False assert parsed.c == 3 - assert parsed.d == True + assert parsed.d is True def test_does_not_tolerate_multiple_converters_for_same_type() -> None: From ee82631745c55ad4c3d09c6ba23fa5dd237b8689 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 13 Apr 2023 18:45:50 +0100 Subject: [PATCH 13/21] Fix mypy issues --- src/blueapi/utils/type_validator.py | 31 ++++++----- tests/utils/test_type_validator.py | 82 +++++++++++++++-------------- 2 files changed, 60 insertions(+), 53 deletions(-) diff --git a/src/blueapi/utils/type_validator.py b/src/blueapi/utils/type_validator.py index 9e2ea6b47b..d4cd073259 100644 --- a/src/blueapi/utils/type_validator.py +++ b/src/blueapi/utils/type_validator.py @@ -14,6 +14,7 @@ Tuple, Type, TypeVar, + Union, overload, ) @@ -26,14 +27,14 @@ AnyCallable, AnyClassMethod = Any, Any -_PYDANTIC_LIST_TYPES = [List, Tuple, Set] -_PYDANTIC_DICT_TYPES = [Dict, Mapping] +_PYDANTIC_LIST_TYPES: List[Type] = [List, Tuple, Set] # type: ignore +_PYDANTIC_DICT_TYPES: List[Type] = [Dict, Mapping] T = TypeVar("T") U = TypeVar("U") FieldDefinition = Tuple[Type, Any] Fields = Mapping[str, FieldDefinition] -Validator = Callable[[AnyCallable], AnyClassMethod] +Validator = Union[Callable[[AnyCallable], AnyClassMethod], classmethod] @dataclass @@ -48,7 +49,7 @@ class TypeValidatorDefinition(Generic[T, U]): """ field_type: Type[T] - func: Callable[[T], U] + func: Callable[[U], T] def __str__(self) -> str: type_name = getattr( @@ -61,6 +62,7 @@ def __str__(self) -> str: def create_model_with_type_validators( name: str, definitions: Iterable[TypeValidatorDefinition], + *, fields: Fields, config: Optional[Type[BaseConfig]] = None, ) -> Type[BaseModel]: @@ -85,6 +87,7 @@ def create_model_with_type_validators( def create_model_with_type_validators( name: str, definitions: Iterable[TypeValidatorDefinition], + *, func: Callable[..., Any], config: Optional[Type[BaseConfig]] = None, ) -> Type[BaseModel]: @@ -111,6 +114,7 @@ def create_model_with_type_validators( def create_model_with_type_validators( name: str, definitions: Iterable[TypeValidatorDefinition], + *, base: Type[BaseModel], ) -> Type[BaseModel]: """ @@ -131,6 +135,7 @@ def create_model_with_type_validators( def create_model_with_type_validators( name: str, definitions: Iterable[TypeValidatorDefinition], + *, fields: Optional[Fields] = None, base: Optional[Type[BaseModel]] = None, func: Optional[Callable[..., Any]] = None, @@ -156,22 +161,22 @@ def create_model_with_type_validators( Type[BaseModel]: A new pydantic model """ - fields = fields or {} + all_fields = {**(fields or {})} if base is not None: - fields = {**fields, **_extract_fields_from_model(base)} + all_fields = {**all_fields, **_extract_fields_from_model(base)} if func is not None: - fields = {**fields, **_extract_fields_from_function(func)} - for name, field in fields.items(): + all_fields = {**all_fields, **_extract_fields_from_function(func)} + for name, field in all_fields.items(): annotation, val = field model_type = find_model_type(annotation) if model_type is not None: recursed = create_model_with_type_validators( annotation.__name__, definitions, base=model_type ) - fields[name] = recursed, val - validators = _type_validators(fields, definitions) - return create_model( - name, **fields, __base__=base, __validators__=validators, __config__=config + all_fields[name] = recursed, val + validators = _type_validators(all_fields, definitions) + return create_model( # type: ignore + name, **all_fields, __base__=base, __validators__=validators, __config__=config ) @@ -183,7 +188,7 @@ def _extract_fields_from_model(model: Type[BaseModel]) -> Fields: def _extract_fields_from_function(func: Callable[..., Any]) -> Fields: - fields: Fields = {} + fields: Dict[str, FieldDefinition] = {} for name, param in signature(func).parameters.items(): type_annotation = param.annotation default_value = param.default diff --git a/tests/utils/test_type_validator.py b/tests/utils/test_type_validator.py index f7626ec558..42b6af4903 100644 --- a/tests/utils/test_type_validator.py +++ b/tests/utils/test_type_validator.py @@ -110,8 +110,8 @@ def test_leaves_unvalidated_types_alone() -> None: fields={"a": (int, Undefined), "b": (str, Undefined)}, ) parsed = parse_obj_as(model, {"a": "c", "b": "hello"}) - assert parsed.a == 2 - assert parsed.b == "hello" + assert parsed.a == 2 # type: ignore + assert parsed.b == "hello" # type: ignore def test_validates_multiple_types() -> None: @@ -124,8 +124,8 @@ def test_validates_multiple_types() -> None: fields={"a": (int, Undefined), "b": (bool, Undefined)}, ) parsed = parse_obj_as(model, {"a": "c", "b": "hello"}) - assert parsed.a == 2 - assert parsed.b is False + assert parsed.a == 2 # type: ignore + assert parsed.b is False # type: ignore def test_validates_multiple_fields() -> None: @@ -135,8 +135,8 @@ def test_validates_multiple_fields() -> None: fields={"a": (int, Undefined), "b": (int, Undefined)}, ) parsed = parse_obj_as(model, {"a": "c", "b": "d"}) - assert parsed.a == 2 - assert parsed.b == 3 + assert parsed.a == 2 # type: ignore + assert parsed.b == 3 # type: ignore def test_validates_multiple_fields_and_types() -> None: @@ -154,10 +154,10 @@ def test_validates_multiple_fields_and_types() -> None: }, ) parsed = parse_obj_as(model, {"a": "c", "b": "hello", "c": "d", "d": "word"}) - assert parsed.a == 2 - assert parsed.b is False - assert parsed.c == 3 - assert parsed.d is True + assert parsed.a == 2 # type: ignore + assert parsed.b is False # type: ignore + assert parsed.c == 3 # type: ignore + assert parsed.d is True # type: ignore def test_does_not_tolerate_multiple_converters_for_same_type() -> None: @@ -245,8 +245,8 @@ def test_applies_to_base() -> None: base=Bar, ) parsed = parse_obj_as(model, {"a": 2, "b": "g"}) - assert parsed.a == 2 - assert parsed.b == ComplexObject("g") + assert parsed.a == 2 # type: ignore + assert parsed.b == ComplexObject("g") # type: ignore def test_applies_to_nested_base() -> None: @@ -256,9 +256,9 @@ def test_applies_to_nested_base() -> None: base=Baz, ) parsed = parse_obj_as(model, {"obj": {"a": 2, "b": "g"}, "c": "hello"}) - assert parsed.obj.a == 2 - assert parsed.obj.b == ComplexObject("g") - assert parsed.c == "hello" + assert parsed.obj.a == 2 # type: ignore + assert parsed.obj.b == ComplexObject("g") # type: ignore + assert parsed.c == "hello" # type: ignore def test_validates_submodel() -> None: @@ -276,8 +276,8 @@ def test_validates_submodel() -> None: }, }, ) - assert parsed.obj.a == 2 - assert parsed.obj.b == ComplexObject("g") + assert parsed.obj.a == 2 # type: ignore + assert parsed.obj.b == ComplexObject("g") # type: ignore def test_validates_nested_submodel() -> None: @@ -298,9 +298,9 @@ def test_validates_nested_submodel() -> None: } }, ) - assert parsed.obj.obj.a == 2 - assert parsed.obj.obj.b == ComplexObject("g") - assert parsed.obj.c == "hello" + assert parsed.obj.obj.a == 2 # type: ignore + assert parsed.obj.obj.b == ComplexObject("g") # type: ignore + assert parsed.obj.c == "hello" # type: ignore def test_validates_dataclass() -> None: @@ -318,8 +318,8 @@ def test_validates_dataclass() -> None: }, }, ) - assert parsed.obj.a == 2 - assert parsed.obj.b == ComplexObject("g") + assert parsed.obj.a == 2 # type: ignore + assert parsed.obj.b == ComplexObject("g") # type: ignore def test_validates_nested_dataclass() -> None: @@ -340,9 +340,9 @@ def test_validates_nested_dataclass() -> None: } }, ) - assert parsed.obj.obj.a == 2 - assert parsed.obj.obj.b == ComplexObject("g") - assert parsed.obj.c == "hello" + assert parsed.obj.obj.a == 2 # type: ignore + assert parsed.obj.obj.b == ComplexObject("g") # type: ignore + assert parsed.obj.c == "hello" # type: ignore def test_validates_mixed_dataclass() -> None: @@ -363,9 +363,9 @@ def test_validates_mixed_dataclass() -> None: } }, ) - assert parsed.obj.obj.a == 2 - assert parsed.obj.obj.b == ComplexObject("g") - assert parsed.obj.c == "hello" + assert parsed.obj.obj.a == 2 # type: ignore + assert parsed.obj.obj.b == ComplexObject("g") # type: ignore + assert parsed.obj.c == "hello" # type: ignore def test_validates_default_value() -> None: @@ -375,7 +375,7 @@ def test_validates_default_value() -> None: fields={"a": (int, "e")}, config=DefaultConfig, ) - assert parse_obj_as(model, {}).a == 4 + assert parse_obj_as(model, {}).a == 4 # type: ignore def test_validates_complex_value() -> None: @@ -385,7 +385,7 @@ def test_validates_complex_value() -> None: fields={"obj": (ComplexObject, "t")}, config=DefaultConfig, ) - assert parse_obj_as(model, {}).obj == ComplexObject("t") + assert parse_obj_as(model, {}).obj == ComplexObject("t") # type: ignore def test_validates_field_info() -> None: @@ -395,7 +395,7 @@ def test_validates_field_info() -> None: fields={"a": (int, Field(default="f"))}, config=DefaultConfig, ) - assert parse_obj_as(model, {}).a == 5 + assert parse_obj_as(model, {}).a == 5 # type: ignore @pytest.mark.parametrize( @@ -408,12 +408,12 @@ def test_validates_field_info() -> None: ], ) def test_validates_scanspec(spec: Spec) -> None: - assert parse_spec(spec).spec == spec + assert parse_spec(spec).spec == spec # type: ignore def test_validates_scanspec_with_complex_axis() -> None: spec = Line(ComplexObject("x"), 0.0, 10.0, 10) - assert parse_spec(spec).spec.axes() == [ComplexObject("x")] + assert parse_spec(spec).spec.axes() == [ComplexObject("x")] # type: ignore def test_model_from_simple_function_signature() -> None: @@ -421,8 +421,8 @@ def test_model_from_simple_function_signature() -> None: "Foo", [TypeValidatorDefinition(int, lookup)], func=foo ) parsed = parse_obj_as(model, {"a": "g", "b": "hello"}) - assert parsed.a == 6 - assert parsed.b == "hello" + assert parsed.a == 6 # type: ignore + assert parsed.b == "hello" # type: ignore def test_model_from_complex_function_signature() -> None: @@ -433,7 +433,7 @@ def test_model_from_complex_function_signature() -> None: config=DefaultConfig, ) parsed = parse_obj_as(model, {"obj": "f"}) - assert parsed.obj == ComplexObject("f") + assert parsed.obj == ComplexObject("f") # type: ignore def test_model_from_nested_function_signature() -> None: @@ -444,8 +444,8 @@ def test_model_from_nested_function_signature() -> None: config=DefaultConfig, ) parsed = parse_obj_as(model, {"bar": {"a": 4, "b": "k"}}) - assert parsed.bar.a == 4 - assert parsed.bar.b == ComplexObject("k") + assert parsed.bar.a == 4 # type: ignore + assert parsed.bar.b == ComplexObject("k") # type: ignore def parse_spec(spec: Spec) -> Any: @@ -465,7 +465,8 @@ def assert_validates_single_type( [TypeValidatorDefinition(int, lookup)], fields={"ch": (field_type, Undefined)}, ) - assert parse_obj_as(model, {"ch": input_value}).ch == expected_output + parsed = parse_obj_as(model, {"ch": input_value}) + assert parsed.ch == expected_output # type: ignore def assert_validates_complex_object( @@ -480,4 +481,5 @@ def assert_validates_complex_object( fields={"obj": (field_type, default_value)}, config=DefaultConfig, ) - assert parse_obj_as(model, {"obj": input_value}).obj == expected_output + parsed = parse_obj_as(model, {"obj": input_value}) + assert parsed.obj == expected_output # type: ignore From 5a2c4c8122e874d4b6e3515462e51cd1a034eab7 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 13 Apr 2023 19:46:36 +0100 Subject: [PATCH 14/21] Mostly working worker --- src/blueapi/core/context.py | 85 +++++++------------------- src/blueapi/messaging/stomptemplate.py | 6 +- src/blueapi/service/app.py | 6 +- src/blueapi/utils/__init__.py | 2 + src/blueapi/utils/serialization.py | 24 ++++++++ src/blueapi/worker/event.py | 2 +- src/blueapi/worker/reworker.py | 35 ++++++----- src/blueapi/worker/task.py | 22 +++---- 8 files changed, 85 insertions(+), 97 deletions(-) create mode 100644 src/blueapi/utils/serialization.py diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 12dda3d4d9..cf68285843 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -21,11 +21,16 @@ from bluesky import RunEngine from bluesky.protocols import Flyable, Readable -from pydantic import BaseModel, create_model, validator +from pydantic import BaseConfig, BaseModel, create_model, validator -from blueapi.utils import load_module_all +from blueapi.utils import ( + TypeValidatorDefinition, + create_model_with_type_validators, + load_module_all, +) from .bluesky_types import ( + BLUESKY_PROTOCOLS, Device, Plan, PlanGenerator, @@ -38,6 +43,10 @@ LOGGER = logging.getLogger(__name__) +class PlanModelConfig(BaseConfig): + arbitrary_types_allowed = True + + @dataclass class BlueskyContext: """ @@ -123,10 +132,13 @@ def my_plan(a: int, b: str): if not is_bluesky_plan_generator(plan): raise TypeError(f"{plan} is not a valid plan generator function") - def get_device(name: str) -> Device: - return self.find_device(name) - - model = generate_plan_model(plan, get_device) + validators = device_validators(self) + model = create_model_with_type_validators( + plan.__name__, + validators, + func=plan, + config=PlanModelConfig, + ) self.plans[plan.__name__] = Plan(name=plan.__name__, model=model) self.plan_functions[plan.__name__] = plan return plan @@ -159,60 +171,9 @@ def device(self, device: Device, name: Optional[str] = None) -> None: self.devices[name] = device -def generate_plan_model( - plan: PlanGenerator, get_device: Callable[[str], Device] -) -> Type[BaseModel]: - model_annotations: Dict[str, Tuple[Type, Any]] = {} - validators: Dict[str, Any] = {} - for name, param in signature(plan).parameters.items(): - type_annotation = param.annotation - if is_bluesky_compatible_device_type(type_annotation): - type_annotation = str - validators[name] = validator(name)(get_device) - elif is_iterable_of_devices(type_annotation): - validators[name] = validator(name, each_item=True)(get_device) - - default_value = param.default - if default_value is Parameter.empty: - default_value = ... - - anno = (type_annotation, default_value) - model_annotations[name] = anno - - name = f"{plan.__name__}_model" - from pprint import pprint - - pprint(model_annotations) - return create_model(name, **model_annotations, __validators__=validators) - - -def is_mapping_with_devices(dct: Type) -> bool: - if get_params(dct): - ... - - -def is_iterable_of_devices(lst: Type) -> bool: - if origin_is_iterable(lst): - params = list(get_params(lst)) - if params: - (inner,) = params - return is_bluesky_compatible_device_type(inner) - return False - - -def get_params(maybe_parametrised: Type) -> Iterable[Type]: - for attr in "__args__", "__parameters__": - yield from getattr(maybe_parametrised, attr, []) - - -def origin_is_iterable(to_check: Type) -> bool: - return any( - map( - lambda origin: origin_is(to_check, origin), - [List, Set, Tuple, FrozenSet, Deque], - ) - ) - +def device_validators(ctx: BlueskyContext) -> Iterable[TypeValidatorDefinition]: + def get_device(name: str) -> Device: + return ctx.find_device(name) -def origin_is(to_check: Type, origin: Type) -> bool: - return hasattr(to_check, "__origin__") and to_check.__origin__ is origin + for proto in BLUESKY_PROTOCOLS: + yield TypeValidatorDefinition(proto, get_device) diff --git a/src/blueapi/messaging/stomptemplate.py b/src/blueapi/messaging/stomptemplate.py index c69acd0468..856abc001f 100644 --- a/src/blueapi/messaging/stomptemplate.py +++ b/src/blueapi/messaging/stomptemplate.py @@ -8,12 +8,12 @@ from typing import Any, Callable, Dict, List, Optional, Set import stomp -from apischema import deserialize, serialize +from pydantic import parse_obj_as from stomp.exception import ConnectFailedException from stomp.utils import Frame from blueapi.config import StompConfig -from blueapi.utils import handle_all_exceptions +from blueapi.utils import handle_all_exceptions, serialize from .base import DestinationProvider, MessageListener, MessagingTemplate from .context import MessageContext @@ -140,7 +140,7 @@ def subscribe(self, destination: str, callback: MessageListener) -> None: def wrapper(frame: Frame) -> None: as_dict = json.loads(frame.body) - value = deserialize(obj_type, as_dict) + value = parse_obj_as(obj_type, as_dict) context = MessageContext( frame.headers["destination"], diff --git a/src/blueapi/service/app.py b/src/blueapi/service/app.py index c53e6b5929..ff641d93c3 100644 --- a/src/blueapi/service/app.py +++ b/src/blueapi/service/app.py @@ -77,12 +77,12 @@ def _on_run_request(self, message_context: MessageContext, task: RunPlan) -> Non reply_queue = message_context.reply_destination if reply_queue is not None: - response = TaskResponse(correlation_id) + response = TaskResponse(task_name=correlation_id) self._template.send(reply_queue, response) def _get_plans(self, message_context: MessageContext, message: PlanRequest) -> None: plans = list(map(PlanModel.from_plan, self._ctx.plans.values())) - response = PlanResponse(plans) + response = PlanResponse(plans=plans) assert message_context.reply_destination is not None self._template.send(message_context.reply_destination, response) @@ -91,7 +91,7 @@ def _get_devices( self, message_context: MessageContext, message: DeviceRequest ) -> None: devices = list(map(DeviceModel.from_device, self._ctx.devices.values())) - response = DeviceResponse(devices) + response = DeviceResponse(devices=devices) assert message_context.reply_destination is not None self._template.send(message_context.reply_destination, response) diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index fa8b28d207..db6a8d2bfe 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,6 +1,7 @@ from .config import ConfigLoader from .modules import load_module_all from .schema import nested_deserialize_with_overrides, schema_for_func +from .serialization import serialize from .thread_exception import handle_all_exceptions from .type_validator import TypeValidatorDefinition, create_model_with_type_validators @@ -12,4 +13,5 @@ "ConfigLoader", "create_model_with_type_validators", "TypeValidatorDefinition", + "serialize", ] diff --git a/src/blueapi/utils/serialization.py b/src/blueapi/utils/serialization.py new file mode 100644 index 0000000000..141d0b7023 --- /dev/null +++ b/src/blueapi/utils/serialization.py @@ -0,0 +1,24 @@ +from typing import Any + +from pydantic import BaseModel + + +def serialize(obj: Any) -> Any: + """ + Pydantic-aware serialization routine that can also be + used on primitives. So serialize(4) is 4, but + serialize() is a dictionary. + + Args: + obj: The object to serialize + + Returns: + Any: The serialized object + """ + + if isinstance(obj, BaseModel): + return obj.dict() + elif hasattr(obj, "__pydantic_model__"): + return serialize(getattr(obj, "__pydantic_model__")) + else: + return obj diff --git a/src/blueapi/worker/event.py b/src/blueapi/worker/event.py index 7d8fbde698..457ec443c2 100644 --- a/src/blueapi/worker/event.py +++ b/src/blueapi/worker/event.py @@ -9,7 +9,7 @@ RawRunEngineState = Union[PropertyMachine, ProxyString, str] -class WorkerState(Enum): +class WorkerState(str, Enum): """ The state of the Worker. """ diff --git a/src/blueapi/worker/reworker.py b/src/blueapi/worker/reworker.py index 382d99415c..f96347e806 100644 --- a/src/blueapi/worker/reworker.py +++ b/src/blueapi/worker/reworker.py @@ -141,16 +141,21 @@ def _report_status( warnings = self._warnings if self._current is not None: task_status = TaskStatus( - self._current.name, - self._current.is_complete, - self._current.is_error or bool(errors), + task_name=self._current.name, + task_complete=self._current.is_complete, + task_failed=self._current.is_error or bool(errors), ) correlation_id = self._current.name else: task_status = None correlation_id = None - event = WorkerEvent(self._state, task_status, errors, warnings) + event = WorkerEvent( + state=self._state, + task_status=task_status, + errors=errors, + warnings=warnings, + ) self._worker_events.publish(event, correlation_id) def _on_document(self, name: str, document: Mapping[str, Any]) -> None: @@ -204,16 +209,16 @@ def _on_status_event( else: percentage = 1.0 view = StatusView( - name or "UNKNOWN", - current, - initial, - target, - unit or "units", - precision or 3, - status.done, - percentage, - time_elapsed, - time_remaining, + display_name=name or "UNKNOWN", + current=current, + initial=initial, + target=target, + unit=unit or "units", + precision=precision or 3, + done=status.done, + percentage=percentage, + time_elapsed=time_elapsed, + time_remaining=time_remaining, ) self._status_snapshot[status_name] = view self._publish_status_snapshot() @@ -224,7 +229,7 @@ def _publish_status_snapshot(self) -> None: else: self._progress_events.publish( ProgressEvent( - self._current.name, + task_name=self._current.name, statuses=self._status_snapshot, ), self._current.name, diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index be67067676..46a516151a 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -6,7 +6,12 @@ from pydantic import BaseModel, Field, parse_obj_as from pydantic.decorator import ValidatedFunction -from blueapi.core import BlueskyContext, Device, create_bluesky_protocol_conversions +from blueapi.core import ( + BlueskyContext, + Device, + Plan, + create_bluesky_protocol_conversions, +) from blueapi.utils import nested_deserialize_with_overrides @@ -43,13 +48,14 @@ def do_task(self, ctx: BlueskyContext) -> None: LOGGER.info(f"Asked to run plan {self.name} with {self.params}") plan = ctx.plans[self.name] + func = ctx.plan_functions[self.name] sanitized_params = _lookup_params(ctx, plan, self.params) - plan_generator = plan.call(**sanitized_params) + plan_generator = func(**sanitized_params.dict()) ctx.run_engine(plan_generator) def _lookup_params( - ctx: BlueskyContext, plan: ValidatedFunction, params: Mapping[str, Any] + ctx: BlueskyContext, plan: Plan, params: Mapping[str, Any] ) -> BaseModel: """ Checks plan parameters against context @@ -66,16 +72,6 @@ def _lookup_params( model = plan.model return parse_obj_as(model, params) - def find_device(name: str) -> Device: - device = ctx.find_device(name) - if device is not None: - return device - else: - raise KeyError(f"Could not find device {name}") - - overrides = list(create_bluesky_protocol_conversions(find_device)) - return nested_deserialize_with_overrides(plan.model, params, overrides).__dict__ - @dataclass class ActiveTask: From 492d0ce679cbc9e99d5b0ae90f88ab99171f0db8 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 13 Apr 2023 20:21:17 +0100 Subject: [PATCH 15/21] Make data events a pydantic model --- src/blueapi/core/bluesky_types.py | 3 +-- src/blueapi/worker/reworker.py | 4 +++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/blueapi/core/bluesky_types.py b/src/blueapi/core/bluesky_types.py index 3895a3cfbf..ac23111298 100644 --- a/src/blueapi/core/bluesky_types.py +++ b/src/blueapi/core/bluesky_types.py @@ -91,8 +91,7 @@ class Plan(BaseModel): ) -@dataclass -class DataEvent: +class DataEvent(BaseModel): """ Event representing collection of some data. Conforms to the Bluesky event model: https://github.com/bluesky/event-model diff --git a/src/blueapi/worker/reworker.py b/src/blueapi/worker/reworker.py index f96347e806..1415325a4a 100644 --- a/src/blueapi/worker/reworker.py +++ b/src/blueapi/worker/reworker.py @@ -161,7 +161,9 @@ def _report_status( def _on_document(self, name: str, document: Mapping[str, Any]) -> None: if self._current is not None: correlation_id = self._current.name - self._data_events.publish(DataEvent(name, document), correlation_id) + self._data_events.publish( + DataEvent(name=name, doc=document), correlation_id + ) else: raise KeyError( "Trying to emit a document despite the fact that the RunEngine is idle" From d4b4022e08938e283c373bd50963814c4a92d496 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 13 Apr 2023 20:31:15 +0100 Subject: [PATCH 16/21] Remove all traces of apischema --- src/blueapi/cli/cli.py | 9 ++- src/blueapi/config.py | 21 +++--- src/blueapi/core/__init__.py | 2 - src/blueapi/core/device_lookup.py | 46 +------------- src/blueapi/plans/plans.py | 19 +----- src/blueapi/utils/__init__.py | 3 - src/blueapi/utils/config.py | 6 +- src/blueapi/utils/schema.py | 102 ------------------------------ src/blueapi/worker/task.py | 11 +--- tests/utils/test_config.py | 17 ++--- 10 files changed, 32 insertions(+), 204 deletions(-) delete mode 100644 src/blueapi/utils/schema.py diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index b969fe70eb..57ac9414d4 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -61,7 +61,14 @@ def controller(ctx, host: str, port: int, log_level: str): return logging.basicConfig(level=log_level) ctx.ensure_object(dict) - client = AmqClient(StompMessagingTemplate.autoconfigured(StompConfig(host, port))) + client = AmqClient( + StompMessagingTemplate.autoconfigured( + StompConfig( + host=host, + port=port, + ) + ) + ) ctx.obj["client"] = client client.app.connect() diff --git a/src/blueapi/config.py b/src/blueapi/config.py index e661965432..6000628d28 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -1,10 +1,10 @@ -from dataclasses import dataclass, field from pathlib import Path from typing import Union +from pydantic import BaseModel, Field -@dataclass -class StompConfig: + +class StompConfig(BaseModel): """ Config for connecting to stomp broker """ @@ -13,8 +13,7 @@ class StompConfig: port: int = 61613 -@dataclass -class EnvironmentConfig: +class EnvironmentConfig(BaseModel): """ Config for the RunEngine environment """ @@ -22,18 +21,16 @@ class EnvironmentConfig: startup_script: Union[Path, str] = "blueapi.startup.example" -@dataclass -class LoggingConfig: +class LoggingConfig(BaseModel): level: str = "INFO" -@dataclass -class ApplicationConfig: +class ApplicationConfig(BaseModel): """ Config for the worker application as a whole. Root of config tree. """ - stomp: StompConfig = field(default_factory=StompConfig) - env: EnvironmentConfig = field(default_factory=EnvironmentConfig) - logging: LoggingConfig = field(default_factory=LoggingConfig) + stomp: StompConfig = Field(default_factory=StompConfig) + env: EnvironmentConfig = Field(default_factory=EnvironmentConfig) + logging: LoggingConfig = Field(default_factory=LoggingConfig) diff --git a/src/blueapi/core/__init__.py b/src/blueapi/core/__init__.py index afc759a5c3..061c0ce031 100644 --- a/src/blueapi/core/__init__.py +++ b/src/blueapi/core/__init__.py @@ -11,7 +11,6 @@ is_bluesky_plan_generator, ) from .context import BlueskyContext -from .device_lookup import create_bluesky_protocol_conversions from .event import EventPublisher, EventStream __all__ = [ @@ -20,7 +19,6 @@ "MsgGenerator", "Device", "BLUESKY_PROTOCOLS", - "create_bluesky_protocol_conversions", "BlueskyContext", "EventPublisher", "EventStream", diff --git a/src/blueapi/core/device_lookup.py b/src/blueapi/core/device_lookup.py index 95c228a06f..957a057f7e 100644 --- a/src/blueapi/core/device_lookup.py +++ b/src/blueapi/core/device_lookup.py @@ -1,48 +1,6 @@ -from functools import partial -from typing import Any, Callable, Iterable, List, Optional, Type, TypeVar - -from apischema.conversions.conversions import Conversion - -from .bluesky_types import BLUESKY_PROTOCOLS, Device, is_bluesky_compatible_device - - -def create_bluesky_protocol_conversions( - device_lookup: Callable[[str], Device], -) -> Iterable[Conversion]: - """ - Generate a series of APISchema Conversions for the valid Device types. - The conversions use a user-defined function to lookup devices by name. - - Args: - device_lookup (Callable[[str], Device]): Function to lookup Device by name, - expects an Exception if name not - found - - Returns: - Iterable[Conversion]: Conversions for locating devices - """ - - def find_device_matching_name_and_type(target_type: Type, name: str) -> Any: - # Find the device in the - device = device_lookup(name) - - # The schema has asked for a particular protocol, at this point in the code we - # have found the device but need to check that it complies with the requested - # protocol. If it doesn't, it means there is a typing error in the plan. - if isinstance(device, target_type): - return device - else: - raise TypeError(f"{name} needs to be of type {target_type}") - - # Create a conversion for each type, the conversion function will automatically - # perform a structural subtyping check - for a_type in BLUESKY_PROTOCOLS: - yield Conversion( - partial(find_device_matching_name_and_type, a_type), - source=str, - target=a_type, - ) +from typing import Any, List, Optional, TypeVar +from .bluesky_types import Device, is_bluesky_compatible_device #: Device obeying Bluesky protocols D = TypeVar("D", bound=Device) diff --git a/src/blueapi/plans/plans.py b/src/blueapi/plans/plans.py index de9baff93a..c0ef93c879 100644 --- a/src/blueapi/plans/plans.py +++ b/src/blueapi/plans/plans.py @@ -3,9 +3,6 @@ from typing import Any, List, Mapping, Optional, Tuple, Type, Union import bluesky.plans as bp -from apischema import serialize -from apischema.conversions.conversions import Conversion -from apischema.conversions.converters import AnyConversion, default_serialization from bluesky.protocols import Movable, Readable from cycler import Cycler, cycler from scanspec.specs import Spec @@ -38,8 +35,8 @@ def scan( metadata = { "detectors": [detector.name for detector in detectors], - "scanspec": serialize(spec, default_conversion=_convert_devices), - "shape": _shape(spec), + # "scanspec": serialize(spec, default_conversion=_convert_devices), + "shape": spec.shape(), **(metadata or {}), } @@ -47,18 +44,6 @@ def scan( yield from bp.scan_nd(detectors, cycler, md=metadata) -# TODO: Use built-in scanspec utility method following completion of DAQ-4487 -def _shape(spec: Spec[Movable]) -> Tuple[int, ...]: - return tuple(len(dim) for dim in spec.calculate()) - - -def _convert_devices(a_type: Type[Any]) -> Optional[AnyConversion]: - if issubclass(a_type, Movable): - return Conversion(str, source=a_type) - else: - return default_serialization(a_type) - - def _scanspec_to_cycler(spec: Spec) -> Cycler: """ Convert a scanspec to a cycler for compatibility with legacy Bluesky plans such as diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index db6a8d2bfe..4d1bff4dc0 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,14 +1,11 @@ from .config import ConfigLoader from .modules import load_module_all -from .schema import nested_deserialize_with_overrides, schema_for_func from .serialization import serialize from .thread_exception import handle_all_exceptions from .type_validator import TypeValidatorDefinition, create_model_with_type_validators __all__ = [ "handle_all_exceptions", - "nested_deserialize_with_overrides", - "schema_for_func", "load_module_all", "ConfigLoader", "create_model_with_type_validators", diff --git a/src/blueapi/utils/config.py b/src/blueapi/utils/config.py index bcaa0c72c5..93d14a2ab0 100644 --- a/src/blueapi/utils/config.py +++ b/src/blueapi/utils/config.py @@ -2,10 +2,10 @@ from typing import Any, Generic, Mapping, Type, TypeVar import yaml -from apischema import deserialize +from pydantic import BaseModel, parse_obj_as #: Configuration schema dataclass -C = TypeVar("C") +C = TypeVar("C", bound=BaseModel) class ConfigLoader(Generic[C]): @@ -59,4 +59,4 @@ def load(self) -> C: C: Dataclass instance holding config """ - return deserialize(self._schema, self._values) + return parse_obj_as(self._schema, self._values) diff --git a/src/blueapi/utils/schema.py b/src/blueapi/utils/schema.py deleted file mode 100644 index 88e260fe4f..0000000000 --- a/src/blueapi/utils/schema.py +++ /dev/null @@ -1,102 +0,0 @@ -from dataclasses import make_dataclass -from inspect import Parameter, signature -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, TypeVar, Union - -from apischema import deserialize -from apischema.conversions.conversions import Conversion -from apischema.conversions.converters import AnyConversion, default_deserialization -from pydantic import BaseModel - - -def schema_for_func(func: Callable[..., Any]) -> BaseModel: - """ - Generate a pydantic model of the set of parameters to a function. - Inspect the parameters, default values and type annotations of a function and - generate the schema. - - Example: - - def foo(a: int, b: str, c: bool = False): - ... - - schema = schema_for_func(foo) - - Schema is the runtime equivalent of: - - class foo_params(BaseModel): - a: int - b: str - c: bool = False - - Args: - func: The source function, all parameters must have type annotations - - Raises: - TypeError: If a type annotation is either `Any` or not supplied - - Returns: - Type: A runtime dataclass whose fields encapsulate the names, types and default - values of the function parameters - """ - - class_name = f"{func.__name__}_params" - fields: List[Union[Tuple[str, Type, Any], Tuple[str, Type]]] = [] - - # Iterate through parameters and convert them to dataclass fields - for name, param in signature(func).parameters.items(): - a_type = param.annotation - # Do not allow parameters without type annotations or with the `Any` annotation - if a_type is Parameter.empty: - raise TypeError( - f"Error serializing function {func.__name__}, all parameters must have " - "a type annotation" - ) - elif a_type is Any: - raise TypeError( - f"Error serializing function {func.__name__} parameter {name} all " - "parameters cannot have `Any` as a type annotation" - ) - - default_value = param.default - - # Include the default value in the field if there is onee - if default_value is not Parameter.empty: - fields.append((name, a_type, default_value)) - else: - fields.append((name, a_type)) - - data_class = make_dataclass(class_name, fields) - return data_class - - -T = TypeVar("T") - - -def nested_deserialize_with_overrides( - schema: Type[T], obj: Any, overrides: Optional[Iterable[Conversion]] = None -) -> T: - """ - Deserialize a dictionary using apischema with custom overrides. Unlike apischema's - built-in override argument, this propagates the overrides to nested dictionaries. - - Args: - schema (Type[T]): Type to deserialize to - obj (Any): Raw object to deserialize, usually a dictionary - overrides (Optional[Iterable[Conversion]], optional): apischema conversions to - customize deserialization. - Defaults to None. - - Returns: - T: Deserialized object - """ - - conversions = {conversion.target: conversion for conversion in overrides or []} - - def deserialize_with_converters(a_type: Type[Any]) -> Optional[AnyConversion]: - # If the type is in _conversions then we can override the function used to - # resolve the parameter, otherwise we use apischema's default deserializer - if a_type in conversions.keys(): - return conversions[a_type] - return default_deserialization(a_type) - - return deserialize(schema, obj, default_conversion=deserialize_with_converters) diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index 46a516151a..efbb01313e 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -4,15 +4,8 @@ from typing import Any, Mapping from pydantic import BaseModel, Field, parse_obj_as -from pydantic.decorator import ValidatedFunction - -from blueapi.core import ( - BlueskyContext, - Device, - Plan, - create_bluesky_protocol_conversions, -) -from blueapi.utils import nested_deserialize_with_overrides + +from blueapi.core import BlueskyContext, Plan # TODO: Make a TaggedUnion diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py index f4abb3b4c3..6722ad5e14 100644 --- a/tests/utils/test_config.py +++ b/tests/utils/test_config.py @@ -1,35 +1,30 @@ import os -from dataclasses import dataclass, field from pathlib import Path from typing import Any, Type import pytest -from apischema import ValidationError +from pydantic import BaseModel, Field, ValidationError from blueapi.utils import ConfigLoader -@dataclass -class Config: +class Config(BaseModel): foo: int bar: str -@dataclass -class ConfigWithDefaults: +class ConfigWithDefaults(BaseModel): foo: int = 3 bar: str = "hello world" -@dataclass -class NestedConfig: +class NestedConfig(BaseModel): nested: Config baz: bool -@dataclass -class NestedConfigWithDefaults: - nested: ConfigWithDefaults = field(default_factory=ConfigWithDefaults) +class NestedConfigWithDefaults(BaseModel): + nested: ConfigWithDefaults = Field(default_factory=ConfigWithDefaults) baz: bool = False From fca93493bde93ff564f0988c71ceecfd982d43b6 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 13 Apr 2023 20:51:24 +0100 Subject: [PATCH 17/21] Write failing scanspec test --- tests/utils/test_type_validator.py | 59 ++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 11 deletions(-) diff --git a/tests/utils/test_type_validator.py b/tests/utils/test_type_validator.py index 42b6af4903..a4b552fcef 100644 --- a/tests/utils/test_type_validator.py +++ b/tests/utils/test_type_validator.py @@ -38,6 +38,14 @@ def __repr__(self) -> str: return f"ComplexObject({self._name})" +class SpecWrapper(BaseModel): + spec: Spec + + +def spec_wrapper(spec: Spec) -> None: + ... + + class Bar(BaseModel): a: int b: ComplexObject @@ -398,22 +406,51 @@ def test_validates_field_info() -> None: assert parse_obj_as(model, {}).a == 5 # type: ignore -@pytest.mark.parametrize( - "spec", - [ - Line("x", 0.0, 10.0, 10), - Line("x", 0.0, 10.0, 10) * Line("y", 0.0, 10.0, 10), - (Line("x", 0.0, 10.0, 10) * Line("y", 0.0, 10.0, 10)) - & Circle("x", "y", 1.0, 2.8, radius=0.5), - ], -) +SPECS = [ + Line("x", 0.0, 10.0, 10), + Line("x", 0.0, 10.0, 10) * Line("y", 0.0, 10.0, 10), + (Line("x", 0.0, 10.0, 10) * Line("y", 0.0, 10.0, 10)) + & Circle("x", "y", 1.0, 2.8, radius=0.5), +] + + +@pytest.mark.parametrize("spec", SPECS) def test_validates_scanspec(spec: Spec) -> None: assert parse_spec(spec).spec == spec # type: ignore +@pytest.mark.parametrize("spec", SPECS) +def test_validates_scanspec_wrapper(spec: Spec) -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + fields={"wrapper": (SpecWrapper, Undefined)}, + ) + parsed = parse_obj_as(model, {"wrapper": {"spec": spec.serialize()}}) + assert parsed.wrapper.spec == spec + + +@pytest.mark.parametrize("spec", SPECS) +def test_validates_scanspec_wrapping_function(spec: Spec) -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + func=spec_wrapper, + ) + parsed = parse_obj_as(model, {"spec": spec.serialize()}) + assert parsed.spec == spec + + def test_validates_scanspec_with_complex_axis() -> None: - spec = Line(ComplexObject("x"), 0.0, 10.0, 10) - assert parse_spec(spec).spec.axes() == [ComplexObject("x")] # type: ignore + spec = Line("x", 0.0, 10.0, 10) + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + fields={"spec": (Spec[ComplexObject], Undefined)}, + config=DefaultConfig, + ) + parsed = parse_obj_as(model, {"spec": spec.serialize()}) + assert parsed.spec.axes() == [ComplexObject("x")] # type: ignore def test_model_from_simple_function_signature() -> None: From 3be1696ec573c8fe0291caa88318b0f3aa7cb651 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Fri, 14 Apr 2023 12:50:28 +0100 Subject: [PATCH 18/21] Remove unused tests --- src/blueapi/utils/type_validator.py | 108 ++++++++++++++++++++++------ tests/core/test_device_lookup.py | 13 ---- tests/utils/test_schema.py | 39 ---------- tests/utils/test_type_validator.py | 40 +++++++++-- 4 files changed, 119 insertions(+), 81 deletions(-) delete mode 100644 tests/core/test_device_lookup.py delete mode 100644 tests/utils/test_schema.py diff --git a/src/blueapi/utils/type_validator.py b/src/blueapi/utils/type_validator.py index d4cd073259..05be7829b1 100644 --- a/src/blueapi/utils/type_validator.py +++ b/src/blueapi/utils/type_validator.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping as AbcMapping from dataclasses import dataclass from inspect import Parameter, isclass, signature from typing import ( @@ -15,6 +16,7 @@ Type, TypeVar, Union, + get_args, overload, ) @@ -61,7 +63,7 @@ def __str__(self) -> str: @overload def create_model_with_type_validators( name: str, - definitions: Iterable[TypeValidatorDefinition], + definitions: List[TypeValidatorDefinition], *, fields: Fields, config: Optional[Type[BaseConfig]] = None, @@ -86,7 +88,7 @@ def create_model_with_type_validators( @overload def create_model_with_type_validators( name: str, - definitions: Iterable[TypeValidatorDefinition], + definitions: List[TypeValidatorDefinition], *, func: Callable[..., Any], config: Optional[Type[BaseConfig]] = None, @@ -113,7 +115,7 @@ def create_model_with_type_validators( @overload def create_model_with_type_validators( name: str, - definitions: Iterable[TypeValidatorDefinition], + definitions: List[TypeValidatorDefinition], *, base: Type[BaseModel], ) -> Type[BaseModel]: @@ -134,12 +136,13 @@ def create_model_with_type_validators( def create_model_with_type_validators( name: str, - definitions: Iterable[TypeValidatorDefinition], + definitions: List[TypeValidatorDefinition], *, fields: Optional[Fields] = None, base: Optional[Type[BaseModel]] = None, func: Optional[Callable[..., Any]] = None, config: Optional[Type[BaseConfig]] = None, + cache: Optional[Dict[Type, Type]] = None, ) -> Type[BaseModel]: """ Create a pydantic model with type validators according to @@ -161,6 +164,7 @@ def create_model_with_type_validators( Type[BaseModel]: A new pydantic model """ + cache = cache or {} all_fields = {**(fields or {})} if base is not None: all_fields = {**all_fields, **_extract_fields_from_model(base)} @@ -168,18 +172,72 @@ def create_model_with_type_validators( all_fields = {**all_fields, **_extract_fields_from_function(func)} for name, field in all_fields.items(): annotation, val = field - model_type = find_model_type(annotation) - if model_type is not None: - recursed = create_model_with_type_validators( - annotation.__name__, definitions, base=model_type - ) - all_fields[name] = recursed, val + if annotation in cache: + all_fields[name] = cache[annotation], val + else: + all_fields[name] = apply_type_validators(annotation, definitions), val + # model_type = find_model_type(annotation) + # if model_type is not None: + # recursed = create_model_with_type_validators( + # annotation.__name__, definitions, base=model_type + # ) + # all_fields[name] = recursed, val validators = _type_validators(all_fields, definitions) return create_model( # type: ignore name, **all_fields, __base__=base, __validators__=validators, __config__=config ) +def apply_type_validators( + model_type: Type, + definitions: List[TypeValidatorDefinition], + cache: Optional[Dict[Type, Type]] = None, +) -> Type: + cache = cache or {} + if model_type in cache: + return cache[model_type] + + if isclass(model_type) and issubclass(model_type, BaseModel): + if "__root__" in model_type.__fields__: + # return create_model_with_type_validators( + # model_type.__name__, + # definitions, + # fields=_extract_fields_from_model(model_type), + # ) + return apply_type_validators( + model_type.__fields__["__root__"].type_, definitions, cache=cache + ) + else: + return create_model_with_type_validators( + model_type.__name__, + definitions, + base=model_type, + ) + elif isclass(model_type) and hasattr(model_type, "__pydantic_model__"): + model = getattr(model_type, "__pydantic_model__") + return apply_type_validators(model, definitions, cache=cache) + else: + params = [ + apply_type_validators(param, definitions, cache=cache) + for param in get_args(model_type) + ] + if params and hasattr(model_type, "__origin__"): + origin = getattr(model_type, "__origin__") + origin = _sanitise_origin(origin) + return origin[tuple(params)] + return model_type + + +def _sanitise_origin(origin: Type) -> Type: + return { + list: List, + set: Set, + tuple: Tuple, + AbcMapping: Mapping, + dict: Mapping, + }.get(origin, origin) + + def _extract_fields_from_model(model: Type[BaseModel]) -> Fields: return { name: (field.type_, field.field_info) @@ -253,24 +311,28 @@ def is_type_or_container_type(type_to_check: Type, field_type: Type) -> bool: def params_contains(type_to_check: Type, field_type: Type) -> bool: - type_params = list( - getattr( - type_to_check, - "__args__", - [], - ) - ) + list( - getattr( - type_to_check, - "__parameters__", - [], - ) - ) + type_params = get_args(type_to_check) return type_to_check is field_type or any( map(lambda v: params_contains(v, field_type), type_params) ) +# def params_of_type(type_to_check: Type) -> List[Type]: +# return list( +# getattr( +# type_to_check, +# "__args__", +# [], +# ) +# ) + list( +# getattr( +# type_to_check, +# "__parameters__", +# [], +# ) +# ) + + def apply_to_scalars(func: Callable[[T], U], obj: Any) -> Any: if is_list_type(obj): return list(map(lambda v: apply_to_scalars(func, v), obj)) diff --git a/tests/core/test_device_lookup.py b/tests/core/test_device_lookup.py deleted file mode 100644 index 7bac1d908b..0000000000 --- a/tests/core/test_device_lookup.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Any, Type -from unittest.mock import MagicMock - -import pytest - -from blueapi.core import BLUESKY_PROTOCOLS, create_bluesky_protocol_conversions - - -@pytest.mark.parametrize("a_type", BLUESKY_PROTOCOLS) -def test_creates_resolver_for(a_type: Type[Any]): - converters = create_bluesky_protocol_conversions(MagicMock()) - target_types = map(lambda c: c.target, converters) - assert a_type in list(target_types) diff --git a/tests/utils/test_schema.py b/tests/utils/test_schema.py deleted file mode 100644 index 48728d9d78..0000000000 --- a/tests/utils/test_schema.py +++ /dev/null @@ -1,39 +0,0 @@ -import dataclasses -from typing import Any - -import pytest - -from blueapi.utils import schema_for_func - - -def test_schema_generated() -> None: - def func(foo: int, bar: str = "hello") -> None: - ... - - schema = schema_for_func(func) - assert dataclasses.is_dataclass(schema) - foo, bar = dataclasses.fields(schema) - - assert foo.name == "foo" - assert foo.type == int - assert foo.default == dataclasses.MISSING - - assert bar.name == "bar" - assert bar.type == str - assert bar.default == "hello" - - -def test_rejects_any() -> None: - def func(foo: int, bar: Any) -> None: - ... - - with pytest.raises(TypeError): - schema_for_func(func) - - -def test_rejects_no_param() -> None: - def func(foo: int, bar) -> None: - ... - - with pytest.raises(TypeError): - schema_for_func(func) diff --git a/tests/utils/test_type_validator.py b/tests/utils/test_type_validator.py index a4b552fcef..4dc74388c2 100644 --- a/tests/utils/test_type_validator.py +++ b/tests/utils/test_type_validator.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Mapping, Set, Tuple, Type +from typing import Any, Dict, List, Literal, Mapping, Optional, Set, Tuple, Type, Union import pytest from pydantic import BaseConfig, BaseModel, Field, parse_obj_as @@ -49,6 +49,7 @@ def spec_wrapper(spec: Spec) -> None: class Bar(BaseModel): a: int b: ComplexObject + type: Literal["Bar"] = Field(default="Bar") class Config: arbitrary_types_allowed = True @@ -57,6 +58,15 @@ class Config: class Baz(BaseModel): obj: Bar c: str + type: Literal["Baz"] = Field(default="Baz") + + +class ComplexLinkedList(BaseModel): + obj: ComplexObject + child: Optional["ComplexLinkedList"] = None + + class Config: + arbitrary_types_allowed = True @dataclass(config=DefaultConfig) @@ -441,16 +451,34 @@ def test_validates_scanspec_wrapping_function(spec: Spec) -> None: assert parsed.spec == spec -def test_validates_scanspec_with_complex_axis() -> None: - spec = Line("x", 0.0, 10.0, 10) +def lookup_union(value: Union[int, str]) -> int: + if isinstance(value, str): + return lookup(value) + else: + return value + + +@pytest.mark.parametrize("value,expected", [(4, 4), ("b", 1)]) +def test_validates_union(value: Union[int, str], expected: int) -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(Union[int, str], lookup_union)], + fields={"un": (Union[int, str], Undefined)}, + config=DefaultConfig, + ) + parsed = parse_obj_as(model, {"un": value}) + assert parsed.un == expected # type: ignore + + +def test_validates_model_union() -> None: model = create_model_with_type_validators( "Foo", [TypeValidatorDefinition(ComplexObject, lookup_complex)], - fields={"spec": (Spec[ComplexObject], Undefined)}, + fields={"un": (Union[Bar, Baz], Field(..., discriminator="type"))}, config=DefaultConfig, ) - parsed = parse_obj_as(model, {"spec": spec.serialize()}) - assert parsed.spec.axes() == [ComplexObject("x")] # type: ignore + parsed = parse_obj_as(model, {"un": {"a": 5, "b": "g", "type": "Bar"}}) + assert parsed.un == Bar(a=5, b=ComplexObject("g")) # type: ignore def test_model_from_simple_function_signature() -> None: From 23c793ead4db1677722249b349c0ba9065a076a2 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Fri, 14 Apr 2023 13:04:07 +0100 Subject: [PATCH 19/21] Implement workaround to make scans work --- src/blueapi/plans/plans.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/blueapi/plans/plans.py b/src/blueapi/plans/plans.py index c0ef93c879..be805f4380 100644 --- a/src/blueapi/plans/plans.py +++ b/src/blueapi/plans/plans.py @@ -12,17 +12,20 @@ def scan( detectors: List[Readable], - spec: Spec[Movable], + axes_to_move: Mapping[str, Movable], + spec: Spec[str], metadata: Optional[Mapping[str, Any]] = None, ) -> MsgGenerator: """ Scan wrapping `bp.scan_nd` Args: - detectors (List[Readable]): List of readable devices, will take a reading at + detectors: List of readable devices, will take a reading at each point - spec (Spec[Movable]): ScanSpec modelling the path of the scan - metadata (Optional[Mapping[str, Any]], optional): Key-value metadata to include + axes_to_move: All axes involved in this scan, names and + objects + spec: ScanSpec modelling the path of the scan + metadata: Key-value metadata to include in exported data, defaults to None. @@ -40,24 +43,26 @@ def scan( **(metadata or {}), } - cycler = _scanspec_to_cycler(spec) + cycler = _scanspec_to_cycler(spec, axes_to_move) yield from bp.scan_nd(detectors, cycler, md=metadata) -def _scanspec_to_cycler(spec: Spec) -> Cycler: +def _scanspec_to_cycler(spec: Spec[str], axes: Mapping[str, Movable]) -> Cycler: """ Convert a scanspec to a cycler for compatibility with legacy Bluesky plans such as `bp.scan_nd`. Use the midpoints of the scanspec since cyclers are noramlly used for software triggered scans. Args: - spec (Spec): A scanspec + spec: A scanspec + axes: Names and axes to move Returns: Cycler: A new cycler """ midpoints = spec.frames().midpoints + midpoints = {axes[name]: points for name, points in midpoints.items()} # Need to "add" the cyclers for all the axes together. The code below is # effectively: cycler(motor1, [...]) + cycler(motor2, [...]) + ... From 88644e31b5ae16e49e50312344dbac60f1689796 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Fri, 14 Apr 2023 13:18:03 +0100 Subject: [PATCH 20/21] Fix mypy --- src/blueapi/core/bluesky_types.py | 1 - src/blueapi/core/context.py | 26 +++++++------------------- src/blueapi/plans/plans.py | 2 +- src/blueapi/utils/type_validator.py | 27 +++------------------------ tests/utils/test_type_validator.py | 27 ++++++++++++++++++++------- 5 files changed, 31 insertions(+), 52 deletions(-) diff --git a/src/blueapi/core/bluesky_types.py b/src/blueapi/core/bluesky_types.py index ac23111298..aca9d64445 100644 --- a/src/blueapi/core/bluesky_types.py +++ b/src/blueapi/core/bluesky_types.py @@ -1,5 +1,4 @@ import inspect -from dataclasses import dataclass from typing import Any, Callable, Generator, Mapping, Type, Union from bluesky.protocols import ( diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index cf68285843..5d46506c1a 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -1,27 +1,13 @@ import logging from dataclasses import dataclass, field from importlib import import_module -from inspect import Parameter, signature from pathlib import Path from types import ModuleType -from typing import ( - Any, - Callable, - Deque, - Dict, - FrozenSet, - Iterable, - List, - Optional, - Set, - Tuple, - Type, - Union, -) +from typing import Dict, Iterable, List, Optional, Union from bluesky import RunEngine from bluesky.protocols import Flyable, Readable -from pydantic import BaseConfig, BaseModel, create_model, validator +from pydantic import BaseConfig from blueapi.utils import ( TypeValidatorDefinition, @@ -35,7 +21,6 @@ Plan, PlanGenerator, is_bluesky_compatible_device, - is_bluesky_compatible_device_type, is_bluesky_plan_generator, ) from .device_lookup import find_component @@ -132,7 +117,7 @@ def my_plan(a: int, b: str): if not is_bluesky_plan_generator(plan): raise TypeError(f"{plan} is not a valid plan generator function") - validators = device_validators(self) + validators = list(device_validators(self)) model = create_model_with_type_validators( plan.__name__, validators, @@ -173,7 +158,10 @@ def device(self, device: Device, name: Optional[str] = None) -> None: def device_validators(ctx: BlueskyContext) -> Iterable[TypeValidatorDefinition]: def get_device(name: str) -> Device: - return ctx.find_device(name) + device = ctx.find_device(name) + if device is None: + raise KeyError(f"Could not find a device named {name}") + return device for proto in BLUESKY_PROTOCOLS: yield TypeValidatorDefinition(proto, get_device) diff --git a/src/blueapi/plans/plans.py b/src/blueapi/plans/plans.py index be805f4380..da4ff3858c 100644 --- a/src/blueapi/plans/plans.py +++ b/src/blueapi/plans/plans.py @@ -1,6 +1,6 @@ import operator from functools import reduce -from typing import Any, List, Mapping, Optional, Tuple, Type, Union +from typing import Any, List, Mapping, Optional, Union import bluesky.plans as bp from bluesky.protocols import Movable, Readable diff --git a/src/blueapi/utils/type_validator.py b/src/blueapi/utils/type_validator.py index 05be7829b1..60f6866b81 100644 --- a/src/blueapi/utils/type_validator.py +++ b/src/blueapi/utils/type_validator.py @@ -40,7 +40,7 @@ @dataclass -class TypeValidatorDefinition(Generic[T, U]): +class TypeValidatorDefinition(Generic[T]): """ Definition of a validator to be applied to all types during validation. @@ -51,7 +51,7 @@ class TypeValidatorDefinition(Generic[T, U]): """ field_type: Type[T] - func: Callable[[U], T] + func: Callable[[Any], T] def __str__(self) -> str: type_name = getattr( @@ -199,11 +199,6 @@ def apply_type_validators( if isclass(model_type) and issubclass(model_type, BaseModel): if "__root__" in model_type.__fields__: - # return create_model_with_type_validators( - # model_type.__name__, - # definitions, - # fields=_extract_fields_from_model(model_type), - # ) return apply_type_validators( model_type.__fields__["__root__"].type_, definitions, cache=cache ) @@ -229,7 +224,7 @@ def apply_type_validators( def _sanitise_origin(origin: Type) -> Type: - return { + return { # type: ignore list: List, set: Set, tuple: Tuple, @@ -317,22 +312,6 @@ def params_contains(type_to_check: Type, field_type: Type) -> bool: ) -# def params_of_type(type_to_check: Type) -> List[Type]: -# return list( -# getattr( -# type_to_check, -# "__args__", -# [], -# ) -# ) + list( -# getattr( -# type_to_check, -# "__parameters__", -# [], -# ) -# ) - - def apply_to_scalars(func: Callable[[T], U], obj: Any) -> Any: if is_list_type(obj): return list(map(lambda v: apply_to_scalars(func, v), obj)) diff --git a/tests/utils/test_type_validator.py b/tests/utils/test_type_validator.py index 4dc74388c2..db9303e4df 100644 --- a/tests/utils/test_type_validator.py +++ b/tests/utils/test_type_validator.py @@ -196,7 +196,15 @@ def test_validates_set_type() -> None: def test_validates_tuple_type() -> None: - assert_validates_single_type(Tuple[int, ...], ["a", "b", "c"], (0, 1, 2)) + assert_validates_single_type( + Tuple[int, ...], # type: ignore + [ + "a", + "b", + "c", + ], + (0, 1, 2), + ) def test_validates_nested_container_type() -> None: @@ -437,7 +445,7 @@ def test_validates_scanspec_wrapper(spec: Spec) -> None: fields={"wrapper": (SpecWrapper, Undefined)}, ) parsed = parse_obj_as(model, {"wrapper": {"spec": spec.serialize()}}) - assert parsed.wrapper.spec == spec + assert parsed.wrapper.spec == spec # type: ignore @pytest.mark.parametrize("spec", SPECS) @@ -448,7 +456,7 @@ def test_validates_scanspec_wrapping_function(spec: Spec) -> None: func=spec_wrapper, ) parsed = parse_obj_as(model, {"spec": spec.serialize()}) - assert parsed.spec == spec + assert parsed.spec == spec # type: ignore def lookup_union(value: Union[int, str]) -> int: @@ -462,8 +470,8 @@ def lookup_union(value: Union[int, str]) -> int: def test_validates_union(value: Union[int, str], expected: int) -> None: model = create_model_with_type_validators( "Foo", - [TypeValidatorDefinition(Union[int, str], lookup_union)], - fields={"un": (Union[int, str], Undefined)}, + [TypeValidatorDefinition(Union[int, str], lookup_union)], # type: ignore + fields={"un": (Union[int, str], Undefined)}, # type: ignore config=DefaultConfig, ) parsed = parse_obj_as(model, {"un": value}) @@ -473,8 +481,13 @@ def test_validates_union(value: Union[int, str], expected: int) -> None: def test_validates_model_union() -> None: model = create_model_with_type_validators( "Foo", - [TypeValidatorDefinition(ComplexObject, lookup_complex)], - fields={"un": (Union[Bar, Baz], Field(..., discriminator="type"))}, + [TypeValidatorDefinition(ComplexObject, lookup_complex)], # type: ignore + fields={ + "un": ( # type: ignore + Union[Bar, Baz], + Field(..., discriminator="type"), + ) + }, config=DefaultConfig, ) parsed = parse_obj_as(model, {"un": {"a": 5, "b": "g", "type": "Bar"}}) From 1cfbd0ee3b92cc25c76e69d37ec0f3c3b8e6947a Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Fri, 14 Apr 2023 13:22:14 +0100 Subject: [PATCH 21/21] Fix tests --- src/blueapi/utils/type_validator.py | 2 ++ tests/messaging/test_stomptemplate.py | 7 +++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/blueapi/utils/type_validator.py b/src/blueapi/utils/type_validator.py index 60f6866b81..461324c434 100644 --- a/src/blueapi/utils/type_validator.py +++ b/src/blueapi/utils/type_validator.py @@ -244,6 +244,8 @@ def _extract_fields_from_function(func: Callable[..., Any]) -> Fields: fields: Dict[str, FieldDefinition] = {} for name, param in signature(func).parameters.items(): type_annotation = param.annotation + if type_annotation is Parameter.empty: + raise TypeError(f"Missing type annotation for parameter {name}") default_value = param.default if default_value is Parameter.empty: default_value = Undefined diff --git a/tests/messaging/test_stomptemplate.py b/tests/messaging/test_stomptemplate.py index 35a7753826..8128386604 100644 --- a/tests/messaging/test_stomptemplate.py +++ b/tests/messaging/test_stomptemplate.py @@ -1,10 +1,10 @@ import itertools from concurrent.futures import Future -from dataclasses import dataclass from queue import Queue from typing import Any, Iterable, Type import pytest +from pydantic import BaseModel from blueapi.config import StompConfig from blueapi.messaging import MessageContext, MessagingTemplate, StompMessagingTemplate @@ -97,8 +97,7 @@ def server(ctx: MessageContext, message: str) -> None: assert reply == "ack" -@dataclass -class Foo: +class Foo(BaseModel): a: int b: str @@ -106,7 +105,7 @@ class Foo: @pytest.mark.stomp @pytest.mark.parametrize( "message,message_type", - [("test", str), (1, int), (Foo(1, "test"), Foo)], + [("test", str), (1, int), (Foo(a=1, b="test"), Foo)], ) def test_deserialization( template: MessagingTemplate, test_queue: str, message: Any, message_type: Type