Skip to content

Commit

Permalink
[SPARK-28875][DSTREAMS][SS][TESTS] Add Task rety tests to make sure n…
Browse files Browse the repository at this point in the history
…ew consumer used
  • Loading branch information
gaborgsomogyi committed Aug 26, 2019
1 parent c353a84 commit d1fa313
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 21 deletions.
Expand Up @@ -78,7 +78,7 @@ private[kafka010] sealed trait KafkaDataConsumer {
def release(): Unit

/** Reference to the internal implementation that this wrapper delegates to */
protected def internalConsumer: InternalKafkaConsumer
def internalConsumer: InternalKafkaConsumer
}


Expand Down Expand Up @@ -512,7 +512,7 @@ private[kafka010] object KafkaDataConsumer extends Logging {
override def release(): Unit = { internalConsumer.close() }
}

private case class CacheKey(groupId: String, topicPartition: TopicPartition) {
private[kafka010] case class CacheKey(groupId: String, topicPartition: TopicPartition) {
def this(topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) =
this(kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String], topicPartition)
}
Expand All @@ -521,7 +521,7 @@ private[kafka010] object KafkaDataConsumer extends Logging {
// - We make a best-effort attempt to maintain the max size of the cache as configured capacity.
// The capacity is not guaranteed to be maintained, especially when there are more active
// tasks simultaneously using consumers than the capacity.
private lazy val cache = {
private[kafka010] lazy val cache = {
val conf = SparkEnv.get.conf
val capacity = conf.get(CONSUMER_CACHE_CAPACITY)
new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer](capacity, 0.75f, true) {
Expand Down
Expand Up @@ -20,22 +20,23 @@ package org.apache.spark.sql.kafka010
import java.util.concurrent.{Executors, TimeUnit}

import scala.collection.JavaConverters._
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.util.Random

import org.apache.kafka.clients.consumer.ConsumerConfig
import org.apache.kafka.clients.consumer.ConsumerConfig._
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.serialization.ByteArrayDeserializer
import org.scalatest.PrivateMethodTester

import org.apache.spark.{TaskContext, TaskContextImpl}
import org.apache.spark.sql.kafka010.KafkaDataConsumer.CacheKey
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.ThreadUtils

class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester {

protected var testUtils: KafkaTestUtils = _
private val topic = "topic" + Random.nextInt()
private val topicPartition = new TopicPartition(topic, 0)
private val groupId = "groupId"

override def beforeAll(): Unit = {
super.beforeAll()
Expand All @@ -51,6 +52,15 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
super.afterAll()
}

private def getKafkaParams() = Map[String, Object](
GROUP_ID_CONFIG -> "groupId",
BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress,
KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
AUTO_OFFSET_RESET_CONFIG -> "earliest",
ENABLE_AUTO_COMMIT_CONFIG -> "false"
).asJava

test("SPARK-19886: Report error cause correctly in reportDataLoss") {
val cause = new Exception("D'oh!")
val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0)
Expand All @@ -60,23 +70,36 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
assert(e.getCause === cause)
}

test("new KafkaDataConsumer instance in case of Task retry") {
KafkaDataConsumer.cache.clear()

val kafkaParams = getKafkaParams()
val key = new CacheKey(groupId, topicPartition)

val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
TaskContext.setTaskContext(context1)
val consumer1 = KafkaDataConsumer.acquire(topicPartition, kafkaParams, true)
consumer1.release()

assert(KafkaDataConsumer.cache.size() == 1)
assert(KafkaDataConsumer.cache.get(key).eq(consumer1.internalConsumer))

val context2 = new TaskContextImpl(0, 0, 0, 0, 1, null, null, null)
TaskContext.setTaskContext(context2)
val consumer2 = KafkaDataConsumer.acquire(topicPartition, kafkaParams, true)
consumer2.release()

// The first consumer should be removed from cache and new non-cached should be returned
assert(KafkaDataConsumer.cache.size() == 0)
assert(consumer1.internalConsumer.ne(consumer2.internalConsumer))
}

test("SPARK-23623: concurrent use of KafkaDataConsumer") {
val topic = "topic" + Random.nextInt()
val data = (1 to 1000).map(_.toString)
testUtils.createTopic(topic, 1)
testUtils.sendMessages(topic, data.toArray)
val topicPartition = new TopicPartition(topic, 0)

import ConsumerConfig._
val kafkaParams = Map[String, Object](
GROUP_ID_CONFIG -> "groupId",
BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress,
KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
AUTO_OFFSET_RESET_CONFIG -> "earliest",
ENABLE_AUTO_COMMIT_CONFIG -> "false"
)

val kafkaParams = getKafkaParams()
val numThreads = 100
val numConsumerUsages = 500

Expand All @@ -90,8 +113,7 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
null
}
TaskContext.setTaskContext(taskContext)
val consumer = KafkaDataConsumer.acquire(
topicPartition, kafkaParams.asJava, useCache)
val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache)
try {
val range = consumer.getAvailableOffsetRange()
val rcvd = range.earliest until range.latest map { offset =>
Expand Down
Expand Up @@ -87,6 +87,30 @@ class KafkaDataConsumerSuite extends SparkFunSuite with MockitoSugar with Before
assert(existingInternalConsumer.eq(consumer2.internalConsumer))
}

test("new KafkaDataConsumer instance in case of Task retry") {
KafkaDataConsumer.cache.clear()

val kafkaParams = getKafkaParams()
val key = new CacheKey(groupId, topicPartition)

val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
topicPartition, kafkaParams, context1, true)
consumer1.release()

assert(KafkaDataConsumer.cache.size() == 1)
assert(KafkaDataConsumer.cache.get(key).eq(consumer1.internalConsumer))

val context2 = new TaskContextImpl(0, 0, 0, 0, 1, null, null, null)
val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
topicPartition, kafkaParams, context2, true)
consumer2.release()

// The first consumer should be removed from cache and new non-cached should be returned
assert(KafkaDataConsumer.cache.size() == 0)
assert(consumer1.internalConsumer.ne(consumer2.internalConsumer))
}

test("concurrent use of KafkaDataConsumer") {
val data = (1 to 1000).map(_.toString)
testUtils.createTopic(topic)
Expand Down

0 comments on commit d1fa313

Please sign in to comment.