diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java index 03c36b0daf3a..478b783d68f2 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java @@ -465,7 +465,7 @@ private void classifyTasksWithStateUpdater(final Map standbyTasksToCreate.remove(taskId); } else { stateUpdater.remove(taskId); - tasks.addPendingTaskToClose(taskId); + tasks.addPendingTaskToCloseClean(taskId); } } } @@ -692,7 +692,7 @@ private void handleRemovedTasksFromStateUpdater() { taskExceptions.putIfAbsent(taskId, e); } - } else if (tasks.removePendingTaskToClose(task.id())) { + } else if (tasks.removePendingTaskToCloseClean(task.id())) { try { task.suspend(); task.closeClean(); @@ -710,6 +710,8 @@ private void handleRemovedTasksFromStateUpdater() { taskExceptions.putIfAbsent(task.id(), e); } + } else if (tasks.removePendingTaskToCloseDirty(task.id())) { + tasksToCloseDirty.add(task); } else if ((inputPartitions = tasks.removePendingTaskToUpdateInputPartitions(task.id())) != null) { task.updateInputPartitions(inputPartitions, topologyMetadata.nodeToSourceTopics(task.id())); stateUpdater.add(task); @@ -755,6 +757,8 @@ void handleRevocation(final Collection revokedPartitions) { } } + removeRevokedTasksFromStateUpdater(remainingRevokedPartitions); + if (!remainingRevokedPartitions.isEmpty()) { log.debug("The following revoked partitions {} are missing from the current task partitions. It could " + "potentially be due to race condition of consumer detecting the heartbeat failure, or the tasks " + @@ -840,6 +844,20 @@ void handleRevocation(final Collection revokedPartitions) { } } + private void removeRevokedTasksFromStateUpdater(final Set remainingRevokedPartitions) { + if (stateUpdater != null) { + for (final Task restoringTask : stateUpdater.getTasks()) { + if (restoringTask.isActive()) { + if (remainingRevokedPartitions.containsAll(restoringTask.inputPartitions())) { + tasks.addPendingTaskToCloseClean(restoringTask.id()); + stateUpdater.remove(restoringTask.id()); + remainingRevokedPartitions.removeAll(restoringTask.inputPartitions()); + } + } + } + } + } + private void prepareCommitAndAddOffsetsToMap(final Set tasksToPrepare, final Map> consumedOffsetsPerTask) { for (final Task task : tasksToPrepare) { @@ -867,6 +885,15 @@ private void prepareCommitAndAddOffsetsToMap(final Set tasksToPrepare, void handleLostAll() { log.debug("Closing lost active tasks as zombies."); + closeRunningTasksDirty(); + removeLostTasksFromStateUpdater(); + + if (processingMode == EXACTLY_ONCE_V2) { + activeTaskCreator.reInitializeThreadProducer(); + } + } + + private void closeRunningTasksDirty() { final Set allTask = tasks.allTasks(); for (final Task task : allTask) { // Even though we've apparently dropped out of the group, we can continue safely to maintain our @@ -875,9 +902,16 @@ void handleLostAll() { closeTaskDirty(task); } } + } - if (processingMode == EXACTLY_ONCE_V2) { - activeTaskCreator.reInitializeThreadProducer(); + private void removeLostTasksFromStateUpdater() { + if (stateUpdater != null) { + for (final Task restoringTask : stateUpdater.getTasks()) { + if (restoringTask.isActive()) { + tasks.addPendingTaskToCloseDirty(restoringTask.id()); + stateUpdater.remove(restoringTask.id()); + } + } } } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java index 9628b42d927c..8178fe369119 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java @@ -55,7 +55,9 @@ class Tasks { private final Map> pendingTasksToRecycle = new HashMap<>(); private final Map> pendingTasksToUpdateInputPartitions = new HashMap<>(); private final Set pendingTasksToInit = new HashSet<>(); - private final Set pendingTasksToClose = new HashSet<>(); + private final Set pendingTasksToCloseClean = new HashSet<>(); + + private final Set pendingTasksToCloseDirty = new HashSet<>(); // TODO: convert to Stream/StandbyTask when we remove TaskManager#StateMachineTask with mocks private final Map activeTasksPerPartition = new HashMap<>(); @@ -111,12 +113,19 @@ void addPendingTaskToUpdateInputPartitions(final TaskId taskId, final Set drainPendingTaskToInit() { diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java index 12ea6477e540..b65b38285333 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java @@ -118,6 +118,7 @@ import static org.junit.Assert.fail; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.when; import static org.mockito.Mockito.mock; @@ -144,7 +145,9 @@ public class TaskManagerTest { private final TaskId taskId02 = new TaskId(0, 2); private final TopicPartition t1p2 = new TopicPartition(topic1, 2); + private final TopicPartition t1p2changelog = new TopicPartition("changelog", 2); private final Set taskId02Partitions = mkSet(t1p2); + private final Set taskId02ChangelogPartitions = mkSet(t1p2changelog); private final TaskId taskId03 = new TaskId(0, 3); private final TopicPartition t1p3 = new TopicPartition(topic1, 3); @@ -343,8 +346,8 @@ public void shouldHandleRemovedTasksToCloseFromStateUpdater() { expectLastCall().anyTimes(); replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer); - taskManager.tasks().addPendingTaskToClose(taskId00); - taskManager.tasks().addPendingTaskToClose(taskId01); + taskManager.tasks().addPendingTaskToCloseClean(taskId00); + taskManager.tasks().addPendingTaskToCloseClean(taskId01); taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter -> { }); @@ -383,6 +386,113 @@ public void shouldHandleRemovedTasksToUpdateInputPartitionsFromStateUpdater() { Mockito.verify(stateUpdater).add(task01); } + @Test + public void shouldRemoveStatefulTaskWithRevokedInputPartitionsFromStateUpdaterOnRevocation() { + final StreamTask task = statefulTask(taskId00, taskId00ChangelogPartitions) + .inState(State.RESTORING) + .withInputPartitions(taskId00Partitions).build(); + final TaskManager taskManager = setupForRevocationAndLost(mkSet(task), mkSet(task)); + + taskManager.handleRevocation(taskId00Partitions); + + Mockito.verify(stateUpdater).remove(task.id()); + + taskManager.tryToCompleteRestoration(time.milliseconds(), null); + + Mockito.verify(task).closeClean(); + } + + public void shouldRemoveMultipleStatefulTaskWithRevokedInputPartitionsFromStateUpdaterOnRevocation() { + final StreamTask task1 = statefulTask(taskId00, taskId00ChangelogPartitions) + .inState(State.RESTORING) + .withInputPartitions(taskId00Partitions).build(); + final StreamTask task2 = statefulTask(taskId01, taskId01ChangelogPartitions) + .inState(State.RESTORING) + .withInputPartitions(taskId01Partitions).build(); + final TaskManager taskManager = setupForRevocationAndLost(mkSet(task1, task2), mkSet(task1, task2)); + + taskManager.handleRevocation(union(HashSet::new, taskId00Partitions, taskId01Partitions)); + + Mockito.verify(stateUpdater).remove(task1.id()); + Mockito.verify(stateUpdater).remove(task2.id()); + + taskManager.tryToCompleteRestoration(time.milliseconds(), null); + + Mockito.verify(task1).closeClean(); + Mockito.verify(task2).closeClean(); + } + + @Test + public void shouldNotRemoveStatefulTaskWithoutRevokedInputPartitionsFromStateUpdaterOnRevocation() { + final StreamTask task = statefulTask(taskId00, taskId00ChangelogPartitions) + .inState(State.RESTORING) + .withInputPartitions(taskId00Partitions).build(); + final TaskManager taskManager = setupForRevocationAndLost(mkSet(task), Collections.emptySet()); + + taskManager.handleRevocation(taskId01Partitions); + + Mockito.verify(stateUpdater, never()).remove(task.id()); + + taskManager.tryToCompleteRestoration(time.milliseconds(), null); + + Mockito.verify(task, never()).closeClean(); + } + + @Test + public void shouldNotRemoveStandbyTaskFromStateUpdaterOnRevocation() { + final StandbyTask task = standbyTask(taskId00, taskId00ChangelogPartitions) + .inState(State.RESTORING) + .withInputPartitions(taskId00Partitions).build(); + final TaskManager taskManager = setupForRevocationAndLost(mkSet(task), Collections.emptySet()); + + taskManager.handleRevocation(taskId00Partitions); + + Mockito.verify(stateUpdater, never()).remove(task.id()); + + taskManager.tryToCompleteRestoration(time.milliseconds(), null); + + Mockito.verify(task, never()).closeClean(); + } + + @Test + public void shouldRemoveAllActiveTasksFromStateUpdaterOnPartitionLost() { + final StreamTask task1 = statefulTask(taskId00, taskId00ChangelogPartitions) + .inState(State.RESTORING) + .withInputPartitions(taskId00Partitions).build(); + final StandbyTask task2 = standbyTask(taskId01, taskId01ChangelogPartitions) + .inState(State.RESTORING) + .withInputPartitions(taskId01Partitions).build(); + final StreamTask task3 = statefulTask(taskId02, taskId02ChangelogPartitions) + .inState(State.RESTORING) + .withInputPartitions(taskId02Partitions).build(); + final TaskManager taskManager = setupForRevocationAndLost(mkSet(task1, task2, task3), mkSet(task1, task3)); + + taskManager.handleLostAll(); + + Mockito.verify(stateUpdater).remove(task1.id()); + Mockito.verify(stateUpdater).remove(task3.id()); + + taskManager.tryToCompleteRestoration(time.milliseconds(), null); + + Mockito.verify(task1).closeDirty(); + Mockito.verify(task3).closeDirty(); + Mockito.verify(task2, never()).closeDirty(); + Mockito.verify(task2, never()).closeClean(); + } + + private TaskManager setupForRevocationAndLost(final Set tasksInStateUpdater, + final Set removedTasks) { + final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, true); + when(stateUpdater.getTasks()).thenReturn(tasksInStateUpdater); + when(stateUpdater.drainRemovedTasks()).thenReturn(removedTasks); + expect(consumer.assignment()).andReturn(emptySet()).anyTimes(); + consumer.resume(anyObject()); + expectLastCall().anyTimes(); + replay(consumer); + + return taskManager; + } + @Test public void shouldHandleRemovedTasksFromStateUpdater() { // tasks to recycle @@ -436,7 +546,7 @@ public void shouldHandleRemovedTasksFromStateUpdater() { stateUpdater ); taskManager.setMainConsumer(consumer); - taskManager.tasks().addPendingTaskToClose(taskId02); + taskManager.tasks().addPendingTaskToCloseClean(taskId02); taskManager.tasks().addPendingTaskToRecycle(taskId00, taskId00Partitions); taskManager.tasks().addPendingTaskToRecycle(taskId01, taskId01Partitions); taskManager.tasks().addPendingTaskToUpdateInputPartitions(taskId03, taskId03Partitions);