diff --git a/packages/aws-durable-execution-sdk-js/src/context/durable-context/durable-context.ts b/packages/aws-durable-execution-sdk-js/src/context/durable-context/durable-context.ts index ec517cf0..4c8f80e7 100644 --- a/packages/aws-durable-execution-sdk-js/src/context/durable-context/durable-context.ts +++ b/packages/aws-durable-execution-sdk-js/src/context/durable-context/durable-context.ts @@ -237,13 +237,13 @@ export class DurableContextImpl implements DurableContext { funcIdOrInput?: string | I, inputOrConfig?: I | InvokeConfig, maybeConfig?: InvokeConfig, - ): Promise { + ): DurablePromise { validateContextUsage( this._stepPrefix, "invoke", this.executionContext.terminationManager, ); - return this.withModeManagement(() => { + return this.withDurableModeManagement(() => { const invokeHandler = createInvokeHandler( this.executionContext, this.checkpoint, @@ -268,13 +268,13 @@ export class DurableContextImpl implements DurableContext { nameOrFn: string | undefined | ChildFunc, fnOrOptions?: ChildFunc | ChildConfig, maybeOptions?: ChildConfig, - ): Promise { + ): DurablePromise { validateContextUsage( this._stepPrefix, "runInChildContext", this.executionContext.terminationManager, ); - return this.withModeManagement(() => { + return this.withDurableModeManagement(() => { const blockHandler = createRunInChildContextHandler( this.executionContext, this.checkpoint, @@ -284,23 +284,20 @@ export class DurableContextImpl implements DurableContext { createDurableContext, this._parentId, ); - const promise = blockHandler(nameOrFn, fnOrOptions, maybeOptions); - // Prevent unhandled promise rejections - promise?.catch(() => {}); - return promise; + return blockHandler(nameOrFn, fnOrOptions, maybeOptions); }); } wait( nameOrDuration: string | Duration, maybeDuration?: Duration, - ): Promise { + ): DurablePromise { validateContextUsage( this._stepPrefix, "wait", this.executionContext.terminationManager, ); - return this.withModeManagement(() => { + return this.withDurableModeManagement(() => { const waitHandler = createWaitHandler( this.executionContext, this.checkpoint, @@ -367,27 +364,22 @@ export class DurableContextImpl implements DurableContext { nameOrSubmitter?: string | undefined | WaitForCallbackSubmitterFunc, submitterOrConfig?: WaitForCallbackSubmitterFunc | WaitForCallbackConfig, maybeConfig?: WaitForCallbackConfig, - ): Promise { + ): DurablePromise { validateContextUsage( this._stepPrefix, "waitForCallback", this.executionContext.terminationManager, ); - return this.withModeManagement(() => { + return this.withDurableModeManagement(() => { const waitForCallbackHandler = createWaitForCallbackHandler( this.executionContext, this.runInChildContext.bind(this), ); - const promise = waitForCallbackHandler( + return waitForCallbackHandler( nameOrSubmitter!, submitterOrConfig, maybeConfig, ); - // Prevent unhandled promise rejections - promise?.catch(() => {}); - return promise?.finally(() => { - this.checkAndUpdateReplayMode(); - }); }); } @@ -397,13 +389,13 @@ export class DurableContextImpl implements DurableContext { | WaitForConditionCheckFunc | WaitForConditionConfig, maybeConfig?: WaitForConditionConfig, - ): Promise { + ): DurablePromise { validateContextUsage( this._stepPrefix, "waitForCondition", this.executionContext.terminationManager, ); - return this.withModeManagement(() => { + return this.withDurableModeManagement(() => { const waitForConditionHandler = createWaitForConditionHandler( this.executionContext, this.checkpoint, @@ -416,20 +408,17 @@ export class DurableContextImpl implements DurableContext { this._parentId, ); - const promise = - typeof nameOrCheckFunc === "string" || nameOrCheckFunc === undefined - ? waitForConditionHandler( - nameOrCheckFunc, - checkFuncOrConfig as WaitForConditionCheckFunc, - maybeConfig!, - ) - : waitForConditionHandler( - nameOrCheckFunc, - checkFuncOrConfig as WaitForConditionConfig, - ); - // Prevent unhandled promise rejections - promise?.catch(() => {}); - return promise; + return typeof nameOrCheckFunc === "string" || + nameOrCheckFunc === undefined + ? waitForConditionHandler( + nameOrCheckFunc, + checkFuncOrConfig as WaitForConditionCheckFunc, + maybeConfig!, + ) + : waitForConditionHandler( + nameOrCheckFunc, + checkFuncOrConfig as WaitForConditionConfig, + ); }); } diff --git a/packages/aws-durable-execution-sdk-js/src/handlers/run-in-child-context-handler/run-in-child-context-handler-two-phase.test.ts b/packages/aws-durable-execution-sdk-js/src/handlers/run-in-child-context-handler/run-in-child-context-handler-two-phase.test.ts index 9462db4e..d04cf03b 100644 --- a/packages/aws-durable-execution-sdk-js/src/handlers/run-in-child-context-handler/run-in-child-context-handler-two-phase.test.ts +++ b/packages/aws-durable-execution-sdk-js/src/handlers/run-in-child-context-handler/run-in-child-context-handler-two-phase.test.ts @@ -1,5 +1,5 @@ import { createRunInChildContextHandler } from "./run-in-child-context-handler"; -import { ExecutionContext, DurableContext } from "../../types"; +import { ExecutionContext } from "../../types"; import { DurablePromise } from "../../types/durable-promise"; import { Context } from "aws-lambda"; @@ -58,15 +58,18 @@ describe("Run In Child Context Handler Two-Phase Execution", () => { // Phase 1: Create the promise - this executes the logic immediately const childPromise = handler(childFn); - // Wait for phase 1 to complete - await new Promise((resolve) => setTimeout(resolve, 50)); - // Should return a DurablePromise expect(childPromise).toBeInstanceOf(DurablePromise); - // Phase 1 should have executed the child function + // Wait briefly for phase 1 to start executing + await new Promise((resolve) => setTimeout(resolve, 10)); + + // Phase 1 should have executed the child function (before we await the promise) expect(childFn).toHaveBeenCalled(); expect(mockCheckpoint).toHaveBeenCalled(); + + // Now await the promise to verify it completes + await childPromise; }); it("should return cached result in phase 2 when awaited", async () => { @@ -84,10 +87,12 @@ describe("Run In Child Context Handler Two-Phase Execution", () => { // Phase 1: Create the promise const childPromise = handler(childFn); - // Wait for phase 1 to complete - await new Promise((resolve) => setTimeout(resolve, 50)); + // Wait briefly for phase 1 to execute + await new Promise((resolve) => setTimeout(resolve, 10)); + // Child function should have been called before we await the promise const phase1Calls = childFn.mock.calls.length; + expect(phase1Calls).toBeGreaterThan(0); // Phase 2: Await the promise - should return cached result const result = await childPromise; diff --git a/packages/aws-durable-execution-sdk-js/src/handlers/step-handler/step-handler-two-phase.test.ts b/packages/aws-durable-execution-sdk-js/src/handlers/step-handler/step-handler-two-phase.test.ts index 8a1bbf77..253df6db 100644 --- a/packages/aws-durable-execution-sdk-js/src/handlers/step-handler/step-handler-two-phase.test.ts +++ b/packages/aws-durable-execution-sdk-js/src/handlers/step-handler/step-handler-two-phase.test.ts @@ -68,15 +68,18 @@ describe("Step Handler Two-Phase Execution", () => { // Phase 1: Create the promise - this executes the logic immediately const stepPromise = stepHandler(stepFn); - // Wait for phase 1 to complete - await new Promise((resolve) => setTimeout(resolve, 50)); - // Should return a DurablePromise expect(stepPromise).toBeInstanceOf(DurablePromise); - // Phase 1 should have executed the step function + // Wait briefly for phase 1 to start executing + await new Promise((resolve) => setTimeout(resolve, 10)); + + // Phase 1 should have executed the step function (before we await the promise) expect(stepFn).toHaveBeenCalled(); expect(mockCheckpoint).toHaveBeenCalled(); + + // Now await the promise to verify it completes + await stepPromise; }); it("should return cached result in phase 2 when awaited", async () => { @@ -97,10 +100,12 @@ describe("Step Handler Two-Phase Execution", () => { // Phase 1: Create the promise const stepPromise = stepHandler(stepFn); - // Wait for phase 1 to complete - await new Promise((resolve) => setTimeout(resolve, 50)); + // Wait briefly for phase 1 to execute + await new Promise((resolve) => setTimeout(resolve, 10)); + // Step function should have been called before we await the promise const phase1Calls = stepFn.mock.calls.length; + expect(phase1Calls).toBeGreaterThan(0); // Phase 2: Await the promise - should return cached result const result = await stepPromise; diff --git a/packages/aws-durable-execution-sdk-js/src/handlers/wait-for-condition-handler/wait-for-condition-handler-two-phase.test.ts b/packages/aws-durable-execution-sdk-js/src/handlers/wait-for-condition-handler/wait-for-condition-handler-two-phase.test.ts new file mode 100644 index 00000000..5c9f1b0b --- /dev/null +++ b/packages/aws-durable-execution-sdk-js/src/handlers/wait-for-condition-handler/wait-for-condition-handler-two-phase.test.ts @@ -0,0 +1,167 @@ +import { createWaitForConditionHandler } from "./wait-for-condition-handler"; +import { ExecutionContext, WaitForConditionCheckFunc } from "../../types"; +import { EventEmitter } from "events"; +import { DurablePromise } from "../../types/durable-promise"; + +describe("WaitForCondition Handler Two-Phase Execution", () => { + let mockContext: ExecutionContext; + let mockCheckpoint: any; + let createStepId: () => string; + let createContextLogger: (stepId: string, attempt?: number) => any; + let addRunningOperation: jest.Mock; + let removeRunningOperation: jest.Mock; + let hasRunningOperations: () => boolean; + let getOperationsEmitter: () => EventEmitter; + let stepIdCounter = 0; + + beforeEach(() => { + stepIdCounter = 0; + mockContext = { + getStepData: jest.fn().mockReturnValue(null), + durableExecutionArn: "test-arn", + terminationManager: { + shouldTerminate: jest.fn().mockReturnValue(false), + terminate: jest.fn(), + }, + } as any; + + mockCheckpoint = jest.fn().mockResolvedValue(undefined); + mockCheckpoint.force = jest.fn().mockResolvedValue(undefined); + mockCheckpoint.setTerminating = jest.fn(); + mockCheckpoint.hasPendingAncestorCompletion = jest + .fn() + .mockReturnValue(false); + + createStepId = (): string => `step-${++stepIdCounter}`; + createContextLogger = jest.fn().mockReturnValue({ + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + }); + addRunningOperation = jest.fn(); + removeRunningOperation = jest.fn(); + hasRunningOperations = jest.fn().mockReturnValue(false) as () => boolean; + getOperationsEmitter = (): EventEmitter => new EventEmitter(); + }); + + it("should execute check function in phase 1 immediately", async () => { + const waitForConditionHandler = createWaitForConditionHandler( + mockContext, + mockCheckpoint, + createStepId, + createContextLogger, + addRunningOperation, + removeRunningOperation, + hasRunningOperations, + getOperationsEmitter, + undefined, + ); + + const checkFn: WaitForConditionCheckFunc = jest + .fn() + .mockResolvedValue(10); + + // Phase 1: Create the promise - this executes the logic immediately + const promise = waitForConditionHandler(checkFn, { + initialState: 0, + waitStrategy: (_state) => ({ shouldContinue: false }), + }); + + // Should return a DurablePromise + expect(promise).toBeInstanceOf(DurablePromise); + + // Wait briefly for phase 1 to start executing + await new Promise((resolve) => setTimeout(resolve, 10)); + + // Phase 1 should have executed the check function (before we await the promise) + expect(checkFn).toHaveBeenCalled(); + expect(mockCheckpoint).toHaveBeenCalled(); + + // Now await the promise to verify it completes + await promise; + }); + + it("should return cached result in phase 2 when awaited", async () => { + const waitForConditionHandler = createWaitForConditionHandler( + mockContext, + mockCheckpoint, + createStepId, + createContextLogger, + addRunningOperation, + removeRunningOperation, + hasRunningOperations, + getOperationsEmitter, + undefined, + ); + + const checkFn: WaitForConditionCheckFunc = jest + .fn() + .mockResolvedValue("completed"); + + // Phase 1: Create the promise + const promise = waitForConditionHandler(checkFn, { + initialState: "initial", + waitStrategy: (_state) => ({ shouldContinue: false }), + }); + + // Wait briefly for phase 1 to execute + await new Promise((resolve) => setTimeout(resolve, 10)); + + // Check function should have been called before we await the promise + expect(checkFn).toHaveBeenCalledTimes(1); + + // Phase 2: Await the promise to get the result + const result = await promise; + + expect(result).toBe("completed"); + expect(checkFn).toHaveBeenCalledTimes(1); + }); + + it("should execute check function before await", async () => { + const waitForConditionHandler = createWaitForConditionHandler( + mockContext, + mockCheckpoint, + createStepId, + createContextLogger, + addRunningOperation, + removeRunningOperation, + hasRunningOperations, + getOperationsEmitter, + undefined, + ); + + let executionOrder: string[] = []; + const checkFn: WaitForConditionCheckFunc = jest.fn(async () => { + executionOrder.push("check-executed"); + return 42; + }); + + // Phase 1: Create the promise + executionOrder.push("promise-created"); + const promise = waitForConditionHandler(checkFn, { + initialState: 0, + waitStrategy: (_state) => ({ shouldContinue: false }), + }); + executionOrder.push("after-handler-call"); + + // Wait briefly for phase 1 to execute + await new Promise((resolve) => setTimeout(resolve, 10)); + + // Check should have executed before we await + expect(checkFn).toHaveBeenCalled(); + + executionOrder.push("before-await"); + const result = await promise; + executionOrder.push("after-await"); + + // Verify execution order: check should execute before await + expect(executionOrder).toEqual([ + "promise-created", + "check-executed", + "after-handler-call", + "before-await", + "after-await", + ]); + expect(result).toBe(42); + }); +}); diff --git a/packages/aws-durable-execution-sdk-js/src/handlers/wait-for-condition-handler/wait-for-condition-handler.test.ts b/packages/aws-durable-execution-sdk-js/src/handlers/wait-for-condition-handler/wait-for-condition-handler.test.ts index 178f762f..066bd136 100644 --- a/packages/aws-durable-execution-sdk-js/src/handlers/wait-for-condition-handler/wait-for-condition-handler.test.ts +++ b/packages/aws-durable-execution-sdk-js/src/handlers/wait-for-condition-handler/wait-for-condition-handler.test.ts @@ -9,6 +9,7 @@ import { WaitForConditionConfig, OperationSubType, Logger, + DurablePromise, } from "../../types"; import { TerminationManager } from "../../termination-manager/termination-manager"; import { TerminationReason } from "../../termination-manager/types"; @@ -509,8 +510,8 @@ describe("WaitForCondition Handler", () => { message: "Retry scheduled for step-1", }); - // Verify that the promise is indeed never-resolving by checking its constructor - expect(promise).toBeInstanceOf(Promise); + // Verify that the promise is indeed a DurablePromise + expect(promise).toBeInstanceOf(DurablePromise); }); it("should wait for timer when status is PENDING", async () => { diff --git a/packages/aws-durable-execution-sdk-js/src/handlers/wait-for-condition-handler/wait-for-condition-handler.ts b/packages/aws-durable-execution-sdk-js/src/handlers/wait-for-condition-handler/wait-for-condition-handler.ts index d6fc6c9d..f145d5c0 100644 --- a/packages/aws-durable-execution-sdk-js/src/handlers/wait-for-condition-handler/wait-for-condition-handler.ts +++ b/packages/aws-durable-execution-sdk-js/src/handlers/wait-for-condition-handler/wait-for-condition-handler.ts @@ -6,6 +6,7 @@ import { OperationSubType, WaitForConditionContext, Logger, + DurablePromise, } from "../../types"; import { durationToSeconds } from "../../utils/duration/duration"; import { terminate } from "../../utils/termination-helper/termination-helper"; @@ -41,6 +42,7 @@ const waitForContinuation = async ( hasRunningOperations: () => boolean, checkpoint: ReturnType, operationsEmitter: EventEmitter, + onAwaitedChange?: (callback: () => void) => void, ): Promise => { const stepData = context.getStepData(stepId); @@ -65,6 +67,7 @@ const waitForContinuation = async ( hasRunningOperations, operationsEmitter, checkpoint, + onAwaitedChange, }); // Return to let the main loop re-evaluate step status @@ -81,112 +84,151 @@ export const createWaitForConditionHandler = ( getOperationsEmitter: () => EventEmitter, parentId: string | undefined, ) => { - return async ( + return ( nameOrCheck: string | undefined | WaitForConditionCheckFunc, checkOrConfig?: WaitForConditionCheckFunc | WaitForConditionConfig, maybeConfig?: WaitForConditionConfig, - ): Promise => { - let name: string | undefined; - let check: WaitForConditionCheckFunc; - let config: WaitForConditionConfig; - - // Parse overloaded parameters - if (typeof nameOrCheck === "string" || nameOrCheck === undefined) { - name = nameOrCheck; - check = checkOrConfig as WaitForConditionCheckFunc; - config = maybeConfig as WaitForConditionConfig; - } else { - check = nameOrCheck; - config = checkOrConfig as WaitForConditionConfig; - } - - if (!config || !config.waitStrategy || config.initialState === undefined) { - throw new Error( - "waitForCondition requires config with waitStrategy and initialState", - ); - } - - const stepId = createStepId(); - - log("🔄", "Running waitForCondition:", { - stepId, - name, - config, - }); + ): DurablePromise => { + // Two-phase execution: Phase 1 starts immediately, Phase 2 returns result when awaited + let phase1Result: T | undefined; + let phase1Error: unknown; + let isAwaited = false; + let waitingCallback: (() => void) | undefined; + + const setWaitingCallback = (cb: () => void): void => { + waitingCallback = cb; + }; - // Main waitForCondition logic - can be re-executed if step status changes - while (true) { - try { - const stepData = context.getStepData(stepId); + // Phase 1: Start execution immediately and capture result/error + const phase1Promise = (async (): Promise => { + let name: string | undefined; + let check: WaitForConditionCheckFunc; + let config: WaitForConditionConfig; + + // Parse overloaded parameters - validation errors thrown here are async + if (typeof nameOrCheck === "string" || nameOrCheck === undefined) { + name = nameOrCheck; + check = checkOrConfig as WaitForConditionCheckFunc; + config = maybeConfig as WaitForConditionConfig; + } else { + check = nameOrCheck; + config = checkOrConfig as WaitForConditionConfig; + } - // Check if already completed - if (stepData?.Status === OperationStatus.SUCCEEDED) { - return await handleCompletedWaitForCondition( - context, - stepId, - name, - config.serdes, - ); - } + if ( + !config || + !config.waitStrategy || + config.initialState === undefined + ) { + throw new Error( + "waitForCondition requires config with waitStrategy and initialState", + ); + } - if (stepData?.Status === OperationStatus.FAILED) { - // Return an async rejected promise to ensure it's handled asynchronously - return (async (): Promise => { - // Reconstruct the original error from stored ErrorObject - if (stepData.StepDetails?.Error) { - throw DurableOperationError.fromErrorObject( - stepData.StepDetails.Error, - ); - } else { - // Fallback for legacy data without Error field - const errorMessage = stepData?.StepDetails?.Result; - throw new WaitForConditionError( - errorMessage || "waitForCondition failed", - ); - } - })(); - } + const stepId = createStepId(); - // If PENDING, wait for timer to complete - if (stepData?.Status === OperationStatus.PENDING) { - await waitForContinuation( + log("🔄", "Running waitForCondition:", { + stepId, + name, + config, + }); + // Main waitForCondition logic - can be re-executed if step status changes + while (true) { + try { + const stepData = context.getStepData(stepId); + + // Check if already completed + if (stepData?.Status === OperationStatus.SUCCEEDED) { + return await handleCompletedWaitForCondition( + context, + stepId, + name, + config.serdes, + ); + } + + if (stepData?.Status === OperationStatus.FAILED) { + // Return an async rejected promise to ensure it's handled asynchronously + return (async (): Promise => { + // Reconstruct the original error from stored ErrorObject + if (stepData.StepDetails?.Error) { + throw DurableOperationError.fromErrorObject( + stepData.StepDetails.Error, + ); + } else { + // Fallback for legacy data without Error field + const errorMessage = stepData?.StepDetails?.Result; + throw new WaitForConditionError( + errorMessage || "waitForCondition failed", + ); + } + })(); + } + + // If PENDING, wait for timer to complete + if (stepData?.Status === OperationStatus.PENDING) { + await waitForContinuation( + context, + stepId, + name, + hasRunningOperations, + checkpoint, + getOperationsEmitter(), + isAwaited ? undefined : setWaitingCallback, + ); + continue; // Re-evaluate step status after waiting + } + + // Execute check function for READY, STARTED, or first time (undefined) + const result = await executeWaitForCondition( context, + checkpoint, stepId, name, + check, + config, + createContextLogger, + addRunningOperation, + removeRunningOperation, hasRunningOperations, - checkpoint, - getOperationsEmitter(), + getOperationsEmitter, + parentId, + isAwaited ? undefined : setWaitingCallback, ); - continue; // Re-evaluate step status after waiting - } - // Execute check function for READY, STARTED, or first time (undefined) - const result = await executeWaitForCondition( - context, - checkpoint, - stepId, - name, - check, - config, - createContextLogger, - addRunningOperation, - removeRunningOperation, - hasRunningOperations, - getOperationsEmitter, - parentId, - ); + // If executeWaitForCondition signals to continue the main loop, do so + if (result === CONTINUE_MAIN_LOOP) { + continue; + } - // If executeWaitForCondition signals to continue the main loop, do so - if (result === CONTINUE_MAIN_LOOP) { - continue; + return result; + } catch (error) { + // For any error from executeWaitForCondition, re-throw it + throw error; } + } + })() + .then((result) => { + phase1Result = result; + }) + .catch((error) => { + phase1Error = error; + }); - return result; - } catch (error) { - // For any error from executeWaitForCondition, re-throw it - throw error; + // Phase 2: Return DurablePromise that returns Phase 1 result when awaited + return new DurablePromise(async () => { + // When promise is awaited, mark as awaited and invoke waiting callback + isAwaited = true; + if (waitingCallback) { + waitingCallback(); } - } + + await phase1Promise; + if (phase1Error !== undefined) { + throw phase1Error; + } + return phase1Result!; + }); }; }; @@ -227,6 +269,7 @@ export const executeWaitForCondition = async ( hasRunningOperations: () => boolean, getOperationsEmitter: () => EventEmitter, parentId: string | undefined, + onAwaitedChange?: ((callback: () => void) => void) | undefined, ): Promise => { const serdes = config.serdes || defaultSerdes; @@ -383,6 +426,7 @@ export const executeWaitForCondition = async ( hasRunningOperations, checkpoint, getOperationsEmitter(), + onAwaitedChange, ); return CONTINUE_MAIN_LOOP; }