Skip to content

Commit

Permalink
[SPARK-23325] Use InternalRow when reading with DataSourceV2.
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This updates the DataSourceV2 API to use InternalRow instead of Row for the default case with no scan mix-ins.

Support for readers that produce Row is added through SupportsDeprecatedScanRow, which matches the previous API. Readers that used Row now implement this class and should be migrated to InternalRow.

Readers that previously implemented SupportsScanUnsafeRow have been migrated to use no SupportsScan mix-ins and produce InternalRow.

## How was this patch tested?

This uses existing tests.

Author: Ryan Blue <blue@apache.org>

Closes #21118 from rdblue/SPARK-23325-datasource-v2-internal-row.
  • Loading branch information
rdblue authored and gatorsmile committed Jul 24, 2018
1 parent 3d5c61e commit 9d27541
Show file tree
Hide file tree
Showing 28 changed files with 138 additions and 135 deletions.
Expand Up @@ -26,6 +26,7 @@ import org.apache.kafka.common.TopicPartition
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.reader._
Expand Down Expand Up @@ -53,7 +54,7 @@ class KafkaContinuousReader(
metadataPath: String,
initialOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
extends ContinuousReader with SupportsScanUnsafeRow with Logging {
extends ContinuousReader with Logging {

private lazy val session = SparkSession.getActiveSession.get
private lazy val sc = session.sparkContext
Expand Down Expand Up @@ -86,7 +87,7 @@ class KafkaContinuousReader(
KafkaSourceOffset(JsonUtils.partitionOffsets(json))
}

override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = {
override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = {
import scala.collection.JavaConverters._

val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset)
Expand All @@ -107,8 +108,8 @@ class KafkaContinuousReader(
startOffsets.toSeq.map {
case (topicPartition, start) =>
KafkaContinuousInputPartition(
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
.asInstanceOf[InputPartition[UnsafeRow]]
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss
): InputPartition[InternalRow]
}.asJava
}

Expand Down Expand Up @@ -161,9 +162,10 @@ case class KafkaContinuousInputPartition(
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends ContinuousInputPartition[UnsafeRow] {
failOnDataLoss: Boolean) extends ContinuousInputPartition[InternalRow] {

override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[UnsafeRow] = {
override def createContinuousReader(
offset: PartitionOffset): InputPartitionReader[InternalRow] = {
val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset]
require(kafkaOffset.topicPartition == topicPartition,
s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}")
Expand Down Expand Up @@ -192,7 +194,7 @@ class KafkaContinuousInputPartitionReader(
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[UnsafeRow] {
failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[InternalRow] {
private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false)
private val converter = new KafkaRecordToUnsafeRowConverter

Expand Down
Expand Up @@ -29,11 +29,12 @@ import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset}
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow}
import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.UninterruptibleThread
Expand Down Expand Up @@ -61,7 +62,7 @@ private[kafka010] class KafkaMicroBatchReader(
metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
extends MicroBatchReader with Logging {

private var startPartitionOffsets: PartitionOffsetMap = _
private var endPartitionOffsets: PartitionOffsetMap = _
Expand Down Expand Up @@ -101,7 +102,7 @@ private[kafka010] class KafkaMicroBatchReader(
}
}

override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = {
override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = {
// Find the new partitions, and get their earliest offsets
val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet)
val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
Expand Down Expand Up @@ -142,11 +143,11 @@ private[kafka010] class KafkaMicroBatchReader(
val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size

// Generate factories based on the offset ranges
val factories = offsetRanges.map { range =>
offsetRanges.map { range =>
new KafkaMicroBatchInputPartition(
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
}
factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer
): InputPartition[InternalRow]
}.asJava
}

override def getStartOffset: Offset = {
Expand Down Expand Up @@ -305,11 +306,11 @@ private[kafka010] case class KafkaMicroBatchInputPartition(
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean,
reuseKafkaConsumer: Boolean) extends InputPartition[UnsafeRow] {
reuseKafkaConsumer: Boolean) extends InputPartition[InternalRow] {

override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray

override def createPartitionReader(): InputPartitionReader[UnsafeRow] =
override def createPartitionReader(): InputPartitionReader[InternalRow] =
new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs,
failOnDataLoss, reuseKafkaConsumer)
}
Expand All @@ -320,7 +321,7 @@ private[kafka010] case class KafkaMicroBatchInputPartitionReader(
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean,
reuseKafkaConsumer: Boolean) extends InputPartitionReader[UnsafeRow] with Logging {
reuseKafkaConsumer: Boolean) extends InputPartitionReader[InternalRow] with Logging {

private val consumer = KafkaDataConsumer.acquire(
offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer)
Expand Down
Expand Up @@ -678,7 +678,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase {
Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))),
Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L)))
)
val factories = reader.planUnsafeInputPartitions().asScala
val factories = reader.planInputPartitions().asScala
.map(_.asInstanceOf[KafkaMicroBatchInputPartition])
withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") {
assert(factories.size == numPartitionsGenerated)
Expand Down
Expand Up @@ -20,7 +20,7 @@
import java.util.List;

import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.ReadSupport;
import org.apache.spark.sql.sources.v2.ReadSupportWithSchema;
Expand All @@ -43,7 +43,7 @@
* Names of these interfaces start with `SupportsScan`. Note that a reader should only
* implement at most one of the special scans, if more than one special scans are implemented,
* only one of them would be respected, according to the priority list from high to low:
* {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}.
* {@link SupportsScanColumnarBatch}, {@link SupportsDeprecatedScanRow}.
*
* If an exception was throw when applying any of these query optimizations, the action will fail
* and no Spark job will be submitted.
Expand Down Expand Up @@ -76,5 +76,5 @@ public interface DataSourceReader {
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
List<InputPartition<Row>> planInputPartitions();
List<InputPartition<InternalRow>> planInputPartitions();
}
Expand Up @@ -26,9 +26,10 @@
* An input partition reader returned by {@link InputPartition#createPartitionReader()} and is
* responsible for outputting data for a RDD partition.
*
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input
* partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input
* partition readers that mix in {@link SupportsScanUnsafeRow}.
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}
* for normal data source readers, {@link org.apache.spark.sql.vectorized.ColumnarBatch} for data
* source readers that mix in {@link SupportsScanColumnarBatch}, or {@link org.apache.spark.sql.Row}
* for data source readers that mix in {@link SupportsDeprecatedScanRow}.
*/
@InterfaceStability.Evolving
public interface InputPartitionReader<T> extends Closeable {
Expand Down
Expand Up @@ -17,30 +17,23 @@

package org.apache.spark.sql.sources.v2.reader;

import java.util.List;

import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.InternalRow;

import java.util.List;

/**
* A mix-in interface for {@link DataSourceReader}. Data source readers can implement this
* interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side.
* This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get
* changed in the future Spark versions.
* interface to output {@link Row} instead of {@link InternalRow}.
* This is an experimental and unstable interface.
*/
@InterfaceStability.Unstable
public interface SupportsScanUnsafeRow extends DataSourceReader {

@Override
default List<InputPartition<Row>> planInputPartitions() {
public interface SupportsDeprecatedScanRow extends DataSourceReader {
default List<InputPartition<InternalRow>> planInputPartitions() {
throw new IllegalStateException(
"planInputPartitions not supported by default within SupportsScanUnsafeRow");
"planInputPartitions not supported by default within SupportsDeprecatedScanRow");
}

/**
* Similar to {@link DataSourceReader#planInputPartitions()},
* but returns data in unsafe row format.
*/
List<InputPartition<UnsafeRow>> planUnsafeInputPartitions();
List<InputPartition<Row>> planRowInputPartitions();
}
Expand Up @@ -20,7 +20,7 @@
import java.util.List;

import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.vectorized.ColumnarBatch;

/**
Expand All @@ -30,7 +30,7 @@
@InterfaceStability.Evolving
public interface SupportsScanColumnarBatch extends DataSourceReader {
@Override
default List<InputPartition<Row>> planInputPartitions() {
default List<InputPartition<InternalRow>> planInputPartitions() {
throw new IllegalStateException(
"planInputPartitions not supported by default within SupportsScanColumnarBatch.");
}
Expand Down
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.execution.datasources.v2

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
Expand Down
Expand Up @@ -75,12 +75,13 @@ case class DataSourceV2ScanExec(
case _ => super.outputPartitioning
}

private lazy val partitions: Seq[InputPartition[UnsafeRow]] = reader match {
case r: SupportsScanUnsafeRow => r.planUnsafeInputPartitions().asScala
case _ =>
reader.planInputPartitions().asScala.map {
new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[UnsafeRow]
private lazy val partitions: Seq[InputPartition[InternalRow]] = reader match {
case r: SupportsDeprecatedScanRow =>
r.planRowInputPartitions().asScala.map {
new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[InternalRow]
}
case _ =>
reader.planInputPartitions().asScala
}

private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match {
Expand Down Expand Up @@ -132,11 +133,11 @@ case class DataSourceV2ScanExec(
}

class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType)
extends InputPartition[UnsafeRow] {
extends InputPartition[InternalRow] {

override def preferredLocations: Array[String] = partition.preferredLocations

override def createPartitionReader: InputPartitionReader[UnsafeRow] = {
override def createPartitionReader: InputPartitionReader[InternalRow] = {
new RowToUnsafeInputPartitionReader(
partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind())
}
Expand All @@ -146,7 +147,7 @@ class RowToUnsafeInputPartitionReader(
val rowReader: InputPartitionReader[Row],
encoder: ExpressionEncoder[Row])

extends InputPartitionReader[UnsafeRow] {
extends InputPartitionReader[InternalRow] {

override def next: Boolean = rowReader.next

Expand Down
Expand Up @@ -125,16 +125,13 @@ object DataSourceV2Strategy extends Strategy {
val filterCondition = postScanFilters.reduceLeftOption(And)
val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan)

val withProjection = if (withFilter.output != project) {
ProjectExec(project, withFilter)
} else {
withFilter
}

withProjection :: Nil
// always add the projection, which will produce unsafe rows required by some operators
ProjectExec(project, withFilter) :: Nil

case r: StreamingDataSourceV2Relation =>
DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil
// ensure there is a projection, which will produce unsafe rows required by some operators
ProjectExec(r.output,
DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader)) :: Nil

case WriteToDataSourceV2(writer, query) =>
WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil
Expand Down
Expand Up @@ -19,16 +19,16 @@ package org.apache.spark.sql.execution.streaming.continuous

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeInputPartitionReader}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.v2.RowToUnsafeInputPartitionReader
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, PartitionOffset}
import org.apache.spark.util.{NextIterator, ThreadUtils}
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader
import org.apache.spark.util.NextIterator

class ContinuousDataSourceRDDPartition(
val index: Int,
val inputPartition: InputPartition[UnsafeRow])
val inputPartition: InputPartition[InternalRow])
extends Partition with Serializable {

// This is semantically a lazy val - it's initialized once the first time a call to
Expand All @@ -51,8 +51,8 @@ class ContinuousDataSourceRDD(
sc: SparkContext,
dataQueueSize: Int,
epochPollIntervalMs: Long,
private val readerInputPartitions: Seq[InputPartition[UnsafeRow]])
extends RDD[UnsafeRow](sc, Nil) {
private val readerInputPartitions: Seq[InputPartition[InternalRow]])
extends RDD[InternalRow](sc, Nil) {

override protected def getPartitions: Array[Partition] = {
readerInputPartitions.zipWithIndex.map {
Expand All @@ -64,7 +64,7 @@ class ContinuousDataSourceRDD(
* Initialize the shared reader for this partition if needed, then read rows from it until
* it returns null to signal the end of the epoch.
*/
override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
// If attempt number isn't 0, this is a task retry, which we don't support.
if (context.attemptNumber() != 0) {
throw new ContinuousTaskRetryException()
Expand All @@ -80,8 +80,8 @@ class ContinuousDataSourceRDD(
partition.queueReader
}

new NextIterator[UnsafeRow] {
override def getNext(): UnsafeRow = {
new NextIterator[InternalRow] {
override def getNext(): InternalRow = {
readerForPartition.next() match {
case null =>
finished = true
Expand All @@ -101,9 +101,9 @@ class ContinuousDataSourceRDD(

object ContinuousDataSourceRDD {
private[continuous] def getContinuousReader(
reader: InputPartitionReader[UnsafeRow]): ContinuousInputPartitionReader[_] = {
reader: InputPartitionReader[InternalRow]): ContinuousInputPartitionReader[_] = {
reader match {
case r: ContinuousInputPartitionReader[UnsafeRow] => r
case r: ContinuousInputPartitionReader[InternalRow] => r
case wrapped: RowToUnsafeInputPartitionReader =>
wrapped.rowReader.asInstanceOf[ContinuousInputPartitionReader[Row]]
case _ =>
Expand Down

0 comments on commit 9d27541

Please sign in to comment.