Skip to content


[SPARK-25214][SS] Fix the issue that Kafka v2 source may return dupli…
Browse files Browse the repository at this point in the history
…cated records when `failOnDataLoss=false`

## What changes were proposed in this pull request?

When there are missing offsets, Kafka v2 source may return duplicated records when `failOnDataLoss=false` because it doesn't skip missing offsets.

This PR fixes the issue and also adds regression tests for all Kafka readers.

## How was this patch tested?

New tests.

Closes #22207 from zsxwing/SPARK-25214.

Authored-by: Shixiong Zhu <>
Signed-off-by: Shixiong Zhu <>
  • Loading branch information
zsxwing committed Aug 24, 2018
1 parent c20916a commit 8bb9414
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 175 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ private[kafka010] case class KafkaMicroBatchPartitionReader(
val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss)
if (record != null) {
nextRow = converter.toUnsafeRow(record)
nextOffset = record.offset + 1
} else {
Expand All @@ -352,7 +353,6 @@ private[kafka010] case class KafkaMicroBatchPartitionReader(

override def get(): UnsafeRow = {
assert(nextRow != null)
nextOffset += 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,44 +77,6 @@ private[kafka010] class KafkaSourceRDD( { case (o, i) => new KafkaSourceRDDPartition(i, o) }.toArray

override def count(): Long =

override def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = {
val c = count
new PartialResult(new BoundedDouble(c, 1.0, c, c), true)

override def isEmpty(): Boolean = count == 0L

override def take(num: Int): Array[ConsumerRecord[Array[Byte], Array[Byte]]] = {
val nonEmptyPartitions =[KafkaSourceRDDPartition]).filter(_.offsetRange.size > 0)

if (num < 1 || nonEmptyPartitions.isEmpty) {
return new Array[ConsumerRecord[Array[Byte], Array[Byte]]](0)

// Determine in advance how many messages need to be taken from each partition
val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) =>
val remain = num - result.values.sum
if (remain > 0) {
val taken = Math.min(remain, part.offsetRange.size)
result + (part.index -> taken.toInt)
} else {

val buf = new ArrayBuffer[ConsumerRecord[Array[Byte], Array[Byte]]]
val res = context.runJob(
(tc: TaskContext, it: Iterator[ConsumerRecord[Array[Byte], Array[Byte]]]) =>
it.take(parts(tc.partitionId)).toArray, parts.keys.toArray
res.foreach(buf ++= _)

override def getPreferredLocations(split: Partition): Seq[String] = {
val part = split.asInstanceOf[KafkaSourceRDDPartition]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.

package org.apache.spark.sql.kafka010

import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable
import scala.util.Random

import org.scalatest.time.SpanSugar._

import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter}
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}

* This is a basic test trait which will set up a Kafka cluster that keeps only several records in
* a topic and ages out records very quickly. This is a helper trait to test
* "failonDataLoss=false" case with missing offsets.
* Note: there is a hard-code 30 seconds delay (kafka.log.LogManager.InitialTaskDelayMs) to clean up
* records. Hence each class extending this trait needs to wait at least 30 seconds (or even longer
* when running on a slow Jenkins machine) before records start to be removed. To make sure a test
* does see missing offsets, you can check the earliest offset in `eventually` and make sure it's
* not 0 rather than sleeping a hard-code duration.
trait KafkaMissingOffsetsTest extends SharedSQLContext {

protected var testUtils: KafkaTestUtils = _

override def createSparkSession(): TestSparkSession = {
// Set maxRetries to 3 to handle NPE from `poll` when deleting a topic
new TestSparkSession(new SparkContext("local[2,3]", "test-sql-context", sparkConf))

override def beforeAll(): Unit = {
testUtils = new KafkaTestUtils {
override def brokerConfiguration: Properties = {
val props = super.brokerConfiguration
// Try to make Kafka clean up messages as fast as possible. However, there is a hard-code
// 30 seconds delay (kafka.log.LogManager.InitialTaskDelayMs) so this test should run at
// least 30 seconds.
props.put("", "100")
// The size of RecordBatch V2 increases to support transactional write.
props.put("log.segment.bytes", "70")
props.put("log.retention.bytes", "40")
props.put("", "100")
props.put("", "10")
props.put("", "10")

override def afterAll(): Unit = {
if (testUtils != null) {
testUtils = null

class KafkaDontFailOnDataLossSuite extends KafkaMissingOffsetsTest {

import testImplicits._

private val topicId = new AtomicInteger(0)

private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}"

* @param testStreamingQuery whether to test a streaming query or a batch query.
* @param writeToTable the function to write the specified [[DataFrame]] to the given table.
private def verifyMissingOffsetsDontCauseDuplicatedRecords(
testStreamingQuery: Boolean)(writeToTable: (DataFrame, String) => Unit): Unit = {
val topic = newTopic()
testUtils.createTopic(topic, partitions = 1)
testUtils.sendMessages(topic, (0 until 50).map(_.toString).toArray)

eventually(timeout(60.seconds)) {
testUtils.getEarliestOffsets(Set(topic)).head._2 > 0,
"Kafka didn't delete records after 1 minute")

val table = "DontFailOnDataLoss"
withTable(table) {
val kafkaOptions = Map(
"kafka.bootstrap.servers" -> testUtils.brokerAddress,
"" -> "1",
"subscribe" -> topic,
"startingOffsets" -> s"""{"$topic":{"0":0}}""",
"failOnDataLoss" -> "false",
"kafkaConsumer.pollTimeoutMs" -> "1000")
val df =
if (testStreamingQuery) {
val reader = spark.readStream.format("kafka")
kafkaOptions.foreach(kv => reader.option(kv._1, kv._2))
} else {
val reader ="kafka")
kafkaOptions.foreach(kv => reader.option(kv._1, kv._2))
writeToTable(df.selectExpr("CAST(value AS STRING)"), table)
val result = spark.table(table).as[String].collect().toList
assert(result.distinct.size === result.size, s"$result contains duplicated records")
// Make sure Kafka did remove some records so that this test is valid.
assert(result.size > 0 && result.size < 50)

test("failOnDataLoss=false should not return duplicated records: v1") {
"spark.sql.streaming.disabledV2MicroBatchReaders" ->
classOf[KafkaSourceProvider].getCanonicalName) {
verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = true) { (df, table) =>
val query = df.writeStream.format("memory").queryName(table).start()
try {
} finally {

test("failOnDataLoss=false should not return duplicated records: v2") {
verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = true) { (df, table) =>
val query = df.writeStream.format("memory").queryName(table).start()
try {
} finally {

test("failOnDataLoss=false should not return duplicated records: continuous processing") {
verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = true) { (df, table) =>
val query = df.writeStream
try {
} finally {

test("failOnDataLoss=false should not return duplicated records: batch") {
verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = false) { (df, table) =>

class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with KafkaMissingOffsetsTest {

import testImplicits._

private val topicId = new AtomicInteger(0)

private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}"

protected def startStream(ds: Dataset[Int]) = {
ds.writeStream.foreach(new ForeachWriter[Int] {

override def open(partitionId: Long, version: Long): Boolean = true

override def process(value: Int): Unit = {
// Slow down the processing speed so that messages may be aged out.

override def close(errorOrNull: Throwable): Unit = {}

test("stress test for failOnDataLoss=false") {
val reader = spark
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("", "1")
.option("", "3000")
.option("subscribePattern", "failOnDataLoss.*")
.option("startingOffsets", "earliest")
.option("failOnDataLoss", "false")
.option("fetchOffset.retryIntervalMs", "3000")
val kafka = reader.load()
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.as[(String, String)]
val query = startStream( => kv._2.toInt))

val testTime = 1.minutes
val startTime = System.currentTimeMillis()
// Track the current existing topics
val topics = mutable.ArrayBuffer[String]()
// Track topics that have been deleted
val deletedTopics = mutable.Set[String]()
while (System.currentTimeMillis() - testTime.toMillis < startTime) {
Random.nextInt(10) match {
case 0 => // Create a new topic
val topic = newTopic()
topics += topic
// As pushing messages into Kafka updates Zookeeper asynchronously, there is a small
// chance that a topic will be recreated after deletion due to the asynchronous update.
// Hence, always overwrite to handle this race condition.
testUtils.createTopic(topic, partitions = 1, overwrite = true)
logInfo(s"Create topic $topic")
case 1 if topics.nonEmpty => // Delete an existing topic
val topic = topics.remove(Random.nextInt(topics.size))
logInfo(s"Delete topic $topic")
deletedTopics += topic
case 2 if deletedTopics.nonEmpty => // Recreate a topic that was deleted.
val topic = deletedTopics.toSeq(Random.nextInt(deletedTopics.size))
deletedTopics -= topic
topics += topic
// As pushing messages into Kafka updates Zookeeper asynchronously, there is a small
// chance that a topic will be recreated after deletion due to the asynchronous update.
// Hence, always overwrite to handle this race condition.
testUtils.createTopic(topic, partitions = 1, overwrite = true)
logInfo(s"Create topic $topic")
case 3 =>
case _ => // Push random messages
for (topic <- topics) {
val size = Random.nextInt(10)
for (_ <- 0 until size) {
testUtils.sendMessages(topic, Array(Random.nextInt(10).toString))
// `failOnDataLoss` is `false`, we should not fail the query
if (query.exception.nonEmpty) {
throw query.exception.get

// `failOnDataLoss` is `false`, we should not fail the query
if (query.exception.nonEmpty) {
throw query.exception.get

0 comments on commit 8bb9414

Please sign in to comment.