Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-28367][SS] Use new KafkaConsumer.poll API in Kafka connector #25135

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.kafka010

import java.{util => ju}
import java.time.Duration
import java.util.concurrent.TimeoutException

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -471,7 +472,7 @@ private[kafka010] case class InternalKafkaConsumer(
private def fetchData(offset: Long, pollTimeoutMs: Long): Unit = {
// Seek to the offset because we may call seekToBeginning or seekToEnd before this.
seek(offset)
val p = consumer.poll(pollTimeoutMs)
val p = consumer.poll(Duration.ofMillis(pollTimeoutMs))
val r = p.records(topicPartition)
logDebug(s"Polled $groupId ${p.partitions()} ${r.size}")
val offsetAfterPoll = consumer.position(topicPartition)
Expand Down
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.kafka010

import java.{time => jt}
import java.{util => ju}
import java.util.concurrent.Executors

Expand All @@ -29,7 +30,9 @@ import scala.util.control.NonFatal
import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer}
import org.apache.kafka.common.TopicPartition

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.sql.types._
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}

Expand All @@ -49,6 +52,11 @@ private[kafka010] class KafkaOffsetReader(
val driverKafkaParams: ju.Map[String, Object],
readerOptions: Map[String, String],
driverGroupIdPrefix: String) extends Logging {
private val pollTimeoutMs = readerOptions.getOrElse(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
(SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L).toString
).toLong

/**
* Used to ensure execute fetch operations execute in an UninterruptibleThread
*/
Expand Down Expand Up @@ -115,9 +123,7 @@ private[kafka010] class KafkaOffsetReader(
*/
def fetchTopicPartitions(): Set[TopicPartition] = runUninterruptibly {
assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])
// Poll to get the latest assigned partitions
consumer.poll(0)
val partitions = consumer.assignment()
val partitions = getPartitions()
consumer.pause(partitions)
partitions.asScala.toSet
}
Expand Down Expand Up @@ -163,9 +169,7 @@ private[kafka010] class KafkaOffsetReader(
reportDataLoss: String => Unit): KafkaSourceOffset = {
val fetched = runUninterruptibly {
withRetriesWithoutInterrupt {
// Poll to get the latest assigned partitions
consumer.poll(0)
val partitions = consumer.assignment()
val partitions = getPartitions()

// Call `position` to wait until the potential offset request triggered by `poll(0)` is
// done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by
Expand All @@ -177,7 +181,7 @@ private[kafka010] class KafkaOffsetReader(
"If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" +
"Use -1 for latest, -2 for earliest, if you don't care.\n" +
s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions.asScala}")
logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets")
logDebug(s"Seeking to $partitionOffsets")

partitionOffsets.foreach {
case (tp, KafkaOffsetRangeLimit.LATEST) =>
Expand Down Expand Up @@ -211,11 +215,9 @@ private[kafka010] class KafkaOffsetReader(
*/
def fetchEarliestOffsets(): Map[TopicPartition, Long] = runUninterruptibly {
withRetriesWithoutInterrupt {
// Poll to get the latest assigned partitions
consumer.poll(0)
val partitions = consumer.assignment()
val partitions = getPartitions()
consumer.pause(partitions)
logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the beginning")
logDebug(s"Seeking to the beginning")

consumer.seekToBeginning(partitions)
val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap
Expand All @@ -241,17 +243,15 @@ private[kafka010] class KafkaOffsetReader(
def fetchLatestOffsets(
knownOffsets: Option[PartitionOffsetMap]): PartitionOffsetMap = runUninterruptibly {
withRetriesWithoutInterrupt {
// Poll to get the latest assigned partitions
consumer.poll(0)
val partitions = consumer.assignment()
val partitions = getPartitions()

// Call `position` to wait until the potential offset request triggered by `poll(0)` is
// done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by
// `poll(0)` may reset offsets that should have been set by another request.
partitions.asScala.map(p => p -> consumer.position(p)).foreach(_ => {})

consumer.pause(partitions)
logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the end.")
logDebug(s"Seeking to the end.")

if (knownOffsets.isEmpty) {
consumer.seekToEnd(partitions)
Expand Down Expand Up @@ -317,11 +317,8 @@ private[kafka010] class KafkaOffsetReader(
} else {
runUninterruptibly {
withRetriesWithoutInterrupt {
// Poll to get the latest assigned partitions
consumer.poll(0)
val partitions = consumer.assignment()
val partitions = getPartitions()
consumer.pause(partitions)
logDebug(s"\tPartitions assigned to consumer: $partitions")

// Get the earliest offset of each partition
consumer.seekToBeginning(partitions)
Expand Down Expand Up @@ -419,6 +416,21 @@ private[kafka010] class KafkaOffsetReader(
stopConsumer()
_consumer = null // will automatically get reinitialized again
}

private def getPartitions(): ju.Set[TopicPartition] = {
consumer.poll(jt.Duration.ZERO)
var partitions = consumer.assignment()
val startTimeMs = System.currentTimeMillis()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this kind of logic it's better to use System.nanoTime() which is monotonic. Also you can do a little less computation this way:

val deadline = System.nanoTime() + someTimeout;
while (... && System.nanoTime() < deadline) {

}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, since @zsxwing suggested new API usage I would wait here and check the Kafka side.

while (partitions.isEmpty && System.currentTimeMillis() - startTimeMs < pollTimeoutMs) {
gaborgsomogyi marked this conversation as resolved.
Show resolved Hide resolved
// Poll to get the latest assigned partitions
consumer.poll(jt.Duration.ofMillis(100))
gaborgsomogyi marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

@zsxwing zsxwing Aug 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So using this new API will pull data to driver. Right? The previous poll(0) is basically a hack to avoid fetching data to driver. Maybe we should ask the Kafka community to add a new API to pull metadata only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. While Kafka doc says the behavior of such hack has been indeterministic and Kafka never support it officially, we expect such behavior in any way.

I've initiated thread to ask about viable alternatives of poll(0) and possibility of adding public API to update metadata only.
https://lists.apache.org/thread.html/017cf631ef981ab1b494b1249be5c11d7edfe5f4867770a18188ebdc@%3Cdev.kafka.apache.org%3E

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm aware of that but since the doc says:

there is no guarantee that poll(0) won't return records the first time it's called

I've considered poll(0) usage as design decision and no problem if small amount of data comes. Since you say it is not guaranteed but was working like that all the time the situation is different.

@HeartSaVioR thanks for initiating the discussion and let's see where it goes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you're right that's also true as well. One thing slightly different between two is, we are providing small amount of timeout (not 0) and we don't know how much amount of remaining timeout would be used as polling records (instead of polling metadata). It would be unlikely to be exactly 0 as it would be timed out if it goes below 0.

partitions = consumer.assignment()
}
require(!partitions.isEmpty, "Partitions assigned to the Kafka consumer can't be empty. " +
"Setting kafkaConsumer.pollTimeoutMs to a too low value can potentially cause this.")
logDebug(s"Partitions assigned to consumer: $partitions")
partitions
}
}

private[kafka010] object KafkaOffsetReader {
Expand Down
Expand Up @@ -114,7 +114,7 @@ class KafkaDontFailOnDataLossSuite extends StreamTest with KafkaMissingOffsetsTe
"subscribe" -> topic,
"startingOffsets" -> s"""{"$topic":{"0":0}}""",
"failOnDataLoss" -> "false",
"kafkaConsumer.pollTimeoutMs" -> "1000")
"kafkaConsumer.pollTimeoutMs" -> "5000")
val df =
if (testStreamingQuery) {
val reader = spark.readStream.format("kafka")
Expand Down
Expand Up @@ -568,7 +568,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase {
// If a topic is deleted and we try to poll data starting from offset 0,
// the Kafka consumer will just block until timeout and return an empty result.
// So set the timeout to 1 second to make this test fast.
.option("kafkaConsumer.pollTimeoutMs", "1000")
.option("kafkaConsumer.pollTimeoutMs", "5000")
.option("startingOffsets", "earliest")
.option("failOnDataLoss", "false")
val kafka = reader.load()
Expand Down
Expand Up @@ -33,6 +33,12 @@ class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester {
private val pollTimeoutMsMethod = PrivateMethod[Long]('pollTimeoutMs)
private val maxOffsetsPerTriggerMethod = PrivateMethod[Option[Long]]('maxOffsetsPerTrigger)

override protected def beforeEach(): Unit = {
gaborgsomogyi marked this conversation as resolved.
Show resolved Hide resolved
val sparkEnv = mock(classOf[SparkEnv])
when(sparkEnv.conf).thenReturn(new SparkConf())
SparkEnv.set(sparkEnv)
}

override protected def afterEach(): Unit = {
SparkEnv.set(null)
super.afterEach()
Expand All @@ -43,11 +49,6 @@ class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester {
options: CaseInsensitiveStringMap,
expectedPollTimeoutMs: Long,
expectedMaxOffsetsPerTrigger: Option[Long]): Unit = {
// KafkaMicroBatchStream reads Spark conf from SparkEnv for default value
// hence we set mock SparkEnv here before creating KafkaMicroBatchStream
val sparkEnv = mock(classOf[SparkEnv])
when(sparkEnv.conf).thenReturn(new SparkConf())
SparkEnv.set(sparkEnv)

val scan = getKafkaDataSourceScan(options)
val stream = scan.toMicroBatchStream("dummy").asInstanceOf[KafkaMicroBatchStream]
Expand Down
Expand Up @@ -32,7 +32,6 @@ import kafka.server.checkpoints.OffsetCheckpointFile
import kafka.utils.ZkUtils
import org.apache.kafka.clients.CommonClientConfigs
import org.apache.kafka.clients.admin.{AdminClient, CreatePartitionsOptions, ListConsumerGroupsResult, NewPartitions, NewTopic}
import org.apache.kafka.clients.consumer.KafkaConsumer
import org.apache.kafka.clients.producer._
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.network.ListenerName
Expand All @@ -43,6 +42,7 @@ import org.scalatest.time.SpanSugar._

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.kafka010.KafkaSourceProvider.kafkaParamsForDriver
import org.apache.spark.util.{ShutdownHookManager, Utils}

/**
Expand Down Expand Up @@ -283,31 +283,29 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L
}

def getEarliestOffsets(topics: Set[String]): Map[TopicPartition, Long] = {
val kc = new KafkaConsumer[String, String](consumerConfiguration)
logInfo("Created consumer to get earliest offsets")
kc.subscribe(topics.asJavaCollection)
kc.poll(0)
val partitions = kc.assignment()
kc.pause(partitions)
kc.seekToBeginning(partitions)
val offsets = partitions.asScala.map(p => p -> kc.position(p)).toMap
kc.close()
logInfo("Closed consumer to get earliest offsets")
offsets
val reader = getKafkaOffsetReader(topics)
try {
reader.fetchEarliestOffsets()
} finally {
reader.close()
}
}

def getLatestOffsets(topics: Set[String]): Map[TopicPartition, Long] = {
val kc = new KafkaConsumer[String, String](consumerConfiguration)
logInfo("Created consumer to get latest offsets")
kc.subscribe(topics.asJavaCollection)
kc.poll(0)
val partitions = kc.assignment()
kc.pause(partitions)
kc.seekToEnd(partitions)
val offsets = partitions.asScala.map(p => p -> kc.position(p)).toMap
kc.close()
logInfo("Closed consumer to get latest offsets")
offsets
val reader = getKafkaOffsetReader(topics)
try {
reader.fetchLatestOffsets(None)
} finally {
reader.close()
}
}

private def getKafkaOffsetReader(topics: Set[String]): KafkaOffsetReader = {
new KafkaOffsetReader(
SubscribeStrategy(topics.toSeq),
kafkaParamsForDriver(Map("bootstrap.servers" -> brokerAddress)),
Map.empty,
driverGroupIdPrefix = "group-KafkaTestUtils")
}

def listConsumerGroups(): ListConsumerGroupsResult = {
Expand Down Expand Up @@ -363,16 +361,6 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L
}
}

private def consumerConfiguration: Properties = {
val props = new Properties()
props.put("bootstrap.servers", brokerAddress)
props.put("group.id", "group-KafkaTestUtils-" + Random.nextInt)
props.put("value.deserializer", classOf[StringDeserializer].getName)
props.put("key.deserializer", classOf[StringDeserializer].getName)
props.put("enable.auto.commit", "false")
props
}

/** Verify topic is deleted in all places, e.g, brokers, zookeeper. */
private def verifyTopicDeletion(
topic: String,
Expand Down