Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/src/wait_for_callback/wait_for_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from aws_durable_execution_sdk_python.config import WaitForCallbackConfig
from aws_durable_execution_sdk_python.context import DurableContext
from aws_durable_execution_sdk_python.execution import durable_execution
from aws_durable_execution_sdk_python.config import Duration


def external_system_call(_callback_id: str) -> None:
Expand All @@ -13,7 +14,9 @@ def external_system_call(_callback_id: str) -> None:

@durable_execution
def handler(_event: Any, context: DurableContext) -> str:
config = WaitForCallbackConfig(timeout_seconds=120, heartbeat_timeout_seconds=60)
config = WaitForCallbackConfig(
timeout=Duration.from_seconds(120), heartbeat_timeout=Duration.from_seconds(60)
)

result = context.wait_for_callback(
external_system_call, name="external_call", config=config
Expand Down
30 changes: 30 additions & 0 deletions examples/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,36 @@ def run(
"""Execute the durable function and return results."""
return self._runner.run(input=input, timeout=timeout)

def run_async(
self,
input: str | None = None, # noqa: A002
timeout: int = 60,
) -> str:
return self._runner.run_async(input=input, timeout=timeout)

def send_callback_success(self, callback_id: str) -> None:
self._runner.send_callback_success(callback_id=callback_id)

def send_callback_failure(self, callback_id: str) -> None:
self._runner.send_callback_failure(callback_id=callback_id)

def send_callback_heartbeat(self, callback_id: str) -> None:
self._runner.send_callback_heartbeat(callback_id=callback_id)

def wait_for_result(
self, execution_arn: str, timeout: int = 60
) -> DurableFunctionTestResult:
return self._runner.wait_for_result(
execution_arn=execution_arn, timeout=timeout
)

def wait_for_callback(
self, execution_arn: str, name: str | None = None, timeout: int = 60
) -> str:
return self._runner.wait_for_callback(
execution_arn=execution_arn, name=name, timeout=timeout
)

@property
def mode(self) -> str:
"""Get the runner mode (local or cloud)."""
Expand Down
224 changes: 202 additions & 22 deletions src/aws_durable_execution_sdk_python_testing/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import aws_durable_execution_sdk_python
import boto3 # type: ignore
from botocore.exceptions import ClientError # type: ignore
from aws_durable_execution_sdk_python.execution import (
InvocationStatus,
durable_execution,
Expand Down Expand Up @@ -75,6 +76,8 @@

from aws_durable_execution_sdk_python_testing.execution import Execution
from aws_durable_execution_sdk_python_testing.web.server import WebServiceConfig
from aws_durable_execution_sdk_python_testing.model import Event


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -792,9 +795,9 @@ def run(
msg = f"Failed to invoke Lambda function {self.function_name}: {e}"
raise DurableFunctionsTestError(msg) from e

# Check HTTP status code (200 for RequestResponse, 202 for Event, 204 for DryRun)
# Check HTTP status code, 200 for RequestResponse
status_code = response.get("StatusCode")
if status_code not in (200, 202, 204):
if status_code != 200:
error_payload = response["Payload"].read().decode("utf-8")
msg = f"Lambda invocation failed with status {status_code}: {error_payload}"
raise DurableFunctionsTestError(msg)
Expand All @@ -819,17 +822,126 @@ def run(
)
raise DurableFunctionsTestError(msg)

# Poll for completion
execution_response = self._wait_for_completion(execution_arn, timeout)
return self.wait_for_result(execution_arn=execution_arn, timeout=timeout)

# Get execution history
history_response = self._get_execution_history(execution_arn)
def run_async(
self,
input: str | None = None, # noqa: A002
timeout: int = 60,
) -> str:
"""Execute function on AWS Lambda asynchronously"""
logger.info(
"Invoking Lambda function: %s (timeout: %ds)", self.function_name, timeout
)
payload = json.dumps(input)
try:
response = self.lambda_client.invoke(
FunctionName=self.function_name,
InvocationType="Event",
Payload=payload,
)
except Exception as e:
msg = f"Failed to invoke Lambda function {self.function_name}: {e}"
raise DurableFunctionsTestError(msg) from e

# Build test result from execution history
return DurableFunctionTestResult.from_execution_history(
execution_response, history_response
# Check HTTP status code, 202 for Event
status_code = response.get("StatusCode")
if status_code != 202:
error_payload = response["Payload"].read().decode("utf-8")
msg = f"Lambda invocation failed with status {status_code}: {error_payload}"
raise DurableFunctionsTestError(msg)

return response.get("DurableExecutionArn")

def _get_callback_id_from_events(
self, events: list[Event], name: str | None = None
) -> str | None:
"""
Get callback ID from execution history for callbacks that haven't completed.

Args:
execution_arn: The ARN of the execution to query.
name: Optional callback name to search for. If not provided, returns the latest callback.

Returns:
The callback ID string for a non-completed callback, or None if not found.

Raises:
DurableFunctionsTestError: If the named callback has already succeeded/failed/timed out.
"""
callback_started_events = [
event for event in events if event.event_type == "CallbackStarted"
]

if not callback_started_events:
return None

completed_callback_ids = {
event.event_id
for event in events
if event.event_type
in ["CallbackSucceeded", "CallbackFailed", "CallbackTimedOut"]
}

if name is not None:
for event in callback_started_events:
if event.name == name:
callback_id = event.event_id
if callback_id in completed_callback_ids:
raise DurableFunctionsTestError(
f"Callback {name} has already completed (succeeded/failed/timed out)"
)
return (
event.callback_started_details.callback_id
if event.callback_started_details
else None
)
return None

# If name is not provided, find the latest non-completed callback event
active_callbacks = [
event
for event in callback_started_events
if event.event_id not in completed_callback_ids
]

if not active_callbacks:
return None

latest_event = active_callbacks[-1]
return (
latest_event.callback_started_details.callback_id
if latest_event.callback_started_details
else None
)

def send_callback_success(self, callback_id: str) -> None:
try:
self.lambda_client.send_durable_execution_callback_success(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to update this to support the optional Result field. Same for error on callback failures.

CallbackId=callback_id
)
except Exception as e:
msg = f"Failed to send callback success for {self.function_name}, callback_id {callback_id}: {e}"
raise DurableFunctionsTestError(msg) from e

def send_callback_failure(self, callback_id: str) -> None:
try:
self.lambda_client.send_durable_execution_callback_failure(
CallbackId=callback_id
)
except Exception as e:
msg = f"Failed to send callback failure for {self.function_name}, callback_id {callback_id}: {e}"
raise DurableFunctionsTestError(msg) from e

def send_callback_heartbeat(self, callback_id: str) -> None:
try:
self.lambda_client.send_durable_execution_callback_heartbeat(
CallbackId=callback_id
)
except Exception as e:
msg = f"Failed to send callback heartbeat for {self.function_name}, callback_id {callback_id}: {e}"
raise DurableFunctionsTestError(msg) from e

def _wait_for_completion(
self, execution_arn: str, timeout: int
) -> GetDurableExecutionResponse:
Expand Down Expand Up @@ -886,7 +998,81 @@ def _wait_for_completion(
)
raise TimeoutError(msg)

def _get_execution_history(
def wait_for_result(
self, execution_arn: str, timeout: int = 60
) -> DurableFunctionTestResult:
# Poll for completion
execution_response = self._wait_for_completion(execution_arn, timeout)

try:
history_response = self._fetch_execution_history(execution_arn)
except Exception as e:
msg = f"Failed to fetch execution history: {e}"
raise DurableFunctionsTestError(msg) from e

# Build test result from execution history
return DurableFunctionTestResult.from_execution_history(
execution_response, history_response
)

def wait_for_callback(
self, execution_arn: str, name: str | None = None, timeout: int = 60
) -> str:
"""
Wait for and retrieve a callback ID from a Step Functions execution.

Polls the execution history at regular intervals until a callback ID is found
or the timeout is reached.

Args:
execution_arn: Execution Arn
name: Specific callback name, default to None
timeout: Maximum time in seconds to wait for callback. Defaults to 60.

Returns:
str: The callback ID/token retrieved from the execution history

Raises:
TimeoutError: If callback is not found within the specified timeout period
DurableFunctionsTestError: If there's an error fetching execution history
(excluding retryable errors)
"""
start_time = time.time()

while time.time() - start_time < timeout:
try:
history_response = self._fetch_execution_history(execution_arn)
callback_id = self._get_callback_id_from_events(
events=history_response.events, name=name
)
if callback_id:
return callback_id
except ClientError as e:
error_code = e.response["Error"]["Code"]
# retryable error, the execution may not start yet in async invoke situation
if error_code in ["ResourceNotFoundException"]:
pass
else:
msg = f"Failed to fetch execution history: {e}"
raise DurableFunctionsTestError(msg) from e
except DurableFunctionsTestError as e:
raise e
except Exception as e:
msg = f"Failed to fetch execution history: {e}"
raise DurableFunctionsTestError(msg) from e

# Wait before next poll
time.sleep(self.poll_interval)

# Timeout reached
elapsed = time.time() - start_time
msg = (
f"Callback did not available within {timeout}s "
f"(elapsed: {elapsed:.1f}s."
)
raise TimeoutError(msg)

def _fetch_execution_history(
self, execution_arn: str
) -> GetDurableExecutionHistoryResponse:
"""Retrieve execution history from Lambda service.
Expand All @@ -898,19 +1084,13 @@ def _get_execution_history(
GetDurableExecutionHistoryResponse with typed Event objects

Raises:
DurableFunctionsTestError: If history retrieval fails
ClientError: If lambda client encounter error
"""
try:
history_dict = self.lambda_client.get_durable_execution_history(
DurableExecutionArn=execution_arn,
IncludeExecutionData=True,
)
history_response = GetDurableExecutionHistoryResponse.from_dict(
history_dict
)
except Exception as e:
msg = f"Failed to get execution history: {e}"
raise DurableFunctionsTestError(msg) from e
history_dict = self.lambda_client.get_durable_execution_history(
DurableExecutionArn=execution_arn,
IncludeExecutionData=True,
)
history_response = GetDurableExecutionHistoryResponse.from_dict(history_dict)

logger.info("Retrieved %d events from history", len(history_response.events))

Expand Down
Loading