Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-28423][SQL] Merge Scan and Batch/Stream #25180

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.v2.avro
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.sources.v2.reader.Scan
import org.apache.spark.sql.sources.v2.reader.BatchScan
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

Expand All @@ -30,7 +30,7 @@ class AvroScanBuilder (
dataSchema: StructType,
options: CaseInsensitiveStringMap)
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
override def build(): Scan = {
override def buildForBatch(): BatchScan = {
AvroScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,25 @@ import org.apache.kafka.common.TopicPartition
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReaderFactory}
import org.apache.spark.sql.sources.v2.reader.{BatchScan, InputPartition, PartitionReaderFactory}
import org.apache.spark.sql.types.StructType


private[kafka010] class KafkaBatch(
private[kafka010] class KafkaBatchScan(
strategy: ConsumerStrategy,
sourceOptions: Map[String, String],
specifiedKafkaParams: Map[String, String],
failOnDataLoss: Boolean,
startingOffsets: KafkaOffsetRangeLimit,
endingOffsets: KafkaOffsetRangeLimit)
extends Batch with Logging {
extends BatchScan with Logging {
assert(startingOffsets != LatestOffsetRangeLimit,
"Starting offset not allowed to be set to latest offsets.")
assert(endingOffsets != EarliestOffsetRangeLimit,
"Ending offset not allowed to be set to earliest offsets.")

override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema

private val pollTimeoutMs = sourceOptions.getOrElse(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
(SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L).toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
* A [[ContinuousStream]] for data from kafka.
* A [[ContinuousScan]] for data from kafka.
*
* @param offsetReader a reader used to get kafka offsets. Note that the actual data will be
* read by per-task consumers generated later.
Expand All @@ -45,14 +46,16 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
* scenarios, where some offsets after the specified initial ones can't be
* properly read.
*/
class KafkaContinuousStream(
class KafkaContinuousScan(
offsetReader: KafkaOffsetReader,
kafkaParams: ju.Map[String, Object],
options: CaseInsensitiveStringMap,
metadataPath: String,
initialOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
extends ContinuousStream with Logging {
extends ContinuousScan with Logging {

override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema

private val pollTimeoutMs =
options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,17 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset}
import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchStream
import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchScan
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchScan, Offset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.UninterruptibleThread

/**
* A [[MicroBatchStream]] that reads data from Kafka.
* A [[MicroBatchScan]] that reads data from Kafka.
*
* The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains
* a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For
Expand All @@ -55,13 +54,15 @@ import org.apache.spark.util.UninterruptibleThread
* To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers
* and not use wrong broker addresses.
*/
private[kafka010] class KafkaMicroBatchStream(
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
private[kafka010] class KafkaMicroBatchScan(
kafkaOffsetReader: KafkaOffsetReader,
executorKafkaParams: ju.Map[String, Object],
options: CaseInsensitiveStringMap,
metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean) extends RateControlMicroBatchStream with Logging {
failOnDataLoss: Boolean) extends RateControlMicroBatchScan with Logging {

override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema

private val pollTimeoutMs = options.getLong(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.TableCapability._
import org.apache.spark.sql.sources.v2.reader.{Batch, Scan, ScanBuilder}
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream}
import org.apache.spark.sql.sources.v2.reader.{BatchScan, ScanBuilder}
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousScan, MicroBatchScan}
import org.apache.spark.sql.sources.v2.writer.{BatchWrite, WriteBuilder}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite
import org.apache.spark.sql.streaming.OutputMode
Expand Down Expand Up @@ -368,8 +368,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
ACCEPT_ANY_SCHEMA).asJava
}

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
() => new KafkaScan(options)
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new KafkaScanBuilder(options)
}

override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = {
new WriteBuilder {
Expand All @@ -395,11 +396,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}
}

class KafkaScan(options: CaseInsensitiveStringMap) extends Scan {

override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema
class KafkaScanBuilder(options: CaseInsensitiveStringMap) extends ScanBuilder {

override def toBatch(): Batch = {
override def buildForBatch(): BatchScan = {
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
validateBatchOptions(caseInsensitiveOptions)
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)
Expand All @@ -410,7 +409,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveOptions, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)

new KafkaBatch(
new KafkaBatchScan(
strategy(caseInsensitiveOptions),
caseInsensitiveOptions,
specifiedKafkaParams,
Expand All @@ -419,7 +418,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
endingRelationOffsets)
}

override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
override def buildForMicroBatchStreaming(checkpointLocation: String): MicroBatchScan = {
val parameters = options.asScala.toMap
validateStreamOptions(parameters)
// Each running query should use its own group id. Otherwise, the query may be only assigned
Expand All @@ -439,7 +438,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
parameters,
driverGroupIdPrefix = s"$uniqueGroupId-driver")

new KafkaMicroBatchStream(
new KafkaMicroBatchScan(
kafkaOffsetReader,
kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
options,
Expand All @@ -448,7 +447,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
failOnDataLoss(caseInsensitiveParams))
}

override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
override def buildForContinuousStreaming(checkpointLocation: String): ContinuousScan = {
val parameters = options.asScala.toMap
validateStreamOptions(parameters)
// Each running query should use its own group id. Otherwise, the query may be only assigned
Expand All @@ -473,7 +472,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
parameters,
driverGroupIdPrefix = s"$uniqueGroupId-driver")

new KafkaContinuousStream(
new KafkaContinuousScan(
kafkaOffsetReader,
kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest {
assert(
query.lastExecution.executedPlan.collectFirst {
case scan: ContinuousScanExec
if scan.stream.isInstanceOf[KafkaContinuousStream] =>
scan.stream.asInstanceOf[KafkaContinuousStream]
if scan.scan.isInstanceOf[KafkaContinuousScan] =>
scan.scan.asInstanceOf[KafkaContinuousScan]
}.exists { stream =>
// Ensure the new topic is present and the old topic is gone.
stream.knownPartitions.exists(_.topic == topic2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ trait KafkaContinuousTest extends KafkaSourceTest {
assert(
query.lastExecution.executedPlan.collectFirst {
case scan: ContinuousScanExec
if scan.stream.isInstanceOf[KafkaContinuousStream] =>
scan.stream.asInstanceOf[KafkaContinuousStream]
if scan.scan.isInstanceOf[KafkaContinuousScan] =>
scan.scan.asInstanceOf[KafkaContinuousScan]
}.exists(_.knownPartitions.size == newCount),
s"query never reconfigured to $newCount partitions")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.functions.{count, window}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream
import org.apache.spark.sql.sources.v2.reader.streaming.StreamingScan
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.test.SharedSQLContext
Expand Down Expand Up @@ -95,7 +95,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with Kaf
message: String = "",
topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData {

override def addData(query: Option[StreamExecution]): (SparkDataStream, Offset) = {
override def addData(query: Option[StreamExecution]): (StreamingScan, Offset) = {
query match {
// Make sure no Spark job is running when deleting a topic
case Some(m: MicroBatchExecution) => m.processAllAvailable()
Expand All @@ -115,12 +115,12 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with Kaf
query.nonEmpty,
"Cannot add data when there is no query for finding the active kafka source")

val sources: Seq[SparkDataStream] = {
val sources: Seq[StreamingScan] = {
query.get.logicalPlan.collect {
case StreamingExecutionRelation(source: KafkaSource, _) => source
case r: StreamingDataSourceV2Relation if r.stream.isInstanceOf[KafkaMicroBatchStream] ||
r.stream.isInstanceOf[KafkaContinuousStream] =>
r.stream
case r: StreamingDataSourceV2Relation if r.scan.isInstanceOf[KafkaMicroBatchScan] ||
r.scan.isInstanceOf[KafkaContinuousScan] =>
r.scan
}
}.distinct

Expand Down Expand Up @@ -1111,7 +1111,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase {
makeSureGetOffsetCalled,
AssertOnQuery { query =>
query.logicalPlan.find {
case r: StreamingDataSourceV2Relation => r.stream.isInstanceOf[KafkaMicroBatchStream]
case r: StreamingDataSourceV2Relation => r.scan.isInstanceOf[KafkaMicroBatchScan]
case _ => false
}.isDefined
}
Expand Down Expand Up @@ -1139,8 +1139,9 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase {
) ++ Option(minPartitions).map { p => "minPartitions" -> p}
val dsOptions = new CaseInsensitiveStringMap(options.asJava)
val table = provider.getTable(dsOptions)
val stream = table.newScanBuilder(dsOptions).build().toMicroBatchStream(dir.getAbsolutePath)
val inputPartitions = stream.planInputPartitions(
val scan = table.newScanBuilder(dsOptions)
.buildForMicroBatchStreaming(dir.getAbsolutePath)
val inputPartitions = scan.planInputPartitions(
KafkaSourceOffset(Map(tp -> 0L)),
KafkaSourceOffset(Map(tp -> 100L))).map(_.asInstanceOf[KafkaBatchInputPartition])
withClue(s"minPartitions = $minPartitions generated factories $inputPartitions\n\t") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import java.util.Locale
import scala.collection.JavaConverters._

import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
import org.scalatest.PrivateMethodTester

import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite}
import org.apache.spark.sql.sources.v2.reader.Scan
import org.apache.spark.sql.sources.v2.reader.ScanBuilder
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester {
Expand All @@ -39,43 +39,45 @@ class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester {
}

test("micro-batch mode - options should be handled as case-insensitive") {
def verifyFieldsInMicroBatchStream(
def verifyFieldsInMicroBatchScan(
options: CaseInsensitiveStringMap,
expectedPollTimeoutMs: Long,
expectedMaxOffsetsPerTrigger: Option[Long]): Unit = {
// KafkaMicroBatchStream reads Spark conf from SparkEnv for default value
// hence we set mock SparkEnv here before creating KafkaMicroBatchStream
// KafkaMicroBatchScan reads Spark conf from SparkEnv for default value
// hence we set mock SparkEnv here before creating KafkaMicroBatchScan
val sparkEnv = mock(classOf[SparkEnv])
when(sparkEnv.conf).thenReturn(new SparkConf())
SparkEnv.set(sparkEnv)

val scan = getKafkaDataSourceScan(options)
val stream = scan.toMicroBatchStream("dummy").asInstanceOf[KafkaMicroBatchStream]
val builder = getKafkaDataSourceScanBuilder(options)
val scan = builder.buildForMicroBatchStreaming("dummy")
.asInstanceOf[KafkaMicroBatchScan]

assert(expectedPollTimeoutMs === getField(stream, pollTimeoutMsMethod))
assert(expectedMaxOffsetsPerTrigger === getField(stream, maxOffsetsPerTriggerMethod))
assert(expectedPollTimeoutMs === getField(scan, pollTimeoutMsMethod))
assert(expectedMaxOffsetsPerTrigger === getField(scan, maxOffsetsPerTriggerMethod))
}

val expectedValue = 1000L
buildCaseInsensitiveStringMapForUpperAndLowerKey(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT -> expectedValue.toString,
KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER -> expectedValue.toString)
.foreach(verifyFieldsInMicroBatchStream(_, expectedValue, Some(expectedValue)))
.foreach(verifyFieldsInMicroBatchScan(_, expectedValue, Some(expectedValue)))
}

test("SPARK-28142 - continuous mode - options should be handled as case-insensitive") {
def verifyFieldsInContinuousStream(
def verifyFieldsInContinuousScan(
options: CaseInsensitiveStringMap,
expectedPollTimeoutMs: Long): Unit = {
val scan = getKafkaDataSourceScan(options)
val stream = scan.toContinuousStream("dummy").asInstanceOf[KafkaContinuousStream]
assert(expectedPollTimeoutMs === getField(stream, pollTimeoutMsMethod))
val builder = getKafkaDataSourceScanBuilder(options)
val scan = builder.buildForContinuousStreaming("dummy")
.asInstanceOf[KafkaContinuousScan]
assert(expectedPollTimeoutMs === getField(scan, pollTimeoutMsMethod))
}

val expectedValue = 1000
buildCaseInsensitiveStringMapForUpperAndLowerKey(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT -> expectedValue.toString)
.foreach(verifyFieldsInContinuousStream(_, expectedValue))
.foreach(verifyFieldsInContinuousScan(_, expectedValue))
}

private def buildCaseInsensitiveStringMapForUpperAndLowerKey(
Expand All @@ -91,9 +93,9 @@ class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester {
new CaseInsensitiveStringMap((options.toMap ++ requiredOptions).asJava)
}

private def getKafkaDataSourceScan(options: CaseInsensitiveStringMap): Scan = {
private def getKafkaDataSourceScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
val provider = new KafkaSourceProvider()
provider.getTable(options).newScanBuilder(options).build()
provider.getTable(options).newScanBuilder(options)
}

private def getField[T](obj: AnyRef, method: PrivateMethod[T]): T = {
Expand Down