diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index cbb99fd7118e0..1190af30d78e5 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} +import java.time.Duration import java.util.concurrent.TimeoutException import scala.collection.JavaConverters._ @@ -471,7 +472,7 @@ private[kafka010] case class InternalKafkaConsumer( private def fetchData(offset: Long, pollTimeoutMs: Long): Unit = { // Seek to the offset because we may call seekToBeginning or seekToEnd before this. seek(offset) - val p = consumer.poll(pollTimeoutMs) + val p = consumer.poll(Duration.ofMillis(pollTimeoutMs)) val r = p.records(topicPartition) logDebug(s"Polled $groupId ${p.partitions()} ${r.size}") val offsetAfterPoll = consumer.position(topicPartition) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index ad608ecafe59f..147c427cef8da 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.kafka010 +import java.{time => jt} import java.{util => ju} import java.util.concurrent.Executors @@ -29,7 +30,9 @@ import scala.util.control.NonFatal import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer} import org.apache.kafka.common.TopicPartition +import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.sql.types._ import org.apache.spark.util.{ThreadUtils, UninterruptibleThread} @@ -49,6 +52,11 @@ private[kafka010] class KafkaOffsetReader( val driverKafkaParams: ju.Map[String, Object], readerOptions: Map[String, String], driverGroupIdPrefix: String) extends Logging { + private val pollTimeoutMs = readerOptions.getOrElse( + KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, + (SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L).toString + ).toLong + /** * Used to ensure execute fetch operations execute in an UninterruptibleThread */ @@ -115,9 +123,7 @@ private[kafka010] class KafkaOffsetReader( */ def fetchTopicPartitions(): Set[TopicPartition] = runUninterruptibly { assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) - // Poll to get the latest assigned partitions - consumer.poll(0) - val partitions = consumer.assignment() + val partitions = getPartitions() consumer.pause(partitions) partitions.asScala.toSet } @@ -163,9 +169,7 @@ private[kafka010] class KafkaOffsetReader( reportDataLoss: String => Unit): KafkaSourceOffset = { val fetched = runUninterruptibly { withRetriesWithoutInterrupt { - // Poll to get the latest assigned partitions - consumer.poll(0) - val partitions = consumer.assignment() + val partitions = getPartitions() // Call `position` to wait until the potential offset request triggered by `poll(0)` is // done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by @@ -177,7 +181,7 @@ private[kafka010] class KafkaOffsetReader( "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" + "Use -1 for latest, -2 for earliest, if you don't care.\n" + s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions.asScala}") - logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets") + logDebug(s"Seeking to $partitionOffsets") partitionOffsets.foreach { case (tp, KafkaOffsetRangeLimit.LATEST) => @@ -211,11 +215,9 @@ private[kafka010] class KafkaOffsetReader( */ def fetchEarliestOffsets(): Map[TopicPartition, Long] = runUninterruptibly { withRetriesWithoutInterrupt { - // Poll to get the latest assigned partitions - consumer.poll(0) - val partitions = consumer.assignment() + val partitions = getPartitions() consumer.pause(partitions) - logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the beginning") + logDebug(s"Seeking to the beginning") consumer.seekToBeginning(partitions) val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap @@ -241,9 +243,7 @@ private[kafka010] class KafkaOffsetReader( def fetchLatestOffsets( knownOffsets: Option[PartitionOffsetMap]): PartitionOffsetMap = runUninterruptibly { withRetriesWithoutInterrupt { - // Poll to get the latest assigned partitions - consumer.poll(0) - val partitions = consumer.assignment() + val partitions = getPartitions() // Call `position` to wait until the potential offset request triggered by `poll(0)` is // done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by @@ -251,7 +251,7 @@ private[kafka010] class KafkaOffsetReader( partitions.asScala.map(p => p -> consumer.position(p)).foreach(_ => {}) consumer.pause(partitions) - logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the end.") + logDebug(s"Seeking to the end.") if (knownOffsets.isEmpty) { consumer.seekToEnd(partitions) @@ -317,11 +317,8 @@ private[kafka010] class KafkaOffsetReader( } else { runUninterruptibly { withRetriesWithoutInterrupt { - // Poll to get the latest assigned partitions - consumer.poll(0) - val partitions = consumer.assignment() + val partitions = getPartitions() consumer.pause(partitions) - logDebug(s"\tPartitions assigned to consumer: $partitions") // Get the earliest offset of each partition consumer.seekToBeginning(partitions) @@ -419,6 +416,19 @@ private[kafka010] class KafkaOffsetReader( stopConsumer() _consumer = null // will automatically get reinitialized again } + + private def getPartitions(): ju.Set[TopicPartition] = { + var partitions = Set.empty[TopicPartition].asJava + val startTimeMs = System.currentTimeMillis() + while (partitions.isEmpty && System.currentTimeMillis() - startTimeMs < pollTimeoutMs) { + // Poll to get the latest assigned partitions + consumer.poll(jt.Duration.ZERO) + partitions = consumer.assignment() + } + require(!partitions.isEmpty) + logDebug(s"Partitions assigned to consumer: $partitions") + partitions + } } private[kafka010] object KafkaOffsetReader { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala index 2fcf37a184684..9162c1d92b881 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala @@ -33,6 +33,12 @@ class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester { private val pollTimeoutMsMethod = PrivateMethod[Long]('pollTimeoutMs) private val maxOffsetsPerTriggerMethod = PrivateMethod[Option[Long]]('maxOffsetsPerTrigger) + override protected def beforeEach(): Unit = { + val sparkEnv = mock(classOf[SparkEnv]) + when(sparkEnv.conf).thenReturn(new SparkConf()) + SparkEnv.set(sparkEnv) + } + override protected def afterEach(): Unit = { SparkEnv.set(null) super.afterEach() @@ -43,11 +49,6 @@ class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester { options: CaseInsensitiveStringMap, expectedPollTimeoutMs: Long, expectedMaxOffsetsPerTrigger: Option[Long]): Unit = { - // KafkaMicroBatchStream reads Spark conf from SparkEnv for default value - // hence we set mock SparkEnv here before creating KafkaMicroBatchStream - val sparkEnv = mock(classOf[SparkEnv]) - when(sparkEnv.conf).thenReturn(new SparkConf()) - SparkEnv.set(sparkEnv) val scan = getKafkaDataSourceScan(options) val stream = scan.toMicroBatchStream("dummy").asInstanceOf[KafkaMicroBatchStream] diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index f2e4ee71450e6..b30463b161c87 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.kafka010 import java.io.{File, IOException} import java.lang.{Integer => JInt} import java.net.InetSocketAddress -import java.util.{Collections, Map => JMap, Properties, UUID} +import java.time.Duration +import java.util.{Collections, Map => JMap, Properties, Set => JSet, UUID} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -286,8 +287,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L val kc = new KafkaConsumer[String, String](consumerConfiguration) logInfo("Created consumer to get earliest offsets") kc.subscribe(topics.asJavaCollection) - kc.poll(0) - val partitions = kc.assignment() + val partitions = getPartitions(kc) kc.pause(partitions) kc.seekToBeginning(partitions) val offsets = partitions.asScala.map(p => p -> kc.position(p)).toMap @@ -300,8 +300,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L val kc = new KafkaConsumer[String, String](consumerConfiguration) logInfo("Created consumer to get latest offsets") kc.subscribe(topics.asJavaCollection) - kc.poll(0) - val partitions = kc.assignment() + val partitions = getPartitions(kc) kc.pause(partitions) kc.seekToEnd(partitions) val offsets = partitions.asScala.map(p => p -> kc.position(p)).toMap @@ -310,6 +309,20 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L offsets } + private def getPartitions(consumer: KafkaConsumer[String, String]): JSet[TopicPartition] = { + var partitions = Set.empty[TopicPartition].asJava + val startTimeMs = System.currentTimeMillis() + val timeoutMs = timeout(1.minute).value.toMillis + while (partitions.isEmpty&& System.currentTimeMillis() - startTimeMs < timeoutMs) { + // Poll to get the latest assigned partitions + consumer.poll(Duration.ZERO) + partitions = consumer.assignment() + } + require(!partitions.isEmpty) + logDebug(s"Partitions assigned to consumer: $partitions") + partitions + } + def listConsumerGroups(): ListConsumerGroupsResult = { adminClient.listConsumerGroups() }