Skip to content

[SPARK-18020][Streaming][Kinesis] Checkpoint SHARD_END to finish reading closed shards #16213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
import org.apache.spark.internal.Logging
import org.apache.spark.streaming.Duration
import org.apache.spark.streaming.util.RecurringTimer
import org.apache.spark.util.{Clock, SystemClock, ThreadUtils}
import org.apache.spark.util.{Clock, SystemClock}

/**
* This is a helper class for managing Kinesis checkpointing.
Expand Down Expand Up @@ -64,7 +64,20 @@ private[kinesis] class KinesisCheckpointer(
def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = {
synchronized {
checkpointers.remove(shardId)
checkpoint(shardId, checkpointer)
}
if (checkpointer != null) {
try {
// We must call `checkpoint()` with no parameter to finish reading shards.
// See an URL below for details:
// https://forums.aws.amazon.com/thread.jspa?threadID=244218
KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100)
} catch {
case NonFatal(e) =>
logError(s"Exception: WorkerId $workerId encountered an exception while checkpointing" +
s"to finish reading a shard of $shardId.", e)
// Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor
throw e
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ import org.apache.spark.internal.Logging
*
* PLEASE KEEP THIS FILE UNDER src/main AS PYTHON TESTS NEED ACCESS TO THIS FILE!
*/
private[kinesis] class KinesisTestUtils extends Logging {
private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Logging {

val endpointUrl = KinesisTestUtils.endpointUrl
val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName()
val streamShardCount = 2

private val createStreamTimeoutSeconds = 300
private val describeStreamPollTimeSeconds = 1
Expand Down Expand Up @@ -88,7 +87,7 @@ private[kinesis] class KinesisTestUtils extends Logging {
logInfo(s"Creating stream ${_streamName}")
val createStreamRequest = new CreateStreamRequest()
createStreamRequest.setStreamName(_streamName)
createStreamRequest.setShardCount(2)
createStreamRequest.setShardCount(streamShardCount)
kinesisClient.createStream(createStreamRequest)

// The stream is now being created. Wait for it to become active.
Expand All @@ -97,6 +96,31 @@ private[kinesis] class KinesisTestUtils extends Logging {
logInfo(s"Created stream ${_streamName}")
}

def getShards(): Seq[Shard] = {
kinesisClient.describeStream(_streamName).getStreamDescription.getShards.asScala
}

def splitShard(shardId: String): Unit = {
val splitShardRequest = new SplitShardRequest()
splitShardRequest.withStreamName(_streamName)
splitShardRequest.withShardToSplit(shardId)
// Set a half of the max hash value
splitShardRequest.withNewStartingHashKey("170141183460469231731687303715884105728")
kinesisClient.splitShard(splitShardRequest)
// Wait for the shards to become active
waitForStreamToBeActive(_streamName)
}

def mergeShard(shardToMerge: String, adjacentShardToMerge: String): Unit = {
val mergeShardRequest = new MergeShardsRequest
mergeShardRequest.withStreamName(_streamName)
mergeShardRequest.withShardToMerge(shardToMerge)
mergeShardRequest.withAdjacentShardToMerge(adjacentShardToMerge)
kinesisClient.mergeShards(mergeShardRequest)
// Wait for the shards to become active
waitForStreamToBeActive(_streamName)
}

/**
* Push data to Kinesis stream and return a map of
* shardId -> seq of (data, seq number) pushed to corresponding shard
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import scala.collection.mutable.ArrayBuffer
import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult}
import com.google.common.util.concurrent.{FutureCallback, Futures}

private[kinesis] class KPLBasedKinesisTestUtils extends KinesisTestUtils {
private[kinesis] class KPLBasedKinesisTestUtils(streamShardCount: Int = 2)
extends KinesisTestUtils(streamShardCount) {
override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = {
if (!aggregate) {
new SimpleDataGenerator(kinesisClient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class KinesisCheckpointerSuite extends TestSuiteBase
when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)

kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock)
verify(checkpointerMock, times(1)).checkpoint(anyString())
verify(checkpointerMock, times(1)).checkpoint()
}

test("if checkpointing is going on, wait until finished before removing and checkpointing") {
Expand All @@ -146,7 +146,8 @@ class KinesisCheckpointerSuite extends TestSuiteBase

clock.advance(checkpointInterval.milliseconds / 2)
eventually(timeout(1 second)) {
verify(checkpointerMock, times(2)).checkpoint(anyString())
verify(checkpointerMock, times(1)).checkpoint(anyString)
verify(checkpointerMock, times(1)).checkpoint()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,76 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
ssc.stop(stopSparkContext = false)
}

testIfEnabled("split and merge shards in a stream") {
// Since this test tries to split and merge shards in a stream, we create another
// temporary stream and then remove it when finished.
val localAppName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}"
val localTestUtils = new KPLBasedKinesisTestUtils(1)
localTestUtils.createStream()
try {
val awsCredentials = KinesisTestUtils.getAWSCredentials()
val stream = KinesisUtils.createStream(ssc, localAppName, localTestUtils.streamName,
localTestUtils.endpointUrl, localTestUtils.regionName, InitialPositionInStream.LATEST,
Seconds(10), StorageLevel.MEMORY_ONLY,
awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey)

val collected = new mutable.HashSet[Int]
stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd =>
collected.synchronized {
collected ++= rdd.collect()
logInfo("Collected = " + collected.mkString(", "))
}
}
ssc.start()

val testData1 = 1 to 10
val testData2 = 11 to 20
val testData3 = 21 to 30

eventually(timeout(60 seconds), interval(10 second)) {
localTestUtils.pushData(testData1, aggregateTestData)
assert(collected.synchronized { collected === testData1.toSet },
"\nData received does not match data sent")
}

val shardToSplit = localTestUtils.getShards().head
localTestUtils.splitShard(shardToSplit.getShardId)
val (splitOpenShards, splitCloseShards) = localTestUtils.getShards().partition { shard =>
shard.getSequenceNumberRange.getEndingSequenceNumber == null
}

// We should have one closed shard and two open shards
assert(splitCloseShards.size == 1)
assert(splitOpenShards.size == 2)

eventually(timeout(60 seconds), interval(10 second)) {
localTestUtils.pushData(testData2, aggregateTestData)
assert(collected.synchronized { collected === (testData1 ++ testData2).toSet },
"\nData received does not match data sent after splitting a shard")
}

val Seq(shardToMerge, adjShard) = splitOpenShards
localTestUtils.mergeShard(shardToMerge.getShardId, adjShard.getShardId)
val (mergedOpenShards, mergedCloseShards) = localTestUtils.getShards().partition { shard =>
shard.getSequenceNumberRange.getEndingSequenceNumber == null
}

// We should have three closed shards and one open shard
assert(mergedCloseShards.size == 3)
assert(mergedOpenShards.size == 1)

eventually(timeout(60 seconds), interval(10 second)) {
localTestUtils.pushData(testData3, aggregateTestData)
assert(collected.synchronized { collected === (testData1 ++ testData2 ++ testData3).toSet },
"\nData received does not match data sent after merging shards")
}
} finally {
ssc.stop(stopSparkContext = false)
localTestUtils.deleteStream()
localTestUtils.deleteDynamoDBTable(localAppName)
}
}

testIfEnabled("failure recovery") {
val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
val checkpointDir = Utils.createTempDir().getAbsolutePath
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,7 +1420,7 @@ def test_kinesis_stream(self):

import random
kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000)))
kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils()
kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils(2)
try:
kinesisTestUtils.createStream()
aWSCredentials = kinesisTestUtils.getAWSCredentials()
Expand Down