Skip to content

Commit

Permalink
[FLINK-8516] [kinesis] Allow for custom hash function for shard to su…
Browse files Browse the repository at this point in the history
…btask mapping in Kinesis consumer

This closes #5393.
  • Loading branch information
tweise authored and tzulitai committed Feb 15, 2018
1 parent 35ee062 commit 942649e
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 13 deletions.
5 changes: 5 additions & 0 deletions docs/dev/connectors/kinesis.md
Expand Up @@ -119,6 +119,11 @@ then some consumer subtasks will simply be idle and wait until it gets assigned
new shards (i.e., when the streams are resharded to increase the new shards (i.e., when the streams are resharded to increase the
number of shards for higher provisioned Kinesis service throughput). number of shards for higher provisioned Kinesis service throughput).


Also note that the assignment of shards to subtasks may not be optimal when
shard IDs are not consecutive (as result of dynamic re-sharding in Kinesis).
For cases where skew in the assignment leads to significant imbalanced consumption,
a custom implementation of `KinesisShardAssigner` can be set on the consumer.

### Configuring Starting Position ### Configuring Starting Position


The Flink Kinesis Consumer currently provides the following options to configure where to start reading Kinesis streams, simply by setting `ConsumerConfigConstants.STREAM_INITIAL_POSITION` to The Flink Kinesis Consumer currently provides the following options to configure where to start reading Kinesis streams, simply by setting `ConsumerConfigConstants.STREAM_INITIAL_POSITION` to
Expand Down
Expand Up @@ -24,6 +24,7 @@
import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.ClosureCleaner;
import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable; import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TupleTypeInfo;
Expand Down Expand Up @@ -68,6 +69,15 @@
* low-level control on the management of stream state. The Flink Kinesis Connector also supports setting the initial * low-level control on the management of stream state. The Flink Kinesis Connector also supports setting the initial
* starting points of Kinesis streams, namely TRIM_HORIZON and LATEST.</p> * starting points of Kinesis streams, namely TRIM_HORIZON and LATEST.</p>
* *
* <p>Kinesis and the Flink consumer support dynamic re-sharding and shard IDs, while sequential,
* cannot be assumed to be consecutive. There is no perfect generic default assignment function.
* Default shard to subtask assignment, which is based on hash code, may result in skew,
* with some subtasks having many shards assigned and others none.
*
* <p>It is recommended to monitor the shard distribution and adjust assignment appropriately.
* A custom assigner implementation can be set via {@link #setShardAssigner(KinesisShardAssigner)} to optimize the
* hash function or use static overrides to limit skew.
*
* @param <T> the type of data emitted * @param <T> the type of data emitted
*/ */
@PublicEvolving @PublicEvolving
Expand All @@ -93,6 +103,11 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
/** User supplied deserialization schema to convert Kinesis byte messages to Flink objects. */ /** User supplied deserialization schema to convert Kinesis byte messages to Flink objects. */
private final KinesisDeserializationSchema<T> deserializer; private final KinesisDeserializationSchema<T> deserializer;


/**
* The function that determines which subtask a shard should be assigned to.
*/
private KinesisShardAssigner shardAssigner = KinesisDataFetcher.DEFAULT_SHARD_ASSIGNER;

// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
// Runtime state // Runtime state
// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
Expand Down Expand Up @@ -192,6 +207,19 @@ public FlinkKinesisConsumer(List<String> streams, KinesisDeserializationSchema<T
} }
} }


public KinesisShardAssigner getShardAssigner() {
return shardAssigner;
}

/**
* Provide a custom assigner to influence how shards are distributed over subtasks.
* @param shardAssigner
*/
public void setShardAssigner(KinesisShardAssigner shardAssigner) {
this.shardAssigner = checkNotNull(shardAssigner, "function can not be null");
ClosureCleaner.clean(shardAssigner, true);
}

// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
// Source life cycle // Source life cycle
// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
Expand Down Expand Up @@ -351,9 +379,11 @@ public void snapshotState(FunctionSnapshotContext context) throws Exception {
for (Map.Entry<StreamShardMetadata.EquivalenceWrapper, SequenceNumber> entry : sequenceNumsToRestore.entrySet()) { for (Map.Entry<StreamShardMetadata.EquivalenceWrapper, SequenceNumber> entry : sequenceNumsToRestore.entrySet()) {
// sequenceNumsToRestore is the restored global union state; // sequenceNumsToRestore is the restored global union state;
// should only snapshot shards that actually belong to us // should only snapshot shards that actually belong to us

int hashCode = shardAssigner.assign(
KinesisDataFetcher.convertToStreamShardHandle(entry.getKey().getShardMetadata()),
getRuntimeContext().getNumberOfParallelSubtasks());
if (KinesisDataFetcher.isThisSubtaskShouldSubscribeTo( if (KinesisDataFetcher.isThisSubtaskShouldSubscribeTo(
KinesisDataFetcher.convertToStreamShardHandle(entry.getKey().getShardMetadata()), hashCode,
getRuntimeContext().getNumberOfParallelSubtasks(), getRuntimeContext().getNumberOfParallelSubtasks(),
getRuntimeContext().getIndexOfThisSubtask())) { getRuntimeContext().getIndexOfThisSubtask())) {


Expand Down Expand Up @@ -384,7 +414,7 @@ protected KinesisDataFetcher<T> createFetcher(
Properties configProps, Properties configProps,
KinesisDeserializationSchema<T> deserializationSchema) { KinesisDeserializationSchema<T> deserializationSchema) {


return new KinesisDataFetcher<>(streams, sourceContext, runtimeContext, configProps, deserializationSchema); return new KinesisDataFetcher<>(streams, sourceContext, runtimeContext, configProps, deserializationSchema, shardAssigner);
} }


@VisibleForTesting @VisibleForTesting
Expand Down
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.connectors.kinesis;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle;

import java.io.Serializable;

/**
* Utility to map Kinesis shards to Flink subtask indices. Users can implement this interface to optimize
* distribution of shards over subtasks. See {@link #assign(StreamShardHandle, int)} for details.
*/
@PublicEvolving
public interface KinesisShardAssigner extends Serializable {

/**
* Returns the index of the target subtask that a specific shard should be
* assigned to. For return values outside the subtask range, modulus operation will
* be applied automatically, hence it is also valid to just return a hash code.
*
* <p>The resulting distribution of shards should have the following contract:
* <ul>
* <li>1. Uniform distribution across subtasks</li>
* <li>2. Deterministic, calls for a given shard always return same index.</li>
* </ul>
*
* <p>The above contract is crucial and cannot be broken. Consumer subtasks rely on this
* contract to filter out shards that they should not subscribe to, guaranteeing
* that each shard of a stream will always be assigned to one subtask in a
* uniformly distributed manner.
*
* @param shard the shard to determine
* @param numParallelSubtasks total number of subtasks
* @return target index, if index falls outside of the range, modulus operation will be applied
*/
int assign(StreamShardHandle shard, int numParallelSubtasks);
}
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.metrics.MetricGroup; import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.connectors.kinesis.KinesisShardAssigner;
import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants; import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
import org.apache.flink.streaming.connectors.kinesis.metrics.KinesisConsumerMetricConstants; import org.apache.flink.streaming.connectors.kinesis.metrics.KinesisConsumerMetricConstants;
import org.apache.flink.streaming.connectors.kinesis.metrics.ShardMetricsReporter; import org.apache.flink.streaming.connectors.kinesis.metrics.ShardMetricsReporter;
Expand Down Expand Up @@ -78,6 +79,8 @@
@Internal @Internal
public class KinesisDataFetcher<T> { public class KinesisDataFetcher<T> {


public static final KinesisShardAssigner DEFAULT_SHARD_ASSIGNER = (shard, subtasks) -> shard.hashCode();

private static final Logger LOG = LoggerFactory.getLogger(KinesisDataFetcher.class); private static final Logger LOG = LoggerFactory.getLogger(KinesisDataFetcher.class);


// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
Expand All @@ -97,6 +100,11 @@ public class KinesisDataFetcher<T> {
*/ */
private final KinesisDeserializationSchema<T> deserializationSchema; private final KinesisDeserializationSchema<T> deserializationSchema;


/**
* The function that determines which subtask a shard should be assigned to.
*/
private final KinesisShardAssigner shardAssigner;

// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
// Consumer metrics // Consumer metrics
// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
Expand Down Expand Up @@ -184,13 +192,15 @@ public KinesisDataFetcher(List<String> streams,
SourceFunction.SourceContext<T> sourceContext, SourceFunction.SourceContext<T> sourceContext,
RuntimeContext runtimeContext, RuntimeContext runtimeContext,
Properties configProps, Properties configProps,
KinesisDeserializationSchema<T> deserializationSchema) { KinesisDeserializationSchema<T> deserializationSchema,
KinesisShardAssigner shardAssigner) {
this(streams, this(streams,
sourceContext, sourceContext,
sourceContext.getCheckpointLock(), sourceContext.getCheckpointLock(),
runtimeContext, runtimeContext,
configProps, configProps,
deserializationSchema, deserializationSchema,
shardAssigner,
new AtomicReference<>(), new AtomicReference<>(),
new ArrayList<>(), new ArrayList<>(),
createInitialSubscribedStreamsToLastDiscoveredShardsState(streams), createInitialSubscribedStreamsToLastDiscoveredShardsState(streams),
Expand All @@ -204,6 +214,7 @@ protected KinesisDataFetcher(List<String> streams,
RuntimeContext runtimeContext, RuntimeContext runtimeContext,
Properties configProps, Properties configProps,
KinesisDeserializationSchema<T> deserializationSchema, KinesisDeserializationSchema<T> deserializationSchema,
KinesisShardAssigner shardAssigner,
AtomicReference<Throwable> error, AtomicReference<Throwable> error,
List<KinesisStreamShardState> subscribedShardsState, List<KinesisStreamShardState> subscribedShardsState,
HashMap<String, String> subscribedStreamsToLastDiscoveredShardIds, HashMap<String, String> subscribedStreamsToLastDiscoveredShardIds,
Expand All @@ -216,6 +227,7 @@ protected KinesisDataFetcher(List<String> streams,
this.totalNumberOfConsumerSubtasks = runtimeContext.getNumberOfParallelSubtasks(); this.totalNumberOfConsumerSubtasks = runtimeContext.getNumberOfParallelSubtasks();
this.indexOfThisConsumerSubtask = runtimeContext.getIndexOfThisSubtask(); this.indexOfThisConsumerSubtask = runtimeContext.getIndexOfThisSubtask();
this.deserializationSchema = checkNotNull(deserializationSchema); this.deserializationSchema = checkNotNull(deserializationSchema);
this.shardAssigner = checkNotNull(shardAssigner);
this.kinesis = checkNotNull(kinesis); this.kinesis = checkNotNull(kinesis);


this.consumerMetricGroup = runtimeContext.getMetricGroup() this.consumerMetricGroup = runtimeContext.getMetricGroup()
Expand Down Expand Up @@ -453,7 +465,8 @@ public List<StreamShardHandle> discoverNewShardsToSubscribe() throws Interrupted
for (String stream : streamsWithNewShards) { for (String stream : streamsWithNewShards) {
List<StreamShardHandle> newShardsOfStream = shardListResult.getRetrievedShardListOfStream(stream); List<StreamShardHandle> newShardsOfStream = shardListResult.getRetrievedShardListOfStream(stream);
for (StreamShardHandle newShard : newShardsOfStream) { for (StreamShardHandle newShard : newShardsOfStream) {
if (isThisSubtaskShouldSubscribeTo(newShard, totalNumberOfConsumerSubtasks, indexOfThisConsumerSubtask)) { int hashCode = shardAssigner.assign(newShard, totalNumberOfConsumerSubtasks);
if (isThisSubtaskShouldSubscribeTo(hashCode, totalNumberOfConsumerSubtasks, indexOfThisConsumerSubtask)) {
newShardsToSubscribe.add(newShard); newShardsToSubscribe.add(newShard);
} }
} }
Expand Down Expand Up @@ -596,14 +609,14 @@ private static ShardMetricsReporter registerShardMetrics(MetricGroup metricGroup
/** /**
* Utility function to determine whether a shard should be subscribed by this consumer subtask. * Utility function to determine whether a shard should be subscribed by this consumer subtask.
* *
* @param shard the shard to determine * @param shardHash hash code for the shard
* @param totalNumberOfConsumerSubtasks total number of consumer subtasks * @param totalNumberOfConsumerSubtasks total number of consumer subtasks
* @param indexOfThisConsumerSubtask index of this consumer subtask * @param indexOfThisConsumerSubtask index of this consumer subtask
*/ */
public static boolean isThisSubtaskShouldSubscribeTo(StreamShardHandle shard, public static boolean isThisSubtaskShouldSubscribeTo(int shardHash,
int totalNumberOfConsumerSubtasks, int totalNumberOfConsumerSubtasks,
int indexOfThisConsumerSubtask) { int indexOfThisConsumerSubtask) {
return (Math.abs(shard.hashCode() % totalNumberOfConsumerSubtasks)) == indexOfThisConsumerSubtask; return (Math.abs(shardHash % totalNumberOfConsumerSubtasks)) == indexOfThisConsumerSubtask;
} }


@VisibleForTesting @VisibleForTesting
Expand Down
Expand Up @@ -17,7 +17,6 @@


package org.apache.flink.streaming.connectors.kinesis; package org.apache.flink.streaming.connectors.kinesis;


import com.amazonaws.services.kinesis.model.SequenceNumberRange;
import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.serialization.SimpleStringSchema; import org.apache.flink.api.common.serialization.SimpleStringSchema;
import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.core.testutils.OneShotLatch;
Expand All @@ -42,8 +41,8 @@
import org.apache.flink.streaming.util.migration.MigrationTestUtil; import org.apache.flink.streaming.util.migration.MigrationTestUtil;
import org.apache.flink.streaming.util.migration.MigrationVersion; import org.apache.flink.streaming.util.migration.MigrationVersion;


import com.amazonaws.services.kinesis.model.SequenceNumberRange;
import com.amazonaws.services.kinesis.model.Shard; import com.amazonaws.services.kinesis.model.Shard;

import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
Expand Down Expand Up @@ -419,7 +418,7 @@ public TestFetcher(
HashMap<StreamShardMetadata, SequenceNumber> testStateSnapshot, HashMap<StreamShardMetadata, SequenceNumber> testStateSnapshot,
List<StreamShardHandle> testInitialDiscoveryShards) { List<StreamShardHandle> testInitialDiscoveryShards) {


super(streams, sourceContext, runtimeContext, configProps, deserializationSchema); super(streams, sourceContext, runtimeContext, configProps, deserializationSchema, DEFAULT_SHARD_ASSIGNER);


this.testStateSnapshot = testStateSnapshot; this.testStateSnapshot = testStateSnapshot;
this.testInitialDiscoveryShards = testInitialDiscoveryShards; this.testInitialDiscoveryShards = testInitialDiscoveryShards;
Expand Down
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.core.testutils.CheckedThread; import org.apache.flink.core.testutils.CheckedThread;
import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer; import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer;
import org.apache.flink.streaming.connectors.kinesis.KinesisShardAssigner;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber; import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle; import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle;
Expand All @@ -33,13 +34,14 @@
import org.apache.flink.streaming.connectors.kinesis.testutils.TestSourceContext; import org.apache.flink.streaming.connectors.kinesis.testutils.TestSourceContext;
import org.apache.flink.streaming.connectors.kinesis.testutils.TestUtils; import org.apache.flink.streaming.connectors.kinesis.testutils.TestUtils;
import org.apache.flink.streaming.connectors.kinesis.testutils.TestableKinesisDataFetcher; import org.apache.flink.streaming.connectors.kinesis.testutils.TestableKinesisDataFetcher;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.TestLogger;


import com.amazonaws.services.kinesis.model.HashKeyRange; import com.amazonaws.services.kinesis.model.HashKeyRange;
import com.amazonaws.services.kinesis.model.SequenceNumberRange; import com.amazonaws.services.kinesis.model.SequenceNumberRange;
import com.amazonaws.services.kinesis.model.Shard; import com.amazonaws.services.kinesis.model.Shard;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.TestLogger;
import org.junit.Test; import org.junit.Test;
import org.powermock.reflect.Whitebox;


import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
Expand All @@ -53,6 +55,7 @@
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;


import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -644,4 +647,67 @@ public RuntimeContext getRuntimeContext() {
return context; return context;
} }
} }


// ----------------------------------------------------------------------
// Tests shard distribution with custom hash function
// ----------------------------------------------------------------------

@Test
public void testShardToSubtaskMappingWithCustomHashFunction() throws Exception {

int totalCountOfSubtasks = 10;
int shardCount = 3;

for (int i = 0; i < 2; i++) {

final int hash = i;
final KinesisShardAssigner allShardsSingleSubtaskFn = (shard, subtasks) -> hash;
Map<String, Integer> streamToShardCount = new HashMap<>();
List<String> fakeStreams = new LinkedList<>();
fakeStreams.add("fakeStream");
streamToShardCount.put("fakeStream", shardCount);

for (int j = 0; j < totalCountOfSubtasks; j++) {

int subtaskIndex = j;
// subscribe with default hashing
final TestableKinesisDataFetcher fetcher =
new TestableKinesisDataFetcher(
fakeStreams,
new TestSourceContext<>(),
new Properties(),
new KinesisDeserializationSchemaWrapper<>(new SimpleStringSchema()),
totalCountOfSubtasks,
subtaskIndex,
new AtomicReference<>(),
new LinkedList<>(),
KinesisDataFetcher.createInitialSubscribedStreamsToLastDiscoveredShardsState(fakeStreams),
FakeKinesisBehavioursFactory.nonReshardedStreamsBehaviour(streamToShardCount));
Whitebox.setInternalState(fetcher, "shardAssigner", allShardsSingleSubtaskFn); // override hashing
List<StreamShardHandle> shards = fetcher.discoverNewShardsToSubscribe();
fetcher.shutdownFetcher();

String msg = String.format("for hash=%d, subtask=%d", hash, subtaskIndex);
if (j == i) {
assertEquals(msg, shardCount, shards.size());
} else {
assertEquals(msg, 0, shards.size());
}
}

}

}

@Test
public void testIsThisSubtaskShouldSubscribeTo() {
assertTrue(KinesisDataFetcher.isThisSubtaskShouldSubscribeTo(0, 2, 0));
assertFalse(KinesisDataFetcher.isThisSubtaskShouldSubscribeTo(1, 2, 0));
assertTrue(KinesisDataFetcher.isThisSubtaskShouldSubscribeTo(2, 2, 0));
assertFalse(KinesisDataFetcher.isThisSubtaskShouldSubscribeTo(0, 2, 1));
assertTrue(KinesisDataFetcher.isThisSubtaskShouldSubscribeTo(1, 2, 1));
assertFalse(KinesisDataFetcher.isThisSubtaskShouldSubscribeTo(2, 2, 1));
}

} }
Expand Up @@ -68,6 +68,7 @@ public TestableKinesisDataFetcher(
getMockedRuntimeContext(fakeTotalCountOfSubtasks, fakeIndexOfThisSubtask), getMockedRuntimeContext(fakeTotalCountOfSubtasks, fakeIndexOfThisSubtask),
fakeConfiguration, fakeConfiguration,
deserializationSchema, deserializationSchema,
DEFAULT_SHARD_ASSIGNER,
thrownErrorUnderTest, thrownErrorUnderTest,
subscribedShardsStateUnderTest, subscribedShardsStateUnderTest,
subscribedStreamsToLastDiscoveredShardIdsStateUnderTest, subscribedStreamsToLastDiscoveredShardIdsStateUnderTest,
Expand Down

0 comments on commit 942649e

Please sign in to comment.