diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index cbb99fd7118e0..af240dc04eea8 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -78,7 +78,7 @@ private[kafka010] sealed trait KafkaDataConsumer { def release(): Unit /** Reference to the internal implementation that this wrapper delegates to */ - protected def internalConsumer: InternalKafkaConsumer + def internalConsumer: InternalKafkaConsumer } @@ -512,7 +512,7 @@ private[kafka010] object KafkaDataConsumer extends Logging { override def release(): Unit = { internalConsumer.close() } } - private case class CacheKey(groupId: String, topicPartition: TopicPartition) { + private[kafka010] case class CacheKey(groupId: String, topicPartition: TopicPartition) { def this(topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) = this(kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String], topicPartition) } @@ -521,7 +521,7 @@ private[kafka010] object KafkaDataConsumer extends Logging { // - We make a best-effort attempt to maintain the max size of the cache as configured capacity. // The capacity is not guaranteed to be maintained, especially when there are more active // tasks simultaneously using consumers than the capacity. - private lazy val cache = { + private[kafka010] lazy val cache = { val conf = SparkEnv.get.conf val capacity = conf.get(CONSUMER_CACHE_CAPACITY) new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer](capacity, 0.75f, true) { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala index 2aa869c02bc5d..8aa7e06e772a1 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala @@ -20,22 +20,23 @@ package org.apache.spark.sql.kafka010 import java.util.concurrent.{Executors, TimeUnit} import scala.collection.JavaConverters._ -import scala.concurrent.{ExecutionContext, Future} -import scala.concurrent.duration.Duration import scala.util.Random -import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.clients.consumer.ConsumerConfig._ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.serialization.ByteArrayDeserializer import org.scalatest.PrivateMethodTester import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.sql.kafka010.KafkaDataConsumer.CacheKey import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.util.ThreadUtils class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester { protected var testUtils: KafkaTestUtils = _ + private val topic = "topic" + Random.nextInt() + private val topicPartition = new TopicPartition(topic, 0) + private val groupId = "groupId" override def beforeAll(): Unit = { super.beforeAll() @@ -51,6 +52,15 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester super.afterAll() } + private def getKafkaParams() = Map[String, Object]( + GROUP_ID_CONFIG -> "groupId", + BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress, + KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + AUTO_OFFSET_RESET_CONFIG -> "earliest", + ENABLE_AUTO_COMMIT_CONFIG -> "false" + ).asJava + test("SPARK-19886: Report error cause correctly in reportDataLoss") { val cause = new Exception("D'oh!") val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0) @@ -60,23 +70,40 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester assert(e.getCause === cause) } + test("new KafkaDataConsumer instance in case of Task retry") { + try { + KafkaDataConsumer.cache.clear() + + val kafkaParams = getKafkaParams() + val key = new CacheKey(groupId, topicPartition) + + val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null) + TaskContext.setTaskContext(context1) + val consumer1 = KafkaDataConsumer.acquire(topicPartition, kafkaParams, true) + consumer1.release() + + assert(KafkaDataConsumer.cache.size() == 1) + assert(KafkaDataConsumer.cache.get(key).eq(consumer1.internalConsumer)) + + val context2 = new TaskContextImpl(0, 0, 0, 0, 1, null, null, null) + TaskContext.setTaskContext(context2) + val consumer2 = KafkaDataConsumer.acquire(topicPartition, kafkaParams, true) + consumer2.release() + + // The first consumer should be removed from cache and new non-cached should be returned + assert(KafkaDataConsumer.cache.size() == 0) + assert(consumer1.internalConsumer.ne(consumer2.internalConsumer)) + } finally { + TaskContext.unset() + } + } + test("SPARK-23623: concurrent use of KafkaDataConsumer") { - val topic = "topic" + Random.nextInt() val data = (1 to 1000).map(_.toString) testUtils.createTopic(topic, 1) testUtils.sendMessages(topic, data.toArray) - val topicPartition = new TopicPartition(topic, 0) - - import ConsumerConfig._ - val kafkaParams = Map[String, Object]( - GROUP_ID_CONFIG -> "groupId", - BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress, - KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, - VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, - AUTO_OFFSET_RESET_CONFIG -> "earliest", - ENABLE_AUTO_COMMIT_CONFIG -> "false" - ) + val kafkaParams = getKafkaParams() val numThreads = 100 val numConsumerUsages = 500 @@ -90,8 +117,7 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester null } TaskContext.setTaskContext(taskContext) - val consumer = KafkaDataConsumer.acquire( - topicPartition, kafkaParams.asJava, useCache) + val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache) try { val range = consumer.getAvailableOffsetRange() val rcvd = range.earliest until range.latest map { offset => diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala index d8df5496f612d..431473e7f1d38 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala @@ -87,6 +87,30 @@ class KafkaDataConsumerSuite extends SparkFunSuite with MockitoSugar with Before assert(existingInternalConsumer.eq(consumer2.internalConsumer)) } + test("new KafkaDataConsumer instance in case of Task retry") { + KafkaDataConsumer.cache.clear() + + val kafkaParams = getKafkaParams() + val key = new CacheKey(groupId, topicPartition) + + val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null) + val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, context1, true) + consumer1.release() + + assert(KafkaDataConsumer.cache.size() == 1) + assert(KafkaDataConsumer.cache.get(key).eq(consumer1.internalConsumer)) + + val context2 = new TaskContextImpl(0, 0, 0, 0, 1, null, null, null) + val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, context2, true) + consumer2.release() + + // The first consumer should be removed from cache and new non-cached should be returned + assert(KafkaDataConsumer.cache.size() == 0) + assert(consumer1.internalConsumer.ne(consumer2.internalConsumer)) + } + test("concurrent use of KafkaDataConsumer") { val data = (1 to 1000).map(_.toString) testUtils.createTopic(topic)