diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4c2a6a9..9ddb7f2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -147,6 +147,30 @@ hatch run examples:list hatch run examples:deploy "Hello World" ``` +### Generate Event Files for Testing +```bash +# Generate expected events JSON file for a function +# This runs the function locally and captures execution events + +# Basic usage - hello_world example +hatch run examples:generate-events \ + --function-module hello_world \ + --function-name handler \ + --input '"test input"' \ + --output examples/events/hello_world_events.json + +# Available options: +# --function-module: Python module path (required) +# --function-name: Function name within module (required) +# --input: JSON string input for the function (optional) +# --output: Output path for events JSON file (required) +# --timeout: Execution timeout in seconds (default: 60) +# --verbose: Enable detailed logging + +# Use generated events in your tests with the event assertion helper: +# assert_events('events/hello_world_events.json', result.events) +``` + ### Other CLI Commands ```bash # Invoke deployed function diff --git a/examples/events/hello_world_events.json b/examples/events/hello_world_events.json new file mode 100644 index 0000000..6869592 --- /dev/null +++ b/examples/events/hello_world_events.json @@ -0,0 +1,113 @@ +{ + "events": [ + { + "event_type": "ExecutionStarted", + "event_timestamp": "2025-12-11T00:32:13.887857+00:00", + "event_id": 1, + "operation_id": "inv-12345678-1234-1234-1234-123456789012", + "name": "execution-name", + "execution_started_details": { + "input": { + "truncated": true + }, + "execution_timeout": 60 + } + }, + { + "event_type": "StepStarted", + "event_timestamp": "2025-12-11T00:32:13.994326+00:00", + "sub_type": "Step", + "event_id": 2, + "operation_id": "1ced8f5be2db23a6513eba4d819c73806424748a7bc6fa0d792cc1c7d1775a97", + "name": "step_1", + "step_started_details": {} + }, + { + "event_type": "StepSucceeded", + "event_timestamp": "2025-12-11T00:32:13.994354+00:00", + "sub_type": "Step", + "event_id": 3, + "operation_id": "1ced8f5be2db23a6513eba4d819c73806424748a7bc6fa0d792cc1c7d1775a97", + "name": "step_1", + "step_succeeded_details": { + "result": { + "truncated": true + }, + "retry_details": { + "current_attempt": 1, + "next_attempt_delay_seconds": 0 + } + } + }, + { + "event_type": "WaitStarted", + "event_timestamp": "2025-12-11T00:32:14.099840+00:00", + "sub_type": "Wait", + "event_id": 4, + "operation_id": "c5faca15ac2f93578b39ef4b6bbb871bdedce4ddd584fd31f0bb66fade3947e6", + "wait_started_details": { + "duration": 10, + "scheduled_end_timestamp": "2025-12-11T00:32:24.099828+00:00" + } + }, + { + "event_type": "InvocationCompleted", + "event_timestamp": "2025-12-11T00:32:14.205118+00:00", + "event_id": 5 + }, + { + "event_type": "WaitSucceeded", + "event_timestamp": "2025-12-11T00:32:24.206724+00:00", + "sub_type": "Wait", + "event_id": 6, + "operation_id": "c5faca15ac2f93578b39ef4b6bbb871bdedce4ddd584fd31f0bb66fade3947e6", + "wait_succeeded_details": { + "duration": 10 + } + }, + { + "event_type": "StepStarted", + "event_timestamp": "2025-12-11T00:32:24.310890+00:00", + "sub_type": "Step", + "event_id": 7, + "operation_id": "6f760b9e9eac89f07ab0223b0f4acb04d1e355d893a1b86a83f4d4b405adee99", + "name": "step_2", + "step_started_details": {} + }, + { + "event_type": "StepSucceeded", + "event_timestamp": "2025-12-11T00:32:24.310917+00:00", + "sub_type": "Step", + "event_id": 8, + "operation_id": "6f760b9e9eac89f07ab0223b0f4acb04d1e355d893a1b86a83f4d4b405adee99", + "name": "step_2", + "step_succeeded_details": { + "result": { + "truncated": true + }, + "retry_details": { + "current_attempt": 1, + "next_attempt_delay_seconds": 0 + } + } + }, + { + "event_type": "InvocationCompleted", + "event_timestamp": "2025-12-11T00:32:24.413013+00:00", + "event_id": 9 + }, + { + "event_type": "ExecutionSucceeded", + "event_timestamp": "2025-12-11T00:32:24.413238+00:00", + "event_id": 10, + "operation_id": "inv-12345678-1234-1234-1234-123456789012", + "name": "execution-name", + "execution_succeeded_details": { + "result": { + "payload": "{\"statusCode\": 200, \"body\": \"Hello from Durable Lambda! (status: 200)\"}", + "truncated": true + } + } + } + ] +} \ No newline at end of file diff --git a/examples/scripts/cli_event_generator.py b/examples/scripts/cli_event_generator.py new file mode 100644 index 0000000..d7ab76a --- /dev/null +++ b/examples/scripts/cli_event_generator.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +"""CLI tool for generating event assertion files from durable function executions. + +This tool runs durable functions locally and captures their execution events +to generate JSON files that can be used for event-based test assertions. + +Usage: + python examples/cli_event_generator.py \ + --function-module examples.src.hello_world \ + --function-name handler \ + --input '{"test": "data"}' \ + --output examples/events/hello_world_events.json +""" + +import argparse +import importlib +import json +import logging +import sys +from pathlib import Path +from typing import Any + +# Add src directories to Python path +examples_dir = Path(__file__).parent +src_dir = examples_dir / "src" +main_src_dir = examples_dir.parent / "src" + +for path in [str(src_dir), str(main_src_dir)]: + if path not in sys.path: + sys.path.insert(0, path) + +from aws_durable_execution_sdk_python_testing.runner import DurableFunctionTestRunner + + +logger = logging.getLogger(__name__) + + +def setup_logging(verbose: bool = False) -> None: + """Configure logging for the CLI tool.""" + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, + format="%(levelname)s: %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + +def import_function(module_name: str, function_name: str) -> Any: + """Import a function from a module dynamically. + + Args: + module_name: Python module path (e.g., 'examples.src.hello_world') + function_name: Function name within the module (e.g., 'handler') + + Returns: + The imported function + + Raises: + ImportError: If module or function cannot be imported + """ + try: + module = importlib.import_module(module_name) + return getattr(module, function_name) + except ImportError as e: + raise ImportError(f"Failed to import module '{module_name}': {e}") from e + except AttributeError as e: + raise ImportError( + f"Function '{function_name}' not found in module '{module_name}': {e}" + ) from e + + +def serialize_event(event: Any) -> dict: + """Serialize an Event object to a JSON-serializable dictionary. + + Args: + event: Event object to serialize + + Returns: + Dictionary representation of the event + """ + # Convert the event to a dictionary, handling datetime objects + event_dict = {} + + for field_name, field_value in event.__dict__.items(): + if field_value is None: + continue + + if hasattr(field_value, "isoformat"): # datetime objects + event_dict[field_name] = field_value.isoformat() + elif hasattr(field_value, "__dict__"): # nested objects + event_dict[field_name] = serialize_nested_object(field_value) + else: + event_dict[field_name] = field_value + + return event_dict + + +def serialize_nested_object(obj: Any) -> dict: + """Serialize nested objects recursively.""" + if obj is None: + return None + + result = {} + for field_name, field_value in obj.__dict__.items(): + if field_value is None: + continue + + if hasattr(field_value, "isoformat"): # datetime objects + result[field_name] = field_value.isoformat() + elif hasattr(field_value, "__dict__"): # nested objects + result[field_name] = serialize_nested_object(field_value) + else: + result[field_name] = field_value + + return result + + +def generate_events_file( + function_module: str, + function_name: str, + input_data: str | None, + output_path: Path, + timeout: int = 60, +) -> None: + """Generate events file by running the durable function locally. + + Args: + function_module: Python module containing the function + function_name: Name of the durable function + input_data: JSON string input for the function + output_path: Path where to save the events JSON file + timeout: Execution timeout in seconds + """ + logger.info(f"Importing function {function_name} from {function_module}") + handler = import_function(function_module, function_name) + + logger.info("Running durable function locally...") + with DurableFunctionTestRunner(handler=handler) as runner: + result = runner.run(input=input_data, timeout=timeout) + + logger.info(f"Execution completed with status: {result.status}") + logger.info(f"Captured {len(result.events)} events") + + # Serialize events to JSON-compatible format + events_data = {"events": [serialize_event(event) for event in result.events]} + + # Ensure output directory exists + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Write events to JSON file + with open(output_path, "w", encoding="utf-8") as f: + json.dump(events_data, f, indent=2, ensure_ascii=False) + + logger.info(f"Events saved to: {output_path}") + + +def main() -> None: + """Main CLI entry point.""" + parser = argparse.ArgumentParser( + description="Generate event assertion files from durable function executions", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Generate events for hello_world example + python examples/cli_event_generator.py \\ + --function-module hello_world \\ + --function-name handler \\ + --input '"test input"' \\ + --output examples/events/hello_world_events.json + + # Generate events for a function with complex input + python examples/cli_event_generator.py \\ + --function-module step.step_with_retry \\ + --function-name handler \\ + --input '{"retries": 3, "data": "test"}' \\ + --output examples/events/step_with_retry_events.json + """, + ) + + parser.add_argument( + "--function-module", + required=True, + help="Python module containing the durable function (e.g., 'hello_world' or 'step.step_with_retry')", + ) + + parser.add_argument( + "--function-name", + required=True, + help="Name of the durable function within the module (e.g., 'handler')", + ) + + parser.add_argument( + "--input", help="JSON string input for the function (default: None)" + ) + + parser.add_argument( + "--output", + type=Path, + required=True, + help="Output path for the events JSON file", + ) + + parser.add_argument( + "--timeout", + type=int, + default=60, + help="Execution timeout in seconds (default: 60)", + ) + + parser.add_argument( + "--verbose", "-v", action="store_true", help="Enable verbose logging" + ) + + args = parser.parse_args() + + setup_logging(args.verbose) + + try: + generate_events_file( + function_module=args.function_module, + function_name=args.function_name, + input_data=args.input, + output_path=args.output, + timeout=args.timeout, + ) + logger.info("Event generation completed successfully!") + + except Exception as e: + logger.error(f"Event generation failed: {e}") + if args.verbose: + logger.exception("Full traceback:") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/examples/test/event_helper.py b/examples/test/event_helper.py new file mode 100644 index 0000000..b446a08 --- /dev/null +++ b/examples/test/event_helper.py @@ -0,0 +1,283 @@ +"""Advanced event assertion helper for examples. + +This module provides sophisticated event assertion capabilities with three categories: +1. STRICT_EQUAL: Key and value must match exactly +2. KEY_EQUAL: Key must exist but value can vary +3. IGNORE: Field is completely ignored + +The helper handles nested objects and provides detailed assertion control. +""" + +import json +import logging +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Set + +logger = logging.getLogger(__name__) + + +class FieldCategory(Enum): + """Field assertion categories.""" + + STRICT_EQUAL = "strict_equal" # Key and value must match exactly + KEY_EQUAL = "key_equal" # Key must exist but value can vary + IGNORE = "ignore" # Field is completely ignored + + +class EventAssertionError(Exception): + """Exception raised when event assertions fail.""" + + pass + + +# Field categorization for Event assertions using dot notation for nested fields +FIELD_CATEGORIES = { + # STRICT_EQUAL: Key and value must match exactly + FieldCategory.STRICT_EQUAL: { + "event_type", # Must match exactly + "sub_type", # Must match exactly (Step, Wait, etc.) + "name", # Must match exactly + # Nested field examples - use dot notation + "execution_succeeded_details.result", # Execution result should match + "wait_started_details.duration", # Wait duration should match exactly + "wait_succeeded_details.duration", # Wait duration should match exactly + }, + # KEY_EQUAL: Key must exist but value can vary + FieldCategory.KEY_EQUAL: { + "event_timestamp", # Must exist but timestamp will vary + "event_id", # Must exist but ID will vary + "operation_id", # Must exist but UUID will vary + "parent_id", # Must exist but UUID will vary + }, + # IGNORE: Completely ignore these fields + FieldCategory.IGNORE: set(), +} + +# Event type specific overrides using same format as FIELD_CATEGORIES +# Fields are optional - only specify what you want to override +EVENT_TYPE_OVERRIDE = { + "ExecutionStarted": { + FieldCategory.IGNORE: {"name"}, # Execution names can vary based on test setup + }, + "ExecutionSucceeded": { + FieldCategory.IGNORE: {"name"}, # Execution names can vary based on test setup + }, +} + + +def get_nested_value(obj: Any, path: str) -> Any: + """Get a nested value from an object using dot notation. + + Args: + obj: Object to get value from + path: Dot-separated path (e.g., 'step_succeeded_details.result') + + Returns: + The nested value or None if path doesn't exist + """ + if obj is None: + return None + + current = obj + for part in path.split("."): + if hasattr(current, "__dict__"): + current = getattr(current, part, None) + elif isinstance(current, dict): + current = current.get(part) + else: + return None + + if current is None: + return None + + return current + + +def get_field_category(field_path: str, event_type: str = "") -> FieldCategory: + """Get the category for a field path, considering event type overrides. + + Override only affects specific keys mentioned in the override dict. + All other keys follow the general FIELD_CATEGORIES rules. + + Args: + field_path: Field path (can be nested with dots) + event_type: Event type for override checking + + Returns: + FieldCategory enum value + """ + # Check event type specific overrides first - only for keys explicitly mentioned + if event_type and event_type in EVENT_TYPE_OVERRIDE: + override_categories = EVENT_TYPE_OVERRIDE[event_type] + + # Check if this specific field_path is mentioned in any override category + if field_path in override_categories.get(FieldCategory.STRICT_EQUAL, set()): + return FieldCategory.STRICT_EQUAL + elif field_path in override_categories.get(FieldCategory.KEY_EQUAL, set()): + return FieldCategory.KEY_EQUAL + elif field_path in override_categories.get(FieldCategory.IGNORE, set()): + return FieldCategory.IGNORE + # If field_path is not in any override category, fall through to general rules + + # Apply general FIELD_CATEGORIES rules for all other fields + if field_path in FIELD_CATEGORIES[FieldCategory.STRICT_EQUAL]: + return FieldCategory.STRICT_EQUAL + elif field_path in FIELD_CATEGORIES[FieldCategory.KEY_EQUAL]: + return FieldCategory.KEY_EQUAL + elif field_path in FIELD_CATEGORIES[FieldCategory.IGNORE]: + return FieldCategory.IGNORE + else: + # Default behavior for unspecified fields + return FieldCategory.IGNORE + + +def assert_field_by_category( + field_path: str, + expected_value: Any, + actual_value: Any, + event_type: str = "", + context: str = "", +) -> None: + """Assert a field value based on its category. + + Args: + field_path: Path to the field (can be nested with dots) + expected_value: Expected value from JSON + actual_value: Actual value from event object (can be dict, object, or primitive) + event_type: Event type for override checking + context: Context string for error messages + """ + category = get_field_category(field_path, event_type) + + if category is FieldCategory.STRICT_EQUAL: + # Convert actual_value to comparable format if it's an object + if expected_value != actual_value: + raise EventAssertionError( + f"{context}Field '{field_path}' strict equality failed: " + f"expected {expected_value}, got {actual_value}" + ) + elif category is FieldCategory.KEY_EQUAL: + # Just check that both have the field (not None) + if expected_value is not None and actual_value is None: + raise EventAssertionError( + f"{context}Field '{field_path}' missing in actual event" + ) + if expected_value is None and actual_value is not None: + raise EventAssertionError( + f"{context}Field '{field_path}' unexpected in actual event" + ) + # If category is FieldCategory.IGNORE, do nothing + + +def assert_nested_fields( + expected_obj: dict, + actual_obj: Any, + parent_path: str = "", + event_type: str = "", + context: str = "", +) -> None: + """Recursively assert nested fields using dot notation paths. + + Args: + expected_obj: Expected object/dict from JSON + actual_obj: Actual object from event + parent_path: Current path prefix for nested fields + event_type: Event type for override checking + context: Context string for error messages + """ + if not isinstance(expected_obj, dict): + return + + for key, expected_value in expected_obj.items(): + current_path = f"{parent_path}.{key}" if parent_path else key + + # Get actual value using dot notation + actual_value = get_nested_value(actual_obj, current_path) + + if isinstance(expected_value, dict) and expected_value: + # This is a nested object, recurse into it + assert_nested_fields( + expected_value, actual_obj, current_path, event_type, context + ) + else: + # This is a leaf value, assert it based on category + assert_field_by_category( + current_path, expected_value, actual_value, event_type, context + ) + + +def assert_events(path_to_json: str, events: list[Any]) -> None: + """Advanced event assertion with categorized field checking. + + This function provides sophisticated event assertion with three categories: + - STRICT_EQUAL: Key and value must match exactly + - KEY_EQUAL: Key must exist but value can vary + - IGNORE: Field is completely ignored + + Args: + path_to_json: Path to JSON file containing expected events + events: List of actual Event objects from execution + + Raises: + EventAssertionError: If events don't match expectations + FileNotFoundError: If JSON file doesn't exist + + Example: + assert_events('events/hello_world_events.json', result.events) + """ + events_file_path = Path(path_to_json) + + logger.info(f"Asserting events from: {path_to_json}") + + # Load expected data + if not events_file_path.exists(): + raise FileNotFoundError(f"Events file not found: {events_file_path}") + + with open(events_file_path, "r", encoding="utf-8") as f: + expected_data = json.load(f) + + expected_events = expected_data.get("events", []) + + # 1. Assert total event count + if len(events) != len(expected_events): + raise EventAssertionError( + f"Event count mismatch: expected {len(expected_events)}, got {len(events)}" + ) + + # 2. Assert each event with categorized field checking using dot notation + for i, (actual_event, expected_event) in enumerate(zip(events, expected_events)): + context = f"Event {i}: " + + # Get event type for override checking + event_type = expected_event.get("event_type", "") + + # Use recursive nested field assertion with dot notation + for field_name, expected_value in expected_event.items(): + # Get actual value using dot notation (handles both simple and nested fields) + actual_value = get_nested_value(actual_event, field_name) + + # Check if this field path should be treated as a whole object assertion + field_category = get_field_category(field_name, event_type) + + if ( + isinstance(expected_value, dict) + and expected_value + and field_category is not FieldCategory.STRICT_EQUAL + ): + # This is a nested object and NOT marked for strict_equal, recurse into it + assert_nested_fields( + expected_value, actual_event, field_name, event_type, context + ) + else: + # This is either: + # 1. A leaf value (string, int, etc.) + # 2. A nested object marked for strict_equal (assert the whole dict) + # 3. An empty dict + assert_field_by_category( + field_name, expected_value, actual_value, event_type, context + ) + + logger.info( + f"✅ All {len(events)} events match expected patterns from {path_to_json}" + ) diff --git a/examples/test/test_hello_world.py b/examples/test/test_hello_world.py index f0a5446..5706bd2 100644 --- a/examples/test/test_hello_world.py +++ b/examples/test/test_hello_world.py @@ -5,6 +5,7 @@ from src import hello_world from test.conftest import deserialize_operation_payload +from test.event_helper import assert_events @pytest.mark.example @@ -22,3 +23,4 @@ def test_hello_world(durable_runner): "statusCode": 200, "body": "Hello from Durable Lambda! (status: 200)", } + assert_events("examples/events/hello_world_events.json", result.events) diff --git a/pyproject.toml b/pyproject.toml index f3652d0..b9e7b48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ history = "python examples/cli.py history {args}" policy = "python examples/cli.py policy {args}" list = "python examples/cli.py list" clean = "rm -rf examples/build examples/.aws-sam examples/*.zip" +generate-events = "python examples/scripts/cli_event_generator.py {args}" [tool.hatch.envs.types] extra-dependencies = ["mypy>=1.0.0", "pytest"] diff --git a/src/aws_durable_execution_sdk_python_testing/runner.py b/src/aws_durable_execution_sdk_python_testing/runner.py index 34e80f4..f412cae 100644 --- a/src/aws_durable_execution_sdk_python_testing/runner.py +++ b/src/aws_durable_execution_sdk_python_testing/runner.py @@ -464,6 +464,7 @@ def _get_callback_id_from_events( class DurableFunctionTestResult: status: InvocationStatus operations: list[Operation] + events: list[Event] = field(default_factory=list) result: OperationPayload | None = None error: ErrorObject | None = None @@ -485,6 +486,7 @@ def create(cls, execution: Execution) -> DurableFunctionTestResult: return cls( status=execution.result.status, operations=operations, + events=[], result=execution.result.result, error=execution.result.error, ) @@ -527,6 +529,7 @@ def from_execution_history( return cls( status=status, operations=operations, + events=history_response.events, result=execution_response.result, error=execution_response.error, ) @@ -668,11 +671,16 @@ def wait_for_result( if not completed: msg_timeout: str = "Execution did not complete within timeout" - raise TimeoutError(msg_timeout) - execution: Execution = self._store.load(execution_arn) - return DurableFunctionTestResult.create(execution=execution) + history_response = self._executor.get_execution_history( + execution_arn, include_execution_data=True + ) + execution_response = self._executor.get_execution_details(execution_arn) + + return DurableFunctionTestResult.from_execution_history( + execution_response, history_response + ) def wait_for_callback( self, execution_arn: str, name: str | None = None, timeout: int = 60 diff --git a/tests/e2e/basic_success_path_test.py b/tests/e2e/basic_success_path_test.py index 3e93bcf..3bea07b 100644 --- a/tests/e2e/basic_success_path_test.py +++ b/tests/e2e/basic_success_path_test.py @@ -23,7 +23,6 @@ from aws_durable_execution_sdk_python.config import Duration -# brazil-test-exec pytest test/runner_int_test.py def test_basic_durable_function() -> None: @durable_step def one(step_context: StepContext, a: int, b: int) -> str: diff --git a/tests/runner_test.py b/tests/runner_test.py index 3b81269..5fcb80c 100644 --- a/tests/runner_test.py +++ b/tests/runner_test.py @@ -28,6 +28,7 @@ StartDurableExecutionInput, StartDurableExecutionOutput, GetDurableExecutionHistoryResponse, + GetDurableExecutionResponse, ) from aws_durable_execution_sdk_python_testing.runner import ( OPERATION_FACTORIES, @@ -740,14 +741,25 @@ def test_durable_function_test_runner_run(mock_store_class, mock_executor_class) mock_executor.start_execution.return_value = output mock_executor.wait_until_complete.return_value = True - # Mock execution for result creation - mock_execution = Mock(spec=Execution) - mock_execution.operations = [] - mock_execution.result = Mock() - mock_execution.result.status = InvocationStatus.SUCCEEDED - mock_execution.result.result = json.dumps("test-result") - mock_execution.result.error = None - mock_store.load.return_value = mock_execution + # Mock the new methods used by wait_for_result + mock_history_response = GetDurableExecutionHistoryResponse(events=[]) + mock_execution_response = GetDurableExecutionResponse( + durable_execution_arn="test-arn", + durable_execution_name="execution-name", + function_arn="arn:aws:lambda:us-west-2:123456789012:function:test-function", + status="SUCCEEDED", + start_timestamp=datetime.datetime( + 2023, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc + ), + end_timestamp=datetime.datetime( + 2023, 1, 1, 0, 1, 0, tzinfo=datetime.timezone.utc + ), + result=json.dumps("test-result"), + error=None, + ) + + mock_executor.get_execution_history.return_value = mock_history_response + mock_executor.get_execution_details.return_value = mock_execution_response runner = DurableFunctionTestRunner(handler) result = runner.run("test-input") @@ -764,8 +776,11 @@ def test_durable_function_test_runner_run(mock_store_class, mock_executor_class) # Verify wait_until_complete was called mock_executor.wait_until_complete.assert_called_once_with("test-arn", 900) - # Verify store.load was called - mock_store.load.assert_called_once_with("test-arn") + # Verify the methods are called + mock_executor.get_execution_history.assert_called_once_with( + "test-arn", include_execution_data=True + ) + mock_executor.get_execution_details.assert_called_once_with("test-arn") # Verify result assert isinstance(result, DurableFunctionTestResult) @@ -791,14 +806,25 @@ def test_durable_function_test_runner_run_with_custom_params( mock_executor.start_execution.return_value = output mock_executor.wait_until_complete.return_value = True - # Mock execution for result creation - mock_execution = Mock(spec=Execution) - mock_execution.operations = [] - mock_execution.result = Mock() - mock_execution.result.status = InvocationStatus.SUCCEEDED - mock_execution.result.result = json.dumps("test-result") - mock_execution.result.error = None - mock_store.load.return_value = mock_execution + # Mock the new methods used by wait_for_result + mock_history_response = GetDurableExecutionHistoryResponse(events=[]) + mock_execution_response = GetDurableExecutionResponse( + durable_execution_arn="test-arn", + durable_execution_name="custom-execution", + function_arn="arn:aws:lambda:us-west-2:987654321098:function:custom-function", + status="SUCCEEDED", + start_timestamp=datetime.datetime( + 2023, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc + ), + end_timestamp=datetime.datetime( + 2023, 1, 1, 0, 1, 0, tzinfo=datetime.timezone.utc + ), + result=json.dumps("test-result"), + error=None, + ) + + mock_executor.get_execution_history.return_value = mock_history_response + mock_executor.get_execution_details.return_value = mock_execution_response runner = DurableFunctionTestRunner(handler) result = runner.run( @@ -820,6 +846,12 @@ def test_durable_function_test_runner_run_with_custom_params( # Verify wait_until_complete was called with custom timeout mock_executor.wait_until_complete.assert_called_once_with("test-arn", 1800) + # Verify the methods are called + mock_executor.get_execution_history.assert_called_once_with( + "test-arn", include_execution_data=True + ) + mock_executor.get_execution_details.assert_called_once_with("test-arn") + assert result.status is InvocationStatus.SUCCEEDED