diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java index 82c5b5fb..a25464b4 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/ParallelWithWaitExample.java @@ -62,6 +62,8 @@ public Output handleRequest(Input input, DurableContext context) { var deliveries = futures.stream().map(DurableFuture::get).toList(); logger.info("All {} notifications delivered", deliveries.size()); + // Test replay + context.wait("wait for finalization", Duration.ofSeconds(5)); return new Output(deliveries); } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java index e06b1d12..a0c08b65 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java @@ -12,6 +12,7 @@ import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.context.DurableContextImpl; import software.amazon.lambda.durable.exception.ConcurrencyExecutionException; +import software.amazon.lambda.durable.execution.ExecutionManager; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.model.OperationSubType; @@ -42,6 +43,8 @@ */ public class ParallelOperation extends ConcurrencyOperation { + private boolean skipCheckpoint = false; + public ParallelOperation( OperationIdentifier operationIdentifier, TypeToken resultTypeToken, @@ -79,6 +82,10 @@ protected ChildContextOperation createItem( @Override protected void handleSuccess() { + if (skipCheckpoint) { + // Do not send checkpoint during replay + return; + } sendOperationUpdate(OperationUpdate.builder() .action(OperationAction.SUCCEED) .subType(getSubType().getValue()) @@ -99,8 +106,9 @@ protected void start() { @Override protected void replay(Operation existing) { - // Always replay child branches for parallel - start(); + // No-op: child branches handle their own replay via ChildContextOperation.replay(). + // Set replaying=true so handleSuccess() skips re-checkpointing the already-completed parallel context. + skipCheckpoint = ExecutionManager.isTerminalStatus(existing.status()); } @Override diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java index 62fccce7..a2b07312 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java @@ -199,6 +199,100 @@ void contextHierarchy_branchesUseParallelContextAsParent() throws Exception { assertNotNull(childOp); } + // ===== Replay ===== + + @Test + void replay_doesNotSendStartCheckpoint() throws Exception { + // Simulate the parallel operation already existing in the service (STARTED status) + when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)) + .thenReturn(Operation.builder() + .id(OPERATION_ID) + .name("test-parallel") + .type(OperationType.CONTEXT) + .subType(OperationSubType.PARALLEL.getValue()) + .status(OperationStatus.STARTED) + .build()); + // Both branches already succeeded + when(executionManager.getOperationAndUpdateReplayState("child-1")) + .thenReturn(Operation.builder() + .id("child-1") + .name("branch-1") + .type(OperationType.CONTEXT) + .subType(OperationSubType.PARALLEL_BRANCH.getValue()) + .status(OperationStatus.SUCCEEDED) + .contextDetails( + ContextDetails.builder().result("\"r1\"").build()) + .build()); + when(executionManager.getOperationAndUpdateReplayState("child-2")) + .thenReturn(Operation.builder() + .id("child-2") + .name("branch-2") + .type(OperationType.CONTEXT) + .subType(OperationSubType.PARALLEL_BRANCH.getValue()) + .status(OperationStatus.SUCCEEDED) + .contextDetails( + ContextDetails.builder().result("\"r2\"").build()) + .build()); + + var op = createOperation(-1, -1, 0); + setOperationIdGenerator(op, mockIdGenerator); + op.execute(); + op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); + op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES); + + runJoin(op); + + verify(executionManager, never()) + .sendOperationUpdate(argThat(update -> update.action() == OperationAction.START)); + verify(executionManager, times(1)) + .sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); + } + + @Test + void replay_doesNotSendSucceedCheckpointWhenParallelAlreadySucceeded() throws Exception { + when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)) + .thenReturn(Operation.builder() + .id(OPERATION_ID) + .name("test-parallel") + .type(OperationType.CONTEXT) + .subType(OperationSubType.PARALLEL.getValue()) + .status(OperationStatus.SUCCEEDED) + .build()); + when(executionManager.getOperationAndUpdateReplayState("child-1")) + .thenReturn(Operation.builder() + .id("child-1") + .name("branch-1") + .type(OperationType.CONTEXT) + .subType(OperationSubType.PARALLEL_BRANCH.getValue()) + .status(OperationStatus.SUCCEEDED) + .contextDetails( + ContextDetails.builder().result("\"r1\"").build()) + .build()); + when(executionManager.getOperationAndUpdateReplayState("child-2")) + .thenReturn(Operation.builder() + .id("child-2") + .name("branch-2") + .type(OperationType.CONTEXT) + .subType(OperationSubType.PARALLEL_BRANCH.getValue()) + .status(OperationStatus.SUCCEEDED) + .contextDetails( + ContextDetails.builder().result("\"r2\"").build()) + .build()); + + var op = createOperation(-1, -1, 0); + setOperationIdGenerator(op, mockIdGenerator); + op.execute(); + op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES); + op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES); + + runJoin(op); + + verify(executionManager, never()) + .sendOperationUpdate(argThat(update -> update.action() == OperationAction.START)); + verify(executionManager, never()) + .sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED)); + } + // ===== handleFailure still sends SUCCEED ===== @Test