From 9ebec3a7242cd225626a9e17bad4e315f9cbab53 Mon Sep 17 00:00:00 2001 From: Tony Wei Date: Thu, 25 May 2017 10:39:22 +0800 Subject: [PATCH] [FLINK-6653] Avoid directly serializing AWS's Shard class in Kinesis consumer's checkpoints --- .../kinesis/FlinkKinesisConsumer.java | 71 ++++-- .../kinesis/internals/KinesisDataFetcher.java | 92 ++++++-- .../kinesis/internals/ShardConsumer.java | 8 +- .../kinesis/model/KinesisStreamShard.java | 5 +- .../model/KinesisStreamShardState.java | 21 +- .../kinesis/model/KinesisStreamShardV2.java | 171 +++++++++++++++ .../kinesis/model/StreamShardHandle.java | 129 +++++++++++ .../kinesis/proxy/GetShardListResult.java | 16 +- .../kinesis/proxy/KinesisProxy.java | 12 +- .../kinesis/proxy/KinesisProxyInterface.java | 4 +- .../FlinkKinesisConsumerMigrationTest.java | 7 +- .../kinesis/FlinkKinesisConsumerTest.java | 205 +++++++++++------- .../internals/KinesisDataFetcherTest.java | 111 +++++++--- .../kinesis/internals/ShardConsumerTest.java | 16 +- .../FakeKinesisBehavioursFactory.java | 22 +- 15 files changed, 698 insertions(+), 192 deletions(-) create mode 100644 flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/KinesisStreamShardV2.java create mode 100644 flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/StreamShardHandle.java diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java index 4982f7f39ab14..b7f550629ca63 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java @@ -38,6 +38,8 @@ import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState; import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber; import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber; +import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle; +import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardV2; import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema; import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchemaWrapper; import org.apache.flink.streaming.connectors.kinesis.util.KinesisConfigUtil; @@ -98,7 +100,7 @@ public class FlinkKinesisConsumer extends RichParallelSourceFunction imple private transient KinesisDataFetcher fetcher; /** The sequence numbers to restore to upon restore from failure */ - private transient HashMap sequenceNumsToRestore; + private transient HashMap sequenceNumsToRestore; private volatile boolean running = true; @@ -109,7 +111,7 @@ public class FlinkKinesisConsumer extends RichParallelSourceFunction imple /** State name to access shard sequence number states; cannot be changed */ private static final String sequenceNumsStateStoreName = "Kinesis-Stream-Shard-State"; - private transient ListState> sequenceNumsStateForCheckpoint; + private transient ListState> sequenceNumsStateForCheckpoint; // ------------------------------------------------------------------------ // Constructors @@ -197,25 +199,26 @@ public void run(SourceContext sourceContext) throws Exception { KinesisDataFetcher fetcher = createFetcher(streams, sourceContext, getRuntimeContext(), configProps, deserializer); // initial discovery - List allShards = fetcher.discoverNewShardsToSubscribe(); + List allShards = fetcher.discoverNewShardsToSubscribe(); - for (KinesisStreamShard shard : allShards) { + for (StreamShardHandle shard : allShards) { + KinesisStreamShardV2 kinesisStreamShard = KinesisDataFetcher.createKinesisStreamShardV2(shard); if (sequenceNumsToRestore != null) { - if (sequenceNumsToRestore.containsKey(shard)) { + if (sequenceNumsToRestore.containsKey(kinesisStreamShard)) { // if the shard was already seen and is contained in the state, // just use the sequence number stored in the state fetcher.registerNewSubscribedShardState( - new KinesisStreamShardState(shard, sequenceNumsToRestore.get(shard))); + new KinesisStreamShardState(kinesisStreamShard, shard, sequenceNumsToRestore.get(kinesisStreamShard))); if (LOG.isInfoEnabled()) { LOG.info("Subtask {} is seeding the fetcher with restored shard {}," + " starting state set to the restored sequence number {}", - getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), sequenceNumsToRestore.get(shard)); + getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), sequenceNumsToRestore.get(kinesisStreamShard)); } } else { // the shard wasn't discovered in the previous run, therefore should be consumed from the beginning fetcher.registerNewSubscribedShardState( - new KinesisStreamShardState(shard, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get())); + new KinesisStreamShardState(kinesisStreamShard, shard, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get())); if (LOG.isInfoEnabled()) { LOG.info("Subtask {} is seeding the fetcher with new discovered shard {}," + @@ -231,7 +234,7 @@ public void run(SourceContext sourceContext) throws Exception { ConsumerConfigConstants.DEFAULT_STREAM_INITIAL_POSITION)).toSentinelSequenceNumber(); fetcher.registerNewSubscribedShardState( - new KinesisStreamShardState(shard, startingSeqNum.get())); + new KinesisStreamShardState(kinesisStreamShard, shard, startingSeqNum.get())); if (LOG.isInfoEnabled()) { LOG.info("Subtask {} will be seeded with initial shard {}, starting state set as sequence number {}", @@ -295,8 +298,8 @@ public TypeInformation getProducedType() { @Override public void initializeState(FunctionInitializationContext context) throws Exception { - TypeInformation> shardsStateTypeInfo = new TupleTypeInfo<>( - TypeInformation.of(KinesisStreamShard.class), + TypeInformation> shardsStateTypeInfo = new TupleTypeInfo<>( + TypeInformation.of(KinesisStreamShardV2.class), TypeInformation.of(SequenceNumber.class)); sequenceNumsStateForCheckpoint = context.getOperatorStateStore().getUnionListState( @@ -305,7 +308,7 @@ public void initializeState(FunctionInitializationContext context) throws Except if (context.isRestored()) { if (sequenceNumsToRestore == null) { sequenceNumsToRestore = new HashMap<>(); - for (Tuple2 kinesisSequenceNumber : sequenceNumsStateForCheckpoint.get()) { + for (Tuple2 kinesisSequenceNumber : sequenceNumsStateForCheckpoint.get()) { sequenceNumsToRestore.put(kinesisSequenceNumber.f0, kinesisSequenceNumber.f1); } @@ -330,12 +333,12 @@ public void snapshotState(FunctionSnapshotContext context) throws Exception { if (fetcher == null) { if (sequenceNumsToRestore != null) { - for (Map.Entry entry : sequenceNumsToRestore.entrySet()) { + for (Map.Entry entry : sequenceNumsToRestore.entrySet()) { // sequenceNumsToRestore is the restored global union state; // should only snapshot shards that actually belong to us if (KinesisDataFetcher.isThisSubtaskShouldSubscribeTo( - entry.getKey(), + KinesisDataFetcher.createStreamShardHandle(entry.getKey()), getRuntimeContext().getNumberOfParallelSubtasks(), getRuntimeContext().getIndexOfThisSubtask())) { @@ -344,14 +347,14 @@ public void snapshotState(FunctionSnapshotContext context) throws Exception { } } } else { - HashMap lastStateSnapshot = fetcher.snapshotState(); + HashMap lastStateSnapshot = fetcher.snapshotState(); if (LOG.isDebugEnabled()) { LOG.debug("Snapshotted state, last processed sequence numbers: {}, checkpoint id: {}, timestamp: {}", lastStateSnapshot.toString(), context.getCheckpointId(), context.getCheckpointTimestamp()); } - for (Map.Entry entry : lastStateSnapshot.entrySet()) { + for (Map.Entry entry : lastStateSnapshot.entrySet()) { sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue())); } } @@ -363,7 +366,14 @@ public void restoreState(HashMap restoredSta LOG.info("Subtask {} restoring offsets from an older Flink version: {}", getRuntimeContext().getIndexOfThisSubtask(), sequenceNumsToRestore); - sequenceNumsToRestore = restoredState.isEmpty() ? null : restoredState; + if (restoredState.isEmpty()) { + sequenceNumsToRestore = null; + } else { + sequenceNumsToRestore = new HashMap<>(); + for (Map.Entry kv: restoredState.entrySet()) { + sequenceNumsToRestore.put(createKinesisStreamShardV2(kv.getKey()), kv.getValue()); + } + } } /** This method is exposed for tests that need to mock the KinesisDataFetcher in the consumer. */ @@ -378,7 +388,32 @@ protected KinesisDataFetcher createFetcher( } @VisibleForTesting - HashMap getRestoredState() { + HashMap getRestoredState() { return sequenceNumsToRestore; } + + /** + * Utility function to convert {@link KinesisStreamShard} into {@link KinesisStreamShardV2} + * + * @param kinesisStreamShard the {@link KinesisStreamShard} to be converted + * @return a {@link KinesisStreamShardV2} object + */ + public static KinesisStreamShardV2 createKinesisStreamShardV2(KinesisStreamShard kinesisStreamShard) { + KinesisStreamShardV2 kinesisStreamShardV2 = new KinesisStreamShardV2(); + + kinesisStreamShardV2.setStreamName(kinesisStreamShard.getStreamName()); + kinesisStreamShardV2.setShardId(kinesisStreamShard.getShard().getShardId()); + kinesisStreamShardV2.setParentShardId(kinesisStreamShard.getShard().getParentShardId()); + kinesisStreamShardV2.setAdjacentParentShardId(kinesisStreamShard.getShard().getAdjacentParentShardId()); + if (kinesisStreamShard.getShard().getHashKeyRange() != null) { + kinesisStreamShardV2.setStartingHashKey(kinesisStreamShard.getShard().getHashKeyRange().getStartingHashKey()); + kinesisStreamShardV2.setEndingHashKey(kinesisStreamShard.getShard().getHashKeyRange().getEndingHashKey()); + } + if (kinesisStreamShard.getShard().getSequenceNumberRange() != null) { + kinesisStreamShardV2.setStartingSequenceNumber(kinesisStreamShard.getShard().getSequenceNumberRange().getStartingSequenceNumber()); + kinesisStreamShardV2.setEndingSequenceNumber(kinesisStreamShard.getShard().getSequenceNumberRange().getEndingSequenceNumber()); + } + + return kinesisStreamShardV2; + } } diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java index 99305cb435e3e..b0dceec07f60b 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java @@ -17,13 +17,17 @@ package org.apache.flink.streaming.connectors.kinesis.internals; +import com.amazonaws.services.kinesis.model.HashKeyRange; +import com.amazonaws.services.kinesis.model.SequenceNumberRange; +import com.amazonaws.services.kinesis.model.Shard; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants; -import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard; +import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState; import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber; import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber; +import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardV2; import org.apache.flink.streaming.connectors.kinesis.proxy.GetShardListResult; import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxy; import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxyInterface; @@ -259,7 +263,7 @@ public void runFetcher() throws Exception { if (LOG.isInfoEnabled()) { LOG.info("Subtask {} will start consuming seeded shard {} from sequence number {} with ShardConsumer {}", - indexOfThisConsumerSubtask, seededShardState.getKinesisStreamShard().toString(), + indexOfThisConsumerSubtask, seededShardState.getStreamShardHandle().toString(), seededShardState.getLastProcessedSequenceNum(), seededStateIndex); } @@ -267,7 +271,7 @@ public void runFetcher() throws Exception { new ShardConsumer<>( this, seededStateIndex, - subscribedShardsState.get(seededStateIndex).getKinesisStreamShard(), + subscribedShardsState.get(seededStateIndex).getStreamShardHandle(), subscribedShardsState.get(seededStateIndex).getLastProcessedSequenceNum())); } } @@ -293,19 +297,19 @@ public void runFetcher() throws Exception { LOG.debug("Subtask {} is trying to discover new shards that were created due to resharding ...", indexOfThisConsumerSubtask); } - List newShardsDueToResharding = discoverNewShardsToSubscribe(); + List newShardsDueToResharding = discoverNewShardsToSubscribe(); - for (KinesisStreamShard shard : newShardsDueToResharding) { + for (StreamShardHandle shard : newShardsDueToResharding) { // since there may be delay in discovering a new shard, all new shards due to // resharding should be read starting from the earliest record possible KinesisStreamShardState newShardState = - new KinesisStreamShardState(shard, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get()); + new KinesisStreamShardState(createKinesisStreamShardV2(shard), shard, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get()); int newStateIndex = registerNewSubscribedShardState(newShardState); if (LOG.isInfoEnabled()) { LOG.info("Subtask {} has discovered a new shard {} due to resharding, and will start consuming " + "the shard from sequence number {} with ShardConsumer {}", - indexOfThisConsumerSubtask, newShardState.getKinesisStreamShard().toString(), + indexOfThisConsumerSubtask, newShardState.getStreamShardHandle().toString(), newShardState.getLastProcessedSequenceNum(), newStateIndex); } @@ -313,7 +317,7 @@ public void runFetcher() throws Exception { new ShardConsumer<>( this, newStateIndex, - newShardState.getKinesisStreamShard(), + newShardState.getStreamShardHandle(), newShardState.getLastProcessedSequenceNum())); } @@ -349,11 +353,11 @@ public void runFetcher() throws Exception { * * @return state snapshot */ - public HashMap snapshotState() { + public HashMap snapshotState() { // this method assumes that the checkpoint lock is held assert Thread.holdsLock(checkpointLock); - HashMap stateSnapshot = new HashMap<>(); + HashMap stateSnapshot = new HashMap<>(); for (KinesisStreamShardState shardWithState : subscribedShardsState) { stateSnapshot.put(shardWithState.getKinesisStreamShard(), shardWithState.getLastProcessedSequenceNum()); } @@ -405,7 +409,7 @@ public void advanceLastDiscoveredShardOfStream(String stream, String shardId) { if (lastSeenShardIdOfStream == null) { // if not previously set, simply put as the last seen shard id this.subscribedStreamsToLastDiscoveredShardIds.put(stream, shardId); - } else if (KinesisStreamShard.compareShardIds(shardId, lastSeenShardIdOfStream) > 0) { + } else if (StreamShardHandle.compareShardIds(shardId, lastSeenShardIdOfStream) > 0) { this.subscribedStreamsToLastDiscoveredShardIds.put(stream, shardId); } } @@ -419,17 +423,17 @@ public void advanceLastDiscoveredShardOfStream(String stream, String shardId) { * 3. Update the subscribedStreamsToLastDiscoveredShardIds state so that we won't get shards * that we have already seen before the next time this function is called */ - public List discoverNewShardsToSubscribe() throws InterruptedException { + public List discoverNewShardsToSubscribe() throws InterruptedException { - List newShardsToSubscribe = new LinkedList<>(); + List newShardsToSubscribe = new LinkedList<>(); GetShardListResult shardListResult = kinesis.getShardList(subscribedStreamsToLastDiscoveredShardIds); if (shardListResult.hasRetrievedShards()) { Set streamsWithNewShards = shardListResult.getStreamsWithRetrievedShards(); for (String stream : streamsWithNewShards) { - List newShardsOfStream = shardListResult.getRetrievedShardListOfStream(stream); - for (KinesisStreamShard newShard : newShardsOfStream) { + List newShardsOfStream = shardListResult.getRetrievedShardListOfStream(stream); + for (StreamShardHandle newShard : newShardsOfStream) { if (isThisSubtaskShouldSubscribeTo(newShard, totalNumberOfConsumerSubtasks, indexOfThisConsumerSubtask)) { newShardsToSubscribe.add(newShard); } @@ -502,7 +506,7 @@ protected void updateState(int shardStateIndex, SequenceNumber lastSequenceNumbe // we've finished reading the shard and should determine it to be non-active if (lastSequenceNumber.equals(SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.get())) { LOG.info("Subtask {} has reached the end of subscribed shard: {}", - indexOfThisConsumerSubtask, subscribedShardsState.get(shardStateIndex).getKinesisStreamShard()); + indexOfThisConsumerSubtask, subscribedShardsState.get(shardStateIndex).getStreamShardHandle()); // check if we need to mark the source as idle; // note that on resharding, if registerNewSubscribedShardState was invoked for newly discovered shards @@ -549,7 +553,7 @@ public int registerNewSubscribedShardState(KinesisStreamShardState newSubscribed * @param totalNumberOfConsumerSubtasks total number of consumer subtasks * @param indexOfThisConsumerSubtask index of this consumer subtask */ - public static boolean isThisSubtaskShouldSubscribeTo(KinesisStreamShard shard, + public static boolean isThisSubtaskShouldSubscribeTo(StreamShardHandle shard, int totalNumberOfConsumerSubtasks, int indexOfThisConsumerSubtask) { return (Math.abs(shard.hashCode() % totalNumberOfConsumerSubtasks)) == indexOfThisConsumerSubtask; @@ -582,4 +586,58 @@ protected static HashMap createInitialSubscribedStreamsToLastDis } return initial; } + + /** + * Utility function to convert {@link StreamShardHandle} into {@link KinesisStreamShardV2} + * + * @param streamShardHandle the {@link StreamShardHandle} to be converted + * @return a {@link KinesisStreamShardV2} object + */ + public static KinesisStreamShardV2 createKinesisStreamShardV2(StreamShardHandle streamShardHandle) { + KinesisStreamShardV2 kinesisStreamShardV2 = new KinesisStreamShardV2(); + + kinesisStreamShardV2.setStreamName(streamShardHandle.getStreamName()); + kinesisStreamShardV2.setShardId(streamShardHandle.getShard().getShardId()); + kinesisStreamShardV2.setParentShardId(streamShardHandle.getShard().getParentShardId()); + kinesisStreamShardV2.setAdjacentParentShardId(streamShardHandle.getShard().getAdjacentParentShardId()); + if (streamShardHandle.getShard().getHashKeyRange() != null) { + kinesisStreamShardV2.setStartingHashKey(streamShardHandle.getShard().getHashKeyRange().getStartingHashKey()); + kinesisStreamShardV2.setEndingHashKey(streamShardHandle.getShard().getHashKeyRange().getEndingHashKey()); + } + if (streamShardHandle.getShard().getSequenceNumberRange() != null) { + kinesisStreamShardV2.setStartingSequenceNumber(streamShardHandle.getShard().getSequenceNumberRange().getStartingSequenceNumber()); + kinesisStreamShardV2.setEndingSequenceNumber(streamShardHandle.getShard().getSequenceNumberRange().getEndingSequenceNumber()); + } + + return kinesisStreamShardV2; + } + + /** + * Utility function to convert {@link KinesisStreamShardV2} into {@link StreamShardHandle} + * + * @param kinesisStreamShard the {@link KinesisStreamShardV2} to be converted + * @return a {@link StreamShardHandle} object + */ + public static StreamShardHandle createStreamShardHandle(KinesisStreamShardV2 kinesisStreamShard) { + Shard shard = new Shard(); + shard.withShardId(kinesisStreamShard.getShardId()); + shard.withParentShardId(kinesisStreamShard.getParentShardId()); + shard.withAdjacentParentShardId(kinesisStreamShard.getAdjacentParentShardId()); + + if (kinesisStreamShard.getStartingHashKey() != null && kinesisStreamShard.getEndingHashKey() != null) { + HashKeyRange hashKeyRange = new HashKeyRange(); + hashKeyRange.withStartingHashKey(kinesisStreamShard.getStartingHashKey()); + hashKeyRange.withEndingHashKey(kinesisStreamShard.getEndingHashKey()); + shard.withHashKeyRange(hashKeyRange); + } + + if (kinesisStreamShard.getStartingSequenceNumber() != null && kinesisStreamShard.getEndingSequenceNumber() != null) { + SequenceNumberRange sequenceNumberRange = new SequenceNumberRange(); + sequenceNumberRange.withStartingSequenceNumber(kinesisStreamShard.getStartingSequenceNumber()); + sequenceNumberRange.withEndingSequenceNumber(kinesisStreamShard.getEndingSequenceNumber()); + shard.withSequenceNumberRange(sequenceNumberRange); + } + + return new StreamShardHandle(kinesisStreamShard.getStreamName(), shard); + } } diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumer.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumer.java index ca85854ea99fa..a724b49bb02f3 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumer.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumer.java @@ -24,7 +24,7 @@ import com.amazonaws.services.kinesis.model.ShardIteratorType; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants; -import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard; +import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle; import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber; import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber; import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxyInterface; @@ -60,7 +60,7 @@ public class ShardConsumer implements Runnable { private final KinesisDataFetcher fetcherRef; - private final KinesisStreamShard subscribedShard; + private final StreamShardHandle subscribedShard; private final int maxNumberOfRecordsPerFetch; private final long fetchIntervalMillis; @@ -79,7 +79,7 @@ public class ShardConsumer implements Runnable { */ public ShardConsumer(KinesisDataFetcher fetcherRef, Integer subscribedShardStateIndex, - KinesisStreamShard subscribedShard, + StreamShardHandle subscribedShard, SequenceNumber lastSequenceNum) { this(fetcherRef, subscribedShardStateIndex, @@ -91,7 +91,7 @@ public ShardConsumer(KinesisDataFetcher fetcherRef, /** This constructor is exposed for testing purposes */ protected ShardConsumer(KinesisDataFetcher fetcherRef, Integer subscribedShardStateIndex, - KinesisStreamShard subscribedShard, + StreamShardHandle subscribedShard, SequenceNumber lastSequenceNum, KinesisProxyInterface kinesis) { this.fetcherRef = checkNotNull(fetcherRef); diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/KinesisStreamShard.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/KinesisStreamShard.java index 53ed11b1b7d97..f3dcfe15cdbea 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/KinesisStreamShard.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/KinesisStreamShard.java @@ -24,9 +24,8 @@ import static org.apache.flink.util.Preconditions.checkNotNull; /** - * A serializable representation of a AWS Kinesis Stream shard. It is basically a wrapper class around the information - * provided along with {@link com.amazonaws.services.kinesis.model.Shard}, with some extra utility methods to - * determine whether or not a shard is closed and whether or not the shard is a result of parent shard splits or merges. + * A legacy serializable representation of a AWS Kinesis Stream shard. It is basically a wrapper class around the information + * provided along with {@link com.amazonaws.services.kinesis.model.Shard}. */ public class KinesisStreamShard implements Serializable { diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/KinesisStreamShardState.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/KinesisStreamShardState.java index 00181da149a56..e68129dd134c8 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/KinesisStreamShardState.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/KinesisStreamShardState.java @@ -18,22 +18,28 @@ package org.apache.flink.streaming.connectors.kinesis.model; /** - * A wrapper class that bundles a {@link KinesisStreamShard} with its last processed sequence number. + * A wrapper class that bundles a {@link StreamShardHandle} with its last processed sequence number. */ public class KinesisStreamShardState { - private KinesisStreamShard kinesisStreamShard; + private KinesisStreamShardV2 kinesisStreamShard; + private StreamShardHandle streamShardHandle; private SequenceNumber lastProcessedSequenceNum; - public KinesisStreamShardState(KinesisStreamShard kinesisStreamShard, SequenceNumber lastProcessedSequenceNum) { + public KinesisStreamShardState(KinesisStreamShardV2 kinesisStreamShard, StreamShardHandle streamShardHandle, SequenceNumber lastProcessedSequenceNum) { this.kinesisStreamShard = kinesisStreamShard; + this.streamShardHandle = streamShardHandle; this.lastProcessedSequenceNum = lastProcessedSequenceNum; } - public KinesisStreamShard getKinesisStreamShard() { + public KinesisStreamShardV2 getKinesisStreamShard() { return this.kinesisStreamShard; } + public StreamShardHandle getStreamShardHandle() { + return this.streamShardHandle; + } + public SequenceNumber getLastProcessedSequenceNum() { return this.lastProcessedSequenceNum; } @@ -46,6 +52,7 @@ public void setLastProcessedSequenceNum(SequenceNumber update) { public String toString() { return "KinesisStreamShardState{" + "kinesisStreamShard='" + kinesisStreamShard.toString() + "'" + + ", streamShardHandle='" + streamShardHandle.toString() + "'" + ", lastProcessedSequenceNumber='" + lastProcessedSequenceNum.toString() + "'}"; } @@ -61,11 +68,13 @@ public boolean equals(Object obj) { KinesisStreamShardState other = (KinesisStreamShardState) obj; - return kinesisStreamShard.equals(other.getKinesisStreamShard()) && lastProcessedSequenceNum.equals(other.getLastProcessedSequenceNum()); + return kinesisStreamShard.equals(other.getKinesisStreamShard()) && + streamShardHandle.equals(other.getStreamShardHandle()) && + lastProcessedSequenceNum.equals(other.getLastProcessedSequenceNum()); } @Override public int hashCode() { - return 37 * (kinesisStreamShard.hashCode() + lastProcessedSequenceNum.hashCode()); + return 37 * (kinesisStreamShard.hashCode() + streamShardHandle.hashCode() + lastProcessedSequenceNum.hashCode()); } } diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/KinesisStreamShardV2.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/KinesisStreamShardV2.java new file mode 100644 index 0000000000000..71cb6fa0adc89 --- /dev/null +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/KinesisStreamShardV2.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.kinesis.model; + +import java.io.Serializable; +import java.util.Objects; + +/** + * A serializable representation of a AWS Kinesis Stream shard. It is basically a wrapper class around the information + * disintegrating from {@link com.amazonaws.services.kinesis.model.Shard} and its nested classes. + */ +public class KinesisStreamShardV2 implements Serializable { + + private static final long serialVersionUID = 5134869582298563604L; + + private String streamName; + private String shardId; + private String parentShardId; + private String adjacentParentShardId; + private String startingHashKey; + private String endingHashKey; + private String startingSequenceNumber; + private String endingSequenceNumber; + + public void setStreamName(String streamName) { + this.streamName = streamName; + } + + public void setShardId(String shardId) { + this.shardId = shardId; + } + + public void setParentShardId(String parentShardId) { + this.parentShardId = parentShardId; + } + + public void setAdjacentParentShardId(String adjacentParentShardId) { + this.adjacentParentShardId = adjacentParentShardId; + } + + public void setStartingHashKey(String startingHashKey) { + this.startingHashKey = startingHashKey; + } + + public void setEndingHashKey(String endingHashKey) { + this.endingHashKey = endingHashKey; + } + + public void setStartingSequenceNumber(String startingSequenceNumber) { + this.startingSequenceNumber = startingSequenceNumber; + } + + public void setEndingSequenceNumber(String endingSequenceNumber) { + this.endingSequenceNumber = endingSequenceNumber; + } + + public String getStreamName() { + return this.streamName; + } + + public String getShardId() { + return this.shardId; + } + + public String getParentShardId() { + return this.parentShardId; + } + + public String getAdjacentParentShardId() { + return this.adjacentParentShardId; + } + + public String getStartingHashKey() { + return this.startingHashKey; + } + + public String getEndingHashKey() { + return this.endingHashKey; + } + + public String getStartingSequenceNumber() { + return this.startingSequenceNumber; + } + + public String getEndingSequenceNumber() { + return this.endingSequenceNumber; + } + + @Override + public String toString() { + return "KinesisStreamShardV2{" + + "streamName='" + streamName + "'" + + ", shardId='" + shardId + "'" + + ", parentShardId='" + parentShardId + "'" + + ", adjacentParentShardId='" + adjacentParentShardId + "'" + + ", startingHashKey='" + startingHashKey + "'" + + ", endingHashKey='" + endingHashKey + "'" + + ", startingSequenceNumber='" + startingSequenceNumber + "'" + + ", endingSequenceNumber='" + endingSequenceNumber + "'}"; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof KinesisStreamShardV2)) { + return false; + } + + if (obj == this) { + return true; + } + + KinesisStreamShardV2 other = (KinesisStreamShardV2) obj; + + return streamName.equals(other.getStreamName()) && + shardId.equals(other.getShardId()) && + Objects.equals(parentShardId, other.getParentShardId()) && + Objects.equals(adjacentParentShardId, other.getAdjacentParentShardId()) && + Objects.equals(startingHashKey, other.getStartingHashKey()) && + Objects.equals(endingHashKey, other.getEndingHashKey()) && + Objects.equals(startingSequenceNumber, other.getStartingSequenceNumber()) && + Objects.equals(endingSequenceNumber, other.getEndingSequenceNumber()); + } + + @Override + public int hashCode() { + int hash = 17; + + if (streamName != null) { + hash = 37 * hash + streamName.hashCode(); + } + if (shardId != null) { + hash = 37 * hash + shardId.hashCode(); + } + if (parentShardId != null) { + hash = 37 * hash + parentShardId.hashCode(); + } + if (adjacentParentShardId != null) { + hash = 37 * hash + adjacentParentShardId.hashCode(); + } + if (startingHashKey != null) { + hash = 37 * hash + startingHashKey.hashCode(); + } + if (endingHashKey != null) { + hash = 37 * hash + endingHashKey.hashCode(); + } + if (startingSequenceNumber != null) { + hash = 37 * hash + startingSequenceNumber.hashCode(); + } + if (endingSequenceNumber != null) { + hash = 37 * hash + endingSequenceNumber.hashCode(); + } + + return hash; + } + +} diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/StreamShardHandle.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/StreamShardHandle.java new file mode 100644 index 0000000000000..d340a88c04b79 --- /dev/null +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/model/StreamShardHandle.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.kinesis.model; + +import com.amazonaws.services.kinesis.model.Shard; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * A wrapper class around the information provided along with streamName and {@link com.amazonaws.services.kinesis.model.Shard}, + * with some extra utility methods to determine whether or not a shard is closed and whether or not the shard is + * a result of parent shard splits or merges. + */ +public class StreamShardHandle { + + private final String streamName; + private final Shard shard; + + private final int cachedHash; + + /** + * Create a new StreamShardHandle + * + * @param streamName + * the name of the Kinesis stream that this shard belongs to + * @param shard + * the actual AWS Shard instance that will be wrapped within this StreamShardHandle + */ + public StreamShardHandle(String streamName, Shard shard) { + this.streamName = checkNotNull(streamName); + this.shard = checkNotNull(shard); + + // since our description of Kinesis Streams shards can be fully defined with the stream name and shard id, + // our hash doesn't need to use hash code of Amazon's description of Shards, which uses other info for calculation + int hash = 17; + hash = 37 * hash + streamName.hashCode(); + hash = 37 * hash + shard.getShardId().hashCode(); + this.cachedHash = hash; + } + + public String getStreamName() { + return streamName; + } + + public boolean isClosed() { + return (shard.getSequenceNumberRange().getEndingSequenceNumber() != null); + } + + public Shard getShard() { + return shard; + } + + @Override + public String toString() { + return "StreamShardHandle{" + + "streamName='" + streamName + "'" + + ", shard='" + shard.toString() + "'}"; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof StreamShardHandle)) { + return false; + } + + if (obj == this) { + return true; + } + + StreamShardHandle other = (StreamShardHandle) obj; + + return streamName.equals(other.getStreamName()) && shard.equals(other.getShard()); + } + + @Override + public int hashCode() { + return cachedHash; + } + + /** + * Utility function to compare two shard ids + * + * @param firstShardId first shard id to compare + * @param secondShardId second shard id to compare + * @return a value less than 0 if the first shard id is smaller than the second shard id, + * or a value larger than 0 the first shard is larger then the second shard id, + * or 0 if they are equal + */ + public static int compareShardIds(String firstShardId, String secondShardId) { + if (!isValidShardId(firstShardId)) { + throw new IllegalArgumentException("The first shard id has invalid format."); + } + + if (!isValidShardId(secondShardId)) { + throw new IllegalArgumentException("The second shard id has invalid format."); + } + + // digit segment of the shard id starts at index 8 + return Long.compare(Long.parseLong(firstShardId.substring(8)), Long.parseLong(secondShardId.substring(8))); + } + + /** + * Checks if a shard id has valid format. + * Kinesis stream shard ids have 12-digit numbers left-padded with 0's, + * prefixed with "shardId-", ex. "shardId-000000000015". + * + * @param shardId the shard id to check + * @return whether the shard id is valid + */ + public static boolean isValidShardId(String shardId) { + if (shardId == null) { return false; } + return shardId.matches("^shardId-\\d{12}"); + } +} diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/proxy/GetShardListResult.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/proxy/GetShardListResult.java index 04b165441f3de..aadb31ccf36fd 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/proxy/GetShardListResult.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/proxy/GetShardListResult.java @@ -17,7 +17,7 @@ package org.apache.flink.streaming.connectors.kinesis.proxy; -import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard; +import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle; import java.util.LinkedList; import java.util.List; @@ -30,25 +30,25 @@ */ public class GetShardListResult { - private final Map> streamsToRetrievedShardList = new HashMap<>(); + private final Map> streamsToRetrievedShardList = new HashMap<>(); - public void addRetrievedShardToStream(String stream, KinesisStreamShard retrievedShard) { + public void addRetrievedShardToStream(String stream, StreamShardHandle retrievedShard) { if (!streamsToRetrievedShardList.containsKey(stream)) { - streamsToRetrievedShardList.put(stream, new LinkedList()); + streamsToRetrievedShardList.put(stream, new LinkedList()); } streamsToRetrievedShardList.get(stream).add(retrievedShard); } - public void addRetrievedShardsToStream(String stream, List retrievedShards) { + public void addRetrievedShardsToStream(String stream, List retrievedShards) { if (retrievedShards.size() != 0) { if (!streamsToRetrievedShardList.containsKey(stream)) { - streamsToRetrievedShardList.put(stream, new LinkedList()); + streamsToRetrievedShardList.put(stream, new LinkedList()); } streamsToRetrievedShardList.get(stream).addAll(retrievedShards); } } - public List getRetrievedShardListOfStream(String stream) { + public List getRetrievedShardListOfStream(String stream) { if (!streamsToRetrievedShardList.containsKey(stream)) { return null; } else { @@ -56,7 +56,7 @@ public List getRetrievedShardListOfStream(String stream) { } } - public KinesisStreamShard getLastSeenShardOfStream(String stream) { + public StreamShardHandle getLastSeenShardOfStream(String stream) { if (!streamsToRetrievedShardList.containsKey(stream)) { return null; } else { diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/proxy/KinesisProxy.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/proxy/KinesisProxy.java index 580555f0fcbd9..70c128602537e 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/proxy/KinesisProxy.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/proxy/KinesisProxy.java @@ -32,7 +32,7 @@ import com.amazonaws.services.kinesis.model.GetShardIteratorRequest; import com.amazonaws.services.kinesis.model.ShardIteratorType; import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants; -import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard; +import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle; import org.apache.flink.streaming.connectors.kinesis.util.AWSUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -237,7 +237,7 @@ public GetShardListResult getShardList(Map streamNamesWithLastSe * {@inheritDoc} */ @Override - public String getShardIterator(KinesisStreamShard shard, String shardIteratorType, @Nullable Object startingMarker) throws InterruptedException { + public String getShardIterator(StreamShardHandle shard, String shardIteratorType, @Nullable Object startingMarker) throws InterruptedException { GetShardIteratorRequest getShardIteratorRequest = new GetShardIteratorRequest() .withStreamName(shard.getStreamName()) .withShardId(shard.getShard().getShardId()) @@ -315,8 +315,8 @@ protected static boolean isRecoverableException(AmazonServiceException ex) { } } - private List getShardsOfStream(String streamName, @Nullable String lastSeenShardId) throws InterruptedException { - List shardsOfStream = new ArrayList<>(); + private List getShardsOfStream(String streamName, @Nullable String lastSeenShardId) throws InterruptedException { + List shardsOfStream = new ArrayList<>(); DescribeStreamResult describeStreamResult; do { @@ -324,7 +324,7 @@ private List getShardsOfStream(String streamName, @Nullable List shards = describeStreamResult.getStreamDescription().getShards(); for (Shard shard : shards) { - shardsOfStream.add(new KinesisStreamShard(streamName, shard)); + shardsOfStream.add(new StreamShardHandle(streamName, shard)); } if (shards.size() != 0) { @@ -384,7 +384,7 @@ private DescribeStreamResult describeStream(String streamName, @Nullable String List shards = describeStreamResult.getStreamDescription().getShards(); Iterator shardItr = shards.iterator(); while (shardItr.hasNext()) { - if (KinesisStreamShard.compareShardIds(shardItr.next().getShardId(), startShardId) <= 0) { + if (StreamShardHandle.compareShardIds(shardItr.next().getShardId(), startShardId) <= 0) { shardItr.remove(); } } diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/proxy/KinesisProxyInterface.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/proxy/KinesisProxyInterface.java index 9f6d594a9baaa..807a16381ca49 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/proxy/KinesisProxyInterface.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/proxy/KinesisProxyInterface.java @@ -18,7 +18,7 @@ package org.apache.flink.streaming.connectors.kinesis.proxy; import com.amazonaws.services.kinesis.model.GetRecordsResult; -import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard; +import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle; import java.util.Map; @@ -43,7 +43,7 @@ public interface KinesisProxyInterface { * operation has exceeded the rate limit; this exception will be thrown * if the backoff is interrupted. */ - String getShardIterator(KinesisStreamShard shard, String shardIteratorType, Object startingMarker) throws InterruptedException; + String getShardIterator(StreamShardHandle shard, String shardIteratorType, Object startingMarker) throws InterruptedException; /** * Get the next batch of data records using a specific shard iterator diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java index ec9a9b5ff630e..d2af6ad2bec69 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java @@ -24,6 +24,7 @@ import org.apache.flink.streaming.api.operators.StreamSource; import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants; import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher; +import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardV2; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard; import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber; import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema; @@ -101,9 +102,9 @@ public void testRestoreFromFlink11() throws Exception { testHarness.open(); // the expected state in "kafka-consumer-migration-test-flink1.1-snapshot" - final HashMap expectedState = new HashMap<>(); - expectedState.put(new KinesisStreamShard("fakeStream1", - new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), + final HashMap expectedState = new HashMap<>(); + expectedState.put(FlinkKinesisConsumer.createKinesisStreamShardV2(new KinesisStreamShard("fakeStream1", + new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0)))), new SequenceNumber("987654321")); // assert that state is correctly restored from legacy checkpoint diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java index 4b178c74620ca..760858a2d4894 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java @@ -17,12 +17,17 @@ package org.apache.flink.streaming.connectors.kinesis; +import com.amazonaws.services.kinesis.model.HashKeyRange; +import com.amazonaws.services.kinesis.model.SequenceNumberRange; import com.amazonaws.services.kinesis.model.Shard; +import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.OperatorStateStore; +import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.runtime.PojoSerializer; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.runtime.state.StateInitializationContext; @@ -36,6 +41,8 @@ import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber; import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState; +import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardV2; +import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle; import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator; import org.apache.flink.streaming.connectors.kinesis.testutils.TestableFlinkKinesisConsumer; import org.apache.flink.streaming.connectors.kinesis.util.KinesisConfigUtil; @@ -62,6 +69,7 @@ import static org.junit.Assert.fail; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @@ -109,8 +117,8 @@ public void testUnrecognizableAwsRegionInConfig() { @Test public void testCredentialProviderTypeSetToBasicButNoCredentialSetInConfig() { exception.expect(IllegalArgumentException.class); - exception.expectMessage("Please set values for AWS Access Key ID ('"+ AWSConfigConstants.AWS_ACCESS_KEY_ID +"') " + - "and Secret Key ('" + AWSConfigConstants.AWS_SECRET_ACCESS_KEY + "') when using the BASIC AWS credential provider type."); + exception.expectMessage("Please set values for AWS Access Key ID ('" + AWSConfigConstants.AWS_ACCESS_KEY_ID + "') " + + "and Secret Key ('" + AWSConfigConstants.AWS_SECRET_ACCESS_KEY + "') when using the BASIC AWS credential provider type."); Properties testConfig = new Properties(); testConfig.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1"); @@ -535,28 +543,26 @@ public void testUseRestoredStateForSnapshotIfFetcherNotInitialized() throws Exce config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId"); config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey"); - OperatorStateStore operatorStateStore = mock(OperatorStateStore.class); - - List> globalUnionState = new ArrayList<>(4); + List> globalUnionState = new ArrayList<>(4); globalUnionState.add(Tuple2.of( - new KinesisStreamShard("fakeStream", - new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), + KinesisDataFetcher.createKinesisStreamShardV2(new StreamShardHandle("fakeStream", + new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0)))), new SequenceNumber("1"))); globalUnionState.add(Tuple2.of( - new KinesisStreamShard("fakeStream", - new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), + KinesisDataFetcher.createKinesisStreamShardV2(new StreamShardHandle("fakeStream", + new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1)))), new SequenceNumber("1"))); globalUnionState.add(Tuple2.of( - new KinesisStreamShard("fakeStream", - new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))), + KinesisDataFetcher.createKinesisStreamShardV2(new StreamShardHandle("fakeStream", + new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2)))), new SequenceNumber("1"))); globalUnionState.add(Tuple2.of( - new KinesisStreamShard("fakeStream", - new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(3))), + KinesisDataFetcher.createKinesisStreamShardV2(new StreamShardHandle("fakeStream", + new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(3)))), new SequenceNumber("1"))); - TestingListState> listState = new TestingListState<>(); - for (Tuple2 state : globalUnionState) { + TestingListState> listState = new TestingListState<>(); + for (Tuple2 state : globalUnionState) { listState.add(state); } @@ -566,10 +572,10 @@ public void testUseRestoredStateForSnapshotIfFetcherNotInitialized() throws Exce when(context.getNumberOfParallelSubtasks()).thenReturn(2); consumer.setRuntimeContext(context); + OperatorStateStore operatorStateStore = mock(OperatorStateStore.class); when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState); StateInitializationContext initializationContext = mock(StateInitializationContext.class); - when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore); when(initializationContext.isRestored()).thenReturn(true); @@ -600,32 +606,32 @@ public void testListStateChangedAfterSnapshotState() throws Exception { config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId"); config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey"); - ArrayList> initialState = new ArrayList<>(1); + ArrayList> initialState = new ArrayList<>(1); initialState.add(Tuple2.of( - new KinesisStreamShard("fakeStream1", - new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), + KinesisDataFetcher.createKinesisStreamShardV2(new StreamShardHandle("fakeStream1", + new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0)))), new SequenceNumber("1"))); - ArrayList> expectedStateSnapshot = new ArrayList<>(3); + ArrayList> expectedStateSnapshot = new ArrayList<>(3); expectedStateSnapshot.add(Tuple2.of( - new KinesisStreamShard("fakeStream1", - new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), + KinesisDataFetcher.createKinesisStreamShardV2(new StreamShardHandle("fakeStream1", + new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0)))), new SequenceNumber("12"))); expectedStateSnapshot.add(Tuple2.of( - new KinesisStreamShard("fakeStream1", - new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), + KinesisDataFetcher.createKinesisStreamShardV2(new StreamShardHandle("fakeStream1", + new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1)))), new SequenceNumber("11"))); expectedStateSnapshot.add(Tuple2.of( - new KinesisStreamShard("fakeStream1", - new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))), + KinesisDataFetcher.createKinesisStreamShardV2(new StreamShardHandle("fakeStream1", + new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2)))), new SequenceNumber("31"))); // ---------------------------------------------------------------------- // mock operator state backend and initial state for initializeState() // ---------------------------------------------------------------------- - TestingListState> listState = new TestingListState<>(); - for (Tuple2 state: initialState) { + TestingListState> listState = new TestingListState<>(); + for (Tuple2 state : initialState) { listState.add(state); } @@ -640,8 +646,8 @@ public void testListStateChangedAfterSnapshotState() throws Exception { // mock a running fetcher and its state for snapshot // ---------------------------------------------------------------------- - HashMap stateSnapshot = new HashMap<>(); - for (Tuple2 tuple: expectedStateSnapshot) { + HashMap stateSnapshot = new HashMap<>(); + for (Tuple2 tuple : expectedStateSnapshot) { stateSnapshot.put(tuple.f0, tuple.f1); } @@ -668,15 +674,15 @@ public void testListStateChangedAfterSnapshotState() throws Exception { assertEquals(true, listState.clearCalled); assertEquals(3, listState.getList().size()); - for (Tuple2 state: initialState) { - for (Tuple2 currentState: listState.getList()) { + for (Tuple2 state : initialState) { + for (Tuple2 currentState : listState.getList()) { assertNotEquals(state, currentState); } } - for (Tuple2 state: expectedStateSnapshot) { + for (Tuple2 state : expectedStateSnapshot) { boolean hasOneIsSame = false; - for (Tuple2 currentState: listState.getList()) { + for (Tuple2 currentState : listState.getList()) { hasOneIsSame = hasOneIsSame || state.equals(currentState); } assertEquals(true, hasOneIsSame); @@ -706,10 +712,14 @@ public void testFetcherShouldNotBeRestoringFromFailureIfNotRestoringFromCheckpoi @Test @SuppressWarnings("unchecked") public void testFetcherShouldBeCorrectlySeededIfRestoringFromLegacyCheckpoint() throws Exception { - HashMap fakeRestoredState = getFakeRestoredStore("all"); + HashMap fakeRestoredState = getFakeRestoredStore("all"); + HashMap legacyFakeRestoredState = new HashMap<>(); + for (Map.Entry kv : fakeRestoredState.entrySet()) { + legacyFakeRestoredState.put(new KinesisStreamShard(kv.getKey().getStreamName(), kv.getKey().getShard()), kv.getValue()); + } KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class); - List shards = new ArrayList<>(); + List shards = new ArrayList<>(); shards.addAll(fakeRestoredState.keySet()); when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards); PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher); @@ -720,13 +730,14 @@ public void testFetcherShouldBeCorrectlySeededIfRestoringFromLegacyCheckpoint() TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer( "fakeStream", new Properties(), 10, 2); - consumer.restoreState(fakeRestoredState); + consumer.restoreState(legacyFakeRestoredState); consumer.open(new Configuration()); consumer.run(Mockito.mock(SourceFunction.SourceContext.class)); - for (Map.Entry restoredShard : fakeRestoredState.entrySet()) { + for (Map.Entry restoredShard : fakeRestoredState.entrySet()) { Mockito.verify(mockedFetcher).registerNewSubscribedShardState( - new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue())); + new KinesisStreamShardState(KinesisDataFetcher.createKinesisStreamShardV2(restoredShard.getKey()), + restoredShard.getKey(), restoredShard.getValue())); } } @@ -738,15 +749,15 @@ public void testFetcherShouldBeCorrectlySeededIfRestoringFromCheckpoint() throws // setup initial state // ---------------------------------------------------------------------- - HashMap fakeRestoredState = getFakeRestoredStore("all"); + HashMap fakeRestoredState = getFakeRestoredStore("all"); // ---------------------------------------------------------------------- // mock operator state backend and initial state for initializeState() // ---------------------------------------------------------------------- - TestingListState> listState = new TestingListState<>(); - for (Map.Entry state: fakeRestoredState.entrySet()) { - listState.add(Tuple2.of(state.getKey(), state.getValue())); + TestingListState> listState = new TestingListState<>(); + for (Map.Entry state : fakeRestoredState.entrySet()) { + listState.add(Tuple2.of(KinesisDataFetcher.createKinesisStreamShardV2(state.getKey()), state.getValue())); } OperatorStateStore operatorStateStore = mock(OperatorStateStore.class); @@ -761,7 +772,7 @@ public void testFetcherShouldBeCorrectlySeededIfRestoringFromCheckpoint() throws // ---------------------------------------------------------------------- KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class); - List shards = new ArrayList<>(); + List shards = new ArrayList<>(); shards.addAll(fakeRestoredState.keySet()); when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards); PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher); @@ -780,9 +791,10 @@ public void testFetcherShouldBeCorrectlySeededIfRestoringFromCheckpoint() throws consumer.open(new Configuration()); consumer.run(Mockito.mock(SourceFunction.SourceContext.class)); - for (Map.Entry restoredShard : fakeRestoredState.entrySet()) { + for (Map.Entry restoredShard : fakeRestoredState.entrySet()) { Mockito.verify(mockedFetcher).registerNewSubscribedShardState( - new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue())); + new KinesisStreamShardState(KinesisDataFetcher.createKinesisStreamShardV2(restoredShard.getKey()), + restoredShard.getKey(), restoredShard.getValue())); } } @@ -794,20 +806,20 @@ public void testFetcherShouldBeCorrectlySeededOnlyItsOwnStates() throws Exceptio // setup initial state // ---------------------------------------------------------------------- - HashMap fakeRestoredState = getFakeRestoredStore("fakeStream1"); + HashMap fakeRestoredState = getFakeRestoredStore("fakeStream1"); - HashMap fakeRestoredStateForOthers = getFakeRestoredStore("fakeStream2"); + HashMap fakeRestoredStateForOthers = getFakeRestoredStore("fakeStream2"); // ---------------------------------------------------------------------- // mock operator state backend and initial state for initializeState() // ---------------------------------------------------------------------- - TestingListState> listState = new TestingListState<>(); - for (Map.Entry state: fakeRestoredState.entrySet()) { - listState.add(Tuple2.of(state.getKey(), state.getValue())); + TestingListState> listState = new TestingListState<>(); + for (Map.Entry state : fakeRestoredState.entrySet()) { + listState.add(Tuple2.of(KinesisDataFetcher.createKinesisStreamShardV2(state.getKey()), state.getValue())); } - for (Map.Entry state: fakeRestoredStateForOthers.entrySet()) { - listState.add(Tuple2.of(state.getKey(), state.getValue())); + for (Map.Entry state : fakeRestoredStateForOthers.entrySet()) { + listState.add(Tuple2.of(KinesisDataFetcher.createKinesisStreamShardV2(state.getKey()), state.getValue())); } OperatorStateStore operatorStateStore = mock(OperatorStateStore.class); @@ -822,7 +834,7 @@ public void testFetcherShouldBeCorrectlySeededOnlyItsOwnStates() throws Exceptio // ---------------------------------------------------------------------- KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class); - List shards = new ArrayList<>(); + List shards = new ArrayList<>(); shards.addAll(fakeRestoredState.keySet()); when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards); PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher); @@ -841,15 +853,17 @@ public void testFetcherShouldBeCorrectlySeededOnlyItsOwnStates() throws Exceptio consumer.open(new Configuration()); consumer.run(Mockito.mock(SourceFunction.SourceContext.class)); - for (Map.Entry restoredShard : fakeRestoredStateForOthers.entrySet()) { + for (Map.Entry restoredShard : fakeRestoredStateForOthers.entrySet()) { // should never get restored state not belonging to itself Mockito.verify(mockedFetcher, never()).registerNewSubscribedShardState( - new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue())); + new KinesisStreamShardState(KinesisDataFetcher.createKinesisStreamShardV2(restoredShard.getKey()), + restoredShard.getKey(), restoredShard.getValue())); } - for (Map.Entry restoredShard : fakeRestoredState.entrySet()) { + for (Map.Entry restoredShard : fakeRestoredState.entrySet()) { // should get restored state belonging to itself Mockito.verify(mockedFetcher).registerNewSubscribedShardState( - new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue())); + new KinesisStreamShardState(KinesisDataFetcher.createKinesisStreamShardV2(restoredShard.getKey()), + restoredShard.getKey(), restoredShard.getValue())); } } @@ -890,15 +904,15 @@ public void testFetcherShouldBeCorrectlySeededWithNewDiscoveredKinesisStreamShar // setup initial state // ---------------------------------------------------------------------- - HashMap fakeRestoredState = getFakeRestoredStore("all"); + HashMap fakeRestoredState = getFakeRestoredStore("all"); // ---------------------------------------------------------------------- // mock operator state backend and initial state for initializeState() // ---------------------------------------------------------------------- - TestingListState> listState = new TestingListState<>(); - for (Map.Entry state: fakeRestoredState.entrySet()) { - listState.add(Tuple2.of(state.getKey(), state.getValue())); + TestingListState> listState = new TestingListState<>(); + for (Map.Entry state : fakeRestoredState.entrySet()) { + listState.add(Tuple2.of(KinesisDataFetcher.createKinesisStreamShardV2(state.getKey()), state.getValue())); } OperatorStateStore operatorStateStore = mock(OperatorStateStore.class); @@ -913,9 +927,9 @@ public void testFetcherShouldBeCorrectlySeededWithNewDiscoveredKinesisStreamShar // ---------------------------------------------------------------------- KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class); - List shards = new ArrayList<>(); + List shards = new ArrayList<>(); shards.addAll(fakeRestoredState.keySet()); - shards.add(new KinesisStreamShard("fakeStream2", + shards.add(new StreamShardHandle("fakeStream2", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2)))); when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards); PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher); @@ -934,15 +948,58 @@ public void testFetcherShouldBeCorrectlySeededWithNewDiscoveredKinesisStreamShar consumer.open(new Configuration()); consumer.run(Mockito.mock(SourceFunction.SourceContext.class)); - fakeRestoredState.put(new KinesisStreamShard("fakeStream2", + fakeRestoredState.put(new StreamShardHandle("fakeStream2", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))), SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get()); - for (Map.Entry restoredShard : fakeRestoredState.entrySet()) { + for (Map.Entry restoredShard : fakeRestoredState.entrySet()) { Mockito.verify(mockedFetcher).registerNewSubscribedShardState( - new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue())); + new KinesisStreamShardState(KinesisDataFetcher.createKinesisStreamShardV2(restoredShard.getKey()), + restoredShard.getKey(), restoredShard.getValue())); } } + @Test + public void testCreateFunctionToConvertBetweenKinesisStreamShardAndKinesisStreamShardV2() { + String streamName = "fakeStream1"; + String shardId = "shard-000001"; + String parentShardId = "shard-000002"; + String adjacentParentShardId = "shard-000003"; + String startingHashKey = "key-000001"; + String endingHashKey = "key-000010"; + String startingSequenceNumber = "seq-0000021"; + String endingSequenceNumber = "seq-00000031"; + + KinesisStreamShardV2 kinesisStreamShardV2 = new KinesisStreamShardV2(); + kinesisStreamShardV2.setStreamName(streamName); + kinesisStreamShardV2.setShardId(shardId); + kinesisStreamShardV2.setParentShardId(parentShardId); + kinesisStreamShardV2.setAdjacentParentShardId(adjacentParentShardId); + kinesisStreamShardV2.setStartingHashKey(startingHashKey); + kinesisStreamShardV2.setEndingHashKey(endingHashKey); + kinesisStreamShardV2.setStartingSequenceNumber(startingSequenceNumber); + kinesisStreamShardV2.setEndingSequenceNumber(endingSequenceNumber); + + Shard shard = new Shard() + .withShardId(shardId) + .withParentShardId(parentShardId) + .withAdjacentParentShardId(adjacentParentShardId) + .withHashKeyRange(new HashKeyRange() + .withStartingHashKey(startingHashKey) + .withEndingHashKey(endingHashKey)) + .withSequenceNumberRange(new SequenceNumberRange() + .withStartingSequenceNumber(startingSequenceNumber) + .withEndingSequenceNumber(endingSequenceNumber)); + KinesisStreamShard kinesisStreamShard = new KinesisStreamShard(streamName, shard); + + assertEquals(kinesisStreamShardV2, FlinkKinesisConsumer.createKinesisStreamShardV2(kinesisStreamShard)); + } + + @Test + public void testKinesisStreamShardV2WillUsePojoSerializer() { + TypeInformation typeInformation = TypeInformation.of(KinesisStreamShardV2.class); + assertTrue(typeInformation.createSerializer(new ExecutionConfig()) instanceof PojoSerializer); + } + private static final class TestingListState implements ListState { private final List list = new ArrayList<>(); @@ -973,31 +1030,31 @@ public boolean isClearCalled() { } } - private HashMap getFakeRestoredStore(String streamName) { - HashMap fakeRestoredState = new HashMap<>(); + private HashMap getFakeRestoredStore(String streamName) { + HashMap fakeRestoredState = new HashMap<>(); if (streamName.equals("fakeStream1") || streamName.equals("all")) { fakeRestoredState.put( - new KinesisStreamShard("fakeStream1", + new StreamShardHandle("fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), new SequenceNumber(UUID.randomUUID().toString())); fakeRestoredState.put( - new KinesisStreamShard("fakeStream1", + new StreamShardHandle("fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), new SequenceNumber(UUID.randomUUID().toString())); fakeRestoredState.put( - new KinesisStreamShard("fakeStream1", + new StreamShardHandle("fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))), new SequenceNumber(UUID.randomUUID().toString())); } if (streamName.equals("fakeStream2") || streamName.equals("all")) { fakeRestoredState.put( - new KinesisStreamShard("fakeStream2", + new StreamShardHandle("fakeStream2", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), new SequenceNumber(UUID.randomUUID().toString())); fakeRestoredState.put( - new KinesisStreamShard("fakeStream2", + new StreamShardHandle("fakeStream2", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), new SequenceNumber(UUID.randomUUID().toString())); } diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java index 800fde566318c..7c369457ef1b0 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java @@ -17,14 +17,17 @@ package org.apache.flink.streaming.connectors.kinesis.internals; +import com.amazonaws.services.kinesis.model.HashKeyRange; +import com.amazonaws.services.kinesis.model.SequenceNumberRange; import com.amazonaws.services.kinesis.model.Shard; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer; import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants; -import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard; +import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle; import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState; +import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardV2; import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema; import org.apache.flink.streaming.connectors.kinesis.testutils.FakeKinesisBehavioursFactory; import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator; @@ -46,6 +49,7 @@ import java.util.UUID; import java.util.concurrent.atomic.AtomicReference; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -149,33 +153,33 @@ public void testStreamToLastSeenShardStateIsCorrectlySetWhenNoNewShardsSinceRest fakeStreams.add("fakeStream1"); fakeStreams.add("fakeStream2"); - Map restoredStateUnderTest = new HashMap<>(); + Map restoredStateUnderTest = new HashMap<>(); // fakeStream1 has 3 shards before restore restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), UUID.randomUUID().toString()); restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), UUID.randomUUID().toString()); restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))), UUID.randomUUID().toString()); // fakeStream2 has 2 shards before restore restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream2", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), UUID.randomUUID().toString()); restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream2", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), UUID.randomUUID().toString()); @@ -198,10 +202,11 @@ public void testStreamToLastSeenShardStateIsCorrectlySetWhenNoNewShardsSinceRest subscribedStreamsToLastSeenShardIdsUnderTest, FakeKinesisBehavioursFactory.nonReshardedStreamsBehaviour(streamToShardCount)); - for (Map.Entry restoredState : restoredStateUnderTest.entrySet()) { + for (Map.Entry restoredState : restoredStateUnderTest.entrySet()) { fetcher.advanceLastDiscoveredShardOfStream(restoredState.getKey().getStreamName(), restoredState.getKey().getShard().getShardId()); fetcher.registerNewSubscribedShardState( - new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue()))); + new KinesisStreamShardState(KinesisDataFetcher.createKinesisStreamShardV2(restoredState.getKey()), + restoredState.getKey(), new SequenceNumber(restoredState.getValue()))); } PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class)); @@ -238,33 +243,33 @@ public void testStreamToLastSeenShardStateIsCorrectlySetWhenNewShardsFoundSinceR fakeStreams.add("fakeStream1"); fakeStreams.add("fakeStream2"); - Map restoredStateUnderTest = new HashMap<>(); + Map restoredStateUnderTest = new HashMap<>(); // fakeStream1 has 3 shards before restore restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), UUID.randomUUID().toString()); restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), UUID.randomUUID().toString()); restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))), UUID.randomUUID().toString()); // fakeStream2 has 2 shards before restore restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream2", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), UUID.randomUUID().toString()); restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream2", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), UUID.randomUUID().toString()); @@ -288,10 +293,11 @@ public void testStreamToLastSeenShardStateIsCorrectlySetWhenNewShardsFoundSinceR subscribedStreamsToLastSeenShardIdsUnderTest, FakeKinesisBehavioursFactory.nonReshardedStreamsBehaviour(streamToShardCount)); - for (Map.Entry restoredState : restoredStateUnderTest.entrySet()) { + for (Map.Entry restoredState : restoredStateUnderTest.entrySet()) { fetcher.advanceLastDiscoveredShardOfStream(restoredState.getKey().getStreamName(), restoredState.getKey().getShard().getShardId()); fetcher.registerNewSubscribedShardState( - new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue()))); + new KinesisStreamShardState(KinesisDataFetcher.createKinesisStreamShardV2(restoredState.getKey()), + restoredState.getKey(), new SequenceNumber(restoredState.getValue()))); } PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class)); @@ -330,33 +336,33 @@ public void testStreamToLastSeenShardStateIsCorrectlySetWhenNoNewShardsSinceRest fakeStreams.add("fakeStream3"); // fakeStream3 will not have any shards fakeStreams.add("fakeStream4"); // fakeStream4 will not have any shards - Map restoredStateUnderTest = new HashMap<>(); + Map restoredStateUnderTest = new HashMap<>(); // fakeStream1 has 3 shards before restore restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), UUID.randomUUID().toString()); restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), UUID.randomUUID().toString()); restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))), UUID.randomUUID().toString()); // fakeStream2 has 2 shards before restore restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream2", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), UUID.randomUUID().toString()); restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream2", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), UUID.randomUUID().toString()); @@ -382,10 +388,11 @@ public void testStreamToLastSeenShardStateIsCorrectlySetWhenNoNewShardsSinceRest subscribedStreamsToLastSeenShardIdsUnderTest, FakeKinesisBehavioursFactory.nonReshardedStreamsBehaviour(streamToShardCount)); - for (Map.Entry restoredState : restoredStateUnderTest.entrySet()) { + for (Map.Entry restoredState : restoredStateUnderTest.entrySet()) { fetcher.advanceLastDiscoveredShardOfStream(restoredState.getKey().getStreamName(), restoredState.getKey().getShard().getShardId()); fetcher.registerNewSubscribedShardState( - new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue()))); + new KinesisStreamShardState(KinesisDataFetcher.createKinesisStreamShardV2(restoredState.getKey()), + restoredState.getKey(), new SequenceNumber(restoredState.getValue()))); } PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class)); @@ -425,33 +432,33 @@ public void testStreamToLastSeenShardStateIsCorrectlySetWhenNewShardsFoundSinceR fakeStreams.add("fakeStream3"); // fakeStream3 will not have any shards fakeStreams.add("fakeStream4"); // fakeStream4 will not have any shards - Map restoredStateUnderTest = new HashMap<>(); + Map restoredStateUnderTest = new HashMap<>(); // fakeStream1 has 3 shards before restore restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), UUID.randomUUID().toString()); restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), UUID.randomUUID().toString()); restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream1", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))), UUID.randomUUID().toString()); // fakeStream2 has 2 shards before restore restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream2", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), UUID.randomUUID().toString()); restoredStateUnderTest.put( - new KinesisStreamShard( + new StreamShardHandle( "fakeStream2", new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), UUID.randomUUID().toString()); @@ -477,10 +484,11 @@ public void testStreamToLastSeenShardStateIsCorrectlySetWhenNewShardsFoundSinceR subscribedStreamsToLastSeenShardIdsUnderTest, FakeKinesisBehavioursFactory.nonReshardedStreamsBehaviour(streamToShardCount)); - for (Map.Entry restoredState : restoredStateUnderTest.entrySet()) { + for (Map.Entry restoredState : restoredStateUnderTest.entrySet()) { fetcher.advanceLastDiscoveredShardOfStream(restoredState.getKey().getStreamName(), restoredState.getKey().getShard().getShardId()); fetcher.registerNewSubscribedShardState( - new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue()))); + new KinesisStreamShardState(KinesisDataFetcher.createKinesisStreamShardV2(restoredState.getKey()), + restoredState.getKey(), new SequenceNumber(restoredState.getValue()))); } PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class)); @@ -512,6 +520,43 @@ public void run() { assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream4") == null); } + @Test + public void testCreateFunctionToConvertBetweenKinesisStreamShardV2AndStreamShardHandle() { + String streamName = "fakeStream1"; + String shardId = "shard-000001"; + String parentShardId = "shard-000002"; + String adjacentParentShardId = "shard-000003"; + String startingHashKey = "key-000001"; + String endingHashKey = "key-000010"; + String startingSequenceNumber = "seq-0000021"; + String endingSequenceNumber = "seq-00000031"; + + KinesisStreamShardV2 kinesisStreamShard = new KinesisStreamShardV2(); + kinesisStreamShard.setStreamName(streamName); + kinesisStreamShard.setShardId(shardId); + kinesisStreamShard.setParentShardId(parentShardId); + kinesisStreamShard.setAdjacentParentShardId(adjacentParentShardId); + kinesisStreamShard.setStartingHashKey(startingHashKey); + kinesisStreamShard.setEndingHashKey(endingHashKey); + kinesisStreamShard.setStartingSequenceNumber(startingSequenceNumber); + kinesisStreamShard.setEndingSequenceNumber(endingSequenceNumber); + + Shard shard = new Shard() + .withShardId(shardId) + .withParentShardId(parentShardId) + .withAdjacentParentShardId(adjacentParentShardId) + .withHashKeyRange(new HashKeyRange() + .withStartingHashKey(startingHashKey) + .withEndingHashKey(endingHashKey)) + .withSequenceNumberRange(new SequenceNumberRange() + .withStartingSequenceNumber(startingSequenceNumber) + .withEndingSequenceNumber(endingSequenceNumber)); + StreamShardHandle streamShardHandle = new StreamShardHandle(streamName, shard); + + assertEquals(kinesisStreamShard, KinesisDataFetcher.createKinesisStreamShardV2(streamShardHandle)); + assertEquals(streamShardHandle, KinesisDataFetcher.createStreamShardHandle(kinesisStreamShard)); + } + private static class DummyFlinkKafkaConsumer extends FlinkKinesisConsumer { private static final long serialVersionUID = 1L; diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java index 96764a4c96c75..4e063296e02c3 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java @@ -20,7 +20,7 @@ import com.amazonaws.services.kinesis.model.HashKeyRange; import com.amazonaws.services.kinesis.model.Shard; import org.apache.commons.lang.StringUtils; -import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard; +import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle; import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState; import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber; import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber; @@ -43,7 +43,7 @@ public class ShardConsumerTest { @Test public void testCorrectNumOfCollectedRecordsAndUpdatedState() { - KinesisStreamShard fakeToBeConsumedShard = new KinesisStreamShard( + StreamShardHandle fakeToBeConsumedShard = new StreamShardHandle( "fakeStream", new Shard() .withShardId(KinesisShardIdGenerator.generateFromShardOrder(0)) @@ -54,7 +54,8 @@ public void testCorrectNumOfCollectedRecordsAndUpdatedState() { LinkedList subscribedShardsStateUnderTest = new LinkedList<>(); subscribedShardsStateUnderTest.add( - new KinesisStreamShardState(fakeToBeConsumedShard, new SequenceNumber("fakeStartingState"))); + new KinesisStreamShardState(KinesisDataFetcher.createKinesisStreamShardV2(fakeToBeConsumedShard), + fakeToBeConsumedShard, new SequenceNumber("fakeStartingState"))); TestableKinesisDataFetcher fetcher = new TestableKinesisDataFetcher( @@ -70,7 +71,7 @@ public void testCorrectNumOfCollectedRecordsAndUpdatedState() { new ShardConsumer<>( fetcher, 0, - subscribedShardsStateUnderTest.get(0).getKinesisStreamShard(), + subscribedShardsStateUnderTest.get(0).getStreamShardHandle(), subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum(), FakeKinesisBehavioursFactory.totalNumOfRecordsAfterNumOfGetRecordsCalls(1000, 9)).run(); @@ -81,7 +82,7 @@ public void testCorrectNumOfCollectedRecordsAndUpdatedState() { @Test public void testCorrectNumOfCollectedRecordsAndUpdatedStateWithUnexpectedExpiredIterator() { - KinesisStreamShard fakeToBeConsumedShard = new KinesisStreamShard( + StreamShardHandle fakeToBeConsumedShard = new StreamShardHandle( "fakeStream", new Shard() .withShardId(KinesisShardIdGenerator.generateFromShardOrder(0)) @@ -92,7 +93,8 @@ public void testCorrectNumOfCollectedRecordsAndUpdatedStateWithUnexpectedExpired LinkedList subscribedShardsStateUnderTest = new LinkedList<>(); subscribedShardsStateUnderTest.add( - new KinesisStreamShardState(fakeToBeConsumedShard, new SequenceNumber("fakeStartingState"))); + new KinesisStreamShardState(KinesisDataFetcher.createKinesisStreamShardV2(fakeToBeConsumedShard), + fakeToBeConsumedShard, new SequenceNumber("fakeStartingState"))); TestableKinesisDataFetcher fetcher = new TestableKinesisDataFetcher( @@ -108,7 +110,7 @@ public void testCorrectNumOfCollectedRecordsAndUpdatedStateWithUnexpectedExpired new ShardConsumer<>( fetcher, 0, - subscribedShardsStateUnderTest.get(0).getKinesisStreamShard(), + subscribedShardsStateUnderTest.get(0).getStreamShardHandle(), subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum(), // Get a total of 1000 records with 9 getRecords() calls, // and the 7th getRecords() call will encounter an unexpected expired shard iterator diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/FakeKinesisBehavioursFactory.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/FakeKinesisBehavioursFactory.java index b62e7de9dc3c8..ce5a0de1d1c7e 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/FakeKinesisBehavioursFactory.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/FakeKinesisBehavioursFactory.java @@ -22,7 +22,7 @@ import com.amazonaws.services.kinesis.model.Record; import com.amazonaws.services.kinesis.model.Shard; import org.apache.flink.configuration.ConfigConstants; -import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard; +import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle; import org.apache.flink.streaming.connectors.kinesis.proxy.GetShardListResult; import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxyInterface; @@ -55,7 +55,7 @@ public GetShardListResult getShardList(Map streamNamesWithLastSe } @Override - public String getShardIterator(KinesisStreamShard shard, String shardIteratorType, Object startingMarker) { + public String getShardIterator(StreamShardHandle shard, String shardIteratorType, Object startingMarker) { return null; } @@ -122,7 +122,7 @@ public GetRecordsResult getRecords(String shardIterator, int maxRecordsToGet) { } @Override - public String getShardIterator(KinesisStreamShard shard, String shardIteratorType, Object startingMarker) { + public String getShardIterator(StreamShardHandle shard, String shardIteratorType, Object startingMarker) { if (!expiredOnceAlready) { // for the first call, just return the iterator of the first batch of records return "0"; @@ -181,7 +181,7 @@ public GetRecordsResult getRecords(String shardIterator, int maxRecordsToGet) { } @Override - public String getShardIterator(KinesisStreamShard shard, String shardIteratorType, Object startingMarker) { + public String getShardIterator(StreamShardHandle shard, String shardIteratorType, Object startingMarker) { // this will be called only one time per ShardConsumer; // so, simply return the iterator of the first batch of records return "0"; @@ -209,7 +209,7 @@ public static List createRecordBatchWithRange(int min, int max) { private static class NonReshardedStreamsKinesis implements KinesisProxyInterface { - private Map> streamsWithListOfShards = new HashMap<>(); + private Map> streamsWithListOfShards = new HashMap<>(); public NonReshardedStreamsKinesis(Map streamsToShardCount) { for (Map.Entry streamToShardCount : streamsToShardCount.entrySet()) { @@ -219,10 +219,10 @@ public NonReshardedStreamsKinesis(Map streamsToShardCount) { if (shardCount == 0) { // don't do anything } else { - List shardsOfStream = new ArrayList<>(shardCount); + List shardsOfStream = new ArrayList<>(shardCount); for (int i=0; i < shardCount; i++) { shardsOfStream.add( - new KinesisStreamShard( + new StreamShardHandle( streamName, new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(i)))); } @@ -234,13 +234,13 @@ public NonReshardedStreamsKinesis(Map streamsToShardCount) { @Override public GetShardListResult getShardList(Map streamNamesWithLastSeenShardIds) { GetShardListResult result = new GetShardListResult(); - for (Map.Entry> streamsWithShards : streamsWithListOfShards.entrySet()) { + for (Map.Entry> streamsWithShards : streamsWithListOfShards.entrySet()) { String streamName = streamsWithShards.getKey(); - for (KinesisStreamShard shard : streamsWithShards.getValue()) { + for (StreamShardHandle shard : streamsWithShards.getValue()) { if (streamNamesWithLastSeenShardIds.get(streamName) == null) { result.addRetrievedShardToStream(streamName, shard); } else { - if (KinesisStreamShard.compareShardIds( + if (StreamShardHandle.compareShardIds( shard.getShard().getShardId(), streamNamesWithLastSeenShardIds.get(streamName)) > 0) { result.addRetrievedShardToStream(streamName, shard); } @@ -251,7 +251,7 @@ public GetShardListResult getShardList(Map streamNamesWithLastSe } @Override - public String getShardIterator(KinesisStreamShard shard, String shardIteratorType, Object startingMarker) { + public String getShardIterator(StreamShardHandle shard, String shardIteratorType, Object startingMarker) { return null; }