diff --git a/src/aws_durable_execution_sdk_python/exceptions.py b/src/aws_durable_execution_sdk_python/exceptions.py index dcaa2c1..fad9568 100644 --- a/src/aws_durable_execution_sdk_python/exceptions.py +++ b/src/aws_durable_execution_sdk_python/exceptions.py @@ -8,12 +8,25 @@ import time from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self, TypedDict if TYPE_CHECKING: import datetime +class AwsErrorObj(TypedDict): + Code: str | None + Message: str | None + + +class AwsErrorMetadata(TypedDict): + RequestId: str | None + HostId: str | None + HTTPStatusCode: str | None + HTTPHeaders: str | None + RetryAttempts: str | None + + class TerminationReason(Enum): """Reasons why a durable execution terminated.""" @@ -69,12 +82,35 @@ def __init__(self, message: str, callback_id: str | None = None): self.callback_id = callback_id -class CheckpointFailedError(InvocationError): - """Error when checkpoint operation fails.""" +class BotoClientError(InvocationError): + def __init__( + self, + message: str, + error: AwsErrorObj | None = None, + response_metadata: AwsErrorMetadata | None = None, + termination_reason=TerminationReason.INVOCATION_ERROR, + ): + super().__init__(message=message, termination_reason=termination_reason) + self.error: AwsErrorObj | None = error + self.response_metadata: AwsErrorMetadata | None = response_metadata - def __init__(self, message: str, step_id: str | None = None): - super().__init__(message, TerminationReason.CHECKPOINT_FAILED) - self.step_id = step_id + @classmethod + def from_exception(cls, exception: Exception) -> Self: + response = getattr(exception, "response", {}) + response_metadata = response.get("ResponseMetadata") + error = response.get("Error") + return cls( + message=str(exception), error=error, response_metadata=response_metadata + ) + + def build_logger_extras(self) -> dict: + extras: dict = {} + # preserve PascalCase to be consistent with other langauges + if error := self.error: + extras["Error"] = error + if response_metadata := self.response_metadata: + extras["ResponseMetadata"] = response_metadata + return extras class NonDeterministicExecutionError(ExecutionError): @@ -85,21 +121,44 @@ def __init__(self, message: str, step_id: str | None = None): self.step_id = step_id -class CheckpointError(CheckpointFailedError): +class CheckpointError(BotoClientError): """Failure to checkpoint. Will terminate the lambda.""" - def __init__(self, message: str): - super().__init__(message) - - @classmethod - def from_exception(cls, exception: Exception) -> CheckpointError: - return cls(message=str(exception)) + def __init__( + self, + message: str, + error: AwsErrorObj | None = None, + response_metadata: AwsErrorMetadata | None = None, + ): + super().__init__( + message, + error, + response_metadata, + termination_reason=TerminationReason.CHECKPOINT_FAILED, + ) class ValidationError(DurableExecutionsError): """Incorrect arguments to a Durable Function operation.""" +class GetExecutionStateError(BotoClientError): + """Raised when failing to retrieve execution state""" + + def __init__( + self, + message: str, + error: AwsErrorObj | None = None, + response_metadata: AwsErrorMetadata | None = None, + ): + super().__init__( + message, + error, + response_metadata, + termination_reason=TerminationReason.INVOCATION_ERROR, + ) + + class InvalidStateError(DurableExecutionsError): """Raised when an operation is attempted on an object in an invalid state.""" diff --git a/src/aws_durable_execution_sdk_python/lambda_service.py b/src/aws_durable_execution_sdk_python/lambda_service.py index 97841c1..b450823 100644 --- a/src/aws_durable_execution_sdk_python/lambda_service.py +++ b/src/aws_durable_execution_sdk_python/lambda_service.py @@ -13,6 +13,7 @@ from aws_durable_execution_sdk_python.exceptions import ( CallableRuntimeError, CheckpointError, + GetExecutionStateError, ) if TYPE_CHECKING: @@ -1007,8 +1008,11 @@ def checkpoint( return CheckpointOutput.from_dict(result) except Exception as e: - logger.exception("Failed to checkpoint.") - raise CheckpointError.from_exception(e) from e + checkpoint_error = CheckpointError.from_exception(e) + logger.exception( + "Failed to checkpoint.", extra=checkpoint_error.build_logger_extras() + ) + raise checkpoint_error from None def get_execution_state( self, @@ -1017,13 +1021,20 @@ def get_execution_state( next_marker: str, max_items: int = 1000, ) -> StateOutput: - result: MutableMapping[str, Any] = self.client.get_durable_execution_state( - DurableExecutionArn=durable_execution_arn, - CheckpointToken=checkpoint_token, - Marker=next_marker, - MaxItems=max_items, - ) - return StateOutput.from_dict(result) + try: + result: MutableMapping[str, Any] = self.client.get_durable_execution_state( + DurableExecutionArn=durable_execution_arn, + CheckpointToken=checkpoint_token, + Marker=next_marker, + MaxItems=max_items, + ) + return StateOutput.from_dict(result) + except Exception as e: + error = GetExecutionStateError.from_exception(e) + logger.exception( + "Failed to get execution state.", extra=error.build_logger_extras() + ) + raise error from None # endregion client diff --git a/tests/lambda_service_test.py b/tests/lambda_service_test.py index 35214b9..f06bf5f 100644 --- a/tests/lambda_service_test.py +++ b/tests/lambda_service_test.py @@ -8,6 +8,7 @@ from aws_durable_execution_sdk_python.exceptions import ( CallableRuntimeError, CheckpointError, + GetExecutionStateError, ) from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( @@ -1788,6 +1789,80 @@ def test_lambda_client_checkpoint_with_exception(): lambda_client.checkpoint("arn123", "token123", [update], None) +@patch("aws_durable_execution_sdk_python.lambda_service.logger") +def test_lambda_client_checkpoint_logs_response_metadata(mock_logger): + """Test LambdaClient.checkpoint logs ResponseMetadata from boto3 exception.""" + mock_client = Mock() + boto_error = Exception("API Error") + boto_error.response = { + "ResponseMetadata": { + "RequestId": "test-request-id-123", + "HTTPStatusCode": 500, + "RetryAttempts": 2, + } + } + mock_client.checkpoint_durable_execution.side_effect = boto_error + + lambda_client = LambdaClient(mock_client) + update = OperationUpdate( + operation_id="op1", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + + with pytest.raises(CheckpointError): + lambda_client.checkpoint("arn123", "token123", [update], None) + + mock_logger.exception.assert_called_once_with( + "Failed to checkpoint.", + extra={ + "ResponseMetadata": { + "RequestId": "test-request-id-123", + "HTTPStatusCode": 500, + "RetryAttempts": 2, + }, + }, + ) + + +@patch("aws_durable_execution_sdk_python.lambda_service.logger") +def test_lambda_client_get_execution_state_logs_response_metadata(mock_logger): + """Test LambdaClient.get_execution_state logs ResponseMetadata from boto3 exception.""" + mock_client = Mock() + boto_error = Exception("API Error") + boto_error.response = { + "ResponseMetadata": { + "RequestId": "test-request-id-456", + "HTTPStatusCode": 503, + "RetryAttempts": 1, + } + } + mock_client.get_durable_execution_state.side_effect = boto_error + + lambda_client = LambdaClient(mock_client) + + with pytest.raises(GetExecutionStateError) as exc_info: + lambda_client.get_execution_state("arn123", "token123", "", 1000) + + assert exc_info.value.error is None + assert exc_info.value.response_metadata == { + "RequestId": "test-request-id-456", + "HTTPStatusCode": 503, + "RetryAttempts": 1, + } + + mock_logger.exception.assert_called_once_with( + "Failed to get execution state.", + extra={ + "ResponseMetadata": { + "RequestId": "test-request-id-456", + "HTTPStatusCode": 503, + "RetryAttempts": 1, + }, + }, + ) + + def test_durable_service_client_protocol_checkpoint(): """Test DurableServiceClient protocol checkpoint method signature.""" mock_client = Mock(spec=DurableServiceClient)