Skip to content

Commit

Permalink
KAFKA-10199: Remove tasks from state updater on revoked and lost part…
Browse files Browse the repository at this point in the history
…itions (#12547)

Removes tasks from the state updater when the input partitions of the tasks are revoked or partitions are lost during a rebalance.

Reviewers: Guozhang Wang <wangguoz@gmail.com>
  • Loading branch information
cadonna committed Aug 22, 2022
1 parent ddb7fdd commit add4ca6
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 12 deletions.
Expand Up @@ -465,7 +465,7 @@ private void classifyTasksWithStateUpdater(final Map<TaskId, Set<TopicPartition>
standbyTasksToCreate.remove(taskId);
} else {
stateUpdater.remove(taskId);
tasks.addPendingTaskToClose(taskId);
tasks.addPendingTaskToCloseClean(taskId);
}
}
}
Expand Down Expand Up @@ -692,7 +692,7 @@ private void handleRemovedTasksFromStateUpdater() {

taskExceptions.putIfAbsent(taskId, e);
}
} else if (tasks.removePendingTaskToClose(task.id())) {
} else if (tasks.removePendingTaskToCloseClean(task.id())) {
try {
task.suspend();
task.closeClean();
Expand All @@ -710,6 +710,8 @@ private void handleRemovedTasksFromStateUpdater() {

taskExceptions.putIfAbsent(task.id(), e);
}
} else if (tasks.removePendingTaskToCloseDirty(task.id())) {
tasksToCloseDirty.add(task);
} else if ((inputPartitions = tasks.removePendingTaskToUpdateInputPartitions(task.id())) != null) {
task.updateInputPartitions(inputPartitions, topologyMetadata.nodeToSourceTopics(task.id()));
stateUpdater.add(task);
Expand Down Expand Up @@ -755,6 +757,8 @@ void handleRevocation(final Collection<TopicPartition> revokedPartitions) {
}
}

removeRevokedTasksFromStateUpdater(remainingRevokedPartitions);

if (!remainingRevokedPartitions.isEmpty()) {
log.debug("The following revoked partitions {} are missing from the current task partitions. It could "
+ "potentially be due to race condition of consumer detecting the heartbeat failure, or the tasks " +
Expand Down Expand Up @@ -840,6 +844,20 @@ void handleRevocation(final Collection<TopicPartition> revokedPartitions) {
}
}

private void removeRevokedTasksFromStateUpdater(final Set<TopicPartition> remainingRevokedPartitions) {
if (stateUpdater != null) {
for (final Task restoringTask : stateUpdater.getTasks()) {
if (restoringTask.isActive()) {
if (remainingRevokedPartitions.containsAll(restoringTask.inputPartitions())) {
tasks.addPendingTaskToCloseClean(restoringTask.id());
stateUpdater.remove(restoringTask.id());
remainingRevokedPartitions.removeAll(restoringTask.inputPartitions());
}
}
}
}
}

private void prepareCommitAndAddOffsetsToMap(final Set<Task> tasksToPrepare,
final Map<Task, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsPerTask) {
for (final Task task : tasksToPrepare) {
Expand Down Expand Up @@ -867,6 +885,15 @@ private void prepareCommitAndAddOffsetsToMap(final Set<Task> tasksToPrepare,
void handleLostAll() {
log.debug("Closing lost active tasks as zombies.");

closeRunningTasksDirty();
removeLostTasksFromStateUpdater();

if (processingMode == EXACTLY_ONCE_V2) {
activeTaskCreator.reInitializeThreadProducer();
}
}

private void closeRunningTasksDirty() {
final Set<Task> allTask = tasks.allTasks();
for (final Task task : allTask) {
// Even though we've apparently dropped out of the group, we can continue safely to maintain our
Expand All @@ -875,9 +902,16 @@ void handleLostAll() {
closeTaskDirty(task);
}
}
}

if (processingMode == EXACTLY_ONCE_V2) {
activeTaskCreator.reInitializeThreadProducer();
private void removeLostTasksFromStateUpdater() {
if (stateUpdater != null) {
for (final Task restoringTask : stateUpdater.getTasks()) {
if (restoringTask.isActive()) {
tasks.addPendingTaskToCloseDirty(restoringTask.id());
stateUpdater.remove(restoringTask.id());
}
}
}
}

Expand Down
Expand Up @@ -55,7 +55,9 @@ class Tasks {
private final Map<TaskId, Set<TopicPartition>> pendingTasksToRecycle = new HashMap<>();
private final Map<TaskId, Set<TopicPartition>> pendingTasksToUpdateInputPartitions = new HashMap<>();
private final Set<Task> pendingTasksToInit = new HashSet<>();
private final Set<TaskId> pendingTasksToClose = new HashSet<>();
private final Set<TaskId> pendingTasksToCloseClean = new HashSet<>();

private final Set<TaskId> pendingTasksToCloseDirty = new HashSet<>();

// TODO: convert to Stream/StandbyTask when we remove TaskManager#StateMachineTask with mocks
private final Map<TopicPartition, Task> activeTasksPerPartition = new HashMap<>();
Expand Down Expand Up @@ -111,12 +113,19 @@ void addPendingTaskToUpdateInputPartitions(final TaskId taskId, final Set<TopicP
pendingTasksToUpdateInputPartitions.put(taskId, inputPartitions);
}

boolean removePendingTaskToClose(final TaskId taskId) {
return pendingTasksToClose.remove(taskId);
boolean removePendingTaskToCloseDirty(final TaskId taskId) {
return pendingTasksToCloseDirty.remove(taskId);
}
void addPendingTaskToCloseDirty(final TaskId taskId) {
pendingTasksToCloseDirty.add(taskId);
}

boolean removePendingTaskToCloseClean(final TaskId taskId) {
return pendingTasksToCloseClean.remove(taskId);
}

void addPendingTaskToClose(final TaskId taskId) {
pendingTasksToClose.add(taskId);
void addPendingTaskToCloseClean(final TaskId taskId) {
pendingTasksToCloseClean.add(taskId);
}

Set<Task> drainPendingTaskToInit() {
Expand Down
Expand Up @@ -118,6 +118,7 @@
import static org.junit.Assert.fail;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.mock;

Expand All @@ -144,7 +145,9 @@ public class TaskManagerTest {

private final TaskId taskId02 = new TaskId(0, 2);
private final TopicPartition t1p2 = new TopicPartition(topic1, 2);
private final TopicPartition t1p2changelog = new TopicPartition("changelog", 2);
private final Set<TopicPartition> taskId02Partitions = mkSet(t1p2);
private final Set<TopicPartition> taskId02ChangelogPartitions = mkSet(t1p2changelog);

private final TaskId taskId03 = new TaskId(0, 3);
private final TopicPartition t1p3 = new TopicPartition(topic1, 3);
Expand Down Expand Up @@ -343,8 +346,8 @@ public void shouldHandleRemovedTasksToCloseFromStateUpdater() {
expectLastCall().anyTimes();
replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer);

taskManager.tasks().addPendingTaskToClose(taskId00);
taskManager.tasks().addPendingTaskToClose(taskId01);
taskManager.tasks().addPendingTaskToCloseClean(taskId00);
taskManager.tasks().addPendingTaskToCloseClean(taskId01);

taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter -> { });

Expand Down Expand Up @@ -383,6 +386,113 @@ public void shouldHandleRemovedTasksToUpdateInputPartitionsFromStateUpdater() {
Mockito.verify(stateUpdater).add(task01);
}

@Test
public void shouldRemoveStatefulTaskWithRevokedInputPartitionsFromStateUpdaterOnRevocation() {
final StreamTask task = statefulTask(taskId00, taskId00ChangelogPartitions)
.inState(State.RESTORING)
.withInputPartitions(taskId00Partitions).build();
final TaskManager taskManager = setupForRevocationAndLost(mkSet(task), mkSet(task));

taskManager.handleRevocation(taskId00Partitions);

Mockito.verify(stateUpdater).remove(task.id());

taskManager.tryToCompleteRestoration(time.milliseconds(), null);

Mockito.verify(task).closeClean();
}

public void shouldRemoveMultipleStatefulTaskWithRevokedInputPartitionsFromStateUpdaterOnRevocation() {
final StreamTask task1 = statefulTask(taskId00, taskId00ChangelogPartitions)
.inState(State.RESTORING)
.withInputPartitions(taskId00Partitions).build();
final StreamTask task2 = statefulTask(taskId01, taskId01ChangelogPartitions)
.inState(State.RESTORING)
.withInputPartitions(taskId01Partitions).build();
final TaskManager taskManager = setupForRevocationAndLost(mkSet(task1, task2), mkSet(task1, task2));

taskManager.handleRevocation(union(HashSet::new, taskId00Partitions, taskId01Partitions));

Mockito.verify(stateUpdater).remove(task1.id());
Mockito.verify(stateUpdater).remove(task2.id());

taskManager.tryToCompleteRestoration(time.milliseconds(), null);

Mockito.verify(task1).closeClean();
Mockito.verify(task2).closeClean();
}

@Test
public void shouldNotRemoveStatefulTaskWithoutRevokedInputPartitionsFromStateUpdaterOnRevocation() {
final StreamTask task = statefulTask(taskId00, taskId00ChangelogPartitions)
.inState(State.RESTORING)
.withInputPartitions(taskId00Partitions).build();
final TaskManager taskManager = setupForRevocationAndLost(mkSet(task), Collections.emptySet());

taskManager.handleRevocation(taskId01Partitions);

Mockito.verify(stateUpdater, never()).remove(task.id());

taskManager.tryToCompleteRestoration(time.milliseconds(), null);

Mockito.verify(task, never()).closeClean();
}

@Test
public void shouldNotRemoveStandbyTaskFromStateUpdaterOnRevocation() {
final StandbyTask task = standbyTask(taskId00, taskId00ChangelogPartitions)
.inState(State.RESTORING)
.withInputPartitions(taskId00Partitions).build();
final TaskManager taskManager = setupForRevocationAndLost(mkSet(task), Collections.emptySet());

taskManager.handleRevocation(taskId00Partitions);

Mockito.verify(stateUpdater, never()).remove(task.id());

taskManager.tryToCompleteRestoration(time.milliseconds(), null);

Mockito.verify(task, never()).closeClean();
}

@Test
public void shouldRemoveAllActiveTasksFromStateUpdaterOnPartitionLost() {
final StreamTask task1 = statefulTask(taskId00, taskId00ChangelogPartitions)
.inState(State.RESTORING)
.withInputPartitions(taskId00Partitions).build();
final StandbyTask task2 = standbyTask(taskId01, taskId01ChangelogPartitions)
.inState(State.RESTORING)
.withInputPartitions(taskId01Partitions).build();
final StreamTask task3 = statefulTask(taskId02, taskId02ChangelogPartitions)
.inState(State.RESTORING)
.withInputPartitions(taskId02Partitions).build();
final TaskManager taskManager = setupForRevocationAndLost(mkSet(task1, task2, task3), mkSet(task1, task3));

taskManager.handleLostAll();

Mockito.verify(stateUpdater).remove(task1.id());
Mockito.verify(stateUpdater).remove(task3.id());

taskManager.tryToCompleteRestoration(time.milliseconds(), null);

Mockito.verify(task1).closeDirty();
Mockito.verify(task3).closeDirty();
Mockito.verify(task2, never()).closeDirty();
Mockito.verify(task2, never()).closeClean();
}

private TaskManager setupForRevocationAndLost(final Set<Task> tasksInStateUpdater,
final Set<Task> removedTasks) {
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, true);
when(stateUpdater.getTasks()).thenReturn(tasksInStateUpdater);
when(stateUpdater.drainRemovedTasks()).thenReturn(removedTasks);
expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
consumer.resume(anyObject());
expectLastCall().anyTimes();
replay(consumer);

return taskManager;
}

@Test
public void shouldHandleRemovedTasksFromStateUpdater() {
// tasks to recycle
Expand Down Expand Up @@ -436,7 +546,7 @@ public void shouldHandleRemovedTasksFromStateUpdater() {
stateUpdater
);
taskManager.setMainConsumer(consumer);
taskManager.tasks().addPendingTaskToClose(taskId02);
taskManager.tasks().addPendingTaskToCloseClean(taskId02);
taskManager.tasks().addPendingTaskToRecycle(taskId00, taskId00Partitions);
taskManager.tasks().addPendingTaskToRecycle(taskId01, taskId01Partitions);
taskManager.tasks().addPendingTaskToUpdateInputPartitions(taskId03, taskId03Partitions);
Expand Down

0 comments on commit add4ca6

Please sign in to comment.