Skip to content

Commit

Permalink
[SPARK-28367][SS] Use new KafkaConsumer.poll API in Kafka connector
Browse files Browse the repository at this point in the history
  • Loading branch information
gaborgsomogyi committed Jul 12, 2019
1 parent 19bcce1 commit ad1863a
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 30 deletions.
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.kafka010

import java.{util => ju}
import java.time.Duration
import java.util.concurrent.TimeoutException

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -471,7 +472,7 @@ private[kafka010] case class InternalKafkaConsumer(
private def fetchData(offset: Long, pollTimeoutMs: Long): Unit = {
// Seek to the offset because we may call seekToBeginning or seekToEnd before this.
seek(offset)
val p = consumer.poll(pollTimeoutMs)
val p = consumer.poll(Duration.ofMillis(pollTimeoutMs))
val r = p.records(topicPartition)
logDebug(s"Polled $groupId ${p.partitions()} ${r.size}")
val offsetAfterPoll = consumer.position(topicPartition)
Expand Down
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.kafka010

import java.{time => jt}
import java.{util => ju}
import java.util.concurrent.Executors

Expand All @@ -29,7 +30,9 @@ import scala.util.control.NonFatal
import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer}
import org.apache.kafka.common.TopicPartition

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.sql.types._
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}

Expand All @@ -49,6 +52,11 @@ private[kafka010] class KafkaOffsetReader(
val driverKafkaParams: ju.Map[String, Object],
readerOptions: Map[String, String],
driverGroupIdPrefix: String) extends Logging {
private val pollTimeoutMs = readerOptions.getOrElse(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
(SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L).toString
).toLong

/**
* Used to ensure execute fetch operations execute in an UninterruptibleThread
*/
Expand Down Expand Up @@ -115,9 +123,7 @@ private[kafka010] class KafkaOffsetReader(
*/
def fetchTopicPartitions(): Set[TopicPartition] = runUninterruptibly {
assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])
// Poll to get the latest assigned partitions
consumer.poll(0)
val partitions = consumer.assignment()
val partitions = getPartitions()
consumer.pause(partitions)
partitions.asScala.toSet
}
Expand Down Expand Up @@ -163,9 +169,7 @@ private[kafka010] class KafkaOffsetReader(
reportDataLoss: String => Unit): KafkaSourceOffset = {
val fetched = runUninterruptibly {
withRetriesWithoutInterrupt {
// Poll to get the latest assigned partitions
consumer.poll(0)
val partitions = consumer.assignment()
val partitions = getPartitions()

// Call `position` to wait until the potential offset request triggered by `poll(0)` is
// done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by
Expand All @@ -177,7 +181,7 @@ private[kafka010] class KafkaOffsetReader(
"If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" +
"Use -1 for latest, -2 for earliest, if you don't care.\n" +
s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions.asScala}")
logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets")
logDebug(s"Seeking to $partitionOffsets")

partitionOffsets.foreach {
case (tp, KafkaOffsetRangeLimit.LATEST) =>
Expand Down Expand Up @@ -211,11 +215,9 @@ private[kafka010] class KafkaOffsetReader(
*/
def fetchEarliestOffsets(): Map[TopicPartition, Long] = runUninterruptibly {
withRetriesWithoutInterrupt {
// Poll to get the latest assigned partitions
consumer.poll(0)
val partitions = consumer.assignment()
val partitions = getPartitions()
consumer.pause(partitions)
logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the beginning")
logDebug(s"Seeking to the beginning")

consumer.seekToBeginning(partitions)
val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap
Expand All @@ -241,17 +243,15 @@ private[kafka010] class KafkaOffsetReader(
def fetchLatestOffsets(
knownOffsets: Option[PartitionOffsetMap]): PartitionOffsetMap = runUninterruptibly {
withRetriesWithoutInterrupt {
// Poll to get the latest assigned partitions
consumer.poll(0)
val partitions = consumer.assignment()
val partitions = getPartitions()

// Call `position` to wait until the potential offset request triggered by `poll(0)` is
// done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by
// `poll(0)` may reset offsets that should have been set by another request.
partitions.asScala.map(p => p -> consumer.position(p)).foreach(_ => {})

consumer.pause(partitions)
logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the end.")
logDebug(s"Seeking to the end.")

if (knownOffsets.isEmpty) {
consumer.seekToEnd(partitions)
Expand Down Expand Up @@ -317,11 +317,8 @@ private[kafka010] class KafkaOffsetReader(
} else {
runUninterruptibly {
withRetriesWithoutInterrupt {
// Poll to get the latest assigned partitions
consumer.poll(0)
val partitions = consumer.assignment()
val partitions = getPartitions()
consumer.pause(partitions)
logDebug(s"\tPartitions assigned to consumer: $partitions")

// Get the earliest offset of each partition
consumer.seekToBeginning(partitions)
Expand Down Expand Up @@ -419,6 +416,19 @@ private[kafka010] class KafkaOffsetReader(
stopConsumer()
_consumer = null // will automatically get reinitialized again
}

private def getPartitions(): ju.Set[TopicPartition] = {
var partitions = Set.empty[TopicPartition].asJava
val startTimeMs = System.currentTimeMillis()
while (partitions.isEmpty && System.currentTimeMillis() - startTimeMs < pollTimeoutMs) {
// Poll to get the latest assigned partitions
consumer.poll(jt.Duration.ZERO)
partitions = consumer.assignment()
}
require(!partitions.isEmpty)
logDebug(s"Partitions assigned to consumer: $partitions")
partitions
}
}

private[kafka010] object KafkaOffsetReader {
Expand Down
Expand Up @@ -33,6 +33,12 @@ class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester {
private val pollTimeoutMsMethod = PrivateMethod[Long]('pollTimeoutMs)
private val maxOffsetsPerTriggerMethod = PrivateMethod[Option[Long]]('maxOffsetsPerTrigger)

override protected def beforeEach(): Unit = {
val sparkEnv = mock(classOf[SparkEnv])
when(sparkEnv.conf).thenReturn(new SparkConf())
SparkEnv.set(sparkEnv)
}

override protected def afterEach(): Unit = {
SparkEnv.set(null)
super.afterEach()
Expand All @@ -43,11 +49,6 @@ class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester {
options: CaseInsensitiveStringMap,
expectedPollTimeoutMs: Long,
expectedMaxOffsetsPerTrigger: Option[Long]): Unit = {
// KafkaMicroBatchStream reads Spark conf from SparkEnv for default value
// hence we set mock SparkEnv here before creating KafkaMicroBatchStream
val sparkEnv = mock(classOf[SparkEnv])
when(sparkEnv.conf).thenReturn(new SparkConf())
SparkEnv.set(sparkEnv)

val scan = getKafkaDataSourceScan(options)
val stream = scan.toMicroBatchStream("dummy").asInstanceOf[KafkaMicroBatchStream]
Expand Down
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark.sql.kafka010
import java.io.{File, IOException}
import java.lang.{Integer => JInt}
import java.net.InetSocketAddress
import java.util.{Collections, Map => JMap, Properties, UUID}
import java.time.Duration
import java.util.{Collections, Map => JMap, Properties, Set => JSet, UUID}
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -286,8 +287,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L
val kc = new KafkaConsumer[String, String](consumerConfiguration)
logInfo("Created consumer to get earliest offsets")
kc.subscribe(topics.asJavaCollection)
kc.poll(0)
val partitions = kc.assignment()
val partitions = getPartitions(kc)
kc.pause(partitions)
kc.seekToBeginning(partitions)
val offsets = partitions.asScala.map(p => p -> kc.position(p)).toMap
Expand All @@ -300,8 +300,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L
val kc = new KafkaConsumer[String, String](consumerConfiguration)
logInfo("Created consumer to get latest offsets")
kc.subscribe(topics.asJavaCollection)
kc.poll(0)
val partitions = kc.assignment()
val partitions = getPartitions(kc)
kc.pause(partitions)
kc.seekToEnd(partitions)
val offsets = partitions.asScala.map(p => p -> kc.position(p)).toMap
Expand All @@ -310,6 +309,20 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L
offsets
}

private def getPartitions(consumer: KafkaConsumer[String, String]): JSet[TopicPartition] = {
var partitions = Set.empty[TopicPartition].asJava
val startTimeMs = System.currentTimeMillis()
val timeoutMs = timeout(1.minute).value.toMillis
while (partitions.isEmpty&& System.currentTimeMillis() - startTimeMs < timeoutMs) {
// Poll to get the latest assigned partitions
consumer.poll(Duration.ZERO)
partitions = consumer.assignment()
}
require(!partitions.isEmpty)
logDebug(s"Partitions assigned to consumer: $partitions")
partitions
}

def listConsumerGroups(): ListConsumerGroupsResult = {
adminClient.listConsumerGroups()
}
Expand Down

0 comments on commit ad1863a

Please sign in to comment.