Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions src/aws_durable_execution_sdk_python/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
from aws_durable_execution_sdk_python.logger import Logger, LogInfo
from aws_durable_execution_sdk_python.operation.callback import (
create_callback_handler,
CallbackOperationExecutor,
wait_for_callback_handler,
)
from aws_durable_execution_sdk_python.operation.child import child_handler
from aws_durable_execution_sdk_python.operation.invoke import invoke_handler
from aws_durable_execution_sdk_python.operation.invoke import InvokeOperationExecutor
from aws_durable_execution_sdk_python.operation.map import map_handler
from aws_durable_execution_sdk_python.operation.parallel import parallel_handler
from aws_durable_execution_sdk_python.operation.step import step_handler
from aws_durable_execution_sdk_python.operation.wait import wait_handler
from aws_durable_execution_sdk_python.operation.step import StepOperationExecutor
from aws_durable_execution_sdk_python.operation.wait import WaitOperationExecutor
from aws_durable_execution_sdk_python.operation.wait_for_condition import (
wait_for_condition_handler,
WaitForConditionOperationExecutor,
)
from aws_durable_execution_sdk_python.serdes import (
PassThroughSerDes,
Expand Down Expand Up @@ -323,13 +323,14 @@ def create_callback(
if not config:
config = CallbackConfig()
operation_id: str = self._create_step_id()
callback_id: str = create_callback_handler(
executor: CallbackOperationExecutor = CallbackOperationExecutor(
state=self.state,
operation_identifier=OperationIdentifier(
operation_id=operation_id, parent_id=self._parent_id, name=name
),
config=config,
)
callback_id: str = executor.process()
result: Callback = Callback(
callback_id=callback_id,
operation_id=operation_id,
Expand Down Expand Up @@ -357,8 +358,10 @@ def invoke(
Returns:
The result of the invoked function
"""
if not config:
config = InvokeConfig[P, R]()
operation_id = self._create_step_id()
result: R = invoke_handler(
executor: InvokeOperationExecutor[R] = InvokeOperationExecutor(
function_name=function_name,
payload=payload,
state=self.state,
Expand All @@ -369,6 +372,7 @@ def invoke(
),
config=config,
)
result: R = executor.process()
self.state.track_replay(operation_id=operation_id)
return result

Expand Down Expand Up @@ -505,8 +509,10 @@ def step(
) -> T:
step_name = self._resolve_step_name(name, func)
logger.debug("Step name: %s", step_name)
if not config:
config = StepConfig()
operation_id = self._create_step_id()
result: T = step_handler(
executor: StepOperationExecutor[T] = StepOperationExecutor(
func=func,
config=config,
state=self.state,
Expand All @@ -517,6 +523,7 @@ def step(
),
context_logger=self.logger,
)
result: T = executor.process()
self.state.track_replay(operation_id=operation_id)
return result

Expand All @@ -532,15 +539,17 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
msg = "duration must be at least 1 second"
raise ValidationError(msg)
operation_id = self._create_step_id()
wait_handler(
seconds=seconds,
wait_seconds = duration.seconds
executor: WaitOperationExecutor = WaitOperationExecutor(
seconds=wait_seconds,
state=self.state,
operation_identifier=OperationIdentifier(
operation_id=operation_id,
parent_id=self._parent_id,
name=name,
),
)
executor.process()
self.state.track_replay(operation_id=operation_id)

def wait_for_callback(
Expand Down Expand Up @@ -584,17 +593,20 @@ def wait_for_condition(
raise ValidationError(msg)

operation_id = self._create_step_id()
result: T = wait_for_condition_handler(
check=check,
config=config,
state=self.state,
operation_identifier=OperationIdentifier(
operation_id=operation_id,
parent_id=self._parent_id,
name=name,
),
context_logger=self.logger,
executor: WaitForConditionOperationExecutor[T] = (
WaitForConditionOperationExecutor(
check=check,
config=config,
state=self.state,
operation_identifier=OperationIdentifier(
operation_id=operation_id,
parent_id=self._parent_id,
name=name,
),
context_logger=self.logger,
)
)
result: T = executor.process()
self.state.track_replay(operation_id=operation_id)
return result

Expand Down
187 changes: 187 additions & 0 deletions src/aws_durable_execution_sdk_python/operation/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
"""Base classes for operation executors with checkpoint response handling."""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, TypeVar

from aws_durable_execution_sdk_python.exceptions import InvalidStateError

if TYPE_CHECKING:
from aws_durable_execution_sdk_python.state import CheckpointedResult

T = TypeVar("T")


@dataclass(frozen=True)
class CheckResult(Generic[T]):
"""Result of checking operation checkpoint status.

Encapsulates the outcome of checking an operation's status and determines
the next action in the operation execution flow.

IMPORTANT: Do not construct directly. Use factory methods:
- create_is_ready_to_execute(checkpoint) - operation ready to execute
- create_started() - checkpoint created, check status again
- create_completed(result) - terminal result available

Attributes:
is_ready_to_execute: True if the operation is ready to execute its logic
has_checkpointed_result: True if a terminal result is already available
checkpointed_result: Checkpoint data for execute()
deserialized_result: Final result when operation completed
"""

is_ready_to_execute: bool
has_checkpointed_result: bool
checkpointed_result: CheckpointedResult | None = None
deserialized_result: T | None = None

@classmethod
def create_is_ready_to_execute(
cls, checkpoint: CheckpointedResult
) -> CheckResult[T]:
"""Create a CheckResult indicating the operation is ready to execute.

Args:
checkpoint: The checkpoint data to pass to execute()

Returns:
CheckResult with is_ready_to_execute=True
"""
return cls(
is_ready_to_execute=True,
has_checkpointed_result=False,
checkpointed_result=checkpoint,
)

@classmethod
def create_started(cls) -> CheckResult[T]:
"""Create a CheckResult signaling that a checkpoint was created.

Signals that process() should verify checkpoint status again to detect
if the operation completed already during checkpoint creation.

Returns:
CheckResult indicating process() should check status again
"""
return cls(is_ready_to_execute=False, has_checkpointed_result=False)

@classmethod
def create_completed(cls, result: T) -> CheckResult[T]:
"""Create a CheckResult with a terminal result already deserialized.

Args:
result: The final deserialized result

Returns:
CheckResult with has_checkpointed_result=True and deserialized_result set
"""
return cls(
is_ready_to_execute=False,
has_checkpointed_result=True,
deserialized_result=result,
)


class OperationExecutor(ABC, Generic[T]):
"""Base class for durable operations with checkpoint response handling.

Provides a framework for implementing operations that check status after
creating START checkpoints to handle synchronous completion, avoiding
unnecessary execution or suspension.

The common pattern:
1. Check operation status
2. Create START checkpoint if needed
3. Check status again (detects synchronous completion)
4. Execute operation logic when ready

Subclasses must implement:
- check_result_status(): Check status, create checkpoint if needed, return next action
- execute(): Execute the operation logic with checkpoint data
"""

@abstractmethod
def check_result_status(self) -> CheckResult[T]:
"""Check operation status and create START checkpoint if needed.

Called twice by process() when creating synchronous checkpoints: once before
and once after, to detect if the operation completed immediately.

This method should:
1. Get the current checkpoint result
2. Check for terminal statuses (SUCCEEDED, FAILED, etc.) and handle them
3. Check for pending statuses and suspend if needed
4. Create a START checkpoint if the operation hasn't started
5. Return a CheckResult indicating the next action

Returns:
CheckResult indicating whether to:
- Return a terminal result (has_checkpointed_result=True)
- Execute operation logic (is_ready_to_execute=True)
- Check status again (neither flag set - checkpoint was just created)

Raises:
Operation-specific exceptions for terminal failure states
SuspendExecution for pending states
"""
... # pragma: no cover

@abstractmethod
def execute(self, checkpointed_result: CheckpointedResult) -> T:
"""Execute operation logic with checkpoint data.

This method is called when the operation is ready to execute its core logic.
It receives the checkpoint data that was returned by check_result_status().

Args:
checkpointed_result: The checkpoint data containing operation state

Returns:
The result of executing the operation

Raises:
May raise operation-specific errors during execution
"""
... # pragma: no cover

def process(self) -> T:
"""Process operation with checkpoint response handling.

Orchestrates the double-check pattern:
1. Check status (handles replay and existing checkpoints)
2. If checkpoint was just created, check status again (detects synchronous completion)
3. Return terminal result if available
4. Execute operation logic if ready
5. Raise error for invalid states

Returns:
The final result of the operation

Raises:
InvalidStateError: If the check result is in an invalid state
May raise operation-specific errors from check_result_status() or execute()
"""
# Check 1: Entry (handles replay and existing checkpoints)
result = self.check_result_status()

# If checkpoint was created, verify checkpoint response for immediate status change
if not result.is_ready_to_execute and not result.has_checkpointed_result:
result = self.check_result_status()

# Return terminal result if available (can be None for operations that return None)
if result.has_checkpointed_result:
return result.deserialized_result # type: ignore[return-value]

# Execute operation logic
if result.is_ready_to_execute:
if result.checkpointed_result is None:
msg = "CheckResult is marked ready to execute but checkpointed result is not set."
raise InvalidStateError(msg)
return self.execute(result.checkpointed_result)

# Invalid state - neither terminal nor ready to execute
msg = "Invalid CheckResult state: neither terminal nor ready to execute"
raise InvalidStateError(msg)
Loading