Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
345c739
started tests
AlyssaCote Jul 1, 2024
68ed090
more tests
AlyssaCote Jul 1, 2024
548caa4
starting to handle errors
AlyssaCote Jul 1, 2024
57faf70
nested try excepts. not good.
AlyssaCote Jul 2, 2024
da3e0a7
building up replies
AlyssaCote Jul 2, 2024
d574fa1
replaced my nestedness with if checks, added helper function and tests
AlyssaCote Jul 2, 2024
4fa1184
doc string and newlines
AlyssaCote Jul 2, 2024
011d113
Merge branch 'mli-feature' into new_failure_handling
AlyssaCote Jul 2, 2024
5fa1a37
changelog
AlyssaCote Jul 2, 2024
8709710
test groups
AlyssaCote Jul 3, 2024
96bf89d
style
AlyssaCote Jul 3, 2024
9d8d51b
build reply fails test
AlyssaCote Jul 3, 2024
9cff781
style
AlyssaCote Jul 3, 2024
d6e3a64
Merge branch 'mli-feature' into new_failure_handling
AlyssaCote Jul 3, 2024
1e93392
add running enum, remove, failed parameter
AlyssaCote Jul 9, 2024
7f04d5b
Merge branch 'mli-feature' into new_failure_handling
AlyssaCote Jul 10, 2024
ada632d
import or skip dragon
AlyssaCote Jul 10, 2024
6bb16f5
fix workermanager init in tests
AlyssaCote Jul 10, 2024
6fa2144
style
AlyssaCote Jul 10, 2024
9ed4c6b
fixing tests
AlyssaCote Jul 10, 2024
ee7b9d1
style
AlyssaCote Jul 10, 2024
7e41d9d
fix _on_iteration
AlyssaCote Jul 10, 2024
d2b7977
style
AlyssaCote Jul 10, 2024
8b7924b
return failure in reply channel asap
AlyssaCote Jul 10, 2024
ac2784a
fix test
AlyssaCote Jul 10, 2024
2ab2d50
pr comments
AlyssaCote Jul 10, 2024
a399fd5
more pr comments
AlyssaCote Jul 10, 2024
17baec6
typing
AlyssaCote Jul 10, 2024
f71fd4b
more typing
AlyssaCote Jul 10, 2024
9aab212
spelling
AlyssaCote Jul 10, 2024
7998843
mock test
AlyssaCote Jul 10, 2024
a8a6e29
fix test
AlyssaCote Jul 10, 2024
f7399ef
fix test again
AlyssaCote Jul 10, 2024
a8dae65
try again
AlyssaCote Jul 10, 2024
6e49886
oops add feature store
AlyssaCote Jul 10, 2024
2b1f5d4
add mock return values for previous functions
AlyssaCote Jul 10, 2024
b1e385b
style
AlyssaCote Jul 10, 2024
e64c0e6
positional args
AlyssaCote Jul 11, 2024
148da63
mock tests
AlyssaCote Jul 11, 2024
f30bc37
moving tests
AlyssaCote Jul 11, 2024
437837d
remove mli tests
AlyssaCote Jul 11, 2024
901fcef
style
AlyssaCote Jul 11, 2024
4bdcc59
parametrize mock tests
AlyssaCote Jul 11, 2024
13ecd13
ignore tests
AlyssaCote Jul 12, 2024
0af26a1
merge
AlyssaCote Jul 16, 2024
7451b2a
Revert "merge"
AlyssaCote Jul 16, 2024
eb352c8
merge
AlyssaCote Jul 16, 2024
ac312c2
merge fix
AlyssaCote Jul 16, 2024
365e7dc
StatusEnum -> Status
AlyssaCote Jul 16, 2024
1952561
more fixes
AlyssaCote Jul 16, 2024
48035c7
another fix
AlyssaCote Jul 16, 2024
e07a2c8
pr comments
AlyssaCote Jul 16, 2024
042e56d
remove extra punctuation
AlyssaCote Jul 16, 2024
ec32235
style
AlyssaCote Jul 16, 2024
7e53d08
more punctuation
AlyssaCote Jul 16, 2024
d986164
send_failure and tests
AlyssaCote Jul 17, 2024
d9e22fb
reintroduce questionable while loop
AlyssaCote Jul 17, 2024
3caf24b
remove send_failure
AlyssaCote Jul 17, 2024
be1ddea
adjust and add timing
AlyssaCote Jul 17, 2024
b3927c7
remove while loop
AlyssaCote Jul 17, 2024
7f104b0
remove comments, remove type ignore from workermanager
AlyssaCote Jul 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -169,17 +169,17 @@ test:
# help: test-verbose - Run all tests verbosely
.PHONY: test-verbose
test-verbose:
@python -m pytest -vv --ignore=tests/full_wlm/
@python -m pytest -vv --ignore=tests/full_wlm/ --ignore=tests/dragon

