Skip to content

Commit

Permalink
[SPARK-23574][SQL] Report SinglePartition in DataSourceV2ScanExec whe…
Browse files Browse the repository at this point in the history
…n there's exactly 1 data reader factory.

## What changes were proposed in this pull request?

Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory.

Note that this means reader factories end up being constructed as partitioning is checked; let me know if you think that could be a problem.

## How was this patch tested?

existing unit tests

Author: Jose Torres <jose@databricks.com>
Author: Jose Torres <torres.joseph.f+github@gmail.com>

Closes #20726 from jose-torres/SPARK-23574.
  • Loading branch information
jose-torres authored and cloud-fan committed Mar 20, 2018
1 parent 7f5e8aa commit 2c4b996
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 14 deletions.
Expand Up @@ -23,6 +23,9 @@
/**
* A mix in interface for {@link DataSourceReader}. Data source readers can implement this
* interface to report data partitioning and try to avoid shuffle at Spark side.
*
* Note that, when the reader creates exactly one {@link DataReaderFactory}, Spark may avoid
* adding a shuffle even if the reader does not implement this interface.
*/
@InterfaceStability.Evolving
public interface SupportsReportPartitioning extends DataSourceReader {
Expand Down
Expand Up @@ -29,11 +29,11 @@ class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: Da

class DataSourceRDD[T: ClassTag](
sc: SparkContext,
@transient private val readerFactories: java.util.List[DataReaderFactory[T]])
@transient private val readerFactories: Seq[DataReaderFactory[T]])
extends RDD[T](sc, Nil) {

override protected def getPartitions: Array[Partition] = {
readerFactories.asScala.zipWithIndex.map {
readerFactories.zipWithIndex.map {
case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
}.toArray
}
Expand Down
Expand Up @@ -25,12 +25,14 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.sources.v2.DataSourceV2
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
* Physical plan node for scanning data from a data source.
Expand All @@ -56,36 +58,49 @@ case class DataSourceV2ScanExec(
}

override def outputPartitioning: physical.Partitioning = reader match {
case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 =>
SinglePartition

case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 =>
SinglePartition

case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 =>
SinglePartition

case s: SupportsReportPartitioning =>
new DataSourcePartitioning(
s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name)))

case _ => super.outputPartitioning
}

private lazy val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]] = reader match {
case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories()
private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match {
case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala
case _ =>
reader.createDataReaderFactories().asScala.map {
new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow]
}.asJava
}
}

private lazy val inputRDD: RDD[InternalRow] = reader match {
private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match {
case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
assert(!reader.isInstanceOf[ContinuousReader],
"continuous stream reader does not support columnar read yet.")
new DataSourceRDD(sparkContext, r.createBatchDataReaderFactories())
.asInstanceOf[RDD[InternalRow]]
r.createBatchDataReaderFactories().asScala
}

private lazy val inputRDD: RDD[InternalRow] = reader match {
case _: ContinuousReader =>
EpochCoordinatorRef.get(
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
sparkContext.env)
.askSync[Unit](SetReaderPartitions(readerFactories.size()))
.askSync[Unit](SetReaderPartitions(readerFactories.size))
new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories)
.asInstanceOf[RDD[InternalRow]]

case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]]

case _ =>
new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]]
}
Expand Down
Expand Up @@ -35,14 +35,14 @@ import org.apache.spark.util.ThreadUtils
class ContinuousDataSourceRDD(
sc: SparkContext,
sqlContext: SQLContext,
@transient private val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]])
@transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]])
extends RDD[UnsafeRow](sc, Nil) {

private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize
private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs

override protected def getPartitions: Array[Partition] = {
readerFactories.asScala.zipWithIndex.map {
readerFactories.zipWithIndex.map {
case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
}.toArray
}
Expand Down
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.functions._
import org.apache.spark.sql.sources.{Filter, GreaterThan}
Expand Down Expand Up @@ -191,6 +191,11 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}
}

test("SPARK-23574: no shuffle exchange with single partition") {
val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*"))
assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty)
}

test("simple writable data source") {
// TODO: java implementation.
Seq(classOf[SimpleWritableDataSource]).foreach { cls =>
Expand Down Expand Up @@ -336,6 +341,19 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}
}

class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport {

class Reader extends DataSourceReader {
override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")

override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5))
}
}

override def createReader(options: DataSourceOptions): DataSourceReader = new Reader
}

class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {

class Reader extends DataSourceReader {
Expand Down
Expand Up @@ -326,9 +326,9 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi

assert(progress.durationMs.get("setOffsetRange") === 50)
assert(progress.durationMs.get("getEndOffset") === 100)
assert(progress.durationMs.get("queryPlanning") === 0)
assert(progress.durationMs.get("queryPlanning") === 200)
assert(progress.durationMs.get("walCommit") === 0)
assert(progress.durationMs.get("addBatch") === 350)
assert(progress.durationMs.get("addBatch") === 150)
assert(progress.durationMs.get("triggerExecution") === 500)

assert(progress.sources.length === 1)
Expand Down

0 comments on commit 2c4b996

Please sign in to comment.