diff --git a/conftest.py b/conftest.py index 991c0d17b6..54a47f9e23 100644 --- a/conftest.py +++ b/conftest.py @@ -93,6 +93,7 @@ test_hostlist = None has_aprun = shutil.which("aprun") is not None + def get_account() -> str: return test_account @@ -227,7 +228,6 @@ def kill_all_test_spawned_processes() -> None: print("Not all processes were killed after test") - def get_hostlist() -> t.Optional[t.List[str]]: global test_hostlist if not test_hostlist: diff --git a/doc/changelog.md b/doc/changelog.md index 7d08c9376f..b0e326d1f7 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -13,12 +13,12 @@ Jump to: Description +- Implement asynchronous notifications for shared data - Quick bug fix in _validate - Add helper methods to MLI classes - Update error handling for consistency - Parameterize installation of dragon package with `smart build` - Update docstrings -- Implement asynchronous notifications for shared data - Filenames conform to snake case - Update SmartSim environment variables using new naming convention - Refactor `exception_handler` diff --git a/ex/high_throughput_inference/mock_app.py b/ex/high_throughput_inference/mock_app.py index dcc52296ef..c3b3eaaf4c 100644 --- a/ex/high_throughput_inference/mock_app.py +++ b/ex/high_throughput_inference/mock_app.py @@ -37,18 +37,10 @@ import argparse import io -import numpy -import os -import time + import torch -from mpi4py import MPI -from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( - DragonFeatureStore, -) -from smartsim._core.mli.message_handler import MessageHandler from smartsim.log import get_logger -from smartsim._core.utils.timings import PerfTimer torch.set_num_interop_threads(16) torch.set_num_threads(1) @@ -56,83 +48,24 @@ logger = get_logger("App") logger.info("Started app") -CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False +from collections import OrderedDict -class ProtoClient: - def __init__(self, timing_on: bool): - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - connect_to_infrastructure() - ddict_str = os.environ["_SMARTSIM_INFRA_BACKBONE"] - self._ddict = DDict.attach(ddict_str) - self._backbone_descriptor = DragonFeatureStore(self._ddict).descriptor - to_worker_fli_str = None - while to_worker_fli_str is None: - try: - to_worker_fli_str = self._ddict["to_worker_fli"] - self._to_worker_fli = fli.FLInterface.attach(to_worker_fli_str) - except KeyError: - time.sleep(1) - self._from_worker_ch = Channel.make_process_local() - self._from_worker_ch_serialized = self._from_worker_ch.serialize() - self._to_worker_ch = Channel.make_process_local() - - self.perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"a{rank}_") - - def run_model(self, model: bytes | str, batch: torch.Tensor): - tensors = [batch.numpy()] - self.perf_timer.start_timings("batch_size", batch.shape[0]) - built_tensor_desc = MessageHandler.build_tensor_descriptor( - "c", "float32", list(batch.shape) - ) - self.perf_timer.measure_time("build_tensor_descriptor") - if isinstance(model, str): - model_arg = MessageHandler.build_model_key(model, self._backbone_descriptor) - else: - model_arg = MessageHandler.build_model(model, "resnet-50", "1.0") - request = MessageHandler.build_request( - reply_channel=self._from_worker_ch_serialized, - model=model_arg, - inputs=[built_tensor_desc], - outputs=[], - output_descriptors=[], - custom_attributes=None, - ) - self.perf_timer.measure_time("build_request") - request_bytes = MessageHandler.serialize_request(request) - self.perf_timer.measure_time("serialize_request") - with self._to_worker_fli.sendh(timeout=None, stream_channel=self._to_worker_ch) as to_sendh: - to_sendh.send_bytes(request_bytes) - self.perf_timer.measure_time("send_request") - for tensor in tensors: - to_sendh.send_bytes(tensor.tobytes()) #TODO NOT FAST ENOUGH!!! - self.perf_timer.measure_time("send_tensors") - with self._from_worker_ch.recvh(timeout=None) as from_recvh: - resp = from_recvh.recv_bytes(timeout=None) - self.perf_timer.measure_time("receive_response") - response = MessageHandler.deserialize_response(resp) - self.perf_timer.measure_time("deserialize_response") - # list of data blobs? recv depending on the len(response.result.descriptors)? - data_blob: bytes = from_recvh.recv_bytes(timeout=None) - self.perf_timer.measure_time("receive_tensor") - result = torch.from_numpy( - numpy.frombuffer( - data_blob, - dtype=str(response.result.descriptors[0].dataType), - ) - ) - self.perf_timer.measure_time("deserialize_tensor") +from smartsim.log import get_logger, log_to_file +from smartsim._core.mli.client.protoclient import ProtoClient - self.perf_timer.end_timings() - return result +logger = get_logger("App") - def set_model(self, key: str, model: bytes): - self._ddict[key] = model +CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False class ResNetWrapper: + """Wrapper around a pre-rained ResNet model.""" def __init__(self, name: str, model: str): + """Initialize the instance. + + :param name: The name to use for the model + :param model: The path to the pre-trained PyTorch model""" self._model = torch.jit.load(model) self._name = name buffer = io.BytesIO() @@ -141,16 +74,28 @@ def __init__(self, name: str, model: str): self._serialized_model = buffer.getvalue() def get_batch(self, batch_size: int = 32): + """Create a random batch of data with the correct dimensions to + invoke a ResNet model. + + :param batch_size: The desired number of samples to produce + :returns: A PyTorch tensor""" return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32) @property - def model(self): + def model(self) -> bytes: + """The content of a model file. + + :returns: The model bytes""" return self._serialized_model @property - def name(self): + def name(self) -> str: + """The name applied to the model. + + :returns: The name""" return self._name + if __name__ == "__main__": parser = argparse.ArgumentParser("Mock application") @@ -166,24 +111,32 @@ def name(self): if CHECK_RESULTS_AND_MAKE_ALL_SLOWER: # TODO: adapt to non-Nvidia devices torch_device = args.device.replace("gpu", "cuda") - pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to(torch_device) + pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to( + torch_device + ) TOTAL_ITERATIONS = 100 - for log2_bsize in range(args.log_max_batchsize+1): + for log2_bsize in range(args.log_max_batchsize + 1): b_size: int = 2**log2_bsize logger.info(f"Batch size: {b_size}") - for iteration_number in range(TOTAL_ITERATIONS + int(b_size==1)): + for iteration_number in range(TOTAL_ITERATIONS + int(b_size == 1)): logger.info(f"Iteration: {iteration_number}") sample_batch = resnet.get_batch(b_size) remote_result = client.run_model(resnet.name, sample_batch) logger.info(client.perf_timer.get_last("total_time")) if CHECK_RESULTS_AND_MAKE_ALL_SLOWER: local_res = pt_model(sample_batch.to(torch_device)) - err_norm = torch.linalg.vector_norm(torch.flatten(remote_result).to(torch_device)-torch.flatten(local_res), ord=1).cpu() + err_norm = torch.linalg.vector_norm( + torch.flatten(remote_result).to(torch_device) + - torch.flatten(local_res), + ord=1, + ).cpu() res_norm = torch.linalg.vector_norm(remote_result, ord=1).item() local_res_norm = torch.linalg.vector_norm(local_res, ord=1).item() - logger.info(f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}") + logger.info( + f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}" + ) torch.cuda.synchronize() - client.perf_timer.print_timings(to_file=True) \ No newline at end of file + client.perf_timer.print_timings(to_file=True) diff --git a/ex/high_throughput_inference/standalone_worker_manager.py b/ex/high_throughput_inference/standalone_worker_manager.py index feb1af1aee..b4527bc5d2 100644 --- a/ex/high_throughput_inference/standalone_worker_manager.py +++ b/ex/high_throughput_inference/standalone_worker_manager.py @@ -37,6 +37,7 @@ from dragon.globalservices.api_setup import connect_to_infrastructure from dragon.managed_memory import MemoryPool from dragon.utils import b64decode, b64encode + # pylint enable=import-error # isort: off @@ -46,33 +47,27 @@ import base64 import multiprocessing as mp import os -import pickle import socket -import sys import time import typing as t import cloudpickle -import optparse -import os from smartsim._core.entrypoints.service import Service -from smartsim._core.mli.comm.channel.channel import CommChannelBase from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel -from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( - DragonFeatureStore, -) +from smartsim._core.mli.comm.channel.dragon_util import create_local from smartsim._core.mli.infrastructure.control.request_dispatcher import ( RequestDispatcher, ) from smartsim._core.mli.infrastructure.control.worker_manager import WorkerManager from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( DragonFeatureStore, ) -from smartsim._core.mli.infrastructure.worker.worker import MachineLearningWorkerBase - from smartsim.log import get_logger logger = get_logger("Worker Manager Entry Point") @@ -85,7 +80,6 @@ logger.info(f"CPUS: {os.cpu_count()}") - def service_as_dragon_proc( service: Service, cpu_affinity: list[int], gpu_affinity: list[int] ) -> dragon_process.Process: @@ -108,8 +102,6 @@ def service_as_dragon_proc( ) - - if __name__ == "__main__": parser = argparse.ArgumentParser("Worker Manager") parser.add_argument( @@ -143,27 +135,26 @@ def service_as_dragon_proc( args = parser.parse_args() connect_to_infrastructure() - ddict_str = os.environ["_SMARTSIM_INFRA_BACKBONE"] - ddict = DDict.attach(ddict_str) + ddict_str = os.environ[BackboneFeatureStore.MLI_BACKBONE] + + backbone = BackboneFeatureStore.from_descriptor(ddict_str) - to_worker_channel = Channel.make_process_local() + to_worker_channel = create_local() to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) - to_worker_fli_serialized = to_worker_fli.serialize() - ddict["to_worker_fli"] = to_worker_fli_serialized + to_worker_fli_comm_ch = DragonFLIChannel(to_worker_fli) + + backbone.worker_queue = to_worker_fli_comm_ch.descriptor + + os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = to_worker_fli_comm_ch.descriptor + os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor arg_worker_type = cloudpickle.loads( base64.b64decode(args.worker_class.encode("ascii")) ) - dfs = DragonFeatureStore(ddict) - comm_channel = DragonFLIChannel(to_worker_fli_serialized) - - descriptor = base64.b64encode(to_worker_fli_serialized).decode("utf-8") - os.environ["_SMARTSIM_REQUEST_QUEUE"] = descriptor - config_loader = EnvironmentConfigLoader( featurestore_factory=DragonFeatureStore.from_descriptor, - callback_factory=DragonCommChannel, + callback_factory=DragonCommChannel.from_descriptor, queue_factory=DragonFLIChannel.from_descriptor, ) @@ -178,7 +169,7 @@ def service_as_dragon_proc( worker_device = args.device for wm_idx in range(args.num_workers): - worker_manager = WorkerManager( + worker_manager = WorkerManager( config_loader=config_loader, worker_type=arg_worker_type, as_service=True, @@ -196,21 +187,25 @@ def service_as_dragon_proc( # the GPU-to-CPU mapping is taken from the nvidia-smi tool # TODO can this be computed on the fly? gpu_to_cpu_aff: dict[int, list[int]] = {} - gpu_to_cpu_aff[0] = list(range(48,64)) + list(range(112,128)) - gpu_to_cpu_aff[1] = list(range(32,48)) + list(range(96,112)) - gpu_to_cpu_aff[2] = list(range(16,32)) + list(range(80,96)) - gpu_to_cpu_aff[3] = list(range(0,16)) + list(range(64,80)) + gpu_to_cpu_aff[0] = list(range(48, 64)) + list(range(112, 128)) + gpu_to_cpu_aff[1] = list(range(32, 48)) + list(range(96, 112)) + gpu_to_cpu_aff[2] = list(range(16, 32)) + list(range(80, 96)) + gpu_to_cpu_aff[3] = list(range(0, 16)) + list(range(64, 80)) worker_manager_procs = [] for worker_idx in range(args.num_workers): wm_cpus = len(gpu_to_cpu_aff[worker_idx]) - 4 wm_affinity = gpu_to_cpu_aff[worker_idx][:wm_cpus] disp_affinity.extend(gpu_to_cpu_aff[worker_idx][wm_cpus:]) - worker_manager_procs.append(service_as_dragon_proc( + worker_manager_procs.append( + service_as_dragon_proc( worker_manager, cpu_affinity=wm_affinity, gpu_affinity=[worker_idx] - )) + ) + ) - dispatcher_proc = service_as_dragon_proc(dispatcher, cpu_affinity=disp_affinity, gpu_affinity=[]) + dispatcher_proc = service_as_dragon_proc( + dispatcher, cpu_affinity=disp_affinity, gpu_affinity=[] + ) # TODO: use ProcessGroup and restart=True? all_procs = [dispatcher_proc, *worker_manager_procs] diff --git a/smartsim/_core/_cli/scripts/dragon_install.py b/smartsim/_core/_cli/scripts/dragon_install.py index 4fd0be3004..b6666f7c8e 100644 --- a/smartsim/_core/_cli/scripts/dragon_install.py +++ b/smartsim/_core/_cli/scripts/dragon_install.py @@ -57,7 +57,7 @@ def __init__( def _check(self) -> None: """Perform validation of this instance - :raises: ValueError if any value fails validation""" + :raises ValueError: if any value fails validation""" if not self.repo_name or len(self.repo_name.split("/")) != 2: raise ValueError( f"Invalid dragon repository name. Example: `dragonhpc/dragon`" @@ -95,13 +95,13 @@ def get_auth_token(request: DragonInstallRequest) -> t.Optional[Token]: def create_dotenv(dragon_root_dir: pathlib.Path, dragon_version: str) -> None: """Create a .env file with required environment variables for the Dragon runtime""" dragon_root = str(dragon_root_dir) - dragon_inc_dir = str(dragon_root_dir / "include") - dragon_lib_dir = str(dragon_root_dir / "lib") - dragon_bin_dir = str(dragon_root_dir / "bin") + dragon_inc_dir = dragon_root + "/include" + dragon_lib_dir = dragon_root + "/lib" + dragon_bin_dir = dragon_root + "/bin" dragon_vars = { "DRAGON_BASE_DIR": dragon_root, - "DRAGON_ROOT_DIR": dragon_root, # note: same as base_dir + "DRAGON_ROOT_DIR": dragon_root, "DRAGON_INCLUDE_DIR": dragon_inc_dir, "DRAGON_LIB_DIR": dragon_lib_dir, "DRAGON_VERSION": dragon_version, @@ -286,7 +286,7 @@ def retrieve_asset( :param request: details of a request for the installation of the dragon package :param asset: GitHub release asset to retrieve :returns: path to the directory containing the extracted release asset - :raises: SmartSimCLIActionCancelled if the asset cannot be downloaded or extracted + :raises SmartSimCLIActionCancelled: if the asset cannot be downloaded or extracted """ download_dir = request.working_dir / str(asset.id) diff --git a/smartsim/_core/entrypoints/service.py b/smartsim/_core/entrypoints/service.py index 6b4ef74b67..719c2a60fe 100644 --- a/smartsim/_core/entrypoints/service.py +++ b/smartsim/_core/entrypoints/service.py @@ -35,26 +35,50 @@ class Service(ABC): - """Base contract for standalone entrypoint scripts. Defines API for entrypoint - behaviors (event loop, automatic shutdown, cooldown) as well as simple - hooks for status changes""" + """Core API for standalone entrypoint scripts. Makes use of overridable hook + methods to modify behaviors (event loop, automatic shutdown, cooldown) as + well as simple hooks for status changes""" def __init__( - self, as_service: bool = False, cooldown: int = 0, loop_delay: int = 0 + self, + as_service: bool = False, + cooldown: float = 0, + loop_delay: float = 0, + health_check_frequency: float = 0, ) -> None: - """Initialize the ServiceHost - :param as_service: Determines if the host will run until shutdown criteria - are met or as a run-once instance - :param cooldown: Period of time to allow service to run before automatic - shutdown, in seconds. A non-zero, positive integer. - :param loop_delay: delay between iterations of the event loop""" + """Initialize the Service + + :param as_service: Determines lifetime of the service. When `True`, calling + execute on the service will run continuously until shutdown criteria are met. + Otherwise, `execute` performs a single pass through the service lifecycle and + automatically exits (regardless of the result of `_can_shutdown`). + :param cooldown: Period of time (in seconds) to allow the service to run + after a shutdown is permitted. Enables the service to avoid restarting if + new work is discovered. A value of 0 disables the cooldown. + :param loop_delay: Duration (in seconds) of a forced delay between + iterations of the event loop + :param health_check_frequency: Time (in seconds) between calls to a + health check handler. A value of 0 triggers the health check on every + iteration. + """ self._as_service = as_service - """If the service should run until shutdown function returns True""" + """Determines lifetime of the service. When `True`, calling + `execute` on the service will run continuously until shutdown criteria are met. + Otherwise, `execute` performs a single pass through the service lifecycle and + automatically exits (regardless of the result of `_can_shutdown`).""" self._cooldown = abs(cooldown) - """Duration of a cooldown period between requests to the service - before shutdown""" + """Period of time (in seconds) to allow the service to run + after a shutdown is permitted. Enables the service to avoid restarting if + new work is discovered. A value of 0 disables the cooldown.""" self._loop_delay = abs(loop_delay) - """Forced delay between iterations of the event loop""" + """Duration (in seconds) of a forced delay between + iterations of the event loop""" + self._health_check_frequency = health_check_frequency + """Time (in seconds) between calls to a + health check handler. A value of 0 triggers the health check on every + iteration.""" + self._last_health_check = time.time() + """The timestamp of the latest health check""" @abstractmethod def _on_iteration(self) -> None: @@ -68,7 +92,7 @@ def _can_shutdown(self) -> bool: def _on_start(self) -> None: """Empty hook method for use by subclasses. Called on initial entry into - ServiceHost `execute` event loop before `_on_iteration` is invoked.""" + Service `execute` event loop before `_on_iteration` is invoked.""" logger.debug(f"Starting {self.__class__.__name__}") def _on_shutdown(self) -> None: @@ -76,6 +100,11 @@ def _on_shutdown(self) -> None: the main event loop during automatic shutdown.""" logger.debug(f"Shutting down {self.__class__.__name__}") + def _on_health_check(self) -> None: + """Empty hook method for use by subclasses. Invoked based on the + value of `self._health_check_frequency`.""" + logger.debug(f"Performing health check for {self.__class__.__name__}") + def _on_cooldown_elapsed(self) -> None: """Empty hook method for use by subclasses. Called on every event loop iteration immediately upon exceeding the cooldown period""" @@ -98,13 +127,30 @@ def execute(self) -> None: """The main event loop of a service host. Evaluates shutdown criteria and combines with a cooldown period to allow automatic service termination. Responsible for executing calls to subclass implementation of `_on_iteration`""" - self._on_start() + + try: + self._on_start() + except Exception: + logger.exception("Unable to start service.") + return running = True cooldown_start: t.Optional[datetime.datetime] = None while running: - self._on_iteration() + try: + self._on_iteration() + except Exception: + running = False + logger.exception( + "Failure in event loop resulted in service termination" + ) + + if self._health_check_frequency >= 0: + hc_elapsed = time.time() - self._last_health_check + if hc_elapsed >= self._health_check_frequency: + self._on_health_check() + self._last_health_check = time.time() # allow immediate shutdown if not set to run as a service if not self._as_service: @@ -133,4 +179,7 @@ def execute(self) -> None: self._on_delay() time.sleep(self._loop_delay) - self._on_shutdown() + try: + self._on_shutdown() + except Exception: + logger.exception("Service shutdown may not have completed.") diff --git a/smartsim/_core/launcher/dragon/dragonBackend.py b/smartsim/_core/launcher/dragon/dragonBackend.py index 7526af14ad..5e01299141 100644 --- a/smartsim/_core/launcher/dragon/dragonBackend.py +++ b/smartsim/_core/launcher/dragon/dragonBackend.py @@ -26,6 +26,8 @@ import collections import functools import itertools +import os +import socket import time import typing as t from dataclasses import dataclass, field @@ -34,18 +36,26 @@ from tabulate import tabulate -# pylint: disable=import-error +# pylint: disable=import-error,C0302,R0915 # isort: off -import dragon.data.ddict.ddict as dragon_ddict + import dragon.infrastructure.connection as dragon_connection import dragon.infrastructure.policy as dragon_policy import dragon.infrastructure.process_desc as dragon_process_desc -import dragon.native.group_state as dragon_group_state + import dragon.native.process as dragon_process import dragon.native.process_group as dragon_process_group import dragon.native.machine as dragon_machine from smartsim._core.launcher.dragon.pqueue import NodePrioritizer, PrioritizerFilter +from smartsim._core.mli.infrastructure.control.listener import ( + ConsumerRegistrationListener, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_util import create_ddict +from smartsim.error.errors import SmartSimError # pylint: enable=import-error # isort: on @@ -72,8 +82,8 @@ class DragonStatus(str, Enum): - ERROR = str(dragon_group_state.Error()) - RUNNING = str(dragon_group_state.Running()) + ERROR = "Error" + RUNNING = "Running" def __str__(self) -> str: return self.value @@ -90,7 +100,7 @@ class ProcessGroupInfo: return_codes: t.Optional[t.List[int]] = None """List of return codes of completed processes""" hosts: t.List[str] = field(default_factory=list) - """List of hosts on which the Process Group """ + """List of hosts on which the Process Group should be executed""" redir_workers: t.Optional[dragon_process_group.ProcessGroup] = None """Workers used to redirect stdout and stderr to file""" @@ -147,6 +157,11 @@ class DragonBackend: by threads spawned by it. """ + _DEFAULT_NUM_MGR_PER_NODE = 2 + """The default number of manager processes for each feature store node""" + _DEFAULT_MEM_PER_NODE = 512 * 1024**2 + """The default memory capacity (in bytes) to allocate for a feaure store node""" + def __init__(self, pid: int) -> None: self._pid = pid """PID of dragon executable which launched this server""" @@ -180,14 +195,12 @@ def __init__(self, pid: int) -> None: """Whether the server frontend should shut down when the backend does""" self._shutdown_initiation_time: t.Optional[float] = None """The time at which the server initiated shutdown""" - smartsim_config = get_config() - self._cooldown_period = ( - smartsim_config.telemetry_frequency * 2 + 5 - if smartsim_config.telemetry_enabled - else 5 - ) - """Time in seconds needed to server to complete shutdown""" - self._infra_ddict: t.Optional[dragon_ddict.DDict] = None + self._cooldown_period = self._initialize_cooldown() + """Time in seconds needed by the server to complete shutdown""" + self._backbone: t.Optional[BackboneFeatureStore] = None + """The backbone feature store""" + self._listener: t.Optional[dragon_process.Process] = None + """The standalone process executing the event consumer""" self._nodes: t.List["dragon_machine.Node"] = [] """Node capability information for hosts in the allocation""" @@ -201,8 +214,6 @@ def __init__(self, pid: int) -> None: """Mapping with hostnames as keys and a set of running step IDs as the value""" self._initialize_hosts() - self._view = DragonBackendView(self) - logger.debug(self._view.host_desc) self._prioritizer = NodePrioritizer(self._nodes, self._queue_lock) @property @@ -254,12 +265,11 @@ def status_message(self) -> str: :returns: a status message """ - return ( - "Dragon server backend update\n" - f"{self._view.host_table}\n{self._view.step_table}" - ) + view = DragonBackendView(self) + return "Dragon server backend update\n" f"{view.host_table}\n{view.step_table}" def _heartbeat(self) -> None: + """Update the value of the last heartbeat to the current time.""" self._last_beat = self.current_time @property @@ -539,21 +549,83 @@ def _stop_steps(self) -> None: self._group_infos[step_id].status = SmartSimStatus.STATUS_CANCELLED self._group_infos[step_id].return_codes = [-9] - @property - def infra_ddict(self) -> str: - """Create a Dragon distributed dictionary and return its - serialized descriptor + def _create_backbone(self) -> BackboneFeatureStore: + """ + Creates a BackboneFeatureStore if one does not exist. Updates + environment variables of this process to include the backbone + descriptor. + + :returns: The backbone feature store + """ + if self._backbone is None: + backbone_storage = create_ddict( + len(self._hosts), + self._DEFAULT_NUM_MGR_PER_NODE, + self._DEFAULT_MEM_PER_NODE, + ) + + self._backbone = BackboneFeatureStore( + backbone_storage, allow_reserved_writes=True + ) + + # put the backbone descriptor in the env vars + os.environ.update(self._backbone.get_env()) + + return self._backbone + + @staticmethod + def _initialize_cooldown() -> int: + """Load environment configuration and determine the correct cooldown + period to apply to the backend process. + + :returns: The calculated cooldown (in seconds) + """ + smartsim_config = get_config() + return ( + smartsim_config.telemetry_frequency * 2 + 5 + if smartsim_config.telemetry_enabled + else 5 + ) + + def start_event_listener( + self, cpu_affinity: list[int], gpu_affinity: list[int] + ) -> dragon_process.Process: + """Start a standalone event listener. + + :param cpu_affinity: The CPU affinity for the process + :param gpu_affinity: The GPU affinity for the process + :returns: The dragon Process managing the process + :raises SmartSimError: If the backbone is not provided """ - if self._infra_ddict is None: - logger.info("Creating DDict") - self._infra_ddict = dragon_ddict.DDict( - n_nodes=len(self._hosts), total_mem=len(self._hosts) * 1024**3 - ) # todo: parametrize - logger.info("Created DDict") - self._infra_ddict["creation"] = str(time.time()) - logger.info(self._infra_ddict["creation"]) + if self._backbone is None: + raise SmartSimError("Backbone feature store is not available") - return str(self._infra_ddict.serialize()) + service = ConsumerRegistrationListener( + self._backbone, 1.0, 2.0, as_service=True, health_check_frequency=90 + ) + + options = dragon_process_desc.ProcessOptions(make_inf_channels=True) + local_policy = dragon_policy.Policy( + placement=dragon_policy.Policy.Placement.HOST_NAME, + host_name=socket.gethostname(), + cpu_affinity=cpu_affinity, + gpu_affinity=gpu_affinity, + ) + process = dragon_process.Process( + target=service.execute, + args=[], + cwd=os.getcwd(), + env={ + **os.environ, + **self._backbone.get_env(), + }, + policy=local_policy, + options=options, + stderr=dragon_process.Popen.STDOUT, + stdout=dragon_process.Popen.STDOUT, + ) + process.start() + return process @staticmethod def create_run_policy( @@ -595,7 +667,9 @@ def create_run_policy( ) def _start_steps(self) -> None: + """Start all new steps created since the last update.""" self._heartbeat() + with self._queue_lock: started = [] for step_id, request in self._queued_steps.items(): @@ -622,7 +696,7 @@ def _start_steps(self) -> None: env={ **request.current_env, **request.env, - "_SMARTSIM_INFRA_BACKBONE": self.infra_ddict, + **(self._backbone.get_env() if self._backbone else {}), }, stdout=dragon_process.Popen.PIPE, stderr=dragon_process.Popen.PIPE, @@ -758,6 +832,9 @@ def _refresh_statuses(self) -> None: group_info.redir_workers = None def _update_shutdown_status(self) -> None: + """Query the status of running tasks and update the status + of any that have completed. + """ self._heartbeat() with self._queue_lock: self._can_shutdown |= ( @@ -771,6 +848,9 @@ def _update_shutdown_status(self) -> None: ) def _should_print_status(self) -> bool: + """Determine if status messages should be printed based off the last + update. Returns `True` to trigger prints, `False` otherwise. + """ if self.current_time - self._last_update_time > 10: self._last_update_time = self.current_time return True @@ -778,6 +858,8 @@ def _should_print_status(self) -> bool: def _update(self) -> None: """Trigger all update queries and update local state database""" + self._create_backbone() + self._stop_steps() self._start_steps() self._refresh_statuses() @@ -785,6 +867,9 @@ def _update(self) -> None: def _kill_all_running_jobs(self) -> None: with self._queue_lock: + if self._listener and self._listener.is_alive: + self._listener.kill() + for step_id, group_info in self._group_infos.items(): if group_info.status not in TERMINAL_STATUSES: self._stop_requests.append(DragonStopRequest(step_id=step_id)) @@ -872,6 +957,8 @@ def __init__(self, backend: DragonBackend) -> None: self._backend = backend """A dragon backend used to produce the view""" + logger.debug(self.host_desc) + @property def host_desc(self) -> str: hosts = self._backend.hosts diff --git a/smartsim/_core/launcher/dragon/dragonConnector.py b/smartsim/_core/launcher/dragon/dragonConnector.py index 0cd68c24e9..1144b7764e 100644 --- a/smartsim/_core/launcher/dragon/dragonConnector.py +++ b/smartsim/_core/launcher/dragon/dragonConnector.py @@ -71,17 +71,23 @@ class DragonConnector: def __init__(self) -> None: self._context: zmq.Context[t.Any] = zmq.Context.instance() + """ZeroMQ context used to share configuration across requests""" self._context.setsockopt(zmq.REQ_CORRELATE, 1) self._context.setsockopt(zmq.REQ_RELAXED, 1) self._authenticator: t.Optional[zmq.auth.thread.ThreadAuthenticator] = None + """ZeroMQ authenticator used to secure queue access""" config = get_config() self._reset_timeout(config.dragon_server_timeout) self._dragon_head_socket: t.Optional[zmq.Socket[t.Any]] = None + """ZeroMQ socket exposing the connection to the DragonBackend""" self._dragon_head_process: t.Optional[subprocess.Popen[bytes]] = None + """A handle to the process executing the DragonBackend""" # Returned by dragon head, useful if shutdown is to be requested # but process was started by another connector self._dragon_head_pid: t.Optional[int] = None + """Process ID of the process executing the DragonBackend""" self._dragon_server_path = config.dragon_server_path + """Path to a dragon installation""" logger.debug(f"Dragon Server path was set to {self._dragon_server_path}") self._env_vars: t.Dict[str, str] = {} if self._dragon_server_path is None: @@ -95,7 +101,7 @@ def __init__(self) -> None: @property def is_connected(self) -> bool: - """Whether the Connector established a connection to the server + """Whether the Connector established a connection to the server. :return: True if connected """ @@ -104,12 +110,18 @@ def is_connected(self) -> bool: @property def can_monitor(self) -> bool: """Whether the Connector knows the PID of the dragon server head process - and can monitor its status + and can monitor its status. :return: True if the server can be monitored""" return self._dragon_head_pid is not None def _handshake(self, address: str) -> None: + """Perform the handshake process with the DragonBackend and + confirm two-way communication is established. + + :param address: The address of the head node socket to initiate a + handhake with + """ self._dragon_head_socket = dragonSockets.get_secure_socket( self._context, zmq.REQ, False ) @@ -132,6 +144,11 @@ def _handshake(self, address: str) -> None: ) from e def _reset_timeout(self, timeout: int = get_config().dragon_server_timeout) -> None: + """Reset the timeout applied to the ZMQ context. If an authenticator is + enabled, also update the authenticator timeouts. + + :param timeout: The timeout value to apply to ZMQ sockets + """ self._context.setsockopt(zmq.SNDTIMEO, value=timeout) self._context.setsockopt(zmq.RCVTIMEO, value=timeout) if self._authenticator is not None and self._authenticator.thread is not None: @@ -183,11 +200,19 @@ def _get_new_authenticator( @staticmethod def _get_dragon_log_level() -> str: + """Maps the log level from SmartSim to a valid log level + for a dragon process. + + :returns: The dragon log level string + """ smartsim_to_dragon = defaultdict(lambda: "NONE") smartsim_to_dragon["developer"] = "INFO" return smartsim_to_dragon.get(get_config().log_level, "NONE") def _connect_to_existing_server(self, path: Path) -> None: + """Connects to an existing DragonBackend using address information from + a persisted dragon log file. + """ config = get_config() dragon_config_log = path / config.dragon_log_filename @@ -217,6 +242,11 @@ def _connect_to_existing_server(self, path: Path) -> None: return def _start_connector_socket(self, socket_addr: str) -> zmq.Socket[t.Any]: + """Instantiate the ZMQ socket to be used by the connector. + + :param socket_addr: The socket address the connector should bind to + :returns: The bound socket + """ config = get_config() connector_socket: t.Optional[zmq.Socket[t.Any]] = None self._reset_timeout(config.dragon_server_startup_timeout) @@ -245,9 +275,14 @@ def load_persisted_env(self) -> t.Dict[str, str]: with open(config.dragon_dotenv, encoding="utf-8") as dot_env: for kvp in dot_env.readlines(): - split = kvp.strip().split("=", maxsplit=1) - key, value = split[0], split[-1] - self._env_vars[key] = value + if not kvp: + continue + + # skip any commented lines + if not kvp.startswith("#"): + split = kvp.strip().split("=", maxsplit=1) + key, value = split[0], split[-1] + self._env_vars[key] = value return self._env_vars @@ -418,6 +453,15 @@ def send_request(self, request: DragonRequest, flags: int = 0) -> DragonResponse def _parse_launched_dragon_server_info_from_iterable( stream: t.Iterable[str], num_dragon_envs: t.Optional[int] = None ) -> t.List[t.Dict[str, str]]: + """Parses dragon backend connection information from a stream. + + :param stream: The stream to inspect. Usually the stdout of the + DragonBackend process + :param num_dragon_envs: The expected number of dragon environments + to parse from the stream. + :returns: A list of dictionaries, one per environment, containing + the parsed server information + """ lines = (line.strip() for line in stream) lines = (line for line in lines if line) tokenized = (line.split(maxsplit=1) for line in lines) @@ -444,6 +488,15 @@ def _parse_launched_dragon_server_info_from_files( file_paths: t.List[t.Union[str, "os.PathLike[str]"]], num_dragon_envs: t.Optional[int] = None, ) -> t.List[t.Dict[str, str]]: + """Read a known log file into a Stream and parse dragon server configuration + from the stream. + + :param file_paths: Path to a file containing dragon server configuration + :num_dragon_envs: The expected number of dragon environments to be found + in the file + :returns: The parsed server configuration, one item per + discovered dragon environment + """ with fileinput.FileInput(file_paths) as ifstream: dragon_envs = cls._parse_launched_dragon_server_info_from_iterable( ifstream, num_dragon_envs @@ -458,6 +511,15 @@ def _send_req_with_socket( send_flags: int = 0, recv_flags: int = 0, ) -> DragonResponse: + """Sends a synchronous request through a ZMQ socket. + + :param socket: Socket to send on + :param request: The request to send + :param send_flags: Configuration to apply to the send operation + :param recv_flags: Configuration to apply to the recv operation; used to + allow the receiver to immediately respond to the sent request. + :returns: The response from the target + """ client = dragonSockets.as_client(socket) with DRG_LOCK: logger.debug(f"Sending {type(request).__name__}: {request}") @@ -469,6 +531,13 @@ def _send_req_with_socket( def _assert_schema_type(obj: object, typ: t.Type[_SchemaT], /) -> _SchemaT: + """Verify that objects can be sent as messages acceptable to the target. + + :param obj: The message to test + :param typ: The type that is acceptable + :returns: The original `obj` if it is of the requested type + :raises TypeError: If the object fails the test and is not + an instance of the desired type""" if not isinstance(obj, typ): raise TypeError(f"Expected schema of type `{typ}`, but got {type(obj)}") return obj @@ -520,6 +589,12 @@ def _dragon_cleanup( def _resolve_dragon_path(fallback: t.Union[str, "os.PathLike[str]"]) -> Path: + """Determine the applicable dragon server path for the connector + + :param fallback: A default dragon server path to use if one is not + found in the runtime configuration + :returns: The path to the dragon libraries + """ dragon_server_path = get_config().dragon_server_path or os.path.join( fallback, ".smartsim", "dragon" ) diff --git a/smartsim/_core/mli/client/__init__.py b/smartsim/_core/mli/client/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/client/protoclient.py b/smartsim/_core/mli/client/protoclient.py new file mode 100644 index 0000000000..46598a8171 --- /dev/null +++ b/smartsim/_core/mli/client/protoclient.py @@ -0,0 +1,348 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# isort: off +# pylint: disable=unused-import,import-error +import dragon +import dragon.channels +from dragon.globalservices.api_setup import connect_to_infrastructure + +try: + from mpi4py import MPI # type: ignore[import-not-found] +except Exception: + MPI = None + print("Unable to import `mpi4py` package") + +# isort: on +# pylint: enable=unused-import,import-error + +import numbers +import os +import time +import typing as t +from collections import OrderedDict + +import numpy +import torch + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.event import OnWriteFeatureStore +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim._core.utils.timings import PerfTimer +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +_TimingDict = OrderedDict[str, list[str]] + + +logger = get_logger("App") +logger.info("Started app") +CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False + + +class ProtoClient: + """Proof of concept implementation of a client enabling user applications + to interact with MLI resources.""" + + _DEFAULT_BACKBONE_TIMEOUT = 1.0 + """A default timeout period applied to connection attempts with the + backbone feature store.""" + + _DEFAULT_WORK_QUEUE_SIZE = 500 + """A default number of events to be buffered in the work queue before + triggering QueueFull exceptions.""" + + _EVENT_SOURCE = "proto-client" + """A user-friendly name for this class instance to identify + the client as the publisher of an event.""" + + @staticmethod + def _attach_to_backbone() -> BackboneFeatureStore: + """Use the supplied environment variables to attach + to a pre-existing backbone featurestore. Requires the + environment to contain `_SMARTSIM_INFRA_BACKBONE` + environment variable. + + :returns: The attached backbone featurestore + :raises SmartSimError: If the backbone descriptor is not contained + in the appropriate environment variable + """ + descriptor = os.environ.get(BackboneFeatureStore.MLI_BACKBONE, None) + if descriptor is None or not descriptor: + raise SmartSimError( + "Missing required backbone configuration in environment: " + f"{BackboneFeatureStore.MLI_BACKBONE}" + ) + + backbone = t.cast( + BackboneFeatureStore, BackboneFeatureStore.from_descriptor(descriptor) + ) + return backbone + + def _attach_to_worker_queue(self) -> DragonFLIChannel: + """Wait until the backbone contains the worker queue configuration, + then attach an FLI to the given worker queue. + + :returns: The attached FLI channel + :raises SmartSimError: if the required configuration is not found in the + backbone feature store + """ + + descriptor = "" + try: + # NOTE: without wait_for, this MUST be in the backbone.... + config = self._backbone.wait_for( + [BackboneFeatureStore.MLI_WORKER_QUEUE], self.backbone_timeout + ) + descriptor = str(config[BackboneFeatureStore.MLI_WORKER_QUEUE]) + except Exception as ex: + logger.info( + f"Unable to retrieve {BackboneFeatureStore.MLI_WORKER_QUEUE} " + "to attach to the worker queue." + ) + raise SmartSimError("Unable to locate worker queue using backbone") from ex + + return DragonFLIChannel.from_descriptor(descriptor) + + def _create_broadcaster(self) -> EventBroadcaster: + """Create an EventBroadcaster that broadcasts events to + all MLI components registered to consume them. + + :returns: An EventBroadcaster instance + """ + broadcaster = EventBroadcaster( + self._backbone, DragonCommChannel.from_descriptor + ) + return broadcaster + + def __init__( + self, + timing_on: bool, + backbone_timeout: float = _DEFAULT_BACKBONE_TIMEOUT, + ) -> None: + """Initialize the client instance. + + :param timing_on: Flag indicating if timing information should be + written to file + :param backbone_timeout: Maximum wait time (in seconds) allowed to attach to the + worker queue + :raises SmartSimError: If unable to attach to a backbone featurestore + :raises ValueError: If an invalid backbone timeout is specified + """ + if MPI is not None: + # TODO: determine a way to make MPI work in the test environment + # - consider catching the import exception and defaulting rank to 0 + comm = MPI.COMM_WORLD + rank: int = comm.Get_rank() + else: + rank = 0 + + if backbone_timeout <= 0: + raise ValueError( + f"Invalid backbone timeout provided: {backbone_timeout}. " + "The value must be greater than zero." + ) + self._backbone_timeout = max(backbone_timeout, 0.1) + + connect_to_infrastructure() + + self._backbone = self._attach_to_backbone() + self._backbone.wait_timeout = self.backbone_timeout + self._to_worker_fli = self._attach_to_worker_queue() + + self._from_worker_ch = create_local(self._DEFAULT_WORK_QUEUE_SIZE) + self._to_worker_ch = create_local(self._DEFAULT_WORK_QUEUE_SIZE) + + self._publisher = self._create_broadcaster() + + self.perf_timer: PerfTimer = PerfTimer( + debug=False, timing_on=timing_on, prefix=f"a{rank}_" + ) + self._start: t.Optional[float] = None + self._interm: t.Optional[float] = None + self._timings: _TimingDict = OrderedDict() + self._timing_on = timing_on + + @property + def backbone_timeout(self) -> float: + """The timeout (in seconds) applied to retrievals + from the backbone feature store. + + :returns: A float indicating the number of seconds to allow""" + return self._backbone_timeout + + def _add_label_to_timings(self, label: str) -> None: + """Adds a new label into the timing dictionary to prepare for + receiving timing events. + + :param label: The label to create storage for + """ + if label not in self._timings: + self._timings[label] = [] + + @staticmethod + def _format_number(number: t.Union[numbers.Number, float]) -> str: + """Utility function for formatting numbers consistently for logs. + + :param number: The number to convert to a formatted string + :returns: The formatted string containing the number + """ + return f"{number:0.4e}" + + def start_timings(self, batch_size: numbers.Number) -> None: + """Configure the client to begin storing timing information. + + :param batch_size: The size of batches to generate as inputs + to the model + """ + if self._timing_on: + self._add_label_to_timings("batch_size") + self._timings["batch_size"].append(self._format_number(batch_size)) + self._start = time.perf_counter() + self._interm = time.perf_counter() + + def end_timings(self) -> None: + """Configure the client to stop storing timing information.""" + if self._timing_on and self._start is not None: + self._add_label_to_timings("total_time") + self._timings["total_time"].append( + self._format_number(time.perf_counter() - self._start) + ) + + def measure_time(self, label: str) -> None: + """Measures elapsed time since the last recorded signal. + + :param label: The label to measure time for + """ + if self._timing_on and self._interm is not None: + self._add_label_to_timings(label) + self._timings[label].append( + self._format_number(time.perf_counter() - self._interm) + ) + self._interm = time.perf_counter() + + def print_timings(self, to_file: bool = False) -> None: + """Print timing information to standard output. If `to_file` + is `True`, also write results to a file. + + :param to_file: If `True`, also saves timing information + to the files `timings.npy` and `timings.txt` + """ + print(" ".join(self._timings.keys())) + + value_array = numpy.array(self._timings.values(), dtype=float) + value_array = numpy.transpose(value_array) + for i in range(value_array.shape[0]): + print(" ".join(self._format_number(value) for value in value_array[i])) + if to_file: + numpy.save("timings.npy", value_array) + numpy.savetxt("timings.txt", value_array) + + def run_model(self, model: t.Union[bytes, str], batch: torch.Tensor) -> t.Any: + """Execute a batch of inference requests with the supplied ML model. + + :param model: The raw bytes or path to a pytorch model + :param batch: The tensor batch to perform inference on + :returns: The inference results + :raises ValueError: if the worker queue is not configured properly + in the environment variables + """ + tensors = [batch.numpy()] + self.perf_timer.start_timings("batch_size", batch.shape[0]) + built_tensor_desc = MessageHandler.build_tensor_descriptor( + "c", "float32", list(batch.shape) + ) + self.perf_timer.measure_time("build_tensor_descriptor") + if isinstance(model, str): + model_arg = MessageHandler.build_model_key(model, self._backbone.descriptor) + else: + model_arg = MessageHandler.build_model( + model, "resnet-50", "1.0" + ) # type: ignore + request = MessageHandler.build_request( + reply_channel=self._from_worker_ch.descriptor, + model=model_arg, + inputs=[built_tensor_desc], + outputs=[], + output_descriptors=[], + custom_attributes=None, + ) + self.perf_timer.measure_time("build_request") + request_bytes = MessageHandler.serialize_request(request) + self.perf_timer.measure_time("serialize_request") + + if self._to_worker_fli is None: + raise ValueError("No worker queue available.") + + # pylint: disable-next=protected-access + with self._to_worker_fli._channel.sendh( # type: ignore + timeout=None, + stream_channel=self._to_worker_ch.channel, + ) as to_sendh: + to_sendh.send_bytes(request_bytes) + self.perf_timer.measure_time("send_request") + for tensor in tensors: + to_sendh.send_bytes(tensor.tobytes()) # TODO NOT FAST ENOUGH!!! + logger.info(f"Message size: {len(request_bytes)} bytes") + + self.perf_timer.measure_time("send_tensors") + with self._from_worker_ch.channel.recvh(timeout=None) as from_recvh: + resp = from_recvh.recv_bytes(timeout=None) + self.perf_timer.measure_time("receive_response") + response = MessageHandler.deserialize_response(resp) + self.perf_timer.measure_time("deserialize_response") + + # recv depending on the len(response.result.descriptors)? + data_blob: bytes = from_recvh.recv_bytes(timeout=None) + self.perf_timer.measure_time("receive_tensor") + result = torch.from_numpy( + numpy.frombuffer( + data_blob, + dtype=str(response.result.descriptors[0].dataType), + ) + ) + self.perf_timer.measure_time("deserialize_tensor") + + self.perf_timer.end_timings() + return result + + def set_model(self, key: str, model: bytes) -> None: + """Write the supplied model to the feature store. + + :param key: The unique key used to identify the model + :param model: The raw bytes of the model to execute + """ + self._backbone[key] = model + + # notify components of a change in the data at this key + event = OnWriteFeatureStore(self._EVENT_SOURCE, self._backbone.descriptor, key) + self._publisher.send(event) diff --git a/smartsim/_core/mli/comm/channel/channel.py b/smartsim/_core/mli/comm/channel/channel.py index 9a12e4c8dc..104333ce7f 100644 --- a/smartsim/_core/mli/comm/channel/channel.py +++ b/smartsim/_core/mli/comm/channel/channel.py @@ -26,6 +26,7 @@ import base64 import typing as t +import uuid from abc import ABC, abstractmethod from smartsim.log import get_logger @@ -36,24 +37,31 @@ class CommChannelBase(ABC): """Base class for abstracting a message passing mechanism""" - def __init__(self, descriptor: t.Union[str, bytes]) -> None: + def __init__( + self, + descriptor: str, + name: t.Optional[str] = None, + ) -> None: """Initialize the CommChannel instance. :param descriptor: Channel descriptor """ self._descriptor = descriptor + """An opaque identifier used to connect to an underlying communication channel""" + self._name = name or str(uuid.uuid4()) + """A user-friendly identifier for channel-related logging""" @abstractmethod - def send(self, value: bytes, timeout: float = 0) -> None: + def send(self, value: bytes, timeout: float = 0.001) -> None: """Send a message through the underlying communication channel. - :param timeout: Maximum time to wait (in seconds) for messages to send :param value: The value to send + :param timeout: Maximum time to wait (in seconds) for messages to send :raises SmartSimError: If sending message fails """ @abstractmethod - def recv(self, timeout: float = 0) -> t.List[bytes]: + def recv(self, timeout: float = 0.001) -> t.List[bytes]: """Receives message(s) through the underlying communication channel. :param timeout: Maximum time to wait (in seconds) for messages to arrive @@ -61,11 +69,14 @@ def recv(self, timeout: float = 0) -> t.List[bytes]: """ @property - def descriptor(self) -> bytes: + def descriptor(self) -> str: """Return the channel descriptor for the underlying dragon channel. :returns: Byte encoded channel descriptor """ - if isinstance(self._descriptor, str): - return base64.b64decode(self._descriptor.encode("utf-8")) return self._descriptor + + def __str__(self) -> str: + """Build a string representation of the channel useful for printing.""" + classname = type(self).__class__.__name__ + return f"{classname}('{self._name}', '{self._descriptor}')" diff --git a/smartsim/_core/mli/comm/channel/dragon_channel.py b/smartsim/_core/mli/comm/channel/dragon_channel.py index 1363c0d675..110f19258a 100644 --- a/smartsim/_core/mli/comm/channel/dragon_channel.py +++ b/smartsim/_core/mli/comm/channel/dragon_channel.py @@ -24,65 +24,17 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import base64 -import sys import typing as t import dragon.channels as dch -import dragon.infrastructure.facts as df -import dragon.infrastructure.parameters as dp -import dragon.managed_memory as dm -import dragon.utils as du import smartsim._core.mli.comm.channel.channel as cch +import smartsim._core.mli.comm.channel.dragon_util as drg_util from smartsim.error.errors import SmartSimError from smartsim.log import get_logger logger = get_logger(__name__) -import dragon.channels as dch - -DEFAULT_CHANNEL_BUFFER_SIZE = 500 -"""Maximum number of messages that can be buffered. DragonCommChannel will -raise an exception if no clients consume messages before the buffer is filled.""" - - -def create_local(capacity: int = 0) -> dch.Channel: - """Creates a Channel attached to the local memory pool. - - :param capacity: The number of events the channel can buffer; uses the default - buffer size `DEFAULT_CHANNEL_BUFFER_SIZE` when not supplied - :returns: The instantiated channel - :raises SmartSimError: If unable to attach local channel - """ - pool = dm.MemoryPool.attach(du.B64.str_to_bytes(dp.this_process.default_pd)) - channel: t.Optional[dch.Channel] = None - offset = 0 - - capacity = capacity if capacity > 0 else DEFAULT_CHANNEL_BUFFER_SIZE - - while not channel: - # search for an open channel ID - offset += 1 - cid = df.BASE_USER_MANAGED_CUID + offset - try: - channel = dch.Channel( - mem_pool=pool, - c_uid=cid, - capacity=capacity, - ) - logger.debug( - f"Channel {cid} created in pool {pool.serialize()} w/capacity {capacity}" - ) - except Exception as e: - if offset < 100: - logger.warning(f"Unable to attach to channel id {cid}. Retrying...") - else: - logger.error(f"All attempts to attach local channel have failed") - raise SmartSimError("Failed to attach local channel") from e - - return channel - class DragonCommChannel(cch.CommChannelBase): """Passes messages by writing to a Dragon channel.""" @@ -92,10 +44,10 @@ def __init__(self, channel: "dch.Channel") -> None: :param channel: A channel to use for communications """ - serialized_ch = channel.serialize() - descriptor = base64.b64encode(serialized_ch).decode("utf-8") + descriptor = drg_util.channel_to_descriptor(channel) super().__init__(descriptor) self._channel = channel + """The underlying dragon channel used by this CommChannel for communications""" @property def channel(self) -> "dch.Channel": @@ -114,11 +66,11 @@ def send(self, value: bytes, timeout: float = 0.001) -> None: """ try: with self._channel.sendh(timeout=timeout) as sendh: - sendh.send_bytes(value) - logger.debug(f"DragonCommChannel {self.descriptor!r} sent message") + sendh.send_bytes(value, blocking=False) + logger.debug(f"DragonCommChannel {self.descriptor} sent message") except Exception as e: raise SmartSimError( - f"Error sending message: DragonCommChannel {self.descriptor!r}" + f"Error sending via DragonCommChannel {self.descriptor}" ) from e def recv(self, timeout: float = 0.001) -> t.List[bytes]: @@ -133,56 +85,43 @@ def recv(self, timeout: float = 0.001) -> t.List[bytes]: try: message_bytes = recvh.recv_bytes(timeout=timeout) messages.append(message_bytes) - logger.debug(f"DragonCommChannel {self.descriptor!r} received message") + logger.debug(f"DragonCommChannel {self.descriptor} received message") except dch.ChannelEmpty: # emptied the queue, ok to swallow this ex - logger.debug(f"DragonCommChannel exhausted: {self.descriptor!r}") - except dch.ChannelRecvTimeout as ex: - logger.debug(f"Timeout exceeded on channel.recv: {self.descriptor!r}") + logger.debug(f"DragonCommChannel exhausted: {self.descriptor}") + except dch.ChannelRecvTimeout: + logger.debug(f"Timeout exceeded on channel.recv: {self.descriptor}") return messages - @property - def descriptor_string(self) -> str: - """Return the channel descriptor for the underlying dragon channel - as a string. Automatically performs base64 encoding to ensure the - string can be used in a call to `from_descriptor`. - - :returns: String representation of channel descriptor - :raises ValueError: If unable to convert descriptor to a string - """ - if isinstance(self._descriptor, str): - return self._descriptor - - if isinstance(self._descriptor, bytes): - return base64.b64encode(self._descriptor).decode("utf-8") - - raise ValueError(f"Unable to convert channel descriptor: {self._descriptor}") - @classmethod def from_descriptor( cls, - descriptor: t.Union[bytes, str], + descriptor: str, ) -> "DragonCommChannel": """A factory method that creates an instance from a descriptor string. - :param descriptor: The descriptor that uniquely identifies the resource. Output - from `descriptor_string` is correctly encoded. + :param descriptor: The descriptor that uniquely identifies the resource. :returns: An attached DragonCommChannel :raises SmartSimError: If creation of comm channel fails """ try: - utf8_descriptor: t.Union[str, bytes] = descriptor - if isinstance(descriptor, str): - utf8_descriptor = descriptor.encode("utf-8") - - # todo: ensure the bytes argument and condition are removed - # after refactoring the RPC models - - actual_descriptor = base64.b64decode(utf8_descriptor) - channel = dch.Channel.attach(actual_descriptor) + channel = drg_util.descriptor_to_channel(descriptor) return DragonCommChannel(channel) except Exception as ex: raise SmartSimError( - f"Failed to create dragon comm channel: {descriptor!r}" + f"Failed to create dragon comm channel: {descriptor}" ) from ex + + @classmethod + def from_local(cls, _descriptor: t.Optional[str] = None) -> "DragonCommChannel": + """A factory method that creates a local channel instance. + + :param _descriptor: Unused placeholder + :returns: An attached DragonCommChannel""" + try: + channel = drg_util.create_local() + return DragonCommChannel(channel) + except: + logger.error(f"Failed to create local dragon comm channel", exc_info=True) + raise diff --git a/smartsim/_core/mli/comm/channel/dragon_fli.py b/smartsim/_core/mli/comm/channel/dragon_fli.py index 84d809c8ac..5fb0790a84 100644 --- a/smartsim/_core/mli/comm/channel/dragon_fli.py +++ b/smartsim/_core/mli/comm/channel/dragon_fli.py @@ -26,19 +26,14 @@ # isort: off from dragon import fli -import dragon.channels as dch -import dragon.infrastructure.facts as df -import dragon.infrastructure.parameters as dp -import dragon.managed_memory as dm -import dragon.utils as du +from dragon.channels import Channel # isort: on -import base64 import typing as t import smartsim._core.mli.comm.channel.channel as cch -from smartsim._core.mli.comm.channel.dragon_channel import create_local +import smartsim._core.mli.comm.channel.dragon_util as drg_util from smartsim.error.errors import SmartSimError from smartsim.log import get_logger @@ -50,36 +45,70 @@ class DragonFLIChannel(cch.CommChannelBase): def __init__( self, - fli_desc: bytes, - sender_supplied: bool = True, - buffer_size: int = 0, + fli_: fli.FLInterface, + buffer_size: int = drg_util.DEFAULT_CHANNEL_BUFFER_SIZE, ) -> None: """Initialize the DragonFLIChannel instance. - :param fli_desc: The descriptor of the FLI channel to attach + :param fli_: The FLIInterface to use as the underlying communications channel :param sender_supplied: Flag indicating if the FLI uses sender-supplied streams :param buffer_size: Maximum number of sent messages that can be buffered """ - super().__init__(fli_desc) - self._fli: "fli" = fli.FLInterface.attach(fli_desc) - self._channel: t.Optional["dch"] = ( - create_local(buffer_size) if sender_supplied else None - ) + descriptor = drg_util.channel_to_descriptor(fli_) + super().__init__(descriptor) + + self._channel: t.Optional["Channel"] = None + """The underlying dragon Channel used by a sender-side DragonFLIChannel + to attach to the main FLI channel""" + + self._fli = fli_ + """The underlying dragon FLInterface used by this CommChannel for communications""" + self._buffer_size: int = buffer_size + """Maximum number of messages that can be buffered before sending""" def send(self, value: bytes, timeout: float = 0.001) -> None: """Send a message through the underlying communication channel. - :param timeout: Maximum time to wait (in seconds) for messages to send :param value: The value to send + :param timeout: Maximum time to wait (in seconds) for messages to send :raises SmartSimError: If sending message fails """ try: + if self._channel is None: + self._channel = drg_util.create_local(self._buffer_size) + with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh: sendh.send_bytes(value, timeout=timeout) - logger.debug(f"DragonFLIChannel {self.descriptor!r} sent message") + logger.debug(f"DragonFLIChannel {self.descriptor} sent message") + except Exception as e: + self._channel = None + raise SmartSimError( + f"Error sending via DragonFLIChannel {self.descriptor}" + ) from e + + def send_multiple( + self, + values: t.Sequence[bytes], + timeout: float = 0.001, + ) -> None: + """Send a message through the underlying communication channel. + + :param values: The values to send + :param timeout: Maximum time to wait (in seconds) for messages to send + :raises SmartSimError: If sending message fails + """ + try: + if self._channel is None: + self._channel = drg_util.create_local(self._buffer_size) + + with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh: + for value in values: + sendh.send_bytes(value) + logger.debug(f"DragonFLIChannel {self.descriptor} sent message") except Exception as e: + self._channel = None raise SmartSimError( - f"Error sending message: DragonFLIChannel {self.descriptor!r}" + f"Error sending via DragonFLIChannel {self.descriptor} {e}" ) from e def recv(self, timeout: float = 0.001) -> t.List[bytes]: @@ -96,14 +125,13 @@ def recv(self, timeout: float = 0.001) -> t.List[bytes]: try: message, _ = recvh.recv_bytes(timeout=timeout) messages.append(message) - logger.debug( - f"DragonFLIChannel {self.descriptor!r} received message" - ) + logger.debug(f"DragonFLIChannel {self.descriptor} received message") except fli.FLIEOT: eot = True + logger.debug(f"DragonFLIChannel exhausted: {self.descriptor}") except Exception as e: raise SmartSimError( - f"Error receiving messages: DragonFLIChannel {self.descriptor!r}" + f"Error receiving messages: DragonFLIChannel {self.descriptor}" ) from e return messages @@ -116,13 +144,14 @@ def from_descriptor( :param descriptor: The descriptor that uniquely identifies the resource :returns: An attached DragonFLIChannel - :raises SmartSimError: If creation of DragonFLIChanenel fails + :raises SmartSimError: If creation of DragonFLIChannel fails + :raises ValueError: If the descriptor is invalid """ + if not descriptor: + raise ValueError("Invalid descriptor provided") + try: - return DragonFLIChannel( - fli_desc=base64.b64decode(descriptor), - sender_supplied=True, - ) + return DragonFLIChannel(fli_=drg_util.descriptor_to_fli(descriptor)) except Exception as e: raise SmartSimError( f"Error while creating DragonFLIChannel: {descriptor}" diff --git a/smartsim/_core/mli/comm/channel/dragon_util.py b/smartsim/_core/mli/comm/channel/dragon_util.py new file mode 100644 index 0000000000..8517979ec4 --- /dev/null +++ b/smartsim/_core/mli/comm/channel/dragon_util.py @@ -0,0 +1,131 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import binascii +import typing as t + +import dragon.channels as dch +import dragon.fli as fli +import dragon.managed_memory as dm + +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + +DEFAULT_CHANNEL_BUFFER_SIZE = 500 +"""Maximum number of messages that can be buffered. DragonCommChannel will +raise an exception if no clients consume messages before the buffer is filled.""" + +LAST_OFFSET = 0 +"""The last offset used to create a local channel. This is used to avoid +unnecessary retries when creating a local channel.""" + + +def channel_to_descriptor(channel: t.Union[dch.Channel, fli.FLInterface]) -> str: + """Convert a dragon channel to a descriptor string. + + :param channel: The dragon channel to convert + :returns: The descriptor string + :raises ValueError: If a dragon channel is not provided + """ + if channel is None: + raise ValueError("Channel is not available to create a descriptor") + + serialized_ch = channel.serialize() + return base64.b64encode(serialized_ch).decode("utf-8") + + +def pool_to_descriptor(pool: dm.MemoryPool) -> str: + """Convert a dragon memory pool to a descriptor string. + + :param pool: The memory pool to convert + :returns: The descriptor string + :raises ValueError: If a memory pool is not provided + """ + if pool is None: + raise ValueError("Memory pool is not available to create a descriptor") + + serialized_pool = pool.serialize() + return base64.b64encode(serialized_pool).decode("utf-8") + + +def descriptor_to_fli(descriptor: str) -> "fli.FLInterface": + """Create and attach a new FLI instance given + the string-encoded descriptor. + + :param descriptor: The descriptor of an FLI to attach to + :returns: The attached dragon FLI + :raises ValueError: If the descriptor is empty or incorrectly formatted + :raises SmartSimError: If attachment using the descriptor fails + """ + if len(descriptor) < 1: + raise ValueError("Descriptors may not be empty") + + try: + encoded = descriptor.encode("utf-8") + descriptor_ = base64.b64decode(encoded) + return fli.FLInterface.attach(descriptor_) + except binascii.Error: + raise ValueError("The descriptor was not properly base64 encoded") + except fli.DragonFLIError: + raise SmartSimError("The descriptor did not address an available FLI") + + +def descriptor_to_channel(descriptor: str) -> dch.Channel: + """Create and attach a new Channel instance given + the string-encoded descriptor. + + :param descriptor: The descriptor of a channel to attach to + :returns: The attached dragon Channel + :raises ValueError: If the descriptor is empty or incorrectly formatted + :raises SmartSimError: If attachment using the descriptor fails + """ + if len(descriptor) < 1: + raise ValueError("Descriptors may not be empty") + + try: + encoded = descriptor.encode("utf-8") + descriptor_ = base64.b64decode(encoded) + return dch.Channel.attach(descriptor_) + except binascii.Error: + raise ValueError("The descriptor was not properly base64 encoded") + except dch.ChannelError: + raise SmartSimError("The descriptor did not address an available channel") + + +def create_local(_capacity: int = 0) -> dch.Channel: + """Creates a Channel attached to the local memory pool. Replacement for + direct calls to `dch.Channel.make_process_local()` to enable + supplying a channel capacity. + + :param _capacity: The number of events the channel can buffer; uses the default + buffer size `DEFAULT_CHANNEL_BUFFER_SIZE` when not supplied + :returns: The instantiated channel + """ + channel = dch.Channel.make_process_local() + return channel diff --git a/smartsim/_core/mli/infrastructure/comm/__init__.py b/smartsim/_core/mli/infrastructure/comm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/infrastructure/comm/broadcaster.py b/smartsim/_core/mli/infrastructure/comm/broadcaster.py new file mode 100644 index 0000000000..56dcf549f7 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/comm/broadcaster.py @@ -0,0 +1,239 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t +import uuid +from collections import defaultdict, deque + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim._core.mli.infrastructure.comm.event import EventBase +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class BroadcastResult(t.NamedTuple): + """Contains summary details about a broadcast.""" + + num_sent: int + """The total number of messages delivered across all consumers""" + num_failed: int + """The total number of messages not delivered across all consumers""" + + +class EventBroadcaster: + """Performs fan-out publishing of system events.""" + + def __init__( + self, + backbone: BackboneFeatureStore, + channel_factory: t.Optional[t.Callable[[str], CommChannelBase]] = None, + name: t.Optional[str] = None, + ) -> None: + """Initialize the EventPublisher instance. + + :param backbone: The MLI backbone feature store + :param channel_factory: Factory method to construct new channel instances + :param name: A user-friendly name for logging. If not provided, an + auto-generated GUID will be used + """ + self._backbone = backbone + """The backbone feature store used to retrieve consumer descriptors""" + self._channel_factory = channel_factory + """A factory method used to instantiate channels from descriptors""" + self._channel_cache: t.Dict[str, t.Optional[CommChannelBase]] = defaultdict( + lambda: None + ) + """A mapping of instantiated channels that can be re-used. Automatically + calls the channel factory if a descriptor is not already in the collection""" + self._event_buffer: t.Deque[EventBase] = deque() + """A buffer for storing events when a consumer list is not found""" + self._descriptors: t.Set[str] + """Stores the most recent list of broadcast consumers. Updated automatically + on each broadcast""" + self._name = name or str(uuid.uuid4()) + """A unique identifer assigned to the broadcaster for logging""" + + @property + def name(self) -> str: + """The friendly name assigned to the broadcaster. + + :returns: The broadcaster name if one is assigned, otherwise a unique + id assigned by the system. + """ + return self._name + + @property + def num_buffered(self) -> int: + """Return the number of events currently buffered to send. + + :returns: Number of buffered events + """ + return len(self._event_buffer) + + def _save_to_buffer(self, event: EventBase) -> None: + """Places the event in the buffer to be sent once a consumer + list is available. + + :param event: The event to buffer + :raises ValueError: If the event cannot be buffered + """ + try: + self._event_buffer.append(event) + logger.debug(f"Buffered event {event=}") + except Exception as ex: + raise ValueError( + f"Unable to buffer event {event} in broadcaster {self.name}" + ) from ex + + def _log_broadcast_start(self) -> None: + """Logs broadcast statistics.""" + num_events = len(self._event_buffer) + num_copies = len(self._descriptors) + logger.debug( + f"Broadcast {num_events} events to {num_copies} consumers from {self.name}" + ) + + def _prune_unused_consumers(self) -> None: + """Performs maintenance on the channel cache by pruning any channel + that has been removed from the consumers list.""" + active_consumers = set(self._descriptors) + current_channels = set(self._channel_cache.keys()) + + # find any cached channels that are now unused + inactive_channels = current_channels.difference(active_consumers) + new_channels = active_consumers.difference(current_channels) + + for descriptor in inactive_channels: + self._channel_cache.pop(descriptor) + + logger.debug( + f"Pruning {len(inactive_channels)} stale consumers and" + f" found {len(new_channels)} new channels for {self.name}" + ) + + def _get_comm_channel(self, descriptor: str) -> CommChannelBase: + """Helper method to build and cache a comm channel. + + :param descriptor: The descriptor to pass to the channel factory + :returns: The instantiated channel + :raises SmartSimError: If the channel fails to attach + """ + comm_channel = self._channel_cache[descriptor] + if comm_channel is not None: + return comm_channel + + if self._channel_factory is None: + raise SmartSimError("No channel factory provided for consumers") + + try: + channel = self._channel_factory(descriptor) + self._channel_cache[descriptor] = channel + return channel + except Exception as ex: + msg = f"Unable to construct channel with descriptor: {descriptor}" + logger.error(msg, exc_info=True) + raise SmartSimError(msg) from ex + + def _get_next_event(self) -> t.Optional[EventBase]: + """Pop the next event to be sent from the queue. + + :returns: The next event to send if any events are enqueued, otherwise `None`. + """ + try: + return self._event_buffer.popleft() + except IndexError: + logger.debug(f"Broadcast buffer exhausted for {self.name}") + + return None + + def _broadcast(self, timeout: float = 0.001) -> BroadcastResult: + """Broadcasts all buffered events to registered event consumers. + + :param timeout: Maximum time to wait (in seconds) for messages to send + :returns: BroadcastResult containing the number of messages that were + successfully and unsuccessfully sent for all consumers + :raises SmartSimError: If the channel fails to attach or broadcasting fails + """ + # allow descriptors to be empty since events are buffered + self._descriptors = set(x for x in self._backbone.notification_channels if x) + if not self._descriptors: + msg = f"No event consumers are registered for {self.name}" + logger.warning(msg) + return BroadcastResult(0, 0) + + self._prune_unused_consumers() + self._log_broadcast_start() + + num_listeners = len(self._descriptors) + num_sent = 0 + num_failures = 0 + + # send each event to every consumer + while event := self._get_next_event(): + logger.debug(f"Broadcasting {event=} to {num_listeners} listeners") + event_bytes = bytes(event) + + for i, descriptor in enumerate(self._descriptors): + comm_channel = self._get_comm_channel(descriptor) + + try: + comm_channel.send(event_bytes, timeout) + num_sent += 1 + except Exception: + msg = ( + f"Broadcast {i+1}/{num_listeners} for event {event.uid} to " + f"channel {descriptor} from {self.name} failed." + ) + logger.exception(msg) + num_failures += 1 + + return BroadcastResult(num_sent, num_failures) + + def send(self, event: EventBase, timeout: float = 0.001) -> int: + """Implementation of `send` method of the `EventPublisher` protocol. Publishes + the supplied event to all registered broadcast consumers. + + :param event: An event to publish + :param timeout: Maximum time to wait (in seconds) for messages to send + :returns: The total number of events successfully published to consumers + :raises ValueError: If event serialization fails + :raises AttributeError: If event cannot be serialized + :raises KeyError: If channel fails to attach using registered descriptors + :raises SmartSimError: If any unexpected error occurs during send + """ + try: + self._save_to_buffer(event) + result = self._broadcast(timeout) + return result.num_sent + except (KeyError, ValueError, AttributeError, SmartSimError): + raise + except Exception as ex: + raise SmartSimError("An unexpected failure occurred while sending") from ex diff --git a/smartsim/_core/mli/infrastructure/comm/consumer.py b/smartsim/_core/mli/infrastructure/comm/consumer.py new file mode 100644 index 0000000000..08b5c47852 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/comm/consumer.py @@ -0,0 +1,281 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pickle +import time +import typing as t +import uuid + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.infrastructure.comm.event import ( + EventBase, + OnCreateConsumer, + OnRemoveConsumer, + OnShutdownRequested, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class EventConsumer: + """Reads system events published to a communications channel.""" + + _BACKBONE_WAIT_TIMEOUT = 10.0 + """Maximum time (in seconds) to wait for the backbone to register the consumer""" + + def __init__( + self, + comm_channel: CommChannelBase, + backbone: BackboneFeatureStore, + filters: t.Optional[t.List[str]] = None, + name: t.Optional[str] = None, + event_handler: t.Optional[t.Callable[[EventBase], None]] = None, + ) -> None: + """Initialize the EventConsumer instance. + + :param comm_channel: Communications channel to listen to for events + :param backbone: The MLI backbone feature store + :param filters: A list of event types to deliver. when empty, all + events will be delivered + :param name: A user-friendly name for logging. If not provided, an + auto-generated GUID will be used + """ + self._comm_channel = comm_channel + """The comm channel used by the consumer to receive messages. The channel + descriptor will be published for senders to discover.""" + self._backbone = backbone + """The backbone instance used to bootstrap the instance. The EventConsumer + uses the backbone to discover where it can publish its descriptor.""" + self._global_filters = filters or [] + """A set of global filters to apply to incoming events. Global filters are + combined with per-call filters. Filters act as an allow-list.""" + self._name = name or str(uuid.uuid4()) + """User-friendly name assigned to a consumer for logging. Automatically + assigned if not provided.""" + self._event_handler = event_handler + """The function that should be executed when an event + passed by the filters is received.""" + self.listening = True + """Flag indicating that the consumer is currently listening for new + events. Setting this flag to `False` will cause any active calls to + `listen` to terminate.""" + + @property + def descriptor(self) -> str: + """The descriptor of the underlying comm channel. + + :returns: The comm channel descriptor""" + return self._comm_channel.descriptor + + @property + def name(self) -> str: + """The friendly name assigned to the consumer. + + :returns: The consumer name if one is assigned, otherwise a unique + id assigned by the system. + """ + return self._name + + def recv( + self, + filters: t.Optional[t.List[str]] = None, + timeout: float = 0.001, + batch_timeout: float = 1.0, + ) -> t.List[EventBase]: + """Receives available published event(s). + + :param filters: Additional filters to add to the global filters configured + on the EventConsumer instance + :param timeout: Maximum time to wait for a single message to arrive + :param batch_timeout: Maximum time to wait for messages to arrive; allows + multiple batches to be retrieved in one call to `send` + :returns: A list of events that pass any configured filters + :raises ValueError: If a positive, non-zero value is not provided for the + timeout or batch_timeout. + """ + if filters is None: + filters = [] + + if timeout is not None and timeout <= 0: + raise ValueError("request timeout must be a non-zero, positive value") + + if batch_timeout is not None and batch_timeout <= 0: + raise ValueError("batch_timeout must be a non-zero, positive value") + + filter_set = {*self._global_filters, *filters} + all_message_bytes: t.List[bytes] = [] + + # firehose as many messages as possible within the batch_timeout + start_at = time.time() + remaining = batch_timeout + + batch_message_bytes = self._comm_channel.recv(timeout=timeout) + while batch_message_bytes: + # remove any empty messages that will fail to decode + all_message_bytes.extend(batch_message_bytes) + batch_message_bytes = [] + + # avoid getting stuck indefinitely waiting for the channel + elapsed = time.time() - start_at + remaining = batch_timeout - elapsed + + if remaining > 0: + batch_message_bytes = self._comm_channel.recv(timeout=timeout) + + events_received: t.List[EventBase] = [] + + # Timeout elapsed or no messages received - return the empty list + if not all_message_bytes: + return events_received + + for message in all_message_bytes: + if not message or message is None: + continue + + event = pickle.loads(message) + if not event: + logger.warning(f"Consumer {self.name} is unable to unpickle message") + continue + + # skip events that don't pass a filter + if filter_set and event.category not in filter_set: + continue + + events_received.append(event) + + return events_received + + def _send_to_registrar(self, event: EventBase) -> None: + """Send an event direct to the registrar listener.""" + registrar_key = BackboneFeatureStore.MLI_REGISTRAR_CONSUMER + config = self._backbone.wait_for([registrar_key], self._BACKBONE_WAIT_TIMEOUT) + registrar_descriptor = str(config.get(registrar_key, None)) + + if not registrar_descriptor: + logger.warning( + f"Unable to send {event.category} from {self.name}. " + "No registrar channel found." + ) + return + + logger.debug(f"Sending {event.category} from {self.name}") + + registrar_channel = DragonCommChannel.from_descriptor(registrar_descriptor) + registrar_channel.send(bytes(event), timeout=1.0) + + logger.debug(f"{event.category} from {self.name} sent") + + def register(self) -> None: + """Send an event to register this consumer as a listener.""" + descriptor = self._comm_channel.descriptor + event = OnCreateConsumer(self.name, descriptor, self._global_filters) + + self._send_to_registrar(event) + + def unregister(self) -> None: + """Send an event to un-register this consumer as a listener.""" + descriptor = self._comm_channel.descriptor + event = OnRemoveConsumer(self.name, descriptor) + + self._send_to_registrar(event) + + def _on_handler_missing(self, event: EventBase) -> None: + """A "dead letter" event handler that is called to perform + processing on events before they're discarded. + + :param event: The event to handle + """ + logger.warning( + "No event handler is registered in consumer " + f"{self.name}. Discarding {event=}" + ) + + def listen_once(self, timeout: float = 0.001, batch_timeout: float = 1.0) -> None: + """Receives messages for the consumer a single time. Delivers + all messages that pass the consumer filters. Shutdown requests + are handled by a default event handler. + + + NOTE: Executes a single batch-retrieval to receive the maximum + number of messages available under batch timeout. To continually + listen, use `listen` in a non-blocking thread/process + + :param timeout: Maximum time to wait (in seconds) for a message to arrive + :param timeout: Maximum time to wait (in seconds) for a batch to arrive + """ + logger.info( + f"Consumer {self.name} listening with {timeout} second timeout" + f" on channel {self._comm_channel.descriptor}" + ) + + if not self._event_handler: + logger.info("Unable to handle messages. No event handler is registered.") + + incoming_messages = self.recv(timeout=timeout, batch_timeout=batch_timeout) + + if not incoming_messages: + logger.info(f"Consumer {self.name} received empty message list") + + for message in incoming_messages: + logger.info(f"Consumer {self.name} is handling event {message=}") + self._handle_shutdown(message) + + if self._event_handler: + self._event_handler(message) + else: + self._on_handler_missing(message) + + def _handle_shutdown(self, event: EventBase) -> bool: + """Handles shutdown requests sent to the consumer by setting the + `self.listener` property to `False`. + + :param event: The event to handle + :returns: A bool indicating if the event was a shutdown request + """ + if isinstance(event, OnShutdownRequested): + logger.debug(f"Shutdown requested from: {event.source}") + self.listening = False + return True + return False + + def listen(self, timeout: float = 0.001, batch_timeout: float = 1.0) -> None: + """Receives messages for the consumer until a shutdown request is received. + + :param timeout: Maximum time to wait (in seconds) for a message to arrive + :param batch_timeout: Maximum time to wait (in seconds) for a batch to arrive + """ + + logger.debug(f"Consumer {self.name} is now listening for events.") + + while self.listening: + self.listen_once(timeout, batch_timeout) + + logger.debug(f"Consumer {self.name} is no longer listening.") diff --git a/smartsim/_core/mli/infrastructure/comm/event.py b/smartsim/_core/mli/infrastructure/comm/event.py new file mode 100644 index 0000000000..ccef9f9b86 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/comm/event.py @@ -0,0 +1,162 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pickle +import typing as t +import uuid +from dataclasses import dataclass, field + +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +@dataclass +class EventBase: + """Core API for an event.""" + + category: str + """Unique category name for an event class""" + source: str + """A unique identifier for the publisher of the event""" + uid: str = field(default_factory=lambda: str(uuid.uuid4())) + """A unique identifier for this event""" + + def __bytes__(self) -> bytes: + """Default conversion to bytes for an event required to publish + messages using byte-oriented communication channels. + + :returns: This entity encoded as bytes""" + return pickle.dumps(self) + + def __str__(self) -> str: + """Convert the event to a string. + + :returns: A string representation of this instance""" + return f"{self.uid}|{self.category}" + + +class OnShutdownRequested(EventBase): + """Publish this event to trigger the listener to shutdown.""" + + SHUTDOWN: t.ClassVar[str] = "consumer-unregister" + """Unique category name for an event raised when a new consumer is unregistered""" + + def __init__(self, source: str) -> None: + """Initialize the event instance. + + :param source: A unique identifier for the publisher of the event + creating the event + """ + super().__init__(self.SHUTDOWN, source) + + +class OnCreateConsumer(EventBase): + """Publish this event when a new event consumer registration is required.""" + + descriptor: str + """Descriptor of the comm channel exposed by the consumer""" + filters: t.List[str] = field(default_factory=list) + """The collection of filters indicating messages of interest to this consumer""" + + CONSUMER_CREATED: t.ClassVar[str] = "consumer-created" + """Unique category name for an event raised when a new consumer is registered""" + + def __init__(self, source: str, descriptor: str, filters: t.Sequence[str]) -> None: + """Initialize the event instance. + + :param source: A unique identifier for the publisher of the event + :param descriptor: Descriptor of the comm channel exposed by the consumer + :param filters: Collection of filters indicating messages of interest + """ + super().__init__(self.CONSUMER_CREATED, source) + self.descriptor = descriptor + self.filters = list(filters) + + def __str__(self) -> str: + """Convert the event to a string. + + :returns: A string representation of this instance + """ + _filters = ",".join(self.filters) + return f"{str(super())}|{self.descriptor}|{_filters}" + + +class OnRemoveConsumer(EventBase): + """Publish this event when a consumer is shutting down and + should be removed from notification lists.""" + + descriptor: str + """Descriptor of the comm channel exposed by the consumer""" + + CONSUMER_REMOVED: t.ClassVar[str] = "consumer-removed" + """Unique category name for an event raised when a new consumer is unregistered""" + + def __init__(self, source: str, descriptor: str) -> None: + """Initialize the OnRemoveConsumer event. + + :param source: A unique identifier for the publisher of the event + :param descriptor: Descriptor of the comm channel exposed by the consumer + """ + super().__init__(self.CONSUMER_REMOVED, source) + self.descriptor = descriptor + + def __str__(self) -> str: + """Convert the event to a string. + + :returns: A string representation of this instance + """ + return f"{str(super())}|{self.descriptor}" + + +class OnWriteFeatureStore(EventBase): + """Publish this event when a feature store key is written.""" + + descriptor: str + """The descriptor of the feature store where the write occurred""" + key: str + """The key identifying where the write occurred""" + + FEATURE_STORE_WRITTEN: str = "feature-store-written" + """Event category for an event raised when a feature store key is written""" + + def __init__(self, source: str, descriptor: str, key: str) -> None: + """Initialize the OnWriteFeatureStore event. + + :param source: A unique identifier for the publisher of the event + :param descriptor: The descriptor of the feature store where the write occurred + :param key: The key identifying where the write occurred + """ + super().__init__(self.FEATURE_STORE_WRITTEN, source) + self.descriptor = descriptor + self.key = key + + def __str__(self) -> str: + """Convert the event to a string. + + :returns: A string representation of this instance + """ + return f"{str(super())}|{self.descriptor}|{self.key}" diff --git a/smartsim/_core/mli/infrastructure/comm/producer.py b/smartsim/_core/mli/infrastructure/comm/producer.py new file mode 100644 index 0000000000..2d8a7c14ad --- /dev/null +++ b/smartsim/_core/mli/infrastructure/comm/producer.py @@ -0,0 +1,44 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +from smartsim._core.mli.infrastructure.comm.event import EventBase +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class EventProducer(t.Protocol): + """Core API of a class that publishes events.""" + + def send(self, event: EventBase, timeout: float = 0.001) -> int: + """Send an event using the configured comm channel. + + :param event: The event to send + :param timeout: Maximum time to wait (in seconds) for messages to send + :returns: The number of messages that were sent + """ diff --git a/smartsim/_core/mli/infrastructure/control/error_handling.py b/smartsim/_core/mli/infrastructure/control/error_handling.py index 8961cac543..a75f533a37 100644 --- a/smartsim/_core/mli/infrastructure/control/error_handling.py +++ b/smartsim/_core/mli/infrastructure/control/error_handling.py @@ -48,7 +48,7 @@ def build_failure_reply(status: "Status", message: str) -> ResponseBuilder: return MessageHandler.build_response( status=status, message=message, - result=[], + result=None, custom_attributes=None, ) diff --git a/smartsim/_core/mli/infrastructure/control/listener.py b/smartsim/_core/mli/infrastructure/control/listener.py new file mode 100644 index 0000000000..56a7b12d34 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/control/listener.py @@ -0,0 +1,352 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# isort: off +# pylint: disable=import-error +# pylint: disable=unused-import +import socket +import dragon + +# pylint: enable=unused-import +# pylint: enable=import-error +# isort: on + +import argparse +import multiprocessing as mp +import os +import sys +import typing as t + +from smartsim._core.entrypoints.service import Service +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer +from smartsim._core.mli.infrastructure.comm.event import ( + EventBase, + OnCreateConsumer, + OnRemoveConsumer, + OnShutdownRequested, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class ConsumerRegistrationListener(Service): + """A long-running service that manages the list of consumers receiving + events that are broadcast. It hosts handlers for adding and removing consumers + """ + + def __init__( + self, + backbone: BackboneFeatureStore, + timeout: float, + batch_timeout: float, + as_service: bool = False, + cooldown: int = 0, + health_check_frequency: float = 60.0, + ) -> None: + """Initialize the EventListener. + + :param backbone: The backbone feature store + :param timeout: Maximum time (in seconds) to allow a single recv request to wait + :param batch_timeout: Maximum time (in seconds) to allow a batch of receives to + continue to build + :param as_service: Specifies run-once or run-until-complete behavior of service + :param cooldown: Number of seconds to wait before shutting down after + shutdown criteria are met + """ + super().__init__( + as_service, cooldown, health_check_frequency=health_check_frequency + ) + self._timeout = timeout + """ Maximum time (in seconds) to allow a single recv request to wait""" + self._batch_timeout = batch_timeout + """Maximum time (in seconds) to allow a batch of receives to + continue to build""" + self._consumer: t.Optional[EventConsumer] = None + """The event consumer that handles receiving events""" + self._backbone = backbone + """A standalone, system-created feature store used to share internal + information among MLI components""" + + def _on_start(self) -> None: + """Called on initial entry into Service `execute` event loop before + `_on_iteration` is invoked.""" + super()._on_start() + self._create_eventing() + + def _on_shutdown(self) -> None: + """Release dragon resources. Called immediately after exiting + the main event loop during automatic shutdown.""" + super()._on_shutdown() + + if not self._consumer: + return + + # remove descriptor for this listener from the backbone if it's there + if registered_consumer := self._backbone.backend_channel: + # if there is a descriptor in the backbone and it's still this listener + if registered_consumer == self._consumer.descriptor: + logger.info( + f"Listener clearing backend consumer {self._consumer.name} " + "from backbone" + ) + + # unregister this listener in the backbone + self._backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + # TODO: need the channel to be cleaned up + # self._consumer._comm_channel._channel.destroy() + + def _on_iteration(self) -> None: + """Executes calls to the machine learning worker implementation to complete + the inference pipeline.""" + + if self._consumer is None: + logger.info("Unable to listen. No consumer available.") + return + + self._consumer.listen_once(self._timeout, self._batch_timeout) + + def _can_shutdown(self) -> bool: + """Determines if the event consumer is ready to stop listening. + + :returns: True when criteria to shutdown the service are met, False otherwise + """ + + if self._backbone is None: + logger.info("Listener must shutdown. No backbone attached") + return True + + if self._consumer is None: + logger.info("Listener must shutdown. No consumer channel created") + return True + + if not self._consumer.listening: + logger.info( + f"Listener can shutdown. Consumer `{self._consumer.name}` " + "is not listening" + ) + return True + + return False + + def _on_unregister(self, event: OnRemoveConsumer) -> None: + """Event handler for updating the backbone when event consumers + are un-registered. + + :param event: The event that was received + """ + notify_list = set(self._backbone.notification_channels) + + # remove the descriptor specified in the event + if event.descriptor in notify_list: + logger.debug(f"Removing notify consumer: {event.descriptor}") + notify_list.remove(event.descriptor) + + # push the updated list back into the backbone + self._backbone.notification_channels = list(notify_list) + + def _on_register(self, event: OnCreateConsumer) -> None: + """Event handler for updating the backbone when new event consumers + are registered. + + :param event: The event that was received + """ + notify_list = set(self._backbone.notification_channels) + logger.debug(f"Adding notify consumer: {event.descriptor}") + notify_list.add(event.descriptor) + self._backbone.notification_channels = list(notify_list) + + def _on_event_received(self, event: EventBase) -> None: + """Primary event handler for the listener. Distributes events to + type-specific handlers. + + :param event: The event that was received + """ + if self._backbone is None: + logger.info("Unable to handle event. Backbone is missing.") + + if isinstance(event, OnCreateConsumer): + self._on_register(event) + elif isinstance(event, OnRemoveConsumer): + self._on_unregister(event) + else: + logger.info( + "Consumer registration listener received an " + f"unexpected event: {event=}" + ) + + def _on_health_check(self) -> None: + """Check if this consumer has been replaced by a new listener + and automatically trigger a shutdown. Invoked based on the + value of `self._health_check_frequency`.""" + super()._on_health_check() + + try: + logger.debug("Retrieving registered listener descriptor") + descriptor = self._backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + except KeyError: + descriptor = None + if self._consumer: + self._consumer.listening = False + + if self._consumer and descriptor != self._consumer.descriptor: + logger.warning( + f"Consumer `{self._consumer.name}` for `ConsumerRegistrationListener` " + "is no longer registered. It will automatically shut down." + ) + self._consumer.listening = False + + def _publish_consumer(self) -> None: + """Publish the registrar consumer descriptor to the backbone.""" + if self._consumer is None: + logger.warning("No registrar consumer descriptor available to publisher") + return + + logger.debug(f"Publishing {self._consumer.descriptor} to backbone") + self._backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] = ( + self._consumer.descriptor + ) + + def _create_eventing(self) -> EventConsumer: + """ + Create an event publisher and event consumer for communicating with + other MLI resources. + + NOTE: the backbone must be initialized before connecting eventing clients. + + :returns: The newly created EventConsumer instance + :raises SmartSimError: If a listener channel cannot be created + """ + + if self._consumer: + return self._consumer + + logger.info("Creating event consumer") + + dragon_channel = create_local(500) + event_channel = DragonCommChannel(dragon_channel) + + if not event_channel.descriptor: + raise SmartSimError( + "Unable to generate the descriptor for the event channel" + ) + + self._consumer = EventConsumer( + event_channel, + self._backbone, + [ + OnCreateConsumer.CONSUMER_CREATED, + OnRemoveConsumer.CONSUMER_REMOVED, + OnShutdownRequested.SHUTDOWN, + ], + name=f"ConsumerRegistrar.{socket.gethostname()}", + event_handler=self._on_event_received, + ) + self._publish_consumer() + + logger.info( + f"Backend consumer `{self._consumer.name}` created: " + f"{self._consumer.descriptor}" + ) + + return self._consumer + + +def _create_parser() -> argparse.ArgumentParser: + """ + Create an argument parser that contains the arguments + required to start the listener as a new process: + + --timeout + --batch_timeout + + :returns: A configured parser + """ + arg_parser = argparse.ArgumentParser(prog="ConsumerRegistrarEventListener") + + arg_parser.add_argument("--timeout", type=float, default=1.0) + arg_parser.add_argument("--batch_timeout", type=float, default=1.0) + + return arg_parser + + +def _connect_backbone() -> t.Optional[BackboneFeatureStore]: + """ + Load the backbone by retrieving the descriptor from environment variables. + + :returns: The backbone feature store + :raises SmartSimError: if a descriptor is not found + """ + descriptor = os.environ.get(BackboneFeatureStore.MLI_BACKBONE, "") + + if not descriptor: + return None + + logger.info(f"Listener backbone descriptor: {descriptor}\n") + + # `from_writable_descriptor` ensures we can update the backbone + return BackboneFeatureStore.from_writable_descriptor(descriptor) + + +if __name__ == "__main__": + mp.set_start_method("dragon") + + parser = _create_parser() + args = parser.parse_args() + + backbone_fs = _connect_backbone() + + if backbone_fs is None: + logger.error( + "Unable to attach to the backbone without the " + f"`{BackboneFeatureStore.MLI_BACKBONE}` environment variable." + ) + sys.exit(1) + + logger.debug(f"Listener attached to backbone: {backbone_fs.descriptor}") + + listener = ConsumerRegistrationListener( + backbone_fs, + float(args.timeout), + float(args.batch_timeout), + as_service=True, + ) + + logger.info(f"listener created? {listener}") + + try: + listener.execute() + sys.exit(0) + except Exception: + logger.exception("An error occurred in the event listener") + sys.exit(1) diff --git a/smartsim/_core/mli/infrastructure/control/request_dispatcher.py b/smartsim/_core/mli/infrastructure/control/request_dispatcher.py index 67797fe448..e22a2c8f62 100644 --- a/smartsim/_core/mli/infrastructure/control/request_dispatcher.py +++ b/smartsim/_core/mli/infrastructure/control/request_dispatcher.py @@ -142,13 +142,22 @@ def ready(self) -> bool: :returns: True if the queue can be flushed, False otherwise """ if self.empty(): + logger.debug("Request dispatcher queue is empty") return False - timed_out = ( - self._batch_timeout > 0 and self._elapsed_time >= self._batch_timeout - ) - logger.debug(f"Is full: {self.full()} or has timed out: {timed_out}") - return self.full() or timed_out + timed_out = False + if self._batch_timeout >= 0: + timed_out = self._elapsed_time >= self._batch_timeout + + if self.full(): + logger.debug("Request dispatcher ready to deliver full batch") + return True + + if timed_out: + logger.debug("Request dispatcher delivering partial batch") + return True + + return False def make_disposable(self) -> None: """Set this queue as disposable, and never use it again after it gets @@ -218,7 +227,6 @@ def __init__( :param config_loader: Object to load configuration from environment :param worker_type: Type of worker to instantiate to batch inputs :param mem_pool_size: Size of the memory pool used to allocate tensors - :raises SmartSimError: If config_loaded.get_queue() does not return a channel """ super().__init__(as_service=True, cooldown=1) self._queues: dict[str, list[BatchQueue]] = {} @@ -281,7 +289,7 @@ def _check_feature_stores(self, request: InferenceRequest) -> bool: fs_missing = fs_desired - fs_actual if not self.has_featurestore_factory: - logger.error("No feature store factory configured") + logger.error("No feature store factory is configured. Unable to dispatch.") return False # create the feature stores we need to service request @@ -363,6 +371,7 @@ def _on_iteration(self) -> None: None, ) + logger.debug(f"Dispatcher is processing {len(bytes_list)} messages") request_bytes = bytes_list[0] tensor_bytes_list = bytes_list[1:] self._perf_timer.start_timings() @@ -463,7 +472,7 @@ def dispatch(self, request: InferenceRequest) -> None: ) self._active_queues[tmp_id] = tmp_queue self._queues[tmp_id] = [tmp_queue] - tmp_queue.put_nowait(request) + tmp_queue.put(request) tmp_queue.make_disposable() return diff --git a/smartsim/_core/mli/infrastructure/environment_loader.py b/smartsim/_core/mli/infrastructure/environment_loader.py index 02043fbd80..5ba0fccc27 100644 --- a/smartsim/_core/mli/infrastructure/environment_loader.py +++ b/smartsim/_core/mli/infrastructure/environment_loader.py @@ -39,10 +39,15 @@ class EnvironmentConfigLoader: Facilitates the loading of a FeatureStore and Queue into the WorkerManager. """ + REQUEST_QUEUE_ENV_VAR = "_SMARTSIM_REQUEST_QUEUE" + """The environment variable that holds the request queue descriptor""" + BACKBONE_ENV_VAR = "_SMARTSIM_INFRA_BACKBONE" + """The environment variable that holds the backbone descriptor""" + def __init__( self, featurestore_factory: t.Callable[[str], FeatureStore], - callback_factory: t.Callable[[bytes], CommChannelBase], + callback_factory: t.Callable[[str], CommChannelBase], queue_factory: t.Callable[[str], CommChannelBase], ) -> None: """Initialize the config loader instance with the factories necessary for @@ -76,14 +81,16 @@ def get_backbone(self) -> t.Optional[FeatureStore]: :returns: The attached feature store via `_SMARTSIM_INFRA_BACKBONE` """ - descriptor = os.getenv("_SMARTSIM_INFRA_BACKBONE", "") + descriptor = os.getenv(self.BACKBONE_ENV_VAR, "") if not descriptor: logger.warning("No backbone descriptor is configured") return None if self._featurestore_factory is None: - logger.warning("No feature store factory is configured") + logger.warning( + "No feature store factory is configured. Backbone not created." + ) return None self.backbone = self._featurestore_factory(descriptor) @@ -95,7 +102,7 @@ def get_queue(self) -> t.Optional[CommChannelBase]: :returns: The attached queue specified via `_SMARTSIM_REQUEST_QUEUE` """ - descriptor = os.getenv("_SMARTSIM_REQUEST_QUEUE", "") + descriptor = os.getenv(self.REQUEST_QUEUE_ENV_VAR, "") if not descriptor: logger.warning("No queue descriptor is configured") diff --git a/smartsim/_core/mli/infrastructure/storage/backbone_feature_store.py b/smartsim/_core/mli/infrastructure/storage/backbone_feature_store.py index b6655bded6..b12d7b11b4 100644 --- a/smartsim/_core/mli/infrastructure/storage/backbone_feature_store.py +++ b/smartsim/_core/mli/infrastructure/storage/backbone_feature_store.py @@ -24,13 +24,10 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import enum -import pickle +import itertools +import os import time import typing as t -import uuid -from collections import defaultdict, deque -from dataclasses import dataclass # pylint: disable=import-error # isort: off @@ -38,7 +35,6 @@ # isort: on -from smartsim._core.mli.comm.channel.channel import CommChannelBase from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( DragonFeatureStore, ) @@ -48,16 +44,29 @@ logger = get_logger(__name__) -# todo: did i create an arms race where a developer just grabs the backbone -# and passes it wherever they need a FeatureStore? class BackboneFeatureStore(DragonFeatureStore): """A DragonFeatureStore wrapper with utility methods for accessing shared information stored in the MLI backbone feature store.""" MLI_NOTIFY_CONSUMERS = "_SMARTSIM_MLI_NOTIFY_CONSUMERS" + """Unique key used in the backbone to locate the consumer list""" + MLI_REGISTRAR_CONSUMER = "_SMARTIM_MLI_REGISTRAR_CONSUMER" + """Unique key used in the backbone to locate the registration consumer""" + MLI_WORKER_QUEUE = "_SMARTSIM_REQUEST_QUEUE" + """Unique key used in the backbone to locate MLI work queue""" + MLI_BACKBONE = "_SMARTSIM_INFRA_BACKBONE" + """Unique key used in the backbone to locate the backbone feature store""" + _CREATED_ON = "creation" + """Unique key used in the backbone to locate the creation date of the + feature store""" + _DEFAULT_WAIT_TIMEOUT = 1.0 + """The default wait time (in seconds) for blocking requests to + the feature store""" def __init__( - self, storage: "dragon_ddict.DDict", allow_reserved_writes: bool = False + self, + storage: dragon_ddict.DDict, + allow_reserved_writes: bool = False, ) -> None: """Initialize the DragonFeatureStore instance. @@ -68,13 +77,33 @@ def __init__( super().__init__(storage) self._enable_reserved_writes = allow_reserved_writes + self._record_creation_data() + + @property + def wait_timeout(self) -> float: + """Retrieve the wait timeout for this feature store. The wait timeout is + applied to all calls to `wait_for`. + + :returns: The wait timeout (in seconds). + """ + return self._wait_timeout + + @wait_timeout.setter + def wait_timeout(self, value: float) -> None: + """Set the wait timeout (in seconds) for this feature store. The wait + timeout is applied to all calls to `wait_for`. + + :param value: The new value to set + """ + self._wait_timeout = value + @property def notification_channels(self) -> t.Sequence[str]: """Retrieve descriptors for all registered MLI notification channels. - :returns: The list of descriptors + :returns: The list of channel descriptors """ - if "_SMARTSIM_MLI_NOTIFY_CONSUMERS" in self: + if self.MLI_NOTIFY_CONSUMERS in self: stored_consumers = self[self.MLI_NOTIFY_CONSUMERS] return str(stored_consumers).split(",") return [] @@ -85,335 +114,146 @@ def notification_channels(self, values: t.Sequence[str]) -> None: :param values: The list of channel descriptors to save """ - self[self.MLI_NOTIFY_CONSUMERS] = ",".join([str(value) for value in values]) - - -class EventCategory(str, enum.Enum): - """Predefined event types raised by SmartSim backend.""" - - CONSUMER_CREATED: str = "consumer-created" - FEATURE_STORE_WRITTEN: str = "feature-store-written" - - -@dataclass -class EventBase: - """Core API for an event.""" - - # todo: shift eventing code to: infrastructure / event / event.py - category: EventCategory - """The event category for this event; may be used for addressing, - prioritization, or filtering of events by a event publisher/consumer""" - - uid: str - """A unique identifier for this event""" - - def __bytes__(self) -> bytes: - """Default conversion to bytes for an event required to publish - messages using byte-oriented communication channels. - - :returns: This entity encoded as bytes""" - return pickle.dumps(self) - - def __str__(self) -> str: - """Convert the event to a string. - - :returns: A string representation of this instance""" - return f"{self.uid}|{self.category}" - - -class OnCreateConsumer(EventBase): - """Publish this event when a new event consumer registration is required.""" - - descriptor: str - """Descriptor of the comm channel exposed by the consumer""" - - def __init__(self, descriptor: str) -> None: - """Initialize the OnCreateConsumer event. - - :param descriptor: Descriptor of the comm channel exposed by the consumer - """ - super().__init__(EventCategory.CONSUMER_CREATED, str(uuid.uuid4())) - self.descriptor = descriptor - - def __str__(self) -> str: - """Convert the event to a string. - - :returns: A string representation of this instance - """ - return f"{str(super())}|{self.descriptor}" - - -class OnWriteFeatureStore(EventBase): - """Publish this event when a feature store key is written.""" - - descriptor: str - """The descriptor of the feature store where the write occurred""" - - key: str - """The key identifying where the write occurred""" - - def __init__(self, descriptor: str, key: str) -> None: - """Initialize the OnWriteFeatureStore event. - - :param descriptor: The descriptor of the feature store where the write occurred - :param key: The key identifying where the write occurred - """ - super().__init__(EventCategory.FEATURE_STORE_WRITTEN, str(uuid.uuid4())) - self.descriptor = descriptor - self.key = key - - def __str__(self) -> str: - """Convert the event to a string. - - :returns: A string representation of this instance - """ - return f"{str(super())}|{self.descriptor}|{self.key}" - - -class EventProducer(t.Protocol): - """Core API of a class that publishes events.""" - - def send(self, event: EventBase, timeout: float = 0.001) -> int: - """The send operation. - - :param event: The event to send - :param timeout: Maximum time to wait (in seconds) for messages to send - """ - - -class EventBroadcaster: - """Performs fan-out publishing of system events.""" - - def __init__( - self, - backbone: BackboneFeatureStore, - channel_factory: t.Optional[t.Callable[[str], CommChannelBase]] = None, - ) -> None: - """Initialize the EventPublisher instance. - - :param backbone: The MLI backbone feature store - :param channel_factory: Factory method to construct new channel instances - """ - self._backbone = backbone - """The backbone feature store used to retrieve consumer descriptors""" - self._channel_factory = channel_factory - """A factory method used to instantiate channels from descriptors""" - self._channel_cache: t.Dict[str, t.Optional[CommChannelBase]] = defaultdict( - lambda: None + self[self.MLI_NOTIFY_CONSUMERS] = ",".join( + [str(value) for value in values if value] ) - """A mapping of instantiated channels that can be re-used. Automatically - calls the channel factory if a descriptor is not already in the collection""" - self._event_buffer: t.Deque[bytes] = deque() - """A buffer for storing events when a consumer list is not found""" - self._descriptors: t.Set[str] - """Stores the most recent list of broadcast consumers. Updated automatically - on each broadcast""" - self._uid = str(uuid.uuid4()) - """A unique identifer assigned to the broadcaster for logging""" @property - def num_buffered(self) -> int: - """Return the number of events currently buffered to send. + def backend_channel(self) -> t.Optional[str]: + """Retrieve the channel descriptor used to register event consumers. - :returns: Number of buffered events - """ - return len(self._event_buffer) + :returns: The channel descriptor""" + if self.MLI_REGISTRAR_CONSUMER in self: + return str(self[self.MLI_REGISTRAR_CONSUMER]) + return None - def _save_to_buffer(self, event: EventBase) -> None: - """Places a serialized event in the buffer to be sent once a consumer - list is available. - - :param event: The event to serialize and buffer - :raises ValueError: If the event cannot be serialized - """ - try: - event_bytes = bytes(event) - self._event_buffer.append(event_bytes) - except Exception as ex: - raise ValueError(f"Unable to serialize event from {self._uid}") from ex - - def _log_broadcast_start(self) -> None: - """Logs broadcast statistics.""" - num_events = len(self._event_buffer) - num_copies = len(self._descriptors) - logger.debug( - f"Broadcast {num_events} events to {num_copies} consumers from {self._uid}" - ) + @backend_channel.setter + def backend_channel(self, value: str) -> None: + """Set the channel used to register event consumers. - def _prune_unused_consumers(self) -> None: - """Performs maintenance on the channel cache by pruning any channel - that has been removed from the consumers list.""" - active_consumers = set(self._descriptors) - current_channels = set(self._channel_cache.keys()) + :param value: The stringified channel descriptor""" + self[self.MLI_REGISTRAR_CONSUMER] = value - # find any cached channels that are now unused - inactive_channels = current_channels.difference(active_consumers) - new_channels = active_consumers.difference(current_channels) + @property + def worker_queue(self) -> t.Optional[str]: + """Retrieve the channel descriptor used to send work to MLI worker managers. - for descriptor in inactive_channels: - self._channel_cache.pop(descriptor) + :returns: The channel descriptor, if found. Otherwise, `None`""" + if self.MLI_WORKER_QUEUE in self: + return str(self[self.MLI_WORKER_QUEUE]) + return None - logger.debug( - f"Pruning {len(inactive_channels)} stale consumers and" - f" found {len(new_channels)} new channels for {self._uid}" - ) + @worker_queue.setter + def worker_queue(self, value: str) -> None: + """Set the channel descriptor used to send work to MLI worker managers. - def _get_comm_channel(self, descriptor: str) -> CommChannelBase: - """Helper method to build and cache a comm channel. + :param value: The channel descriptor""" + self[self.MLI_WORKER_QUEUE] = value - :param descriptor: The descriptor to pass to the channel factory - :returns: The instantiated channel - :raises SmartSimError: If the channel fails to build + @property + def creation_date(self) -> str: + """Return the creation date for the backbone feature store. + + :returns: The string-formatted date when feature store was created""" + return str(self[self._CREATED_ON]) + + def _record_creation_data(self) -> None: + """Write the creation timestamp to the feature store.""" + if self._CREATED_ON not in self: + if not self._allow_reserved_writes: + logger.warning( + "Recorded creation from a write-protected backbone instance" + ) + self[self._CREATED_ON] = str(time.time()) + + os.environ[self.MLI_BACKBONE] = self.descriptor + + @classmethod + def from_writable_descriptor( + cls, + descriptor: str, + ) -> "BackboneFeatureStore": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached DragonFeatureStore + :raises SmartSimError: if attachment to DragonFeatureStore fails """ - comm_channel = self._channel_cache[descriptor] - if comm_channel is not None: - return comm_channel - - if self._channel_factory is None: - raise SmartSimError("No channel factory provided for consumers") - try: - channel = self._channel_factory(descriptor) - self._channel_cache[descriptor] = channel - return channel + return BackboneFeatureStore(dragon_ddict.DDict.attach(descriptor), True) except Exception as ex: - msg = f"Unable to construct channel with descriptor: {descriptor}" - logger.error(msg, exc_info=True) - raise SmartSimError(msg) from ex + raise SmartSimError( + f"Error creating backbone feature store: {descriptor}" + ) from ex - def _broadcast(self, timeout: float = 0.001) -> int: - """Broadcasts all buffered events to registered event consumers. + def _check_wait_timeout( + self, start_time: float, timeout: float, indicators: t.Dict[str, bool] + ) -> None: + """Perform timeout verification. - :param timeout: Maximum time to wait (in seconds) for messages to send - :returns: The number of events broadcasted to consumers - :raises SmartSimError: If broadcasting fails + :param start_time: the start time to use for elapsed calculation + :param timeout: the timeout (in seconds) + :param indicators: latest retrieval status for requested keys + :raises SmartSimError: If the timeout elapses before all values are + retrieved """ - # allow descriptors to be empty since events are buffered - self._descriptors = set(x for x in self._backbone.notification_channels if x) - if not self._descriptors: - logger.warning(f"No event consumers are registered for {self._uid}") - return 0 - - self._prune_unused_consumers() - self._log_broadcast_start() - - num_sent: int = 0 - next_event: t.Optional[bytes] = self._event_buffer.popleft() - - # send each event to every consumer - while next_event is not None: - for descriptor in map(str, self._descriptors): - comm_channel = self._get_comm_channel(descriptor) - - try: - # todo: given a failure, the message is not sent to any other - # recipients. consider retrying, adding a dead letter queue, or - # logging the message details more intentionally - comm_channel.send(next_event, timeout) - num_sent += 1 - except Exception as ex: - raise SmartSimError( - f"Failed broadcast to channel {descriptor} from {self._uid}" - ) from ex - - try: - next_event = self._event_buffer.popleft() - except IndexError: - next_event = None - logger.debug(f"Broadcast buffer exhausted for {self._uid}") - - return num_sent - - def send(self, event: EventBase, timeout: float = 0.001) -> int: - """Implementation of `send` method of the `EventPublisher` protocol. Publishes - the supplied event to all registered broadcast consumers. - - :param event: An event to publish - :param timeout: Maximum time to wait (in seconds) for messages to send - :returns: The number of events successfully published - :raises ValueError: If event serialization fails - :raises KeyError: If channel fails to attach using registered descriptors - :raises SmartSimError: If any unexpected error occurs during send + elapsed = time.time() - start_time + if timeout and elapsed > timeout: + raise SmartSimError( + f"Backbone {self.descriptor=} timeout after {elapsed} " + f"seconds retrieving keys: {indicators}" + ) + + def wait_for( + self, keys: t.List[str], timeout: float = _DEFAULT_WAIT_TIMEOUT + ) -> t.Dict[str, t.Union[str, bytes, None]]: + """Perform a blocking wait until all specified keys have been found + in the backbone. + + :param keys: The required collection of keys to retrieve + :param timeout: The maximum wait time in seconds + :returns: Dictionary containing the keys and values requested + :raises SmartSimError: If the timeout elapses without retrieving + all requested keys """ - try: - self._save_to_buffer(event) - return self._broadcast(timeout) - except (KeyError, ValueError, SmartSimError): - raise - except Exception as ex: - raise SmartSimError("An unexpected failure occurred while sending") from ex + if timeout < 0: + timeout = self._DEFAULT_WAIT_TIMEOUT + logger.info(f"Using default wait_for timeout: {timeout}s") + if not keys: + return {} -class EventConsumer: - """Reads system events published to a communications channel.""" + values: t.Dict[str, t.Union[str, bytes, None]] = {k: None for k in set(keys)} + is_found = {k: False for k in values.keys()} - def __init__( - self, - comm_channel: CommChannelBase, - backbone: BackboneFeatureStore, - filters: t.Optional[t.List[EventCategory]] = None, - batch_timeout: t.Optional[float] = None, - ) -> None: - """Initialize the EventConsumer instance. - - :param comm_channel: Communications channel to listen to for events - :param backbone: The MLI backbone feature store - :param filters: A list of event types to deliver. when empty, all - events will be delivered - :param timeout: Maximum time to wait for messages to arrive; may be overridden - on individual calls to `receive` - :raises ValueError: If batch_timeout <= 0 - """ - if batch_timeout is not None and batch_timeout <= 0: - raise ValueError("batch_timeout must be a non-zero, positive value") - - self._comm_channel = comm_channel - self._backbone = backbone - self._global_filters = filters or [] - self._global_timeout = batch_timeout or 1.0 - - def receive( - self, filters: t.Optional[t.List[EventCategory]] = None, timeout: float = 0 - ) -> t.List[EventBase]: - """Receives available published event(s). - - :param filters: Additional filters to add to the global filters configured - on the EventConsumer instance - :param timeout: Maximum time to wait for messages to arrive - :returns: A list of events that pass any configured filters - """ - if filters is None: - filters = [] - - filter_set = {*self._global_filters, *filters} - messages: t.List[t.Any] = [] + backoff = (0.1, 0.2, 0.4, 0.8) + backoff_iter = itertools.cycle(backoff) + start_time = time.time() - # use the local timeout to override a global setting - start_at = time.time_ns() + while not all(is_found.values()): + delay = next(backoff_iter) - while msg_bytes_list := self._comm_channel.recv(timeout=timeout): - # remove any empty messages that will fail to decode - msg_bytes_list = [msg for msg in msg_bytes_list if msg] + for key in [k for k, v in is_found.items() if not v]: + try: + values[key] = self[key] + is_found[key] = True + except Exception: + if delay == backoff[-1]: + logger.debug(f"Re-attempting `{key}` retrieval in {delay}s") - msg: t.Optional[EventBase] = None - if msg_bytes_list: - for message in msg_bytes_list: - msg = pickle.loads(message) + if all(is_found.values()): + logger.debug(f"wait_for({keys}) retrieved all keys") + continue - if not msg: - logger.warning("Unable to unpickle message") - continue + self._check_wait_timeout(start_time, timeout, is_found) + time.sleep(delay) - # ignore anything that doesn't match a filter (if one is - # supplied), otherwise return everything - if not filter_set or msg.category in filter_set: - messages.append(msg) + return values - # avoid getting stuck indefinitely waiting for the channel - elapsed = (time.time_ns() - start_at) / 1000000000 - remaining = elapsed - self._global_timeout - if remaining > 0: - logger.debug(f"Consumer batch timeout exceeded by: {abs(remaining)}") - break + def get_env(self) -> t.Dict[str, str]: + """Returns a dictionary populated with environment variables necessary to + connect a process to the existing backbone instance. - return messages + :returns: The dictionary populated with env vars + """ + return {self.MLI_BACKBONE: self.descriptor} diff --git a/smartsim/_core/mli/infrastructure/storage/dragon_feature_store.py b/smartsim/_core/mli/infrastructure/storage/dragon_feature_store.py index d7b37ffe61..24f2221c87 100644 --- a/smartsim/_core/mli/infrastructure/storage/dragon_feature_store.py +++ b/smartsim/_core/mli/infrastructure/storage/dragon_feature_store.py @@ -32,6 +32,10 @@ # isort: on +from smartsim._core.mli.infrastructure.storage.dragon_util import ( + ddict_to_descriptor, + descriptor_to_ddict, +) from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore from smartsim.error import SmartSimError from smartsim.log import get_logger @@ -46,15 +50,20 @@ def __init__(self, storage: "dragon_ddict.DDict") -> None: """Initialize the DragonFeatureStore instance. :param storage: A distributed dictionary to be used as the underlying - storage mechanism of the feature store - """ + storage mechanism of the feature store""" + if storage is None: + raise ValueError( + "Storage is required when instantiating a DragonFeatureStore." + ) + + descriptor = "" if isinstance(storage, dragon_ddict.DDict): - descriptor = str(storage.serialize()) - else: - descriptor = "not-set" + descriptor = ddict_to_descriptor(storage) super().__init__(descriptor) self._storage: t.Dict[str, t.Union[str, bytes]] = storage + """The underlying storage mechanism of the DragonFeatureStore; a + distributed, in-memory key-value store""" def _get(self, key: str) -> t.Union[str, bytes]: """Retrieve a value from the underlying storage mechanism. @@ -65,7 +74,7 @@ def _get(self, key: str) -> t.Union[str, bytes]: """ try: return self._storage[key] - except KeyError as e: + except dragon_ddict.DDictError as e: raise KeyError(f"Key not found in FeatureStore: {key}") from e def _set(self, key: str, value: t.Union[str, bytes]) -> None: @@ -85,6 +94,17 @@ def _contains(self, key: str) -> bool: """ return key in self._storage + def pop(self, key: str) -> t.Union[str, bytes, None]: + """Remove the value from the dictionary and return the value. + + :param key: Dictionary key to retrieve + :returns: The value held at the key if it exists, otherwise `None + `""" + try: + return self._storage.pop(key) + except dragon_ddict.DDictError: + return None + @classmethod def from_descriptor( cls, @@ -97,9 +117,10 @@ def from_descriptor( :raises SmartSimError: If attachment to DragonFeatureStore fails """ try: - return DragonFeatureStore(dragon_ddict.DDict.attach(descriptor)) + logger.debug(f"Attaching to FeatureStore with descriptor: {descriptor}") + storage = descriptor_to_ddict(descriptor) + return cls(storage) except Exception as ex: - logger.error(f"Error creating dragon feature store: {descriptor}") raise SmartSimError( - f"Error creating dragon feature store: {descriptor}" + f"Error creating dragon feature store from descriptor: {descriptor}" ) from ex diff --git a/smartsim/_core/mli/infrastructure/storage/dragon_util.py b/smartsim/_core/mli/infrastructure/storage/dragon_util.py new file mode 100644 index 0000000000..50d15664c0 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/storage/dragon_util.py @@ -0,0 +1,101 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# pylint: disable=import-error +# isort: off +import dragon.data.ddict.ddict as dragon_ddict + +# isort: on + +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +def ddict_to_descriptor(ddict: dragon_ddict.DDict) -> str: + """Convert a DDict to a descriptor string. + + :param ddict: The dragon dictionary to convert + :returns: The descriptor string + :raises ValueError: If a ddict is not provided + """ + if ddict is None: + raise ValueError("DDict is not available to create a descriptor") + + # unlike other dragon objects, the dictionary serializes to a string + # instead of bytes + return str(ddict.serialize()) + + +def descriptor_to_ddict(descriptor: str) -> dragon_ddict.DDict: + """Create and attach a new DDict instance given + the string-encoded descriptor. + + :param descriptor: The descriptor of a dictionary to attach to + :returns: The attached dragon dictionary""" + return dragon_ddict.DDict.attach(descriptor) + + +def create_ddict( + num_nodes: int, mgr_per_node: int, mem_per_node: int +) -> dragon_ddict.DDict: + """Create a distributed dragon dictionary. + + :param num_nodes: The number of distributed nodes to distribute the dictionary to. + At least one node is required. + :param mgr_per_node: The number of manager processes per node + :param mem_per_node: The amount of memory (in megabytes) to allocate per node. Total + memory available will be calculated as `num_nodes * node_mem` + + :returns: The instantiated dragon dictionary + :raises ValueError: If invalid num_nodes is supplied + :raises ValueError: If invalid mem_per_node is supplied + :raises ValueError: If invalid mgr_per_node is supplied + """ + if num_nodes < 1: + raise ValueError("A dragon dictionary must have at least 1 node") + + if mgr_per_node < 1: + raise ValueError("A dragon dict requires at least 2 managers per ndode") + + if mem_per_node < dragon_ddict.DDICT_MIN_SIZE: + raise ValueError( + "A dragon dictionary requires at least " + f"{dragon_ddict.DDICT_MIN_SIZE / 1024} MB" + ) + + mem_total = num_nodes * mem_per_node + + logger.debug( + f"Creating dragon dictionary with {num_nodes} nodes, {mem_total} MB memory" + ) + + distributed_dict = dragon_ddict.DDict(num_nodes, mgr_per_node, total_mem=mem_total) + logger.debug( + "Successfully created dragon dictionary with " + f"{num_nodes} nodes, {mem_total} MB total memory" + ) + return distributed_dict diff --git a/smartsim/_core/mli/infrastructure/storage/feature_store.py b/smartsim/_core/mli/infrastructure/storage/feature_store.py index a55c523058..ebca07ed4e 100644 --- a/smartsim/_core/mli/infrastructure/storage/feature_store.py +++ b/smartsim/_core/mli/infrastructure/storage/feature_store.py @@ -43,6 +43,14 @@ class ReservedKeys(str, enum.Enum): """Storage location for the list of registered consumers that will receive events from an EventBroadcaster""" + MLI_REGISTRAR_CONSUMER = "_SMARTIM_MLI_REGISTRAR_CONSUMER" + """Storage location for the channel used to send messages directly to + the MLI backend""" + + MLI_WORKER_QUEUE = "_SMARTSIM_REQUEST_QUEUE" + """Storage location for the channel used to send work requests + to the available worker managers""" + @classmethod def contains(cls, value: str) -> bool: """Convert a string representation into an enumeration member. @@ -59,7 +67,27 @@ def contains(cls, value: str) -> bool: @dataclass(frozen=True) -class FeatureStoreKey: +class TensorKey: + """A key,descriptor pair enabling retrieval of an item from a feature store.""" + + key: str + """The unique key of an item in a feature store""" + descriptor: str + """The unique identifier of the feature store containing the key""" + + def __post_init__(self) -> None: + """Ensure the key and descriptor have at least one character. + + :raises ValueError: If key or descriptor are empty strings + """ + if len(self.key) < 1: + raise ValueError("Key must have at least one character.") + if len(self.descriptor) < 1: + raise ValueError("Descriptor must have at least one character.") + + +@dataclass(frozen=True) +class ModelKey: """A key,descriptor pair enabling retrieval of an item from a feature store.""" key: str @@ -119,8 +147,8 @@ def __getitem__(self, key: str) -> t.Union[str, bytes]: """ try: return self._get(key) - except KeyError as ex: - raise SmartSimError(f"An unknown key was requested: {key}") from ex + except KeyError: + raise except Exception as ex: # note: explicitly avoid round-trip to check for key existence raise SmartSimError( diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py index 530d251540..9556b8e438 100644 --- a/smartsim/_core/mli/infrastructure/worker/worker.py +++ b/smartsim/_core/mli/infrastructure/worker/worker.py @@ -39,17 +39,16 @@ from ...comm.channel.channel import CommChannelBase from ...message_handler import MessageHandler from ...mli_schemas.model.model_capnp import Model -from ..storage.feature_store import FeatureStore, FeatureStoreKey +from ..storage.feature_store import FeatureStore, ModelKey, TensorKey if t.TYPE_CHECKING: - from smartsim._core.mli.mli_schemas.data.data_references_capnp import TensorKey from smartsim._core.mli.mli_schemas.response.response_capnp import Status from smartsim._core.mli.mli_schemas.tensor.tensor_capnp import TensorDescriptor logger = get_logger(__name__) # Placeholder -ModelIdentifier = FeatureStoreKey +ModelIdentifier = ModelKey class InferenceRequest: @@ -57,12 +56,12 @@ class InferenceRequest: def __init__( self, - model_key: t.Optional[FeatureStoreKey] = None, + model_key: t.Optional[ModelKey] = None, callback: t.Optional[CommChannelBase] = None, raw_inputs: t.Optional[t.List[bytes]] = None, - input_keys: t.Optional[t.List[FeatureStoreKey]] = None, + input_keys: t.Optional[t.List[TensorKey]] = None, input_meta: t.Optional[t.List[t.Any]] = None, - output_keys: t.Optional[t.List[FeatureStoreKey]] = None, + output_keys: t.Optional[t.List[TensorKey]] = None, raw_model: t.Optional[Model] = None, batch_size: int = 0, ): @@ -112,7 +111,7 @@ def has_model_key(self) -> bool: @property def has_raw_inputs(self) -> bool: - """Check if the InferenceRequest contains raw_outputs. + """Check if the InferenceRequest contains raw_inputs. :returns: True if raw_outputs is not None and is not an empty list, False otherwise @@ -153,7 +152,7 @@ class InferenceReply: def __init__( self, outputs: t.Optional[t.Collection[t.Any]] = None, - output_keys: t.Optional[t.Collection[FeatureStoreKey]] = None, + output_keys: t.Optional[t.Collection[TensorKey]] = None, status_enum: "Status" = "running", message: str = "In progress", ) -> None: @@ -166,7 +165,7 @@ def __init__( """ self.outputs: t.Collection[t.Any] = outputs or [] """List of output data""" - self.output_keys: t.Collection[t.Optional[FeatureStoreKey]] = output_keys or [] + self.output_keys: t.Collection[t.Optional[TensorKey]] = output_keys or [] """List of keys used for output data""" self.status_enum = status_enum """Status of the reply""" @@ -201,6 +200,7 @@ def __init__(self, model: t.Any) -> None: :param model: The loaded model """ self.model = model + """The loaded model (e.g. a TensorFlow, PyTorch, ONNX, etc. model)""" class TransformInputResult: @@ -320,7 +320,7 @@ class RequestBatch: """List of InferenceRequests in the batch""" inputs: t.Optional[TransformInputResult] """Transformed batch of input tensors""" - model_id: ModelIdentifier + model_id: "ModelIdentifier" """Model (key, descriptor) tuple""" @property @@ -350,7 +350,7 @@ def raw_model(self) -> t.Optional[t.Any]: return None @property - def input_keys(self) -> t.List[FeatureStoreKey]: + def input_keys(self) -> t.List[TensorKey]: """All input keys available in this batch's requests. :returns: All input keys belonging to requests in this batch""" @@ -361,7 +361,7 @@ def input_keys(self) -> t.List[FeatureStoreKey]: return keys @property - def output_keys(self) -> t.List[FeatureStoreKey]: + def output_keys(self) -> t.List[TensorKey]: """All output keys available in this batch's requests. :returns: All output keys belonging to requests in this batch""" @@ -378,7 +378,7 @@ class MachineLearningWorkerCore: @staticmethod def deserialize_message( data_blob: bytes, - callback_factory: t.Callable[[bytes], CommChannelBase], + callback_factory: t.Callable[[str], CommChannelBase], ) -> InferenceRequest: """Deserialize a message from a byte stream into an InferenceRequest. @@ -388,27 +388,27 @@ def deserialize_message( :returns: The raw input message deserialized into an InferenceRequest """ request = MessageHandler.deserialize_request(data_blob) - model_key: t.Optional[FeatureStoreKey] = None + model_key: t.Optional[ModelKey] = None model_bytes: t.Optional[Model] = None if request.model.which() == "key": - model_key = FeatureStoreKey( + model_key = ModelKey( key=request.model.key.key, - descriptor=request.model.key.featureStoreDescriptor, + descriptor=request.model.key.descriptor, ) elif request.model.which() == "data": model_bytes = request.model.data callback_key = request.replyChannel.descriptor comm_channel = callback_factory(callback_key) - input_keys: t.Optional[t.List[FeatureStoreKey]] = None + input_keys: t.Optional[t.List[TensorKey]] = None input_bytes: t.Optional[t.List[bytes]] = None - output_keys: t.Optional[t.List[FeatureStoreKey]] = None + output_keys: t.Optional[t.List[TensorKey]] = None input_meta: t.Optional[t.List[TensorDescriptor]] = None if request.input.which() == "keys": input_keys = [ - FeatureStoreKey(key=value.key, descriptor=value.featureStoreDescriptor) + TensorKey(key=value.key, descriptor=value.descriptor) for value in request.input.keys ] elif request.input.which() == "descriptors": @@ -416,7 +416,7 @@ def deserialize_message( if request.output: output_keys = [ - FeatureStoreKey(key=value.key, descriptor=value.featureStoreDescriptor) + TensorKey(key=value.key, descriptor=value.descriptor) for value in request.output ] @@ -490,7 +490,7 @@ def fetch_model( feature_store = feature_stores[fsd] raw_bytes: bytes = t.cast(bytes, feature_store[key]) return FetchModelResult(raw_bytes) - except FileNotFoundError as ex: + except (FileNotFoundError, KeyError) as ex: logger.exception(ex) raise SmartSimError(f"Model could not be retrieved with key {key}") from ex @@ -545,12 +545,12 @@ def place_output( request: InferenceRequest, transform_result: TransformOutputResult, feature_stores: t.Dict[str, FeatureStore], - ) -> t.Collection[t.Optional[FeatureStoreKey]]: + ) -> t.Collection[t.Optional[TensorKey]]: """Given a collection of data, make it available as a shared resource in the feature store. :param request: The request that triggered the pipeline - :param execute_result: Results from inference + :param transform_result: Transformed version of the inference result :param feature_stores: Available feature stores used for persistence :returns: A collection of keys that were placed in the feature store :raises ValueError: If a feature store is not provided @@ -558,7 +558,7 @@ def place_output( if not feature_stores: raise ValueError("Feature store is required for output persistence") - keys: t.List[t.Optional[FeatureStoreKey]] = [] + keys: t.List[t.Optional[TensorKey]] = [] # need to decide how to get back to original sub-batch inputs so they can be # accurately placed, datum might need to include this. @@ -580,10 +580,12 @@ class MachineLearningWorkerBase(MachineLearningWorkerCore, ABC): def load_model( batch: RequestBatch, fetch_result: FetchModelResult, device: str ) -> LoadModelResult: - """Given a loaded MachineLearningModel, ensure it is loaded into - device memory. + """Given the raw bytes of an ML model that were fetched, ensure + it is loaded into device memory. :param request: The request that triggered the pipeline + :param fetch_result: The result of a fetch-model operation; contains + the raw bytes of the ML model. :param device: The device on which the model must be placed :returns: LoadModelResult wrapping the model loaded for the request :raises ValueError: If model reference object is not found @@ -600,7 +602,7 @@ def transform_input( """Given a collection of data, perform a transformation on the data and put the raw tensor data on a MemoryPool allocation. - :param request: The request that triggered the pipeline + :param batch: The request that triggered the pipeline :param fetch_result: Raw outputs from fetching inputs out of a feature store :param mem_pool: The memory pool used to access batched input tensors :returns: The transformed inputs wrapped in a TransformInputResult diff --git a/smartsim/_core/mli/message_handler.py b/smartsim/_core/mli/message_handler.py index 71def143ad..e3d46a7ab3 100644 --- a/smartsim/_core/mli/message_handler.py +++ b/smartsim/_core/mli/message_handler.py @@ -35,6 +35,10 @@ class MessageHandler: + """Utility methods for transforming capnproto messages to and from + internal representations. + """ + @staticmethod def build_tensor_descriptor( order: "tensor_capnp.Order", @@ -73,7 +77,7 @@ def build_output_tensor_descriptor( order, data type, and dimensions. :param order: Order of the tensor, such as row-major (c) or column-major (f) - :param keys: List of TensorKeys to apply transorm descriptor to + :param keys: List of TensorKey to apply transorm descriptor to :param data_type: Tranform data type of the tensor :param dimensions: Transform dimensions of the tensor :returns: The OutputDescriptor @@ -92,14 +96,12 @@ def build_output_tensor_descriptor( return description @staticmethod - def build_tensor_key( - key: str, feature_store_descriptor: str - ) -> data_references_capnp.TensorKey: + def build_tensor_key(key: str, descriptor: str) -> data_references_capnp.TensorKey: """ Builds a new TensorKey message with the provided key. :param key: String to set the TensorKey - :param feature_store_descriptor: A descriptor identifying the feature store + :param descriptor: A descriptor identifying the feature store containing the key :returns: The TensorKey :raises ValueError: If building fails @@ -107,7 +109,7 @@ def build_tensor_key( try: tensor_key = data_references_capnp.TensorKey.new_message() tensor_key.key = key - tensor_key.featureStoreDescriptor = feature_store_descriptor + tensor_key.descriptor = descriptor except Exception as e: raise ValueError("Error building tensor key.") from e return tensor_key @@ -133,14 +135,12 @@ def build_model(data: bytes, name: str, version: str) -> model_capnp.Model: return model @staticmethod - def build_model_key( - key: str, feature_store_descriptor: str - ) -> data_references_capnp.ModelKey: + def build_model_key(key: str, descriptor: str) -> data_references_capnp.ModelKey: """ Builds a new ModelKey message with the provided key. :param key: String to set the ModelKey - :param feature_store_descriptor: A descriptor identifying the feature store + :param descriptor: A descriptor identifying the feature store containing the key :returns: The ModelKey :raises ValueError: If building fails @@ -148,9 +148,9 @@ def build_model_key( try: model_key = data_references_capnp.ModelKey.new_message() model_key.key = key - model_key.featureStoreDescriptor = feature_store_descriptor + model_key.descriptor = descriptor except Exception as e: - raise ValueError("Error building model key.") from e + raise ValueError("Error building tensor key.") from e return model_key @staticmethod @@ -242,7 +242,7 @@ def _assign_model( @staticmethod def _assign_reply_channel( - request: request_capnp.Request, reply_channel: bytes + request: request_capnp.Request, reply_channel: str ) -> None: """ Assigns a reply channel to the supplied request. @@ -360,7 +360,7 @@ def _assign_custom_request_attributes( @staticmethod def build_request( - reply_channel: bytes, + reply_channel: str, model: t.Union[data_references_capnp.ModelKey, model_capnp.Model], inputs: t.Union[ t.List[data_references_capnp.TensorKey], diff --git a/smartsim/_core/mli/mli_schemas/data/data_references.capnp b/smartsim/_core/mli/mli_schemas/data/data_references.capnp index 699abe5d22..65293be7b2 100644 --- a/smartsim/_core/mli/mli_schemas/data/data_references.capnp +++ b/smartsim/_core/mli/mli_schemas/data/data_references.capnp @@ -28,10 +28,10 @@ struct ModelKey { key @0 :Text; - featureStoreDescriptor @1 :Text; + descriptor @1 :Text; } struct TensorKey { key @0 :Text; - featureStoreDescriptor @1 :Text; + descriptor @1 :Text; } diff --git a/smartsim/_core/mli/mli_schemas/data/data_references_capnp.pyi b/smartsim/_core/mli/mli_schemas/data/data_references_capnp.pyi index bcf53e0a04..a5e318a556 100644 --- a/smartsim/_core/mli/mli_schemas/data/data_references_capnp.pyi +++ b/smartsim/_core/mli/mli_schemas/data/data_references_capnp.pyi @@ -36,7 +36,7 @@ from typing import Iterator class ModelKey: key: str - featureStoreDescriptor: str + descriptor: str @staticmethod @contextmanager def from_bytes( @@ -72,7 +72,7 @@ class ModelKeyBuilder(ModelKey): class TensorKey: key: str - featureStoreDescriptor: str + descriptor: str @staticmethod @contextmanager def from_bytes( diff --git a/smartsim/_core/mli/mli_schemas/request/request.capnp b/smartsim/_core/mli/mli_schemas/request/request.capnp index 4be1cfa215..26d9542d9f 100644 --- a/smartsim/_core/mli/mli_schemas/request/request.capnp +++ b/smartsim/_core/mli/mli_schemas/request/request.capnp @@ -32,7 +32,7 @@ using DataRef = import "../data/data_references.capnp"; using Models = import "../model/model.capnp"; struct ChannelDescriptor { - descriptor @0 :Data; + descriptor @0 :Text; } struct Request { diff --git a/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi b/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi index a4ad631f9f..2aab80b1d0 100644 --- a/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi +++ b/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi @@ -61,7 +61,7 @@ from .request_attributes.request_attributes_capnp import ( ) class ChannelDescriptor: - descriptor: bytes + descriptor: str @staticmethod @contextmanager def from_bytes( diff --git a/smartsim/_core/utils/timings.py b/smartsim/_core/utils/timings.py index 114db88d90..f99950739e 100644 --- a/smartsim/_core/utils/timings.py +++ b/smartsim/_core/utils/timings.py @@ -145,10 +145,12 @@ def max_length(self) -> int: return max(len(value) for value in self._timings.values()) def print_timings(self, to_file: bool = False) -> None: - """Print all timing information + """Print timing information to standard output. If `to_file` + is `True`, also write results to a file. - :param to_file: flag indicating if timing should be written to stdout - or to the timing file""" + :param to_file: If `True`, also saves timing information + to the files `timings.npy` and `timings.txt` + """ print(" ".join(self._timings.keys())) try: value_array = np.array(list(self._timings.values()), dtype=float) diff --git a/smartsim/log.py b/smartsim/log.py index 3d6c0860ee..c8fed9329f 100644 --- a/smartsim/log.py +++ b/smartsim/log.py @@ -252,16 +252,21 @@ def filter(self, record: logging.LogRecord) -> bool: return record.levelno <= level_no -def log_to_file(filename: str, log_level: str = "debug") -> None: +def log_to_file( + filename: str, log_level: str = "debug", logger: t.Optional[logging.Logger] = None +) -> None: """Installs a second filestream handler to the root logger, allowing subsequent logging calls to be sent to filename. - :param filename: the name of the desired log file. - :param log_level: as defined in get_logger. Can be specified + :param filename: The name of the desired log file. + :param log_level: As defined in get_logger. Can be specified to allow the file to store more or less verbose logging information. + :param logger: If supplied, a logger to add the file stream logging + behavior to. By default, a new logger is instantiated. """ - logger = logging.getLogger("SmartSim") + if logger is None: + logger = logging.getLogger("SmartSim") stream = open( # pylint: disable=consider-using-with filename, "w+", encoding="utf-8" ) diff --git a/tests/dragon/channel.py b/tests/dragon/channel.py index 2348784236..4c46359c2d 100644 --- a/tests/dragon/channel.py +++ b/tests/dragon/channel.py @@ -39,17 +39,15 @@ class FileSystemCommChannel(CommChannelBase): """Passes messages by writing to a file""" - def __init__(self, key: t.Union[bytes, pathlib.Path]) -> None: - """Initialize the FileSystemCommChannel instance + def __init__(self, key: pathlib.Path) -> None: + """Initialize the FileSystemCommChannel instance. - :param key: a path to the root directory of the feature store""" + :param key: a path to the root directory of the feature store + """ self._lock = threading.RLock() - if isinstance(key, pathlib.Path): - super().__init__(key.as_posix().encode("utf-8")) - self._file_path = key - else: - super().__init__(key) - self._file_path = pathlib.Path(key.decode("utf-8")) + + super().__init__(key.as_posix()) + self._file_path = key if not self._file_path.parent.exists(): self._file_path.parent.mkdir(parents=True) @@ -57,10 +55,11 @@ def __init__(self, key: t.Union[bytes, pathlib.Path]) -> None: self._file_path.touch() def send(self, value: bytes, timeout: float = 0) -> None: - """Send a message throuh the underlying communication channel + """Send a message throuh the underlying communication channel. + :param value: The value to send :param timeout: maximum time to wait (in seconds) for messages to send - :param value: The value to send""" + """ with self._lock: # write as text so we can add newlines as delimiters with open(self._file_path, "a") as fp: @@ -69,11 +68,12 @@ def send(self, value: bytes, timeout: float = 0) -> None: logger.debug(f"FileSystemCommChannel {self._file_path} sent message") def recv(self, timeout: float = 0) -> t.List[bytes]: - """Receives message(s) through the underlying communication channel + """Receives message(s) through the underlying communication channel. :param timeout: maximum time to wait (in seconds) for messages to arrive :returns: the received message - :raises SmartSimError: if the descriptor points to a missing file""" + :raises SmartSimError: if the descriptor points to a missing file + """ with self._lock: messages: t.List[bytes] = [] if not self._file_path.exists(): @@ -102,7 +102,7 @@ def recv(self, timeout: float = 0) -> t.List[bytes]: return messages def clear(self) -> None: - """Create an empty file for events""" + """Create an empty file for events.""" if self._file_path.exists(): self._file_path.unlink() self._file_path.touch() @@ -110,17 +110,15 @@ def clear(self) -> None: @classmethod def from_descriptor( cls, - descriptor: t.Union[str, bytes], + descriptor: str, ) -> "FileSystemCommChannel": - """A factory method that creates an instance from a descriptor string + """A factory method that creates an instance from a descriptor string. :param descriptor: The descriptor that uniquely identifies the resource - :returns: An attached FileSystemCommChannel""" + :returns: An attached FileSystemCommChannel + """ try: - if isinstance(descriptor, str): - path = pathlib.Path(descriptor) - else: - path = pathlib.Path(descriptor.decode("utf-8")) + path = pathlib.Path(descriptor) return FileSystemCommChannel(path) except: logger.warning(f"failed to create fs comm channel: {descriptor}") diff --git a/tests/dragon/conftest.py b/tests/dragon/conftest.py new file mode 100644 index 0000000000..d542700175 --- /dev/null +++ b/tests/dragon/conftest.py @@ -0,0 +1,129 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import os +import pathlib +import socket +import subprocess +import sys +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +# isort: off +import dragon.data.ddict.ddict as dragon_ddict +import dragon.infrastructure.policy as dragon_policy +import dragon.infrastructure.process_desc as dragon_process_desc +import dragon.native.process as dragon_process + +from dragon.fli import FLInterface + +# isort: on + +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.storage import dragon_util +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +@pytest.fixture(scope="module") +def the_storage() -> dragon_ddict.DDict: + """Fixture to instantiate a dragon distributed dictionary.""" + return dragon_util.create_ddict(1, 2, 32 * 1024**2) + + +@pytest.fixture(scope="module") +def the_worker_channel() -> DragonFLIChannel: + """Fixture to create a valid descriptor for a worker channel + that can be attached to.""" + channel_ = create_local() + fli_ = FLInterface(main_ch=channel_, manager_ch=None) + comm_channel = DragonFLIChannel(fli_) + return comm_channel + + +@pytest.fixture(scope="module") +def the_backbone( + the_storage: t.Any, the_worker_channel: DragonFLIChannel +) -> BackboneFeatureStore: + """Fixture to create a distributed dragon dictionary and wrap it + in a BackboneFeatureStore. + + :param the_storage: The dragon storage engine to use + :param the_worker_channel: Pre-configured worker channel + """ + + backbone = BackboneFeatureStore(the_storage, allow_reserved_writes=True) + backbone[BackboneFeatureStore.MLI_WORKER_QUEUE] = the_worker_channel.descriptor + + return backbone + + +@pytest.fixture(scope="module") +def backbone_descriptor(the_backbone: BackboneFeatureStore) -> str: + # create a shared backbone featurestore + return the_backbone.descriptor + + +def function_as_dragon_proc( + entrypoint_fn: t.Callable[[t.Any], None], + args: t.List[t.Any], + cpu_affinity: t.List[int], + gpu_affinity: t.List[int], +) -> dragon_process.Process: + """Execute a function as an independent dragon process. + + :param entrypoint_fn: The function to execute + :param args: The arguments for the entrypoint function + :param cpu_affinity: The cpu affinity for the process + :param gpu_affinity: The gpu affinity for the process + :returns: The dragon process handle + """ + options = dragon_process_desc.ProcessOptions(make_inf_channels=True) + local_policy = dragon_policy.Policy( + placement=dragon_policy.Policy.Placement.HOST_NAME, + host_name=socket.gethostname(), + cpu_affinity=cpu_affinity, + gpu_affinity=gpu_affinity, + ) + return dragon_process.Process( + target=entrypoint_fn, + args=args, + cwd=os.getcwd(), + policy=local_policy, + options=options, + stderr=dragon_process.Popen.STDOUT, + stdout=dragon_process.Popen.STDOUT, + ) diff --git a/tests/dragon/test_core_machine_learning_worker.py b/tests/dragon/test_core_machine_learning_worker.py index ed9ac625cd..e9c356b4e0 100644 --- a/tests/dragon/test_core_machine_learning_worker.py +++ b/tests/dragon/test_core_machine_learning_worker.py @@ -34,7 +34,7 @@ import torch import smartsim.error as sse -from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStoreKey +from smartsim._core.mli.infrastructure.storage.feature_store import ModelKey, TensorKey from smartsim._core.mli.infrastructure.worker.worker import ( InferenceRequest, MachineLearningWorkerCore, @@ -98,7 +98,7 @@ def test_fetch_model_disk(persist_torch_model: pathlib.Path, test_dir: str) -> N fsd = feature_store.descriptor feature_store[str(persist_torch_model)] = persist_torch_model.read_bytes() - model_key = FeatureStoreKey(key=key, descriptor=fsd) + model_key = ModelKey(key=key, descriptor=fsd) request = InferenceRequest(model_key=model_key) batch = RequestBatch([request], None, model_key) @@ -116,7 +116,7 @@ def test_fetch_model_disk_missing() -> None: key = "/path/that/doesnt/exist" - model_key = FeatureStoreKey(key=key, descriptor=fsd) + model_key = ModelKey(key=key, descriptor=fsd) request = InferenceRequest(model_key=model_key) batch = RequestBatch([request], None, model_key) @@ -141,7 +141,7 @@ def test_fetch_model_feature_store(persist_torch_model: pathlib.Path) -> None: fsd = feature_store.descriptor feature_store[key] = persist_torch_model.read_bytes() - model_key = FeatureStoreKey(key=key, descriptor=feature_store.descriptor) + model_key = ModelKey(key=key, descriptor=feature_store.descriptor) request = InferenceRequest(model_key=model_key) batch = RequestBatch([request], None, model_key) @@ -159,7 +159,7 @@ def test_fetch_model_feature_store_missing() -> None: feature_store = MemoryFeatureStore() fsd = feature_store.descriptor - model_key = FeatureStoreKey(key=key, descriptor=feature_store.descriptor) + model_key = ModelKey(key=key, descriptor=feature_store.descriptor) request = InferenceRequest(model_key=model_key) batch = RequestBatch([request], None, model_key) @@ -182,7 +182,7 @@ def test_fetch_model_memory(persist_torch_model: pathlib.Path) -> None: fsd = feature_store.descriptor feature_store[key] = persist_torch_model.read_bytes() - model_key = FeatureStoreKey(key=key, descriptor=feature_store.descriptor) + model_key = ModelKey(key=key, descriptor=feature_store.descriptor) request = InferenceRequest(model_key=model_key) batch = RequestBatch([request], None, model_key) @@ -199,11 +199,9 @@ def test_fetch_input_disk(persist_torch_tensor: pathlib.Path) -> None: feature_store = MemoryFeatureStore() fsd = feature_store.descriptor - request = InferenceRequest( - input_keys=[FeatureStoreKey(key=tensor_name, descriptor=fsd)] - ) + request = InferenceRequest(input_keys=[TensorKey(key=tensor_name, descriptor=fsd)]) - model_key = FeatureStoreKey(key="test-model", descriptor=fsd) + model_key = ModelKey(key="test-model", descriptor=fsd) batch = RequestBatch([request], None, model_key) worker = MachineLearningWorkerCore @@ -223,9 +221,9 @@ def test_fetch_input_disk_missing() -> None: fsd = feature_store.descriptor key = "/path/that/doesnt/exist" - request = InferenceRequest(input_keys=[FeatureStoreKey(key=key, descriptor=fsd)]) + request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) - model_key = FeatureStoreKey(key="test-model", descriptor=fsd) + model_key = ModelKey(key="test-model", descriptor=fsd) batch = RequestBatch([request], None, model_key) with pytest.raises(sse.SmartSimError) as ex: @@ -245,14 +243,12 @@ def test_fetch_input_feature_store(persist_torch_tensor: pathlib.Path) -> None: feature_store = MemoryFeatureStore() fsd = feature_store.descriptor - request = InferenceRequest( - input_keys=[FeatureStoreKey(key=tensor_name, descriptor=fsd)] - ) + request = InferenceRequest(input_keys=[TensorKey(key=tensor_name, descriptor=fsd)]) # put model bytes into the feature store feature_store[tensor_name] = persist_torch_tensor.read_bytes() - model_key = FeatureStoreKey(key="test-model", descriptor=fsd) + model_key = ModelKey(key="test-model", descriptor=fsd) batch = RequestBatch([request], None, model_key) fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) @@ -284,13 +280,13 @@ def test_fetch_multi_input_feature_store(persist_torch_tensor: pathlib.Path) -> request = InferenceRequest( input_keys=[ - FeatureStoreKey(key=tensor_name + "1", descriptor=fsd), - FeatureStoreKey(key=tensor_name + "2", descriptor=fsd), - FeatureStoreKey(key=tensor_name + "3", descriptor=fsd), + TensorKey(key=tensor_name + "1", descriptor=fsd), + TensorKey(key=tensor_name + "2", descriptor=fsd), + TensorKey(key=tensor_name + "3", descriptor=fsd), ] ) - model_key = FeatureStoreKey(key="test-model", descriptor=fsd) + model_key = ModelKey(key="test-model", descriptor=fsd) batch = RequestBatch([request], None, model_key) fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) @@ -310,9 +306,9 @@ def test_fetch_input_feature_store_missing() -> None: key = "bad-key" feature_store = MemoryFeatureStore() fsd = feature_store.descriptor - request = InferenceRequest(input_keys=[FeatureStoreKey(key=key, descriptor=fsd)]) + request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) - model_key = FeatureStoreKey(key="test-model", descriptor=fsd) + model_key = ModelKey(key="test-model", descriptor=fsd) batch = RequestBatch([request], None, model_key) with pytest.raises(sse.SmartSimError) as ex: @@ -332,9 +328,9 @@ def test_fetch_input_memory(persist_torch_tensor: pathlib.Path) -> None: key = "test-model" feature_store[key] = persist_torch_tensor.read_bytes() - request = InferenceRequest(input_keys=[FeatureStoreKey(key=key, descriptor=fsd)]) + request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) - model_key = FeatureStoreKey(key="test-model", descriptor=fsd) + model_key = ModelKey(key="test-model", descriptor=fsd) batch = RequestBatch([request], None, model_key) fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) @@ -351,9 +347,9 @@ def test_place_outputs() -> None: # create a key to retrieve from the feature store keys = [ - FeatureStoreKey(key=key_name + "1", descriptor=fsd), - FeatureStoreKey(key=key_name + "2", descriptor=fsd), - FeatureStoreKey(key=key_name + "3", descriptor=fsd), + TensorKey(key=key_name + "1", descriptor=fsd), + TensorKey(key=key_name + "2", descriptor=fsd), + TensorKey(key=key_name + "3", descriptor=fsd), ] data = [b"abcdef", b"ghijkl", b"mnopqr"] @@ -376,6 +372,6 @@ def test_place_outputs() -> None: pytest.param("key", "", id="invalid descriptor"), ], ) -def test_invalid_featurestorekey(key, descriptor) -> None: +def test_invalid_tensorkey(key, descriptor) -> None: with pytest.raises(ValueError): - fsk = FeatureStoreKey(key, descriptor) + fsk = TensorKey(key, descriptor) diff --git a/tests/dragon/test_device_manager.py b/tests/dragon/test_device_manager.py index c58879cb62..d270e921cb 100644 --- a/tests/dragon/test_device_manager.py +++ b/tests/dragon/test_device_manager.py @@ -36,7 +36,8 @@ ) from smartsim._core.mli.infrastructure.storage.feature_store import ( FeatureStore, - FeatureStoreKey, + ModelKey, + TensorKey, ) from smartsim._core.mli.infrastructure.worker.worker import ( ExecuteResult, @@ -116,9 +117,9 @@ def test_device_manager_model_in_request(): worker = MockWorker() - tensor_key = FeatureStoreKey(key="key", descriptor="desc") - output_key = FeatureStoreKey(key="key", descriptor="desc") - model_key = FeatureStoreKey(key="model key", descriptor="desc") + tensor_key = TensorKey(key="key", descriptor="desc") + output_key = TensorKey(key="key", descriptor="desc") + model_key = ModelKey(key="model key", descriptor="desc") request = InferenceRequest( model_key=model_key, @@ -154,9 +155,9 @@ def test_device_manager_model_key(): worker = MockWorker() - tensor_key = FeatureStoreKey(key="key", descriptor="desc") - output_key = FeatureStoreKey(key="key", descriptor="desc") - model_key = FeatureStoreKey(key="model key", descriptor="desc") + tensor_key = TensorKey(key="key", descriptor="desc") + output_key = TensorKey(key="key", descriptor="desc") + model_key = ModelKey(key="model key", descriptor="desc") request = InferenceRequest( model_key=model_key, diff --git a/tests/dragon/test_dragon_backend.py b/tests/dragon/test_dragon_backend.py new file mode 100644 index 0000000000..2b2ef50f99 --- /dev/null +++ b/tests/dragon/test_dragon_backend.py @@ -0,0 +1,307 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import time +import uuid + +import pytest + +dragon = pytest.importorskip("dragon") + + +from smartsim._core.launcher.dragon.dragonBackend import DragonBackend +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.infrastructure.comm.event import ( + OnCreateConsumer, + OnShutdownRequested, +) +from smartsim._core.mli.infrastructure.control.listener import ( + ConsumerRegistrationListener, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.log import get_logger + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +logger = get_logger(__name__) + + +@pytest.fixture(scope="module") +def the_backend() -> DragonBackend: + return DragonBackend(pid=9999) + + +def test_dragonbackend_start_listener(the_backend: DragonBackend): + """Verify the background process listening to consumer registration events + is up and processing messages as expected.""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor + + with pytest.raises(KeyError) as ex: + # we expect the value of the consumer to be empty until + # the listener start-up completes. + backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + + assert "not found" in ex.value.args[0] + + drg_process = the_backend.start_event_listener(cpu_affinity=[], gpu_affinity=[]) + + # # confirm there is a process still running + logger.info(f"Dragon process started: {drg_process}") + assert drg_process is not None, "Backend was unable to start event listener" + assert drg_process.puid != 0, "Process unique ID is empty" + assert drg_process.returncode is None, "Listener terminated early" + + # wait for the event listener to come up + try: + config = backbone.wait_for( + [BackboneFeatureStore.MLI_REGISTRAR_CONSUMER], timeout=30 + ) + # verify result was in the returned configuration map + assert config[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + except Exception: + raise KeyError( + f"Unable to locate {BackboneFeatureStore.MLI_REGISTRAR_CONSUMER}" + "in the backbone" + ) + + # wait_for ensures the normal retrieval will now work, error-free + descriptor = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + assert descriptor is not None + + # register a new listener channel + comm_channel = DragonCommChannel.from_descriptor(descriptor) + mock_descriptor = str(uuid.uuid4()) + event = OnCreateConsumer("test_dragonbackend_start_listener", mock_descriptor, []) + + event_bytes = bytes(event) + comm_channel.send(event_bytes) + + subscriber_list = [] + + # Give the channel time to write the message and the listener time to handle it + for i in range(20): + time.sleep(1) + # Retrieve the subscriber list from the backbone and verify it is updated + if subscriber_list := backbone.notification_channels: + logger.debug(f"The subscriber list was populated after {i} iterations") + break + + assert mock_descriptor in subscriber_list + + # now send a shutdown message to terminate the listener + return_code = drg_process.returncode + + # clean up if the OnShutdownRequested wasn't properly handled + if return_code is None and drg_process.is_alive: + drg_process.kill() + drg_process.join() + + +def test_dragonbackend_backend_consumer(the_backend: DragonBackend): + """Verify the listener background process updates the appropriate + value in the backbone.""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + assert backbone._allow_reserved_writes + + # create listener with `as_service=False` to perform a single loop iteration + listener = ConsumerRegistrationListener(backbone, 1.0, 1.0, as_service=False) + + logger.debug(f"backbone loaded? {listener._backbone}") + logger.debug(f"listener created? {listener}") + + try: + # call the service execute method directly to trigger + # the entire service lifecycle + listener.execute() + + consumer_desc = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + logger.debug(f"MLI_REGISTRAR_CONSUMER: {consumer_desc}") + + assert consumer_desc + except Exception as ex: + logger.info("") + finally: + listener._on_shutdown() + + +def test_dragonbackend_event_handled(the_backend: DragonBackend): + """Verify the event listener process updates the appropriate + value in the backbone when an event is received and again on shutdown. + """ + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + # create the listener to be tested + listener = ConsumerRegistrationListener(backbone, 1.0, 1.0, as_service=False) + + assert listener._backbone, "The listener is not attached to a backbone" + + try: + # set up the listener but don't let the service event loop start + listener._create_eventing() # listener.execute() + + # grab the channel descriptor so we can simulate registrations + channel_desc = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + comm_channel = DragonCommChannel.from_descriptor(channel_desc) + + num_events = 5 + events = [] + for i in range(num_events): + # register some mock consumers using the backend channel + event = OnCreateConsumer( + "test_dragonbackend_event_handled", + f"mock-consumer-descriptor-{uuid.uuid4()}", + [], + ) + event_bytes = bytes(event) + comm_channel.send(event_bytes) + events.append(event) + + # run few iterations of the event loop in case it takes a few cycles to write + for _ in range(20): + listener._on_iteration() + # Grab the value that should be getting updated + notify_consumers = set(backbone.notification_channels) + if len(notify_consumers) == len(events): + logger.info(f"Retrieved all consumers after {i} listen cycles") + break + + # ... and confirm that all the mock consumer descriptors are registered + assert set([e.descriptor for e in events]) == set(notify_consumers) + logger.info(f"Number of registered consumers: {len(notify_consumers)}") + + except Exception as ex: + logger.exception(f"test_dragonbackend_event_handled - exception occurred: {ex}") + assert False + finally: + # shutdown should unregister a registration listener + listener._on_shutdown() + + for i in range(10): + if BackboneFeatureStore.MLI_REGISTRAR_CONSUMER not in backbone: + logger.debug(f"The listener was removed after {i} iterations") + channel_desc = None + break + + # we should see that there is no listener registered + assert not channel_desc, "Listener shutdown failed to clean up the backbone" + + +def test_dragonbackend_shutdown_event(the_backend: DragonBackend): + """Verify the background process shuts down when it receives a + shutdown request.""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + listener = ConsumerRegistrationListener(backbone, 1.0, 1.0, as_service=True) + + # set up the listener but don't let the listener loop start + listener._create_eventing() # listener.execute() + + # grab the channel descriptor so we can publish to it + channel_desc = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + comm_channel = DragonCommChannel.from_descriptor(channel_desc) + + assert listener._consumer.listening, "Listener isn't ready to listen" + + # send a shutdown request... + event = OnShutdownRequested("test_dragonbackend_shutdown_event") + event_bytes = bytes(event) + comm_channel.send(event_bytes, 0.1) + + # execute should encounter the shutdown and exit + listener.execute() + + # ...and confirm the listener is now cancelled + assert not listener._consumer.listening + + +@pytest.mark.parametrize("health_check_frequency", [10, 20]) +def test_dragonbackend_shutdown_on_health_check( + the_backend: DragonBackend, + health_check_frequency: float, +): + """Verify that the event listener automatically shuts down when + a new listener is registered in its place. + + :param health_check_frequency: The expected frequency of service health check + invocations""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + listener = ConsumerRegistrationListener( + backbone, + 1.0, + 1.0, + as_service=True, # allow service to run long enough to health check + health_check_frequency=health_check_frequency, + ) + + # set up the listener but don't let the listener loop start + listener._create_eventing() # listener.execute() + assert listener._consumer.listening, "Listener wasn't ready to listen" + + # Replace the consumer descriptor in the backbone to trigger + # an automatic shutdown + backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] = str(uuid.uuid4()) + + # set the last health check manually to verify the duration + start_at = time.time() + listener._last_health_check = time.time() + + # run execute to let the service trigger health checks + listener.execute() + elapsed = time.time() - start_at + + # confirm the frequency of the health check was honored + assert elapsed >= health_check_frequency + + # ...and confirm the listener is now cancelled + assert ( + not listener._consumer.listening + ), "Listener was not automatically shutdown by the health check" diff --git a/tests/dragon/test_dragon_ddict_utils.py b/tests/dragon/test_dragon_ddict_utils.py new file mode 100644 index 0000000000..c8bf687ef1 --- /dev/null +++ b/tests/dragon/test_dragon_ddict_utils.py @@ -0,0 +1,117 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +dragon = pytest.importorskip("dragon") + +# isort: off +import dragon.data.ddict.ddict as dragon_ddict + +# isort: on + +from smartsim._core.mli.infrastructure.storage import dragon_util +from smartsim.log import get_logger + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +logger = get_logger(__name__) + + +@pytest.mark.parametrize( + "num_nodes, num_managers, mem_per_node", + [ + pytest.param(1, 1, 3 * 1024**2, id="3MB, Bare minimum allocation"), + pytest.param(2, 2, 128 * 1024**2, id="128 MB allocation, 2 nodes, 2 mgr"), + pytest.param(2, 1, 512 * 1024**2, id="512 MB allocation, 2 nodes, 1 mgr"), + ], +) +def test_dragon_storage_util_create_ddict( + num_nodes: int, + num_managers: int, + mem_per_node: int, +): + """Verify that a dragon dictionary is successfully created. + + :param num_nodes: Number of ddict nodes to attempt to create + :param num_managers: Number of managers per node to request + :param num_managers: Memory to allocate per node + """ + ddict = dragon_util.create_ddict(num_nodes, num_managers, mem_per_node) + + assert ddict is not None + + +@pytest.mark.parametrize( + "num_nodes, num_managers, mem_per_node", + [ + pytest.param(-1, 1, 3 * 1024**2, id="Negative Node Count"), + pytest.param(0, 1, 3 * 1024**2, id="Invalid Node Count"), + pytest.param(1, -1, 3 * 1024**2, id="Negative Mgr Count"), + pytest.param(1, 0, 3 * 1024**2, id="Invalid Mgr Count"), + pytest.param(1, 1, -3 * 1024**2, id="Negative Mem Per Node"), + pytest.param(1, 1, (3 * 1024**2) - 1, id="Invalid Mem Per Node"), + pytest.param(1, 1, 0 * 1024**2, id="No Mem Per Node"), + ], +) +def test_dragon_storage_util_create_ddict_validators( + num_nodes: int, + num_managers: int, + mem_per_node: int, +): + """Verify that a dragon dictionary is successfully created. + + :param num_nodes: Number of ddict nodes to attempt to create + :param num_managers: Number of managers per node to request + :param num_managers: Memory to allocate per node + """ + with pytest.raises(ValueError): + dragon_util.create_ddict(num_nodes, num_managers, mem_per_node) + + +def test_dragon_storage_util_get_ddict_descriptor(the_storage: dragon_ddict.DDict): + """Verify that a descriptor is created. + + :param the_storage: A pre-allocated ddict + """ + value = dragon_util.ddict_to_descriptor(the_storage) + + assert isinstance(value, str) + assert len(value) > 0 + + +def test_dragon_storage_util_get_ddict_from_descriptor(the_storage: dragon_ddict.DDict): + """Verify that a ddict is created from a descriptor. + + :param the_storage: A pre-allocated ddict + """ + descriptor = dragon_util.ddict_to_descriptor(the_storage) + + value = dragon_util.descriptor_to_ddict(descriptor) + + assert value is not None + assert isinstance(value, dragon_ddict.DDict) + assert dragon_util.ddict_to_descriptor(value) == descriptor diff --git a/tests/dragon/test_environment_loader.py b/tests/dragon/test_environment_loader.py index e9bcc8dfd9..07b2a45c1c 100644 --- a/tests/dragon/test_environment_loader.py +++ b/tests/dragon/test_environment_loader.py @@ -28,15 +28,15 @@ dragon = pytest.importorskip("dragon") +import dragon.data.ddict.ddict as dragon_ddict import dragon.utils as du -from dragon.channels import Channel -from dragon.data.ddict.ddict import DDict -from dragon.fli import DragonFLIError, FLInterface +from dragon.fli import FLInterface from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader -from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( DragonFeatureStore, ) from smartsim.error.errors import SmartSimError @@ -53,11 +53,12 @@ ], ) def test_environment_loader_attach_fli(content: bytes, monkeypatch: pytest.MonkeyPatch): - """A descriptor can be stored, loaded, and reattached""" - chan = Channel.make_process_local() + """A descriptor can be stored, loaded, and reattached.""" + chan = create_local() queue = FLInterface(main_ch=chan) monkeypatch.setenv( - "_SMARTSIM_REQUEST_QUEUE", du.B64.bytes_to_str(queue.serialize()) + EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, + du.B64.bytes_to_str(queue.serialize()), ) config = EnvironmentConfigLoader( @@ -76,11 +77,12 @@ def test_environment_loader_attach_fli(content: bytes, monkeypatch: pytest.Monke def test_environment_loader_serialize_fli(monkeypatch: pytest.MonkeyPatch): """The serialized descriptors of a loaded and unloaded - queue are the same""" - chan = Channel.make_process_local() + queue are the same.""" + chan = create_local() queue = FLInterface(main_ch=chan) monkeypatch.setenv( - "_SMARTSIM_REQUEST_QUEUE", du.B64.bytes_to_str(queue.serialize()) + EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, + du.B64.bytes_to_str(queue.serialize()), ) config = EnvironmentConfigLoader( @@ -93,8 +95,10 @@ def test_environment_loader_serialize_fli(monkeypatch: pytest.MonkeyPatch): def test_environment_loader_flifails(monkeypatch: pytest.MonkeyPatch): - """An incorrect serialized descriptor will fails to attach""" - monkeypatch.setenv("_SMARTSIM_REQUEST_QUEUE", "randomstring") + """An incorrect serialized descriptor will fails to attach.""" + + monkeypatch.setenv(EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, "randomstring") + config = EnvironmentConfigLoader( featurestore_factory=DragonFeatureStore.from_descriptor, callback_factory=None, @@ -105,11 +109,15 @@ def test_environment_loader_flifails(monkeypatch: pytest.MonkeyPatch): config.get_queue() -def test_environment_loader_backbone_load_dfs(monkeypatch: pytest.MonkeyPatch): +def test_environment_loader_backbone_load_dfs( + monkeypatch: pytest.MonkeyPatch, the_storage: dragon_ddict.DDict +): """Verify the dragon feature store is loaded correctly by the - EnvironmentConfigLoader to demonstrate featurestore_factory correctness""" - feature_store = DragonFeatureStore(DDict()) - monkeypatch.setenv("_SMARTSIM_INFRA_BACKBONE", feature_store.descriptor) + EnvironmentConfigLoader to demonstrate featurestore_factory correctness.""" + feature_store = DragonFeatureStore(the_storage) + monkeypatch.setenv( + EnvironmentConfigLoader.BACKBONE_ENV_VAR, feature_store.descriptor + ) config = EnvironmentConfigLoader( featurestore_factory=DragonFeatureStore.from_descriptor, @@ -123,13 +131,17 @@ def test_environment_loader_backbone_load_dfs(monkeypatch: pytest.MonkeyPatch): assert backbone is not None -def test_environment_variables_not_set(): +def test_environment_variables_not_set(monkeypatch: pytest.MonkeyPatch): """EnvironmentConfigLoader getters return None when environment - variables are not set""" - config = EnvironmentConfigLoader( - featurestore_factory=DragonFeatureStore.from_descriptor, - callback_factory=DragonCommChannel.from_descriptor, - queue_factory=DragonCommChannel.from_descriptor, - ) - assert config.get_backbone() is None - assert config.get_queue() is None + variables are not set.""" + with monkeypatch.context() as patch: + patch.setenv(EnvironmentConfigLoader.BACKBONE_ENV_VAR, "") + patch.setenv(EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, "") + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonCommChannel.from_descriptor, + ) + assert config.get_backbone() is None + assert config.get_queue() is None diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py index 618b00d87e..aacd47b556 100644 --- a/tests/dragon/test_error_handling.py +++ b/tests/dragon/test_error_handling.py @@ -24,6 +24,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import typing as t from unittest.mock import MagicMock import pytest @@ -32,14 +33,13 @@ import multiprocessing as mp -import dragon.utils as du from dragon.channels import Channel from dragon.data.ddict.ddict import DDict from dragon.fli import FLInterface from dragon.mpbridge.queues import DragonQueue +from smartsim._core.mli.comm.channel.channel import CommChannelBase from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel -from smartsim._core.mli.infrastructure.control.device_manager import WorkerDevice from smartsim._core.mli.infrastructure.control.request_dispatcher import ( RequestDispatcher, ) @@ -48,25 +48,30 @@ exception_handler, ) from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( DragonFeatureStore, ) from smartsim._core.mli.infrastructure.storage.feature_store import ( FeatureStore, - FeatureStoreKey, + ModelKey, + TensorKey, ) from smartsim._core.mli.infrastructure.worker.worker import ( ExecuteResult, FetchInputResult, FetchModelResult, - InferenceReply, InferenceRequest, LoadModelResult, + MachineLearningWorkerBase, RequestBatch, TransformInputResult, TransformOutputResult, ) from smartsim._core.mli.message_handler import MessageHandler +from smartsim._core.mli.mli_schemas.response.response_capnp import ResponseBuilder from .utils.channel import FileSystemCommChannel from .utils.worker import IntegratedTorchWorker @@ -75,37 +80,29 @@ pytestmark = pytest.mark.dragon -@pytest.fixture -def backbone_descriptor() -> str: - # create a shared backbone featurestore - feature_store = DragonFeatureStore(DDict()) - return feature_store.descriptor - - -@pytest.fixture -def app_feature_store() -> FeatureStore: +@pytest.fixture(scope="module") +def app_feature_store(the_storage) -> FeatureStore: # create a standalone feature store to mimic a user application putting # data into an application-owned resource (app should not access backbone) - app_fs = DragonFeatureStore(DDict()) + app_fs = DragonFeatureStore(the_storage) return app_fs @pytest.fixture def setup_worker_manager_model_bytes( - test_dir, + test_dir: str, monkeypatch: pytest.MonkeyPatch, backbone_descriptor: str, app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, ): integrated_worker_type = IntegratedTorchWorker - chan = Channel.make_process_local() - queue = FLInterface(main_ch=chan) monkeypatch.setenv( - "_SMARTSIM_REQUEST_QUEUE", du.B64.bytes_to_str(queue.serialize()) + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor ) # Put backbone descriptor into env var for the `EnvironmentConfigLoader` - monkeypatch.setenv("_SMARTSIM_INFRA_BACKBONE", backbone_descriptor) + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) config_loader = EnvironmentConfigLoader( featurestore_factory=DragonFeatureStore.from_descriptor, @@ -113,7 +110,7 @@ def setup_worker_manager_model_bytes( queue_factory=DragonFLIChannel.from_descriptor, ) - dispatcher_task_queue = mp.Queue(maxsize=0) + dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0) worker_manager = WorkerManager( config_loader=config_loader, @@ -123,10 +120,10 @@ def setup_worker_manager_model_bytes( cooldown=3, ) - tensor_key = FeatureStoreKey(key="key", descriptor=app_feature_store.descriptor) - output_key = FeatureStoreKey(key="key", descriptor=app_feature_store.descriptor) + tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) - request = InferenceRequest( + inf_request = InferenceRequest( model_key=None, callback=None, raw_inputs=None, @@ -137,10 +134,10 @@ def setup_worker_manager_model_bytes( batch_size=0, ) - model_id = FeatureStoreKey(key="key", descriptor=app_feature_store.descriptor) + model_id = ModelKey(key="key", descriptor=app_feature_store.descriptor) request_batch = RequestBatch( - [request], + [inf_request], TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), model_id=model_id, ) @@ -155,16 +152,15 @@ def setup_worker_manager_model_key( monkeypatch: pytest.MonkeyPatch, backbone_descriptor: str, app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, ): integrated_worker_type = IntegratedTorchWorker - chan = Channel.make_process_local() - queue = FLInterface(main_ch=chan) monkeypatch.setenv( - "_SMARTSIM_REQUEST_QUEUE", du.B64.bytes_to_str(queue.serialize()) + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor ) # Put backbone descriptor into env var for the `EnvironmentConfigLoader` - monkeypatch.setenv("_SMARTSIM_INFRA_BACKBONE", backbone_descriptor) + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) config_loader = EnvironmentConfigLoader( featurestore_factory=DragonFeatureStore.from_descriptor, @@ -172,7 +168,7 @@ def setup_worker_manager_model_key( queue_factory=DragonFLIChannel.from_descriptor, ) - dispatcher_task_queue = mp.Queue(maxsize=0) + dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0) worker_manager = WorkerManager( config_loader=config_loader, @@ -182,9 +178,9 @@ def setup_worker_manager_model_key( cooldown=3, ) - tensor_key = FeatureStoreKey(key="key", descriptor=app_feature_store.descriptor) - output_key = FeatureStoreKey(key="key", descriptor=app_feature_store.descriptor) - model_id = FeatureStoreKey(key="model key", descriptor=app_feature_store.descriptor) + tensor_key = TensorKey(key="key", descriptor=app_feature_store.descriptor) + output_key = TensorKey(key="key", descriptor=app_feature_store.descriptor) + model_id = ModelKey(key="model key", descriptor=app_feature_store.descriptor) request = InferenceRequest( model_key=model_id, @@ -208,20 +204,19 @@ def setup_worker_manager_model_key( @pytest.fixture def setup_request_dispatcher_model_bytes( - test_dir, + test_dir: str, monkeypatch: pytest.MonkeyPatch, backbone_descriptor: str, app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, ): integrated_worker_type = IntegratedTorchWorker - chan = Channel.make_process_local() - queue = FLInterface(main_ch=chan) monkeypatch.setenv( - "_SMARTSIM_REQUEST_QUEUE", du.B64.bytes_to_str(queue.serialize()) + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor ) # Put backbone descriptor into env var for the `EnvironmentConfigLoader` - monkeypatch.setenv("_SMARTSIM_INFRA_BACKBONE", backbone_descriptor) + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) config_loader = EnvironmentConfigLoader( featurestore_factory=DragonFeatureStore.from_descriptor, @@ -252,20 +247,19 @@ def setup_request_dispatcher_model_bytes( @pytest.fixture def setup_request_dispatcher_model_key( - test_dir, + test_dir: str, monkeypatch: pytest.MonkeyPatch, backbone_descriptor: str, app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, ): integrated_worker_type = IntegratedTorchWorker - chan = Channel.make_process_local() - queue = FLInterface(main_ch=chan) monkeypatch.setenv( - "_SMARTSIM_REQUEST_QUEUE", du.B64.bytes_to_str(queue.serialize()) + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor ) # Put backbone descriptor into env var for the `EnvironmentConfigLoader` - monkeypatch.setenv("_SMARTSIM_INFRA_BACKBONE", backbone_descriptor) + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) config_loader = EnvironmentConfigLoader( featurestore_factory=DragonFeatureStore.from_descriptor, @@ -284,7 +278,7 @@ def setup_request_dispatcher_model_key( tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) model_key = MessageHandler.build_model_key( - key="model key", feature_store_descriptor=app_feature_store.descriptor + key="model key", descriptor=app_feature_store.descriptor ) request = MessageHandler.build_request( test_dir, model_key, [tensor_key], [output_key], [], None @@ -296,8 +290,12 @@ def setup_request_dispatcher_model_key( return request_dispatcher, integrated_worker_type -def mock_pipeline_stage(monkeypatch: pytest.MonkeyPatch, integrated_worker, stage): - def mock_stage(*args, **kwargs): +def mock_pipeline_stage( + monkeypatch: pytest.MonkeyPatch, + integrated_worker: MachineLearningWorkerBase, + stage: str, +) -> t.Callable[[t.Any], ResponseBuilder]: + def mock_stage(*args: t.Any, **kwargs: t.Any) -> None: raise ValueError(f"Simulated error in {stage}") monkeypatch.setattr(integrated_worker, stage, mock_stage) @@ -314,8 +312,10 @@ def mock_stage(*args, **kwargs): mock_reply_channel = MagicMock() mock_reply_channel.send = MagicMock() - def mock_exception_handler(exc, reply_channel, failure_message): - return exception_handler(exc, mock_reply_channel, failure_message) + def mock_exception_handler( + exc: Exception, reply_channel: CommChannelBase, failure_message: str + ) -> None: + exception_handler(exc, mock_reply_channel, failure_message) monkeypatch.setattr( "smartsim._core.mli.infrastructure.control.worker_manager.exception_handler", @@ -362,12 +362,12 @@ def mock_exception_handler(exc, reply_channel, failure_message): ], ) def test_wm_pipeline_stage_errors_handled( - request, - setup_worker_manager, + request: pytest.FixtureRequest, + setup_worker_manager: str, monkeypatch: pytest.MonkeyPatch, stage: str, error_message: str, -): +) -> None: """Ensures that the worker manager does not crash after a failure in various pipeline stages""" worker_manager, integrated_worker_type = request.getfixturevalue( setup_worker_manager @@ -446,12 +446,12 @@ def test_wm_pipeline_stage_errors_handled( ], ) def test_dispatcher_pipeline_stage_errors_handled( - request, - setup_request_dispatcher, + request: pytest.FixtureRequest, + setup_request_dispatcher: str, monkeypatch: pytest.MonkeyPatch, stage: str, error_message: str, -): +) -> None: """Ensures that the request dispatcher does not crash after a failure in various pipeline stages""" request_dispatcher, integrated_worker_type = request.getfixturevalue( setup_request_dispatcher @@ -473,7 +473,7 @@ def test_dispatcher_pipeline_stage_errors_handled( mock_reply_fn.assert_called_with("fail", error_message) -def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch): +def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch) -> None: """Ensures that the worker manager does not crash after a failure in the execute pipeline stage""" @@ -498,3 +498,14 @@ def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch): mock_reply_fn.assert_called_once() mock_reply_fn.assert_called_with("fail", "Failure while fetching the model.") + + +def test_dragon_feature_store_invalid_storage(): + """Verify that attempting to create a DragonFeatureStore without storage fails.""" + storage = None + + with pytest.raises(ValueError) as ex: + DragonFeatureStore(storage) + + assert "storage" in ex.value.args[0].lower() + assert "required" in ex.value.args[0].lower() diff --git a/tests/dragon/test_event_consumer.py b/tests/dragon/test_event_consumer.py new file mode 100644 index 0000000000..8a241bab19 --- /dev/null +++ b/tests/dragon/test_event_consumer.py @@ -0,0 +1,386 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import time +import typing as t +from unittest import mock + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer +from smartsim._core.mli.infrastructure.comm.event import ( + OnCreateConsumer, + OnShutdownRequested, + OnWriteFeatureStore, +) +from smartsim._core.mli.infrastructure.control.listener import ( + ConsumerRegistrationListener, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.log import get_logger + +logger = get_logger(__name__) + +# isort: off +from dragon import fli +from dragon.channels import Channel + +# isort: on + +if t.TYPE_CHECKING: + import conftest + + +# The tests in this file must run in a dragon environment +pytestmark = pytest.mark.dragon + + +def test_eventconsumer_eventpublisher_integration( + the_backbone: t.Any, test_dir: str +) -> None: + """Verify that the publisher and consumer integrate as expected when + multiple publishers and consumers are sending simultaneously. This + test closely tracks the test in tests/test_featurestore_base.py also named + test_eventconsumer_eventpublisher_integration but requires dragon entities. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + capp_channel = DragonCommChannel(create_local()) + back_channel = DragonCommChannel(create_local()) + + wmgr_consumer_descriptor = wmgr_channel.descriptor + capp_consumer_descriptor = capp_channel.descriptor + back_consumer_descriptor = back_channel.descriptor + + # create some consumers to receive messages + wmgr_consumer = EventConsumer( + wmgr_channel, + the_backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + capp_consumer = EventConsumer( + capp_channel, + the_backbone, + ) + back_consumer = EventConsumer( + back_channel, + the_backbone, + filters=[OnCreateConsumer.CONSUMER_CREATED], + ) + + # create some broadcasters to publish messages + mock_worker_mgr = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + mock_client_app = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # register all of the consumers even though the OnCreateConsumer really should + # trigger its registration. event processing is tested elsewhere. + the_backbone.notification_channels = [ + wmgr_consumer_descriptor, + capp_consumer_descriptor, + back_consumer_descriptor, + ] + + # simulate worker manager sending a notification to backend that it's alive + event_1 = OnCreateConsumer( + "test_eventconsumer_eventpublisher_integration", + wmgr_consumer_descriptor, + filters=[], + ) + mock_worker_mgr.send(event_1) + + # simulate the app updating a model a few times + for key in ["key-1", "key-2", "key-1"]: + event = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", + the_backbone.descriptor, + key, + ) + mock_client_app.send(event, timeout=0.1) + + # worker manager should only get updates about feature update + wmgr_messages = wmgr_consumer.recv() + assert len(wmgr_messages) == 3 + + # the backend should only receive messages about consumer creation + back_messages = back_consumer.recv() + assert len(back_messages) == 1 + + # hypothetical app has no filters and will get all events + app_messages = capp_consumer.recv() + assert len(app_messages) == 4 + + +@pytest.mark.parametrize( + " timeout, batch_timeout, exp_err_msg", + [(-1, 1, " timeout"), (1, -1, "batch_timeout")], +) +def test_eventconsumer_invalid_timeout( + timeout: float, + batch_timeout: float, + exp_err_msg: str, + test_dir: str, + the_backbone: BackboneFeatureStore, +) -> None: + """Verify that the event consumer raises an exception + when provided an invalid request timeout. + + :param timeout: The request timeout for the event consumer recv call + :param batch_timeout: The batch timeout for the event consumer recv call + :param exp_err_msg: A unique value from the error message that should be raised + :param the_storage: The dragon storage engine to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + + # create some consumers to receive messages + wmgr_consumer = EventConsumer( + wmgr_channel, + the_backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + + # the consumer should report an error for the invalid timeout value + with pytest.raises(ValueError) as ex: + wmgr_consumer.recv(timeout=timeout, batch_timeout=batch_timeout) + + assert exp_err_msg in ex.value.args[0] + + +def test_eventconsumer_no_event_handler_registered( + the_backbone: t.Any, test_dir: str +) -> None: + """Verify that a consumer discards messages when + on a channel if no handler is registered. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + + # create a consumer to receive messages + wmgr_consumer = EventConsumer(wmgr_channel, the_backbone, event_handler=None) + + # create a broadcasters to publish messages + mock_worker_mgr = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # manually register the consumers since we don't have a backend running + the_backbone.notification_channels = [wmgr_channel.descriptor] + + # simulate the app updating a model a few times + for key in ["key-1", "key-2", "key-1"]: + event = OnWriteFeatureStore( + "test_eventconsumer_no_event_handler_registered", + the_backbone.descriptor, + key, + ) + mock_worker_mgr.send(event, timeout=0.1) + + # run the handler and let it discard messages + for _ in range(15): + wmgr_consumer.listen_once(0.2, 2.0) + + assert wmgr_consumer.listening + + +def test_eventconsumer_no_event_handler_registered_shutdown( + the_backbone: t.Any, test_dir: str +) -> None: + """Verify that a consumer without an event handler + registered still honors shutdown requests. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + capp_channel = DragonCommChannel(create_local()) + + # create a consumers to receive messages + wmgr_consumer = EventConsumer(wmgr_channel, the_backbone) + + # create a broadcaster to publish messages + mock_worker_mgr = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # manually register the consumers since we don't have a backend running + the_backbone.notification_channels = [ + wmgr_channel.descriptor, + capp_channel.descriptor, + ] + + # simulate the app updating a model a few times + for key in ["key-1", "key-2", "key-1"]: + event = OnWriteFeatureStore( + "test_eventconsumer_no_event_handler_registered_shutdown", + the_backbone.descriptor, + key, + ) + mock_worker_mgr.send(event, timeout=0.1) + + event = OnShutdownRequested( + "test_eventconsumer_no_event_handler_registered_shutdown" + ) + mock_worker_mgr.send(event, timeout=0.1) + + # wmgr will stop listening to messages when it is told to stop listening + wmgr_consumer.listen(timeout=0.1, batch_timeout=2.0) + + for _ in range(15): + wmgr_consumer.listen_once(timeout=0.1, batch_timeout=2.0) + + # confirm the messages were processed, discarded, and the shutdown was received + assert wmgr_consumer.listening == False + + +def test_eventconsumer_registration( + the_backbone: t.Any, test_dir: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that a consumer is correctly registered in + the backbone after sending a registration request. Then, + Confirm the consumer is unregistered after sending the + un-register request. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + with monkeypatch.context() as patch: + registrar = ConsumerRegistrationListener( + the_backbone, 1.0, 2.0, as_service=False + ) + + # NOTE: service.execute(as_service=False) will complete the service life- + # cycle and remove the registrar from the backbone, so mock _on_shutdown + disabled_shutdown = mock.MagicMock() + patch.setattr(registrar, "_on_shutdown", disabled_shutdown) + + # initialze registrar resources + registrar.execute() + + # create a consumer that will be registered + wmgr_channel = DragonCommChannel(create_local()) + wmgr_consumer = EventConsumer(wmgr_channel, the_backbone) + + registered_channels = the_backbone.notification_channels + + # trigger the consumer-to-registrar handshake + wmgr_consumer.register() + + current_registrations: t.List[str] = [] + + # have the registrar run a few times to pick up the msg + for i in range(15): + registrar.execute() + current_registrations = the_backbone.notification_channels + if len(current_registrations) != len(registered_channels): + logger.debug(f"The event was processed on iteration {i}") + break + + # confirm the consumer is registered + assert wmgr_channel.descriptor in current_registrations + + # copy old list so we can compare against it. + registered_channels = list(current_registrations) + + # trigger the consumer removal + wmgr_consumer.unregister() + + # have the registrar run a few times to pick up the msg + for i in range(15): + registrar.execute() + current_registrations = the_backbone.notification_channels + if len(current_registrations) != len(registered_channels): + logger.debug(f"The event was processed on iteration {i}") + break + + # confirm the consumer is no longer registered + assert wmgr_channel.descriptor not in current_registrations + + +def test_registrar_teardown( + the_backbone: t.Any, test_dir: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that the consumer registrar removes itself from + the backbone when it shuts down. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + with monkeypatch.context() as patch: + registrar = ConsumerRegistrationListener( + the_backbone, 1.0, 2.0, as_service=False + ) + + # directly initialze registrar resources to avoid service life-cycle + registrar._create_eventing() + + # confirm the registrar is published to the backbone + cfg = the_backbone.wait_for([BackboneFeatureStore.MLI_REGISTRAR_CONSUMER], 10) + assert BackboneFeatureStore.MLI_REGISTRAR_CONSUMER in cfg + + # execute the entire service lifecycle 1x + registrar.execute() + + consumer_found = BackboneFeatureStore.MLI_REGISTRAR_CONSUMER in the_backbone + + for i in range(15): + time.sleep(0.1) + consumer_found = BackboneFeatureStore.MLI_REGISTRAR_CONSUMER in the_backbone + if not consumer_found: + logger.debug(f"Registrar removed from the backbone on iteration {i}") + break + + assert BackboneFeatureStore.MLI_REGISTRAR_CONSUMER not in the_backbone diff --git a/tests/dragon/test_featurestore.py b/tests/dragon/test_featurestore.py new file mode 100644 index 0000000000..019dcde7a0 --- /dev/null +++ b/tests/dragon/test_featurestore.py @@ -0,0 +1,327 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import multiprocessing as mp +import random +import time +import typing as t +import unittest.mock as mock +import uuid + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + time as bbtime, +) +from smartsim.log import get_logger + +logger = get_logger(__name__) + +# isort: off +from dragon import fli +from dragon.channels import Channel + +# isort: on + +if t.TYPE_CHECKING: + import conftest + + +# The tests in this file must run in a dragon environment +pytestmark = pytest.mark.dragon + + +def test_backbone_wait_for_no_keys( + the_backbone: BackboneFeatureStore, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that asking the backbone to wait for a value succeeds + immediately and does not cause a wait to occur if the supplied key + list is empty. + + :param the_backbone: the storage engine to use, prepopulated with + """ + # set a very low timeout to confirm that it does not wait + + with monkeypatch.context() as ctx: + # all keys should be found and the timeout should never be checked. + ctx.setattr(bbtime, "sleep", mock.MagicMock()) + + values = the_backbone.wait_for([]) + assert len(values) == 0 + + # confirm that no wait occurred + bbtime.sleep.assert_not_called() + + +def test_backbone_wait_for_prepopulated( + the_backbone: BackboneFeatureStore, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that asking the backbone to wait for a value succeed + immediately and do not cause a wait to occur if the data exists. + + :param the_backbone: the storage engine to use, prepopulated with + """ + # set a very low timeout to confirm that it does not wait + + with monkeypatch.context() as ctx: + # all keys should be found and the timeout should never be checked. + ctx.setattr(bbtime, "sleep", mock.MagicMock()) + + values = the_backbone.wait_for([BackboneFeatureStore.MLI_WORKER_QUEUE], 0.1) + + # confirm that wait_for with one key returns one value + assert len(values) == 1 + + # confirm that the descriptor is non-null w/some non-trivial value + assert len(values[BackboneFeatureStore.MLI_WORKER_QUEUE]) > 5 + + # confirm that no wait occurred + bbtime.sleep.assert_not_called() + + +def test_backbone_wait_for_prepopulated_dupe( + the_backbone: BackboneFeatureStore, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that asking the backbone to wait for keys that are duplicated + results in a single value being returned for each key. + + :param the_backbone: the storage engine to use, prepopulated with + """ + # set a very low timeout to confirm that it does not wait + + key1, key2 = "key-1", "key-2" + value1, value2 = "i-am-value-1", "i-am-value-2" + the_backbone[key1] = value1 + the_backbone[key2] = value2 + + with monkeypatch.context() as ctx: + # all keys should be found and the timeout should never be checked. + ctx.setattr(bbtime, "sleep", mock.MagicMock()) + + values = the_backbone.wait_for([key1, key2, key1]) # key1 is duplicated + + # confirm that wait_for with one key returns one value + assert len(values) == 2 + assert key1 in values + assert key2 in values + + assert values[key1] == value1 + assert values[key2] == value2 + + +def set_value_after_delay( + descriptor: str, key: str, value: str, delay: float = 5 +) -> None: + """Helper method to persist a random value into the backbone + + :param descriptor: the backbone feature store descriptor to attach to + :param key: the key to write to + :param value: a value to write to the key + :param delay: amount of delay to apply before writing the key + """ + time.sleep(delay) + + backbone = BackboneFeatureStore.from_descriptor(descriptor) + backbone[key] = value + logger.debug(f"set_value_after_delay wrote `{value} to backbone[`{key}`]") + + +@pytest.mark.parametrize( + "delay", + [ + pytest.param( + 0, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 1, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 2, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 4, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 8, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + ], +) +def test_backbone_wait_for_partial_prepopulated( + the_backbone: BackboneFeatureStore, delay: float +) -> None: + """Verify that when data is not all in the backbone, the `wait_for` operation + continues to poll until it finds everything it needs. + + :param the_backbone: the storage engine to use, prepopulated with + :param delay: the number of seconds the second process will wait before + setting the target value in the backbone featurestore + """ + # set a very low timeout to confirm that it does not wait + wait_timeout = 10 + + key, value = str(uuid.uuid4()), str(random.random() * 10) + + logger.debug(f"Starting process to write {key} after {delay}s") + p = mp.Process( + target=set_value_after_delay, args=(the_backbone.descriptor, key, value, delay) + ) + p.start() + + p2 = mp.Process( + target=the_backbone.wait_for, + args=([BackboneFeatureStore.MLI_WORKER_QUEUE, key],), + kwargs={"timeout": wait_timeout}, + ) + p2.start() + + p.join() + p2.join() + + # both values should be written at this time + ret_vals = the_backbone.wait_for( + [key, BackboneFeatureStore.MLI_WORKER_QUEUE, key], 0.1 + ) + # confirm that wait_for with two keys returns two values + assert len(ret_vals) == 2, "values should contain values for both awaited keys" + + # confirm the pre-populated value has the correct output + assert ( + ret_vals[BackboneFeatureStore.MLI_WORKER_QUEUE] == "12345" + ) # mock descriptor value from fixture + + # confirm the population process completed and the awaited value is correct + assert ret_vals[key] == value, "verify order of values " + + +@pytest.mark.parametrize( + "num_keys", + [ + pytest.param( + 0, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 1, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 3, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 7, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 11, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + ], +) +def test_backbone_wait_for_multikey( + the_backbone: BackboneFeatureStore, + num_keys: int, + test_dir: str, +) -> None: + """Verify that asking the backbone to wait for multiple keys results + in that number of values being returned. + + :param the_backbone: the storage engine to use, prepopulated with + :param num_keys: the number of extra keys to set & request in the backbone + """ + # maximum delay allowed for setter processes + max_delay = 5 + + extra_keys = [str(uuid.uuid4()) for _ in range(num_keys)] + extra_values = [str(uuid.uuid4()) for _ in range(num_keys)] + extras = dict(zip(extra_keys, extra_values)) + delays = [random.random() * max_delay for _ in range(num_keys)] + processes = [] + + for key, value, delay in zip(extra_keys, extra_values, delays): + assert delay < max_delay, "write delay exceeds test timeout" + logger.debug(f"Delaying {key} write by {delay} seconds") + p = mp.Process( + target=set_value_after_delay, + args=(the_backbone.descriptor, key, value, delay), + ) + p.start() + processes.append(p) + + p2 = mp.Process( + target=the_backbone.wait_for, + args=(extra_keys,), + kwargs={"timeout": max_delay * 2}, + ) + p2.start() + for p in processes: + p.join(timeout=max_delay * 2) + p2.join( + timeout=max_delay * 2 + ) # give it 10 seconds longer than p2 timeout for backoff + + # use without a wait to verify all values are written + num_keys = len(extra_keys) + actual_values = the_backbone.wait_for(extra_keys, timeout=0.01) + assert len(extra_keys) == num_keys + + # confirm that wait_for returns all the expected values + assert len(actual_values) == num_keys + + # confirm that the returned values match (e.g. are returned in the right order) + for k in extras: + assert extras[k] == actual_values[k] diff --git a/tests/dragon/test_featurestore_base.py b/tests/dragon/test_featurestore_base.py index 932e734c8a..6daceb9061 100644 --- a/tests/dragon/test_featurestore_base.py +++ b/tests/dragon/test_featurestore_base.py @@ -24,20 +24,22 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import pathlib +import time import typing as t import pytest dragon = pytest.importorskip("dragon") -from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( - BackboneFeatureStore, - EventBroadcaster, - EventCategory, - EventConsumer, +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer +from smartsim._core.mli.infrastructure.comm.event import ( OnCreateConsumer, OnWriteFeatureStore, ) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( DragonFeatureStore, ) @@ -55,15 +57,21 @@ pytestmark = pytest.mark.dragon +def boom(*args, **kwargs) -> None: + """Helper function that blows up when used to mock up + some other function.""" + raise Exception(f"you shall not pass! {args}, {kwargs}") + + def test_event_uid() -> None: - """Verify that all events include a unique identifier""" + """Verify that all events include a unique identifier.""" uids: t.Set[str] = set() num_iters = 1000 # generate a bunch of events and keep track all the IDs for i in range(num_iters): - event_a = OnCreateConsumer(str(i)) - event_b = OnWriteFeatureStore(str(i), "key") + event_a = OnCreateConsumer("test_event_uid", str(i), filters=[]) + event_b = OnWriteFeatureStore("test_event_uid", "test_event_uid", str(i)) uids.add(event_a.uid) uids.add(event_b.uid) @@ -74,7 +82,7 @@ def test_event_uid() -> None: def test_mli_reserved_keys_conversion() -> None: """Verify that conversion from a string to an enum member - works as expected""" + works as expected.""" for reserved_key in ReservedKeys: # iterate through all keys and verify `from_string` works @@ -87,7 +95,7 @@ def test_mli_reserved_keys_conversion() -> None: def test_mli_reserved_keys_writes() -> None: """Verify that attempts to write to reserved keys are blocked from a - standard DragonFeatureStore but enabled with the BackboneFeatureStore""" + standard DragonFeatureStore but enabled with the BackboneFeatureStore.""" mock_storage = {} dfs = DragonFeatureStore(mock_storage) @@ -116,10 +124,8 @@ def test_mli_reserved_keys_writes() -> None: def test_mli_consumers_read_by_key() -> None: - """Verify that the value returned from the mli consumers - method is written to the correct key and reads are - allowed via standard dragon feature store. - NOTE: should reserved reads also be blocked""" + """Verify that the value returned from the mli consumers method is written + to the correct key and reads are allowed via standard dragon feature store.""" mock_storage = {} dfs = DragonFeatureStore(mock_storage) @@ -138,7 +144,7 @@ def test_mli_consumers_read_by_key() -> None: def test_mli_consumers_read_by_backbone() -> None: """Verify that the backbone reads the correct location - when using the backbone feature store API instead of mapping API""" + when using the backbone feature store API instead of mapping API.""" mock_storage = {} backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) @@ -152,7 +158,7 @@ def test_mli_consumers_read_by_backbone() -> None: def test_mli_consumers_write_by_backbone() -> None: """Verify that the backbone writes the correct location - when using the backbone feature store API instead of mapping API""" + when using the backbone feature store API instead of mapping API.""" mock_storage = {} backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) @@ -166,10 +172,11 @@ def test_mli_consumers_write_by_backbone() -> None: def test_eventpublisher_broadcast_no_factory(test_dir: str) -> None: """Verify that a broadcast operation without any registered subscribers - succeeds without raising Exceptions + succeeds without raising Exceptions. :param test_dir: pytest fixture automatically generating unique working - directories for individual test outputs""" + directories for individual test outputs + """ storage_path = pathlib.Path(test_dir) / "features" mock_storage = {} consumer_descriptor = storage_path / "test-consumer" @@ -177,7 +184,9 @@ def test_eventpublisher_broadcast_no_factory(test_dir: str) -> None: # NOTE: we're not putting any consumers into the backbone here! backbone = BackboneFeatureStore(mock_storage) - event = OnCreateConsumer(consumer_descriptor) + event = OnCreateConsumer( + "test_eventpublisher_broadcast_no_factory", consumer_descriptor, filters=[] + ) publisher = EventBroadcaster(backbone) num_receivers = 0 @@ -185,7 +194,9 @@ def test_eventpublisher_broadcast_no_factory(test_dir: str) -> None: # publishing this event without any known consumers registered should succeed # but report that it didn't have anybody to send the event to consumer_descriptor = storage_path / f"test-consumer" - event = OnCreateConsumer(consumer_descriptor) + event = OnCreateConsumer( + "test_eventpublisher_broadcast_no_factory", consumer_descriptor, filters=[] + ) num_receivers += publisher.send(event) @@ -201,10 +212,11 @@ def test_eventpublisher_broadcast_no_factory(test_dir: str) -> None: def test_eventpublisher_broadcast_to_empty_consumer_list(test_dir: str) -> None: """Verify that a broadcast operation without any registered subscribers - succeeds without raising Exceptions + succeeds without raising Exceptions. :param test_dir: pytest fixture automatically generating unique working - directories for individual test outputs""" + directories for individual test outputs + """ storage_path = pathlib.Path(test_dir) / "features" mock_storage = {} @@ -215,7 +227,11 @@ def test_eventpublisher_broadcast_to_empty_consumer_list(test_dir: str) -> None: backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) backbone.notification_channels = [] - event = OnCreateConsumer(consumer_descriptor) + event = OnCreateConsumer( + "test_eventpublisher_broadcast_to_empty_consumer_list", + consumer_descriptor, + filters=[], + ) publisher = EventBroadcaster( backbone, channel_factory=FileSystemCommChannel.from_descriptor ) @@ -233,10 +249,11 @@ def test_eventpublisher_broadcast_to_empty_consumer_list(test_dir: str) -> None: def test_eventpublisher_broadcast_without_channel_factory(test_dir: str) -> None: """Verify that a broadcast operation reports an error if no channel - factory was supplied for constructing the consumer channels + factory was supplied for constructing the consumer channels. :param test_dir: pytest fixture automatically generating unique working - directories for individual test outputs""" + directories for individual test outputs + """ storage_path = pathlib.Path(test_dir) / "features" mock_storage = {} @@ -247,7 +264,11 @@ def test_eventpublisher_broadcast_without_channel_factory(test_dir: str) -> None backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) backbone.notification_channels = [consumer_descriptor] - event = OnCreateConsumer(consumer_descriptor) + event = OnCreateConsumer( + "test_eventpublisher_broadcast_without_channel_factory", + consumer_descriptor, + filters=[], + ) publisher = EventBroadcaster( backbone, # channel_factory=FileSystemCommChannel.from_descriptor # <--- not supplied @@ -261,10 +282,11 @@ def test_eventpublisher_broadcast_without_channel_factory(test_dir: str) -> None def test_eventpublisher_broadcast_empties_buffer(test_dir: str) -> None: """Verify that a successful broadcast clears messages from the event - buffer when a new message is sent and consumers are registered + buffer when a new message is sent and consumers are registered. :param test_dir: pytest fixture automatically generating unique working - directories for individual test outputs""" + directories for individual test outputs + """ storage_path = pathlib.Path(test_dir) / "features" mock_storage = {} @@ -281,11 +303,17 @@ def test_eventpublisher_broadcast_empties_buffer(test_dir: str) -> None: # mock building up some buffered events num_buffered_events = 14 for i in range(num_buffered_events): - event = OnCreateConsumer(storage_path / f"test-consumer-{str(i)}") + event = OnCreateConsumer( + "test_eventpublisher_broadcast_empties_buffer", + storage_path / f"test-consumer-{str(i)}", + [], + ) publisher._event_buffer.append(bytes(event)) event0 = OnCreateConsumer( - storage_path / f"test-consumer-{str(num_buffered_events + 1)}" + "test_eventpublisher_broadcast_empties_buffer", + storage_path / f"test-consumer-{str(num_buffered_events + 1)}", + [], ) num_receivers = publisher.send(event0) @@ -332,13 +360,21 @@ def test_eventpublisher_broadcast_returns_total_sent( # mock building up some buffered events for i in range(num_buffered): - event = OnCreateConsumer(storage_path / f"test-consumer-{str(i)}") + event = OnCreateConsumer( + "test_eventpublisher_broadcast_returns_total_sent", + storage_path / f"test-consumer-{str(i)}", + [], + ) publisher._event_buffer.append(bytes(event)) assert publisher.num_buffered == num_buffered # this event will trigger clearing anything already in buffer - event0 = OnCreateConsumer(storage_path / f"test-consumer-{num_buffered}") + event0 = OnCreateConsumer( + "test_eventpublisher_broadcast_returns_total_sent", + storage_path / f"test-consumer-{num_buffered}", + [], + ) # num_receivers should contain a number that computes w/all consumers and all events num_receivers = publisher.send(event0) @@ -347,10 +383,11 @@ def test_eventpublisher_broadcast_returns_total_sent( def test_eventpublisher_prune_unused_consumer(test_dir: str) -> None: - """Verify that any unused consumers are pruned each time a new event is sent + """Verify that any unused consumers are pruned each time a new event is sent. :param test_dir: pytest fixture automatically generating unique working - directories for individual test outputs""" + directories for individual test outputs + """ storage_path = pathlib.Path(test_dir) / "features" mock_storage = {} @@ -363,7 +400,11 @@ def test_eventpublisher_prune_unused_consumer(test_dir: str) -> None: backbone, channel_factory=FileSystemCommChannel.from_descriptor ) - event = OnCreateConsumer(consumer_descriptor) + event = OnCreateConsumer( + "test_eventpublisher_prune_unused_consumer", + consumer_descriptor, + filters=[], + ) # the only registered cnosumer is in the event, expect no pruning backbone.notification_channels = (consumer_descriptor,) @@ -377,7 +418,9 @@ def test_eventpublisher_prune_unused_consumer(test_dir: str) -> None: # ... and remove the old descriptor from the backbone when it's looked up backbone.notification_channels = (consumer_descriptor2,) - event = OnCreateConsumer(consumer_descriptor2) + event = OnCreateConsumer( + "test_eventpublisher_prune_unused_consumer", consumer_descriptor2, filters=[] + ) publisher.send(event) @@ -413,12 +456,13 @@ def test_eventpublisher_prune_unused_consumer(test_dir: str) -> None: def test_eventpublisher_serialize_failure( test_dir: str, monkeypatch: pytest.MonkeyPatch ) -> None: - """Verify that errors during message serialization are raised to the caller + """Verify that errors during message serialization are raised to the caller. :param test_dir: pytest fixture automatically generating unique working directories for individual test outputs :param monkeypatch: pytest fixture for modifying behavior of existing code - with mock implementations""" + with mock implementations + """ storage_path = pathlib.Path(test_dir) / "features" storage_path.mkdir(parents=True, exist_ok=True) @@ -433,15 +477,21 @@ def test_eventpublisher_serialize_failure( ) with monkeypatch.context() as patch: - event = OnCreateConsumer(target_descriptor) + event = OnCreateConsumer( + "test_eventpublisher_serialize_failure", target_descriptor, filters=[] + ) # patch the __bytes__ implementation to cause pickling to fail during send - patch.setattr(event, "__bytes__", lambda x: b"abc") + def bad_bytes(self) -> bytes: + return b"abc" + + # this patch causes an attribute error when event pickling is attempted + patch.setattr(event, "__bytes__", bad_bytes) backbone.notification_channels = (target_descriptor,) # send a message into the channel - with pytest.raises(ValueError) as ex: + with pytest.raises(AttributeError) as ex: publisher.send(event) assert "serialize" in ex.value.args[0] @@ -450,12 +500,13 @@ def test_eventpublisher_serialize_failure( def test_eventpublisher_factory_failure( test_dir: str, monkeypatch: pytest.MonkeyPatch ) -> None: - """Verify that errors during channel construction are raised to the caller + """Verify that errors during channel construction are raised to the caller. :param test_dir: pytest fixture automatically generating unique working directories for individual test outputs :param monkeypatch: pytest fixture for modifying behavior of existing code - with mock implementations""" + with mock implementations + """ storage_path = pathlib.Path(test_dir) / "features" storage_path.mkdir(parents=True, exist_ok=True) @@ -471,7 +522,9 @@ def boom(descriptor: str) -> None: publisher = EventBroadcaster(backbone, channel_factory=boom) with monkeypatch.context() as patch: - event = OnCreateConsumer(target_descriptor) + event = OnCreateConsumer( + "test_eventpublisher_factory_failure", target_descriptor, filters=[] + ) backbone.notification_channels = (target_descriptor,) @@ -484,12 +537,13 @@ def boom(descriptor: str) -> None: def test_eventpublisher_failure(test_dir: str, monkeypatch: pytest.MonkeyPatch) -> None: """Verify that unexpected errors during message send are caught and wrapped in a - SmartSimError so they are not propagated directly to the caller + SmartSimError so they are not propagated directly to the caller. :param test_dir: pytest fixture automatically generating unique working directories for individual test outputs :param monkeypatch: pytest fixture for modifying behavior of existing code - with mock implementations""" + with mock implementations + """ storage_path = pathlib.Path(test_dir) / "features" storage_path.mkdir(parents=True, exist_ok=True) @@ -507,7 +561,9 @@ def boom(self) -> None: raise Exception("That was unexpected...") with monkeypatch.context() as patch: - event = OnCreateConsumer(target_descriptor) + event = OnCreateConsumer( + "test_eventpublisher_failure", target_descriptor, filters=[] + ) # patch the _broadcast implementation to cause send to fail after # after the event has been pickled @@ -524,10 +580,11 @@ def boom(self) -> None: def test_eventconsumer_receive(test_dir: str) -> None: - """Verify that a consumer retrieves a message from the given channel + """Verify that a consumer retrieves a message from the given channel. :param test_dir: pytest fixture automatically generating unique working - directories for individual test outputs""" + directories for individual test outputs + """ storage_path = pathlib.Path(test_dir) / "features" storage_path.mkdir(parents=True, exist_ok=True) @@ -538,14 +595,16 @@ def test_eventconsumer_receive(test_dir: str) -> None: backbone = BackboneFeatureStore(mock_storage) comm_channel = FileSystemCommChannel.from_descriptor(target_descriptor) - event = OnCreateConsumer(target_descriptor) + event = OnCreateConsumer( + "test_eventconsumer_receive", target_descriptor, filters=[] + ) # simulate a sent event by writing directly to the input comm channel comm_channel.send(bytes(event)) consumer = EventConsumer(comm_channel, backbone) - all_received: t.List[OnCreateConsumer] = consumer.receive() + all_received: t.List[OnCreateConsumer] = consumer.recv() assert len(all_received) == 1 # verify we received the same event that was raised @@ -555,12 +614,13 @@ def test_eventconsumer_receive(test_dir: str) -> None: @pytest.mark.parametrize("num_sent", [0, 1, 2, 4, 8, 16]) def test_eventconsumer_receive_multi(test_dir: str, num_sent: int) -> None: - """Verify that a consumer retrieves multiple message from the given channel + """Verify that a consumer retrieves multiple message from the given channel. :param test_dir: pytest fixture automatically generating unique working directories for individual test outputs :param num_sent: parameterized value used to vary the number of events - that are enqueued and validations are checked at multiple queue sizes""" + that are enqueued and validations are checked at multiple queue sizes + """ storage_path = pathlib.Path(test_dir) / "features" storage_path.mkdir(parents=True, exist_ok=True) @@ -574,21 +634,24 @@ def test_eventconsumer_receive_multi(test_dir: str, num_sent: int) -> None: # simulate multiple sent events by writing directly to the input comm channel for _ in range(num_sent): - event = OnCreateConsumer(target_descriptor) + event = OnCreateConsumer( + "test_eventconsumer_receive_multi", target_descriptor, filters=[] + ) comm_channel.send(bytes(event)) consumer = EventConsumer(comm_channel, backbone) - all_received: t.List[OnCreateConsumer] = consumer.receive() + all_received: t.List[OnCreateConsumer] = consumer.recv() assert len(all_received) == num_sent def test_eventconsumer_receive_empty(test_dir: str) -> None: """Verify that a consumer receiving an empty message ignores the - message and continues processing + message and continues processing. :param test_dir: pytest fixture automatically generating unique working - directories for individual test outputs""" + directories for individual test outputs + """ storage_path = pathlib.Path(test_dir) / "features" storage_path.mkdir(parents=True, exist_ok=True) @@ -605,7 +668,7 @@ def test_eventconsumer_receive_empty(test_dir: str) -> None: consumer = EventConsumer(comm_channel, backbone) - messages = consumer.receive() + messages = consumer.recv() # the messages array should be empty assert not messages @@ -616,7 +679,8 @@ def test_eventconsumer_eventpublisher_integration(test_dir: str) -> None: multiple publishers and consumers are sending simultaneously. :param test_dir: pytest fixture automatically generating unique working - directories for individual test outputs""" + directories for individual test outputs + """ storage_path = pathlib.Path(test_dir) / "features" storage_path.mkdir(parents=True, exist_ok=True) @@ -628,15 +692,15 @@ def test_eventconsumer_eventpublisher_integration(test_dir: str) -> None: capp_channel = FileSystemCommChannel(storage_path / "test-capp") back_channel = FileSystemCommChannel(storage_path / "test-backend") - wmgr_consumer_descriptor = wmgr_channel.descriptor.decode("utf-8") - capp_consumer_descriptor = capp_channel.descriptor.decode("utf-8") - back_consumer_descriptor = back_channel.descriptor.decode("utf-8") + wmgr_consumer_descriptor = wmgr_channel.descriptor + capp_consumer_descriptor = capp_channel.descriptor + back_consumer_descriptor = back_channel.descriptor # create some consumers to receive messages wmgr_consumer = EventConsumer( wmgr_channel, backbone, - filters=[EventCategory.FEATURE_STORE_WRITTEN], + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], ) capp_consumer = EventConsumer( capp_channel, @@ -645,7 +709,7 @@ def test_eventconsumer_eventpublisher_integration(test_dir: str) -> None: back_consumer = EventConsumer( back_channel, backbone, - filters=[EventCategory.CONSUMER_CREATED], + filters=[OnCreateConsumer.CONSUMER_CREATED], ) # create some broadcasters to publish messages @@ -667,28 +731,38 @@ def test_eventconsumer_eventpublisher_integration(test_dir: str) -> None: ] # simulate worker manager sending a notification to backend that it's alive - event_1 = OnCreateConsumer(wmgr_consumer_descriptor) + event_1 = OnCreateConsumer( + "test_eventconsumer_eventpublisher_integration", + wmgr_consumer_descriptor, + filters=[], + ) mock_worker_mgr.send(event_1) # simulate the app updating a model a few times - event_2 = OnWriteFeatureStore(mock_fs_descriptor, "key-1") - event_3 = OnWriteFeatureStore(mock_fs_descriptor, "key-2") - event_4 = OnWriteFeatureStore(mock_fs_descriptor, "key-1") + event_2 = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", mock_fs_descriptor, "key-1" + ) + event_3 = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", mock_fs_descriptor, "key-2" + ) + event_4 = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", mock_fs_descriptor, "key-1" + ) mock_client_app.send(event_2) mock_client_app.send(event_3) mock_client_app.send(event_4) # worker manager should only get updates about feature update - wmgr_messages = wmgr_consumer.receive() + wmgr_messages = wmgr_consumer.recv() assert len(wmgr_messages) == 3 # the backend should only receive messages about consumer creation - back_messages = back_consumer.receive() + back_messages = back_consumer.recv() assert len(back_messages) == 1 # hypothetical app has no filters and will get all events - app_messages = capp_consumer.receive() + app_messages = capp_consumer.recv() assert len(app_messages) == 4 @@ -702,7 +776,8 @@ def test_eventconsumer_batch_timeout( :param invalid_timeout: any invalid timeout that should fail validation :param test_dir: pytest fixture automatically generating unique working - directories for individual test outputs""" + directories for individual test outputs + """ storage_path = pathlib.Path(test_dir) / "features" storage_path.mkdir(parents=True, exist_ok=True) @@ -713,11 +788,57 @@ def test_eventconsumer_batch_timeout( with pytest.raises(ValueError) as ex: # try to create a consumer w/a max recv size of 0 - EventConsumer( + consumer = EventConsumer( channel, backbone, - filters=[EventCategory.FEATURE_STORE_WRITTEN], - batch_timeout=invalid_timeout, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], ) + consumer.recv(batch_timeout=invalid_timeout) assert "positive" in ex.value.args[0] + + +@pytest.mark.parametrize( + "wait_timeout, exp_wait_max", + [ + # aggregate the 1+1+1 into 3 on remaining parameters + pytest.param(1, 1 + 1 + 1, id="1s wait, 3 cycle steps"), + pytest.param(2, 3 + 2, id="2s wait, 4 cycle steps"), + pytest.param(4, 3 + 2 + 4, id="4s wait, 5 cycle steps"), + pytest.param(9, 3 + 2 + 4 + 8, id="9s wait, 6 cycle steps"), + # aggregate an entire cycle into 16 + pytest.param(19.5, 16 + 3 + 2 + 4, id="20s wait, repeat cycle"), + ], +) +def test_backbone_wait_timeout(wait_timeout: float, exp_wait_max: float) -> None: + """Verify that attempts to attach to the worker queue from the protoclient + timeout in an appropriate amount of time. Note: due to the backoff, we verify + the elapsed time is less than the 15s of a cycle of waits. + + :param wait_timeout: Maximum amount of time (in seconds) to allow the backbone + to wait for the requested value to exist + :param exp_wait_max: Maximum amount of time (in seconds) to set as the upper + bound to allow the delays with backoff to occur + :param storage_for_dragon_fs: the dragon storage engine to use + """ + + # NOTE: exp_wait_time maps to the cycled backoff of [0.1, 0.2, 0.4, 0.8] + # with leeway added (by allowing 1s each for the 0.1 and 0.5 steps) + start_time = time.time() + + storage = {} + backbone = BackboneFeatureStore(storage) + + with pytest.raises(SmartSimError) as ex: + backbone.wait_for(["does-not-exist"], wait_timeout) + + assert "timeout" in str(ex.value.args[0]).lower() + + end_time = time.time() + elapsed = end_time - start_time + + # confirm that we met our timeout + assert elapsed > wait_timeout, f"below configured timeout {wait_timeout}" + + # confirm that the total wait time is aligned with the sleep cycle + assert elapsed < exp_wait_max, f"above expected max wait {exp_wait_max}" diff --git a/tests/dragon/test_featurestore_integration.py b/tests/dragon/test_featurestore_integration.py index 59801eebe2..23fdc55ab6 100644 --- a/tests/dragon/test_featurestore_integration.py +++ b/tests/dragon/test_featurestore_integration.py @@ -30,21 +30,17 @@ dragon = pytest.importorskip("dragon") -from smartsim._core.mli.comm.channel.dragon_channel import ( +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_util import ( DEFAULT_CHANNEL_BUFFER_SIZE, - DragonCommChannel, create_local, ) -from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer +from smartsim._core.mli.infrastructure.comm.event import OnWriteFeatureStore from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( BackboneFeatureStore, - EventBroadcaster, - EventCategory, - EventConsumer, - OnCreateConsumer, - OnWriteFeatureStore, ) -from smartsim._core.mli.infrastructure.storage.dragon_feature_store import dragon_ddict # isort: off from dragon.channels import Channel @@ -59,187 +55,135 @@ pytestmark = pytest.mark.dragon -@pytest.fixture -def storage_for_dragon_fs() -> t.Dict[str, str]: - return dragon_ddict.DDict() - - -def test_eventconsumer_eventpublisher_integration( - storage_for_dragon_fs: t.Any, test_dir: str -) -> None: - """Verify that the publisher and consumer integrate as expected when - multiple publishers and consumers are sending simultaneously. This - test closely tracks the test in tests/test_featurestore.py also named - test_eventconsumer_eventpublisher_integration but requires dragon entities - - :param storage_for_dragon_fs: the dragon storage engine to use - :param test_dir: pytest fixture automatically generating unique working - directories for individual test outputs""" - - mock_storage = storage_for_dragon_fs - backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) - mock_fs_descriptor = backbone.descriptor - - # verify ability to write and read from ddict - backbone["test_dir"] = test_dir - assert backbone["test_dir"] == test_dir - - wmgr_channel_ = Channel.make_process_local() - capp_channel_ = Channel.make_process_local() - back_channel_ = Channel.make_process_local() - +@pytest.fixture(scope="module") +def the_worker_channel() -> DragonCommChannel: + """Fixture to create a valid descriptor for a worker channel + that can be attached to.""" + wmgr_channel_ = create_local() wmgr_channel = DragonCommChannel(wmgr_channel_) - capp_channel = DragonCommChannel(capp_channel_) - back_channel = DragonCommChannel(back_channel_) - - wmgr_consumer_descriptor = wmgr_channel.descriptor_string - capp_consumer_descriptor = capp_channel.descriptor_string - back_consumer_descriptor = back_channel.descriptor_string - - # create some consumers to receive messages - wmgr_consumer = EventConsumer( - wmgr_channel, - backbone, - filters=[EventCategory.FEATURE_STORE_WRITTEN], - ) - capp_consumer = EventConsumer( - capp_channel, - backbone, - ) - back_consumer = EventConsumer( - back_channel, - backbone, - filters=[EventCategory.CONSUMER_CREATED], - ) - - # create some broadcasters to publish messages - mock_worker_mgr = EventBroadcaster( - backbone, - channel_factory=DragonCommChannel.from_descriptor, - ) - mock_client_app = EventBroadcaster( - backbone, - channel_factory=DragonCommChannel.from_descriptor, - ) - - # register all of the consumers even though the OnCreateConsumer really should - # trigger its registration. event processing is tested elsewhere. - backbone.notification_channels = [ - wmgr_consumer_descriptor, - capp_consumer_descriptor, - back_consumer_descriptor, - ] - - # simulate worker manager sending a notification to backend that it's alive - event_1 = OnCreateConsumer(wmgr_consumer_descriptor) - mock_worker_mgr.send(event_1) - - # simulate the app updating a model a few times - for key in ["key-1", "key-2", "key-1"]: - event = OnWriteFeatureStore(backbone.descriptor, key) - mock_client_app.send(event, timeout=0.1) - - # worker manager should only get updates about feature update - wmgr_messages = wmgr_consumer.receive() - assert len(wmgr_messages) == 3 - - # the backend should only receive messages about consumer creation - back_messages = back_consumer.receive() - assert len(back_messages) == 1 - - # hypothetical app has no filters and will get all events - app_messages = capp_consumer.receive() - assert len(app_messages) == 4 + return wmgr_channel @pytest.mark.parametrize( - "num_events, batch_timeout", + "num_events, batch_timeout, max_batches_expected", [ - pytest.param(1, 1.0, id="under 1s timeout"), - pytest.param(20, 1.0, id="test 1s timeout w/20"), - pytest.param(50, 1.0, id="test 1s timeout w/50"), - pytest.param(60, 0.1, id="small batches"), - pytest.param(100, 0.1, id="many small batches"), + pytest.param(1, 1.0, 2, id="under 1s timeout"), + pytest.param(20, 1.0, 3, id="test 1s timeout 20x"), + pytest.param(30, 0.2, 5, id="test 0.2s timeout 30x"), + pytest.param(60, 0.4, 4, id="small batches"), + pytest.param(100, 0.1, 10, id="many small batches"), ], ) def test_eventconsumer_max_dequeue( num_events: int, batch_timeout: float, - storage_for_dragon_fs: t.Any, + max_batches_expected: int, + the_worker_channel: DragonCommChannel, + the_backbone: BackboneFeatureStore, ) -> None: """Verify that a consumer does not sit and collect messages indefinitely - by checking that a consumer returns after a maximum timeout is exceeded - - :param num_events: the total number of events to raise in the test - :param batch_timeout: the maximum wait time for a message to be sent. - :param storage_for_dragon_fs: the dragon storage engine to use""" + by checking that a consumer returns after a maximum timeout is exceeded. - mock_storage = storage_for_dragon_fs - backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) - - wmgr_channel_ = Channel.make_process_local() - wmgr_channel = DragonCommChannel(wmgr_channel_) - wmgr_consumer_descriptor = wmgr_channel.descriptor_string + :param num_events: Total number of events to raise in the test + :param batch_timeout: Maximum wait time (in seconds) for a message to be sent + :param max_batches_expected: Maximum number of receives that should occur + :param the_storage: Dragon storage engine to use + """ # create some consumers to receive messages wmgr_consumer = EventConsumer( - wmgr_channel, - backbone, - filters=[EventCategory.FEATURE_STORE_WRITTEN], - batch_timeout=batch_timeout, + the_worker_channel, + the_backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], ) # create a broadcaster to publish messages mock_client_app = EventBroadcaster( - backbone, + the_backbone, channel_factory=DragonCommChannel.from_descriptor, ) # register all of the consumers even though the OnCreateConsumer really should # trigger its registration. event processing is tested elsewhere. - backbone.notification_channels = [wmgr_consumer_descriptor] + the_backbone.notification_channels = [the_worker_channel.descriptor] # simulate the app updating a model a lot of times for key in (f"key-{i}" for i in range(num_events)): - event = OnWriteFeatureStore(backbone.descriptor, key) - mock_client_app.send(event, timeout=0.1) + event = OnWriteFeatureStore( + "test_eventconsumer_max_dequeue", the_backbone.descriptor, key + ) + mock_client_app.send(event, timeout=0.01) num_dequeued = 0 + num_batches = 0 - while wmgr_messages := wmgr_consumer.receive(timeout=0.01): + while wmgr_messages := wmgr_consumer.recv( + timeout=0.1, + batch_timeout=batch_timeout, + ): # worker manager should not get more than `max_num_msgs` events num_dequeued += len(wmgr_messages) + num_batches += 1 # make sure we made all the expected dequeue calls and got everything assert num_dequeued == num_events + assert num_batches > 0 + assert num_batches < max_batches_expected, "too many recv calls were made" @pytest.mark.parametrize( "buffer_size", [ - pytest.param(-1, id="use default: 500"), - pytest.param(0, id="use default: 500"), - pytest.param(1, id="non-zero buffer size: 1"), - pytest.param(500, id="buffer size: 500"), - pytest.param(1000, id="buffer size: 1000"), + pytest.param( + -1, + id="replace negative, default to 500", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 0, + id="replace zero, default to 500", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 1, + id="non-zero buffer size: 1", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + # pytest.param(500, id="maximum size edge case: 500"), + pytest.param( + 550, + id="larger than default: 550", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 800, + id="much larger then default: 800", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 1000, + id="very large buffer: 1000, unreliable in dragon-v0.10", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), ], ) def test_channel_buffer_size( buffer_size: int, - storage_for_dragon_fs: t.Any, + the_storage: t.Any, ) -> None: """Verify that a channel used by an EventBroadcaster can buffer messages until a configured maximum value is exceeded. - :param buffer_size: the maximum number of messages allowed in a channel buffer - :param storage_for_dragon_fs: the dragon storage engine to use""" + :param buffer_size: Maximum number of messages allowed in a channel buffer + :param the_storage: The dragon storage engine to use + """ - mock_storage = storage_for_dragon_fs + mock_storage = the_storage backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) wmgr_channel_ = create_local(buffer_size) # <--- vary buffer size wmgr_channel = DragonCommChannel(wmgr_channel_) - wmgr_consumer_descriptor = wmgr_channel.descriptor_string + wmgr_consumer_descriptor = wmgr_channel.descriptor # create a broadcaster to publish messages. create no consumers to # push the number of sent messages past the allotted buffer size @@ -259,9 +203,11 @@ def test_channel_buffer_size( # simulate the app updating a model a lot of times for key in (f"key-{i}" for i in range(buffer_size)): - event = OnWriteFeatureStore(backbone.descriptor, key) - mock_client_app.send(event, timeout=0.1) + event = OnWriteFeatureStore( + "test_channel_buffer_size", backbone.descriptor, key + ) + mock_client_app.send(event, timeout=0.01) # adding 1 more over the configured buffer size should report the error with pytest.raises(Exception) as ex: - mock_client_app.send(event, timeout=0.1) + mock_client_app.send(event, timeout=0.01) diff --git a/tests/dragon/test_inference_reply.py b/tests/dragon/test_inference_reply.py index 1eb137ae61..bdc7be14bc 100644 --- a/tests/dragon/test_inference_reply.py +++ b/tests/dragon/test_inference_reply.py @@ -28,7 +28,7 @@ dragon = pytest.importorskip("dragon") -from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStoreKey +from smartsim._core.mli.infrastructure.storage.feature_store import TensorKey from smartsim._core.mli.infrastructure.worker.worker import InferenceReply from smartsim._core.mli.message_handler import MessageHandler @@ -44,8 +44,8 @@ def inference_reply() -> InferenceReply: @pytest.fixture -def fs_key() -> FeatureStoreKey: - return FeatureStoreKey("key", "descriptor") +def fs_key() -> TensorKey: + return TensorKey("key", "descriptor") @pytest.mark.parametrize( diff --git a/tests/dragon/test_inference_request.py b/tests/dragon/test_inference_request.py index 909d021d6e..f5c8b9bdc7 100644 --- a/tests/dragon/test_inference_request.py +++ b/tests/dragon/test_inference_request.py @@ -28,7 +28,7 @@ dragon = pytest.importorskip("dragon") -from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStoreKey +from smartsim._core.mli.infrastructure.storage.feature_store import TensorKey from smartsim._core.mli.infrastructure.worker.worker import InferenceRequest from smartsim._core.mli.message_handler import MessageHandler @@ -44,8 +44,8 @@ def inference_request() -> InferenceRequest: @pytest.fixture -def fs_key() -> FeatureStoreKey: - return FeatureStoreKey("key", "descriptor") +def fs_key() -> TensorKey: + return TensorKey("key", "descriptor") @pytest.mark.parametrize( diff --git a/tests/dragon/test_protoclient.py b/tests/dragon/test_protoclient.py new file mode 100644 index 0000000000..f84417107d --- /dev/null +++ b/tests/dragon/test_protoclient.py @@ -0,0 +1,313 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import pickle +import time +import typing as t +from unittest.mock import MagicMock + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.event import OnWriteFeatureStore +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +# isort: off +from dragon import fli +from dragon.data.ddict.ddict import DDict + +# from ..ex..high_throughput_inference.mock_app import ProtoClient +from smartsim._core.mli.client.protoclient import ProtoClient + + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +WORK_QUEUE_KEY = BackboneFeatureStore.MLI_WORKER_QUEUE +logger = get_logger(__name__) + + +@pytest.fixture(scope="module") +def the_worker_queue(the_backbone: BackboneFeatureStore) -> DragonFLIChannel: + """Fixture that creates a dragon FLI channel as a stand-in for the + worker queue created by the worker. + + :param the_backbone: The backbone feature store to update + with the worker queue descriptor. + :returns: The attached `DragonFLIChannel` + """ + + # create the FLI + to_worker_channel = create_local() + fli_ = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) + comm_channel = DragonFLIChannel(fli_) + + # store the descriptor in the backbone + the_backbone.worker_queue = comm_channel.descriptor + + try: + comm_channel.send(b"foo") + except Exception as ex: + logger.exception(f"Test send from worker channel failed", exc_info=True) + + return comm_channel + + +@pytest.mark.parametrize( + "backbone_timeout, exp_wait_max", + [ + # aggregate the 1+1+1 into 3 on remaining parameters + pytest.param(0.5, 1 + 1 + 1, id="0.5s wait, 3 cycle steps"), + pytest.param(2, 3 + 2, id="2s wait, 4 cycle steps"), + pytest.param(4, 3 + 2 + 4, id="4s wait, 5 cycle steps"), + ], +) +def test_protoclient_timeout( + backbone_timeout: float, + exp_wait_max: float, + the_backbone: BackboneFeatureStore, + monkeypatch: pytest.MonkeyPatch, +): + """Verify that attempts to attach to the worker queue from the protoclient + timeout in an appropriate amount of time. Note: due to the backoff, we verify + the elapsed time is less than the 15s of a cycle of waits. + + :param backbone_timeout: a timeout for use when configuring a proto client + :param exp_wait_max: a ceiling for the expected time spent waiting for + the timeout + :param the_backbone: a pre-initialized backbone featurestore for setting up + the environment variable required by the client + """ + + # NOTE: exp_wait_time maps to the cycled backoff of [0.1, 0.2, 0.4, 0.8] + # with leeway added (by allowing 1s each for the 0.1 and 0.5 steps) + + with monkeypatch.context() as ctx, pytest.raises(SmartSimError) as ex: + start_time = time.time() + # remove the worker queue value from the backbone if it exists + # to ensure the timeout occurs + the_backbone.pop(BackboneFeatureStore.MLI_WORKER_QUEUE) + + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + + ProtoClient(timing_on=False, backbone_timeout=backbone_timeout) + elapsed = time.time() - start_time + logger.info(f"ProtoClient timeout occurred in {elapsed} seconds") + + # confirm that we met our timeout + assert ( + elapsed >= backbone_timeout + ), f"below configured timeout {backbone_timeout}" + + # confirm that the total wait time is aligned with the sleep cycle + assert elapsed < exp_wait_max, f"above expected max wait {exp_wait_max}" + + +def test_protoclient_initialization_no_backbone( + monkeypatch: pytest.MonkeyPatch, the_worker_queue: DragonFLIChannel +): + """Verify that attempting to start the client without required environment variables + results in an exception. + + :param the_worker_queue: Passing the worker queue fixture to ensure + the worker queue environment is correctly configured. + + NOTE: os.environ[BackboneFeatureStore.MLI_BACKBONE] is not set""" + + with monkeypatch.context() as patch, pytest.raises(SmartSimError) as ex: + patch.setenv(BackboneFeatureStore.MLI_BACKBONE, "") + + ProtoClient(timing_on=False) + + # confirm the missing value error has been raised + assert {"backbone", "configuration"}.issubset(set(ex.value.args[0].split(" "))) + + +def test_protoclient_initialization( + the_backbone: BackboneFeatureStore, + the_worker_queue: DragonFLIChannel, + monkeypatch: pytest.MonkeyPatch, +): + """Verify that attempting to start the client with required env vars results + in a fully initialized client. + + :param the_backbone: a pre-initialized backbone featurestore + :param the_worker_queue: an FLI channel the client will retrieve + from the backbone""" + + with monkeypatch.context() as ctx: + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + # NOTE: rely on `the_worker_queue` fixture to put MLI_WORKER_QUEUE in backbone + + client = ProtoClient(timing_on=False) + + fs_descriptor = the_backbone.descriptor + wq_descriptor = the_worker_queue.descriptor + + # confirm the backbone was attached correctly + assert client._backbone is not None + assert client._backbone.descriptor == fs_descriptor + + # we expect the backbone to add its descriptor to the local env + assert os.environ[BackboneFeatureStore.MLI_BACKBONE] == fs_descriptor + + # confirm the worker queue is created and attached correctly + assert client._to_worker_fli is not None + assert client._to_worker_fli.descriptor == wq_descriptor + + # we expect the worker queue descriptor to be placed into the backbone + # we do NOT expect _from_worker_ch to be placed anywhere. it's a specific callback + assert the_backbone[BackboneFeatureStore.MLI_WORKER_QUEUE] == wq_descriptor + + # confirm the worker channels are created + assert client._from_worker_ch is not None + assert client._to_worker_ch is not None + + # wrap the channels just to easily verify they produces a descriptor + assert DragonCommChannel(client._from_worker_ch).descriptor + assert DragonCommChannel(client._to_worker_ch).descriptor + + # confirm a publisher is created + assert client._publisher is not None + + +def test_protoclient_write_model( + the_backbone: BackboneFeatureStore, + the_worker_queue: DragonFLIChannel, + monkeypatch: pytest.MonkeyPatch, +): + """Verify that writing a model using the client causes the model data to be + written to a feature store. + + :param the_backbone: a pre-initialized backbone featurestore + :param the_worker_queue: Passing the worker queue fixture to ensure + the worker queue environment is correctly configured. + from the backbone + """ + + with monkeypatch.context() as ctx: + # we won't actually send here + client = ProtoClient(timing_on=False) + + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + # NOTE: rely on `the_worker_queue` fixture to put MLI_WORKER_QUEUE in backbone + + client = ProtoClient(timing_on=False) + + model_key = "my-model" + model_bytes = b"12345" + + client.set_model(model_key, model_bytes) + + # confirm the client modified the underlying feature store + assert client._backbone[model_key] == model_bytes + + +@pytest.mark.parametrize( + "num_listeners, num_model_updates", + [(1, 1), (1, 4), (2, 4), (16, 4), (64, 8)], +) +def test_protoclient_write_model_notification_sent( + the_backbone: BackboneFeatureStore, + the_worker_queue: DragonFLIChannel, + monkeypatch: pytest.MonkeyPatch, + num_listeners: int, + num_model_updates: int, +): + """Verify that writing a model sends a key-written event. + + :param the_backbone: a pre-initialized backbone featurestore + :param the_worker_queue: an FLI channel the client will retrieve + from the backbone + :param num_listeners: vary the number of registered listeners + to verify that the event is broadcast to everyone + :param num_listeners: vary the number of listeners to register + to verify the broadcast counts messages sent correctly + """ + + # we won't actually send here, but it won't try without registered listeners + listeners = [f"mock-ch-desc-{i}" for i in range(num_listeners)] + + the_backbone[BackboneFeatureStore.MLI_BACKBONE] = the_backbone.descriptor + the_backbone[BackboneFeatureStore.MLI_WORKER_QUEUE] = the_worker_queue.descriptor + the_backbone[BackboneFeatureStore.MLI_NOTIFY_CONSUMERS] = ",".join(listeners) + the_backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] = None + + with monkeypatch.context() as ctx: + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + # NOTE: rely on `the_worker_queue` fixture to put MLI_WORKER_QUEUE in backbone + + client = ProtoClient(timing_on=False) + + publisher = t.cast(EventBroadcaster, client._publisher) + + # mock attaching to a channel given the mock-ch-desc in backbone + mock_send = MagicMock(return_value=None) + mock_comm_channel = MagicMock(**{"send": mock_send}, spec=DragonCommChannel) + mock_get_comm_channel = MagicMock(return_value=mock_comm_channel) + ctx.setattr(publisher, "_get_comm_channel", mock_get_comm_channel) + + model_key = "my-model" + model_bytes = b"12345" + + for i in range(num_model_updates): + client.set_model(model_key, model_bytes) + + # confirm that a listener channel was attached + # once for each registered listener in backbone + assert mock_get_comm_channel.call_count == num_listeners * num_model_updates + + # confirm the client raised the key-written event + assert ( + mock_send.call_count == num_listeners * num_model_updates + ), f"Expected {num_listeners} sends with {num_listeners} registrations" + + # with at least 1 consumer registered, we can verify the message is sent + for call_args in mock_send.call_args_list: + send_args = call_args.args + event_bytes, timeout = send_args[0], send_args[1] + + assert event_bytes, "Expected event bytes to be supplied to send" + assert ( + timeout == 0.001 + ), "Expected default timeout on call to `publisher.send`, " + + # confirm the correct event was raised + event = t.cast( + OnWriteFeatureStore, + pickle.loads(event_bytes), + ) + assert event.descriptor == the_backbone.descriptor + assert event.key == model_key diff --git a/tests/dragon/test_reply_building.py b/tests/dragon/test_reply_building.py index 063200dd64..48493b3c4d 100644 --- a/tests/dragon/test_reply_building.py +++ b/tests/dragon/test_reply_building.py @@ -31,7 +31,6 @@ dragon = pytest.importorskip("dragon") from smartsim._core.mli.infrastructure.control.worker_manager import build_failure_reply -from smartsim._core.mli.infrastructure.worker.worker import InferenceReply if t.TYPE_CHECKING: from smartsim._core.mli.mli_schemas.response.response_capnp import Status diff --git a/tests/dragon/test_request_dispatcher.py b/tests/dragon/test_request_dispatcher.py index ccdbce58c3..70d73e243f 100644 --- a/tests/dragon/test_request_dispatcher.py +++ b/tests/dragon/test_request_dispatcher.py @@ -25,10 +25,8 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import gc -import io -import logging -import pathlib -import socket +import os +import subprocess as sp import time import typing as t from queue import Empty @@ -36,33 +34,27 @@ import numpy as np import pytest -torch = pytest.importorskip("torch") -dragon = pytest.importorskip("dragon") +from . import conftest +from .utils import msg_pump + +pytest.importorskip("dragon") + -import base64 +# isort: off +import dragon import multiprocessing as mp -try: - mp.set_start_method("dragon") -except Exception: - pass +import torch -import os +# isort: on -import dragon.channels as dch -import dragon.infrastructure.policy as dragon_policy -import dragon.infrastructure.process_desc as dragon_process_desc -import dragon.native.process as dragon_process from dragon import fli -from dragon.channels import Channel from dragon.data.ddict.ddict import DDict -from dragon.managed_memory import MemoryAlloc, MemoryPool -from dragon.mpbridge.queues import DragonQueue +from dragon.managed_memory import MemoryAlloc -from smartsim._core.entrypoints.service import Service -from smartsim._core.mli.comm.channel.channel import CommChannelBase from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local from smartsim._core.mli.infrastructure.control.request_dispatcher import ( RequestBatch, RequestDispatcher, @@ -70,210 +62,122 @@ from smartsim._core.mli.infrastructure.control.worker_manager import ( EnvironmentConfigLoader, ) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( DragonFeatureStore, ) -from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker -from smartsim._core.mli.message_handler import MessageHandler from smartsim.log import get_logger -from .feature_store import FileSystemFeatureStore -from .utils.channel import FileSystemCommChannel - logger = get_logger(__name__) + # The tests in this file belong to the dragon group pytestmark = pytest.mark.dragon -def persist_model_file(model_path: pathlib.Path) -> pathlib.Path: - """Create a simple torch model and persist to disk for - testing purposes. - - TODO: remove once unit tests are in place""" - # test_path = pathlib.Path(work_dir) - if not model_path.parent.exists(): - model_path.parent.mkdir(parents=True, exist_ok=True) - - model_path.unlink(missing_ok=True) - - model = torch.nn.Linear(2, 1) - torch.save(model, model_path) - - return model_path +try: + mp.set_start_method("dragon") +except Exception: + pass -def mock_messages( - request_dispatcher_queue: DragonFLIChannel, - feature_store: FeatureStore, - feature_store_root_dir: pathlib.Path, - comm_channel_root_dir: pathlib.Path, +@pytest.mark.parametrize("num_iterations", [4]) +def test_request_dispatcher( + num_iterations: int, + the_storage: DDict, + test_dir: str, ) -> None: - """Mock event producer for triggering the inference pipeline""" - feature_store_root_dir.mkdir(parents=True, exist_ok=True) - comm_channel_root_dir.mkdir(parents=True, exist_ok=True) - - model_path = persist_model_file(feature_store_root_dir.parent / "model_original.pt") - model_bytes = model_path.read_bytes() - model_key = str(feature_store_root_dir / "model_fs.pt") - - feature_store[model_key] = model_bytes - - for iteration_number in range(2): - - channel = Channel.make_process_local() - callback_channel = DragonCommChannel(channel) - - input_path = feature_store_root_dir / f"{iteration_number}/input.pt" - output_path = feature_store_root_dir / f"{iteration_number}/output.pt" - - input_key = str(input_path) - output_key = str(output_path) - - tensor = ( - (iteration_number + 1) * torch.ones((1, 2), dtype=torch.float32) - ).numpy() - fsd = feature_store.descriptor - - tensor_desc = MessageHandler.build_tensor_descriptor( - "c", "float32", list(tensor.shape) - ) - - message_tensor_output_key = MessageHandler.build_tensor_key(output_key, fsd) - message_tensor_input_key = MessageHandler.build_tensor_key(input_key, fsd) - message_model_key = MessageHandler.build_model_key(model_key, fsd) - - request = MessageHandler.build_request( - reply_channel=base64.b64encode(channel.serialize()).decode("utf-8"), - model=message_model_key, - inputs=[tensor_desc], - outputs=[message_tensor_output_key], - output_descriptors=[], - custom_attributes=None, - ) - request_bytes = MessageHandler.serialize_request(request) - with request_dispatcher_queue._fli.sendh( - timeout=None, stream_channel=request_dispatcher_queue._channel - ) as sendh: - sendh.send_bytes(request_bytes) - sendh.send_bytes(tensor.tobytes()) - time.sleep(1) - - -@pytest.fixture -def prepare_environment(test_dir: str) -> pathlib.Path: - """Cleanup prior outputs to run demo repeatedly""" - path = pathlib.Path(f"{test_dir}/workermanager.log") - logging.basicConfig(filename=path.absolute(), level=logging.DEBUG) - return path - - -def service_as_dragon_proc( - service: Service, cpu_affinity: list[int], gpu_affinity: list[int] -) -> dragon_process.Process: - - options = dragon_process_desc.ProcessOptions(make_inf_channels=True) - local_policy = dragon_policy.Policy( - placement=dragon_policy.Policy.Placement.HOST_NAME, - host_name=socket.gethostname(), - cpu_affinity=cpu_affinity, - gpu_affinity=gpu_affinity, - ) - return dragon_process.Process( - target=service.execute, - args=[], - cwd=os.getcwd(), - policy=local_policy, - options=options, - stderr=dragon_process.Popen.STDOUT, - stdout=dragon_process.Popen.STDOUT, - ) - - -def test_request_dispatcher(prepare_environment: pathlib.Path) -> None: """Test the request dispatcher batching and queueing system This also includes setting a queue to disposable, checking that it is no longer referenced by the dispatcher. """ - test_path = prepare_environment - fs_path = test_path / "feature_store" - comm_path = test_path / "comm_store" - - to_worker_channel = dch.Channel.make_process_local() + to_worker_channel = create_local() to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) - to_worker_fli_serialized = to_worker_fli.serialize() + to_worker_fli_comm_ch = DragonFLIChannel(to_worker_fli) + + backbone_fs = BackboneFeatureStore(the_storage, allow_reserved_writes=True) # NOTE: env vars should be set prior to instantiating EnvironmentConfigLoader # or test environment may be unable to send messages w/queue - descriptor = base64.b64encode(to_worker_fli_serialized).decode("utf-8") - os.environ["_SMARTSIM_REQUEST_QUEUE"] = descriptor - - ddict = DDict(1, 2, 4 * 1024**2) - dragon_fs = DragonFeatureStore(ddict) + os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = to_worker_fli_comm_ch.descriptor + os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone_fs.descriptor config_loader = EnvironmentConfigLoader( featurestore_factory=DragonFeatureStore.from_descriptor, callback_factory=DragonCommChannel.from_descriptor, queue_factory=DragonFLIChannel.from_descriptor, ) - integrated_worker_type = TorchWorker request_dispatcher = RequestDispatcher( - batch_timeout=0, + batch_timeout=1000, batch_size=2, config_loader=config_loader, - worker_type=integrated_worker_type, + worker_type=TorchWorker, mem_pool_size=2 * 1024**2, ) worker_queue = config_loader.get_queue() if worker_queue is None: - logger.warn( + logger.warning( "FLI input queue not loaded correctly from config_loader: " f"{config_loader._queue_descriptor}" ) request_dispatcher._on_start() - for _ in range(2): + # put some messages into the work queue for the dispatcher to pickup + channels = [] + processes = [] + for i in range(num_iterations): batch: t.Optional[RequestBatch] = None mem_allocs = [] tensors = [] - fs_path = test_path / f"feature_store" - comm_path = test_path / f"comm_store" - model_key = str(fs_path / "model_fs.pt") - - # create a mock client application to populate the request queue - msg_pump = mp.Process( - target=mock_messages, - args=( - worker_queue, - dragon_fs, - fs_path, - comm_path, - ), - ) - - msg_pump.start() - time.sleep(1) + # NOTE: creating callbacks in test to avoid a local channel being torn + # down when mock_messages terms but before the final response message is sent + + callback_channel = DragonCommChannel.from_local() + channels.append(callback_channel) + + process = conftest.function_as_dragon_proc( + msg_pump.mock_messages, + [ + worker_queue.descriptor, + backbone_fs.descriptor, + i, + callback_channel.descriptor, + ], + [], + [], + ) + processes.append(process) + process.start() + assert process.returncode is None, "The message pump failed to start" - for attempts in range(15): + # give dragon some time to populate the message queues + for i in range(15): try: request_dispatcher._on_iteration() - batch = request_dispatcher.task_queue.get(timeout=1) + batch = request_dispatcher.task_queue.get(timeout=1.0) break except Empty: + time.sleep(2) + logger.warning(f"Task queue is empty on iteration {i}") continue except Exception as exc: + logger.error(f"Task queue exception on iteration {i}") raise exc - try: - assert batch is not None - assert batch.has_valid_requests + assert batch is not None + assert batch.has_valid_requests + + model_key = batch.model_id.key + try: transform_result = batch.inputs for transformed, dims, dtype in zip( transform_result.transformed, @@ -316,8 +220,6 @@ def test_request_dispatcher(prepare_environment: pathlib.Path) -> None: for mem_alloc in mem_allocs: mem_alloc.free() - msg_pump.kill() - request_dispatcher._active_queues[model_key].make_disposable() assert request_dispatcher._active_queues[model_key].can_be_removed diff --git a/tests/dragon/test_torch_worker.py b/tests/dragon/test_torch_worker.py index 9a5ed6309f..2a9e7d01bd 100644 --- a/tests/dragon/test_torch_worker.py +++ b/tests/dragon/test_torch_worker.py @@ -37,7 +37,7 @@ from torch import nn from torch.nn import functional as F -from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStoreKey +from smartsim._core.mli.infrastructure.storage.feature_store import ModelKey from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker from smartsim._core.mli.infrastructure.worker.worker import ( ExecuteResult, @@ -109,7 +109,7 @@ def get_request() -> InferenceRequest: ] return InferenceRequest( - model_key=FeatureStoreKey(key="model", descriptor="xyz"), + model_key=ModelKey(key="model", descriptor="xyz"), callback=None, raw_inputs=tensor_numpy, input_keys=None, diff --git a/tests/dragon/test_worker_manager.py b/tests/dragon/test_worker_manager.py index 1ebc512a50..4047a731fc 100644 --- a/tests/dragon/test_worker_manager.py +++ b/tests/dragon/test_worker_manager.py @@ -34,7 +34,6 @@ torch = pytest.importorskip("torch") dragon = pytest.importorskip("dragon") -import base64 import multiprocessing as mp try: @@ -44,25 +43,26 @@ import os -import dragon.channels as dch +import torch.nn as nn from dragon import fli -from dragon.mpbridge.queues import DragonQueue -from smartsim._core.mli.comm.channel.channel import CommChannelBase from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local from smartsim._core.mli.infrastructure.control.worker_manager import ( EnvironmentConfigLoader, WorkerManager, ) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( DragonFeatureStore, ) -from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore +from smartsim._core.mli.infrastructure.storage.dragon_util import create_ddict from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker from smartsim._core.mli.message_handler import MessageHandler from smartsim.log import get_logger -from .feature_store import FileSystemFeatureStore from .utils.channel import FileSystemCommChannel logger = get_logger(__name__) @@ -70,111 +70,205 @@ pytestmark = pytest.mark.dragon -def persist_model_file(model_path: pathlib.Path) -> pathlib.Path: +class MiniModel(nn.Module): + """A torch model that can be executed by the default torch worker""" + + def __init__(self): + """Initialize the model.""" + super().__init__() + + self._name = "mini-model" + self._net = torch.nn.Linear(2, 1) + + def forward(self, input): + """Execute a forward pass.""" + return self._net(input) + + @property + def bytes(self) -> bytes: + """Retrieve the serialized model + + :returns: The byte stream of the model file + """ + buffer = io.BytesIO() + scripted = torch.jit.trace(self._net, self.get_batch()) + torch.jit.save(scripted, buffer) + return buffer.getvalue() + + @classmethod + def get_batch(cls) -> "torch.Tensor": + """Generate a single batch of data with the correct + shape for inference. + + :returns: The batch as a torch tensor + """ + return torch.randn((100, 2), dtype=torch.float32) + + +def create_model(model_path: pathlib.Path) -> pathlib.Path: """Create a simple torch model and persist to disk for testing purposes. - TODO: remove once unit tests are in place""" - # test_path = pathlib.Path(work_dir) + :param model_path: The path to the torch model file + """ if not model_path.parent.exists(): model_path.parent.mkdir(parents=True, exist_ok=True) model_path.unlink(missing_ok=True) - # model_path = test_path / "basic.pt" - model = torch.nn.Linear(2, 1) - torch.save(model, model_path) + mini_model = MiniModel() + torch.save(mini_model, model_path) return model_path +def load_model() -> bytes: + """Create a simple torch model in memory for testing.""" + mini_model = MiniModel() + return mini_model.bytes + + def mock_messages( - worker_manager_queue: CommChannelBase, - feature_store: FeatureStore, feature_store_root_dir: pathlib.Path, comm_channel_root_dir: pathlib.Path, + kill_queue: mp.Queue, ) -> None: - """Mock event producer for triggering the inference pipeline""" + """Mock event producer for triggering the inference pipeline. + + :param feature_store_root_dir: Path to a directory where a + FileSystemFeatureStore can read & write results + :param comm_channel_root_dir: Path to a directory where a + FileSystemCommChannel can read & write messages + :param kill_queue: Queue used by unit test to stop mock_message process + """ feature_store_root_dir.mkdir(parents=True, exist_ok=True) comm_channel_root_dir.mkdir(parents=True, exist_ok=True) - model_path = persist_model_file(feature_store_root_dir.parent / "model_original.pt") - model_bytes = model_path.read_bytes() - model_key = str(feature_store_root_dir / "model_fs.pt") + iteration_number = 0 + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + backbone = config_loader.get_backbone() - feature_store[model_key] = model_bytes + worker_queue = config_loader.get_queue() + if worker_queue is None: + queue_desc = config_loader._queue_descriptor + logger.warn( + f"FLI input queue not loaded correctly from config_loader: {queue_desc}" + ) - iteration_number = 0 + model_key = "mini-model" + model_bytes = load_model() + backbone[model_key] = model_bytes while True: + if not kill_queue.empty(): + return iteration_number += 1 time.sleep(1) - # 1. for demo, ignore upstream and just put stuff into downstream - # 2. for demo, only one downstream but we'd normally have to filter - # msg content and send to the correct downstream (worker) queue - # timestamp = time.time_ns() - # mock_channel = test_path / f"brainstorm-{timestamp}.txt" - # mock_channel.touch() - - # thread - just look for key (wait for keys) - # call checkpoint, try to get non-persistent key, it blocks - # working set size > 1 has side-effects - # only incurs cost when working set size has been exceeded channel_key = comm_channel_root_dir / f"{iteration_number}/channel.txt" callback_channel = FileSystemCommChannel(pathlib.Path(channel_key)) - input_path = feature_store_root_dir / f"{iteration_number}/input.pt" - output_path = feature_store_root_dir / f"{iteration_number}/output.pt" + batch = MiniModel.get_batch() + shape = batch.shape + batch_bytes = batch.numpy().tobytes() - input_key = str(input_path) - output_key = str(output_path) + logger.debug(f"Model content: {backbone[model_key][:20]}") - buffer = io.BytesIO() - tensor = torch.randn((1, 2), dtype=torch.float32) - torch.save(tensor, buffer) - feature_store[input_key] = buffer.getvalue() - fsd = feature_store.descriptor - - message_tensor_output_key = MessageHandler.build_tensor_key(output_key, fsd) - message_tensor_input_key = MessageHandler.build_tensor_key(input_key, fsd) - message_model_key = MessageHandler.build_model_key(model_key, fsd) + input_descriptor = MessageHandler.build_tensor_descriptor( + "f", "float32", list(shape) + ) + # The first request is always the metadata... request = MessageHandler.build_request( reply_channel=callback_channel.descriptor, - model=message_model_key, - inputs=[message_tensor_input_key], - outputs=[message_tensor_output_key], + model=MessageHandler.build_model(model_bytes, "mini-model", "1.0"), + inputs=[input_descriptor], + outputs=[], output_descriptors=[], custom_attributes=None, ) request_bytes = MessageHandler.serialize_request(request) - worker_manager_queue.send(request_bytes) + fli: DragonFLIChannel = worker_queue + + with fli._fli.sendh(timeout=None, stream_channel=fli._channel) as sendh: + sendh.send_bytes(request_bytes) + sendh.send_bytes(batch_bytes) + + logger.info("published message") + + if iteration_number > 5: + return + + +def mock_mli_infrastructure_mgr() -> None: + """Create resources normally instanatiated by the infrastructure + management portion of the DragonBackend. + """ + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + integrated_worker = TorchWorker + + worker_manager = WorkerManager( + config_loader, + integrated_worker, + as_service=True, + cooldown=10, + device="cpu", + dispatcher_queue=mp.Queue(maxsize=0), + ) + worker_manager.execute() @pytest.fixture def prepare_environment(test_dir: str) -> pathlib.Path: - """Cleanup prior outputs to run demo repeatedly""" + """Cleanup prior outputs to run demo repeatedly. + + :param test_dir: the directory to prepare + :returns: The path to the log file + """ path = pathlib.Path(f"{test_dir}/workermanager.log") logging.basicConfig(filename=path.absolute(), level=logging.DEBUG) return path def test_worker_manager(prepare_environment: pathlib.Path) -> None: - """Test the worker manager""" + """Test the worker manager. + + :param prepare_environment: Pass this fixture to configure + global resources before the worker manager executes + """ test_path = prepare_environment fs_path = test_path / "feature_store" comm_path = test_path / "comm_store" - to_worker_channel = dch.Channel.make_process_local() + mgr_per_node = 1 + num_nodes = 2 + mem_per_node = 128 * 1024**2 + + storage = create_ddict(num_nodes, mgr_per_node, mem_per_node) + backbone = BackboneFeatureStore(storage, allow_reserved_writes=True) + + to_worker_channel = create_local() to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) - to_worker_fli_serialized = to_worker_fli.serialize() - # NOTE: env vars should be set prior to instantiating EnvironmentConfigLoader + to_worker_fli_comm_channel = DragonFLIChannel(to_worker_fli) + + # NOTE: env vars must be set prior to instantiating EnvironmentConfigLoader # or test environment may be unable to send messages w/queue - descriptor = base64.b64encode(to_worker_fli_serialized).decode("utf-8") - os.environ["_SMARTSIM_REQUEST_QUEUE"] = descriptor + os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = ( + to_worker_fli_comm_channel.descriptor + ) + os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor config_loader = EnvironmentConfigLoader( featurestore_factory=DragonFeatureStore.from_descriptor, @@ -197,22 +291,24 @@ def test_worker_manager(prepare_environment: pathlib.Path) -> None: logger.warn( f"FLI input queue not loaded correctly from config_loader: {config_loader._queue_descriptor}" ) + backbone.worker_queue = to_worker_fli_comm_channel.descriptor # create a mock client application to populate the request queue + kill_queue = mp.Queue() msg_pump = mp.Process( target=mock_messages, - args=( - worker_queue, - FileSystemFeatureStore(fs_path), - fs_path, - comm_path, - ), + args=(fs_path, comm_path, kill_queue), ) msg_pump.start() # create a process to execute commands - process = mp.Process(target=worker_manager.execute) + process = mp.Process(target=mock_mli_infrastructure_mgr) + + # let it send some messages before starting the worker manager + msg_pump.join(timeout=5) process.start() + msg_pump.join(timeout=5) + kill_queue.put_nowait("kill!") process.join(timeout=5) - process.kill() msg_pump.kill() + process.kill() diff --git a/tests/dragon/utils/channel.py b/tests/dragon/utils/channel.py index 6cde6258f2..4c46359c2d 100644 --- a/tests/dragon/utils/channel.py +++ b/tests/dragon/utils/channel.py @@ -39,17 +39,15 @@ class FileSystemCommChannel(CommChannelBase): """Passes messages by writing to a file""" - def __init__(self, key: t.Union[bytes, pathlib.Path]) -> None: - """Initialize the FileSystemCommChannel instance + def __init__(self, key: pathlib.Path) -> None: + """Initialize the FileSystemCommChannel instance. - :param key: a path to the root directory of the feature store""" + :param key: a path to the root directory of the feature store + """ self._lock = threading.RLock() - if not isinstance(key, bytes): - super().__init__(key.as_posix().encode("utf-8")) - self._file_path = key - else: - super().__init__(key) - self._file_path = pathlib.Path(key.decode("utf-8")) + + super().__init__(key.as_posix()) + self._file_path = key if not self._file_path.parent.exists(): self._file_path.parent.mkdir(parents=True) @@ -57,10 +55,11 @@ def __init__(self, key: t.Union[bytes, pathlib.Path]) -> None: self._file_path.touch() def send(self, value: bytes, timeout: float = 0) -> None: - """Send a message throuh the underlying communication channel + """Send a message throuh the underlying communication channel. + :param value: The value to send :param timeout: maximum time to wait (in seconds) for messages to send - :param value: The value to send""" + """ with self._lock: # write as text so we can add newlines as delimiters with open(self._file_path, "a") as fp: @@ -69,11 +68,12 @@ def send(self, value: bytes, timeout: float = 0) -> None: logger.debug(f"FileSystemCommChannel {self._file_path} sent message") def recv(self, timeout: float = 0) -> t.List[bytes]: - """Receives message(s) through the underlying communication channel + """Receives message(s) through the underlying communication channel. :param timeout: maximum time to wait (in seconds) for messages to arrive :returns: the received message - :raises SmartSimError: if the descriptor points to a missing file""" + :raises SmartSimError: if the descriptor points to a missing file + """ with self._lock: messages: t.List[bytes] = [] if not self._file_path.exists(): @@ -102,7 +102,7 @@ def recv(self, timeout: float = 0) -> t.List[bytes]: return messages def clear(self) -> None: - """Create an empty file for events""" + """Create an empty file for events.""" if self._file_path.exists(): self._file_path.unlink() self._file_path.touch() @@ -110,18 +110,16 @@ def clear(self) -> None: @classmethod def from_descriptor( cls, - descriptor: t.Union[str, bytes], + descriptor: str, ) -> "FileSystemCommChannel": - """A factory method that creates an instance from a descriptor string + """A factory method that creates an instance from a descriptor string. :param descriptor: The descriptor that uniquely identifies the resource - :returns: An attached FileSystemCommChannel""" + :returns: An attached FileSystemCommChannel + """ try: - if isinstance(descriptor, str): - path = pathlib.Path(descriptor) - else: - path = pathlib.Path(descriptor.decode("utf-8")) + path = pathlib.Path(descriptor) return FileSystemCommChannel(path) except: - logger.warning(f"failed to create fs comm channel: {descriptor!r}") + logger.warning(f"failed to create fs comm channel: {descriptor}") raise diff --git a/tests/dragon/utils/msg_pump.py b/tests/dragon/utils/msg_pump.py new file mode 100644 index 0000000000..8d69e57c63 --- /dev/null +++ b/tests/dragon/utils/msg_pump.py @@ -0,0 +1,225 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import logging +import pathlib +import sys +import time +import typing as t + +import pytest + +pytest.importorskip("torch") +pytest.importorskip("dragon") + + +# isort: off +import dragon +import multiprocessing as mp +import torch +import torch.nn as nn + +# isort: on + +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +logger = get_logger(__name__, log_level=logging.DEBUG) + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + +try: + mp.set_start_method("dragon") +except Exception: + pass + + +class MiniModel(nn.Module): + def __init__(self): + super().__init__() + + self._name = "mini-model" + self._net = torch.nn.Linear(2, 1) + + def forward(self, input): + return self._net(input) + + @property + def bytes(self) -> bytes: + """Returns the model serialized to a byte stream""" + buffer = io.BytesIO() + scripted = torch.jit.trace(self._net, self.get_batch()) + torch.jit.save(scripted, buffer) + return buffer.getvalue() + + @classmethod + def get_batch(cls) -> "torch.Tensor": + return torch.randn((100, 2), dtype=torch.float32) + + +def load_model() -> bytes: + """Create a simple torch model in memory for testing""" + mini_model = MiniModel() + return mini_model.bytes + + +def persist_model_file(model_path: pathlib.Path) -> pathlib.Path: + """Create a simple torch model and persist to disk for + testing purposes. + + :returns: Path to the model file + """ + # test_path = pathlib.Path(work_dir) + if not model_path.parent.exists(): + model_path.parent.mkdir(parents=True, exist_ok=True) + + model_path.unlink(missing_ok=True) + + model = torch.nn.Linear(2, 1) + torch.save(model, model_path) + + return model_path + + +def _mock_messages( + dispatch_fli_descriptor: str, + fs_descriptor: str, + parent_iteration: int, + callback_descriptor: str, +) -> None: + """Mock event producer for triggering the inference pipeline.""" + model_key = "mini-model" + # mock_message sends 2 messages, so we offset by 2 * (# of iterations in caller) + offset = 2 * parent_iteration + + feature_store = BackboneFeatureStore.from_descriptor(fs_descriptor) + request_dispatcher_queue = DragonFLIChannel.from_descriptor(dispatch_fli_descriptor) + + feature_store[model_key] = load_model() + + for iteration_number in range(2): + logged_iteration = offset + iteration_number + logger.debug(f"Sending mock message {logged_iteration}") + + output_key = f"output-{iteration_number}" + + tensor = ( + (iteration_number + 1) * torch.ones((1, 2), dtype=torch.float32) + ).numpy() + fsd = feature_store.descriptor + + tensor_desc = MessageHandler.build_tensor_descriptor( + "c", "float32", list(tensor.shape) + ) + + message_tensor_output_key = MessageHandler.build_tensor_key(output_key, fsd) + message_model_key = MessageHandler.build_model_key(model_key, fsd) + + request = MessageHandler.build_request( + reply_channel=callback_descriptor, + model=message_model_key, + inputs=[tensor_desc], + outputs=[message_tensor_output_key], + output_descriptors=[], + custom_attributes=None, + ) + + logger.info(f"Sending request {iteration_number} to request_dispatcher_queue") + request_bytes = MessageHandler.serialize_request(request) + + logger.info("Sending msg_envelope") + + # cuid = request_dispatcher_queue._channel.cuid + # logger.info(f"\tInternal cuid: {cuid}") + + # send the header & body together so they arrive together + try: + request_dispatcher_queue.send_multiple([request_bytes, tensor.tobytes()]) + logger.info(f"\tenvelope 0: {request_bytes[:5]}...") + logger.info(f"\tenvelope 1: {tensor.tobytes()[:5]}...") + except Exception as ex: + logger.exception("Unable to send request envelope") + + logger.info("All messages sent") + + # keep the process alive for an extra 15 seconds to let the processor + # have access to the channels before they're destroyed + for _ in range(15): + time.sleep(1) + + +def mock_messages( + dispatch_fli_descriptor: str, + fs_descriptor: str, + parent_iteration: int, + callback_descriptor: str, +) -> int: + """Mock event producer for triggering the inference pipeline. Used + when starting using multiprocessing.""" + logger.info(f"{dispatch_fli_descriptor=}") + logger.info(f"{fs_descriptor=}") + logger.info(f"{parent_iteration=}") + logger.info(f"{callback_descriptor=}") + + try: + return _mock_messages( + dispatch_fli_descriptor, + fs_descriptor, + parent_iteration, + callback_descriptor, + ) + except Exception as ex: + logger.exception() + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + args = argparse.ArgumentParser() + + args.add_argument("--dispatch-fli-descriptor", type=str) + args.add_argument("--fs-descriptor", type=str) + args.add_argument("--parent-iteration", type=int) + args.add_argument("--callback-descriptor", type=str) + + args = args.parse_args() + + return_code = mock_messages( + args.dispatch_fli_descriptor, + args.fs_descriptor, + args.parent_iteration, + args.callback_descriptor, + ) + sys.exit(return_code) diff --git a/tests/mli/channel.py b/tests/mli/channel.py index 2348784236..4c46359c2d 100644 --- a/tests/mli/channel.py +++ b/tests/mli/channel.py @@ -39,17 +39,15 @@ class FileSystemCommChannel(CommChannelBase): """Passes messages by writing to a file""" - def __init__(self, key: t.Union[bytes, pathlib.Path]) -> None: - """Initialize the FileSystemCommChannel instance + def __init__(self, key: pathlib.Path) -> None: + """Initialize the FileSystemCommChannel instance. - :param key: a path to the root directory of the feature store""" + :param key: a path to the root directory of the feature store + """ self._lock = threading.RLock() - if isinstance(key, pathlib.Path): - super().__init__(key.as_posix().encode("utf-8")) - self._file_path = key - else: - super().__init__(key) - self._file_path = pathlib.Path(key.decode("utf-8")) + + super().__init__(key.as_posix()) + self._file_path = key if not self._file_path.parent.exists(): self._file_path.parent.mkdir(parents=True) @@ -57,10 +55,11 @@ def __init__(self, key: t.Union[bytes, pathlib.Path]) -> None: self._file_path.touch() def send(self, value: bytes, timeout: float = 0) -> None: - """Send a message throuh the underlying communication channel + """Send a message throuh the underlying communication channel. + :param value: The value to send :param timeout: maximum time to wait (in seconds) for messages to send - :param value: The value to send""" + """ with self._lock: # write as text so we can add newlines as delimiters with open(self._file_path, "a") as fp: @@ -69,11 +68,12 @@ def send(self, value: bytes, timeout: float = 0) -> None: logger.debug(f"FileSystemCommChannel {self._file_path} sent message") def recv(self, timeout: float = 0) -> t.List[bytes]: - """Receives message(s) through the underlying communication channel + """Receives message(s) through the underlying communication channel. :param timeout: maximum time to wait (in seconds) for messages to arrive :returns: the received message - :raises SmartSimError: if the descriptor points to a missing file""" + :raises SmartSimError: if the descriptor points to a missing file + """ with self._lock: messages: t.List[bytes] = [] if not self._file_path.exists(): @@ -102,7 +102,7 @@ def recv(self, timeout: float = 0) -> t.List[bytes]: return messages def clear(self) -> None: - """Create an empty file for events""" + """Create an empty file for events.""" if self._file_path.exists(): self._file_path.unlink() self._file_path.touch() @@ -110,17 +110,15 @@ def clear(self) -> None: @classmethod def from_descriptor( cls, - descriptor: t.Union[str, bytes], + descriptor: str, ) -> "FileSystemCommChannel": - """A factory method that creates an instance from a descriptor string + """A factory method that creates an instance from a descriptor string. :param descriptor: The descriptor that uniquely identifies the resource - :returns: An attached FileSystemCommChannel""" + :returns: An attached FileSystemCommChannel + """ try: - if isinstance(descriptor, str): - path = pathlib.Path(descriptor) - else: - path = pathlib.Path(descriptor.decode("utf-8")) + path = pathlib.Path(descriptor) return FileSystemCommChannel(path) except: logger.warning(f"failed to create fs comm channel: {descriptor}") diff --git a/tests/mli/test_default_torch_worker.py b/tests/mli/test_default_torch_worker.py deleted file mode 100644 index b2ec6c3dca..0000000000 --- a/tests/mli/test_default_torch_worker.py +++ /dev/null @@ -1,206 +0,0 @@ -# # BSD 2-Clause License -# # -# # Copyright (c) 2021-2024, Hewlett Packard Enterprise -# # All rights reserved. -# # -# # Redistribution and use in source and binary forms, with or without -# # modification, are permitted provided that the following conditions are met: -# # -# # 1. Redistributions of source code must retain the above copyright notice, this -# # list of conditions and the following disclaimer. -# # -# # 2. Redistributions in binary form must reproduce the above copyright notice, -# # this list of conditions and the following disclaimer in the documentation -# # and/or other materials provided with the distribution. -# # -# # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -# import io -# import pathlib -# import typing as t - -# import pytest -# import torch - -# from smartsim._core.mli.infrastructure.worker.integratedtorchworker import ( -# IntegratedTorchWorker, -# ) -# import smartsim.error as sse -# from smartsim._core.mli.infrastructure import MemoryFeatureStore -# from smartsim._core.mli.infrastructure.worker.worker import ( -# ExecuteResult, -# FetchInputResult, -# FetchModelResult, -# InferenceRequest, -# TransformInputResult, -# LoadModelResult, -# ) -# from smartsim._core.utils import installed_redisai_backends - -# # The tests in this file belong to the group_a group -# pytestmark = pytest.mark.group_b - -# # retrieved from pytest fixtures -# is_dragon = pytest.test_launcher == "dragon" -# torch_available = "torch" in installed_redisai_backends() - - -# @pytest.fixture -# def persist_torch_model(test_dir: str) -> pathlib.Path: -# test_path = pathlib.Path(test_dir) -# model_path = test_path / "basic.pt" - -# model = torch.nn.Linear(2, 1) -# torch.save(model, model_path) - -# return model_path - - -# # def test_deserialize() -> None: -# # """Verify that serialized requests are properly deserialized to -# # and converted to the internal representation used by ML workers""" -# # worker = SampleTorchWorker -# # buffer = io.BytesIO() - -# # exp_model_key = "model-key" -# # msg = InferenceRequest(model_key=exp_model_key) -# # pickle.dump(msg, buffer) - -# # deserialized: InferenceRequest = worker.deserialize(buffer.getvalue()) - -# # assert deserialized.model_key == exp_model_key -# # # assert deserialized.backend == exp_backend - - -# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") -# def test_load_model_from_disk(persist_torch_model: pathlib.Path) -> None: -# """Verify that a model can be loaded using a FileSystemFeatureStore""" -# worker = IntegratedTorchWorker -# request = InferenceRequest(raw_model=persist_torch_model.read_bytes()) - -# fetch_result = FetchModelResult(persist_torch_model.read_bytes()) -# load_result = worker.load_model(request, fetch_result) - -# input = torch.randn(2) -# pred = load_result.model(input) - -# assert pred - - -# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") -# def test_transform_input() -> None: -# """Verify that the default input transform operation is a no-op copy""" -# rows, cols = 1, 4 -# num_values = 7 -# tensors = [torch.randn((rows, cols)) for _ in range(num_values)] - -# request = InferenceRequest() - -# inputs: t.List[bytes] = [] -# for tensor in tensors: -# buffer = io.BytesIO() -# torch.save(tensor, buffer) -# inputs.append(buffer.getvalue()) - -# fetch_result = FetchInputResult(inputs) -# worker = IntegratedTorchWorker -# result = worker.transform_input(request, fetch_result) -# transformed: t.Collection[torch.Tensor] = result.transformed - -# assert len(transformed) == num_values - -# for output, expected in zip(transformed, tensors): -# assert output.shape == expected.shape -# assert output.equal(expected) - -# transformed = list(transformed) - -# original: torch.Tensor = tensors[0] -# assert transformed[0].equal(original) - -# # verify a copy was made -# transformed[0] = 2 * transformed[0] -# assert transformed[0].equal(2 * original) - - -# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") -# def test_execute_model(persist_torch_model: pathlib.Path) -> None: -# """Verify that a model executes corrrectly via the worker""" - -# # put model bytes into memory -# model_name = "test-key" -# feature_store = MemoryFeatureStore() -# feature_store[model_name] = persist_torch_model.read_bytes() - -# worker = IntegratedTorchWorker -# request = InferenceRequest(model_key=model_name) -# fetch_result = FetchModelResult(persist_torch_model.read_bytes()) -# load_result = worker.load_model(request, fetch_result) - -# value = torch.randn(2) -# transform_result = TransformInputResult([value]) - -# execute_result = worker.execute(request, load_result, transform_result) - -# assert execute_result.predictions is not None - - -# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") -# def test_execute_missing_model(persist_torch_model: pathlib.Path) -> None: -# """Verify that a executing a model with an invalid key fails cleanly""" - -# # use key that references an un-set model value -# model_name = "test-key" -# feature_store = MemoryFeatureStore() -# feature_store[model_name] = persist_torch_model.read_bytes() - -# worker = IntegratedTorchWorker -# request = InferenceRequest(input_keys=[model_name]) - -# load_result = LoadModelResult(None) -# transform_result = TransformInputResult( -# [torch.randn(2), torch.randn(2), torch.randn(2)] -# ) - -# with pytest.raises(sse.SmartSimError) as ex: -# worker.execute(request, load_result, transform_result) - -# assert "Model must be loaded" in ex.value.args[0] - - -# @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") -# def test_transform_output() -> None: -# """Verify that the default output transform operation is a no-op copy""" -# rows, cols = 1, 4 -# num_values = 7 -# inputs = [torch.randn((rows, cols)) for _ in range(num_values)] -# exp_outputs = [torch.Tensor(tensor) for tensor in inputs] - -# worker = SampleTorchWorker -# request = InferenceRequest() -# exec_result = ExecuteResult(inputs) - -# result = worker.transform_output(request, exec_result) - -# assert len(result.outputs) == num_values - -# for output, expected in zip(result.outputs, exp_outputs): -# assert output.shape == expected.shape -# assert output.equal(expected) - -# transformed = list(result.outputs) - -# # verify a copy was made -# original: torch.Tensor = inputs[0] -# transformed[0] = 2 * transformed[0] - -# assert transformed[0].equal(2 * original) diff --git a/tests/mli/test_service.py b/tests/mli/test_service.py index 617738f949..3635f6ff78 100644 --- a/tests/mli/test_service.py +++ b/tests/mli/test_service.py @@ -27,6 +27,7 @@ import datetime import multiprocessing as mp import pathlib +import time import typing as t from asyncore import loop @@ -47,23 +48,37 @@ class SimpleService(Service): def __init__( self, log: t.List[str], - quit_after: int = 0, + quit_after: int = -1, as_service: bool = False, - cooldown: int = 0, - loop_delay: int = 0, + cooldown: float = 0, + loop_delay: float = 0, + hc_freq: float = -1, + run_for: float = 0, ) -> None: - super().__init__(as_service, cooldown, loop_delay) + super().__init__(as_service, cooldown, loop_delay, hc_freq) self._log = log self._quit_after = quit_after - self.num_iterations = 0 self.num_starts = 0 self.num_shutdowns = 0 + self.num_health_checks = 0 self.num_cooldowns = 0 - self.num_can_shutdown = 0 self.num_delays = 0 + self.num_iterations = 0 + self.num_can_shutdown = 0 + self.run_for = run_for + self.start_time = time.time() - def _on_iteration(self) -> None: - self.num_iterations += 1 + @property + def runtime(self) -> float: + return time.time() - self.start_time + + def _can_shutdown(self) -> bool: + self.num_can_shutdown += 1 + + if self._quit_after > -1 and self.num_iterations >= self._quit_after: + return True + if self.run_for > 0: + return self.runtime >= self.run_for def _on_start(self) -> None: self.num_starts += 1 @@ -71,16 +86,17 @@ def _on_start(self) -> None: def _on_shutdown(self) -> None: self.num_shutdowns += 1 + def _on_health_check(self) -> None: + self.num_health_checks += 1 + def _on_cooldown_elapsed(self) -> None: self.num_cooldowns += 1 def _on_delay(self) -> None: self.num_delays += 1 - def _can_shutdown(self) -> bool: - self.num_can_shutdown += 1 - if self._quit_after == 0: - return True + def _on_iteration(self) -> None: + self.num_iterations += 1 return self.num_iterations >= self._quit_after @@ -134,6 +150,7 @@ def test_service_run_until_can_shutdown(num_iterations: int) -> None: # no matter what, it should always execute the _on_iteration method assert service.num_iterations == 1 else: + # the shutdown check follows on_iteration. there will be one last call assert service.num_iterations == num_iterations assert service.num_starts == 1 @@ -203,3 +220,71 @@ def test_service_delay(delay: int, num_iterations: int) -> None: assert duration_in_seconds <= expected_duration assert service.num_cooldowns == 0 assert service.num_shutdowns == 1 + + +@pytest.mark.parametrize( + "health_check_freq, run_for", + [ + pytest.param(1, 5.5, id="1s freq, 10x"), + pytest.param(5, 10.5, id="5s freq, 2x"), + pytest.param(0.1, 5.1, id="0.1s freq, 50x"), + ], +) +def test_service_health_check_freq(health_check_freq: float, run_for: float) -> None: + """Verify that a the health check frequency is honored + + :param health_check_freq: The desired frequency of the health check + :pram run_for: A fixed duration to allow the service to run + """ + activity_log: t.List[str] = [] + + service = SimpleService( + activity_log, + quit_after=-1, + as_service=True, + cooldown=0, + hc_freq=health_check_freq, + run_for=run_for, + ) + + ts0 = datetime.datetime.now() + service.execute() + ts1 = datetime.datetime.now() + + # the expected duration is the sum of the delay between each iteration + expected_hc_count = run_for // health_check_freq + + # allow some wiggle room for frequency comparison + assert expected_hc_count - 1 <= service.num_health_checks <= expected_hc_count + 1 + + assert service.num_cooldowns == 0 + assert service.num_shutdowns == 1 + + +def test_service_health_check_freq_unbound() -> None: + """Verify that a health check frequency of zero is treated as + "always on" and is called each loop iteration + + :param health_check_freq: The desired frequency of the health check + :pram run_for: A fixed duration to allow the service to run + """ + health_check_freq: float = 0.0 + run_for: float = 5 + + activity_log: t.List[str] = [] + + service = SimpleService( + activity_log, + quit_after=-1, + as_service=True, + cooldown=0, + hc_freq=health_check_freq, + run_for=run_for, + ) + + service.execute() + + # allow some wiggle room for frequency comparison + assert service.num_health_checks == service.num_iterations + assert service.num_cooldowns == 0 + assert service.num_shutdowns == 1 diff --git a/tests/test_dragon_comm_utils.py b/tests/test_dragon_comm_utils.py new file mode 100644 index 0000000000..a6f9c206a4 --- /dev/null +++ b/tests/test_dragon_comm_utils.py @@ -0,0 +1,257 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import pathlib +import uuid + +import pytest + +from smartsim.error.errors import SmartSimError + +dragon = pytest.importorskip("dragon") + +# isort: off +import dragon.channels as dch +import dragon.infrastructure.parameters as dp +import dragon.managed_memory as dm +import dragon.fli as fli + +# isort: on + +from smartsim._core.mli.comm.channel import dragon_util +from smartsim.log import get_logger + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +logger = get_logger(__name__) + + +@pytest.fixture(scope="function") +def the_pool() -> dm.MemoryPool: + """Creates a memory pool.""" + raw_pool_descriptor = dp.this_process.default_pd + descriptor_ = base64.b64decode(raw_pool_descriptor) + + pool = dm.MemoryPool.attach(descriptor_) + return pool + + +@pytest.fixture(scope="function") +def the_channel() -> dch.Channel: + """Creates a Channel attached to the local memory pool.""" + channel = dch.Channel.make_process_local() + return channel + + +@pytest.fixture(scope="function") +def the_fli(the_channel) -> fli.FLInterface: + """Creates an FLI attached to the local memory pool.""" + fli_ = fli.FLInterface(main_ch=the_channel, manager_ch=None) + return fli_ + + +def test_descriptor_to_channel_empty() -> None: + """Verify that `descriptor_to_channel` raises an exception when + provided with an empty descriptor.""" + descriptor = "" + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + assert "empty" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + ["a", "ab", "abc", "x1", pathlib.Path(".").absolute().as_posix()], +) +def test_descriptor_to_channel_b64fail(descriptor: str) -> None: + """Verify that `descriptor_to_channel` raises an exception when + provided with an incorrectly encoded descriptor. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + assert "base64" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + [str(uuid.uuid4())], +) +def test_descriptor_to_channel_channel_fail(descriptor: str) -> None: + """Verify that `descriptor_to_channel` raises an exception when a correctly + formatted descriptor that does not describe a real channel is passed. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + # ensure we're receiving the right exception + assert "address" in ex.value.args[0] + assert "channel" in ex.value.args[0] + + +def test_descriptor_to_channel_channel_not_available(the_channel: dch.Channel) -> None: + """Verify that `descriptor_to_channel` raises an exception when a channel + is no longer available. + + :param the_channel: A dragon channel + """ + + # get a good descriptor & wipe out the channel so it can't be attached + descriptor = dragon_util.channel_to_descriptor(the_channel) + the_channel.destroy() + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_channel(descriptor) + + assert "address" in ex.value.args[0] + + +def test_descriptor_to_channel_happy_path(the_channel: dch.Channel) -> None: + """Verify that `descriptor_to_channel` works as expected when provided + a valid descriptor + + :param the_channel: A dragon channel + """ + + # get a good descriptor + descriptor = dragon_util.channel_to_descriptor(the_channel) + + reattached = dragon_util.descriptor_to_channel(descriptor) + assert reattached + + # and just make sure creation of the descriptor is transitive + assert dragon_util.channel_to_descriptor(reattached) == descriptor + + +def test_descriptor_to_fli_empty() -> None: + """Verify that `descriptor_to_fli` raises an exception when + provided with an empty descriptor.""" + descriptor = "" + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + assert "empty" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + ["a", "ab", "abc", "x1", pathlib.Path(".").absolute().as_posix()], +) +def test_descriptor_to_fli_b64fail(descriptor: str) -> None: + """Verify that `descriptor_to_fli` raises an exception when + provided with an incorrectly encoded descriptor. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(ValueError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + assert "base64" in ex.value.args[0] + + +@pytest.mark.parametrize( + "descriptor", + [str(uuid.uuid4())], +) +def test_descriptor_to_fli_fli_fail(descriptor: str) -> None: + """Verify that `descriptor_to_fli` raises an exception when a correctly + formatted descriptor that does not describe a real FLI is passed. + + :param descriptor: A descriptor that is not properly base64 encoded + """ + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + # ensure we're receiving the right exception + assert "address" in ex.value.args[0] + assert "fli" in ex.value.args[0].lower() + + +def test_descriptor_to_fli_fli_not_available( + the_fli: fli.FLInterface, the_channel: dch.Channel +) -> None: + """Verify that `descriptor_to_fli` raises an exception when a channel + is no longer available. + + :param the_fli: A dragon FLInterface + :param the_channel: A dragon channel + """ + + # get a good descriptor & wipe out the FLI so it can't be attached + descriptor = dragon_util.channel_to_descriptor(the_fli) + the_fli.destroy() + the_channel.destroy() + + with pytest.raises(SmartSimError) as ex: + dragon_util.descriptor_to_fli(descriptor) + + # ensure we're receiving the right exception + assert "address" in ex.value.args[0] + + +def test_descriptor_to_fli_happy_path(the_fli: dch.Channel) -> None: + """Verify that `descriptor_to_fli` works as expected when provided + a valid descriptor + + :param the_fli: A dragon FLInterface + """ + + # get a good descriptor + descriptor = dragon_util.channel_to_descriptor(the_fli) + + reattached = dragon_util.descriptor_to_fli(descriptor) + assert reattached + + # and just make sure creation of the descriptor is transitive + assert dragon_util.channel_to_descriptor(reattached) == descriptor + + +def test_pool_to_descriptor_empty() -> None: + """Verify that `pool_to_descriptor` raises an exception when + provided with a null pool.""" + + with pytest.raises(ValueError) as ex: + dragon_util.pool_to_descriptor(None) + + +def test_pool_to_happy_path(the_pool) -> None: + """Verify that `pool_to_descriptor` creates a descriptor + when supplied with a valid memory pool.""" + + descriptor = dragon_util.pool_to_descriptor(the_pool) + assert descriptor diff --git a/tests/test_dragon_installer.py b/tests/test_dragon_installer.py index 7b678239a0..b1d8cd34c9 100644 --- a/tests/test_dragon_installer.py +++ b/tests/test_dragon_installer.py @@ -511,10 +511,18 @@ def test_create_dotenv_existing_dotenv(monkeypatch: pytest.MonkeyPatch, test_dir # ensure file was overwritten and env vars are not duplicated dotenv_content = exp_env_path.read_text(encoding="utf-8") - split_content = dotenv_content.split(var_name) - - # split to confirm env var only appars once - assert len(split_content) == 2 + lines = [ + line for line in dotenv_content.split("\n") if line and not "#" in line + ] + for line in lines: + if line.startswith(var_name): + # make sure the var isn't defined recursively + # DRAGON_BASE_DIR=$DRAGON_BASE_DIR + assert var_name not in line[len(var_name) + 1 :] + else: + # make sure any values reference the original base dir var + if var_name in line: + assert f"${var_name}" in line def test_create_dotenv_format(monkeypatch: pytest.MonkeyPatch, test_dir: str): @@ -532,7 +540,7 @@ def test_create_dotenv_format(monkeypatch: pytest.MonkeyPatch, test_dir: str): content = exp_env_path.read_text(encoding="utf-8") # ensure we have values written, but ignore empty lines - lines = [line for line in content.split("\n") if line] + lines = [line for line in content.split("\n") if line and not "#" in line] assert lines # ensure each line is formatted as key=value diff --git a/tests/test_dragon_launcher.py b/tests/test_dragon_launcher.py index 37c46a573b..ea45a2cb71 100644 --- a/tests/test_dragon_launcher.py +++ b/tests/test_dragon_launcher.py @@ -510,7 +510,26 @@ def test_load_env_env_file_created(monkeypatch: pytest.MonkeyPatch, test_dir: st assert loaded_env # confirm .env was parsed as expected by inspecting a key + assert "DRAGON_BASE_DIR" in loaded_env + base_dir = loaded_env["DRAGON_BASE_DIR"] + assert "DRAGON_ROOT_DIR" in loaded_env + assert loaded_env["DRAGON_ROOT_DIR"] == base_dir + + assert "DRAGON_INCLUDE_DIR" in loaded_env + assert loaded_env["DRAGON_INCLUDE_DIR"] == f"{base_dir}/include" + + assert "DRAGON_LIB_DIR" in loaded_env + assert loaded_env["DRAGON_LIB_DIR"] == f"{base_dir}/lib" + + assert "DRAGON_VERSION" in loaded_env + assert loaded_env["DRAGON_VERSION"] == DEFAULT_DRAGON_VERSION + + assert "PATH" in loaded_env + assert loaded_env["PATH"] == f"{base_dir}/bin" + + assert "LD_LIBRARY_PATH" in loaded_env + assert loaded_env["LD_LIBRARY_PATH"] == f"{base_dir}/lib" def test_load_env_cached_env(monkeypatch: pytest.MonkeyPatch, test_dir: str): diff --git a/tests/test_message_handler/test_build_model_key.py b/tests/test_message_handler/test_build_model_key.py index c09c787fcf..6c9b3dc951 100644 --- a/tests/test_message_handler/test_build_model_key.py +++ b/tests/test_message_handler/test_build_model_key.py @@ -38,7 +38,7 @@ def test_build_model_key_successful(): fsd = "mock-feature-store-descriptor" model_key = handler.build_model_key("tensor_key", fsd) assert model_key.key == "tensor_key" - assert model_key.featureStoreDescriptor == fsd + assert model_key.descriptor == fsd def test_build_model_key_unsuccessful(): diff --git a/tests/test_message_handler/test_request.py b/tests/test_message_handler/test_request.py index 7ede41b50d..a60818f7dd 100644 --- a/tests/test_message_handler/test_request.py +++ b/tests/test_message_handler/test_request.py @@ -101,7 +101,7 @@ "reply_channel, model, input, output, output_descriptors, custom_attributes", [ pytest.param( - b"reply channel", + "reply channel", model_key, [input_key1, input_key2], [output_key1, output_key2], @@ -109,7 +109,7 @@ torch_attributes, ), pytest.param( - b"another reply channel", + "another reply channel", model, [input_key1], [output_key2], @@ -117,7 +117,7 @@ tf_attributes, ), pytest.param( - b"another reply channel", + "another reply channel", model, [input_key1], [output_key2], @@ -125,7 +125,7 @@ torch_attributes, ), pytest.param( - b"reply channel", + "reply channel", model_key, [input_key1], [output_key1], @@ -185,7 +185,7 @@ def test_build_request_indirect_successful( id="bad channel", ), pytest.param( - b"reply channel", + "reply channel", "bad model", [input_key1], [output_key2], @@ -194,7 +194,7 @@ def test_build_request_indirect_successful( id="bad model", ), pytest.param( - b"reply channel", + "reply channel", model_key, ["input_key1", "input_key2"], [output_key1, output_key2], @@ -212,7 +212,7 @@ def test_build_request_indirect_successful( id="bad input schema type", ), pytest.param( - b"reply channel", + "reply channel", model_key, [input_key1], ["output_key1", "output_key2"], @@ -230,7 +230,7 @@ def test_build_request_indirect_successful( id="bad output schema type", ), pytest.param( - b"reply channel", + "reply channel", model_key, [input_key1], [output_key1, output_key2], @@ -239,7 +239,7 @@ def test_build_request_indirect_successful( id="bad custom attributes", ), pytest.param( - b"reply channel", + "reply channel", model_key, [input_key1], [output_key1, output_key2], @@ -248,7 +248,7 @@ def test_build_request_indirect_successful( id="bad custom attributes schema type", ), pytest.param( - b"reply channel", + "reply channel", model_key, [input_key1], [output_key1, output_key2], @@ -276,7 +276,7 @@ def test_build_request_indirect_unsuccessful( "reply_channel, model, input, output, output_descriptors, custom_attributes", [ pytest.param( - b"reply channel", + "reply channel", model_key, [tensor_1, tensor_2], [], @@ -284,7 +284,7 @@ def test_build_request_indirect_unsuccessful( torch_attributes, ), pytest.param( - b"another reply channel", + "another reply channel", model, [tensor_1], [], @@ -292,7 +292,7 @@ def test_build_request_indirect_unsuccessful( tf_attributes, ), pytest.param( - b"another reply channel", + "another reply channel", model, [tensor_2], [], @@ -300,7 +300,7 @@ def test_build_request_indirect_unsuccessful( tf_attributes, ), pytest.param( - b"another reply channel", + "another reply channel", model, [tensor_1], [],