diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala index 3b73896d631c6..829ee15c13a3d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala @@ -57,7 +57,7 @@ private[kafka010] class KafkaMicroBatchStream( metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends SupportsAdmissionControl with ReportsSourceMetrics with MicroBatchStream with Logging { + extends SupportsTriggerAvailableNow with ReportsSourceMetrics with MicroBatchStream with Logging { private[kafka010] val pollTimeoutMs = options.getLong( KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, @@ -81,6 +81,8 @@ private[kafka010] class KafkaMicroBatchStream( private var latestPartitionOffsets: PartitionOffsetMap = _ + private var allDataForTriggerAvailableNow: PartitionOffsetMap = _ + /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * called in StreamExecutionThread. Otherwise, interrupting a thread while running @@ -98,7 +100,8 @@ private[kafka010] class KafkaMicroBatchStream( } else if (minOffsetPerTrigger.isDefined) { ReadLimit.minRows(minOffsetPerTrigger.get, maxTriggerDelayMs) } else { - maxOffsetsPerTrigger.map(ReadLimit.maxRows).getOrElse(super.getDefaultReadLimit) + // TODO (SPARK-37973) Directly call super.getDefaultReadLimit when scala issue 12523 is fixed + maxOffsetsPerTrigger.map(ReadLimit.maxRows).getOrElse(ReadLimit.allAvailable()) } } @@ -113,7 +116,13 @@ private[kafka010] class KafkaMicroBatchStream( override def latestOffset(start: Offset, readLimit: ReadLimit): Offset = { val startPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets - latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets(Some(startPartitionOffsets)) + + // Use the pre-fetched list of partition offsets when Trigger.AvailableNow is enabled. + latestPartitionOffsets = if (allDataForTriggerAvailableNow != null) { + allDataForTriggerAvailableNow + } else { + kafkaOffsetReader.fetchLatestOffsets(Some(startPartitionOffsets)) + } val limits: Seq[ReadLimit] = readLimit match { case rows: CompositeReadLimit => rows.getReadLimits @@ -298,6 +307,11 @@ private[kafka010] class KafkaMicroBatchStream( logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") } } + + override def prepareForTriggerAvailableNow(): Unit = { + allDataForTriggerAvailableNow = kafkaOffsetReader.fetchLatestOffsets( + Some(getOrCreateInitialPartitionOffsets())) + } } object KafkaMicroBatchStream extends Logging { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 87cef02d0d8f2..09db0a7e82dfe 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -77,7 +77,7 @@ private[kafka010] class KafkaSource( metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends SupportsAdmissionControl with Source with Logging { + extends SupportsTriggerAvailableNow with Source with Logging { private val sc = sqlContext.sparkContext @@ -99,6 +99,8 @@ private[kafka010] class KafkaSource( private var lastTriggerMillis = 0L + private var allDataForTriggerAvailableNow: PartitionOffsetMap = _ + /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * called in StreamExecutionThread. Otherwise, interrupting a thread while running @@ -130,7 +132,8 @@ private[kafka010] class KafkaSource( } else if (minOffsetPerTrigger.isDefined) { ReadLimit.minRows(minOffsetPerTrigger.get, maxTriggerDelayMs) } else { - maxOffsetsPerTrigger.map(ReadLimit.maxRows).getOrElse(super.getDefaultReadLimit) + // TODO (SPARK-37973) Directly call super.getDefaultReadLimit when scala issue 12523 is fixed + maxOffsetsPerTrigger.map(ReadLimit.maxRows).getOrElse(ReadLimit.allAvailable()) } } @@ -159,7 +162,14 @@ private[kafka010] class KafkaSource( // Make sure initialPartitionOffsets is initialized initialPartitionOffsets val currentOffsets = currentPartitionOffsets.orElse(Some(initialPartitionOffsets)) - val latest = kafkaReader.fetchLatestOffsets(currentOffsets) + + // Use the pre-fetched list of partition offsets when Trigger.AvailableNow is enabled. + val latest = if (allDataForTriggerAvailableNow != null) { + allDataForTriggerAvailableNow + } else { + kafkaReader.fetchLatestOffsets(currentOffsets) + } + latestPartitionOffsets = Some(latest) val limits: Seq[ReadLimit] = limit match { @@ -331,6 +341,10 @@ private[kafka010] class KafkaSource( logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") } } + + override def prepareForTriggerAvailableNow(): Unit = { + allDataForTriggerAvailableNow = kafkaReader.fetchLatestOffsets(Some(initialPartitionOffsets)) + } } /** Companion object for the [[KafkaSource]]. */ diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index f61696f6485e6..61be7dd6cd8ef 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.streaming.{StreamingQuery, StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -195,6 +195,45 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { true } + test("Trigger.AvailableNow") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + testUtils.sendMessages(topic, (0 until 15).map { case x => + s"foo-$x" + }.toArray, Some(0)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 5) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + .load() + + var index: Int = 0 + def startTriggerAvailableNowQuery(): StreamingQuery = { + reader.writeStream + .foreachBatch((_: Dataset[Row], _: Long) => { + index += 1 + }) + .trigger(Trigger.AvailableNow) + .start() + } + + val query = startTriggerAvailableNowQuery() + try { + assert(query.awaitTermination(streamingTimeout.toMillis)) + } finally { + query.stop() + } + + // should have 3 batches now i.e. 15 / 5 = 3 + assert(index == 3) + } + test("(de)serialization of initial offsets") { val topic = newTopic() testUtils.createTopic(topic, partitions = 5)