Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-23325: Use InternalRow when reading with DataSourceV2. #21118

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.kafka.common.TopicPartition
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.reader._
Expand Down Expand Up @@ -53,7 +54,7 @@ class KafkaContinuousReader(
metadataPath: String,
initialOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
extends ContinuousReader with SupportsScanUnsafeRow with Logging {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we get rid of KafkaRecordToUnsafeRowConverter? Since Spark would do a unsafe projection at the end, here we should just return GenericInternalRow instead of UnsafeRow, to save data cpoy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, but this is intended to make minimal changes. We can add optimizations like this in a follow-up.

extends ContinuousReader with Logging {

private lazy val session = SparkSession.getActiveSession.get
private lazy val sc = session.sparkContext
Expand Down Expand Up @@ -86,7 +87,7 @@ class KafkaContinuousReader(
KafkaSourceOffset(JsonUtils.partitionOffsets(json))
}

override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = {
override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = {
import scala.collection.JavaConverters._

val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset)
Expand All @@ -107,8 +108,8 @@ class KafkaContinuousReader(
startOffsets.toSeq.map {
case (topicPartition, start) =>
KafkaContinuousInputPartition(
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
.asInstanceOf[InputPartition[UnsafeRow]]
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss
): InputPartition[InternalRow]
}.asJava
}

Expand Down Expand Up @@ -161,9 +162,10 @@ case class KafkaContinuousInputPartition(
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends ContinuousInputPartition[UnsafeRow] {
failOnDataLoss: Boolean) extends ContinuousInputPartition[InternalRow] {

override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[UnsafeRow] = {
override def createContinuousReader(
offset: PartitionOffset): InputPartitionReader[InternalRow] = {
val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset]
require(kafkaOffset.topicPartition == topicPartition,
s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}")
Expand Down Expand Up @@ -192,7 +194,7 @@ class KafkaContinuousInputPartitionReader(
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[UnsafeRow] {
failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[InternalRow] {
private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false)
private val converter = new KafkaRecordToUnsafeRowConverter

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset}
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow}
import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.UninterruptibleThread
Expand Down Expand Up @@ -61,7 +62,7 @@ private[kafka010] class KafkaMicroBatchReader(
metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
extends MicroBatchReader with Logging {

private var startPartitionOffsets: PartitionOffsetMap = _
private var endPartitionOffsets: PartitionOffsetMap = _
Expand Down Expand Up @@ -101,7 +102,7 @@ private[kafka010] class KafkaMicroBatchReader(
}
}

override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = {
override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = {
// Find the new partitions, and get their earliest offsets
val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet)
val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
Expand Down Expand Up @@ -142,11 +143,11 @@ private[kafka010] class KafkaMicroBatchReader(
val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size

// Generate factories based on the offset ranges
val factories = offsetRanges.map { range =>
offsetRanges.map { range =>
new KafkaMicroBatchInputPartition(
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
}
factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer
): InputPartition[InternalRow]
}.asJava
}

override def getStartOffset: Offset = {
Expand Down Expand Up @@ -305,11 +306,11 @@ private[kafka010] case class KafkaMicroBatchInputPartition(
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean,
reuseKafkaConsumer: Boolean) extends InputPartition[UnsafeRow] {
reuseKafkaConsumer: Boolean) extends InputPartition[InternalRow] {

override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray

override def createPartitionReader(): InputPartitionReader[UnsafeRow] =
override def createPartitionReader(): InputPartitionReader[InternalRow] =
new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs,
failOnDataLoss, reuseKafkaConsumer)
}
Expand All @@ -320,7 +321,7 @@ private[kafka010] case class KafkaMicroBatchInputPartitionReader(
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean,
reuseKafkaConsumer: Boolean) extends InputPartitionReader[UnsafeRow] with Logging {
reuseKafkaConsumer: Boolean) extends InputPartitionReader[InternalRow] with Logging {

private val consumer = KafkaDataConsumer.acquire(
offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase {
Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))),
Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L)))
)
val factories = reader.planUnsafeInputPartitions().asScala
val factories = reader.planInputPartitions().asScala
.map(_.asInstanceOf[KafkaMicroBatchInputPartition])
withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") {
assert(factories.size == numPartitionsGenerated)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.util.List;

import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.ReadSupport;
import org.apache.spark.sql.sources.v2.ReadSupportWithSchema;
Expand All @@ -43,7 +43,7 @@
* Names of these interfaces start with `SupportsScan`. Note that a reader should only
* implement at most one of the special scans, if more than one special scans are implemented,
* only one of them would be respected, according to the priority list from high to low:
* {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}.
* {@link SupportsScanColumnarBatch}, {@link SupportsDeprecatedScanRow}.
*
* If an exception was throw when applying any of these query optimizations, the action will fail
* and no Spark job will be submitted.
Expand Down Expand Up @@ -76,5 +76,5 @@ public interface DataSourceReader {
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
List<InputPartition<Row>> planInputPartitions();
List<InputPartition<InternalRow>> planInputPartitions();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry for a question in a old PR like this and I think this might not be directly related with this PR. but please allow me ask a question here. Does this mean developers should produce InternalRow here for each partition? InternalRow is under catalyst and not meant to be exposed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rationale is, data source v2 is not stable yet, and we should make it usable first, to make more people implement data sources and provide feedback. Eventually we should design a stable and efficient row builder in data source v2, but for now we should switch to InternalRow to make it usable. Row is too slow to implement a decent data source (like iceberg).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, okie. thanks!

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
* An input partition reader returned by {@link InputPartition#createPartitionReader()} and is
* responsible for outputting data for a RDD partition.
*
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input
* partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input
* partition readers that mix in {@link SupportsScanUnsafeRow}.
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}
* for normal data source readers, {@link org.apache.spark.sql.vectorized.ColumnarBatch} for data
* source readers that mix in {@link SupportsScanColumnarBatch}, or {@link org.apache.spark.sql.Row}
* for data source readers that mix in {@link SupportsDeprecatedScanRow}.
*/
@InterfaceStability.Evolving
public interface InputPartitionReader<T> extends Closeable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,23 @@

package org.apache.spark.sql.sources.v2.reader;

import java.util.List;

import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.InternalRow;

import java.util.List;

/**
* A mix-in interface for {@link DataSourceReader}. Data source readers can implement this
* interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side.
* This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get
* changed in the future Spark versions.
* interface to output {@link Row} instead of {@link InternalRow}.
* This is an experimental and unstable interface.
*/
@InterfaceStability.Unstable
public interface SupportsScanUnsafeRow extends DataSourceReader {

@Override
default List<InputPartition<Row>> planInputPartitions() {
public interface SupportsDeprecatedScanRow extends DataSourceReader {
default List<InputPartition<InternalRow>> planInputPartitions() {
throw new IllegalStateException(
"planInputPartitions not supported by default within SupportsScanUnsafeRow");
"planInputPartitions not supported by default within SupportsDeprecatedScanRow");
}

/**
* Similar to {@link DataSourceReader#planInputPartitions()},
* but returns data in unsafe row format.
*/
List<InputPartition<UnsafeRow>> planUnsafeInputPartitions();
List<InputPartition<Row>> planRowInputPartitions();
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.util.List;

import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.vectorized.ColumnarBatch;

/**
Expand All @@ -30,7 +30,7 @@
@InterfaceStability.Evolving
public interface SupportsScanColumnarBatch extends DataSourceReader {
@Override
default List<InputPartition<Row>> planInputPartitions() {
default List<InputPartition<InternalRow>> planInputPartitions() {
throw new IllegalStateException(
"planInputPartitions not supported by default within SupportsScanColumnarBatch.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.execution.datasources.v2

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,13 @@ case class DataSourceV2ScanExec(
case _ => super.outputPartitioning
}

private lazy val partitions: Seq[InputPartition[UnsafeRow]] = reader match {
case r: SupportsScanUnsafeRow => r.planUnsafeInputPartitions().asScala
case _ =>
reader.planInputPartitions().asScala.map {
new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[UnsafeRow]
private lazy val partitions: Seq[InputPartition[InternalRow]] = reader match {
case r: SupportsDeprecatedScanRow =>
r.planRowInputPartitions().asScala.map {
new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[InternalRow]
}
case _ =>
reader.planInputPartitions().asScala
}

private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match {
Expand Down Expand Up @@ -132,11 +133,11 @@ case class DataSourceV2ScanExec(
}

class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType)
extends InputPartition[UnsafeRow] {
extends InputPartition[InternalRow] {

override def preferredLocations: Array[String] = partition.preferredLocations

override def createPartitionReader: InputPartitionReader[UnsafeRow] = {
override def createPartitionReader: InputPartitionReader[InternalRow] = {
new RowToUnsafeInputPartitionReader(
partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind())
}
Expand All @@ -146,7 +147,7 @@ class RowToUnsafeInputPartitionReader(
val rowReader: InputPartitionReader[Row],
encoder: ExpressionEncoder[Row])

extends InputPartitionReader[UnsafeRow] {
extends InputPartitionReader[InternalRow] {

override def next: Boolean = rowReader.next

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,13 @@ object DataSourceV2Strategy extends Strategy {
val filterCondition = postScanFilters.reduceLeftOption(And)
val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan)

val withProjection = if (withFilter.output != project) {
ProjectExec(project, withFilter)
} else {
withFilter
}

withProjection :: Nil
// always add the projection, which will produce unsafe rows required by some operators
ProjectExec(project, withFilter) :: Nil

case r: StreamingDataSourceV2Relation =>
DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil
// ensure there is a projection, which will produce unsafe rows required by some operators
ProjectExec(r.output,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

continuous streaming scan always return unsafe rows, will we introduce regression here? cc @jose-torres

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I think it's safer to still require DataSourceV2ScanExec to return unsafe rows and move the unsafe conversion to here in a followup PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Continuous processing will still be experimental in the 2.4 release, so I'm not tremendously concerned about this. We should eventually change the scan to produce rows in whatever way is most efficient in the final API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is perfectly fine for sources to produce UnsafeRow because it is an InternalRow.

I think it is important for us to get to InternalRow in this release. UnsafeRow is too hard to produce and the easiest thing to do is to produce InternalRow and then call into Spark's UnsafeProjection to produce UnsafeRow. That's painful, uses internal APIs, and is slower.

DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader)) :: Nil

case WriteToDataSourceV2(writer, query) =>
WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ package org.apache.spark.sql.execution.streaming.continuous

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeInputPartitionReader}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.v2.RowToUnsafeInputPartitionReader
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, PartitionOffset}
import org.apache.spark.util.{NextIterator, ThreadUtils}
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader
import org.apache.spark.util.NextIterator

class ContinuousDataSourceRDDPartition(
val index: Int,
val inputPartition: InputPartition[UnsafeRow])
val inputPartition: InputPartition[InternalRow])
extends Partition with Serializable {

// This is semantically a lazy val - it's initialized once the first time a call to
Expand All @@ -51,8 +51,8 @@ class ContinuousDataSourceRDD(
sc: SparkContext,
dataQueueSize: Int,
epochPollIntervalMs: Long,
private val readerInputPartitions: Seq[InputPartition[UnsafeRow]])
extends RDD[UnsafeRow](sc, Nil) {
private val readerInputPartitions: Seq[InputPartition[InternalRow]])
extends RDD[InternalRow](sc, Nil) {

override protected def getPartitions: Array[Partition] = {
readerInputPartitions.zipWithIndex.map {
Expand All @@ -64,7 +64,7 @@ class ContinuousDataSourceRDD(
* Initialize the shared reader for this partition if needed, then read rows from it until
* it returns null to signal the end of the epoch.
*/
override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
// If attempt number isn't 0, this is a task retry, which we don't support.
if (context.attemptNumber() != 0) {
throw new ContinuousTaskRetryException()
Expand All @@ -80,8 +80,8 @@ class ContinuousDataSourceRDD(
partition.queueReader
}

new NextIterator[UnsafeRow] {
override def getNext(): UnsafeRow = {
new NextIterator[InternalRow] {
override def getNext(): InternalRow = {
readerForPartition.next() match {
case null =>
finished = true
Expand All @@ -101,9 +101,9 @@ class ContinuousDataSourceRDD(

object ContinuousDataSourceRDD {
private[continuous] def getContinuousReader(
reader: InputPartitionReader[UnsafeRow]): ContinuousInputPartitionReader[_] = {
reader: InputPartitionReader[InternalRow]): ContinuousInputPartitionReader[_] = {
reader match {
case r: ContinuousInputPartitionReader[UnsafeRow] => r
case r: ContinuousInputPartitionReader[InternalRow] => r
case wrapped: RowToUnsafeInputPartitionReader =>
wrapped.rowReader.asInstanceOf[ContinuousInputPartitionReader[Row]]
case _ =>
Expand Down
Loading