diff --git a/.gitignore b/.gitignore index 8695c4b2..bd355dc6 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,6 @@ __pycache__/ dist/ -.idea \ No newline at end of file +.idea + +.kiro/ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f3b1b950..dc4dbef3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,9 +10,7 @@ readme = "README.md" requires-python = ">=3.13" license = "Apache-2.0" keywords = [] -authors = [ - { name = "yaythomas", email = "tgaigher@amazon.com" }, -] +authors = [{ name = "yaythomas", email = "tgaigher@amazon.com" }] classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python", @@ -20,9 +18,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "boto3>=1.40.30" -] +dependencies = ["boto3>=1.40.30"] [project.urls] Documentation = "https://github.com/aws/aws-durable-execution-sdk-python#readme" @@ -38,48 +34,31 @@ packages = ["src/aws_durable_execution_sdk_python"] [tool.hatch.version] path = "src/aws_durable_execution_sdk_python/__about__.py" -# [tool.hatch.envs.default] -# dependencies=["pytest"] - -# [tool.hatch.envs.default.scripts] -# test="pytest" - [tool.hatch.envs.test] -dependencies = [ - "coverage[toml]", - "pytest", - "pytest-cov", -] +dependencies = ["coverage[toml]", "pytest", "pytest-cov"] [tool.hatch.envs.test.scripts] -cov="pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=src/aws_durable_execution_sdk_python --cov=tests --cov-fail-under=98" +cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=src/aws_durable_execution_sdk_python --cov-fail-under=98" [tool.hatch.envs.types] -extra-dependencies = [ - "mypy>=1.0.0", - "pytest" -] +extra-dependencies = ["mypy>=1.0.0", "pytest"] [tool.hatch.envs.types.scripts] check = "mypy --install-types --non-interactive {args:src/aws_durable_execution_sdk_python tests}" [tool.coverage.run] -source_pkgs = ["aws_durable_execution_sdk_python", "tests"] +source_pkgs = ["aws_durable_execution_sdk_python"] branch = true parallel = true -omit = [ - "src/aws_durable_execution_sdk_python/__about__.py", -] +omit = ["src/aws_durable_execution_sdk_python/__about__.py"] [tool.coverage.paths] -aws_durable_execution_sdk_python = ["src/aws_durable_execution_sdk_python", "*/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python"] -tests = ["tests", "*/aws-durable-execution-sdk-python/tests"] +aws_durable_execution_sdk_python = [ + "src/aws_durable_execution_sdk_python", + "*/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python", +] [tool.coverage.report] -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [tool.ruff] line-length = 88 @@ -88,4 +67,12 @@ line-length = 88 preview = false [tool.ruff.lint.per-file-ignores] -"tests/**" = ["ARG001", "ARG002", "ARG005", "S101", "PLR2004", "SIM117", "TRY301"] +"tests/**" = [ + "ARG001", + "ARG002", + "ARG005", + "S101", + "PLR2004", + "SIM117", + "TRY301", +] diff --git a/src/aws_durable_execution_sdk_python/config.py b/src/aws_durable_execution_sdk_python/config.py index 348e480b..0ac2473b 100644 --- a/src/aws_durable_execution_sdk_python/config.py +++ b/src/aws_durable_execution_sdk_python/config.py @@ -8,7 +8,8 @@ from aws_durable_execution_sdk_python.retries import RetryDecision # noqa: TCH001 -R = TypeVar("R") +P = TypeVar("P") # Payload type +R = TypeVar("R") # Result type T = TypeVar("T") U = TypeVar("U") @@ -133,6 +134,14 @@ class MapConfig: serdes: SerDes | None = None +@dataclass +class InvokeConfig(Generic[P, R]): + # retry_strategy: Callable[[Exception, int], RetryDecision] | None = None + timeout_seconds: int = 0 + serdes_payload: SerDes[P] | None = None + serdes_result: SerDes[R] | None = None + + @dataclass(frozen=True) class CallbackConfig: """Configuration for callbacks.""" diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index 62fda63b..e6f74c82 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -7,6 +7,7 @@ BatchedInput, CallbackConfig, ChildConfig, + InvokeConfig, MapConfig, ParallelConfig, StepConfig, @@ -30,6 +31,7 @@ wait_for_callback_handler, ) from aws_durable_execution_sdk_python.operation.child import child_handler +from aws_durable_execution_sdk_python.operation.invoke import invoke_handler from aws_durable_execution_sdk_python.operation.map import map_handler from aws_durable_execution_sdk_python.operation.parallel import parallel_handler from aws_durable_execution_sdk_python.operation.step import step_handler @@ -56,17 +58,19 @@ from aws_durable_execution_sdk_python.state import CheckpointedResult -R = TypeVar("R") +P = TypeVar("P") # Payload type +R = TypeVar("R") # Result type T = TypeVar("T") U = TypeVar("U") -P = ParamSpec("P") +Params = ParamSpec("Params") + logger = logging.getLogger(__name__) def durable_step( - func: Callable[Concatenate[StepContext, P], T], -) -> Callable[P, Callable[[StepContext], T]]: + func: Callable[Concatenate[StepContext, Params], T], +) -> Callable[Params, Callable[[StepContext], T]]: """Wrap your callable into a named function that a Durable step can run.""" def wrapper(*args, **kwargs): @@ -80,8 +84,8 @@ def function_with_arguments(context: StepContext): def durable_with_child_context( - func: Callable[Concatenate[DurableContext, P], T], -) -> Callable[P, Callable[[DurableContext], T]]: + func: Callable[Concatenate[DurableContext, Params], T], +) -> Callable[Params, Callable[[DurableContext], T]]: """Wrap your callable into a Durable child context.""" def wrapper(*args, **kwargs): @@ -291,6 +295,36 @@ def create_callback( serdes=config.serdes, ) + def invoke( + self, + function_name: str, + payload: P, + name: str | None = None, + config: InvokeConfig[P, R] | None = None, + ) -> R: + """Invoke another Durable Function. + + Args: + function_name: Name of the function to invoke + payload: Input payload to send to the function + name: Optional name for the operation + config: Optional configuration for the invoke operation + + Returns: + The result of the invoked function + """ + return invoke_handler( + function_name=function_name, + payload=payload, + state=self.state, + operation_identifier=OperationIdentifier( + operation_id=self._create_step_id(), + parent_id=self._parent_id, + name=name, + ), + config=config, + ) + def map( self, inputs: Sequence[U], diff --git a/src/aws_durable_execution_sdk_python/exceptions.py b/src/aws_durable_execution_sdk_python/exceptions.py index 0577751b..7aef680d 100644 --- a/src/aws_durable_execution_sdk_python/exceptions.py +++ b/src/aws_durable_execution_sdk_python/exceptions.py @@ -5,6 +5,7 @@ from __future__ import annotations +import time from dataclasses import dataclass @@ -77,6 +78,24 @@ def __init__(self, message: str, scheduled_timestamp: float): super().__init__(message) self.scheduled_timestamp = scheduled_timestamp + @classmethod + def from_delay(cls, message: str, delay_seconds: int) -> TimedSuspendExecution: + """Create a timed suspension with the delay calculated from now. + + Args: + message: Descriptive message for the suspension + delay_seconds: Duration to suspend in seconds from current time + + Returns: + TimedSuspendExecution: Instance with calculated resume time + + Example: + >>> exception = TimedSuspendExecution.from_delay("Waiting for callback", 30) + >>> # Will suspend for 30 seconds from now + """ + resume_time = time.time() + delay_seconds + return cls(message, scheduled_timestamp=resume_time) + class OrderedLockError(DurableExecutionsError): """An error from OrderedLock. diff --git a/src/aws_durable_execution_sdk_python/lambda_service.py b/src/aws_durable_execution_sdk_python/lambda_service.py index cb103e37..089dcf36 100644 --- a/src/aws_durable_execution_sdk_python/lambda_service.py +++ b/src/aws_durable_execution_sdk_python/lambda_service.py @@ -63,6 +63,7 @@ class OperationSubType(Enum): PARALLEL_BRANCH = "ParallelBranch" WAIT_FOR_CALLBACK = "WaitForCallback" WAIT_FOR_CONDITION = "WaitForCondition" + INVOKE = "Invoke" @dataclass(frozen=True) @@ -241,15 +242,11 @@ def to_dict(self) -> MutableMapping[str, Any]: @dataclass(frozen=True) class InvokeOptions: function_name: str - function_qualifier: str | None = None - durable_execution_name: str | None = None + timeout_seconds: int = 0 def to_dict(self) -> MutableMapping[str, Any]: - result = {"FunctionName": self.function_name} - if self.function_qualifier: - result["FunctionQualifier"] = self.function_qualifier - if self.durable_execution_name: - result["DurableExecutionName"] = self.durable_execution_name + result: MutableMapping[str, Any] = {"FunctionName": self.function_name} + result["TimeoutSeconds"] = self.timeout_seconds return result @@ -471,6 +468,28 @@ def create_step_retry( # endregion step + # region invoke + @classmethod + def create_invoke_start( + cls, + identifier: OperationIdentifier, + payload: str, + invoke_options: InvokeOptions, + ) -> OperationUpdate: + """Create an instance of OperationUpdate for type: INVOKE, action: START.""" + return cls( + operation_id=identifier.operation_id, + parent_id=identifier.parent_id, + operation_type=OperationType.INVOKE, + sub_type=OperationSubType.INVOKE, + action=OperationAction.START, + name=identifier.name, + payload=payload, + invoke_options=invoke_options, + ) + + # endregion invoke + # region wait for condition @classmethod def create_wait_for_condition_start( diff --git a/src/aws_durable_execution_sdk_python/operation/invoke.py b/src/aws_durable_execution_sdk_python/operation/invoke.py new file mode 100644 index 00000000..c9daf555 --- /dev/null +++ b/src/aws_durable_execution_sdk_python/operation/invoke.py @@ -0,0 +1,128 @@ +"""Implement the Durable invoke operation.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, TypeVar + +from aws_durable_execution_sdk_python.config import InvokeConfig +from aws_durable_execution_sdk_python.exceptions import ( + FatalError, + SuspendExecution, + TimedSuspendExecution, +) +from aws_durable_execution_sdk_python.lambda_service import ( + InvokeOptions, + OperationUpdate, +) +from aws_durable_execution_sdk_python.serdes import deserialize, serialize + +if TYPE_CHECKING: + from typing import NoReturn + + from aws_durable_execution_sdk_python.identifier import OperationIdentifier + from aws_durable_execution_sdk_python.state import ExecutionState + +P = TypeVar("P") # Payload type +R = TypeVar("R") # Result type + +logger = logging.getLogger(__name__) + + +def invoke_handler( + function_name: str, + payload: P, + state: ExecutionState, + operation_identifier: OperationIdentifier, + config: InvokeConfig[P, R] | None, +) -> R: + """Invoke another Durable Function.""" + logger.debug( + "🔗 Invoke %s (%s)", + operation_identifier.name or function_name, + operation_identifier.operation_id, + ) + + if not config: + config = InvokeConfig[P, R]() + + # Check if we have existing step data + checkpointed_result = state.get_checkpoint_result(operation_identifier.operation_id) + + if checkpointed_result.is_succeeded(): + # Return persisted result - no need to check for errors in successful operations + if ( + checkpointed_result.operation + and checkpointed_result.operation.invoke_details + and checkpointed_result.operation.invoke_details.result + ): + return deserialize( + serdes=config.serdes_result, + data=checkpointed_result.operation.invoke_details.result, + operation_id=operation_identifier.operation_id, + durable_execution_arn=state.durable_execution_arn, + ) + return None # type: ignore + + if checkpointed_result.is_failed() or checkpointed_result.is_timed_out(): + # Operation failed, throw the exact same error on replay as the checkpointed failure + checkpointed_result.raise_callable_error() + + if checkpointed_result.is_started(): + # Operation is still running, suspend until completion + logger.debug( + "⏳ Invoke %s still in progress, suspending", + operation_identifier.name or function_name, + ) + msg = f"Invoke {operation_identifier.operation_id} still in progress" + suspend_with_optional_timeout(msg, config.timeout_seconds) + + serialized_payload: str = serialize( + serdes=config.serdes_payload, + value=payload, + operation_id=operation_identifier.operation_id, + durable_execution_arn=state.durable_execution_arn, + ) + + # the backend will do the invoke once it gets this checkpoint + start_operation: OperationUpdate = OperationUpdate.create_invoke_start( + identifier=operation_identifier, + payload=serialized_payload, + invoke_options=InvokeOptions( + function_name=function_name, timeout_seconds=config.timeout_seconds + ), + ) + + state.create_checkpoint(operation_update=start_operation) + + logger.debug( + "🚀 Invoke %s started, suspending for async execution", + operation_identifier.name or function_name, + ) + + # Suspend so invoke executes asynchronously without consuming cpu here + msg = ( + f"Invoke {operation_identifier.operation_id} started, suspending for completion" + ) + suspend_with_optional_timeout(msg, config.timeout_seconds) + # This line should never be reached since suspend_with_optional_timeout always raises + msg = "suspend_with_optional_timeout should have raised an exception, but did not." + raise FatalError(msg) from None + + +def suspend_with_optional_timeout( + msg: str, timeout_seconds: int | None = None +) -> NoReturn: + """Suspend execution with optional timeout. + + Args: + msg: Descriptive message for the suspension + timeout_seconds: Duration to suspend in seconds, or None/0 for indefinite + + Raises: + TimedSuspendExecution: When timeout_seconds > 0 + SuspendExecution: When timeout_seconds is None or <= 0 + """ + if timeout_seconds and timeout_seconds > 0: + raise TimedSuspendExecution.from_delay(msg, timeout_seconds) + raise SuspendExecution(msg) diff --git a/tests/context_test.py b/tests/context_test.py index cfb2e32c..f49e84f1 100644 --- a/tests/context_test.py +++ b/tests/context_test.py @@ -444,6 +444,170 @@ def test_step_with_original_name(mock_handler): # endregion step +# region invoke +@patch("aws_durable_execution_sdk_python.context.invoke_handler") +def test_invoke_basic(mock_handler): + """Test invoke with basic parameters.""" + mock_handler.return_value = "invoke_result" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + context = DurableContext(state=mock_state) + + result = context.invoke("test_function", "test_payload") + + assert result == "invoke_result" + + mock_handler.assert_called_once_with( + function_name="test_function", + payload="test_payload", + state=mock_state, + operation_identifier=OperationIdentifier("1", None, None), + config=None, + ) + + +@patch("aws_durable_execution_sdk_python.context.invoke_handler") +def test_invoke_with_name_and_config(mock_handler): + """Test invoke with name and config.""" + from aws_durable_execution_sdk_python.config import InvokeConfig + + mock_handler.return_value = "configured_result" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + config = InvokeConfig[str, str](timeout_seconds=30) + + context = DurableContext(state=mock_state) + [context._create_step_id() for _ in range(5)] # Set counter to 5 # noqa: SLF001 + + result = context.invoke( + "test_function", {"key": "value"}, name="named_invoke", config=config + ) + + assert result == "configured_result" + mock_handler.assert_called_once_with( + function_name="test_function", + payload={"key": "value"}, + state=mock_state, + operation_identifier=OperationIdentifier("6", None, "named_invoke"), + config=config, + ) + + +@patch("aws_durable_execution_sdk_python.context.invoke_handler") +def test_invoke_with_parent_id(mock_handler): + """Test invoke with parent_id.""" + mock_handler.return_value = "parent_result" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + context = DurableContext(state=mock_state, parent_id="parent123") + [context._create_step_id() for _ in range(2)] # Set counter to 2 # noqa: SLF001 + + context.invoke("test_function", None) + + mock_handler.assert_called_once_with( + function_name="test_function", + payload=None, + state=mock_state, + operation_identifier=OperationIdentifier("parent123-3", "parent123", None), + config=None, + ) + + +@patch("aws_durable_execution_sdk_python.context.invoke_handler") +def test_invoke_increments_counter(mock_handler): + """Test invoke increments step counter.""" + mock_handler.return_value = "result" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + context = DurableContext(state=mock_state) + [context._create_step_id() for _ in range(10)] # Set counter to 10 # noqa: SLF001 + + context.invoke("function1", "payload1") + context.invoke("function2", "payload2") + + assert context._step_counter.get_current() == 12 # noqa: SLF001 + assert mock_handler.call_args_list[0][1][ + "operation_identifier" + ] == OperationIdentifier("11", None, None) + assert mock_handler.call_args_list[1][1][ + "operation_identifier" + ] == OperationIdentifier("12", None, None) + + +@patch("aws_durable_execution_sdk_python.context.invoke_handler") +def test_invoke_with_none_payload(mock_handler): + """Test invoke with None payload.""" + mock_handler.return_value = None + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + context = DurableContext(state=mock_state) + + result = context.invoke("test_function", None) + + assert result is None + + mock_handler.assert_called_once_with( + function_name="test_function", + payload=None, + state=mock_state, + operation_identifier=OperationIdentifier("1", None, None), + config=None, + ) + + +@patch("aws_durable_execution_sdk_python.context.invoke_handler") +def test_invoke_with_custom_serdes(mock_handler): + """Test invoke with custom serialization config.""" + mock_handler.return_value = {"transformed": "data"} + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + from aws_durable_execution_sdk_python.config import InvokeConfig + + config = InvokeConfig[dict, dict]( + serdes_payload=CustomDictSerDes(), + serdes_result=CustomDictSerDes(), + timeout_seconds=60, + ) + + context = DurableContext(state=mock_state) + + result = context.invoke( + "test_function", + {"original": "data"}, + name="custom_serdes_invoke", + config=config, + ) + + assert result == {"transformed": "data"} + mock_handler.assert_called_once_with( + function_name="test_function", + payload={"original": "data"}, + state=mock_state, + operation_identifier=OperationIdentifier("1", None, "custom_serdes_invoke"), + config=config, + ) + + +# endregion invoke + + # region wait @patch("aws_durable_execution_sdk_python.context.wait_handler") def test_wait_basic(mock_handler): diff --git a/tests/exceptions_test.py b/tests/exceptions_test.py index de8e2797..32d2ea6f 100644 --- a/tests/exceptions_test.py +++ b/tests/exceptions_test.py @@ -1,5 +1,8 @@ """Unit tests for exceptions module.""" +import time +from unittest.mock import patch + import pytest from aws_durable_execution_sdk_python.exceptions import ( @@ -11,6 +14,7 @@ OrderedLockError, StepInterruptedError, SuspendExecution, + TimedSuspendExecution, UserlandError, ValidationError, ) @@ -122,3 +126,87 @@ def test_callable_runtime_error_serializable_details_frozen(): details = CallableRuntimeErrorSerializableDetails("Error", "message") with pytest.raises(AttributeError): details.type = "NewError" + + +def test_timed_suspend_execution(): + """Test TimedSuspendExecution exception.""" + scheduled_time = 1234567890.0 + error = TimedSuspendExecution("timed suspend", scheduled_time) + assert str(error) == "timed suspend" + assert error.scheduled_timestamp == scheduled_time + assert isinstance(error, SuspendExecution) + assert isinstance(error, BaseException) + + +def test_timed_suspend_execution_from_delay(): + """Test TimedSuspendExecution.from_delay factory method.""" + message = "Waiting for callback" + delay_seconds = 30 + + # Mock time.time() to get predictable results + with patch("time.time", return_value=1000.0): + error = TimedSuspendExecution.from_delay(message, delay_seconds) + + assert str(error) == message + assert error.scheduled_timestamp == 1030.0 # 1000.0 + 30 + assert isinstance(error, TimedSuspendExecution) + assert isinstance(error, SuspendExecution) + + +def test_timed_suspend_execution_from_delay_zero_delay(): + """Test TimedSuspendExecution.from_delay with zero delay.""" + message = "Immediate suspension" + delay_seconds = 0 + + with patch("time.time", return_value=500.0): + error = TimedSuspendExecution.from_delay(message, delay_seconds) + + assert str(error) == message + assert error.scheduled_timestamp == 500.0 # 500.0 + 0 + assert isinstance(error, TimedSuspendExecution) + + +def test_timed_suspend_execution_from_delay_negative_delay(): + """Test TimedSuspendExecution.from_delay with negative delay.""" + message = "Past suspension" + delay_seconds = -10 + + with patch("time.time", return_value=100.0): + error = TimedSuspendExecution.from_delay(message, delay_seconds) + + assert str(error) == message + assert error.scheduled_timestamp == 90.0 # 100.0 + (-10) + assert isinstance(error, TimedSuspendExecution) + + +def test_timed_suspend_execution_from_delay_large_delay(): + """Test TimedSuspendExecution.from_delay with large delay.""" + message = "Long suspension" + delay_seconds = 3600 # 1 hour + + with patch("time.time", return_value=0.0): + error = TimedSuspendExecution.from_delay(message, delay_seconds) + + assert str(error) == message + assert error.scheduled_timestamp == 3600.0 # 0.0 + 3600 + assert isinstance(error, TimedSuspendExecution) + + +def test_timed_suspend_execution_from_delay_calculation_accuracy(): + """Test that TimedSuspendExecution.from_delay calculates time accurately.""" + message = "Accurate timing test" + delay_seconds = 42 + + # Test with actual time.time() to ensure the calculation works in real scenarios + before_time = time.time() + error = TimedSuspendExecution.from_delay(message, delay_seconds) + after_time = time.time() + + # The scheduled timestamp should be within a reasonable range + # (accounting for the small time difference between calls) + expected_min = before_time + delay_seconds + expected_max = after_time + delay_seconds + + assert expected_min <= error.scheduled_timestamp <= expected_max + assert str(error) == message + assert isinstance(error, TimedSuspendExecution) diff --git a/tests/lambda_service_test.py b/tests/lambda_service_test.py index 328e536a..0baff328 100644 --- a/tests/lambda_service_test.py +++ b/tests/lambda_service_test.py @@ -339,14 +339,12 @@ def test_invoke_options_to_dict(): """Test InvokeOptions.to_dict method.""" options = InvokeOptions( function_name="test_function", - function_qualifier="$LATEST", - durable_execution_name="test_execution", + timeout_seconds=30, ) result = options.to_dict() expected = { "FunctionName": "test_function", - "FunctionQualifier": "$LATEST", - "DurableExecutionName": "test_execution", + "TimeoutSeconds": 30, } assert result == expected @@ -355,7 +353,7 @@ def test_invoke_options_to_dict_minimal(): """Test InvokeOptions.to_dict with minimal fields.""" options = InvokeOptions(function_name="test_function") result = options.to_dict() - assert result == {"FunctionName": "test_function"} + assert result == {"FunctionName": "test_function", "TimeoutSeconds": 0} def test_operation_update_to_dict(): @@ -400,9 +398,7 @@ def test_operation_update_to_dict_complete(): callback_options = CallbackOptions( timeout_seconds=300, heartbeat_timeout_seconds=60 ) - invoke_options = InvokeOptions( - function_name="test_func", function_qualifier="$LATEST" - ) + invoke_options = InvokeOptions(function_name="test_func", timeout_seconds=60) update = OperationUpdate( operation_id="op1", @@ -430,7 +426,7 @@ def test_operation_update_to_dict_complete(): "StepOptions": {"NextAttemptDelaySeconds": 30}, "WaitOptions": {"WaitSeconds": 60}, "CallbackOptions": {"TimeoutSeconds": 300, "HeartbeatTimeoutSeconds": 60}, - "InvokeOptions": {"FunctionName": "test_func", "FunctionQualifier": "$LATEST"}, + "InvokeOptions": {"FunctionName": "test_func", "TimeoutSeconds": 60}, } assert result == expected @@ -1377,9 +1373,7 @@ def test_operation_update_complete_with_new_fields(): callback_options = CallbackOptions( timeout_seconds=300, heartbeat_timeout_seconds=60 ) - invoke_options = InvokeOptions( - function_name="test_func", function_qualifier="$LATEST" - ) + invoke_options = InvokeOptions(function_name="test_func", timeout_seconds=60) update = OperationUpdate( operation_id="op1", @@ -1411,7 +1405,7 @@ def test_operation_update_complete_with_new_fields(): "StepOptions": {"NextAttemptDelaySeconds": 30}, "WaitOptions": {"WaitSeconds": 60}, "CallbackOptions": {"TimeoutSeconds": 300, "HeartbeatTimeoutSeconds": 60}, - "InvokeOptions": {"FunctionName": "test_func", "FunctionQualifier": "$LATEST"}, + "InvokeOptions": {"FunctionName": "test_func", "TimeoutSeconds": 60}, } assert result == expected diff --git a/tests/operation/child_test.py b/tests/operation/child_test.py index 3ae9e7d3..adc3685e 100644 --- a/tests/operation/child_test.py +++ b/tests/operation/child_test.py @@ -305,7 +305,7 @@ def test_child_handler_default_serialization(): assert len(success_call) == 1 -def test_child_handler_custom_serdes_not_start(): +def test_child_handler_custom_serdes_not_start() -> None: mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" mock_result = Mock() @@ -334,7 +334,7 @@ def test_child_handler_custom_serdes_not_start(): assert success_operation.payload == expected_checkpoointed_result -def test_child_handler_custom_serdes_already_succeeded(): +def test_child_handler_custom_serdes_already_succeeded() -> None: mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" mock_result = Mock() @@ -363,7 +363,7 @@ def test_child_handler_custom_serdes_already_succeeded(): # large payload with summary generator -def test_child_handler_large_payload_with_summary_generator(): +def test_child_handler_large_payload_with_summary_generator() -> None: """Test child_handler with large payload and summary generator.""" mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" @@ -397,7 +397,7 @@ def my_summary(result: str) -> str: # large payload without summary generator -def test_child_handler_large_payload_without_summary_generator(): +def test_child_handler_large_payload_without_summary_generator() -> None: """Test child_handler with large payload and no summary generator.""" mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" @@ -427,7 +427,7 @@ def test_child_handler_large_payload_without_summary_generator(): # mocked children replay mode execute the function again -def test_child_handler_replay_children_mode(): +def test_child_handler_replay_children_mode() -> None: """Test child_handler in ReplayChildren mode.""" mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" diff --git a/tests/operation/invoke_test.py b/tests/operation/invoke_test.py new file mode 100644 index 00000000..fe4eaa5a --- /dev/null +++ b/tests/operation/invoke_test.py @@ -0,0 +1,531 @@ +"""Unit tests for invoke handler.""" + +from __future__ import annotations + +import json +from unittest.mock import Mock, patch + +import pytest + +from aws_durable_execution_sdk_python.config import InvokeConfig +from aws_durable_execution_sdk_python.exceptions import ( + CallableRuntimeError, + FatalError, + SuspendExecution, + TimedSuspendExecution, +) +from aws_durable_execution_sdk_python.identifier import OperationIdentifier +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + InvokeDetails, + Operation, + OperationAction, + OperationStatus, + OperationType, +) +from aws_durable_execution_sdk_python.operation.invoke import ( + invoke_handler, + suspend_with_optional_timeout, +) +from aws_durable_execution_sdk_python.state import CheckpointedResult, ExecutionState +from tests.serdes_test import CustomDictSerDes + + +def test_invoke_handler_already_succeeded(): + """Test invoke_handler when operation already succeeded.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + operation = Operation( + operation_id="invoke1", + operation_type=OperationType.INVOKE, + status=OperationStatus.SUCCEEDED, + invoke_details=InvokeDetails( + durable_execution_arn="invoked_arn", result=json.dumps("test_result") + ), + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + result = invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier("invoke1", None, "test_invoke"), + config=None, + ) + + assert result == "test_result" + mock_state.create_checkpoint.assert_not_called() + + +def test_invoke_handler_already_succeeded_none_result(): + """Test invoke_handler when operation succeeded with None result.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + operation = Operation( + operation_id="invoke2", + operation_type=OperationType.INVOKE, + status=OperationStatus.SUCCEEDED, + invoke_details=InvokeDetails(durable_execution_arn="invoked_arn", result=None), + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + result = invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier("invoke2", None, "test_invoke"), + config=None, + ) + + assert result is None + + +def test_invoke_handler_already_succeeded_no_invoke_details(): + """Test invoke_handler when operation succeeded but has no invoke_details.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + operation = Operation( + operation_id="invoke3", + operation_type=OperationType.INVOKE, + status=OperationStatus.SUCCEEDED, + invoke_details=None, + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + result = invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier("invoke3", None, "test_invoke"), + config=None, + ) + + assert result is None + + +def test_invoke_handler_already_failed(): + """Test invoke_handler when operation already failed.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + error = ErrorObject( + message="Test error", type="TestError", data=None, stack_trace=None + ) + operation = Operation( + operation_id="invoke4", + operation_type=OperationType.INVOKE, + status=OperationStatus.FAILED, + invoke_details=InvokeDetails(durable_execution_arn="invoked_arn", error=error), + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + with pytest.raises(CallableRuntimeError): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier("invoke4", None, "test_invoke"), + config=None, + ) + + +def test_invoke_handler_already_timed_out(): + """Test invoke_handler when operation already timed out.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + error = ErrorObject( + message="Operation timed out", type="TimeoutError", data=None, stack_trace=None + ) + operation = Operation( + operation_id="invoke5", + operation_type=OperationType.INVOKE, + status=OperationStatus.TIMED_OUT, + invoke_details=InvokeDetails(durable_execution_arn="invoked_arn", error=error), + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + with pytest.raises(CallableRuntimeError): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier("invoke5", None, "test_invoke"), + config=None, + ) + + +def test_invoke_handler_already_started(): + """Test invoke_handler when operation is already started.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + operation = Operation( + operation_id="invoke6", + operation_type=OperationType.INVOKE, + status=OperationStatus.STARTED, + invoke_details=InvokeDetails(durable_execution_arn="invoked_arn"), + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + with pytest.raises(SuspendExecution, match="Invoke invoke6 still in progress"): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier("invoke6", None, "test_invoke"), + config=None, + ) + + +def test_invoke_handler_already_started_with_timeout(): + """Test invoke_handler when operation is already started with timeout config.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + operation = Operation( + operation_id="invoke7", + operation_type=OperationType.INVOKE, + status=OperationStatus.STARTED, + invoke_details=InvokeDetails(durable_execution_arn="invoked_arn"), + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + config = InvokeConfig[str, str](timeout_seconds=30) + + with pytest.raises(TimedSuspendExecution): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier("invoke7", None, "test_invoke"), + config=config, + ) + + +def test_invoke_handler_new_operation(): + """Test invoke_handler when starting a new operation.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + mock_result = CheckpointedResult.create_not_found() + mock_state.get_checkpoint_result.return_value = mock_result + + config = InvokeConfig[str, str](timeout_seconds=60) + + with pytest.raises( + SuspendExecution, match="Invoke invoke8 started, suspending for completion" + ): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier("invoke8", None, "test_invoke"), + config=config, + ) + + # Verify checkpoint was created + mock_state.create_checkpoint.assert_called_once() + operation_update = mock_state.create_checkpoint.call_args[1]["operation_update"] + + assert operation_update.operation_id == "invoke8" + assert operation_update.operation_type == OperationType.INVOKE + assert operation_update.action == OperationAction.START + assert operation_update.name == "test_invoke" + assert operation_update.payload == json.dumps("test_input") + assert operation_update.invoke_options.function_name == "test_function" + assert operation_update.invoke_options.timeout_seconds == 60 + + +def test_invoke_handler_new_operation_with_timeout(): + """Test invoke_handler when starting a new operation with timeout.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + mock_result = CheckpointedResult.create_not_found() + mock_state.get_checkpoint_result.return_value = mock_result + + config = InvokeConfig[str, str](timeout_seconds=30) + + with pytest.raises(TimedSuspendExecution): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier("invoke9", None, "test_invoke"), + config=config, + ) + + +def test_invoke_handler_new_operation_no_timeout(): + """Test invoke_handler when starting a new operation without timeout.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + mock_result = CheckpointedResult.create_not_found() + mock_state.get_checkpoint_result.return_value = mock_result + + config = InvokeConfig[str, str](timeout_seconds=0) + + with pytest.raises(SuspendExecution): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier("invoke10", None, "test_invoke"), + config=config, + ) + + +def test_invoke_handler_no_config(): + """Test invoke_handler when no config is provided.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + mock_result = CheckpointedResult.create_not_found() + mock_state.get_checkpoint_result.return_value = mock_result + + with pytest.raises(SuspendExecution): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier("invoke11", None, "test_invoke"), + config=None, + ) + + # Verify default config was used + operation_update = mock_state.create_checkpoint.call_args[1]["operation_update"] + assert operation_update.invoke_options.timeout_seconds == 0 + + +def test_invoke_handler_custom_serdes(): + """Test invoke_handler with custom serialization.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + operation = Operation( + operation_id="invoke12", + operation_type=OperationType.INVOKE, + status=OperationStatus.SUCCEEDED, + invoke_details=InvokeDetails( + durable_execution_arn="invoked_arn", + result='{"key": "VALUE", "number": "84", "list": [1, 2, 3]}', + ), + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + config = InvokeConfig[dict, dict]( + serdes_payload=CustomDictSerDes(), serdes_result=CustomDictSerDes() + ) + + result = invoke_handler( + function_name="test_function", + payload={"key": "value", "number": 42, "list": [1, 2, 3]}, + state=mock_state, + operation_identifier=OperationIdentifier("invoke12", None, "test_invoke"), + config=config, + ) + + # CustomDictSerDes transforms the result back + assert result == {"key": "value", "number": 42, "list": [1, 2, 3]} + + +def test_invoke_handler_custom_serdes_new_operation(): + """Test invoke_handler with custom serialization for new operation.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + mock_result = CheckpointedResult.create_not_found() + mock_state.get_checkpoint_result.return_value = mock_result + + config = InvokeConfig[dict, dict]( + serdes_payload=CustomDictSerDes(), serdes_result=CustomDictSerDes() + ) + complex_payload = {"key": "value", "number": 42, "list": [1, 2, 3]} + + with pytest.raises(SuspendExecution): + invoke_handler( + function_name="test_function", + payload=complex_payload, + state=mock_state, + operation_identifier=OperationIdentifier("invoke13", None, "test_invoke"), + config=config, + ) + + # Verify custom serialization was used + operation_update = mock_state.create_checkpoint.call_args[1]["operation_update"] + expected_serialized = '{"key": "VALUE", "number": "84", "list": [1, 2, 3]}' + assert operation_update.payload == expected_serialized + + +def test_suspend_with_optional_timeout_with_timeout(): + """Test suspend_with_optional_timeout with timeout.""" + with pytest.raises(TimedSuspendExecution) as exc_info: + suspend_with_optional_timeout("test message", 30) + + assert "test message" in str(exc_info.value) + + +def test_suspend_with_optional_timeout_no_timeout(): + """Test suspend_with_optional_timeout without timeout.""" + with pytest.raises(SuspendExecution) as exc_info: + suspend_with_optional_timeout("test message", None) + + assert "test message" in str(exc_info.value) + + +def test_suspend_with_optional_timeout_zero_timeout(): + """Test suspend_with_optional_timeout with zero timeout.""" + with pytest.raises(SuspendExecution) as exc_info: + suspend_with_optional_timeout("test message", 0) + + assert "test message" in str(exc_info.value) + + +def test_suspend_with_optional_timeout_negative_timeout(): + """Test suspend_with_optional_timeout with negative timeout.""" + with pytest.raises(SuspendExecution) as exc_info: + suspend_with_optional_timeout("test message", -5) + + assert "test message" in str(exc_info.value) + + +def test_invoke_handler_with_operation_name(): + """Test invoke_handler uses operation name in logs when available.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + operation = Operation( + operation_id="invoke14", + operation_type=OperationType.INVOKE, + status=OperationStatus.STARTED, + invoke_details=InvokeDetails(durable_execution_arn="invoked_arn"), + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + with pytest.raises(SuspendExecution): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier("invoke14", None, "named_invoke"), + config=None, + ) + + +def test_invoke_handler_without_operation_name(): + """Test invoke_handler uses function name in logs when no operation name.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + operation = Operation( + operation_id="invoke15", + operation_type=OperationType.INVOKE, + status=OperationStatus.STARTED, + invoke_details=InvokeDetails(durable_execution_arn="invoked_arn"), + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + with pytest.raises(SuspendExecution): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier("invoke15", None, None), + config=None, + ) + + +def test_invoke_handler_with_none_payload(): + """Test invoke_handler when payload is None.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + mock_result = CheckpointedResult.create_not_found() + mock_state.get_checkpoint_result.return_value = mock_result + + with pytest.raises(SuspendExecution): + invoke_handler( + function_name="test_function", + payload=None, + state=mock_state, + operation_identifier=OperationIdentifier("invoke16", None, "test_invoke"), + config=None, + ) + + # Verify checkpoint was created with None payload + mock_state.create_checkpoint.assert_called_once() + operation_update = mock_state.create_checkpoint.call_args[1]["operation_update"] + assert operation_update.payload == "null" # JSON serialization of None + + +def test_invoke_handler_already_succeeded_with_none_payload(): + """Test invoke_handler when operation succeeded and original payload was None.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + operation = Operation( + operation_id="invoke17", + operation_type=OperationType.INVOKE, + status=OperationStatus.SUCCEEDED, + invoke_details=InvokeDetails( + durable_execution_arn="invoked_arn", result=json.dumps("test_result") + ), + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + result = invoke_handler( + function_name="test_function", + payload=None, + state=mock_state, + operation_identifier=OperationIdentifier("invoke17", None, "test_invoke"), + config=None, + ) + + assert result == "test_result" + mock_state.create_checkpoint.assert_not_called() + + +@patch( + "aws_durable_execution_sdk_python.operation.invoke.suspend_with_optional_timeout" +) +def test_invoke_handler_suspend_does_not_raise(mock_suspend): + """Test invoke_handler when suspend_with_optional_timeout doesn't raise an exception.""" + + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + mock_result = CheckpointedResult.create_not_found() + mock_state.get_checkpoint_result.return_value = mock_result + + # Mock suspend_with_optional_timeout to not raise an exception (which it should always do) + mock_suspend.return_value = None + + with pytest.raises( + FatalError, + match="suspend_with_optional_timeout should have raised an exception, but did not.", + ): + invoke_handler( + function_name="test_function", + payload="test_input", + state=mock_state, + operation_identifier=OperationIdentifier("invoke18", None, "test_invoke"), + config=None, + ) + + mock_suspend.assert_called_once()