diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java index 64f80f7963f9..83f705a25617 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java @@ -128,7 +128,7 @@ private static class ClientMetadata { private final ClientState state; private final SortedSet consumers; - ClientMetadata(final String endPoint, final Map clientTags) { + ClientMetadata(final UUID processId, final String endPoint, final Map clientTags) { // get the host info, or null if no endpoint is configured (ie endPoint == null) hostInfo = HostInfo.buildFromEndpoint(endPoint); @@ -137,7 +137,7 @@ private static class ClientMetadata { consumers = new TreeSet<>(); // initialize the client state with client tags - state = new ClientState(clientTags); + state = new ClientState(processId, clientTags); } void addConsumer(final String consumerMemberId, final List ownedPartitions) { @@ -340,7 +340,7 @@ public GroupAssignment assign(final Cluster metadata, final GroupSubscription gr futureMetadataVersion = usedVersion; processId = FUTURE_ID; if (!clientMetadataMap.containsKey(FUTURE_ID)) { - clientMetadataMap.put(FUTURE_ID, new ClientMetadata(null, Collections.emptyMap())); + clientMetadataMap.put(FUTURE_ID, new ClientMetadata(FUTURE_ID, null, Collections.emptyMap())); } } else { processId = info.processId(); @@ -350,7 +350,7 @@ public GroupAssignment assign(final Cluster metadata, final GroupSubscription gr // create the new client metadata if necessary if (clientMetadata == null) { - clientMetadata = new ClientMetadata(info.userEndPoint(), info.clientTags()); + clientMetadata = new ClientMetadata(info.processId(), info.userEndPoint(), info.clientTags()); clientMetadataMap.put(info.processId(), clientMetadata); } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java index b8ba4ce27e10..170ec315a90c 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java @@ -59,25 +59,31 @@ public class ClientState { private final ClientStateTask revokingActiveTasks = new ClientStateTask(null, new TreeMap<>()); private int capacity; + private UUID processId; public ClientState() { - this(0); + this(null, 0); } - public ClientState(final Map clientTags) { - this(0, clientTags); + public ClientState(final UUID processId, final Map clientTags) { + this(processId, 0, clientTags); } ClientState(final int capacity) { - this(capacity, Collections.emptyMap()); + this(null, capacity); } - ClientState(final int capacity, final Map clientTags) { + ClientState(final UUID processId, final int capacity) { + this(processId, capacity, Collections.emptyMap()); + } + + ClientState(final UUID processId, final int capacity, final Map clientTags) { previousStandbyTasks.taskIds(new TreeSet<>()); previousActiveTasks.taskIds(new TreeSet<>()); taskOffsetSums = new TreeMap<>(); taskLagTotals = new TreeMap<>(); this.capacity = capacity; + this.processId = processId; this.clientTags = unmodifiableMap(clientTags); } @@ -99,6 +105,10 @@ int capacity() { return capacity; } + UUID processId() { + return processId; + } + public void incrementCapacity() { capacity++; } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientTagAwareStandbyTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientTagAwareStandbyTaskAssignor.java index cabfa545b101..de5036fe809d 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientTagAwareStandbyTaskAssignor.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientTagAwareStandbyTaskAssignor.java @@ -16,6 +16,9 @@ */ package org.apache.kafka.streams.processor.internals.assignment; +import java.util.List; +import java.util.function.BiFunction; +import java.util.function.Function; import org.apache.kafka.streams.processor.TaskId; import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs; import org.slf4j.Logger; @@ -42,6 +45,19 @@ class ClientTagAwareStandbyTaskAssignor implements StandbyTaskAssignor { private static final Logger log = LoggerFactory.getLogger(ClientTagAwareStandbyTaskAssignor.class); + private final BiFunction> clientTagFunction; + private final Function> tagsFunction; + + public ClientTagAwareStandbyTaskAssignor() { + this((uuid, clientState) -> clientState.clientTags(), assignmentConfigs -> assignmentConfigs.rackAwareAssignmentTags); + } + + public ClientTagAwareStandbyTaskAssignor(final BiFunction> clientTagFunction, + final Function> tagsFunction) { + this.clientTagFunction = clientTagFunction; + this.tagsFunction = tagsFunction; + } + /** * The algorithm distributes standby tasks for the {@param statefulTaskIds} over different tag dimensions. * For each stateful task, the number of standby tasks will be assigned based on configured {@link AssignmentConfigs#numStandbyReplicas}. @@ -56,7 +72,7 @@ public boolean assign(final Map clients, final Set statefulTaskIds, final AssignorConfiguration.AssignmentConfigs configs) { final int numStandbyReplicas = configs.numStandbyReplicas; - final Set rackAwareAssignmentTags = new HashSet<>(configs.rackAwareAssignmentTags); + final Set rackAwareAssignmentTags = new HashSet<>(tagsFunction.apply(configs)); final Map tasksToRemainingStandbys = computeTasksToRemainingStandbys( numStandbyReplicas, @@ -128,8 +144,8 @@ private static void assignPendingStandbyTasksToLeastLoadedClients(final Map sourceClientTags = source.clientTags(); - final Map destinationClientTags = destination.clientTags(); + final Map sourceClientTags = clientTagFunction.apply(source.processId(), source); + final Map destinationClientTags = clientTagFunction.apply(destination.processId(), destination); for (final Entry sourceClientTagEntry : sourceClientTags.entrySet()) { if (!sourceClientTagEntry.getValue().equals(destinationClientTags.get(sourceClientTagEntry.getKey()))) { @@ -141,14 +157,14 @@ public boolean isAllowedTaskMovement(final ClientState source, final ClientState } // Visible for testing - static void fillClientsTagStatistics(final Map clientStates, - final Map> tagEntryToClients, - final Map> tagKeyToValues) { + void fillClientsTagStatistics(final Map clientStates, + final Map> tagEntryToClients, + final Map> tagKeyToValues) { for (final Entry clientStateEntry : clientStates.entrySet()) { final UUID clientId = clientStateEntry.getKey(); final ClientState clientState = clientStateEntry.getValue(); - clientState.clientTags().forEach((tagKey, tagValue) -> { + clientTagFunction.apply(clientId, clientState).forEach((tagKey, tagValue) -> { tagKeyToValues.computeIfAbsent(tagKey, ignored -> new HashSet<>()).add(tagValue); tagEntryToClients.computeIfAbsent(new TagEntry(tagKey, tagValue), ignored -> new HashSet<>()).add(clientId); }); @@ -156,16 +172,16 @@ static void fillClientsTagStatistics(final Map clientStates, } // Visible for testing - static void assignStandbyTasksToClientsWithDifferentTags(final int numberOfStandbyClients, - final ConstrainedPrioritySet standbyTaskClientsByTaskLoad, - final TaskId activeTaskId, - final UUID activeTaskClient, - final Set rackAwareAssignmentTags, - final Map clientStates, - final Map tasksToRemainingStandbys, - final Map> tagKeyToValues, - final Map> tagEntryToClients, - final Map pendingStandbyTasksToClientId) { + void assignStandbyTasksToClientsWithDifferentTags(final int numberOfStandbyClients, + final ConstrainedPrioritySet standbyTaskClientsByTaskLoad, + final TaskId activeTaskId, + final UUID activeTaskClient, + final Set rackAwareAssignmentTags, + final Map clientStates, + final Map tasksToRemainingStandbys, + final Map> tagKeyToValues, + final Map> tagEntryToClients, + final Map pendingStandbyTasksToClientId) { standbyTaskClientsByTaskLoad.offerAll(clientStates.keySet()); // We set countOfUsedClients as 1 because client where active task is located has to be considered as used. @@ -199,9 +215,10 @@ static void assignStandbyTasksToClientsWithDifferentTags(final int numberOfStand numRemainingStandbys--; log.debug("Assigning {} out of {} standby tasks for an active task [{}] with client tags {}. " + - "Standby task client tags are {}.", - numberOfStandbyClients - numRemainingStandbys, numberOfStandbyClients, activeTaskId, - clientStates.get(activeTaskClient).clientTags(), clientStateOnUsedTagDimensions.clientTags()); + "Standby task client tags are {}.", + numberOfStandbyClients - numRemainingStandbys, numberOfStandbyClients, activeTaskId, + clientTagFunction.apply(activeTaskClient, clientStates.get(activeTaskClient)), + clientTagFunction.apply(clientStateOnUsedTagDimensions.processId(), clientStateOnUsedTagDimensions)); clientStateOnUsedTagDimensions.assignStandby(activeTaskId); lastUsedClient = clientOnUnusedTagDimensions; @@ -218,7 +235,7 @@ static void assignStandbyTasksToClientsWithDifferentTags(final int numberOfStand "Standby task assignment will fall back to assigning standby tasks to the least loaded clients.", numRemainingStandbys, numberOfStandbyClients, activeTaskId, rackAwareAssignmentTags, - clientStates.get(activeTaskClient).clientTags()); + clientTagFunction.apply(activeTaskClient, clientStates.get(activeTaskClient))); } else { tasksToRemainingStandbys.remove(activeTaskId); @@ -230,14 +247,14 @@ private static boolean isClientUsedOnAnyOfTheTagEntries(final UUID client, return tagEntryToUsedClients.values().stream().anyMatch(usedClients -> usedClients.contains(client)); } - private static void updateClientsOnAlreadyUsedTagEntries(final UUID usedClient, + private void updateClientsOnAlreadyUsedTagEntries(final UUID usedClient, final int countOfUsedClients, final Set rackAwareAssignmentTags, final Map clientStates, final Map> tagEntryToClients, final Map> tagKeyToValues, final Map> tagEntryToUsedClients) { - final Map usedClientTags = clientStates.get(usedClient).clientTags(); + final Map usedClientTags = clientTagFunction.apply(usedClient, clientStates.get(usedClient)); for (final Entry usedClientTagEntry : usedClientTags.entrySet()) { final String tagKey = usedClientTagEntry.getKey(); diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java index c54199ad1773..f402f8c279df 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java @@ -134,7 +134,7 @@ private void assignStandbyReplicaTasks(final TreeMap clientSt return; } - final StandbyTaskAssignor standbyTaskAssignor = StandbyTaskAssignorFactory.create(configs); + final StandbyTaskAssignor standbyTaskAssignor = StandbyTaskAssignorFactory.create(configs, null); standbyTaskAssignor.assign(clientStates, allTaskIds, statefulTasks, configs); diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java index 0b52cad482d6..9eb82593062e 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java @@ -17,6 +17,7 @@ package org.apache.kafka.streams.processor.internals.assignment; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -48,23 +49,29 @@ public class RackAwareTaskAssignor { private final Cluster fullMetadata; private final Map> partitionsForTask; - private final Map>> racksForProcess; private final AssignmentConfigs assignmentConfigs; private final Map> racksForPartition; + private final Map racksForProcess; private final InternalTopicManager internalTopicManager; + private final boolean validClientRack; public RackAwareTaskAssignor(final Cluster fullMetadata, final Map> partitionsForTask, final Map> tasksForTopicGroup, - final Map>> racksForProcess, + final Map>> racksForProcessConsumer, final InternalTopicManager internalTopicManager, final AssignmentConfigs assignmentConfigs) { this.fullMetadata = fullMetadata; this.partitionsForTask = partitionsForTask; - this.racksForProcess = racksForProcess; this.internalTopicManager = internalTopicManager; this.assignmentConfigs = assignmentConfigs; this.racksForPartition = new HashMap<>(); + this.racksForProcess = new HashMap<>(); + validClientRack = validateClientRack(racksForProcessConsumer); + } + + public boolean validClientRack() { + return validClientRack; } public synchronized boolean canEnableRackAwareAssignor() { @@ -75,17 +82,12 @@ public synchronized boolean canEnableRackAwareAssignor() { return false; } */ - - if (!validateClientRack()) { - return false; - } - - return validateTopicPartitionRack(); + return validClientRack && validateTopicPartitionRack(); // TODO: add changelog topic, standby task validation } // Visible for testing. This method also checks if all TopicPartitions exist in cluster - public boolean populateTopicsToDiscribe(final Set topicsToDescribe) { + public boolean populateTopicsToDescribe(final Set topicsToDescribe) { // Make sure rackId exist for all TopicPartitions needed for (final Set topicPartitions : partitionsForTask.values()) { for (final TopicPartition topicPartition : topicPartitions) { @@ -115,7 +117,7 @@ public boolean populateTopicsToDiscribe(final Set topicsToDescribe) { private boolean validateTopicPartitionRack() { // Make sure rackId exist for all TopicPartitions needed final Set topicsToDescribe = new HashSet<>(); - if (!populateTopicsToDiscribe(topicsToDescribe)) { + if (!populateTopicsToDescribe(topicsToDescribe)) { return false; } @@ -155,51 +157,56 @@ private boolean validateTopicPartitionRack() { return true; } - // Visible for testing - public boolean validateClientRack() { + private boolean validateClientRack(final Map>> racksForProcessConsumer) { /* * Check rack information is populated correctly in clients * 1. RackId exist for all clients * 2. Different consumerId for same process should have same rackId */ - for (final Map.Entry>> entry : racksForProcess.entrySet()) { + for (final Map.Entry>> entry : racksForProcessConsumer.entrySet()) { final UUID processId = entry.getKey(); KeyValue previousRackInfo = null; for (final Map.Entry> rackEntry : entry.getValue().entrySet()) { if (!rackEntry.getValue().isPresent()) { - log.warn("RackId doesn't exist for process {} and consumer {}. Disable {}", - processId, rackEntry.getKey(), getClass().getName()); + log.error(String.format("RackId doesn't exist for process %s and consumer %s", + processId, rackEntry.getKey())); return false; } if (previousRackInfo == null) { previousRackInfo = KeyValue.pair(rackEntry.getKey(), rackEntry.getValue().get()); } else if (!previousRackInfo.value.equals(rackEntry.getValue().get())) { log.error( - "Consumers {} and {} for same process {} has different rackId {} and {}. File a ticket for this bug. Disable {}", - previousRackInfo.key, - rackEntry.getKey(), - entry.getKey(), - previousRackInfo.value, - rackEntry.getValue().get(), - getClass().getName()); + String.format("Consumers %s and %s for same process %s has different rackId %s and %s. File a ticket for this bug", + previousRackInfo.key, + rackEntry.getKey(), + entry.getKey(), + previousRackInfo.value, + rackEntry.getValue().get() + ) + ); return false; } } + if (previousRackInfo == null) { + log.error(String.format("RackId doesn't exist for process %s", processId)); + return false; + } + racksForProcess.put(entry.getKey(), previousRackInfo.value); } + return true; } + public Map racksForProcess() { + return Collections.unmodifiableMap(racksForProcess); + } + private int getCost(final TaskId taskId, final UUID processId, final boolean inCurrentAssignment, final int trafficCost, final int nonOverlapCost) { - final Map> clientRacks = racksForProcess.get(processId); - if (clientRacks == null) { - throw new IllegalStateException("Client " + processId + " doesn't exist in processRacks"); - } - final Optional> clientRackOpt = clientRacks.values().stream().filter(Optional::isPresent).findFirst(); - if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) { + final String clientRack = racksForProcess.get(processId); + if (clientRack == null) { throw new IllegalStateException("Client " + processId + " doesn't have rack configured. Maybe forgot to call canEnableRackAwareAssignor first"); } - final String clientRack = clientRackOpt.get().get(); final Set topicPartitions = partitionsForTask.get(taskId); if (topicPartitions == null || topicPartitions.isEmpty()) { throw new IllegalStateException("Task " + taskId + " has no TopicPartitions"); @@ -227,11 +234,20 @@ private static int getSinkID(final List clientList, final List tas return clientList.size() + taskIdList.size(); } - // For testing. canEnableRackAwareAssignor must be called first - long activeTasksCost(final SortedMap clientStates, final SortedSet activeTasks, final int trafficCost, final int nonOverlapCost) { + /** + * Compute the cost for the provided {@code activeTasks}. The passed in active tasks must be contained in {@code clientState}. + */ + long activeTasksCost(final SortedSet activeTasks, + final SortedMap clientStates, + final int trafficCost, + final int nonOverlapCost) { + if (activeTasks.isEmpty()) { + return 0; + } + final List clientList = new ArrayList<>(clientStates.keySet()); final List taskIdList = new ArrayList<>(activeTasks); - final Graph graph = constructActiveTaskGraph(activeTasks, clientList, taskIdList, + final Graph graph = constructActiveTaskGraph(clientList, taskIdList, clientStates, new HashMap<>(), new HashMap<>(), trafficCost, nonOverlapCost); return graph.totalCost(); } @@ -245,14 +261,14 @@ long activeTasksCost(final SortedMap clientStates, final Sort * cross rack traffic can be higher. In extreme case, if we set {@code nonOverlapCost} to 0 and @{code trafficCost} * to a positive value, the computed assignment will be minimum for cross rack traffic. If we set {@code trafficCost} to 0, * and {@code nonOverlapCost} to a positive value, the computed assignment should be the same as input - * @param clientStates Client states * @param activeTasks Tasks to reassign if needed. They must be assigned already in clientStates + * @param clientStates Client states * @param trafficCost Cost of cross rack traffic for each TopicPartition * @param nonOverlapCost Cost of assign a task to a different client * @return Total cost after optimization */ - public long optimizeActiveTasks(final SortedMap clientStates, - final SortedSet activeTasks, + public long optimizeActiveTasks(final SortedSet activeTasks, + final SortedMap clientStates, final int trafficCost, final int nonOverlapCost) { if (activeTasks.isEmpty()) { @@ -263,20 +279,19 @@ public long optimizeActiveTasks(final SortedMap clientStates, final List taskIdList = new ArrayList<>(activeTasks); final Map taskClientMap = new HashMap<>(); final Map originalAssignedTaskNumber = new HashMap<>(); - final Graph graph = constructActiveTaskGraph(activeTasks, clientList, taskIdList, + final Graph graph = constructActiveTaskGraph(clientList, taskIdList, clientStates, taskClientMap, originalAssignedTaskNumber, trafficCost, nonOverlapCost); graph.solveMinCostFlow(); final long cost = graph.totalCost(); - assignActiveTaskFromMinCostFlow(graph, activeTasks, clientList, taskIdList, + assignActiveTaskFromMinCostFlow(graph, clientList, taskIdList, clientStates, originalAssignedTaskNumber, taskClientMap); return cost; } - private Graph constructActiveTaskGraph(final SortedSet activeTasks, - final List clientList, + private Graph constructActiveTaskGraph(final List clientList, final List taskIdList, final Map clientStates, final Map taskClientMap, @@ -285,7 +300,7 @@ private Graph constructActiveTaskGraph(final SortedSet activeTa final int nonOverlapCost) { final Graph graph = new Graph<>(); - for (final TaskId taskId : activeTasks) { + for (final TaskId taskId : taskIdList) { for (final Entry clientState : clientStates.entrySet()) { if (clientState.getValue().hasAssignedTask(taskId)) { originalAssignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum); @@ -315,13 +330,12 @@ private Graph constructActiveTaskGraph(final SortedSet activeTa if (!taskClientMap.containsKey(taskId)) { throw new IllegalArgumentException("Task " + taskId + " not assigned to any client"); } - } - final int sinkId = getSinkID(clientList, taskIdList); - for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) { + // Add edge from source to task graph.addEdge(SOURCE_ID, taskNodeId, 1, 0, 1); } + final int sinkId = getSinkID(clientList, taskIdList); // It's possible that some clients have 0 task assign. These clients will have 0 tasks assigned // even though it may have higher traffic cost. This is to maintain the original assigned task count for (int i = 0; i < clientList.size(); i++) { @@ -338,7 +352,6 @@ private Graph constructActiveTaskGraph(final SortedSet activeTa } private void assignActiveTaskFromMinCostFlow(final Graph graph, - final SortedSet activeTasks, final List clientList, final List taskIdList, final Map clientStates, @@ -367,14 +380,14 @@ private void assignActiveTaskFromMinCostFlow(final Graph graph, } // Validate task assigned - if (tasksAssigned != activeTasks.size()) { + if (tasksAssigned != taskIdList.size()) { throw new IllegalStateException("Computed active task assignment number " - + tasksAssigned + " is different size " + activeTasks.size()); + + tasksAssigned + " is different size " + taskIdList.size()); } // Validate original assigned task number matches final Map assignedTaskNumber = new HashMap<>(); - for (final TaskId taskId : activeTasks) { + for (final TaskId taskId : taskIdList) { for (final Entry clientState : clientStates.entrySet()) { if (clientState.getValue().hasAssignedTask(taskId)) { assignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum); diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignorFactory.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignorFactory.java index 30c78f33f38d..4d47a843042e 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignorFactory.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignorFactory.java @@ -16,12 +16,27 @@ */ package org.apache.kafka.streams.processor.internals.assignment; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; + +import java.util.Collections; +import java.util.Map; +import java.util.UUID; + class StandbyTaskAssignorFactory { private StandbyTaskAssignorFactory() {} - static StandbyTaskAssignor create(final AssignorConfiguration.AssignmentConfigs configs) { + static StandbyTaskAssignor create(final AssignorConfiguration.AssignmentConfigs configs, + final RackAwareTaskAssignor rackAwareTaskAssignor) { if (!configs.rackAwareAssignmentTags.isEmpty()) { return new ClientTagAwareStandbyTaskAssignor(); + } else if (rackAwareTaskAssignor != null && rackAwareTaskAssignor.validClientRack()) { + // racksForProcess should be populated if rackAwareTaskAssignor isn't null + final Map racksForProcess = rackAwareTaskAssignor.racksForProcess(); + return new ClientTagAwareStandbyTaskAssignor( + (processId, clientState) -> mkMap(mkEntry("rack", racksForProcess.get(processId))), + assignmentConfigs -> Collections.singletonList("rack") + ); } else { return new DefaultStandbyTaskAssignor(); } diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java index 5f53e028a7ab..5da4da5ce96f 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java @@ -71,11 +71,11 @@ public final class AssignmentTestUtils { public static final UUID UUID_8 = uuidForInt(8); public static final UUID UUID_9 = uuidForInt(9); - public static final String RACK_0 = "rock0"; - public static final String RACK_1 = "rock1"; - public static final String RACK_2 = "rock2"; - public static final String RACK_3 = "rock3"; - public static final String RACK_4 = "rock4"; + public static final String RACK_0 = "rack0"; + public static final String RACK_1 = "rack1"; + public static final String RACK_2 = "rack2"; + public static final String RACK_3 = "rack3"; + public static final String RACK_4 = "rack4"; public static final Node NODE_0 = new Node(0, "node0", 1, RACK_0); public static final Node NODE_1 = new Node(1, "node1", 1, RACK_1); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java index 940c69313222..54df914cf534 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java @@ -46,6 +46,9 @@ import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TP_1_0; import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TP_1_1; import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TP_1_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3; import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasActiveTasks; import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasStandbyTasks; import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM; @@ -55,6 +58,7 @@ import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; @@ -537,7 +541,7 @@ public void shouldThrowIllegalStateExceptionIfAssignedTasksForConsumerToNonClien @Test public void shouldReturnClientTags() { final Map clientTags = mkMap(mkEntry("k1", "v1")); - assertEquals(clientTags, new ClientState(0, clientTags).clientTags()); + assertEquals(clientTags, new ClientState(null, 0, clientTags).clientTags()); } @Test @@ -545,4 +549,11 @@ public void shouldReturnEmptyClientTagsMapByDefault() { assertTrue(new ClientState().clientTags().isEmpty()); } + @Test + public void shouldSetProcessId() { + assertEquals(UUID_1, new ClientState(UUID_1, 1).processId()); + assertEquals(UUID_2, new ClientState(UUID_2, mkMap()).processId()); + assertEquals(UUID_3, new ClientState(UUID_3, 1, mkMap()).processId()); + assertNull(new ClientState().processId()); + } } diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientTagAwareStandbyTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientTagAwareStandbyTaskAssignorTest.java index 631430c6a82f..07da8e4f83ea 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientTagAwareStandbyTaskAssignorTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientTagAwareStandbyTaskAssignorTest.java @@ -44,13 +44,15 @@ import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_1; import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_2; import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.uuidForInt; -import static org.apache.kafka.streams.processor.internals.assignment.ClientTagAwareStandbyTaskAssignor.assignStandbyTasksToClientsWithDifferentTags; -import static org.apache.kafka.streams.processor.internals.assignment.ClientTagAwareStandbyTaskAssignor.fillClientsTagStatistics; import static org.apache.kafka.streams.processor.internals.assignment.StandbyTaskAssignmentUtils.computeTasksToRemainingStandbys; import static org.apache.kafka.streams.processor.internals.assignment.StandbyTaskAssignmentUtils.createLeastLoadedPrioritySetConstrainedByAssignedTask; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class ClientTagAwareStandbyTaskAssignorTest { private static final String ZONE_TAG = "zone"; @@ -90,17 +92,17 @@ public void shouldNotAssignStatelessTasksToAnyClients() { ); final Map clientStates = mkMap( - mkEntry(UUID_1, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0, TASK_1_0)), - mkEntry(UUID_2, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1)))), - mkEntry(UUID_3, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_1)))), + mkEntry(UUID_1, createClientStateWithCapacity(UUID_1, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0, TASK_1_0)), + mkEntry(UUID_2, createClientStateWithCapacity(UUID_2, 2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1)))), + mkEntry(UUID_3, createClientStateWithCapacity(UUID_3, 2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_1)))), - mkEntry(UUID_4, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_2)), TASK_0_1, TASK_1_1)), - mkEntry(UUID_5, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_2)))), - mkEntry(UUID_6, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_2)))), + mkEntry(UUID_4, createClientStateWithCapacity(UUID_4, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_2)), TASK_0_1, TASK_1_1)), + mkEntry(UUID_5, createClientStateWithCapacity(UUID_5, 2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_2)))), + mkEntry(UUID_6, createClientStateWithCapacity(UUID_6, 2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_2)))), - mkEntry(UUID_7, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_3)), TASK_0_2, TASK_1_2)), - mkEntry(UUID_8, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_3)))), - mkEntry(UUID_9, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_3)))) + mkEntry(UUID_7, createClientStateWithCapacity(UUID_7, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_3)), TASK_0_2, TASK_1_2)), + mkEntry(UUID_8, createClientStateWithCapacity(UUID_8, 2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_3)))), + mkEntry(UUID_9, createClientStateWithCapacity(UUID_9, 2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_3)))) ); final Set allActiveTasks = findAllActiveTasks(clientStates); @@ -120,9 +122,9 @@ public void shouldRemoveClientToRemainingStandbysAndNotPopulatePendingStandbyTas final int numStandbyReplicas = 2; final Set rackAwareAssignmentTags = mkSet(ZONE_TAG, CLUSTER_TAG); final Map clientStates = mkMap( - mkEntry(UUID_1, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0)), - mkEntry(UUID_2, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_2)), TASK_0_1)), - mkEntry(UUID_3, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_3)), TASK_0_2)) + mkEntry(UUID_1, createClientStateWithCapacity(UUID_1, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0)), + mkEntry(UUID_2, createClientStateWithCapacity(UUID_2, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_2)), TASK_0_1)), + mkEntry(UUID_3, createClientStateWithCapacity(UUID_3, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_3)), TASK_0_2)) ); final ConstrainedPrioritySet constrainedPrioritySet = createLeastLoadedPrioritySetConstrainedByAssignedTask(clientStates); @@ -134,13 +136,13 @@ public void shouldRemoveClientToRemainingStandbysAndNotPopulatePendingStandbyTas final Map> tagKeyToValues = new HashMap<>(); final Map> tagEntryToClients = new HashMap<>(); - fillClientsTagStatistics(clientStates, tagEntryToClients, tagKeyToValues); + new ClientTagAwareStandbyTaskAssignor().fillClientsTagStatistics(clientStates, tagEntryToClients, tagKeyToValues); final Map pendingStandbyTasksToClientId = new HashMap<>(); final Map tasksToRemainingStandbys = computeTasksToRemainingStandbys(numStandbyReplicas, allActiveTasks); for (final TaskId activeTaskId : allActiveTasks) { - assignStandbyTasksToClientsWithDifferentTags( + new ClientTagAwareStandbyTaskAssignor().assignStandbyTasksToClientsWithDifferentTags( numStandbyReplicas, constrainedPrioritySet, activeTaskId, @@ -162,9 +164,9 @@ public void shouldRemoveClientToRemainingStandbysAndNotPopulatePendingStandbyTas public void shouldUpdateClientToRemainingStandbysAndPendingStandbyTasksToClientIdWhenNotAllStandbyTasksWereAssigned() { final Set rackAwareAssignmentTags = mkSet(ZONE_TAG, CLUSTER_TAG); final Map clientStates = mkMap( - mkEntry(UUID_1, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0)), - mkEntry(UUID_2, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_2)), TASK_0_1)), - mkEntry(UUID_3, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_3)), TASK_0_2)) + mkEntry(UUID_1, createClientStateWithCapacity(UUID_1, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0)), + mkEntry(UUID_2, createClientStateWithCapacity(UUID_2, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_2)), TASK_0_1)), + mkEntry(UUID_3, createClientStateWithCapacity(UUID_3, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_3)), TASK_0_2)) ); final ConstrainedPrioritySet constrainedPrioritySet = createLeastLoadedPrioritySetConstrainedByAssignedTask(clientStates); @@ -177,13 +179,13 @@ public void shouldUpdateClientToRemainingStandbysAndPendingStandbyTasksToClientI final Map> tagKeyToValues = new HashMap<>(); final Map> tagEntryToClients = new HashMap<>(); - fillClientsTagStatistics(clientStates, tagEntryToClients, tagKeyToValues); + new ClientTagAwareStandbyTaskAssignor().fillClientsTagStatistics(clientStates, tagEntryToClients, tagKeyToValues); final Map pendingStandbyTasksToClientId = new HashMap<>(); final Map tasksToRemainingStandbys = computeTasksToRemainingStandbys(numStandbyReplicas, allActiveTasks); for (final TaskId activeTaskId : allActiveTasks) { - assignStandbyTasksToClientsWithDifferentTags( + new ClientTagAwareStandbyTaskAssignor().assignStandbyTasksToClientsWithDifferentTags( numStandbyReplicas, constrainedPrioritySet, activeTaskId, @@ -214,16 +216,16 @@ public void shouldUpdateClientToRemainingStandbysAndPendingStandbyTasksToClientI @Test public void shouldPermitTaskMovementWhenClientTagsMatch() { - final ClientState source = createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1))); - final ClientState destination = createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1))); + final ClientState source = createClientStateWithCapacity(UUID_1, 1, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1))); + final ClientState destination = createClientStateWithCapacity(UUID_2, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1))); assertTrue(standbyTaskAssignor.isAllowedTaskMovement(source, destination)); } @Test public void shouldDeclineTaskMovementWhenClientTagsDoNotMatch() { - final ClientState source = createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1))); - final ClientState destination = createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1))); + final ClientState source = createClientStateWithCapacity(UUID_1, 1, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1))); + final ClientState destination = createClientStateWithCapacity(UUID_2, 1, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1))); assertFalse(standbyTaskAssignor.isAllowedTaskMovement(source, destination)); } @@ -231,17 +233,17 @@ public void shouldDeclineTaskMovementWhenClientTagsDoNotMatch() { @Test public void shouldDistributeStandbyTasksWhenActiveTasksAreLocatedOnSameZone() { final Map clientStates = mkMap( - mkEntry(UUID_1, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0, TASK_1_0)), - mkEntry(UUID_2, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1)))), - mkEntry(UUID_3, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_1)))), + mkEntry(UUID_1, createClientStateWithCapacity(UUID_1, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0, TASK_1_0)), + mkEntry(UUID_2, createClientStateWithCapacity(UUID_2, 2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1)))), + mkEntry(UUID_3, createClientStateWithCapacity(UUID_3, 2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_1)))), - mkEntry(UUID_4, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_2)), TASK_0_1, TASK_1_1)), - mkEntry(UUID_5, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_2)))), - mkEntry(UUID_6, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_2)))), + mkEntry(UUID_4, createClientStateWithCapacity(UUID_4, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_2)), TASK_0_1, TASK_1_1)), + mkEntry(UUID_5, createClientStateWithCapacity(UUID_5, 2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_2)))), + mkEntry(UUID_6, createClientStateWithCapacity(UUID_6, 2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_2)))), - mkEntry(UUID_7, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_3)), TASK_0_2, TASK_1_2)), - mkEntry(UUID_8, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_3)))), - mkEntry(UUID_9, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_3)))) + mkEntry(UUID_7, createClientStateWithCapacity(UUID_7, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_3)), TASK_0_2, TASK_1_2)), + mkEntry(UUID_8, createClientStateWithCapacity(UUID_8, 2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_3)))), + mkEntry(UUID_9, createClientStateWithCapacity(UUID_9, 2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_3)))) ); final Set allActiveTasks = findAllActiveTasks(clientStates); @@ -313,20 +315,133 @@ public void shouldDistributeStandbyTasksWhenActiveTasksAreLocatedOnSameZone() { ); } + @Test + public void shouldDistributeStandbyTasksUsingFunctionAndSupplierTags() { + final Map racksForProcess = mkMap( + mkEntry(UUID_1, "rack1"), + mkEntry(UUID_2, "rack2"), + mkEntry(UUID_3, "rack3"), + mkEntry(UUID_4, "rack1"), + mkEntry(UUID_5, "rack2"), + mkEntry(UUID_6, "rack3"), + mkEntry(UUID_7, "rack1"), + mkEntry(UUID_8, "rack2"), + mkEntry(UUID_9, "rack3") + ); + final RackAwareTaskAssignor rackAwareTaskAssignor = mock(RackAwareTaskAssignor.class); + when(rackAwareTaskAssignor.validClientRack()).thenReturn(true); + when(rackAwareTaskAssignor.racksForProcess()).thenReturn(racksForProcess); + final AssignmentConfigs assignmentConfigs = newAssignmentConfigs(2); + standbyTaskAssignor = StandbyTaskAssignorFactory.create(assignmentConfigs, rackAwareTaskAssignor); + verify(rackAwareTaskAssignor, times(1)).racksForProcess(); + + final Map clientStates = mkMap( + mkEntry(UUID_1, createClientStateWithCapacity(UUID_1, 2, mkMap(), TASK_0_0, TASK_1_0)), + mkEntry(UUID_2, createClientStateWithCapacity(UUID_2, 2, mkMap(), TASK_0_1, TASK_1_1)), + mkEntry(UUID_3, createClientStateWithCapacity(UUID_3, 2, mkMap(), TASK_0_2, TASK_1_2)), + + mkEntry(UUID_4, createClientStateWithCapacity(UUID_4, 2, mkMap())), + mkEntry(UUID_5, createClientStateWithCapacity(UUID_5, 2, mkMap())), + mkEntry(UUID_6, createClientStateWithCapacity(UUID_6, 2, mkMap())), + + mkEntry(UUID_7, createClientStateWithCapacity(UUID_7, 2, mkMap())), + mkEntry(UUID_8, createClientStateWithCapacity(UUID_8, 2, mkMap())), + mkEntry(UUID_9, createClientStateWithCapacity(UUID_9, 2, mkMap())) + ); + + final Map clientStatesWithTags = mkMap( + mkEntry(UUID_1, createClientStateWithCapacity(UUID_1, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1)), TASK_0_0, TASK_1_0)), + mkEntry(UUID_2, createClientStateWithCapacity(UUID_2, 2, mkMap(mkEntry(ZONE_TAG, ZONE_2)), TASK_0_1, TASK_1_1)), + mkEntry(UUID_3, createClientStateWithCapacity(UUID_3, 2, mkMap(mkEntry(ZONE_TAG, ZONE_3)), TASK_0_2, TASK_1_2)), + + mkEntry(UUID_4, createClientStateWithCapacity(UUID_4, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1)))), + mkEntry(UUID_5, createClientStateWithCapacity(UUID_5, 2, mkMap(mkEntry(ZONE_TAG, ZONE_2)))), + mkEntry(UUID_6, createClientStateWithCapacity(UUID_6, 2, mkMap(mkEntry(ZONE_TAG, ZONE_3)))), + + mkEntry(UUID_7, createClientStateWithCapacity(UUID_7, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1)))), + mkEntry(UUID_8, createClientStateWithCapacity(UUID_8, 2, mkMap(mkEntry(ZONE_TAG, ZONE_2)))), + mkEntry(UUID_9, createClientStateWithCapacity(UUID_9, 2, mkMap(mkEntry(ZONE_TAG, ZONE_3)))) + ); + + final Set allActiveTasks = findAllActiveTasks(clientStates); + + standbyTaskAssignor.assign(clientStates, allActiveTasks, allActiveTasks, assignmentConfigs); + + final AssignmentConfigs assignmentConfigsWithTags = newAssignmentConfigs(2, ZONE_TAG); + standbyTaskAssignor = new ClientTagAwareStandbyTaskAssignor(); + standbyTaskAssignor.assign(clientStatesWithTags, allActiveTasks, allActiveTasks, assignmentConfigsWithTags); + + Stream.of(clientStates, clientStatesWithTags).forEach( + cs -> { + assertTrue(cs.values().stream().allMatch(ClientState::reachedCapacity)); + Stream.of(UUID_1, UUID_2, UUID_3) + .forEach(client -> assertStandbyTaskCountForClientEqualsTo(cs, client, 0)); + Stream.of(UUID_4, UUID_5, UUID_6, UUID_7, UUID_8, UUID_9) + .forEach(client -> assertStandbyTaskCountForClientEqualsTo(cs, client, 2)); + assertTotalNumberOfStandbyTasksEqualsTo(cs, 12); + + assertTrue( + containsStandbyTasks( + TASK_0_0, + cs, + mkSet(UUID_2, UUID_3, UUID_5, UUID_6, UUID_8, UUID_9) + ) + ); + assertTrue( + containsStandbyTasks( + TASK_1_0, + cs, + mkSet(UUID_2, UUID_3, UUID_5, UUID_6, UUID_8, UUID_9) + ) + ); + + assertTrue( + containsStandbyTasks( + TASK_0_1, + cs, + mkSet(UUID_1, UUID_3, UUID_4, UUID_6, UUID_7, UUID_9) + ) + ); + assertTrue( + containsStandbyTasks( + TASK_1_1, + cs, + mkSet(UUID_1, UUID_3, UUID_4, UUID_6, UUID_7, UUID_9) + ) + ); + + assertTrue( + containsStandbyTasks( + TASK_0_2, + cs, + mkSet(UUID_1, UUID_2, UUID_4, UUID_5, UUID_7, UUID_8) + ) + ); + assertTrue( + containsStandbyTasks( + TASK_1_2, + cs, + mkSet(UUID_1, UUID_2, UUID_4, UUID_5, UUID_7, UUID_8) + ) + ); + } + ); + } + @Test public void shouldDistributeStandbyTasksWhenActiveTasksAreLocatedOnSameCluster() { final Map clientStates = mkMap( - mkEntry(UUID_1, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0, TASK_1_0)), - mkEntry(UUID_2, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_1, TASK_1_1)), - mkEntry(UUID_3, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_2, TASK_1_2)), + mkEntry(UUID_1, createClientStateWithCapacity(UUID_1, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0, TASK_1_0)), + mkEntry(UUID_2, createClientStateWithCapacity(UUID_2, 2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_1, TASK_1_1)), + mkEntry(UUID_3, createClientStateWithCapacity(UUID_3, 2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_2, TASK_1_2)), - mkEntry(UUID_4, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_2)))), - mkEntry(UUID_5, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_2)))), - mkEntry(UUID_6, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_2)))), + mkEntry(UUID_4, createClientStateWithCapacity(UUID_4, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_2)))), + mkEntry(UUID_5, createClientStateWithCapacity(UUID_5, 2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_2)))), + mkEntry(UUID_6, createClientStateWithCapacity(UUID_6, 2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_2)))), - mkEntry(UUID_7, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_3)))), - mkEntry(UUID_8, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_3)))), - mkEntry(UUID_9, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_3)))) + mkEntry(UUID_7, createClientStateWithCapacity(UUID_7, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_3)))), + mkEntry(UUID_8, createClientStateWithCapacity(UUID_8, 2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_3)))), + mkEntry(UUID_9, createClientStateWithCapacity(UUID_9, 2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_3)))) ); final Set allActiveTasks = findAllActiveTasks(clientStates); @@ -401,13 +516,13 @@ public void shouldDistributeStandbyTasksWhenActiveTasksAreLocatedOnSameCluster() @Test public void shouldDoThePartialRackAwareness() { final Map clientStates = mkMap( - mkEntry(UUID_1, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_0)), - mkEntry(UUID_2, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_2)))), - mkEntry(UUID_3, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_3)))), + mkEntry(UUID_1, createClientStateWithCapacity(UUID_1, 1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_0)), + mkEntry(UUID_2, createClientStateWithCapacity(UUID_2, 1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_2)))), + mkEntry(UUID_3, createClientStateWithCapacity(UUID_3, 1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_3)))), - mkEntry(UUID_4, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_2), mkEntry(ZONE_TAG, ZONE_1)))), - mkEntry(UUID_5, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_2), mkEntry(ZONE_TAG, ZONE_2)))), - mkEntry(UUID_6, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_2), mkEntry(ZONE_TAG, ZONE_3)), TASK_1_0)) + mkEntry(UUID_4, createClientStateWithCapacity(UUID_4, 1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_2), mkEntry(ZONE_TAG, ZONE_1)))), + mkEntry(UUID_5, createClientStateWithCapacity(UUID_5, 1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_2), mkEntry(ZONE_TAG, ZONE_2)))), + mkEntry(UUID_6, createClientStateWithCapacity(UUID_6, 1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_2), mkEntry(ZONE_TAG, ZONE_3)), TASK_1_0)) ); final Set allActiveTasks = findAllActiveTasks(clientStates); @@ -466,12 +581,12 @@ public void shouldDoThePartialRackAwareness() { @Test public void shouldDistributeClientsOnDifferentZoneTagsEvenWhenClientsReachedCapacity() { final Map clientStates = mkMap( - mkEntry(UUID_1, createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0)), - mkEntry(UUID_2, createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_1)), - mkEntry(UUID_3, createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_2)), - mkEntry(UUID_4, createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_1_0)), - mkEntry(UUID_5, createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_1_1)), - mkEntry(UUID_6, createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_1_2)) + mkEntry(UUID_1, createClientStateWithCapacity(UUID_1, 1, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0)), + mkEntry(UUID_2, createClientStateWithCapacity(UUID_2, 1, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_1)), + mkEntry(UUID_3, createClientStateWithCapacity(UUID_3, 1, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_2)), + mkEntry(UUID_4, createClientStateWithCapacity(UUID_4, 1, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_1_0)), + mkEntry(UUID_5, createClientStateWithCapacity(UUID_5, 1, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_1_1)), + mkEntry(UUID_6, createClientStateWithCapacity(UUID_6, 1, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_1_2)) ); final Set allActiveTasks = findAllActiveTasks(clientStates); @@ -543,10 +658,10 @@ public void shouldDistributeClientsOnDifferentZoneTagsEvenWhenClientsReachedCapa @Test public void shouldIgnoreTagsThatAreNotPresentInRackAwareness() { final Map clientStates = mkMap( - mkEntry(UUID_1, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_0)), - mkEntry(UUID_2, createClientStateWithCapacity(2, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_2)))), + mkEntry(UUID_1, createClientStateWithCapacity(UUID_1, 1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_0)), + mkEntry(UUID_2, createClientStateWithCapacity(UUID_2, 2, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_2)))), - mkEntry(UUID_3, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_2), mkEntry(ZONE_TAG, ZONE_1)))) + mkEntry(UUID_3, createClientStateWithCapacity(UUID_3, 1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_2), mkEntry(ZONE_TAG, ZONE_1)))) ); final Set allActiveTasks = findAllActiveTasks(clientStates); @@ -561,8 +676,8 @@ public void shouldIgnoreTagsThatAreNotPresentInRackAwareness() { @Test public void shouldHandleOverlappingTagValuesBetweenDifferentTagKeys() { final Map clientStates = mkMap( - mkEntry(UUID_1, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0)), - mkEntry(UUID_2, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, CLUSTER_1), mkEntry(CLUSTER_TAG, CLUSTER_3)))) + mkEntry(UUID_1, createClientStateWithCapacity(UUID_1, 2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0)), + mkEntry(UUID_2, createClientStateWithCapacity(UUID_2, 2, mkMap(mkEntry(ZONE_TAG, CLUSTER_1), mkEntry(CLUSTER_TAG, CLUSTER_3)))) ); final Set allActiveTasks = findAllActiveTasks(clientStates); @@ -585,10 +700,10 @@ public void shouldHandleOverlappingTagValuesBetweenDifferentTagKeys() { @Test public void shouldDistributeStandbyTasksOnLeastLoadedClientsWhenClientsAreNotOnDifferentTagDimensions() { final Map clientStates = mkMap( - mkEntry(UUID_1, createClientStateWithCapacity(3, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_0)), - mkEntry(UUID_2, createClientStateWithCapacity(3, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_1)), - mkEntry(UUID_3, createClientStateWithCapacity(3, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_2)), - mkEntry(UUID_4, createClientStateWithCapacity(3, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_1_0)) + mkEntry(UUID_1, createClientStateWithCapacity(UUID_1, 3, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_0)), + mkEntry(UUID_2, createClientStateWithCapacity(UUID_2, 3, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_1)), + mkEntry(UUID_3, createClientStateWithCapacity(UUID_3, 3, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_2)), + mkEntry(UUID_4, createClientStateWithCapacity(UUID_4, 3, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_1_0)) ); final Set allActiveTasks = findAllActiveTasks(clientStates); @@ -606,7 +721,7 @@ public void shouldDistributeStandbyTasksOnLeastLoadedClientsWhenClientsAreNotOnD @Test public void shouldNotAssignStandbyTasksIfThereAreNoEnoughClients() { final Map clientStates = mkMap( - mkEntry(UUID_1, createClientStateWithCapacity(3, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_0)) + mkEntry(UUID_1, createClientStateWithCapacity(UUID_1, 3, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_0)) ); final Set allActiveTasks = findAllActiveTasks(clientStates); @@ -644,6 +759,13 @@ private static boolean standbyClientsHonorRackAwareness(final TaskId activeTaskI .count() == 1; } + private static boolean containsStandbyTasks(final TaskId activeTaskId, + final Map clientStates, + final Set validClientIdsBasedOnRackAwareAssignmentTags) { + final Set standbyTaskClientIds = findAllStandbyTaskClients(clientStates, activeTaskId); + return validClientIdsBasedOnRackAwareAssignmentTags.containsAll(standbyTaskClientIds); + } + private static Set findAllStandbyTaskClients(final Map clientStates, final TaskId task) { return clientStates.keySet() .stream() @@ -660,10 +782,11 @@ private static AssignmentConfigs newAssignmentConfigs(final int numStandbyReplic asList(rackAwareAssignmentTags)); } - private static ClientState createClientStateWithCapacity(final int capacity, + private static ClientState createClientStateWithCapacity(final UUID processId, + final int capacity, final Map clientTags, final TaskId... tasks) { - final ClientState clientState = new ClientState(capacity, clientTags); + final ClientState clientState = new ClientState(processId, capacity, clientTags); Optional.ofNullable(tasks).ifPresent(t -> clientState.assignActiveTasks(asList(t))); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignorTest.java index 554e3461d03d..572cc1e15aaa 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignorTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignorTest.java @@ -176,8 +176,7 @@ public void shouldDisableActiveWhenMissingClusterInfo() { // False since partitionWithoutInfo10 is missing in cluster metadata assertFalse(assignor.canEnableRackAwareAssignor()); - assertFalse(assignor.populateTopicsToDiscribe(new HashSet<>())); - assertTrue(assignor.validateClientRack()); + assertFalse(assignor.populateTopicsToDescribe(new HashSet<>())); } @Test @@ -191,14 +190,13 @@ public void shouldDisableActiveWhenRackMissingInNode() { new AssignorConfiguration(streamsConfig.originals()).assignmentConfigs() ); - assertTrue(assignor.validateClientRack()); - assertFalse(assignor.populateTopicsToDiscribe(new HashSet<>())); + assertFalse(assignor.populateTopicsToDescribe(new HashSet<>())); // False since nodeMissingRack has one node which doesn't have rack assertFalse(assignor.canEnableRackAwareAssignor()); } @Test - public void shouldDisableActiveWhenRackMissingInClient() { + public void shouldReturnInvalidClientRackWhenRackMissingInClientConsumer() { final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor( getClusterForTopic0(), getTaskTopicPartitionMapForTask0(), @@ -207,14 +205,41 @@ public void shouldDisableActiveWhenRackMissingInClient() { mockInternalTopicManager, new AssignorConfiguration(streamsConfig.originals()).assignmentConfigs() ); + // False since process1 doesn't have rackId + assertFalse(assignor.validClientRack()); + } + @Test + public void shouldReturnFalseWhenRackMissingInProcess() { + final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor( + getClusterForTopic0(), + getTaskTopicPartitionMapForTask0(), + getTopologyGroupTaskMap(), + getProcessWithNoConsumerRacks(), + mockInternalTopicManager, + new AssignorConfiguration(streamsConfig.originals()).assignmentConfigs() + ); // False since process1 doesn't have rackId - assertFalse(assignor.validateClientRack()); - assertFalse(assignor.canEnableRackAwareAssignor()); + assertFalse(assignor.validClientRack()); } @Test - public void shouldDisableActiveWhenRackDiffersInSameProcess() { + public void shouldPopulateRacksForProcess() { + // Throws since process1 doesn't have rackId + final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor( + getClusterForTopic0(), + getTaskTopicPartitionMapForTask0(), + getTopologyGroupTaskMap(), + getProcessRacksForProcess0(), + mockInternalTopicManager, + new AssignorConfiguration(streamsConfig.originals()).assignmentConfigs() + ); + final Map racksForProcess = assignor.racksForProcess(); + assertEquals(mkMap(mkEntry(UUID_1, RACK_1)), racksForProcess); + } + + @Test + public void shouldReturnInvalidClientRackWhenRackDiffersInSameProcess() { final Map>> processRacks = new HashMap<>(); // Different consumers in same process have different rack ID. This shouldn't happen. @@ -231,12 +256,11 @@ public void shouldDisableActiveWhenRackDiffersInSameProcess() { new AssignorConfiguration(streamsConfig.originals()).assignmentConfigs() ); - assertFalse(assignor.validateClientRack()); - assertFalse(assignor.canEnableRackAwareAssignor()); + assertFalse(assignor.validClientRack()); } @Test - public void shouldEnableRackAwareAssignorForActiveWithoutDescribingTopics() { + public void shouldEnableRackAwareAssignorWithoutDescribingTopics() { final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor( getClusterForTopic0(), getTaskTopicPartitionMapForTask0(), @@ -251,7 +275,7 @@ public void shouldEnableRackAwareAssignorForActiveWithoutDescribingTopics() { } @Test - public void shouldEnableRackAwareAssignorForActiveWithDescribingTopics() { + public void shouldEnableRackAwareAssignorWithDescribingTopics() { final MockInternalTopicManager spyTopicManager = spy(mockInternalTopicManager); doReturn( Collections.singletonMap( @@ -275,7 +299,7 @@ public void shouldEnableRackAwareAssignorForActiveWithDescribingTopics() { } @Test - public void shouldDisableRackAwareAssignorForActiveWithDescribingTopicsFailure() { + public void shouldDisableRackAwareAssignorWithDescribingTopicsFailure() { final MockInternalTopicManager spyTopicManager = spy(mockInternalTopicManager); doThrow(new TimeoutException("Timeout describing topic")).when(spyTopicManager).getTopicPartitionInfo(Collections.singleton( TP_0_NAME)); @@ -290,7 +314,7 @@ public void shouldDisableRackAwareAssignorForActiveWithDescribingTopicsFailure() ); assertFalse(assignor.canEnableRackAwareAssignor()); - assertTrue(assignor.populateTopicsToDiscribe(new HashSet<>())); + assertTrue(assignor.populateTopicsToDescribe(new HashSet<>())); } @Test @@ -314,10 +338,10 @@ public void shouldOptimizeEmptyActiveTasks() { final SortedSet taskIds = mkSortedSet(); assertTrue(assignor.canEnableRackAwareAssignor()); - final long originalCost = assignor.activeTasksCost(clientStateMap, taskIds, trafficCost, nonOverlapCost); + final long originalCost = assignor.activeTasksCost(taskIds, clientStateMap, trafficCost, nonOverlapCost); assertEquals(0, originalCost); - final long cost = assignor.optimizeActiveTasks(clientStateMap, taskIds, trafficCost, nonOverlapCost); + final long cost = assignor.optimizeActiveTasks(taskIds, clientStateMap, trafficCost, nonOverlapCost); assertEquals(0, cost); assertEquals(mkSet(TASK_0_1, TASK_1_1), clientState0.activeTasks()); @@ -356,11 +380,11 @@ public void shouldOptimizeActiveTasks() { assertTrue(assignor.canEnableRackAwareAssignor()); int expected = stateful ? 40 : 4; - final long originalCost = assignor.activeTasksCost(clientStateMap, taskIds, trafficCost, nonOverlapCost); + final long originalCost = assignor.activeTasksCost(taskIds, clientStateMap, trafficCost, nonOverlapCost); assertEquals(expected, originalCost); expected = stateful ? 4 : 0; - final long cost = assignor.optimizeActiveTasks(clientStateMap, taskIds, trafficCost, nonOverlapCost); + final long cost = assignor.optimizeActiveTasks(taskIds, clientStateMap, trafficCost, nonOverlapCost); assertEquals(expected, cost); assertEquals(mkSet(TASK_0_0, TASK_1_0), clientState0.activeTasks()); @@ -390,8 +414,8 @@ public void shouldOptimizeRandom() { .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().activeTasks().size())); assertTrue(assignor.canEnableRackAwareAssignor()); - final long originalCost = assignor.activeTasksCost(clientStateMap, taskIds, trafficCost, nonOverlapCost); - final long cost = assignor.optimizeActiveTasks(clientStateMap, taskIds, trafficCost, nonOverlapCost); + final long originalCost = assignor.activeTasksCost(taskIds, clientStateMap, trafficCost, nonOverlapCost); + final long cost = assignor.optimizeActiveTasks(taskIds, clientStateMap, trafficCost, nonOverlapCost); assertThat(cost, lessThanOrEqualTo(originalCost)); for (final Entry entry : clientStateMap.entrySet()) { @@ -425,10 +449,10 @@ public void shouldMaintainOriginalAssignment() { // Because trafficCost is 0, original assignment should be maintained assertTrue(assignor.canEnableRackAwareAssignor()); - final long originalCost = assignor.activeTasksCost(clientStateMap, taskIds, 0, 1); + final long originalCost = assignor.activeTasksCost(taskIds, clientStateMap, 0, 1); assertEquals(0, originalCost); - final long cost = assignor.optimizeActiveTasks(clientStateMap, taskIds, 0, 1); + final long cost = assignor.optimizeActiveTasks(taskIds, clientStateMap, 0, 1); assertEquals(0, cost); // Make sure assignment doesn't change @@ -468,12 +492,12 @@ public void shouldOptimizeActiveTasksWithMoreClients() { final SortedSet taskIds = mkSortedSet(TASK_0_0, TASK_1_0); assertTrue(assignor.canEnableRackAwareAssignor()); - final long originalCost = assignor.activeTasksCost(clientStateMap, taskIds, trafficCost, nonOverlapCost); + final long originalCost = assignor.activeTasksCost(taskIds, clientStateMap, trafficCost, nonOverlapCost); int expected = stateful ? 20 : 2; assertEquals(expected, originalCost); expected = stateful ? 2 : 0; - final long cost = assignor.optimizeActiveTasks(clientStateMap, taskIds, trafficCost, nonOverlapCost); + final long cost = assignor.optimizeActiveTasks(taskIds, clientStateMap, trafficCost, nonOverlapCost); assertEquals(expected, cost); // UUID_1 remains empty @@ -512,12 +536,12 @@ public void shouldOptimizeActiveTasksWithMoreClientsWithMoreThanOneTask() { final SortedSet taskIds = mkSortedSet(TASK_0_0, TASK_0_1, TASK_1_0); assertTrue(assignor.canEnableRackAwareAssignor()); - final long originalCost = assignor.activeTasksCost(clientStateMap, taskIds, trafficCost, nonOverlapCost); + final long originalCost = assignor.activeTasksCost(taskIds, clientStateMap, trafficCost, nonOverlapCost); int expected = stateful ? 20 : 2; assertEquals(expected, originalCost); expected = stateful ? 2 : 0; - final long cost = assignor.optimizeActiveTasks(clientStateMap, taskIds, trafficCost, nonOverlapCost); + final long cost = assignor.optimizeActiveTasks(taskIds, clientStateMap, trafficCost, nonOverlapCost); assertEquals(expected, cost); // Because original assignment is not balanced (3 tasks but client 0 has no task), we maintain it @@ -555,10 +579,10 @@ public void shouldBalanceAssignmentWithMoreCost() { assertTrue(assignor.canEnableRackAwareAssignor()); final int expectedCost = stateful ? 10 : 1; - final long originalCost = assignor.activeTasksCost(clientStateMap, taskIds, trafficCost, nonOverlapCost); + final long originalCost = assignor.activeTasksCost(taskIds, clientStateMap, trafficCost, nonOverlapCost); assertEquals(expectedCost, originalCost); - final long cost = assignor.optimizeActiveTasks(clientStateMap, taskIds, trafficCost, nonOverlapCost); + final long cost = assignor.optimizeActiveTasks(taskIds, clientStateMap, trafficCost, nonOverlapCost); assertEquals(expectedCost, cost); // Even though assigning all tasks to UUID_2 will result in min cost, but it's not balanced @@ -590,9 +614,9 @@ public void shouldThrowIfMissingCallcanEnableRackAwareAssignor() { )); final SortedSet taskIds = mkSortedSet(TASK_0_0, TASK_0_1, TASK_1_1); final Exception exception = assertThrows(IllegalStateException.class, - () -> assignor.optimizeActiveTasks(clientStateMap, taskIds, trafficCost, nonOverlapCost)); - Assertions.assertEquals("TopicPartition topic0-0 has no rack information. Maybe forgot to call " - + "canEnableRackAwareAssignor first", exception.getMessage()); + () -> assignor.optimizeActiveTasks(taskIds, clientStateMap, trafficCost, nonOverlapCost)); + Assertions.assertEquals("TopicPartition topic0-0 has no rack information. " + + "Maybe forgot to call canEnableRackAwareAssignor first", exception.getMessage()); } @Test @@ -619,7 +643,7 @@ public void shouldThrowIfTaskInMultipleClients() { final SortedSet taskIds = mkSortedSet(TASK_0_0, TASK_0_1, TASK_1_1); assertTrue(assignor.canEnableRackAwareAssignor()); final Exception exception = assertThrows(IllegalArgumentException.class, - () -> assignor.optimizeActiveTasks(clientStateMap, taskIds, trafficCost, nonOverlapCost)); + () -> assignor.optimizeActiveTasks(taskIds, clientStateMap, trafficCost, nonOverlapCost)); Assertions.assertEquals( "Task 1_1 assigned to multiple clients 00000000-0000-0000-0000-000000000005, " + "00000000-0000-0000-0000-000000000002", exception.getMessage()); @@ -649,7 +673,7 @@ public void shouldThrowIfTaskMissingInClients() { final SortedSet taskIds = mkSortedSet(TASK_0_0, TASK_0_1, TASK_1_0, TASK_1_1); assertTrue(assignor.canEnableRackAwareAssignor()); final Exception exception = assertThrows(IllegalArgumentException.class, - () -> assignor.optimizeActiveTasks(clientStateMap, taskIds, trafficCost, nonOverlapCost)); + () -> assignor.optimizeActiveTasks(taskIds, clientStateMap, trafficCost, nonOverlapCost)); Assertions.assertEquals( "Task 1_0 not assigned to any client", exception.getMessage()); } @@ -787,6 +811,12 @@ private Map>> getProcessRacksForProcess0(fina return processRacks; } + private Map>> getProcessWithNoConsumerRacks() { + return mkMap( + mkEntry(UUID_1, mkMap()) + ); + } + private Map> getTaskTopicPartitionMapForTask0() { return getTaskTopicPartitionMapForTask0(false); } diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignorFactoryTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignorFactoryTest.java index fdd7fa1d4737..4c4b35f83a2e 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignorFactoryTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignorFactoryTest.java @@ -16,30 +16,91 @@ */ package org.apache.kafka.streams.processor.internals.assignment; +import java.util.Arrays; +import java.util.Collection; +import org.junit.Before; import org.junit.Test; import java.util.Collections; import java.util.List; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.mockito.quality.Strictness; import static java.util.Collections.singletonList; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +@RunWith(Parameterized.class) public class StandbyTaskAssignorFactoryTest { + @org.junit.Rule + public MockitoRule rule = MockitoJUnit.rule().strictness(Strictness.LENIENT); + private static final long ACCEPTABLE_RECOVERY_LAG = 0L; private static final int MAX_WARMUP_REPLICAS = 1; private static final int NUMBER_OF_STANDBY_REPLICAS = 1; private static final long PROBING_REBALANCE_INTERVAL_MS = 60000L; + enum State { + DISABLED, + ENABLED, + NULL + } + + private RackAwareTaskAssignor rackAwareTaskAssignor; + + @Parameter + public State state; + + @Parameterized.Parameters(name = "RackAwareTaskAssignor={0}") + public static Collection parameters() { + return Arrays.asList(State.DISABLED, State.ENABLED, State.NULL); + } + + @Before + public void setUp() { + if (state == State.ENABLED) { + rackAwareTaskAssignor = mock(RackAwareTaskAssignor.class); + when(rackAwareTaskAssignor.validClientRack()).thenReturn(true); + } else if (state == State.DISABLED) { + rackAwareTaskAssignor = mock(RackAwareTaskAssignor.class); + when(rackAwareTaskAssignor.validClientRack()).thenReturn(false); + } else { + rackAwareTaskAssignor = null; + } + } + @Test public void shouldReturnClientTagAwareStandbyTaskAssignorWhenRackAwareAssignmentTagsIsSet() { - final StandbyTaskAssignor standbyTaskAssignor = StandbyTaskAssignorFactory.create(newAssignmentConfigs(singletonList("az"))); + final StandbyTaskAssignor standbyTaskAssignor = StandbyTaskAssignorFactory.create(newAssignmentConfigs(singletonList("az")), rackAwareTaskAssignor); assertTrue(standbyTaskAssignor instanceof ClientTagAwareStandbyTaskAssignor); + if (state != State.NULL) { + verify(rackAwareTaskAssignor, never()).racksForProcess(); + verify(rackAwareTaskAssignor, never()).validClientRack(); + } } @Test - public void shouldReturnDefaultStandbyTaskAssignorWhenRackAwareAssignmentTagsIsEmpty() { - final StandbyTaskAssignor standbyTaskAssignor = StandbyTaskAssignorFactory.create(newAssignmentConfigs(Collections.emptyList())); - assertTrue(standbyTaskAssignor instanceof DefaultStandbyTaskAssignor); + public void shouldReturnDefaultOrRackAwareStandbyTaskAssignorWhenRackAwareAssignmentTagsIsEmpty() { + final StandbyTaskAssignor standbyTaskAssignor = StandbyTaskAssignorFactory.create(newAssignmentConfigs(Collections.emptyList()), rackAwareTaskAssignor); + if (state == State.ENABLED) { + assertTrue(standbyTaskAssignor instanceof ClientTagAwareStandbyTaskAssignor); + verify(rackAwareTaskAssignor, times(1)).racksForProcess(); + verify(rackAwareTaskAssignor, times(1)).validClientRack(); + } else if (state == State.DISABLED) { + assertTrue(standbyTaskAssignor instanceof DefaultStandbyTaskAssignor); + verify(rackAwareTaskAssignor, never()).racksForProcess(); + verify(rackAwareTaskAssignor, times(1)).validClientRack(); + } else { + assertTrue(standbyTaskAssignor instanceof DefaultStandbyTaskAssignor); + } } private static AssignorConfiguration.AssignmentConfigs newAssignmentConfigs(final List rackAwareAssignmentTags) {