Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/aws_durable_execution_sdk_python_testing/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def start_execution(
trace_fields=input.trace_fields,
tenant_id=input.tenant_id,
input=input.input,
lambda_endpoint=input.lambda_endpoint,
)

execution = Execution.new(input=input)
Expand Down
121 changes: 86 additions & 35 deletions src/aws_durable_execution_sdk_python_testing/invoker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
from threading import Lock
from typing import TYPE_CHECKING, Any, Protocol

import boto3 # type: ignore
Expand Down Expand Up @@ -108,21 +109,68 @@ def update_endpoint(self, endpoint_url: str, region_name: str) -> None:
class LambdaInvoker(Invoker):
def __init__(self, lambda_client: Any) -> None:
self.lambda_client = lambda_client
# Maps execution_arn -> endpoint for that execution
# Maps endpoint -> client to reuse clients across executions
self._execution_endpoints: dict[str, str] = {}
self._endpoint_clients: dict[str, Any] = {}
self._current_endpoint: str = "" # Track current endpoint for new executions
self._lock = Lock()

@staticmethod
def create(endpoint_url: str, region_name: str) -> LambdaInvoker:
"""Create with the boto lambda client."""
return LambdaInvoker(
invoker = LambdaInvoker(
boto3.client(
"lambdainternal", endpoint_url=endpoint_url, region_name=region_name
)
)
invoker._current_endpoint = endpoint_url
invoker._endpoint_clients[endpoint_url] = invoker.lambda_client
return invoker

def update_endpoint(self, endpoint_url: str, region_name: str) -> None:
"""Update the Lambda client endpoint."""
self.lambda_client = boto3.client(
"lambdainternal", endpoint_url=endpoint_url, region_name=region_name
)
# Cache client by endpoint to reuse across executions
with self._lock:
if endpoint_url not in self._endpoint_clients:
self._endpoint_clients[endpoint_url] = boto3.client(
"lambdainternal", endpoint_url=endpoint_url, region_name=region_name
)
self.lambda_client = self._endpoint_clients[endpoint_url]
self._current_endpoint = endpoint_url

def _get_client_for_execution(
self, durable_execution_arn: str, lambda_endpoint: str | None = None
) -> Any:
"""Get the appropriate client for this execution."""
# Use provided endpoint or fall back to cached endpoint for this execution
if lambda_endpoint:
# Client should already exist from update_endpoint() call
if lambda_endpoint not in self._endpoint_clients:
from aws_durable_execution_sdk_python_testing.exceptions import (
ServiceException,
)

raise ServiceException(
f"Lambda endpoint {lambda_endpoint} not configured. update_endpoint() must be called first."
)
return self._endpoint_clients[lambda_endpoint]

# Fallback to cached endpoint
if durable_execution_arn not in self._execution_endpoints:
with self._lock:
if durable_execution_arn not in self._execution_endpoints:
self._execution_endpoints[durable_execution_arn] = (
self._current_endpoint
)

endpoint = self._execution_endpoints[durable_execution_arn]

# If no endpoint configured, fall back to default client
if not endpoint:
return self.lambda_client

return self._endpoint_clients[endpoint]

def create_invocation_input(
self, execution: Execution
Expand Down Expand Up @@ -165,9 +213,12 @@ def invoke(
msg = "Function name is required"
raise InvalidParameterValueException(msg)

# Get the client for this execution
client = self._get_client_for_execution(input.durable_execution_arn)

try:
# Invoke AWS Lambda function using standard invoke method
response = self.lambda_client.invoke(
response = client.invoke(
FunctionName=function_name,
InvocationType="RequestResponse", # Synchronous invocation
Payload=json.dumps(input.to_dict(), default=str),
Expand All @@ -192,49 +243,49 @@ def invoke(
# Convert to DurableExecutionInvocationOutput
return DurableExecutionInvocationOutput.from_dict(response_dict)

except self.lambda_client.exceptions.ResourceNotFoundException as e:
except client.exceptions.ResourceNotFoundException as e:
msg = f"Function not found: {function_name}"
raise ResourceNotFoundException(msg) from e
except self.lambda_client.exceptions.InvalidParameterValueException as e:
except client.exceptions.InvalidParameterValueException as e:
msg = f"Invalid parameter: {e}"
raise InvalidParameterValueException(msg) from e
except (
self.lambda_client.exceptions.TooManyRequestsException,
self.lambda_client.exceptions.ServiceException,
self.lambda_client.exceptions.ResourceConflictException,
self.lambda_client.exceptions.InvalidRequestContentException,
self.lambda_client.exceptions.RequestTooLargeException,
self.lambda_client.exceptions.UnsupportedMediaTypeException,
self.lambda_client.exceptions.InvalidRuntimeException,
self.lambda_client.exceptions.InvalidZipFileException,
self.lambda_client.exceptions.ResourceNotReadyException,
self.lambda_client.exceptions.SnapStartTimeoutException,
self.lambda_client.exceptions.SnapStartNotReadyException,
self.lambda_client.exceptions.SnapStartException,
self.lambda_client.exceptions.RecursiveInvocationException,
client.exceptions.TooManyRequestsException,
client.exceptions.ServiceException,
client.exceptions.ResourceConflictException,
client.exceptions.InvalidRequestContentException,
client.exceptions.RequestTooLargeException,
client.exceptions.UnsupportedMediaTypeException,
client.exceptions.InvalidRuntimeException,
client.exceptions.InvalidZipFileException,
client.exceptions.ResourceNotReadyException,
client.exceptions.SnapStartTimeoutException,
client.exceptions.SnapStartNotReadyException,
client.exceptions.SnapStartException,
client.exceptions.RecursiveInvocationException,
) as e:
msg = f"Lambda invocation failed: {e}"
raise DurableFunctionsTestError(msg) from e
except (
self.lambda_client.exceptions.InvalidSecurityGroupIDException,
self.lambda_client.exceptions.EC2ThrottledException,
self.lambda_client.exceptions.EFSMountConnectivityException,
self.lambda_client.exceptions.SubnetIPAddressLimitReachedException,
self.lambda_client.exceptions.EC2UnexpectedException,
self.lambda_client.exceptions.InvalidSubnetIDException,
self.lambda_client.exceptions.EC2AccessDeniedException,
self.lambda_client.exceptions.EFSIOException,
self.lambda_client.exceptions.ENILimitReachedException,
self.lambda_client.exceptions.EFSMountTimeoutException,
self.lambda_client.exceptions.EFSMountFailureException,
client.exceptions.InvalidSecurityGroupIDException,
client.exceptions.EC2ThrottledException,
client.exceptions.EFSMountConnectivityException,
client.exceptions.SubnetIPAddressLimitReachedException,
client.exceptions.EC2UnexpectedException,
client.exceptions.InvalidSubnetIDException,
client.exceptions.EC2AccessDeniedException,
client.exceptions.EFSIOException,
client.exceptions.ENILimitReachedException,
client.exceptions.EFSMountTimeoutException,
client.exceptions.EFSMountFailureException,
) as e:
msg = f"Lambda infrastructure error: {e}"
raise DurableFunctionsTestError(msg) from e
except (
self.lambda_client.exceptions.KMSAccessDeniedException,
self.lambda_client.exceptions.KMSDisabledException,
self.lambda_client.exceptions.KMSNotFoundException,
self.lambda_client.exceptions.KMSInvalidStateException,
client.exceptions.KMSAccessDeniedException,
client.exceptions.KMSDisabledException,
client.exceptions.KMSNotFoundException,
client.exceptions.KMSInvalidStateException,
) as e:
msg = f"Lambda KMS error: {e}"
raise DurableFunctionsTestError(msg) from e
Expand Down
4 changes: 4 additions & 0 deletions src/aws_durable_execution_sdk_python_testing/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class StartDurableExecutionInput:
trace_fields: dict | None = None
tenant_id: str | None = None
input: str | None = None
lambda_endpoint: str | None = None # Endpoint for this specific execution

@classmethod
def from_dict(cls, data: dict) -> StartDurableExecutionInput:
Expand Down Expand Up @@ -146,6 +147,7 @@ def from_dict(cls, data: dict) -> StartDurableExecutionInput:
trace_fields=data.get("TraceFields"),
tenant_id=data.get("TenantId"),
input=data.get("Input"),
lambda_endpoint=data.get("LambdaEndpoint", None),
)

def to_dict(self) -> dict[str, Any]:
Expand All @@ -165,6 +167,8 @@ def to_dict(self) -> dict[str, Any]:
result["TenantId"] = self.tenant_id
if self.input is not None:
result["Input"] = self.input
if self.lambda_endpoint is not None:
result["LambdaEndpoint"] = self.lambda_endpoint
return result

def get_normalized_input(self):
Expand Down
Loading