Skip to content

Commit

Permalink
[SPARK-28875][DSTREAMS][SS][TESTS] Add Task retry tests to make sure …
Browse files Browse the repository at this point in the history
…new consumer used

### What changes were proposed in this pull request?
When Task retry happens with Kafka source then it's not known whether the consumer is the issue so the old consumer removed from cache and new consumer created. The feature works fine but not covered with tests.

In this PR I've added such test for DStreams + Structured Streaming.

### Why are the changes needed?
No such tests are there.

### Does this PR introduce any user-facing change?
No.

### How was this patch tested?
Existing + new unit tests.

Closes #25582 from gaborgsomogyi/SPARK-28875.

Authored-by: Gabor Somogyi <gabor.g.somogyi@gmail.com>
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>
  • Loading branch information
gaborgsomogyi authored and Marcelo Vanzin committed Aug 26, 2019
1 parent 84d4f94 commit b205269
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 21 deletions.
Expand Up @@ -78,7 +78,7 @@ private[kafka010] sealed trait KafkaDataConsumer {
def release(): Unit

/** Reference to the internal implementation that this wrapper delegates to */
protected def internalConsumer: InternalKafkaConsumer
def internalConsumer: InternalKafkaConsumer
}


Expand Down Expand Up @@ -512,7 +512,7 @@ private[kafka010] object KafkaDataConsumer extends Logging {
override def release(): Unit = { internalConsumer.close() }
}

private case class CacheKey(groupId: String, topicPartition: TopicPartition) {
private[kafka010] case class CacheKey(groupId: String, topicPartition: TopicPartition) {
def this(topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) =
this(kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String], topicPartition)
}
Expand All @@ -521,7 +521,7 @@ private[kafka010] object KafkaDataConsumer extends Logging {
// - We make a best-effort attempt to maintain the max size of the cache as configured capacity.
// The capacity is not guaranteed to be maintained, especially when there are more active
// tasks simultaneously using consumers than the capacity.
private lazy val cache = {
private[kafka010] lazy val cache = {
val conf = SparkEnv.get.conf
val capacity = conf.get(CONSUMER_CACHE_CAPACITY)
new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer](capacity, 0.75f, true) {
Expand Down
Expand Up @@ -20,22 +20,23 @@ package org.apache.spark.sql.kafka010
import java.util.concurrent.{Executors, TimeUnit}

import scala.collection.JavaConverters._
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.util.Random

import org.apache.kafka.clients.consumer.ConsumerConfig
import org.apache.kafka.clients.consumer.ConsumerConfig._
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.serialization.ByteArrayDeserializer
import org.scalatest.PrivateMethodTester

import org.apache.spark.{TaskContext, TaskContextImpl}
import org.apache.spark.sql.kafka010.KafkaDataConsumer.CacheKey
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.ThreadUtils

class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester {

protected var testUtils: KafkaTestUtils = _
private val topic = "topic" + Random.nextInt()
private val topicPartition = new TopicPartition(topic, 0)
private val groupId = "groupId"

override def beforeAll(): Unit = {
super.beforeAll()
Expand All @@ -51,6 +52,15 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
super.afterAll()
}

private def getKafkaParams() = Map[String, Object](
GROUP_ID_CONFIG -> "groupId",
BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress,
KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
AUTO_OFFSET_RESET_CONFIG -> "earliest",
ENABLE_AUTO_COMMIT_CONFIG -> "false"
).asJava

test("SPARK-19886: Report error cause correctly in reportDataLoss") {
val cause = new Exception("D'oh!")
val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0)
Expand All @@ -60,23 +70,40 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
assert(e.getCause === cause)
}

test("new KafkaDataConsumer instance in case of Task retry") {
try {
KafkaDataConsumer.cache.clear()

val kafkaParams = getKafkaParams()
val key = new CacheKey(groupId, topicPartition)

val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
TaskContext.setTaskContext(context1)
val consumer1 = KafkaDataConsumer.acquire(topicPartition, kafkaParams, true)
consumer1.release()

assert(KafkaDataConsumer.cache.size() == 1)
assert(KafkaDataConsumer.cache.get(key).eq(consumer1.internalConsumer))

val context2 = new TaskContextImpl(0, 0, 0, 0, 1, null, null, null)
TaskContext.setTaskContext(context2)
val consumer2 = KafkaDataConsumer.acquire(topicPartition, kafkaParams, true)
consumer2.release()

// The first consumer should be removed from cache and new non-cached should be returned
assert(KafkaDataConsumer.cache.size() == 0)
assert(consumer1.internalConsumer.ne(consumer2.internalConsumer))
} finally {
TaskContext.unset()
}
}

test("SPARK-23623: concurrent use of KafkaDataConsumer") {
val topic = "topic" + Random.nextInt()
val data = (1 to 1000).map(_.toString)
testUtils.createTopic(topic, 1)
testUtils.sendMessages(topic, data.toArray)
val topicPartition = new TopicPartition(topic, 0)

import ConsumerConfig._
val kafkaParams = Map[String, Object](
GROUP_ID_CONFIG -> "groupId",
BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress,
KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
AUTO_OFFSET_RESET_CONFIG -> "earliest",
ENABLE_AUTO_COMMIT_CONFIG -> "false"
)

val kafkaParams = getKafkaParams()
val numThreads = 100
val numConsumerUsages = 500

Expand All @@ -90,8 +117,7 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
null
}
TaskContext.setTaskContext(taskContext)
val consumer = KafkaDataConsumer.acquire(
topicPartition, kafkaParams.asJava, useCache)
val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache)
try {
val range = consumer.getAvailableOffsetRange()
val rcvd = range.earliest until range.latest map { offset =>
Expand Down
Expand Up @@ -87,6 +87,30 @@ class KafkaDataConsumerSuite extends SparkFunSuite with MockitoSugar with Before
assert(existingInternalConsumer.eq(consumer2.internalConsumer))
}

test("new KafkaDataConsumer instance in case of Task retry") {
KafkaDataConsumer.cache.clear()

val kafkaParams = getKafkaParams()
val key = new CacheKey(groupId, topicPartition)

val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
topicPartition, kafkaParams, context1, true)
consumer1.release()

assert(KafkaDataConsumer.cache.size() == 1)
assert(KafkaDataConsumer.cache.get(key).eq(consumer1.internalConsumer))

val context2 = new TaskContextImpl(0, 0, 0, 0, 1, null, null, null)
val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
topicPartition, kafkaParams, context2, true)
consumer2.release()

// The first consumer should be removed from cache and new non-cached should be returned
assert(KafkaDataConsumer.cache.size() == 0)
assert(consumer1.internalConsumer.ne(consumer2.internalConsumer))
}

test("concurrent use of KafkaDataConsumer") {
val data = (1 to 1000).map(_.toString)
testUtils.createTopic(topic)
Expand Down

0 comments on commit b205269

Please sign in to comment.