# help: test-debug - Run all tests with debug output
.PHONY: test-debug
test-debug:
@SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/
@SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/ --ignore=tests/dragon

# help: test-cov - Run all tests with coverage
.PHONY: test-cov
test-cov:
@python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/
@python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/ --ignore=tests/dragon


# help: test-full - Run all WLM tests with Python coverage (full test suite)
Expand Down
1 change: 1 addition & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Jump to:
Description

- Add TorchWorker first implementation and mock inference app example
- Add error handling in Worker Manager pipeline
- Add EnvironmentConfigLoader for ML Worker Manager
- Add Model schema with model metadata included
- Removed device from schemas, MessageHandler and tests
Expand Down
192 changes: 135 additions & 57 deletions smartsim/_core/mli/infrastructure/control/workermanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@
MachineLearningWorkerBase,
)
from ...message_handler import MessageHandler
from ...mli_schemas.response.response_capnp import Response
from ...mli_schemas.response.response_capnp import Response, ResponseBuilder

if t.TYPE_CHECKING:
from dragon.fli import FLInterface

from smartsim._core.mli.mli_schemas.model.model_capnp import Model
from smartsim._core.mli.mli_schemas.response.response_capnp import StatusEnum
from smartsim._core.mli.mli_schemas.response.response_capnp import Status

logger = get_logger(__name__)

