diff --git a/examples/ml-multi-cloud/EXAMPLE_SIMPLE.md b/examples/ml-multi-cloud/EXAMPLE_SIMPLE.md index ce2684afd..d5bbdc074 100644 --- a/examples/ml-multi-cloud/EXAMPLE_SIMPLE.md +++ b/examples/ml-multi-cloud/EXAMPLE_SIMPLE.md @@ -243,10 +243,8 @@ The operation will end with information like: config: operation: module: - moduleClass: RandomWalk - moduleName: orchestrator.modules.operators.randomwalk - modulePath: . - moduleType: operation + operatorName: random_walk + operationType: search parameters: batchSize: 1 numberEntities: 48 diff --git a/examples/optimization_test_functions/README.md b/examples/optimization_test_functions/README.md index 6b3636a42..161087224 100644 --- a/examples/optimization_test_functions/README.md +++ b/examples/optimization_test_functions/README.md @@ -280,10 +280,8 @@ Operation: metadata: {} operation: module: - moduleClass: RayTune - moduleName: ado_ray_tune.operator - modulePath: . - moduleType: operation + operatorName: ray_tune + operationType: search parameters: tuneConfig: max_concurrent_trials: 2 diff --git a/examples/pfas-generative-models/operation_transformer_benchmark.yaml b/examples/pfas-generative-models/operation_transformer_benchmark.yaml index 5181fb2c9..ebfaa38c0 100644 --- a/examples/pfas-generative-models/operation_transformer_benchmark.yaml +++ b/examples/pfas-generative-models/operation_transformer_benchmark.yaml @@ -4,7 +4,8 @@ spaces: - 'space-abcdef-123456' # Replace with space id from space_transformer_with_objective.yaml operation: module: - moduleClass: "RandomWalk" + operatorName: "random_walk" + operationType: "search" parameters: numberEntities: 43000 batchSize: 1000 diff --git a/orchestrator/cli/resources/operation/template.py b/orchestrator/cli/resources/operation/template.py index da00da41f..5a82e5d46 100644 --- a/orchestrator/cli/resources/operation/template.py +++ b/orchestrator/cli/resources/operation/template.py @@ -21,7 +21,7 @@ DiscoveryOperationConfiguration, DiscoveryOperationEnum, DiscoveryOperationResourceConfiguration, - OperatorFunctionConf, + OperatorReference, ) @@ -103,12 +103,13 @@ def template_operation(parameters: AdoTemplateCommandParameters) -> None: # We are now sure that the operator_type is supported and # has a value - default_operation_parameters = operators[ - parameters.operator_type - ].default_configuration_model_for_operation(parameters.operator_name) + operator = operators[parameters.operator_type].operators.get( + parameters.operator_name + ) + default_operation_parameters = operator.example_configuration if operator else None # Certain operators may not have a default configuration model - # Use an OperatorFunctionConf and set the values we have + # Use an OperatorReference and set the values we have if not default_operation_parameters: console_print( @@ -119,7 +120,7 @@ def template_operation(parameters: AdoTemplateCommandParameters) -> None: default_operation_parameters = {} default_operation_configuration = DiscoveryOperationConfiguration( - module=OperatorFunctionConf( + module=OperatorReference( operatorName=parameters.operator_name, operationType=parameters.operator_type, ), @@ -181,5 +182,5 @@ def operator_type_has_operator( operator_name in orchestrator.modules.operators.collections.operationCollectionMap[ operator_type - ].function_operations + ].operators ) diff --git a/orchestrator/cli/resources/operator/get.py b/orchestrator/cli/resources/operator/get.py index fa9fc10dc..b3f9d842c 100644 --- a/orchestrator/cli/resources/operator/get.py +++ b/orchestrator/cli/resources/operator/get.py @@ -42,22 +42,46 @@ def get_operator(parameters: AdoGetCommandParameters) -> None: ) parameters.output_format = AdoGetSupportedOutputFormats.TABLE - # Build entries for DataFrame + # Handle NAME output format + if parameters.output_format == AdoGetSupportedOutputFormats.NAME: + + # Collect all operator names + operator_names = [] + for ( + collection + ) in orchestrator.modules.operators.collections.operationCollectionMap.values(): + operator_names.extend(collection.operators.keys()) + + if parameters.resource_id: + # Single operator: verify it exists and output its name + if parameters.resource_id not in operator_names: + console_print( + f"{ERROR}{parameters.resource_id} is not among the available operators.\n" + f"{HINT}Run {cyan('ado get operators')} to list them.", + stderr=True, + ) + raise typer.Exit(1) + console_print(parameters.resource_id) + else: + # Multiple operators: output all names + for operator_name in sorted(operator_names): + console_print(operator_name) + return + + # Build entries for TABLE format entries = [] for ( collection ) in orchestrator.modules.operators.collections.operationCollectionMap.values(): - for function_name in collection.function_operations: + for operator_name, operator in collection.operators.items(): entry = { - "OPERATOR": function_name, - "VERSION": collection.function_operation_versions.get( - function_name, "" - ), + "OPERATOR": operator_name, + "VERSION": operator.version, "TYPE": collection.type.value, } if parameters.show_details: entry["DESCRIPTION"] = normalize_and_truncate_at_period( - collection.function_operation_descriptions.get(function_name, "") + operator.description or "" ) entries.append(entry) diff --git a/orchestrator/core/operation/config.py b/orchestrator/core/operation/config.py index ddeba2b10..72ee90901 100644 --- a/orchestrator/core/operation/config.py +++ b/orchestrator/core/operation/config.py @@ -182,19 +182,115 @@ def operationType(self) -> DiscoveryOperationEnum: c: type[orchestrator.modules.operators.base.DiscoveryOperationBase] = ( load_module_class_or_function(self) ) - return c.operationType() + return c.operator_metadata().type @property def operatorIdentifier(self) -> str: c: type[orchestrator.modules.operators.base.DiscoveryOperationBase] = ( load_module_class_or_function(self) ) + return c.operator_metadata().operatorIdentifier - return c.operatorIdentifier() +class OperatorMetadata(pydantic.BaseModel): + """Registry metadata for a registered operator.""" + + name: Annotated[ + str, + pydantic.Field(description="Canonical name the operator is registered under."), + ] + function: Annotated[ + typing.Callable | None, + pydantic.Field( + description=( + "The callable implementing the operator. None when returned by " + "operator_metadata() before the decorator injects it." + ), + ), + ] = None + version: Annotated[ + str, + pydantic.Field( + description=( + "PEP 440 version string for the operator (e.g. '0.1.0', " + "'1.2.3.dev4+abc.dirty'). Validated on construction." + ), + ), + ] = "0.1.0" + + description: Annotated[ + str | None, + pydantic.Field( + description="Human-readable description of the operator.", + ), + ] = None + configuration_model: Annotated[ + type[pydantic.BaseModel], + pydantic.Field( + description="Pydantic model class used to validate operation parameters.", + ), + ] + example_configuration: Annotated[ + pydantic.BaseModel, + pydantic.Field( + description="Default instance of the configuration model.", + ), + ] + cls: Annotated[ + type | None, + pydantic.Field( + description=( + "For explore operators: the unwrapped Python class implementing the " + "operator. None for function-only operators. The concrete " + "modules/operators layer enforces that this is a DiscoveryOperationBase " + "subclass; config.py treats it as an opaque type to stay decoupled." + ), + ), + ] = None + type: Annotated[ + DiscoveryOperationEnum, + pydantic.Field( + description="The discovery operation type this operator belongs to." + ), + ] + + @pydantic.field_validator("version", mode="after") + @classmethod + def validate_version_is_pep440(cls, value: str) -> str: + """Validate that *version* is a valid PEP 440 version string. + + Args: + value: The version string to validate. + + Returns: + The original version string unchanged. + + Raises: + ValueError: If *value* is not a valid PEP 440 version string. + """ + from packaging.version import InvalidVersion, Version + + try: + Version(value) + except InvalidVersion as exc: + raise ValueError( + f"Operator version {value!r} is not a valid PEP 440 version string: {exc}" + ) from exc + return value + + @property + def operatorIdentifier(self) -> str: + """Canonical identifier for this operator: ``{name}-{version}``.""" + return f"{self.name}-{self.version}" -class OperatorFunctionConf(pydantic.BaseModel): - """Describes an operator vended as a function""" + +class OperatorReference(pydantic.BaseModel): + """Identifies a registered operator by name and operation type. + + A lightweight reference used to look up an operator from the registry and + dispatch to its callable. Paired with :class:`OperatorMetadata`, which + holds the full operator metadata stored in the registry. + """ model_config = ConfigDict(extra="forbid") operationType: Annotated[ @@ -216,7 +312,7 @@ def validateOperatorExists(self) -> bool: if ( self.operatorName - not in operationCollectionMap[self.operationType].function_operations + not in operationCollectionMap[self.operationType].operators ): raise ValueError( f"Operator {self.operatorName} had no functions of type {self.operationType}" @@ -234,18 +330,64 @@ def operationFunction( self.operationType ] - return collection.function_operations.get(self.operatorName) + operator = collection.operators.get(self.operatorName) + return operator.function if operator else None @property def operatorIdentifier(self) -> str: + """Canonical identifier delegated to ``OperatorMetadata.operatorIdentifier``. + Returns: + ``"{operatorName}-{version}"`` as stored in the operator registry, + or ``"{operatorName}-None"`` if the operator is not yet registered. + """ import orchestrator.modules.operators.collections collection = orchestrator.modules.operators.collections.operationCollectionMap[ self.operationType ] - return f"{self.operatorName}-{collection.function_operation_versions.get(self.operatorName)}" + operator = collection.operators.get(self.operatorName) + return operator.operatorIdentifier if operator else f"{self.operatorName}-None" + + +# --------------------------------------------------------------------------- +# Backwards-compatibility alias — use OperatorReference in new code +# --------------------------------------------------------------------------- + + +class OperatorFunctionConf(OperatorReference): + """Deprecated alias for :class:`OperatorReference`. + + .. deprecated:: + ``OperatorFunctionConf`` has been renamed to :class:`OperatorReference`. + Update imports and instantiation sites to use ``OperatorReference`` + directly. + """ + + @pydantic.model_validator(mode="wrap") + @classmethod + def _warn_deprecated( + cls, value: object, handler: pydantic.ValidatorFunctionWrapHandler + ) -> "OperatorFunctionConf": + """Emit a deprecation warning whenever OperatorFunctionConf is instantiated. + + Args: + value: The raw input value passed to the model. + handler: The pydantic validation handler. + + Returns: + The validated model instance. + """ + import warnings + + warnings.warn( + "OperatorFunctionConf has been renamed to OperatorReference. " + "Update your import to use OperatorReference instead.", + DeprecationWarning, + stacklevel=2, + ) + return handler(value) class DiscoveryOperationConfiguration(pydantic.BaseModel): @@ -254,7 +396,7 @@ class DiscoveryOperationConfiguration(pydantic.BaseModel): model_config = ConfigDict(extra="forbid") module: Annotated[ - OperatorModuleConf | OperatorFunctionConf, + OperatorModuleConf | OperatorReference, pydantic.Field( description="The module or function providing the discovery operation" ), @@ -270,8 +412,8 @@ class DiscoveryOperationConfiguration(pydantic.BaseModel): @pydantic.field_validator("module", mode="after") @classmethod def ensure_module_is_installed( - cls, module: OperatorModuleConf | OperatorFunctionConf - ) -> OperatorModuleConf | OperatorFunctionConf: + cls, module: OperatorModuleConf | OperatorReference + ) -> OperatorModuleConf | OperatorReference: """Validates that the operator module is installed and accessible. Args: @@ -283,7 +425,7 @@ def ensure_module_is_installed( Raises: ValueError: If the operator module is not installed or cannot be imported. """ - if isinstance(module, OperatorFunctionConf): + if isinstance(module, OperatorReference): return module import importlib @@ -299,10 +441,10 @@ def ensure_module_is_installed( @pydantic.model_validator(mode="after") def validate_and_downcast_parameters(self) -> Self: - """Validates and downcasts operation parameters based on the module type. + """Validates and downcasts operation parameters. For OperatorModuleConf modules, validates parameters using the operation's - validateOperationParameters method. For OperatorFunctionConf modules, + validateOperationParameters method. For OperatorReference modules, validates parameters against the configuration model if available. Returns: @@ -313,30 +455,26 @@ def validate_and_downcast_parameters(self) -> Self: """ if isinstance(self.module, OperatorModuleConf): # This is guaranteed to not raise an error thanks to ensure_module_is_installed - operation = getattr( + operator_class = getattr( importlib.import_module(self.module.moduleName), self.module.moduleClass ) - self.parameters = operation.validateOperationParameters(self.parameters) + operator_metadata = operator_class.operator_metadata() + self.parameters = operator_metadata.configuration_model.model_validate( + self.parameters + ) else: - import logging - from orchestrator.modules.operators.collections import ( operationCollectionMap, ) operation_type = self.module.operationType operator_name = self.module.operatorName - configuration_model = operationCollectionMap[ - operation_type - ].configuration_model_for_operation(operator_name) - - if configuration_model: - self.parameters = configuration_model.model_validate(self.parameters) - else: - logger = logging.getLogger(__file__) - logger.warning( - f"No configuration model was available for operation {operator_name} of type {operation_type}" - ) + operator_metadata = operationCollectionMap[operation_type].operators[ + operator_name + ] + self.parameters = operator_metadata.configuration_model.model_validate( + self.parameters + ) return self diff --git a/orchestrator/modules/actuators/registry.py b/orchestrator/modules/actuators/registry.py index b98353f59..de005f406 100644 --- a/orchestrator/modules/actuators/registry.py +++ b/orchestrator/modules/actuators/registry.py @@ -22,6 +22,7 @@ from orchestrator.schema.reference import ExperimentReference from orchestrator.utilities.distribution import distribution_from_module from orchestrator.utilities.logging import configure_logging +from orchestrator.utilities.ray import extract_base_class if typing.TYPE_CHECKING: import pandas as pd @@ -35,66 +36,6 @@ moduleLogger = logging.getLogger("registry") -def _extract_base_actuator_class( - actuator: typing.Any, # noqa: ANN401 -) -> "type[ActuatorBase]": - """Extract the base actuator class from a potentially Ray-decorated class. - - Args: - actuator: Either a Ray-decorated ActorClass instance or an undecorated - ActuatorBase subclass. - - Returns: - The undecorated base ActuatorBase subclass. - - Raises: - ValueError: If the actuator is a Ray ActorClass but the base class - cannot be extracted. - """ - from orchestrator.modules.actuators.base import ActuatorBase - - # First, check if this is already a regular class (not decorated) - try: - issubclass(actuator, ActuatorBase) - except TypeError: # actuator is an instance -> decorated - pass - else: - return actuator - - # Try to import Ray and check if it's an ActorClass - try: - import ray.actor - - if issubclass(actuator.__class__, ray.actor.ActorClass): - # It's a Ray-decorated class, extract the original class - # Ray stores the original class in __ray_actor_class__ - if hasattr(actuator, "__ray_actor_class__"): - original_class = actuator.__ray_actor_class__ - if isinstance(original_class, type) and issubclass( - original_class, ActuatorBase - ): - return original_class - - # Could not extract base class - raise ValueError( - f"Could not extract base ActuatorBase class from Ray ActorClass {actuator}. " - "The ActorClass does not expose the original class through expected attributes." - ) - except ImportError: - # Ray not available, fall through - pass - - # If we get here, it's neither a regular class nor a Ray ActorClass we can handle - # Check if it's an instance and raise a helpful error - if not isinstance(actuator, type): - raise TypeError( - f"Expected a class or Ray ActorClass, got instance of {type(actuator)}" - ) - - # It's a class but not an ActuatorBase subclass - raise TypeError(f"Expected ActuatorBase subclass, got {actuator}") - - class UnknownExperimentError(Exception): pass @@ -142,7 +83,7 @@ def __init__( import orchestrator.modules.actuators as builtin_actuators from orchestrator.modules.actuators.base import ActuatorBase, ActuatorModuleConf - # Mpass actuator ids to actuator configurations: G + # Maps actuator ids to generic actuator parameter payloads from configuration. self.actuatorConfigurationMap: dict[str, GenericActuatorParameters] = {} if actuator_configurations: self.actuatorConfigurationMap.update(actuator_configurations) @@ -258,7 +199,7 @@ def __init__( # registered the actuator self.log.debug(f"Add actuator plugin {actuator}") # Extract base class in case actuator_class is Ray-decorated - actuator_class = _extract_base_actuator_class(actuator_class) + actuator_class = extract_base_class(actuator_class, ActuatorBase) self.registerActuator( actuatorid=actuator_class.identifier, actuatorClass=actuator_class, diff --git a/orchestrator/modules/operators/_cleanup.py b/orchestrator/modules/operators/_cleanup.py index 031cf16de..46558370c 100644 --- a/orchestrator/modules/operators/_cleanup.py +++ b/orchestrator/modules/operators/_cleanup.py @@ -29,6 +29,7 @@ def handler(sig: int, frame: typing.Any | None) -> None: # noqa: ANN401 if shutdown_signal_received: moduleLog.info("Graceful shutdown already completed") + return shutdown_signal_received = True moduleLog.info("Calling cleanup callbacks") diff --git a/orchestrator/modules/operators/_explore_orchestration.py b/orchestrator/modules/operators/_explore_orchestration.py index 2bd341540..540fc9bfb 100644 --- a/orchestrator/modules/operators/_explore_orchestration.py +++ b/orchestrator/modules/operators/_explore_orchestration.py @@ -12,11 +12,10 @@ from orchestrator.core.discoveryspace.space import DiscoverySpace from orchestrator.core.operation.config import ( FunctionOperationInfo, - OperatorModuleConf, + OperatorMetadata, ) from orchestrator.core.operation.operation import OperationOutput from orchestrator.modules.actuators.measurement_queue import MeasurementQueue -from orchestrator.modules.module import load_module_class_or_function from orchestrator.modules.operators._cleanup import ( CLEANER_ACTOR, cleanup_callback_functions, @@ -49,7 +48,7 @@ def graceful_explore_operation_shutdown( identifier: str, operator: "OperatorActor", - state: "DiscoverySpaceManagerActor", + discovery_space_manager: "DiscoverySpaceManagerActor", actuators: list["ActuatorActor"], namespace: str, timeout: int = 60, @@ -71,7 +70,7 @@ def graceful_explore_operation_shutdown( ) as status: moduleLog.debug("Shutting down state") - ray.get(state.shutdown.remote()) + ray.get(discovery_space_manager.shutdown.remote()) status.update(f"Shutdown ({identifier}) - cleaning up custom actors") @@ -90,7 +89,7 @@ def graceful_explore_operation_shutdown( terminate_actor_waitables = [ operator.__ray_terminate__.remote(), - state.__ray_terminate__.remote(), + discovery_space_manager.__ray_terminate__.remote(), ] # __ray_terminate allows atexit handlers of actors to run # see https://docs.ray.io/en/latest/ray-core/api/doc/ray.kill.html @@ -100,7 +99,7 @@ def graceful_explore_operation_shutdown( n_actors = len(terminate_actor_waitables) moduleLog.debug(f"waiting for graceful shutdown of {n_actors} actors") - actors = [operator, state] + actors = [operator, discovery_space_manager] actors.extend(actuators) terminate_waitable_to_actor_lookup = dict( @@ -129,7 +128,7 @@ def graceful_explore_operation_shutdown( def run_explore_operation_core_closure( - operator: "OperatorActor", state: "DiscoverySpaceManagerActor" + operator: "OperatorActor", discovery_space_manager: "DiscoverySpaceManagerActor" ) -> typing.Callable[[], OperationOutput | None]: def _run_explore_operation_core() -> OperationOutput: @@ -142,10 +141,10 @@ def _run_explore_operation_core() -> OperationOutput: name="RichConsoleQueue", lifetime="detached", get_if_exists=True ).remote() - discovery_space = ray.get(state.discoverySpace.remote()) + discovery_space = ray.get(discovery_space_manager.discoverySpace.remote()) operation_id = ray.get(operator.operationIdentifier.remote()) - state.startMonitoring.remote() + discovery_space_manager.startMonitoring.remote() future = operator.run.remote() # Start the rich live updates @@ -163,12 +162,12 @@ def _run_explore_operation_core() -> OperationOutput: def orchestrate_explore_operation( - operator_module: OperatorModuleConf, + operator_metadata: OperatorMetadata, discovery_space: DiscoverySpace, parameters: dict, operation_info: FunctionOperationInfo, ) -> OperationOutput: - """Orchestrates an explore operation + """Orchestrates an explore operation. This function sets up and executes an explore (search) operation. It handles: - Initializing the resource cleaner @@ -182,7 +181,8 @@ def orchestrate_explore_operation( execute the operation, handle exceptions, and store the operation results. Params: - operator_module: Configuration for the operator module (class-based operation) + operator_metadata: Registered metadata for the operator, carrying the class, + configuration model, name, and type. discovery_space: The discovery space to operate on parameters: Dictionary of parameters for the operation operation_info: Information about the operation including metadata, actuator @@ -192,8 +192,8 @@ def orchestrate_explore_operation( OperationOutput containing the results and status of the operation Raises: - ValueError: If the MeasurementSpace is not consistent with EntitySpace or if - actuator configurations are invalid + ValueError: If the MeasurementSpace is not consistent with EntitySpace, + actuator configurations are invalid, or no operator class is registered pydantic.ValidationError: If the operation parameters are not valid OperationException: If there is an error during the operation ray.exceptions.ActorDiedError: If there was an error initializing the actuators @@ -206,9 +206,12 @@ def orchestrate_explore_operation( if not operation_info.ray_namespace: operation_info.ray_namespace = ( - f"{operator_module.moduleClass}-namespace-{str(uuid.uuid4())[:8]}" + f"{operator_metadata.name}-namespace-{str(uuid.uuid4())[:8]}" ) + # Validate parameters + operator_metadata.configuration_model.model_validate(parameters) + # Check the space if not discovery_space.measurementSpace.isConsistent: moduleLog.critical("Measurement space is inconsistent - aborting") @@ -251,7 +254,7 @@ def orchestrate_explore_operation( queue=measurement_queue, space=discovery_space, namespace=operation_info.ray_namespace, - ) # type: "InternalStateActor" + ) # type: DiscoverySpaceManagerActor moduleLog.debug( f"Waiting for discovery space manager to be ready: {discovery_space_manager}" ) @@ -262,20 +265,14 @@ def orchestrate_explore_operation( # OPERATOR # - # Validate the parameters for the operation - operator_class = load_module_class_or_function( - operator_module - ) # type: typing.Type["StateSubscribingDiscoveryOperation"] - operator_class.validateOperationParameters(parameters) - # Create operator actor operator = orchestrator.modules.operators.setup.setup_operator( - operator_module=operator_module, + operator_metadata=operator_metadata, parameters=parameters, discovery_space=discovery_space, actuators=actuators, namespace=operation_info.ray_namespace, - state=discovery_space_manager, + discovery_space_manager=discovery_space_manager, ) # type: "OperatorActor" identifier = ray.get(operator.operationIdentifier.remote()) @@ -292,7 +289,7 @@ def orchestrate_explore_operation( lambda: graceful_explore_operation_shutdown( identifier=identifier, operator=operator, - state=discovery_space_manager, + discovery_space_manager=discovery_space_manager, actuators=list(actuators.values()), namespace=operation_info.ray_namespace, ) @@ -311,17 +308,17 @@ def finalize_callback_closure( def finalize_callback(operation_resource: OperationResource) -> None: # Even on exception we can still get entities submitted - logging.debug("Finalize callback - Getting entities submitted") + moduleLog.debug("Finalize callback - Getting entities submitted") try: operation_resource.metadata["entities_submitted"] = ray.get( operator_actor.numberEntitiesSampled.remote(), timeout=10 ) - logging.debug("Finalize callback - Getting experiments requested") + moduleLog.debug("Finalize callback - Getting experiments requested") operation_resource.metadata["experiments_requested"] = ray.get( operator_actor.numberMeasurementsRequested.remote() ) except GetTimeoutError: - logging.warning( + moduleLog.warning( "Unable to retrieve entity/experiment submission data from operator" ) @@ -331,7 +328,7 @@ def finalize_callback(operation_resource: OperationResource) -> None: operation_output = _run_operation_harness( run_closure=explore_run_closure, discovery_space=discovery_space, - operator_module=operator_module, + operator_metadata=operator_metadata, operation_parameters=parameters, operation_info=operation_info, operation_identifier=identifier, @@ -344,7 +341,7 @@ def finalize_callback(operation_resource: OperationResource) -> None: graceful_explore_operation_shutdown( identifier=identifier, operator=operator, - state=discovery_space_manager, + discovery_space_manager=discovery_space_manager, actuators=list(actuators.values()), namespace=operation_info.ray_namespace, ) diff --git a/orchestrator/modules/operators/_general_orchestration.py b/orchestrator/modules/operators/_general_orchestration.py index 236110f05..a06f743a4 100644 --- a/orchestrator/modules/operators/_general_orchestration.py +++ b/orchestrator/modules/operators/_general_orchestration.py @@ -4,15 +4,10 @@ import logging import typing -import pydantic - -import orchestrator.core -import orchestrator.modules -import orchestrator.modules.operators._cleanup from orchestrator.core.discoveryspace.space import DiscoverySpace from orchestrator.core.operation.config import ( FunctionOperationInfo, - OperatorFunctionConf, + OperatorMetadata, get_actuator_configurations, validate_actuator_configurations_against_space_configuration, ) @@ -21,19 +16,13 @@ _run_operation_harness, log_space_details, ) +from orchestrator.modules.operators.base import OperatorFunction moduleLog = logging.getLogger("general_orchestration") def run_general_operation_core_closure( - operation_function: typing.Callable[ - [ - DiscoverySpace, - FunctionOperationInfo, - ..., - ], - OperationOutput | None, - ], + operation_function: OperatorFunction, discovery_space: DiscoverySpace, operationInfo: FunctionOperationInfo, operation_parameters: dict, @@ -48,25 +37,16 @@ def _run_general_operation_core() -> OperationOutput | None: def orchestrate_general_operation( - operator_function: typing.Callable[ - [ - DiscoverySpace, - FunctionOperationInfo, - ..., - ], - OperationOutput, - ], + operator_metadata: OperatorMetadata, operation_parameters: dict, - parameters_model: type[pydantic.BaseModel] | None, discovery_space: DiscoverySpace, operation_info: FunctionOperationInfo, - operation_type: orchestrator.core.operation.config.DiscoveryOperationEnum, ) -> OperationOutput: """Orchestrates a general operation (non-explore) This function handles the orchestration of non-explore operations (characterize, compare, modify, fuse, learn, etc.). It performs the following: - - Validates operation parameters if a parameters model is provided + - Validates operation parameters against the configuration model - Checks measurement space consistency - Validates actuator configurations against the space - Inserts graceful shutdown handler for keyboard interrupts @@ -75,15 +55,12 @@ def orchestrate_general_operation( execute the operation, handle exceptions, and stores the operation results. Params: - operator_function: The function that implements the operation. Must accept - DiscoverySpace and FunctionOperationInfo as first two arguments, followed - by operation-specific parameters + operator_metadata: Registered metadata for the operator, carrying the callable, + configuration model, operation type, and name. operation_parameters: Dictionary of parameters to pass to the operator function - parameters_model: Optional Pydantic model to validate operation_parameters against discovery_space: The discovery space to operate on operation_info: Information about the operation including metadata, actuator configuration identifiers, and namespace - operation_type: The type of operation being executed Returns: OperationOutput containing the results and status of the operation @@ -103,18 +80,18 @@ def orchestrate_general_operation( # for general operations it makes no difference # if a signal handler for SIGTERM is in place or not + if operator_metadata.function is None: + raise ValueError( + f"Operator '{operator_metadata.name}' has no function registered" + ) + operator_function = typing.cast("OperatorFunction", operator_metadata.function) + if not operation_info.ray_namespace: operation_info.ray_namespace = ( - f"{operator_function.__name__}-namespace-{str(uuid.uuid4())[:8]}" + f"{operator_metadata.name}-namespace-{str(uuid.uuid4())[:8]}" ) - operator_module = OperatorFunctionConf( - operatorName=operator_function.__name__, - operationType=operation_type, - ) - - if parameters_model: - parameters_model.model_validate(operation_parameters) + operator_metadata.configuration_model.model_validate(operation_parameters) # Check the space if not discovery_space.measurementSpace.isConsistent: @@ -144,7 +121,7 @@ def orchestrate_general_operation( return _run_operation_harness( run_closure=operation_run_closure, discovery_space=discovery_space, - operator_module=operator_module, + operator_metadata=operator_metadata, operation_parameters=operation_parameters, operation_info=operation_info, ) diff --git a/orchestrator/modules/operators/_orchestrate_core.py b/orchestrator/modules/operators/_orchestrate_core.py index 1fef8b0eb..9d702ceef 100644 --- a/orchestrator/modules/operators/_orchestrate_core.py +++ b/orchestrator/modules/operators/_orchestrate_core.py @@ -13,8 +13,8 @@ from orchestrator.core.discoveryspace.space import DiscoverySpace from orchestrator.core.operation.config import ( FunctionOperationInfo, - OperatorFunctionConf, - OperatorModuleConf, + OperatorMetadata, + OperatorReference, ) from orchestrator.core.operation.operation import OperationException, OperationOutput from orchestrator.core.operation.resource import ( @@ -46,7 +46,7 @@ def log_space_details(discovery_space: "DiscoverySpace") -> None: def _run_operation_harness( run_closure: typing.Callable[[], OperationOutput], discovery_space: DiscoverySpace, - operator_module: OperatorModuleConf | OperatorFunctionConf, + operator_metadata: OperatorMetadata, operation_parameters: dict, operation_info: FunctionOperationInfo, operation_identifier: str | None = None, @@ -61,7 +61,7 @@ def _run_operation_harness( Params: run_closure: Callable that executes the operation and returns OperationOutput discovery_space: The discovery space the operation is running on - operator_module: Configuration for the operator (either module or function-based) + operator_metadata: Metadata for the registered operator. operation_parameters: Dictionary of parameters for the operation operation_info: Information about the operation including metadata and actuator configs operation_identifier: Optional pre-existing identifier for the operation resource @@ -80,9 +80,13 @@ def _run_operation_harness( # Create and add OperationResource to metastore # + operator_reference = OperatorReference( + operatorName=operator_metadata.name, + operationType=operator_metadata.type, + ) operation_resource = create_operation_and_add_to_metastore( discovery_space=discovery_space, - operator_module=operator_module, + operator_module=operator_reference, operation_parameters=operation_parameters, metastore=discovery_space.metadataStore, operation_info=operation_info, @@ -103,7 +107,7 @@ def _run_operation_harness( operationStatus = OperationResourceStatus( event=OperationResourceEventEnum.FINISHED, exit_state=OperationExitStateEnum.ERROR, - message="Operation exited due uncaught exception)", + message="Operation exited due to uncaught exception)", ) try: operation_resource.status.append( @@ -178,18 +182,16 @@ def _run_operation_harness( sys.stdout.flush() if shutdown_signal_received: moduleLog.warning( - f"Operation {operation_identifier} exited normally but an external event e.g. SIGTERM, has already initiated shutdown" + f"Operation {operation_resource.identifier} exited normally but an external event e.g. SIGTERM, has already initiated shutdown" ) if operation_output: moduleLog.info("Operation returned output - will save") - operationStatus = ( - OperationResourceStatus( - event=OperationResourceEventEnum.FINISHED, - exit_state=OperationExitStateEnum.ERROR, - message="An external event e.g. SIGTERM, initiated shutdown. " - "This may have caused the operation to exit early", - ), + operationStatus = OperationResourceStatus( + event=OperationResourceEventEnum.FINISHED, + exit_state=OperationExitStateEnum.ERROR, + message="An external event e.g. SIGTERM, initiated shutdown. " + "This may have caused the operation to exit early", ) else: if not operation_output: @@ -202,7 +204,7 @@ def _run_operation_harness( ) else: moduleLog.debug( - f"Operation {operation_identifier} exited normally with status {operation_output.exitStatus}" + f"Operation {operation_resource.identifier} exited normally with status {operation_output.exitStatus}" ) finally: if operation_output: @@ -212,7 +214,7 @@ def _run_operation_harness( # Add it to metastore moduleLog.info( - f"Adding output for operation {operation_identifier} to metastore" + f"Adding output for operation {operation_resource.identifier} to metastore" ) add_operation_output_to_metastore( operation=operation_resource, diff --git a/orchestrator/modules/operators/base.py b/orchestrator/modules/operators/base.py index bf62ce2d1..3eae731b3 100644 --- a/orchestrator/modules/operators/base.py +++ b/orchestrator/modules/operators/base.py @@ -7,12 +7,11 @@ import contextlib import logging import typing +from typing import Protocol -import pydantic import ray import ray.exceptions -import orchestrator.core.metadata import orchestrator.core.operation.resource import orchestrator.core.resources import orchestrator.metastore.project @@ -21,18 +20,16 @@ import orchestrator.schema.reference from orchestrator.core.discoveryspace.space import DiscoverySpace from orchestrator.core.operation.config import ( - DiscoveryOperationEnum, DiscoveryOperationResourceConfiguration, FunctionOperationInfo, - OperatorFunctionConf, OperatorModuleConf, + OperatorReference, ) from orchestrator.core.operation.operation import OperationOutput from orchestrator.core.operation.resource import OperationResource from orchestrator.metastore.sqlstore import SQLStore from orchestrator.modules.actuators.measurement_queue import MeasurementQueue from orchestrator.modules.operators.discovery_space_manager import ( - DiscoverySpaceManager, DiscoverySpaceUpdateSubscriber, ) from orchestrator.schema.entity import Entity @@ -43,49 +40,174 @@ from orchestrator.metastore.base import ResourceStore from orchestrator.modules.actuators.base import ActuatorBase + from .discovery_space_manager import DiscoverySpaceManagerActor + moduleLog = logging.getLogger("operation_base") +class OperatorFunction(Protocol): + def __call__( + self, + discoverySpace: DiscoverySpace, + operationInfo: FunctionOperationInfo | None = None, + **kwargs: object, + ) -> OperationOutput: ... + + +def validate_operator_function_signature(fn: typing.Callable) -> None: + """Validate that *fn* conforms to the :class:`OperatorFunction` Protocol. + + Validates the positional parameter count against the Protocol, then + checks that every positional parameter and the return value carry a type + annotation whose type matches the Protocol. Type hints are mandatory: + an operator function without them cannot be verified to be correct, so + a ``ValueError`` is raised rather than silently skipping the check. + + If ``inspect.signature`` cannot be obtained for *fn* a ``ValueError`` is + raised, because conformance cannot be confirmed. A failure to inspect the + Protocol itself propagates as-is (it indicates a framework bug). + + Args: + fn: The callable to validate. + + Raises: + ValueError: If *fn* does not conform to the Protocol, if its + signature cannot be inspected, or if any required type hints are + absent or unresolvable. + """ + import inspect + + _HINTS_MISSING_HINT = ( + "Operator functions must declare the correct types for all parameters " + "and the return value to pass validation." + ) + + proto_sig = inspect.signature(OperatorFunction.__call__) + + try: + sig = inspect.signature(fn) + except (TypeError, ValueError) as exc: + raise ValueError( + f"Cannot validate operator function {fn!r}: signature could not " + f"be inspected ({exc})." + ) from exc + + proto_positional = [ + p + for p in proto_sig.parameters.values() + if p.name != "self" + and p.kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.POSITIONAL_ONLY, + ) + ] + actual_positional = [ + p + for p in sig.parameters.values() + if p.kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.POSITIONAL_ONLY, + ) + ] + if len(actual_positional) < len(proto_positional): + missing = [p.name for p in proto_positional[len(actual_positional) :]] + raise ValueError( + f"Operator function is missing required positional parameter(s): " + f"{missing!r}." + ) + if len(actual_positional) > len(proto_positional): + extra = [p.name for p in actual_positional[len(proto_positional) :]] + raise ValueError( + f"Operator function has extra positional parameter(s) not in the " + f"Protocol: {extra!r}." + ) + + proto_hints = typing.get_type_hints(OperatorFunction.__call__) + + try: + hints = typing.get_type_hints(fn) + except Exception as exc: # noqa: BLE001 + raise ValueError( + f"Operator function has valid structure but type hints are missing " + f"or unresolvable ({exc}). {_HINTS_MISSING_HINT}" + ) from exc + + missing_hints: list[str] = [ + f"parameter {p.name!r}" for p in actual_positional if p.name not in hints + ] + if "return" not in hints: + missing_hints.append("return type") + if missing_hints: + raise ValueError( + f"Operator function has valid structure but type hints are missing " + f"for {missing_hints}. {_HINTS_MISSING_HINT}" + ) + + expected_return = proto_hints.get("return") + actual_return = hints["return"] + if expected_return is not None and actual_return != expected_return: + raise ValueError( + f"Operator function return type must be {expected_return!r}, " + f"got {actual_return!r}." + ) + + for idx, (proto_param, actual_param) in enumerate( + zip(proto_positional, actual_positional, strict=True), start=1 + ): + expected_type = proto_hints.get(proto_param.name) + actual_type = hints[actual_param.name] + if expected_type is not None and actual_type != expected_type: + raise ValueError( + f"Operator function parameter {actual_param.name!r} " + f"(position {idx}) must be {expected_type!r}, " + f"got {actual_type!r}." + ) + + # Some operations are RayActors: These operations use Actuators and StateUpdateQueue and require Ray # Some operations are not RayActors: They don't have to use Actuators and StateUpdateQueue or Ray. They can use ray-workers class DiscoveryOperationBase(metaclass=abc.ABCMeta): - @abc.abstractmethod def operationIdentifier(self) -> str: - """A unique id for the operation instance being run by the operator - - should have form $operatorIdentifier-$version-$uid""" - - @classmethod - @abc.abstractmethod - def operatorIdentifier(cls) -> str: - """The identifier of this operator + """A unique id for the operation instance being run by the operator. - should have form method-version""" + Returns ``"-<8-hex-uuid>"``, stable for the lifetime of + the instance. Override when a deterministic or richer identifier is + needed (e.g. embedding the operator version or a caller-supplied run + id). + """ + import uuid - @classmethod - @abc.abstractmethod - def operationType(cls) -> DiscoveryOperationEnum: - """The type of operation this operator applies""" + return f"{type(self).__name__.lower()}-{uuid.uuid4().hex[:8]}" @classmethod @abc.abstractmethod - def defaultOperationParameters( + def operator_metadata( cls, - ) -> pydantic.BaseModel: - """A default pytdantic model for this operations parameters with this operator""" - - @classmethod - @abc.abstractmethod - def validateOperationParameters( - cls, - parameters: dict | pydantic.BaseModel, - ) -> pydantic.BaseModel: - """If the parameters are valid returns a model for them. - - Otherwise, will raise ValidationErrors""" + ) -> "orchestrator.core.operation.config.OperatorMetadata": + """Return :class:`~orchestrator.core.operation.config.OperatorMetadata` for this operator. + + Subclasses must override this method and return an + :class:`~orchestrator.core.operation.config.OperatorMetadata` instance + that describes the operator's name, version, type, and configuration + model. The ``@explore_operation`` decorator fills in the ``function`` + and ``cls`` fields before registering:: + + @classmethod + def operator_metadata(cls) -> OperatorMetadata: + return OperatorMetadata( + name="my_op", + version="0.1.0", + configuration_model=MyOpParameters, + example_configuration=MyOpParameters(), + type=DiscoveryOperationEnum.SEARCH, + ) + """ + raise NotImplementedError(f"{cls.__name__} must implement operator_metadata().") class UnaryDiscoveryOperation(metaclass=abc.ABCMeta): @@ -110,7 +232,49 @@ async def run(self) -> OperationOutput | None: pass -# Note: We need async and sync versions because depending on agent +class DiscoverySpaceSubscribingDiscoveryOperation( + DiscoveryOperationBase, + DiscoverySpaceUpdateSubscriber, + metaclass=abc.ABCMeta, +): + """Instances of this class can receive notifications when measurements are added to a DiscoverySpace""" + + def __init__( + self, + operationActorName: str, + namespace: str | None, + discovery_space_manager: "DiscoverySpaceManagerActor", + actuators: dict[str, "orchestrator.modules.actuators.base.ActuatorBase"], + ) -> None: + self.actorName = operationActorName + self.namespace = namespace + self.ds_manager = discovery_space_manager + self.actuators = actuators + # noinspection PyUnresolvedReferences + self.ds_manager.subscribeToUpdates.remote(subscriberName=self.actorName) + + super().__init__() + + +class Explore( + DiscoverySpaceSubscribingDiscoveryOperation, + UnaryDiscoveryOperation, + metaclass=abc.ABCMeta, +): + """Subclasses sample entities from a DiscoverySpace and run measurements on them + + The general pattern is that on calling run() the subclass + 1. Samples a set of entities from the discovery space + 2. Submits them for measurement using measure_or_replay() + 3. Waits for notifications that measurements are completed + 4. Returns to 1 or finishes + + Subclasses define different sampling strategies and submission strategies + e.g. wait for all measurements to complete before sending new batch or + submit new measurements as soon as one completes. + """ + + def measure_or_replay( requestIndex: int, requesterid: str, @@ -210,77 +374,6 @@ def measure_or_replay( return request_ids -class DiscoverySpaceSubscribingDiscoveryOperation( - DiscoveryOperationBase, - DiscoverySpaceUpdateSubscriber, - metaclass=abc.ABCMeta, -): - """Instances of this class can cause updates the state and receives details of updates via the StateUpdateSubscriber interface - - Instances of this class are RayActors and must run in Ray. - They work on a Ray wrapped instance of the DiscoveryState (models.actors.InternalState) - """ - - def __init__( - self, - operationActorName: str, - namespace: str | None, - state: DiscoverySpaceManager, - # Will actually be ray.actor.ActorHandle accessing InternalState - actuators: dict[str, "orchestrator.modules.actuators.base.ActuatorBase"], - params: dict | None = None, - metadata: orchestrator.core.metadata.ConfigurationMetadata | None = None, - ) -> None: - # Common code for StateSubscribingDiscoveryOperations - self.state = state - self.actorName = operationActorName - self.namespace = namespace - # noinspection PyUnresolvedReferences - self.state.subscribeToUpdates.remote(subscriberName=self.actorName) - - super().__init__() - - # async def run(self, discoveryState: orchestrator.model.actors.InternalState): - # - # pass - - -class Characterize( - DiscoverySpaceSubscribingDiscoveryOperation, - UnaryDiscoveryOperation, - metaclass=abc.ABCMeta, -): - pass - - -class Search( - DiscoverySpaceSubscribingDiscoveryOperation, - UnaryDiscoveryOperation, - metaclass=abc.ABCMeta, -): - pass - - -class Compare( - DiscoveryOperationBase, MultivariateDiscoveryOperation, metaclass=abc.ABCMeta -): - pass - - -class Modify(DiscoveryOperationBase, UnaryDiscoveryOperation, metaclass=abc.ABCMeta): - pass - - -class Fuse( - DiscoveryOperationBase, MultivariateDiscoveryOperation, metaclass=abc.ABCMeta -): - pass - - -class Learn(DiscoveryOperationBase, UnaryDiscoveryOperation, metaclass=abc.ABCMeta): - pass - - def add_operation_output_to_metastore( operation: "OperationResource", output: "OperationOutput", @@ -332,7 +425,7 @@ def add_operation_and_output_to_metastore( def create_operation_and_add_to_metastore( discovery_space: DiscoverySpace, - operator_module: OperatorModuleConf | OperatorFunctionConf, + operator_module: OperatorModuleConf | OperatorReference, operation_parameters: dict, operation_info: FunctionOperationInfo, metastore: SQLStore, diff --git a/orchestrator/modules/operators/collections.py b/orchestrator/modules/operators/collections.py index e31a1da55..5de2954e0 100644 --- a/orchestrator/modules/operators/collections.py +++ b/orchestrator/modules/operators/collections.py @@ -7,127 +7,104 @@ from typing import Annotated import pydantic -from pydantic import ConfigDict -import orchestrator.core.metadata +import orchestrator.core.operation.config from orchestrator.core.discoveryspace.space import DiscoverySpace from orchestrator.core.operation.config import ( DiscoveryOperationEnum, FunctionOperationInfo, + OperatorMetadata, +) +from orchestrator.modules.operators.base import ( + DiscoveryOperationBase, + DiscoverySpaceSubscribingDiscoveryOperation, + OperationOutput, + OperatorFunction, + validate_operator_function_signature, +) +from orchestrator.modules.operators.orchestrate import ( + orchestrate_explore_operation, + orchestrate_general_operation, ) -from orchestrator.modules.operators.base import DiscoveryOperationBase, OperationOutput -from orchestrator.modules.operators.orchestrate import orchestrate_general_operation moduleLog = logging.getLogger("operation_collections") -class OperationCollections(pydantic.BaseModel): - type: DiscoveryOperationEnum - function_operations: Annotated[ - dict[typing.AnyStr, typing.Callable], pydantic.Field(default_factory=dict) - ] - object_operations: Annotated[ - dict[typing.AnyStr, DiscoveryOperationBase], - pydantic.Field(default_factory=dict), - ] - function_operation_models: Annotated[ - dict[typing.AnyStr, type[pydantic.BaseModel]], - pydantic.Field(default_factory=dict), - ] - function_operation_model_defaults: Annotated[ - dict[typing.AnyStr, pydantic.BaseModel], pydantic.Field(default_factory=dict) - ] - function_operation_versions: Annotated[ - dict[typing.AnyStr, str], pydantic.Field(default_factory=dict) - ] - function_operation_descriptions: Annotated[ - dict[typing.AnyStr, str], pydantic.Field(default_factory=dict) - ] - - model_config = ConfigDict(arbitrary_types_allowed=True) - - def add_operation_function(self, name: str, fn: typing.Callable) -> None: - self.function_operations[name] = fn - - def add_operation_version(self, name: str, version: str) -> None: - self.function_operation_versions[name] = version - - def add_operation_description(self, name: str, version: str) -> None: - self.function_operation_descriptions[name] = version - - def add_operation_configuration_model( - self, name: str, model: type[pydantic.BaseModel] - ) -> None: - self.function_operation_models[name] = model - - def add_operation_configuration_model_default( - self, name: str, default: pydantic.BaseModel - ) -> None: - self.function_operation_model_defaults[name] = default +def _warn_if_operator_name_reused( + collection_label: str, name: str, operators: dict[str, OperatorMetadata] +) -> None: + """Log a warning when registering under a name that is already in use.""" + if name in operators: + moduleLog.warning( + "Operator %r is already registered in %s; replacing the existing entry", + name, + collection_label, + ) - def add_operation_object(self, name: str, object: DiscoveryOperationBase) -> None: - self.object_operations[name] = object - def list_operations(self) -> list: - return list(self.function_operations.keys()) + list( - self.object_operations.keys() - ) +class OperatorCollection(pydantic.BaseModel): + """A registry of operators of a single discovery operation type. - def configuration_model_for_operation(self, name: str) -> type[pydantic.BaseModel]: - if name not in self.function_operation_models: - raise ValueError(f"Unknown operator {name}") + Operators are added via the decorator functions (e.g. ``characterize_operation``, + ``explore_operation``). Each registered name maps to an + :class:`~orchestrator.core.operation.config.OperatorMetadata` instance that + carries the function, version, description, configuration model, + example configuration, and optional actor class. - return self.function_operation_models.get(name) + Attributes: + type: The discovery operation type all operators in this collection belong to. + operators: Mapping of operator name to :class:`~orchestrator.core.operation.config.OperatorMetadata` instance. + """ - def default_configuration_model_for_operation( - self, name: str - ) -> pydantic.BaseModel: - if name not in self.function_operation_models: - raise ValueError(f"Unknown operator {name}") + type: DiscoveryOperationEnum + operators: Annotated[ + dict[str, OperatorMetadata], pydantic.Field(default_factory=dict) + ] - return self.function_operation_model_defaults.get(name) + def list_operators(self) -> list[str]: + """Returns all registered operator names.""" + return list(self.operators.keys()) - def description_for_operation(self, name: str) -> str: - if name not in self.function_operation_models: - raise ValueError(f"Unknown operator {name}") + def __getattr__(self, item: str) -> OperatorFunction | None: + """Returns the operator function for the given registered name. - return self.function_operation_descriptions.get(name) + Args: + item: Registered operator name. - def __getattr__( - self, item: str - ) -> typing.Callable[..., object] | DiscoveryOperationBase: - if item in self.function_operations: - retval = self.function_operations[item] - elif item in self.object_operations: - retval = self.object_operations[item] - else: - raise AttributeError(f"Unknown attribute {item}") + Returns: + The callable registered under that name + Or None if no callable registered. - return retval + Raises: + AttributeError: If no operator is registered with that name. + """ + if item in self.operators: + return self.operators[item].function + raise AttributeError(f"Unknown attribute {item}") -characterize = OperationCollections( +characterize = OperatorCollection( type=orchestrator.core.operation.config.DiscoveryOperationEnum.CHARACTERIZE ) -explore = OperationCollections( +explore = OperatorCollection( type=orchestrator.core.operation.config.DiscoveryOperationEnum.SEARCH ) -modify = OperationCollections( +modify = OperatorCollection( type=orchestrator.core.operation.config.DiscoveryOperationEnum.MODIFY ) -export = OperationCollections( +export = OperatorCollection( type=orchestrator.core.operation.config.DiscoveryOperationEnum.EXPORT ) -compare = OperationCollections( +compare = OperatorCollection( type=orchestrator.core.operation.config.DiscoveryOperationEnum.COMPARE ) -fuse = OperationCollections( +fuse = OperatorCollection( type=orchestrator.core.operation.config.DiscoveryOperationEnum.FUSE ) -study = OperationCollections( +study = OperatorCollection( type=orchestrator.core.operation.config.DiscoveryOperationEnum.STUDY ) -learn = OperationCollections( +learn = OperatorCollection( type=orchestrator.core.operation.config.DiscoveryOperationEnum.LEARN ) operationCollectionMap = { @@ -142,178 +119,258 @@ def __getattr__( } # -# Decorators for registering operation functions +# Decorators for registering operator functions # -def register_characterize_operation( - func: typing.Callable[..., object], -) -> typing.Callable[ - [DiscoverySpace, FunctionOperationInfo | None, dict[str, dict]], OperationOutput -]: - @functools.wraps(func) - def characterize_operation_wrapper( - discoverySpace: DiscoverySpace, - operationInfo: FunctionOperationInfo | None = None, - **kwargs: dict, - ) -> OperationOutput: - - return orchestrate_general_operation( - operator_function=func, - operation_parameters=kwargs, - parameters_model=operationCollectionMap[ - DiscoveryOperationEnum.CHARACTERIZE - ].configuration_model_for_operation(func.__name__), - discovery_space=discoverySpace, - operation_info=operationInfo or FunctionOperationInfo(), - operation_type=orchestrator.core.operation.config.DiscoveryOperationEnum.CHARACTERIZE, - ) - - characterize.add_operation_function(func.__name__, characterize_operation_wrapper) - - return characterize_operation_wrapper - - def characterize_operation( name: str, + version: str, + configuration_model: type[pydantic.BaseModel], + configuration_model_default: pydantic.BaseModel, description: str | None = None, - version: str | None = "v0.1", - configuration_model: type[pydantic.BaseModel] | None = None, - configuration_model_default: pydantic.BaseModel | None = None, -) -> typing.Callable[ - [typing.Callable[..., object]], - typing.Callable[ - [DiscoverySpace, FunctionOperationInfo | None, dict[str, dict]], OperationOutput - ], -]: - characterize.add_operation_configuration_model(name, configuration_model) - characterize.add_operation_configuration_model_default( - name, configuration_model_default - ) - characterize.add_operation_version(name, version) - characterize.add_operation_description(name, description) - - return register_characterize_operation - - -def register_explore_operation( - func: typing.Callable[..., object], -) -> typing.Callable[..., object]: - """Registers a function that performs an explore operation on a DiscoverySpace""" +) -> typing.Callable[[OperatorFunction], OperatorFunction]: + """Decorator that registers a function as a characterize operation. + + Args: + name: Canonical operator name used in the registry. + version: Version string included in the operator identifier. + configuration_model: Pydantic model used to validate operation parameters. + configuration_model_default: Default parameter model instance. + description: Human-readable description shown in the registry. + + Returns: + A decorator that wraps and registers the decorated function under ``name``. + """ + + def _register(func: OperatorFunction) -> OperatorFunction: + @functools.wraps(func) + def wrapper( + discoverySpace: DiscoverySpace, + operationInfo: FunctionOperationInfo | None = None, + **kwargs: object, + ) -> OperationOutput: + return orchestrate_general_operation( + operator_metadata=characterize.operators[name], + operation_parameters=kwargs, + discovery_space=discoverySpace, + operation_info=operationInfo or FunctionOperationInfo(), + ) - # All explore operation function must call explore_operation_function_wrapper - # This function will validate params, create OperationResource, - # set up the necessary ray actors, run the operation etc. - explore.add_operation_function(func.__name__, func) + validate_operator_function_signature(wrapper) + wrapper = typing.cast("OperatorFunction", wrapper) + _warn_if_operator_name_reused("characterize", name, characterize.operators) + characterize.operators[name] = OperatorMetadata( + name=name, + function=wrapper, + version=version, + description=description, + configuration_model=configuration_model, + example_configuration=configuration_model_default, + type=DiscoveryOperationEnum.CHARACTERIZE, + ) + return wrapper - return func + return _register -def explore_operation( - name: str, - description: str | None = None, - configuration_model: type[pydantic.BaseModel] | None = None, - version: str | None = "v0.1", - configuration_model_default: pydantic.BaseModel | None = None, -) -> typing.Callable[[typing.Callable[..., object]], typing.Callable[..., object]]: - explore.add_operation_configuration_model(name, configuration_model) - explore.add_operation_configuration_model_default(name, configuration_model_default) - explore.add_operation_version(name, version) - explore.add_operation_description(name, description) +def _validate_explore_cls(t: type, metadata: OperatorMetadata) -> None: + """Validate a class-decorated explore operator and its metadata. - return register_explore_operation + Args: + t: The decorated class. + metadata: The :class:`~orchestrator.core.operation.config.OperatorMetadata` + returned by ``t.operator_metadata()``. + Raises: + TypeError: If ``t`` is not a + :class:`~orchestrator.modules.operators.base.DiscoverySpaceSubscribingDiscoveryOperation` + subclass, or if ``metadata.cls`` is set to a class other than ``t``. + """ + if not issubclass(t, DiscoverySpaceSubscribingDiscoveryOperation): + raise TypeError( + f"@explore_operation: {t.__name__} must be a subclass of " + "DiscoverySpaceSubscribingDiscoveryOperation (e.g. subclass " + "`Explore` or another discovery operation that subscribes to the space)." + ) + if metadata.cls is not None and metadata.cls is not t: + raise TypeError( + f"@explore_operation on {t.__name__}: operator_metadata().cls is " + f"{metadata.cls!r} but the decorated class is {t!r}. " + "Leave cls as None in operator_metadata() — the decorator sets it." + ) -def register_modify_operation( - func: typing.Callable[..., object], -) -> typing.Callable[[typing.Callable[..., object]], OperationCollections]: - """Registers a function that modifies a discovery space to return a new discovery space""" - @functools.wraps(func) - def modify_operation_wrapper( +def explore_operation( + cls: "type[DiscoveryOperationBase]", +) -> "type[DiscoveryOperationBase]": + """Decorator that registers an explore (search) operator class. + + All metadata is sourced from the class's ``operator_metadata()`` + classmethod. The decorator generates an :class:`OperatorFunction`, + validates its signature, registers it in the explore collection, and + returns the **original class unchanged**:: + + @explore_operation + class MyOp(Explore): + @classmethod + def operator_metadata(cls) -> OperatorMetadata: + return OperatorMetadata( + name="my_op", + version="0.1.0", + configuration_model=MyOpParameters, + example_configuration=MyOpParameters(), + type=DiscoveryOperationEnum.SEARCH, + ) + + async def run(self) -> OperationOutput | None: ... + + The generated operator function is accessible via + ``explore.operators[name].function``; the class name continues to refer + to the class itself. + + Args: + cls: The operator class to register. Must be a subclass of + :class:`~orchestrator.modules.operators.base.DiscoverySpaceSubscribingDiscoveryOperation` + and must implement ``operator_metadata()``. + + Returns: + *cls* unchanged. + + Raises: + NotImplementedError: If ``cls.operator_metadata()`` is not implemented. + TypeError: If ``cls`` fails :func:`_validate_explore_cls`. + """ + metadata = cls.operator_metadata() + _validate_explore_cls(cls, metadata) + op_name = metadata.name + + def _generated( discoverySpace: DiscoverySpace, operationInfo: FunctionOperationInfo | None = None, - **kwargs: dict, + **kwargs: object, ) -> OperationOutput: - - return orchestrate_general_operation( - operator_function=func, - operation_parameters=kwargs, - parameters_model=operationCollectionMap[ - DiscoveryOperationEnum.MODIFY - ].configuration_model_for_operation(func.__name__), + return orchestrate_explore_operation( + operator_metadata=explore.operators[op_name], discovery_space=discoverySpace, + parameters=kwargs, operation_info=operationInfo or FunctionOperationInfo(), - operation_type=orchestrator.core.operation.config.DiscoveryOperationEnum.MODIFY, ) - modify.add_operation_function(func.__name__, modify_operation_wrapper) - - return modify + _generated.__name__ = op_name + _generated.__qualname__ = op_name + validate_operator_function_signature(_generated) + _generated = typing.cast("OperatorFunction", _generated) + _warn_if_operator_name_reused("explore", op_name, explore.operators) + explore.operators[op_name] = metadata.model_copy( + update={ + "function": _generated, + "cls": cls, + } + ) + return cls def modify_operation( name: str, + version: str, + configuration_model: type[pydantic.BaseModel], + configuration_model_default: pydantic.BaseModel, description: str | None = None, - version: str | None = "v0.1", - configuration_model: type[pydantic.BaseModel] | None = None, - configuration_model_default: pydantic.BaseModel | None = None, -) -> typing.Callable[[typing.Callable[..., object]], OperationCollections]: - modify.add_operation_configuration_model(name, configuration_model) - modify.add_operation_configuration_model_default(name, configuration_model_default) - modify.add_operation_version(name, version) - modify.add_operation_description(name, description) - - return register_modify_operation - - -def register_export_operation( - func: typing.Callable[..., object], -) -> typing.Callable[ - [DiscoverySpace, FunctionOperationInfo | None, dict[str, dict]], OperationOutput -]: - """Registers a function that performs a lakehouse operation on a DiscoverySpace""" - - @functools.wraps(func) - def export_operation_wrapper( - discoverySpace: DiscoverySpace, - operationInfo: FunctionOperationInfo | None = None, - **kwargs: dict, - ) -> OperationOutput: - return orchestrate_general_operation( - operator_function=func, - operation_parameters=kwargs, - parameters_model=operationCollectionMap[ - DiscoveryOperationEnum.EXPORT - ].configuration_model_for_operation(func.__name__), - discovery_space=discoverySpace, - operation_info=operationInfo or FunctionOperationInfo(), - operation_type=orchestrator.core.operation.config.DiscoveryOperationEnum.EXPORT, - ) +) -> typing.Callable[[OperatorFunction], OperatorFunction]: + """Decorator that registers a function as a modify operation. + + Args: + name: Canonical operator name used in the registry. + version: Version string included in the operator identifier. + configuration_model: Pydantic model used to validate operation parameters. + configuration_model_default: Default parameter model instance. + description: Human-readable description shown in the registry. + + Returns: + A decorator that wraps and registers the decorated function under ``name``. + """ + + def _register(func: OperatorFunction) -> OperatorFunction: + @functools.wraps(func) + def wrapper( + discoverySpace: DiscoverySpace, + operationInfo: FunctionOperationInfo | None = None, + **kwargs: object, + ) -> OperationOutput: + return orchestrate_general_operation( + operator_metadata=modify.operators[name], + operation_parameters=kwargs, + discovery_space=discoverySpace, + operation_info=operationInfo or FunctionOperationInfo(), + ) - export.add_operation_function(func.__name__, export_operation_wrapper) + validate_operator_function_signature(wrapper) + wrapper = typing.cast("OperatorFunction", wrapper) + _warn_if_operator_name_reused("modify", name, modify.operators) + modify.operators[name] = OperatorMetadata( + name=name, + function=wrapper, + version=version, + description=description, + configuration_model=configuration_model, + example_configuration=configuration_model_default, + type=DiscoveryOperationEnum.MODIFY, + ) + return wrapper - return export_operation_wrapper + return _register def export_operation( name: str, + version: str, + configuration_model: type[pydantic.BaseModel], + configuration_model_default: pydantic.BaseModel, description: str | None = None, - configuration_model: type[pydantic.BaseModel] | None = None, - version: str | None = "v0.1", - configuration_model_default: pydantic.BaseModel | None = None, -) -> typing.Callable[ - [typing.Callable[..., object]], - typing.Callable[ - [DiscoverySpace, FunctionOperationInfo | None, dict[str, dict]], OperationOutput - ], -]: - export.add_operation_configuration_model(name, configuration_model) - export.add_operation_configuration_model_default(name, configuration_model_default) - export.add_operation_version(name, version) - export.add_operation_description(name, description) - - return register_export_operation +) -> typing.Callable[[OperatorFunction], OperatorFunction]: + """Decorator that registers a function as an export operation. + + Args: + name: Canonical operator name used in the registry. + version: Version string included in the operator identifier. + configuration_model: Pydantic model used to validate operation parameters. + configuration_model_default: Default parameter model instance. + description: Human-readable description shown in the registry. + + Returns: + A decorator that wraps and registers the decorated function under ``name``. + """ + + def _register(func: OperatorFunction) -> OperatorFunction: + @functools.wraps(func) + def wrapper( + discoverySpace: DiscoverySpace, + operationInfo: FunctionOperationInfo | None = None, + **kwargs: object, + ) -> OperationOutput: + return orchestrate_general_operation( + operator_metadata=export.operators[name], + operation_parameters=kwargs, + discovery_space=discoverySpace, + operation_info=operationInfo or FunctionOperationInfo(), + ) + + validate_operator_function_signature(wrapper) + wrapper = typing.cast("OperatorFunction", wrapper) + _warn_if_operator_name_reused("export", name, export.operators) + export.operators[name] = OperatorMetadata( + name=name, + function=wrapper, + version=version, + description=description, + configuration_model=configuration_model, + example_configuration=configuration_model_default, + type=DiscoveryOperationEnum.EXPORT, + ) + return wrapper + + return _register def load_operators() -> None: diff --git a/orchestrator/modules/operators/orchestrate.py b/orchestrator/modules/operators/orchestrate.py index 4b74eb2a4..d2d632dc6 100644 --- a/orchestrator/modules/operators/orchestrate.py +++ b/orchestrator/modules/operators/orchestrate.py @@ -15,6 +15,7 @@ from orchestrator.core.operation.config import ( DiscoveryOperationResourceConfiguration, FunctionOperationInfo, + OperatorModuleConf, ) from orchestrator.core.operation.operation import OperationException, OperationOutput from orchestrator.metastore.base import ResourceDoesNotExistError @@ -25,11 +26,12 @@ cleanup_callback_functions, graceful_operation_shutdown_signal_handler, ) + +# These functions are re-exported via this module — keep the imports even if +# not referenced locally. from orchestrator.modules.operators._explore_orchestration import ( - orchestrate_explore_operation, + orchestrate_explore_operation, # noqa: F401 ) - -# Want this function to be accessed via this module not the private module from orchestrator.modules.operators._general_orchestration import ( orchestrate_general_operation, # noqa: F401 ) @@ -56,6 +58,26 @@ def graceful_orchestrate_shutdown() -> None: moduleLog.info("Graceful shutdown complete") +def _check_if_using_unsupported_operator_module_conf( + operation_resource_configuration: DiscoveryOperationResourceConfiguration, +) -> None: + if isinstance( + operation_resource_configuration.operation.module, OperatorModuleConf + ): + moduleLog.warning( + "The supplied operation configuration uses an unsupported legacy format for the" + "operation.module field: Use operatorName/operationType instead " + "of moduleName/moduleClass. See https://ibm.github.io/ado/examples/random-walk/#exploring-the-discoveryspace" + "for an example. " + ) + raise ValueError( + "The supplied operation configuration uses an unsupported legacy format for the" + "operation.module field: Use operatorName/operationType instead " + "of moduleName/moduleClass. See https://ibm.github.io/ado/examples/random-walk/#exploring-the-discoveryspace" + "for an example. " + ) + + def orchestrate( operation_resource_configuration: DiscoveryOperationResourceConfiguration, project_context: ProjectContext, @@ -151,32 +173,18 @@ def orchestrate( operation_parameters = operation_parameters.model_dump() try: - if isinstance( - operation_resource_configuration.operation.module, - orchestrator.core.operation.config.OperatorModuleConf, - ): - if ( - operation_resource_configuration.operation.module.operationType - == orchestrator.core.operation.config.DiscoveryOperationEnum.SEARCH - ): - output = orchestrate_explore_operation( - operator_module=operation_resource_configuration.operation.module, - discovery_space=discovery_space, - parameters=operation_parameters, - operation_info=operation_info, - ) - else: - raise ValueError( - "Implementing operations as classes is only supported for explore operations" - ) - else: - output = ( - operation_resource_configuration.operation.module.operationFunction()( - discovery_space, - operationInfo=operation_info, - **operation_parameters, - ) - ) # type: OperationOutput + _check_if_using_unsupported_operator_module_conf( + operation_resource_configuration + ) + + operator_fn = ( + operation_resource_configuration.operation.module.operationFunction() + ) + output: OperationOutput = operator_fn( + discovery_space, + operationInfo=operation_info, + **operation_parameters, + ) except KeyboardInterrupt: moduleLog.warning("Caught keyboard interrupt - initiating graceful shutdown") raise diff --git a/orchestrator/modules/operators/randomwalk.py b/orchestrator/modules/operators/randomwalk.py index a6daeb480..d77ab5fc6 100644 --- a/orchestrator/modules/operators/randomwalk.py +++ b/orchestrator/modules/operators/randomwalk.py @@ -29,21 +29,18 @@ SequentialSampleSelector, WalkModeEnum, ) -from orchestrator.core.discoveryspace.space import DiscoverySpace from orchestrator.core.operation.config import ( DiscoveryOperationEnum, - FunctionOperationInfo, + OperatorMetadata, ) -from orchestrator.core.operation.operation import OperationOutput from orchestrator.modules.module import ( ModuleConf, ModuleTypeEnum, load_module_class_or_function, ) -from orchestrator.modules.operators.base import Characterize, measure_or_replay +from orchestrator.modules.operators.base import Explore, measure_or_replay from orchestrator.modules.operators.collections import explore_operation from orchestrator.modules.operators.discovery_space_manager import DiscoverySpaceManager -from orchestrator.modules.operators.orchestrate import orchestrate_explore_operation from orchestrator.schema.entity import Entity from orchestrator.schema.measurementspace import MeasurementSpace from orchestrator.schema.request import MeasurementRequest, MeasurementRequestStateEnum @@ -421,24 +418,10 @@ def __str__(self) -> str: ) -@ray.remote -class RandomWalk(Characterize): +@explore_operation +class RandomWalk(Explore): """Performs a random walk through a set of known entities in a space""" - @classmethod - def defaultOperationParameters( - cls, - ) -> RandomWalkParameters: - - return RandomWalkParameters() - - @classmethod - def validateOperationParameters( - cls, parameters: dict | pydantic.BaseModel - ) -> RandomWalkParameters: - - return RandomWalkParameters.model_validate(parameters) - @classmethod def description(cls) -> str: @@ -453,7 +436,7 @@ def __init__( self, operationActorName: str, namespace: str, - state: DiscoverySpaceManager, + discovery_space_manager: DiscoverySpaceManager, actuators: dict[str, "ActuatorBase"], params: dict | None = None, ) -> None: @@ -472,7 +455,6 @@ def __init__( self.criticalError = False self.update_queue = asyncio.queues.Queue() - self.actuators = actuators self._entitiesSampled = 0 self._experimentsRequested = 0 # Key is requestIndex, value is RequestRetry @@ -481,11 +463,10 @@ def __init__( # If this was not true we would need to use the entity id+requestIndex self._retriedExperimentRequests = {} # type: dict[int, RequestRetry] - # Sets state, actorName ivars and subscribes to the state super().__init__( operationActorName=operationActorName, namespace=namespace, - state=state, + discovery_space_manager=discovery_space_manager, actuators=actuators, ) @@ -511,16 +492,16 @@ async def run(self) -> None: f"Starting random walk. Sampler config is: {self.params.samplerConfig}" ) # noinspection PyUnresolvedReferences - measurement_queue = await self.state.measurement_queue.remote() + measurement_queue = await self.ds_manager.measurement_queue.remote() sampler = self.params.samplerConfig.sampler() self.log.debug(sampler) iterator = await sampler.remoteEntityIterator( - remoteDiscoverySpace=self.state, batchsize=1 + remoteDiscoverySpace=self.ds_manager, batchsize=1 ) # noinspection PyUnresolvedReferences - ds = await self.state.discoverySpace.remote() # type: DiscoverySpace + ds = await self.ds_manager.discoverySpace.remote() # type: DiscoverySpace measurement_space = ds.measurementSpace entity_space: EntitySpaceRepresentation | None = ds.entitySpace @@ -536,7 +517,7 @@ async def run(self) -> None: number_entities = entity_space.size except AttributeError as error: # noinspection PyUnresolvedReferences - self.state.unsubscribeFromUpdates.remote( + self.ds_manager.unsubscribeFromUpdates.remote( subscriberName=self.actorName ) raise ValueError( @@ -549,7 +530,7 @@ async def run(self) -> None: ) else: # noinspection PyUnresolvedReferences - self.state.unsubscribeFromUpdates.remote( + self.ds_manager.unsubscribeFromUpdates.remote( subscriberName=self.actorName ) raise ValueError( @@ -772,7 +753,7 @@ async def run(self) -> None: f"Was notified that {finished_requests} measurements had completed before error." ) # noinspection PyUnresolvedReferences - self.state.unsubscribeFromUpdates.remote(subscriberName=self.actorName) + self.ds_manager.unsubscribeFromUpdates.remote(subscriberName=self.actorName) def _processCompletedMeasurement( self, @@ -965,54 +946,16 @@ def numberMeasurementsRequested(self) -> int: def operationIdentifier(self) -> str: - return f"{self.__class__.operatorIdentifier()}-{self.runid}" + return f"{self.__class__.operator_metadata().operatorIdentifier}-{self.runid}" @classmethod - def operatorIdentifier(cls) -> str: - - from importlib.metadata import version - - version = version("ado-core") - - return f"randomwalk-{version}" - - @classmethod - def operationType(cls) -> DiscoveryOperationEnum: - - return DiscoveryOperationEnum.SEARCH - - -@explore_operation( - name="random_walk", - description=RandomWalk.description(), - configuration_model=RandomWalkParameters, - configuration_model_default=RandomWalkParameters(), - version=version("ado-core"), -) -def random_walk( - discoverySpace: DiscoverySpace, - operationInfo: FunctionOperationInfo | None = None, - **kwargs: dict, -) -> OperationOutput: - """ - Performs a random_walk operation on a given discoverySpace - - """ - - import orchestrator.modules.module - - if operationInfo is None: - operationInfo = FunctionOperationInfo() - - module = orchestrator.core.operation.config.OperatorModuleConf( - moduleName="orchestrator.modules.operators.randomwalk", - moduleClass="RandomWalk", - moduleType=orchestrator.modules.module.ModuleTypeEnum.OPERATION, - ) - - return orchestrate_explore_operation( - discovery_space=discoverySpace, - operator_module=module, - parameters=kwargs, - operation_info=operationInfo, - ) + def operator_metadata(cls) -> OperatorMetadata: + """Returns operator metadata for the random_walk explore operator.""" + return OperatorMetadata( + name="random_walk", + version=version("ado-core"), + description=cls.description(), + configuration_model=RandomWalkParameters, + example_configuration=RandomWalkParameters(), + type=DiscoveryOperationEnum.SEARCH, + ) diff --git a/orchestrator/modules/operators/setup.py b/orchestrator/modules/operators/setup.py index b9954332f..25ac7f69e 100644 --- a/orchestrator/modules/operators/setup.py +++ b/orchestrator/modules/operators/setup.py @@ -10,12 +10,12 @@ from orchestrator.core.discoveryspace.space import DiscoverySpace from orchestrator.core.operation.config import ( DiscoveryOperationConfiguration, - OperatorModuleConf, + OperatorMetadata, + OperatorReference, get_actuator_configurations, validate_actuator_configurations_against_space_configuration, ) from orchestrator.modules.actuators.measurement_queue import MeasurementQueue -from orchestrator.modules.module import load_module_class_or_function from orchestrator.utilities.logging import configure_logging if typing.TYPE_CHECKING: @@ -134,51 +134,72 @@ def setup_actuators( def setup_operator( - operator_module: OperatorModuleConf, + operator_metadata: OperatorMetadata, parameters: dict, discovery_space: DiscoverySpace, namespace: str, - state: "DiscoverySpaceManagerActor", + discovery_space_manager: "DiscoverySpaceManagerActor", actuators: dict, ) -> "OperatorActor": - """Sets up and creates an operator actor for class-based operations + """Sets up and creates an operator actor for class-based explore operations. - This function loads the operator class, creates a Ray actor instance with the - specified namespace, and initializes it with the provided parameters, state, - and actuators. + Instantiates the operator class from ``operator_metadata`` as a Ray actor. Params: - operator_module: Configuration for the operator module to load + operator_metadata: Registered metadata for the operator, carrying the class + and canonical name. parameters: Dictionary of parameters to pass to the operator discovery_space: The discovery space the operator will operate on namespace: Ray namespace to create the operator actor in - state: DiscoverySpaceManager actor handle for state management + discovery_space_manager: DiscoverySpaceManager actor handle actuators: Dictionary of actuator actor handles keyed by actuator identifier Returns: OperatorActor handle for the created operator actor + + Raises: + ValueError: If ``operator_metadata.cls`` is None (operator was not registered + via the class decorator). """ + import ray + import orchestrator.utilities.output moduleLog.info("Creating operation") - operatorClass = load_module_class_or_function(operator_module) - operator = operatorClass.options( - name=operator_module.moduleClass, namespace=namespace - ).remote( - operationActorName=operator_module.moduleClass, - namespace=namespace, - state=state, - params=parameters, - actuators=actuators, + if operator_metadata.cls is None: + raise ValueError( + f"No operator class registered for '{operator_metadata.name}'. " + "Ensure the operator is decorated with @explore_operation." + ) + + base_class = operator_metadata.cls + actor_name = operator_metadata.name + + operator = ( + ray.remote(base_class) + .options(name=actor_name, namespace=namespace) + .remote( + operationActorName=actor_name, + namespace=namespace, + discovery_space_manager=discovery_space_manager, + params=parameters, + actuators=actuators, + ) ) print("=========== Operation Details ============\n") print(f"Space ID: {discovery_space.uri}") print(f"Sample Store ID: {discovery_space.sample_store.identifier}") + operator_reference = OperatorReference( + operatorName=operator_metadata.name, + operationType=operator_metadata.type, + ) conf_string = orchestrator.utilities.output.pydantic_model_as_yaml( - DiscoveryOperationConfiguration(module=operator_module, parameters=parameters), + DiscoveryOperationConfiguration( + module=operator_reference, parameters=parameters + ), exclude_none=True, ) print(f"Operation Configuration:\n {conf_string}") diff --git a/orchestrator/utilities/ray.py b/orchestrator/utilities/ray.py new file mode 100644 index 000000000..51e06bc00 --- /dev/null +++ b/orchestrator/utilities/ray.py @@ -0,0 +1,63 @@ +# Copyright IBM Corporation 2025, 2026 +# SPDX-License-Identifier: MIT + +"""Utilities for working with Ray-decorated classes.""" + +import typing + +T = typing.TypeVar("T") + + +def extract_base_class( + obj: typing.Any, # noqa: ANN401 + base_class: type[T], +) -> type[T]: + """Extract the undecorated base class from a potentially Ray-decorated ActorClass. + + When a class is decorated with ``@ray.remote`` it becomes a Ray + ``ActorClass`` instance rather than a plain Python ``type``. This + function accepts either form and always returns the underlying Python + class, which is necessary before applying ``ray.remote`` dynamically + (applying it a second time to an already-decorated class would fail). + + Args: + obj: Either a Ray ``ActorClass`` instance or an undecorated subclass + of ``base_class``. + base_class: The expected base class that the extracted type must + be a subclass of. + + Returns: + The undecorated Python class that is a subclass of ``base_class``. + + Raises: + ValueError: If ``obj`` is a Ray ``ActorClass`` but the original class + cannot be extracted or is not a subclass of ``base_class``. + TypeError: If ``obj`` is not a type and not a Ray ``ActorClass``, or + is a type but not a subclass of ``base_class``. + """ + # Fast path: already an undecorated subclass. + if isinstance(obj, type) and issubclass(obj, base_class): + return obj # type: ignore[return-value] + + # Try to extract the original class from a Ray ActorClass. + try: + import ray.actor + + if isinstance(obj, ray.actor.ActorClass): + original = getattr(obj, "__ray_actor_class__", None) + if isinstance(original, type) and issubclass(original, base_class): + return original # type: ignore[return-value] + raise ValueError( + f"Could not extract {base_class.__name__} from Ray ActorClass {obj}: " + "__ray_actor_class__ is missing or is not a subclass of " + f"{base_class.__name__}." + ) + except ImportError: + pass + + if not isinstance(obj, type): + raise TypeError( + f"Expected a {base_class.__name__} subclass or a Ray ActorClass, " + f"got instance of {type(obj).__name__}." + ) + raise TypeError(f"Expected a subclass of {base_class.__name__}, got {obj!r}.") diff --git a/plugins/actuators/example_actuator/yamls/random_walk_operation.yaml b/plugins/actuators/example_actuator/yamls/random_walk_operation.yaml index 642f538d7..c1c0404e2 100644 --- a/plugins/actuators/example_actuator/yamls/random_walk_operation.yaml +++ b/plugins/actuators/example_actuator/yamls/random_walk_operation.yaml @@ -4,8 +4,8 @@ spaces: - 'space-75f7af-c04713' # REPLACE WITH REAL SPACE ID. Note only one spaceid allowed operation: module: - moduleClass: RandomWalk - moduleName: orchestrator.modules.operators.randomwalk + operatorName: random_walk + operationType: search parameters: numberEntities: 10 batchSize: 2 diff --git a/plugins/operators/profile_space/profile_space/operator.py b/plugins/operators/profile_space/profile_space/operator.py index 0bc8f919b..2384d9db8 100644 --- a/plugins/operators/profile_space/profile_space/operator.py +++ b/plugins/operators/profile_space/profile_space/operator.py @@ -3,6 +3,7 @@ from importlib.metadata import version import pandas as pd +import pydantic from orchestrator.core.discoveryspace.space import DiscoverySpace from orchestrator.core.operation.config import FunctionOperationInfo @@ -10,14 +11,18 @@ from orchestrator.modules.operators.collections import characterize_operation +class ProfileParameters(pydantic.BaseModel): + """Parameters for the profile operator (no configurable options).""" + + # See https://ibm.github.io/ado/operators/creating-operators/#ado-operator-functions # for documentation on the decorator and its parameters @characterize_operation( name="profile", - configuration_model=None, # You can use this field to define the option of your operator if any - see https://ibm.github.io/ado/operators/creating-operators/#describing-your-operation-input-parameters - configuration_model_default=None, # Use this field to provide default/example values for your operator - description="Returns a ydata_profiling ProfileReport for the space", version=version("ado-core"), + configuration_model=ProfileParameters, + configuration_model_default=ProfileParameters(), + description="Returns a ydata_profiling ProfileReport for the space", ) # operator function can have any name but have similar parameters - see https://ibm.github.io/ado/operators/creating-operators/#operator-function-parameters def profile( diff --git a/plugins/operators/ray_tune/ado_ray_tune/operator.py b/plugins/operators/ray_tune/ado_ray_tune/operator.py index bb4e33e67..63fd7531c 100644 --- a/plugins/operators/ray_tune/ado_ray_tune/operator.py +++ b/plugins/operators/ray_tune/ado_ray_tune/operator.py @@ -25,6 +25,7 @@ ) from orchestrator.core.operation.config import ( DiscoveryOperationEnum, + OperatorMetadata, ) from orchestrator.core.operation.operation import OperationOutput from orchestrator.core.operation.resource import ( @@ -34,9 +35,10 @@ ) from orchestrator.modules.actuators.measurement_queue import MeasurementQueue from orchestrator.modules.operators.base import ( - Search, + Explore, measure_or_replay, ) +from orchestrator.modules.operators.collections import explore_operation from orchestrator.modules.operators.discovery_space_manager import DiscoverySpaceManager from orchestrator.schema.domain import PropertyDomain from orchestrator.schema.entity import ( @@ -761,27 +763,10 @@ def search_space_from_explicit_entity_space( return space -@ray.remote -class RayTune(Search): +@explore_operation +class RayTune(Explore): """Uses raytune optimization algorithm to search through entities in a space""" - @classmethod - def defaultOperationParameters( - cls, - ) -> RayTuneConfiguration: - return RayTuneConfiguration( - tuneConfig=OrchTuneConfig( - metric="wallclock_time", search_alg=OrchSearchAlgorithm(name="bayesopt") - ), - runtimeConfig=OrchRunConfig(), - ) - - @classmethod - def validateOperationParameters( - cls, parameters: dict | pydantic.BaseModel - ) -> RayTuneConfiguration: - return RayTuneConfiguration.model_validate(parameters) - @classmethod def description(cls) -> str: return """RayTune provides capabilities for sampling points in an entity space and applying @@ -797,7 +782,7 @@ def __init__( self, operationActorName: str, namespace: str, - state: DiscoverySpaceManager, + discovery_space_manager: DiscoverySpaceManager, actuators: dict[str, "orchestrator.modules.actuators.base.ActuatorBase"], params: dict | None = None, ) -> None: @@ -814,18 +799,16 @@ def __init__( self.params = RayTuneConfiguration(**params) - self.actuators = actuators self._entitiesSubmitted = 0 self._finishedMeasurements = {} self._requestIndex = 0 self.received_critical_error_notification = False self.criticalError = None # Will store the critical error if we receive one - # Sets state, actorName ivars and subscribes to the state super().__init__( operationActorName=operationActorName, namespace=namespace, - state=state, + discovery_space_manager=discovery_space_manager, actuators=actuators, ) @@ -858,11 +841,11 @@ async def run(self) -> OperationOutput: try: # noinspection PyUnresolvedReferences entity_space = ( - await self.state.entitySpace.remote() + await self.ds_manager.entitySpace.remote() ) # type: EntitySpaceRepresentation # noinspection PyUnresolvedReferences measurement_space = ( - await self.state.measurementSpace.remote() + await self.ds_manager.measurementSpace.remote() ) # type: MeasurementSpace metric_or_metrics = self.params.tuneConfig.metric @@ -895,7 +878,7 @@ async def run(self) -> OperationOutput: measurement_space=measurement_space, entity_space=entity_space, actuators=self.actuators, - state=self.state, + state=self.ds_manager, orchestrator_config=self.params.orchestratorConfig, target_metric=self.params.tuneConfig.metric, debugging=False, @@ -984,7 +967,7 @@ async def run(self) -> OperationOutput: ) # noinspection PyUnresolvedReferences - self.state.unsubscribeFromUpdates.remote(subscriberName=self.actorName) + self.ds_manager.unsubscribeFromUpdates.remote(subscriberName=self.actorName) return operation_output @@ -996,16 +979,24 @@ def numberMeasurementsRequested(self) -> int: return self._requestIndex def operationIdentifier(self) -> str: - return f"{self.__class__.operatorIdentifier()}-{self.params.tuneConfig.search_alg.name}-{self.runid}" + return f"{self.__class__.operator_metadata().operatorIdentifier}-{self.params.tuneConfig.search_alg.name}-{self.runid}" @classmethod - def operatorIdentifier(cls) -> str: + def operator_metadata(cls) -> OperatorMetadata: + """Returns operator metadata for the ray_tune explore operator.""" from importlib.metadata import version - version = version("ado-ray-tune") - - return f"raytune-{version}" - - @classmethod - def operationType(cls) -> DiscoveryOperationEnum: - return DiscoveryOperationEnum.SEARCH + return OperatorMetadata( + name="ray_tune", + version=version("ado-ray-tune"), + description=cls.description(), + configuration_model=RayTuneConfiguration, + example_configuration=RayTuneConfiguration( + tuneConfig=OrchTuneConfig( + metric="wallclock_time", + search_alg=OrchSearchAlgorithm(name="bayesopt"), + ), + runtimeConfig=OrchRunConfig(), + ), + type=DiscoveryOperationEnum.SEARCH, + ) diff --git a/plugins/operators/ray_tune/ado_ray_tune/operator_function.py b/plugins/operators/ray_tune/ado_ray_tune/operator_function.py deleted file mode 100644 index b60ba2133..000000000 --- a/plugins/operators/ray_tune/ado_ray_tune/operator_function.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright IBM Corporation 2025, 2026 -# SPDX-License-Identifier: MIT -from importlib.metadata import version - -import orchestrator.core -import orchestrator.modules.module -from orchestrator.core.discoveryspace.space import DiscoverySpace -from orchestrator.core.operation.config import FunctionOperationInfo -from orchestrator.core.operation.operation import OperationOutput -from orchestrator.modules.operators.collections import explore_operation -from orchestrator.modules.operators.orchestrate import ( - orchestrate_explore_operation, -) - -from .config import RayTuneConfiguration -from .operator import RayTune - - -@explore_operation( - name="ray_tune", - description=RayTune.description(), - configuration_model=RayTuneConfiguration, - configuration_model_default=RayTune.defaultOperationParameters(), - version=version("ado-ray-tune"), -) -def ray_tune( - discoverySpace: DiscoverySpace, - operationInfo: FunctionOperationInfo | None = None, - **kwargs: dict, -) -> OperationOutput: - """ - Performs a random_walk operation on a given discoverySpace - - """ - if operationInfo is None: - operationInfo = FunctionOperationInfo() - - module = orchestrator.core.operation.config.OperatorModuleConf( - moduleName="ado_ray_tune.operator", - moduleClass="RayTune", - moduleType=orchestrator.modules.module.ModuleTypeEnum.OPERATION, - ) - - # validate parameters - RayTuneConfiguration.model_validate(kwargs) - - return orchestrate_explore_operation( - discovery_space=discoverySpace, - operator_module=module, - parameters=kwargs, - operation_info=operationInfo, - ) diff --git a/plugins/operators/ray_tune/pyproject.toml b/plugins/operators/ray_tune/pyproject.toml index 809e6f7db..837a9887d 100644 --- a/plugins/operators/ray_tune/pyproject.toml +++ b/plugins/operators/ray_tune/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ dynamic = ["version"] [project.entry-points."ado.operators"] -ado-ray-tune = "ado_ray_tune.operator_function" +ado-ray-tune = "ado_ray_tune.operator" rifferla = "ado_ray_tune.rifferla" [build-system] diff --git a/tests/actuators/test_registry_extraction.py b/tests/actuators/test_registry_extraction.py index 45e980106..9f8f325f4 100644 --- a/tests/actuators/test_registry_extraction.py +++ b/tests/actuators/test_registry_extraction.py @@ -12,7 +12,7 @@ def test_extract_base_class_from_undecorated_actuator() -> None: """Test that extraction works with undecorated classes (returns as-is).""" - from orchestrator.modules.actuators.registry import _extract_base_actuator_class + from orchestrator.utilities.ray import extract_base_class class TestActuator(ActuatorBase): # noqa: ANN001, ANN202, ANN206 identifier = "test_undecorated" @@ -33,7 +33,7 @@ def catalog(cls, actuator_configuration=None): # noqa: ANN001, ANN206 return ExperimentCatalog(identifier=cls.identifier, experiments=[]) # Extract from undecorated class (should return the same class) - extracted_class = _extract_base_actuator_class(TestActuator) + extracted_class = extract_base_class(TestActuator, ActuatorBase) # Should return the same class assert extracted_class is TestActuator @@ -49,12 +49,12 @@ def test_extract_base_class_from_undecorated_actuator_in_codebase() -> None: class as-is for undecorated actuators. """ from orchestrator.modules.actuators import custom_experiments - from orchestrator.modules.actuators.registry import _extract_base_actuator_class + from orchestrator.utilities.ray import extract_base_class actuator_class = custom_experiments.CustomExperiments # Extraction should return the same class for undecorated actuators - extracted = _extract_base_actuator_class(actuator_class) + extracted = extract_base_class(actuator_class, ActuatorBase) assert extracted is actuator_class assert issubclass(extracted, ActuatorBase) assert extracted.identifier == "custom_experiments" diff --git a/tests/core/test_operation.py b/tests/core/test_operation.py index 5d379138c..52cc995ba 100644 --- a/tests/core/test_operation.py +++ b/tests/core/test_operation.py @@ -144,7 +144,9 @@ def test_operation_config_file_valid(valid_operation_config_file: str) -> None: moduleClass = load_module_class_or_function( module ) # type: "orchestrator.modules.operators.base.DiscoveryOperationBase" - moduleClass.validateOperationParameters(parameters=op_cfg.parameters) + meta = moduleClass.operator_metadata() + if meta.configuration_model is not None: + meta.configuration_model.model_validate(op_cfg.parameters) def test_set_manual_operation_identifier( diff --git a/tests/operators/test_operators.py b/tests/operators/test_operators.py index 8319ebe17..764e0e1c0 100644 --- a/tests/operators/test_operators.py +++ b/tests/operators/test_operators.py @@ -1,6 +1,7 @@ # Copyright IBM Corporation 2025, 2026 # SPDX-License-Identifier: MIT import itertools +import logging import re import typing @@ -35,34 +36,12 @@ RandomWalk, RandomWalkParameters, SamplerModuleConf, - random_walk, ) -def test_randomwalk_class_methods() -> None: - - import orchestrator.metastore.project - - assert ( - RandomWalk.operationType() - == orchestrator.core.operation.config.DiscoveryOperationEnum.SEARCH - ) - assert RandomWalk.operatorIdentifier().split("-")[0] == "randomwalk" - - -def test_raytune_class_methods() -> None: - import orchestrator.metastore.project - - assert ( - RayTune.operationType() - == orchestrator.core.operation.config.DiscoveryOperationEnum.SEARCH - ) - assert RayTune.operatorIdentifier().split("-")[0] == "raytune" - - def test_operator_function_conf() -> None: - function = orchestrator.core.operation.config.OperatorFunctionConf( + function = orchestrator.core.operation.config.OperatorReference( operationType=orchestrator.core.operation.config.DiscoveryOperationEnum.MODIFY, operatorName="rifferla", ) @@ -81,42 +60,27 @@ def test_operator_module_conf( operator_module_conf: orchestrator.core.operation.config.OperatorModuleConf, ) -> None: - assert ( - operator_module_conf.operationType - == orchestrator.core.operation.config.DiscoveryOperationEnum.SEARCH - ) - assert ( - operator_module_conf.operatorIdentifier.split("-")[0] - == operator_module_conf.moduleClass.lower() - ) - + from orchestrator.modules.module import load_module_class_or_function -def test_operator_module_conf_random_walk() -> None: - - module = orchestrator.core.operation.config.OperatorModuleConf( - moduleName="orchestrator.modules.operators.randomwalk", - moduleClass="RandomWalk", - ) - - assert module.operatorIdentifier - assert isinstance(module.operatorIdentifier, str) assert ( - module.operationType + operator_module_conf.operationType == orchestrator.core.operation.config.DiscoveryOperationEnum.SEARCH ) - assert module.operatorIdentifier.split("-")[0] == "randomwalk" + cls = load_module_class_or_function(operator_module_conf) + expected_name = cls.operator_metadata().name + assert operator_module_conf.operatorIdentifier.startswith(f"{expected_name}-") def test_characterize(expected_characterize_operators: list[str]) -> None: assert len( - orchestrator.modules.operators.collections.characterize.list_operations() + orchestrator.modules.operators.collections.characterize.list_operators() ) == len(expected_characterize_operators) for operation in expected_characterize_operators: assert ( operation - in orchestrator.modules.operators.collections.characterize.list_operations() + in orchestrator.modules.operators.collections.characterize.list_operators() ) assert orchestrator.modules.operators.collections.characterize.__getattr__( operation @@ -126,13 +90,13 @@ def test_characterize(expected_characterize_operators: list[str]) -> None: def test_explore(expected_explore_operators: list[str]) -> None: assert len( - orchestrator.modules.operators.collections.explore.list_operations() + orchestrator.modules.operators.collections.explore.list_operators() ) == len(expected_explore_operators) for operation in expected_explore_operators: assert ( operation - in orchestrator.modules.operators.collections.explore.list_operations() + in orchestrator.modules.operators.collections.explore.list_operators() ) assert orchestrator.modules.operators.collections.explore.__getattr__(operation) @@ -142,7 +106,7 @@ def test_characterize_operator_function_configurations( ) -> None: for operationName in expected_characterize_operators: - operationConf = orchestrator.core.operation.config.OperatorFunctionConf( + operationConf = orchestrator.core.operation.config.OperatorReference( operatorName=operationName, operationType=orchestrator.core.operation.config.DiscoveryOperationEnum.CHARACTERIZE, ) @@ -154,11 +118,40 @@ def test_explore_operator_function_configurations( ) -> None: for operationName in expected_explore_operators: - operationConf = orchestrator.core.operation.config.OperatorFunctionConf( + operationConf = orchestrator.core.operation.config.OperatorReference( operatorName=operationName, operationType=orchestrator.core.operation.config.DiscoveryOperationEnum.SEARCH, ) assert operationConf is not None + assert operationConf.validateOperatorExists() + assert operationConf.operatorName == operationName + # operatorIdentifier must be -, matching ado get operators + assert operationConf.operatorIdentifier.startswith(f"{operationName}-") + + +def test_explore_operator_class_registration( + expected_explore_operators: list[str], +) -> None: + """Each explore operator must have its actor class registered in the collection.""" + for name in expected_explore_operators: + operator = orchestrator.modules.operators.collections.explore.operators[name] + cls = operator.cls + assert cls is not None + + +def test_explore_operator_function_conf_identifier_matches_registered_name() -> None: + """operatorIdentifier via OperatorReference must use the registered name.""" + for name in ["random_walk", "ray_tune"]: + conf = orchestrator.core.operation.config.OperatorReference( + operatorName=name, + operationType=orchestrator.core.operation.config.DiscoveryOperationEnum.SEARCH, + ) + # The identifier must start with the registered function name, not the + # class name (e.g. "random_walk-v0.1", not "randomwalk-1.7.1.dev...") + identifier = conf.operatorIdentifier + assert identifier.startswith( + f"{name}-" + ), f"Expected identifier to start with '{name}-', got '{identifier}'" def test_operator_function_configuration_incorrect_type( @@ -170,7 +163,7 @@ def test_operator_function_configuration_incorrect_type( ) for operator_name in expected_explore_operators: - operationConf = orchestrator.core.operation.config.OperatorFunctionConf( + operationConf = orchestrator.core.operation.config.OperatorReference( operatorName=operator_name, operationType=operation_type, ) @@ -191,7 +184,7 @@ def test_operator_function_configuration_unknown_function() -> None: orchestrator.core.operation.config.DiscoveryOperationEnum.CHARACTERIZE ) - operationConf = orchestrator.core.operation.config.OperatorFunctionConf( + operationConf = orchestrator.core.operation.config.OperatorReference( operatorName="UnknownOperationName", operationType=orchestrator.core.operation.config.DiscoveryOperationEnum.CHARACTERIZE, ) @@ -210,7 +203,7 @@ def test_operator_function_configuration_unknown_type() -> None: operator_name = "raytune" operation_type = orchestrator.core.operation.config.DiscoveryOperationEnum.STUDY - operationConf = orchestrator.core.operation.config.OperatorFunctionConf( + operationConf = orchestrator.core.operation.config.OperatorReference( operatorName=operator_name, operationType=operation_type, ) @@ -231,18 +224,23 @@ def test_random_walk_operation_configuration() -> None: RandomWalkParameters, ) - assert random_walk assert ( - orchestrator.modules.operators.collections.explore.configuration_model_for_operation( + orchestrator.modules.operators.collections.explore.operators[ "random_walk" - ) + ].function + is not None + ) + assert ( + orchestrator.modules.operators.collections.explore.operators[ + "random_walk" + ].configuration_model == RandomWalkParameters ) assert ( - orchestrator.modules.operators.collections.explore.default_configuration_model_for_operation( + orchestrator.modules.operators.collections.explore.operators[ "random_walk" - ) - == RandomWalk.defaultOperationParameters() + ].example_configuration + == RandomWalk.operator_metadata().example_configuration ) @@ -250,24 +248,28 @@ def test_raytune_operation_configuration( raytuneConf: DiscoveryOperationResourceConfiguration, ) -> None: - import ado_ray_tune.operator_function from ado_ray_tune.operator import ( RayTune, RayTuneConfiguration, ) - assert ado_ray_tune.operator_function.ray_tune assert ( - orchestrator.modules.operators.collections.explore.configuration_model_for_operation( + orchestrator.modules.operators.collections.explore.operators[ "ray_tune" - ) + ].function + is not None + ) + assert ( + orchestrator.modules.operators.collections.explore.operators[ + "ray_tune" + ].configuration_model == RayTuneConfiguration ) assert ( - orchestrator.modules.operators.collections.explore.default_configuration_model_for_operation( + orchestrator.modules.operators.collections.explore.operators[ "ray_tune" - ) - == RayTune.defaultOperationParameters() + ].example_configuration + == RayTune.operator_metadata().example_configuration ) @@ -288,12 +290,10 @@ def test_random_walk_config( import pydantic assert randomWalkConf is not None - assert RandomWalk.validateOperationParameters( - parameters=randomWalkConf.operation.parameters - ) + assert RandomWalkParameters.model_validate(randomWalkConf.operation.parameters) - parameters_model: RandomWalkParameters = RandomWalk.validateOperationParameters( - parameters=randomWalkConf.operation.parameters + parameters_model: RandomWalkParameters = RandomWalkParameters.model_validate( + randomWalkConf.operation.parameters ) # Test sampler @@ -308,7 +308,7 @@ def test_random_walk_config( parameters_dict["foo"] = "bar" with pytest.raises(pydantic.ValidationError): - RandomWalk.validateOperationParameters(parameters=parameters_dict) + RandomWalkParameters.model_validate(parameters_dict) # Test extra params not allowed @@ -317,7 +317,7 @@ def test_random_walk_config( parameters_dict["number-iterations"] = 6 with pytest.raises(pydantic.ValidationError): - RandomWalk.validateOperationParameters(parameters=parameters_dict) + RandomWalkParameters.model_validate(parameters_dict) def test_random_walk_custom_sampler_config() -> None: @@ -394,17 +394,16 @@ def test_ray_tune_config( """Test running a random_walk operation via the operation functions""" import pydantic + from ado_ray_tune.operator import RayTuneConfiguration assert raytuneConf is not None - assert RayTune.validateOperationParameters( - parameters=raytuneConf.operation.parameters - ) + assert RayTuneConfiguration.model_validate(raytuneConf.operation.parameters) parameters_dict = raytuneConf.operation.parameters.model_dump() parameters_dict["foo"] = "bar" with pytest.raises(pydantic.ValidationError): - RandomWalk.validateOperationParameters(parameters=parameters_dict) + RandomWalkParameters.model_validate(parameters_dict) def test_run_random_walk_operation( @@ -420,11 +419,14 @@ def test_run_random_walk_operation( assert discoverySpace is not None assert randomWalkConf is not None randomWalkConf.spaces[0] = ml_multi_cloud_space.uri - assert RandomWalk.validateOperationParameters( - parameters=randomWalkConf.operation.parameters - ) + assert RandomWalkParameters.model_validate(randomWalkConf.operation.parameters) - operationOutput = random_walk( + random_walk_fn = orchestrator.modules.operators.collections.explore.operators[ + "random_walk" + ].function + assert random_walk_fn is not None + + operationOutput = random_walk_fn( discoverySpace, **randomWalkConf.operation.parameters.model_dump() ) @@ -491,10 +493,15 @@ def test_random_walk_fail_invalid_config( # Note: Number of entities being greater than space size (valueGreaterThanSize) raises a ValueError # as it is detected at RandomWalk.run() not during configuration validation (which can't check this as it has no access to the space) # This is captured and raise as a OperationException + random_walk_fn = orchestrator.modules.operators.collections.explore.operators[ + "random_walk" + ].function + assert random_walk_fn is not None + try: - random_walk( + random_walk_fn( discoverySpace, **invalidRandomWalkConf.operation.parameters.model_dump() - ) # type: orchestrator.modules.operators.base.OperationOutput + ) except orchestrator.core.operation.operation.OperationException as error: operation = error.operation assert operation @@ -531,19 +538,22 @@ def test_run_ray_tune_operation( ) -> None: """Test running a ray_tune operation via the operation functions""" + from ado_ray_tune.operator import RayTuneConfiguration + import orchestrator.core.resources discoverySpace = ml_multi_cloud_space assert discoverySpace is not None assert raytuneConf is not None - assert RayTune.validateOperationParameters( - parameters=raytuneConf.operation.parameters - ) + assert RayTuneConfiguration.model_validate(raytuneConf.operation.parameters) - import ado_ray_tune.operator_function + ray_tune_fn = orchestrator.modules.operators.collections.explore.operators[ + "ray_tune" + ].function + assert ray_tune_fn is not None - operationOutput = ado_ray_tune.operator_function.ray_tune( + operationOutput = ray_tune_fn( discoverySpace, **raytuneConf.operation.parameters.model_dump() ) @@ -589,8 +599,315 @@ def test_operator_default_and_validate( optimizer_operator: type[RandomWalk] | type[RayTune], ) -> None: - assert optimizer_operator - default = optimizer_operator.defaultOperationParameters() - parameters = default.model_dump() if not isinstance(default, dict) else default + meta = optimizer_operator.operator_metadata() + assert meta.configuration_model is not None + assert meta.example_configuration is not None + parameters = meta.example_configuration.model_dump() + assert meta.configuration_model.model_validate(parameters) + + +# --------------------------------------------------------------------------- +# OperatorMetadata and explore_operation class-decorator tests +# --------------------------------------------------------------------------- + + +def test_operator_metadata_identifier_property() -> None: + """OperatorMetadata.operatorIdentifier returns '{name}-{version}'.""" + import pydantic + + from orchestrator.core.operation.config import ( + DiscoveryOperationEnum, + OperatorMetadata, + ) + + class _P(pydantic.BaseModel): + pass + + meta = OperatorMetadata( + name="my_op", + version="v2.0", + configuration_model=_P, + example_configuration=_P(), + type=DiscoveryOperationEnum.SEARCH, + ) + assert meta.operatorIdentifier == "my_op-v2.0" + + +def test_operator_metadata_identifier_default_version() -> None: + """OperatorMetadata.operatorIdentifier uses '0.1.0' when version is not supplied.""" + import pydantic + + from orchestrator.core.operation.config import ( + DiscoveryOperationEnum, + OperatorMetadata, + ) + + class _P(pydantic.BaseModel): + pass + + meta = OperatorMetadata( + name="op", + configuration_model=_P, + example_configuration=_P(), + type=DiscoveryOperationEnum.SEARCH, + ) + assert meta.operatorIdentifier == "op-0.1.0" + + +def test_operator_metadata_version_valid_pep440() -> None: + """OperatorMetadata accepts valid PEP 440 version strings.""" + import pydantic + + from orchestrator.core.operation.config import ( + DiscoveryOperationEnum, + OperatorMetadata, + ) + + class _P(pydantic.BaseModel): + pass + + valid_versions = [ + "0.1.0", + "1.2.3", + "1.0.0.dev5", + "1.7.1.dev82+1ee4e59.dirty", + "2.0.0a1", + "2.0.0rc1", + "v1.0.0", + ] + for ver in valid_versions: + meta = OperatorMetadata( + name="op", + version=ver, + configuration_model=_P, + example_configuration=_P(), + type=DiscoveryOperationEnum.SEARCH, + ) + assert meta.version == ver + + +def test_operator_metadata_version_invalid_pep440() -> None: + """OperatorMetadata rejects strings that are not valid PEP 440 versions.""" + import pydantic + + from orchestrator.core.operation.config import ( + DiscoveryOperationEnum, + OperatorMetadata, + ) + + class _P(pydantic.BaseModel): + pass + + invalid_versions = ["not-a-version", "hello", "1.0.0-final"] + for ver in invalid_versions: + with pytest.raises(pydantic.ValidationError, match="PEP 440"): + OperatorMetadata( + name="op", + version=ver, + configuration_model=_P, + example_configuration=_P(), + type=DiscoveryOperationEnum.SEARCH, + ) + + +def test_operator_function_conf_identifier_delegates_to_operator_metadata() -> None: + """OperatorReference.operatorIdentifier equals explore.operators[name].operatorIdentifier.""" + from orchestrator.core.operation.config import ( + DiscoveryOperationEnum, + OperatorReference, + ) + from orchestrator.modules.operators.collections import explore + + for name in ["random_walk", "ray_tune"]: + conf = OperatorReference( + operatorName=name, + operationType=DiscoveryOperationEnum.SEARCH, + ) + assert conf.operatorIdentifier == explore.operators[name].operatorIdentifier + + +def test_explore_operation_class_decorator_registers_function() -> None: + """@explore_operation returns the class unchanged and stores the OperatorFunction in the collection.""" + import inspect + + import pydantic + + from orchestrator.core.operation.config import ( + DiscoveryOperationEnum, + OperatorMetadata, + ) + from orchestrator.modules.operators.base import Explore + from orchestrator.modules.operators.collections import explore, explore_operation + + class _Params(pydantic.BaseModel): + pass + + @explore_operation + class _TestOp(Explore): + @classmethod + def operator_metadata(cls) -> OperatorMetadata: + return OperatorMetadata( + name="_test_class_op", + version="0.1.0", + description="A test operator.", + configuration_model=_Params, + example_configuration=_Params(), + type=DiscoveryOperationEnum.SEARCH, + ) + + def operationIdentifier(self) -> str: + return "_test_class_op-run" + + async def run(self) -> None: + pass + + # The decorator returns the class unchanged + assert isinstance(_TestOp, type) + assert issubclass(_TestOp, Explore) + + # The generated OperatorFunction is stored in the collection + assert "_test_class_op" in explore.operators + fn = explore.operators["_test_class_op"].function + assert callable(fn) + sig = inspect.signature(fn) + params = list(sig.parameters.keys()) + assert "discoverySpace" in params + assert "operationInfo" in params + + +def test_explore_operation_class_decorator_cls_stored() -> None: + """explore.operators[name].cls is the unwrapped class after class decoration.""" + import pydantic + + from orchestrator.core.operation.config import ( + DiscoveryOperationEnum, + OperatorMetadata, + ) + from orchestrator.modules.operators.base import Explore + from orchestrator.modules.operators.collections import explore, explore_operation + + class _ParamsCls(pydantic.BaseModel): + pass + + @explore_operation + class _TestOpCls(Explore): + @classmethod + def operator_metadata(cls) -> OperatorMetadata: + return OperatorMetadata( + name="_test_cls_stored", + version="0.1.0", + configuration_model=_ParamsCls, + example_configuration=_ParamsCls(), + type=DiscoveryOperationEnum.SEARCH, + ) + + async def run(self) -> None: + pass + + op = explore.operators.get("_test_cls_stored") + assert op is not None + assert op.cls is not None + + +def test_explore_operation_class_decorator_metadata_from_class() -> None: + """All OperatorMetadata fields come from the class's operator_metadata().""" + import pydantic + + from orchestrator.core.operation.config import ( + DiscoveryOperationEnum, + OperatorMetadata, + ) + from orchestrator.modules.operators.base import Explore + from orchestrator.modules.operators.collections import explore, explore_operation + + class _Params2(pydantic.BaseModel): + x: int = 42 + + @explore_operation + class _TestOp2(Explore): + @classmethod + def operator_metadata(cls) -> OperatorMetadata: + return OperatorMetadata( + name="_test_class_op2", + version="3.0.0", + description="Another test operator.", + configuration_model=_Params2, + example_configuration=_Params2(), + type=DiscoveryOperationEnum.SEARCH, + ) + + def operationIdentifier(self) -> str: + return "_test_class_op2-run" + + async def run(self) -> None: + pass + + registered = explore.operators["_test_class_op2"] + assert registered.name == "_test_class_op2" + assert registered.version == "3.0.0" + assert registered.description == "Another test operator." + assert registered.configuration_model is _Params2 + assert isinstance(registered.example_configuration, _Params2) + assert registered.type == DiscoveryOperationEnum.SEARCH + + +def test_explore_operation_class_decorator_missing_operator_metadata_raises() -> None: + """Decorating a Search subclass without operator_metadata() raises NotImplementedError.""" + + from orchestrator.modules.operators.base import Explore + from orchestrator.modules.operators.collections import explore_operation + + with pytest.raises(NotImplementedError): + + @explore_operation + class _BadOp(Explore): + # No operator_metadata() and no legacy classmethods — must raise. + async def run(self) -> None: + pass + + +def test_random_walk_registration() -> None: + from orchestrator.modules.operators.collections import explore + + assert "random_walk" in explore.operators + rw = explore.operators["random_walk"] + assert rw.name == "random_walk" + assert rw.cls is not None + assert callable(rw.function) + + +def test_ray_tune_registration() -> None: + from orchestrator.modules.operators.collections import explore + + assert "ray_tune" in explore.operators + rt = explore.operators["ray_tune"] + assert rt.name == "ray_tune" + assert rt.cls is not None + assert callable(rt.function) + + +def test_warn_if_operator_name_reused_logs_for_duplicate( + caplog: pytest.LogCaptureFixture, +) -> None: + """Reusing an operator name logs a warning before the registry entry is replaced.""" + from orchestrator.core.operation.config import ( + DiscoveryOperationEnum, + OperatorMetadata, + ) + from orchestrator.modules.operators.collections import _warn_if_operator_name_reused + + class _Cfg(pydantic.BaseModel): + pass + + placeholder = OperatorMetadata( + name="dup", + configuration_model=_Cfg, + example_configuration=_Cfg(), + type=DiscoveryOperationEnum.CHARACTERIZE, + ) + ops: dict[str, OperatorMetadata] = {"dup": placeholder} + + with caplog.at_level(logging.WARNING): + _warn_if_operator_name_reused("characterize", "dup", ops) - assert optimizer_operator.validateOperationParameters(parameters=parameters) + assert any("already registered" in r.getMessage() for r in caplog.records) diff --git a/tests/operators/test_ray_tune_validation.py b/tests/operators/test_ray_tune_validation.py new file mode 100644 index 000000000..727ae1c82 --- /dev/null +++ b/tests/operators/test_ray_tune_validation.py @@ -0,0 +1,115 @@ +# Copyright IBM Corporation 2025, 2026 +# SPDX-License-Identifier: MIT + +"""Tests for Ray Tune operator points_to_evaluate validation.""" + +import pytest +from ado_ray_tune.operator import _validate_points_to_evaluate + +from orchestrator.schema.domain import PropertyDomain +from orchestrator.schema.entityspace import EntitySpaceRepresentation +from orchestrator.schema.property import ConstitutiveProperty + + +def _pigeon10_entity_space() -> EntitySpaceRepresentation: + """Entity space matching cplex_mip_pigeon10: mps_file, node_selection, variable_selection.""" + mps_file = ConstitutiveProperty( + identifier="mps_file", + propertyDomain=PropertyDomain(values=["pigeon-10.mps.gz"]), + ) + node_selection = ConstitutiveProperty( + identifier="node_selection", + propertyDomain=PropertyDomain(values=[0, 1, 2, 3]), + ) + variable_selection = ConstitutiveProperty( + identifier="variable_selection", + propertyDomain=PropertyDomain(values=[0, 1, 2, 3]), + ) + return EntitySpaceRepresentation([mps_file, node_selection, variable_selection]) + + +def _minimal_entity_space() -> EntitySpaceRepresentation: + """Minimal 2-property entity space for unit tests.""" + cp1 = ConstitutiveProperty( + identifier="a", + propertyDomain=PropertyDomain(values=[1, 2, 3]), + ) + cp2 = ConstitutiveProperty( + identifier="b", + propertyDomain=PropertyDomain(values=["x", "y"]), + ) + return EntitySpaceRepresentation([cp1, cp2]) + + +def test_validate_points_to_evaluate_none_passes() -> None: + """None or empty points_to_evaluate should pass without error.""" + entity_space = _minimal_entity_space() + _validate_points_to_evaluate(None, entity_space) + _validate_points_to_evaluate([], entity_space) + + +def test_validate_points_to_evaluate_valid_point_passes() -> None: + """Valid complete points should pass.""" + entity_space = _minimal_entity_space() + _validate_points_to_evaluate( + [{"a": 1, "b": "x"}, {"a": 2, "b": "y"}], + entity_space, + ) + + +def test_validate_points_to_evaluate_missing_property_raises() -> None: + """Point missing a constitutive property should raise ValueError.""" + entity_space = _minimal_entity_space() + with pytest.raises(ValueError, match=r"missing properties.*\bb\b"): + _validate_points_to_evaluate([{"a": 1}], entity_space) + + +def test_validate_points_to_evaluate_extra_property_raises() -> None: + """Point with extra property not in space should raise ValueError.""" + entity_space = _minimal_entity_space() + with pytest.raises(ValueError, match=r"extra properties.*\bc\b"): + _validate_points_to_evaluate([{"a": 1, "b": "x", "c": 0}], entity_space) + + +def test_validate_points_to_evaluate_value_out_of_domain_raises() -> None: + """Point with value outside constitutive property domain should raise ValueError.""" + entity_space = _minimal_entity_space() + with pytest.raises(ValueError, match=r"invalid"): + _validate_points_to_evaluate([{"a": 99, "b": "x"}], entity_space) + + +def test_validate_points_to_evaluate_non_dict_raises() -> None: + """Point that is not a dict should raise ValueError.""" + entity_space = _minimal_entity_space() + with pytest.raises(ValueError, match=r"must be a dict.*got list"): + _validate_points_to_evaluate([{"a": 1, "b": "x"}, [1, 2, 3]], entity_space) + + +def test_validate_points_to_evaluate_pigeon10_space_bab6_points_raises() -> None: + """Regression: bab6 points_to_evaluate (missing mps_file) on pigeon10 space raises.""" + entity_space = _pigeon10_entity_space() + # Points from operation_lhs.yaml - designed for bab6, missing mps_file for pigeon10 + invalid_points = [ + { + "n_threads": 1, + "rins_frequency": 0, + "cut_passes": 0, + "node_selection": 1, + "variable_selection": 0, + } + ] + with pytest.raises(ValueError, match=r"missing properties.*mps_file"): + _validate_points_to_evaluate(invalid_points, entity_space) + + +def test_validate_points_to_evaluate_pigeon10_space_valid_points_passes() -> None: + """Complete pigeon10 points should pass.""" + entity_space = _pigeon10_entity_space() + valid_points = [ + { + "mps_file": "pigeon-10.mps.gz", + "node_selection": 1, + "variable_selection": 0, + } + ] + _validate_points_to_evaluate(valid_points, entity_space) diff --git a/tests/operators/test_validate_operator_function_signature.py b/tests/operators/test_validate_operator_function_signature.py new file mode 100644 index 000000000..e33d799da --- /dev/null +++ b/tests/operators/test_validate_operator_function_signature.py @@ -0,0 +1,144 @@ +# Copyright IBM Corporation 2025, 2026 +# SPDX-License-Identifier: MIT +"""Tests for validate_operator_function_signature.""" + +import pytest + +from orchestrator.core.discoveryspace.space import DiscoverySpace +from orchestrator.core.operation.config import FunctionOperationInfo +from orchestrator.core.operation.operation import OperationOutput +from orchestrator.modules.operators.base import validate_operator_function_signature + + +def valid_op( + discoverySpace: DiscoverySpace, + operationInfo: FunctionOperationInfo | None = None, + **kwargs: object, +) -> OperationOutput: ... + + +def valid_op_no_kwargs( + discoverySpace: DiscoverySpace, + operationInfo: FunctionOperationInfo | None = None, +) -> OperationOutput: ... + + +def no_annotations( + discoverySpace, # noqa: ANN001 + operationInfo=None, # noqa: ANN001 +) -> OperationOutput: ... + + +def missing_operation_info( + discoverySpace: DiscoverySpace, + **kwargs: object, +) -> OperationOutput: ... + + +def extra_positional( + discoverySpace: DiscoverySpace, + operationInfo: FunctionOperationInfo | None = None, + extra: int = 0, +) -> OperationOutput: ... + + +def wrong_first_param_type( + discoverySpace: int, + operationInfo: FunctionOperationInfo | None = None, +) -> OperationOutput: ... + + +def wrong_second_param_type( + discoverySpace: DiscoverySpace, + operationInfo: int = 0, +) -> OperationOutput: ... + + +def wrong_return_type( + discoverySpace: DiscoverySpace, + operationInfo: FunctionOperationInfo | None = None, +) -> int: ... + + +class TestValidSignatures: + def test_valid_with_kwargs(self) -> None: + """Full protocol-matching signature passes.""" + validate_operator_function_signature(valid_op) + + def test_valid_without_kwargs(self) -> None: + """Omitting **kwargs is allowed.""" + validate_operator_function_signature(valid_op_no_kwargs) + + +class TestHintIntrospectionFailure: + def test_unresolvable_forward_reference_raises(self) -> None: + """An unresolvable forward reference in annotations raises ValueError. + + A function whose annotation references a name that is not in scope + causes typing.get_type_hints to raise NameError. The function has + valid structure but the hints are unresolvable, so validation must + reject it rather than silently skipping the type checks. + """ + + def bad_hints( + discoverySpace: "UnresolvableType", # noqa: F821 + operationInfo: FunctionOperationInfo | None = None, + ) -> OperationOutput: ... + + with pytest.raises(ValueError, match="type hints are missing or unresolvable"): + validate_operator_function_signature(bad_hints) + + def test_uninspectable_fn_raises(self) -> None: + """A callable whose signature cannot be inspected must raise ValueError. + + Conformance cannot be confirmed when introspection fails, so silently + passing would allow invalid callables through. + """ + import inspect + import unittest.mock as mock + + fn = lambda: None # noqa: E731 + original_signature = inspect.signature + + def patched_signature(obj: object, **kwargs: object) -> inspect.Signature: + if obj is fn: + raise TypeError("not inspectable") + return original_signature(obj, **kwargs) # type: ignore[arg-type] + + with ( + mock.patch("inspect.signature", side_effect=patched_signature), + pytest.raises(ValueError, match="signature could not be inspected"), + ): + validate_operator_function_signature(fn) + + +class TestInvalidSignatures: + def test_missing_param_annotations_rejected(self) -> None: + """A function with unannotated parameters is rejected.""" + with pytest.raises(ValueError, match="type hints are missing"): + validate_operator_function_signature(no_annotations) + + def test_missing_operation_info(self) -> None: + """Function with only one positional parameter is rejected.""" + with pytest.raises(ValueError, match="operationInfo"): + validate_operator_function_signature(missing_operation_info) + + def test_extra_positional_parameter(self) -> None: + """Function with more positional parameters than the protocol is rejected.""" + with pytest.raises(ValueError, match="extra"): + validate_operator_function_signature(extra_positional) + + def test_wrong_first_param_type(self) -> None: + """Wrong type on the first positional parameter is rejected.""" + with pytest.raises(ValueError, match="discoverySpace"): + validate_operator_function_signature(wrong_first_param_type) + + def test_wrong_second_param_type(self) -> None: + """Wrong type on the second positional parameter is rejected.""" + with pytest.raises(ValueError, match="operationInfo"): + validate_operator_function_signature(wrong_second_param_type) + + def test_wrong_return_type(self) -> None: + """Wrong return type is rejected.""" + with pytest.raises(ValueError, match="return"): + validate_operator_function_signature(wrong_return_type) diff --git a/website/docs/operators/creating-operators.md b/website/docs/operators/creating-operators.md index 4d9e09409..c08a06efa 100644 --- a/website/docs/operators/creating-operators.md +++ b/website/docs/operators/creating-operators.md @@ -193,7 +193,7 @@ discussed in [explore operators](#creating-explore-operators). > created in. The operator function must return data using the -`orchestrators.core.operation.operation.OperationOutput` pydantic model. +`orchestrator.core.operation.operation.OperationOutput` pydantic model. ```python class OperationOutput(pydantic.BaseModel): @@ -495,77 +495,189 @@ try: # operator logic ... except KeyboardInterrupt as error: - # Assumes created_resources is an array containing all ado resource already created - raise InterruptedOperationError(resources=created_resources) from error -except ( - InterruptedOperationError -) as nested_operation_error: # This is when a nested operation was interrupted first - # IMPORTANT: You must add the identifier of the interrupted nested operation + # Assumes created_resources lists all ADO resources already created, and + # operation_id is the identifier string of this operation (from your context). raise InterruptedOperationError( - resources=created_resources, identifier=nested_operation_error.identifier + operation_identifier=operation_id, + resources=created_resources, ) from error +except InterruptedOperationError as nested_operation_error: + # Nested operation was interrupted first; propagate using its operation identifier. + raise InterruptedOperationError( + operation_identifier=nested_operation_error.operation_identifier, + resources=created_resources, + ) from nested_operation_error ``` ## Creating Explore Operators -Explore operators sample and measure entities. In `ado` all explore operation -run as distributed ray jobs with: - -- actuator ray actors for performing measurements -- discovery space manager actor for storing and notifying about measurement - results - -This means explore operators need to be implemented differently to the others, -in particular - -- The logic of your explore operator must be implemented as a ray actor (a - class) -- The explore operator functions must call this class i.e. you won't have any - operator logic in the function +Explore operators sample entities from a discovery space and submit them for +measurement. Unlike other operator types, the logic runs inside a **Ray +actor** and requires a class to be implemented. -### Explore operation functions +### Implementation -All explore operation functions follow this pattern: +1. Create a class that subclasses +`orchestrator.modules.operators.base.Explore` +2. Decorate it with `@explore_operation` +3. Implement, at least, the following methods: + - **`operator_metadata()`** — a classmethod returning an + `OperatorMetadata` instance that describes your operator. + - **`run()`** — an async method containing your operator logic + - **`onUpdate()`**, **`onCompleted`** and **`onError`** - methods + that handle notifications about completed measurements + +A simple example is show below: ```python -@explore_operation( - name="ray_tune", - description=RayTune.description(), - configuration_model=RayTuneConfiguration, - configuration_model_default=RayTuneConfiguration(), +import asyncio +import pydantic +from orchestrator.core.datacontainer.resource import DataContainer +from orchestrator.core.datacontainer.resource import DataContainerResource +from orchestrator.core.operation.config import DiscoveryOperationEnum, OperatorMetadata +from orchestrator.core.operation.operation import OperationOutput +from orchestrator.core.operation.resource import ( + OperationExitStateEnum, + OperationResourceEventEnum, + OperationResourceStatus, ) -def ray_tune( - discoverySpace: DiscoverySpace, - operationInfo: FunctionOperationInfo = FunctionOperationInfo(), - **kwargs: typing.Dict, -) -> OperationOutput: - """ - Performs an optimization on a given discoverySpace +from orchestrator.modules.operators.base import Explore, measure_or_replay +from orchestrator.modules.operators.collections import explore_operation - """ - from orchestrator.core.operation.config import OperatorModuleConf - from orchestrator.module.operator.orchestrate import orchestrate_explore_operation +class MySearchParameters(pydantic.BaseModel): + num_entities: int = 10 - ## This describes where the class that implements your explore operation is - module = OperatorModuleConf( - moduleName="ado_ray_tune.operator", # The name of the package containing your explore actor - moduleClass="RayTune", # The name of your explore actor class - ) +@explore_operation +class MySearchOperator(Explore): - # Tell ado to execute your class - return orchestrate_explore_operation( - discovery_space=discoverySpace, - module=module, - parameters=kwargs, - operation_info=operationInfo, # Important: This is where you must pass the operationInfo parameter to ado - ) -``` + def __init__(self, operationActorName, namespace, discovery_space_manager, actuators, params=None): + self.params = MySearchParameters(**(params or {})) + # Queue for completed-measurement notifications received via onUpdate + self.completed_measurements_queue = asyncio.Queue() + super().__init__( + operationActorName=operationActorName, + namespace=namespace, + discovery_space_manager=discovery_space_manager, + actuators=actuators, + ) -### Explore operator classes + @classmethod + def operator_metadata(cls) -> OperatorMetadata: + return OperatorMetadata( + name="my_search", + version="0.1.0", + description="A minimal example search operator.", + configuration_model=MySearchParameters, + example_configuration=MySearchParameters(), + type=DiscoveryOperationEnum.SEARCH, + ) + + # --- callbacks from DiscoverySpaceManager -------------------------------- + # onUpdate: called when a measurement completes + # onError: called on an unrecoverable error + + def onUpdate(self, measurementRequest) -> None: + self.completed_measurements_queue.put_nowait(measurementRequest) + + def onCompleted(self) -> None: + pass + + def onError(self, error: Exception) -> None: + self.completed_measurements_queue.put_nowait(error) + + # ------------------------------------------------------------------------- + + async def run(self) -> OperationOutput | None: + measurement_queue = await self.ds_manager.measurement_queue.remote() + ds = await self.ds_manager.discoverySpace.remote() + experiments = ds.measurementSpace.independentExperiments + + error_message = "" + + # Sample entities and submit them for measurement + submitted = 0 + async for entities in ...: # use your chosen sampling strategy + for experiment in experiments: + request_ids = measure_or_replay( + requestIndex=submitted, + requesterid=self.operationIdentifier(), + entities=entities, + experimentReference=experiment.reference, + actuators=self.actuators, + measurement_queue=measurement_queue, + memoize=False, + ) + submitted += len(request_ids) + + # Wait for all submitted measurements to complete + completed = 0 + while not error_message and completed < submitted: + item = await self.completed_measurements_queue.get() + if isinstance(item, Exception): + error_message = f"Discovery space manager error: {item}" + break + if item.operation_id == self.operationIdentifier(): + completed += 1 + + self.ds_manager.unsubscribeFromUpdates.remote(subscriberName=self.actorName) + + if error_message: + return OperationOutput( + exitStatus=OperationResourceStatus( + event=OperationResourceEventEnum.FINISHED, + exit_state=OperationExitStateEnum.FAIL, + message=error_message, + ) + ) + # exitStatus defaults to success when omitted + summary = DataContainer(data={"entities_submitted": submitted}) + return OperationOutput(resources=[DataContainerResource(config=summary)]) +``` -TBA +### Tips + +**Submit measurements in batches.** Submitting a batch at once, then waiting +for all of them to complete before sampling the next batch, is the simplest +pattern. A more advanced approach is to submit the next entity as soon as one +measurement finishes (continuous batching), which keeps actuators busy and +reduces idle time. + +**Use `measure_or_replay` for all submissions.** Do not call actuators +directly. `measure_or_replay` handles memoisation (reusing a prior +measurement result when `memoize=True`) and routes the request to the correct +actuator. + +**Use `self.operationIdentifier()` as the `requesterid`.** The update +notifications you receive via `onUpdate` include the `operation_id` that +created the request. Filtering on `measurement_request.operation_id == +self.operationIdentifier()` lets you ignore notifications from other +concurrent operations sharing the same space. + +**Unsubscribe before returning.** Call +`self.ds_manager.unsubscribeFromUpdates.remote(subscriberName=self.actorName)` +at the end of `run()` as a courtesy — it stops the `DiscoverySpaceManager` from +dispatching further `onUpdate` and `onCompleted` calls to an operator that has +already finished. + +### Error handling + +**Errors from `measure_or_replay`.** The function raises `KeyError` (no +actuator can handle the experiment) or `MeasurementError` (experiment is +deprecated for the actuator version in use). These don't have to be caught +as they are handled by `ado`. It records the operation with exit state `ERROR` including the full +exception message. You only need to catch them explicitly if you want finer +control. + +**Errors from the discovery space manager.** If the discovery space manager +encounters an unrecoverable problem it calls `onError`. This arrives asynchronously +and without explicit handling run() could wait forever for a new measurement result +to arrive. A recommended pattern to handle this is shown in the example above: `onError` +puts the exception onto the same `asyncio.Queue` that `onUpdate` uses for +completed measurements. In the wait loop, check whether the item dequeued is an +`Exception` to detect this sentinel and exit early, then return a failed +`OperationOutput`. ## Operator plugin packages diff --git a/website/docs/operators/explore_operators.md b/website/docs/operators/explore_operators.md index f290a1eb7..881cfd4ac 100644 --- a/website/docs/operators/explore_operators.md +++ b/website/docs/operators/explore_operators.md @@ -141,10 +141,8 @@ config: issue: "904" operation: module: - moduleClass: RandomWalk - moduleName: orchestrator.modules.operators.randomwalk - modulePath: . - moduleType: operation + operatorName: random_walk + operationType: search parameters: batchSize: 4 singleMeasurement: false