diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java index f458629cf1f2c..2f3360a3f53ed 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java @@ -1458,6 +1458,14 @@ protected boolean maybeUpdateSubscriptionVersion(final int receivedAssignmentMet @Override public void onAssignment(final Assignment assignment, final ConsumerGroupMetadata metadata) { + final Set tasksWithOpenTransactions = taskManager.allOwnedTasks() + .values() + .stream() + .filter(t -> t.commitNeeded()) + .collect(Collectors.toSet()); + log.info("Committing {} tasks with open transactions before onAssignment()", tasksWithOpenTransactions.size()); + taskManager.commit(tasksWithOpenTransactions); + final List partitions = new ArrayList<>(assignment.partitions()); partitions.sort(PARTITION_COMPARATOR); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java index 08674857b7182..8e072f08012d3 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java @@ -1442,6 +1442,14 @@ public void testOnAssignment(final Map parameterizedConfig) { setUp(parameterizedConfig, false); taskManager = mock(TaskManager.class); + final Task notToCommit = mock(Task.class); + final Task toCommit = mock(Task.class); + final TaskId notToCommitId = new TaskId(0, 0); + final TaskId toCommitId = new TaskId(0, 1); + when(notToCommit.commitNeeded()).thenReturn(false); + when(toCommit.commitNeeded()).thenReturn(true); + when(taskManager.allOwnedTasks()).thenReturn(Map.of(notToCommitId, notToCommit, toCommitId, toCommit)); + final Map> hostState = Collections.singletonMap( new HostInfo("localhost", 9090), Set.of(t3p0, t3p3)); @@ -1465,6 +1473,7 @@ public void testOnAssignment(final Map parameterizedConfig) { verify(streamsMetadataState).onChange(eq(hostState), any(), topicPartitionInfoCaptor.capture()); verify(taskManager).handleAssignment(activeTasks, standbyTasks); + verify(taskManager).commit(Set.of(toCommit)); assertTrue(topicPartitionInfoCaptor.getValue().containsKey(t3p0)); assertTrue(topicPartitionInfoCaptor.getValue().containsKey(t3p3));