Expand Down Expand Up @@ -98,6 +98,7 @@ def deserialize_message(
input_bytes: t.Optional[t.List[bytes]] = (
None # these will really be tensors already
)
output_keys: t.Optional[t.List[str]] = None

input_meta: t.List[t.Any] = []

Expand All @@ -107,22 +108,26 @@ def deserialize_message(
input_bytes = [data.blob for data in request.input.data]
input_meta = [data.tensorDescriptor for data in request.input.data]

if request.output:
output_keys = [tensor_key.key for tensor_key in request.output]

inference_request = InferenceRequest(
model_key=model_key,
callback=comm_channel,
raw_inputs=input_bytes,
input_meta=input_meta,
input_keys=input_keys,
input_meta=input_meta,
output_keys=output_keys,
raw_model=model_bytes,
batch_size=0,
)
return inference_request


def build_failure_reply(status: "StatusEnum", message: str) -> Response:
def build_failure_reply(status: "Status", message: str) -> ResponseBuilder:
return MessageHandler.build_response(
status=status, # todo: need to indicate correct status
message=message, # todo: decide what these will be
status=status,
message=message,
result=[],
custom_attributes=None,
)
Expand Down Expand Up @@ -154,17 +159,39 @@ def prepare_outputs(reply: InferenceReply) -> t.List[t.Any]:
return prepared_outputs


def build_reply(reply: InferenceReply) -> Response:
def build_reply(reply: InferenceReply) -> ResponseBuilder:
results = prepare_outputs(reply)

return MessageHandler.build_response(
status="complete",
message="success",
status=reply.status_enum,
message=reply.message,
result=results,
custom_attributes=None,
)


def exception_handler(
exc: Exception, reply_channel: t.Optional[CommChannelBase], failure_message: str
) -> None:
"""
Logs exceptions and sends a failure response.

:param exc: The exception to be logged
:param reply_channel: The channel used to send replies
:param failure_message: Failure message to log and send back
"""
logger.exception(
f"{failure_message}\n"
f"Exception type: {type(exc).__name__}\n"
f"Exception message: {str(exc)}"
)
serialized_resp = MessageHandler.serialize_response(
build_failure_reply("fail", failure_message)
)
if reply_channel:
reply_channel.send(serialized_resp)


class WorkerManager(Service):
"""An implementation of a service managing distribution of tasks to
machine learning workers"""
Expand Down Expand Up @@ -258,96 +285,147 @@ def _on_iteration(self) -> None:
timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing

reply = InferenceReply()

if not request.raw_model:
if request.model_key is None:
# A valid request should never get here.
raise ValueError("Could not read model key")
exception_handler(
ValueError("Could not find model key or model"),
request.callback,
"Could not find model key or model.",
)
return
if request.model_key in self._cached_models:
timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing
model_result = LoadModelResult(self._cached_models[request.model_key])

else:
fetch_model_result = None
while True:
try:
interm = time.perf_counter() # timing
fetch_model_result = self._worker.fetch_model(
request, self._feature_store
)
except KeyError:
time.sleep(0.1)
else:
break

if fetch_model_result is None:
raise SmartSimError("Could not retrieve model from feature store")
timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing
try:
fetch_model_result = self._worker.fetch_model(
request, self._feature_store
)
except Exception as e:
exception_handler(
e, request.callback, "Failed while fetching the model."
)
return

timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing
try:
model_result = self._worker.load_model(
request,
fetch_result=fetch_model_result,
device=self._device,
)
self._cached_models[request.model_key] = model_result.model
except Exception as e:
exception_handler(
e, request.callback, "Failed while loading the model."
)
return

else:
timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing
try:
fetch_model_result = self._worker.fetch_model(
request, self._feature_store
)
except Exception as e:
exception_handler(
e, request.callback, "Failed while fetching the model."
)
return

timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing
try:
model_result = self._worker.load_model(
request, fetch_model_result, self._device
request, fetch_result=fetch_model_result, device=self._device
)
self._cached_models[request.model_key] = model_result.model
else:
fetch_model_result = self._worker.fetch_model(request, None)
model_result = self._worker.load_model(
request, fetch_result=fetch_model_result, device=self._device
)
except Exception as e:
exception_handler(
e, request.callback, "Failed while loading the model."
)
return

timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing
fetch_input_result = self._worker.fetch_inputs(request, self._feature_store)
try:
fetch_input_result = self._worker.fetch_inputs(request, self._feature_store)
except Exception as e:
exception_handler(e, request.callback, "Failed while fetching the inputs.")
return

timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing
transformed_input = self._worker.transform_input(
request, fetch_input_result, self._device
)
try:
transformed_input = self._worker.transform_input(
request, fetch_input_result, self._device
)
except Exception as e:
exception_handler(
e, request.callback, "Failed while transforming the input."
)
return

timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing

reply = InferenceReply()

try:
execute_result = self._worker.execute(
request, model_result, transformed_input
)
except Exception as e:
exception_handler(e, request.callback, "Failed while executing.")
return

timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing
timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing
try:
transformed_output = self._worker.transform_output(
request, execute_result, self._device
)
except Exception as e:
exception_handler(
e, request.callback, "Failed while transforming the output."
)
return

timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing
if request.output_keys:
timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing
if request.output_keys:
try:
reply.output_keys = self._worker.place_output(
request, transformed_output, self._feature_store
request,
transformed_output,
self._feature_store,
)
else:
reply.outputs = transformed_output.outputs
except Exception:
logger.exception("Error executing worker")
reply.failed = True
except Exception as e:
exception_handler(
e, request.callback, "Failed while placing the output."
)
return
else:
reply.outputs = transformed_output.outputs

timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing

if reply.failed:
response = build_failure_reply("fail", "failure-occurred")
if reply.outputs is None or not reply.outputs:
response = build_failure_reply("fail", "Outputs not found.")
else:
if reply.outputs is None or not reply.outputs:
response = build_failure_reply("fail", "no-results")

reply.status_enum = "complete"
reply.message = "Success"
response = build_reply(reply)

timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing

# serialized = self._worker.serialize_reply(request, transformed_output)
serialized_resp = MessageHandler.serialize_response(response) # type: ignore
serialized_resp = MessageHandler.serialize_response(response)

timings.append(time.perf_counter() - interm) # timing
interm = time.perf_counter() # timing
Expand Down
9 changes: 7 additions & 2 deletions smartsim/_core/mli/infrastructure/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
from ...infrastructure.storage.featurestore import FeatureStore
from ...mli_schemas.model.model_capnp import Model

if t.TYPE_CHECKING:
from smartsim._core.mli.mli_schemas.response.response_capnp import Status

logger = get_logger(__name__)


Expand Down Expand Up @@ -70,12 +73,14 @@ def __init__(
self,
outputs: t.Optional[t.Collection[t.Any]] = None,
output_keys: t.Optional[t.Collection[str]] = None,
failed: bool = False,
status_enum: "Status" = "running",
message: str = "In progress",
) -> None:
"""Initialize the object"""
self.outputs: t.Collection[t.Any] = outputs or []
self.output_keys: t.Collection[t.Optional[str]] = output_keys or []
self.failed = failed
self.status_enum = status_enum
self.message = message


class LoadModelResult:
Expand Down
8 changes: 4 additions & 4 deletions smartsim/_core/mli/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def build_request(
request_attributes_capnp.TensorFlowRequestAttributes,
None,
],
) -> request_capnp.Request:
) -> request_capnp.RequestBuilder:
"""
Builds the request message.

Expand Down Expand Up @@ -405,7 +405,7 @@ def deserialize_request(request_bytes: bytes) -> request_capnp.Request:

@staticmethod
def _assign_status(
response: response_capnp.Response, status: "response_capnp.StatusEnum"
response: response_capnp.Response, status: "response_capnp.Status"
) -> None:
"""
Assigns a status to the supplied response.
Expand Down Expand Up @@ -498,7 +498,7 @@ def _assign_custom_response_attributes(

@staticmethod
def build_response(
status: "response_capnp.StatusEnum",
status: "response_capnp.Status",
message: str,
result: t.Union[
t.List[tensor_capnp.Tensor], t.List[data_references_capnp.TensorKey]
Expand All @@ -508,7 +508,7 @@ def build_response(
response_attributes_capnp.TensorFlowResponseAttributes,
None,
],
) -> response_capnp.Response:
) -> response_capnp.ResponseBuilder:
"""
Builds the response message.

Expand Down
Loading