Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-28875][DSTREAMS][SS][TESTS] Add Task retry tests to make sure new consumer used #25582

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -78,7 +78,7 @@ private[kafka010] sealed trait KafkaDataConsumer {
def release(): Unit

/** Reference to the internal implementation that this wrapper delegates to */
protected def internalConsumer: InternalKafkaConsumer
def internalConsumer: InternalKafkaConsumer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's technically private[kafka010] as class scope so seems OK.

}


Expand Down Expand Up @@ -512,7 +512,7 @@ private[kafka010] object KafkaDataConsumer extends Logging {
override def release(): Unit = { internalConsumer.close() }
}

private case class CacheKey(groupId: String, topicPartition: TopicPartition) {
private[kafka010] case class CacheKey(groupId: String, topicPartition: TopicPartition) {
def this(topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) =
this(kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String], topicPartition)
}
Expand All @@ -521,7 +521,7 @@ private[kafka010] object KafkaDataConsumer extends Logging {
// - We make a best-effort attempt to maintain the max size of the cache as configured capacity.
// The capacity is not guaranteed to be maintained, especially when there are more active
// tasks simultaneously using consumers than the capacity.
private lazy val cache = {
private[kafka010] lazy val cache = {
val conf = SparkEnv.get.conf
val capacity = conf.get(CONSUMER_CACHE_CAPACITY)
new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer](capacity, 0.75f, true) {
Expand Down
Expand Up @@ -20,22 +20,23 @@ package org.apache.spark.sql.kafka010
import java.util.concurrent.{Executors, TimeUnit}

import scala.collection.JavaConverters._
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.util.Random

import org.apache.kafka.clients.consumer.ConsumerConfig
import org.apache.kafka.clients.consumer.ConsumerConfig._
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.serialization.ByteArrayDeserializer
import org.scalatest.PrivateMethodTester

import org.apache.spark.{TaskContext, TaskContextImpl}
import org.apache.spark.sql.kafka010.KafkaDataConsumer.CacheKey
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.ThreadUtils

class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester {

protected var testUtils: KafkaTestUtils = _
private val topic = "topic" + Random.nextInt()
private val topicPartition = new TopicPartition(topic, 0)
private val groupId = "groupId"

override def beforeAll(): Unit = {
super.beforeAll()
Expand All @@ -51,6 +52,15 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
super.afterAll()
}

private def getKafkaParams() = Map[String, Object](
GROUP_ID_CONFIG -> "groupId",
BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress,
KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
AUTO_OFFSET_RESET_CONFIG -> "earliest",
ENABLE_AUTO_COMMIT_CONFIG -> "false"
).asJava

test("SPARK-19886: Report error cause correctly in reportDataLoss") {
val cause = new Exception("D'oh!")
val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0)
Expand All @@ -60,23 +70,40 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
assert(e.getCause === cause)
}

test("new KafkaDataConsumer instance in case of Task retry") {
try {
KafkaDataConsumer.cache.clear()

val kafkaParams = getKafkaParams()
val key = new CacheKey(groupId, topicPartition)

val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
TaskContext.setTaskContext(context1)
val consumer1 = KafkaDataConsumer.acquire(topicPartition, kafkaParams, true)
consumer1.release()

assert(KafkaDataConsumer.cache.size() == 1)
assert(KafkaDataConsumer.cache.get(key).eq(consumer1.internalConsumer))

val context2 = new TaskContextImpl(0, 0, 0, 0, 1, null, null, null)
TaskContext.setTaskContext(context2)
val consumer2 = KafkaDataConsumer.acquire(topicPartition, kafkaParams, true)
consumer2.release()

// The first consumer should be removed from cache and new non-cached should be returned
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say consumer2 should be cached as it's created after invalidation, but here you only address test so that's OK.

assert(KafkaDataConsumer.cache.size() == 0)
assert(consumer1.internalConsumer.ne(consumer2.internalConsumer))
} finally {
TaskContext.unset()
}
}

test("SPARK-23623: concurrent use of KafkaDataConsumer") {
val topic = "topic" + Random.nextInt()
val data = (1 to 1000).map(_.toString)
testUtils.createTopic(topic, 1)
testUtils.sendMessages(topic, data.toArray)
val topicPartition = new TopicPartition(topic, 0)

import ConsumerConfig._
val kafkaParams = Map[String, Object](
GROUP_ID_CONFIG -> "groupId",
BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress,
KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
AUTO_OFFSET_RESET_CONFIG -> "earliest",
ENABLE_AUTO_COMMIT_CONFIG -> "false"
)

val kafkaParams = getKafkaParams()
val numThreads = 100
val numConsumerUsages = 500

Expand All @@ -90,8 +117,7 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
null
}
TaskContext.setTaskContext(taskContext)
val consumer = KafkaDataConsumer.acquire(
topicPartition, kafkaParams.asJava, useCache)
val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache)
try {
val range = consumer.getAvailableOffsetRange()
val rcvd = range.earliest until range.latest map { offset =>
Expand Down
Expand Up @@ -87,6 +87,30 @@ class KafkaDataConsumerSuite extends SparkFunSuite with MockitoSugar with Before
assert(existingInternalConsumer.eq(consumer2.internalConsumer))
}

test("new KafkaDataConsumer instance in case of Task retry") {
KafkaDataConsumer.cache.clear()

val kafkaParams = getKafkaParams()
val key = new CacheKey(groupId, topicPartition)

val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
topicPartition, kafkaParams, context1, true)
consumer1.release()

assert(KafkaDataConsumer.cache.size() == 1)
assert(KafkaDataConsumer.cache.get(key).eq(consumer1.internalConsumer))

val context2 = new TaskContextImpl(0, 0, 0, 0, 1, null, null, null)
val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
topicPartition, kafkaParams, context2, true)
consumer2.release()

// The first consumer should be removed from cache and new non-cached should be returned
assert(KafkaDataConsumer.cache.size() == 0)
assert(consumer1.internalConsumer.ne(consumer2.internalConsumer))
}

test("concurrent use of KafkaDataConsumer") {
val data = (1 to 1000).map(_.toString)
testUtils.createTopic(topic)
Expand Down