Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KAFKA-2763: better stream task assignment #497

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -29,6 +29,7 @@
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.UUID;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For easy trouble-shooting and debugging, could we add the host name as prefix to the UUID.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the consumer group protocol supports such debugging? Identifying a problematic host is general enough to be considered by KafkaConsumer implementation.

import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

Expand Down Expand Up @@ -85,11 +86,13 @@ public class KafkaStreaming {
private final StreamThread[] threads;

private String clientId;
private final UUID uuid;
private final Metrics metrics;

public KafkaStreaming(TopologyBuilder builder, StreamingConfig config) throws Exception {
// create the metrics
this.time = new SystemTime();
this.uuid = UUID.randomUUID();

MetricConfig metricConfig = new MetricConfig().samples(config.getInt(StreamingConfig.METRICS_NUM_SAMPLES_CONFIG))
.timeWindow(config.getLong(StreamingConfig.METRICS_SAMPLE_WINDOW_MS_CONFIG),
Expand All @@ -104,7 +107,7 @@ public KafkaStreaming(TopologyBuilder builder, StreamingConfig config) throws Ex

this.threads = new StreamThread[config.getInt(StreamingConfig.NUM_STREAM_THREADS_CONFIG)];
for (int i = 0; i < this.threads.length; i++) {
this.threads[i] = new StreamThread(builder, config, this.clientId, this.metrics, this.time);
this.threads[i] = new StreamThread(builder, config, this.clientId, this.uuid, this.metrics, this.time);
}
}

Expand Down
Expand Up @@ -27,8 +27,8 @@
import org.apache.kafka.common.serialization.Deserializer;
import org.apache.kafka.common.serialization.Serializer;
import org.apache.kafka.streams.processor.DefaultPartitionGrouper;
import org.apache.kafka.streams.processor.PartitionGrouper;
import org.apache.kafka.streams.processor.internals.KafkaStreamingPartitionAssignor;
import org.apache.kafka.streams.processor.internals.StreamThread;

import java.util.Map;

Expand Down Expand Up @@ -205,16 +205,16 @@ public class StreamingConfig extends AbstractConfig {
}

public static class InternalConfig {
public static final String PARTITION_GROUPER_INSTANCE = "__partition.grouper.instance__";
public static final String STREAM_THREAD_INSTANCE = "__stream.thread.instance__";
}

public StreamingConfig(Map<?, ?> props) {
super(CONFIG, props);
}

public Map<String, Object> getConsumerConfigs(PartitionGrouper partitionGrouper) {
public Map<String, Object> getConsumerConfigs(StreamThread streamThread) {
Map<String, Object> props = getConsumerConfigs();
props.put(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE, partitionGrouper);
props.put(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, streamThread);
props.put(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, KafkaStreamingPartitionAssignor.class.getName());
return props;
}
Expand Down
Expand Up @@ -50,4 +50,8 @@ public Set<TaskId> taskIds(TopicPartition partition) {
return partitionAssignor.taskIds(partition);
}

public Set<TaskId> standbyTasks() {
return partitionAssignor.standbyTasks();
}

}
Expand Up @@ -17,7 +17,9 @@

package org.apache.kafka.streams.processor;

public class TaskId {
import java.nio.ByteBuffer;

public class TaskId implements Comparable<TaskId> {

public final int topicGroupId;
public final int partition;
Expand Down Expand Up @@ -45,6 +47,15 @@ public static TaskId parse(String string) {
}
}

public void writeTo(ByteBuffer buf) {
buf.putInt(topicGroupId);
buf.putInt(partition);
}

public static TaskId readFrom(ByteBuffer buf) {
return new TaskId(buf.getInt(), buf.getInt());
}

@Override
public boolean equals(Object o) {
if (o instanceof TaskId) {
Expand All @@ -61,6 +72,16 @@ public int hashCode() {
return (int) (n % 0xFFFFFFFFL);
}

@Override
public int compareTo(TaskId other) {
return
this.topicGroupId < other.topicGroupId ? -1 :
(this.topicGroupId > other.topicGroupId ? 1 :
(this.partition < other.partition ? -1 :
(this.partition > other.partition ? 1 :
0)));
}

public static class TaskIdFormatException extends RuntimeException {
}
}
Expand Up @@ -23,37 +23,49 @@
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.streams.StreamingConfig;
import org.apache.kafka.streams.processor.PartitionGrouper;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
import org.apache.kafka.streams.processor.internals.assignment.ClientState;
import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo;
import org.apache.kafka.streams.processor.internals.assignment.TaskAssignmentException;
import org.apache.kafka.streams.processor.internals.assignment.TaskAssignor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;

public class KafkaStreamingPartitionAssignor implements PartitionAssignor, Configurable {

private static final Logger log = LoggerFactory.getLogger(KafkaStreamingPartitionAssignor.class);

private PartitionGrouper partitionGrouper;
private StreamThread streamThread;
private Map<TopicPartition, Set<TaskId>> partitionToTaskIds;
private Set<TaskId> standbyTasks;

@Override
public void configure(Map<String, ?> configs) {
Object o = configs.get(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE);
if (o == null)
throw new KafkaException("PartitionGrouper is not specified");
Object o = configs.get(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE);
if (o == null) {
KafkaException ex = new KafkaException("StreamThread is not specified");
log.error(ex.getMessage(), ex);
throw ex;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since log.error(.. ex) will print the stack trace already, may be we can save re-throwing the exception again.

EDIT: if we want to stop the whole process by throwing the exception, we can then save log.error().

}

if (!PartitionGrouper.class.isInstance(o))
throw new KafkaException(o.getClass().getName() + " is not an instance of " + PartitionGrouper.class.getName());
if (!(o instanceof StreamThread)) {
KafkaException ex = new KafkaException(o.getClass().getName() + " is not an instance of " + StreamThread.class.getName());
log.error(ex.getMessage(), ex);
throw ex;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto here.

}

partitionGrouper = (PartitionGrouper) o;
partitionGrouper.partitionAssignor(this);
streamThread = (StreamThread) o;
streamThread.partitionGrouper.partitionAssignor(this);
}

@Override
Expand All @@ -63,38 +75,110 @@ public String name() {

@Override
public Subscription subscription(Set<String> topics) {
return new Subscription(new ArrayList<>(topics));
// Adds the following information to subscription
// 1. Client UUID (a unique id assigned to an instance of KafkaStreaming)
// 2. Task ids of previously running tasks
// 3. Task ids of valid local states on the client's state directory.

Set<TaskId> prevTasks = streamThread.prevTasks();
Set<TaskId> standbyTasks = streamThread.cachedTasks();
standbyTasks.removeAll(prevTasks);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can leave some comments in here and in the assignor's logic with the descriptions in the PR. For example, under which scenario would a client have some checkpoints of tasks that have already been migrated out of this client.

SubscriptionInfo data = new SubscriptionInfo(streamThread.clientUUID, prevTasks, standbyTasks);

return new Subscription(new ArrayList<>(topics), data.encode());
}

@Override
public Map<String, Assignment> assign(Cluster metadata, Map<String, Subscription> subscriptions) {
Map<TaskId, Set<TopicPartition>> partitionGroups = partitionGrouper.partitionGroups(metadata);
// This assigns tasks to consumer clients in two steps.
// 1. using TaskAssignor tasks are assigned to streaming clients.
// - Assign a task to a client which was running it previously.
// If there is no such client, assign a task to a client which has its valid local state.
// - A client may have more than one stream threads.
// The assignor tries to assign tasks to a client proportionally to the number of threads.
// - We try not to assign the same set of tasks to two different clients
// We do the assignment in one-pass. The result may not satisfy above all.
// 2. within each client, tasks are assigned to consumer clients in round-robin manner.

Map<UUID, Set<String>> consumersByClient = new HashMap<>();
Map<UUID, ClientState<TaskId>> states = new HashMap<>();

// Decode subscription info
for (Map.Entry<String, Subscription> entry : subscriptions.entrySet()) {
String consumerId = entry.getKey();
Subscription subscription = entry.getValue();

SubscriptionInfo info = SubscriptionInfo.decode(subscription.userData());

Set<String> consumers = consumersByClient.get(info.clientUUID);
if (consumers == null) {
consumers = new HashSet<>();
consumersByClient.put(info.clientUUID, consumers);
}
consumers.add(consumerId);

ClientState<TaskId> state = states.get(info.clientUUID);
if (state == null) {
state = new ClientState<>();
states.put(info.clientUUID, state);
}

state.prevActiveTasks.addAll(info.prevTasks);
state.prevAssignedTasks.addAll(info.prevTasks);
state.prevAssignedTasks.addAll(info.standbyTasks);
state.capacity = state.capacity + 1d;
}

String[] clientIds = subscriptions.keySet().toArray(new String[subscriptions.size()]);
TaskId[] taskIds = partitionGroups.keySet().toArray(new TaskId[partitionGroups.size()]);
// Get partition groups from the partition grouper
Map<TaskId, Set<TopicPartition>> partitionGroups = streamThread.partitionGrouper.partitionGroups(metadata);

states = TaskAssignor.assign(states, partitionGroups.keySet(), 0); // TODO: enable standby tasks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add some comments explaining about the two-level assignment algorithm? 1) load-based task-id to client assignment, and then 2) within each client, round-robin task-id to consumer assignment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Map<String, Assignment> assignment = new HashMap<>();

for (int i = 0; i < clientIds.length; i++) {
List<TopicPartition> partitions = new ArrayList<>();
List<TaskId> ids = new ArrayList<>();
for (int j = i; j < taskIds.length; j += clientIds.length) {
TaskId taskId = taskIds[j];
for (TopicPartition partition : partitionGroups.get(taskId)) {
partitions.add(partition);
ids.add(taskId);
}
for (Map.Entry<UUID, Set<String>> entry : consumersByClient.entrySet()) {
UUID uuid = entry.getKey();
Set<String> consumers = entry.getValue();
ClientState<TaskId> state = states.get(uuid);

ArrayList<TaskId> taskIds = new ArrayList<>(state.assignedTasks.size());
final int numActiveTasks = state.activeTasks.size();
for (TaskId id : state.activeTasks) {
taskIds.add(id);
}
ByteBuffer buf = ByteBuffer.allocate(4 + ids.size() * 8);
//version
buf.putInt(1);
// encode task ids
for (TaskId id : ids) {
buf.putInt(id.topicGroupId);
buf.putInt(id.partition);
for (TaskId id : state.assignedTasks) {
if (!state.activeTasks.contains(id))
taskIds.add(id);
}

final int numConsumers = consumers.size();
List<TaskId> active = new ArrayList<>();
Set<TaskId> standby = new HashSet<>();

int i = 0;
for (String consumer : consumers) {
List<TopicPartition> partitions = new ArrayList<>();

final int numTaskIds = taskIds.size();
for (int j = i; j < numTaskIds; j += numConsumers) {
TaskId taskId = taskIds.get(j);
if (j < numActiveTasks) {
for (TopicPartition partition : partitionGroups.get(taskId)) {
partitions.add(partition);
active.add(taskId);
}
} else {
// no partition to a standby task
standby.add(taskId);
}
}

AssignmentInfo data = new AssignmentInfo(active, standby);
assignment.put(consumer, new Assignment(partitions, data.encode()));
i++;

active.clear();
standby.clear();
}
buf.rewind();
assignment.put(clientIds[i], new Assignment(partitions, buf));
}

return assignment;
Expand All @@ -103,27 +187,29 @@ public Map<String, Assignment> assign(Cluster metadata, Map<String, Subscription
@Override
public void onAssignment(Assignment assignment) {
List<TopicPartition> partitions = assignment.partitions();
ByteBuffer data = assignment.userData();
data.rewind();

AssignmentInfo info = AssignmentInfo.decode(assignment.userData());
this.standbyTasks = info.standbyTasks;

Map<TopicPartition, Set<TaskId>> partitionToTaskIds = new HashMap<>();
Iterator<TaskId> iter = info.activeTasks.iterator();
for (TopicPartition partition : partitions) {
Set<TaskId> taskIds = partitionToTaskIds.get(partition);
if (taskIds == null) {
taskIds = new HashSet<>();
partitionToTaskIds.put(partition, taskIds);
}

// check version
int version = data.getInt();
if (version == 1) {
for (TopicPartition partition : partitions) {
Set<TaskId> taskIds = partitionToTaskIds.get(partition);
if (taskIds == null) {
taskIds = new HashSet<>();
partitionToTaskIds.put(partition, taskIds);
}
// decode a task id
taskIds.add(new TaskId(data.getInt(), data.getInt()));
if (iter.hasNext()) {
taskIds.add(iter.next());
} else {
TaskAssignmentException ex = new TaskAssignmentException(
"failed to find a task id for the partition=" + partition.toString() +
", partitions=" + partitions.size() + ", assignmentInfo=" + info.toString()
);
log.error(ex.getMessage(), ex);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

throw ex;
}
} else {
KafkaException ex = new KafkaException("unknown assignment data version: " + version);
log.error(ex.getMessage(), ex);
throw ex;
}
this.partitionToTaskIds = partitionToTaskIds;
}
Expand All @@ -132,4 +218,7 @@ public Set<TaskId> taskIds(TopicPartition partition) {
return partitionToTaskIds.get(partition);
}

public Set<TaskId> standbyTasks() {
return standbyTasks;
}
}