Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23315][SQL] failed to get output from canonicalized data source v2 related plans #20485

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2

import java.util.Objects

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.sources.v2.reader._

/**
Expand All @@ -28,9 +28,9 @@ import org.apache.spark.sql.sources.v2.reader._
trait DataSourceReaderHolder {

/**
* The full output of the data source reader, without column pruning.
* The output of the data source reader, w.r.t. column pruning.
*/
def fullOutput: Seq[AttributeReference]
def output: Seq[Attribute]

/**
* The held data source reader.
Expand All @@ -46,7 +46,7 @@ trait DataSourceReaderHolder {
case s: SupportsPushDownFilters => s.pushedFilters().toSet
case _ => Nil
}
Seq(fullOutput, reader.getClass, reader.readSchema(), filters)
Seq(output, reader.getClass, filters)
}

def canEqual(other: Any): Boolean
Expand All @@ -61,8 +61,4 @@ trait DataSourceReaderHolder {
override def hashCode(): Int = {
metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
}

lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name =>
fullOutput.find(_.name == name).get
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.sources.v2.reader._

case class DataSourceV2Relation(
fullOutput: Seq[AttributeReference],
output: Seq[AttributeReference],
reader: DataSourceReader)
extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder {

Expand All @@ -37,7 +37,7 @@ case class DataSourceV2Relation(
}

override def newInstance(): DataSourceV2Relation = {
copy(fullOutput = fullOutput.map(_.newInstance()))
copy(output = output.map(_.newInstance()))
}
}

Expand All @@ -46,8 +46,8 @@ case class DataSourceV2Relation(
* to the non-streaming relation.
*/
class StreamingDataSourceV2Relation(
fullOutput: Seq[AttributeReference],
reader: DataSourceReader) extends DataSourceV2Relation(fullOutput, reader) {
output: Seq[AttributeReference],
reader: DataSourceReader) extends DataSourceV2Relation(output, reader) {
override def isStreaming: Boolean = true
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,12 @@ import org.apache.spark.sql.types.StructType
* Physical plan node for scanning data from a data source.
*/
case class DataSourceV2ScanExec(
fullOutput: Seq[AttributeReference],
output: Seq[AttributeReference],
@transient reader: DataSourceReader)
extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan {

override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec]

override def producedAttributes: AttributeSet = AttributeSet(fullOutput)

override def outputPartitioning: physical.Partitioning = reader match {
case s: SupportsReportPartitioning =>
new DataSourcePartitioning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,33 +81,44 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel

// TODO: add more push down rules.

pushDownRequiredColumns(filterPushed, filterPushed.outputSet)
val columnPruned = pushDownRequiredColumns(filterPushed, filterPushed.outputSet)
// After column pruning, we may have redundant PROJECT nodes in the query plan, remove them.
RemoveRedundantProject(filterPushed)
RemoveRedundantProject(columnPruned)
}

// TODO: nested fields pruning
private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: AttributeSet): Unit = {
private def pushDownRequiredColumns(
plan: LogicalPlan, requiredByParent: AttributeSet): LogicalPlan = {
plan match {
case Project(projectList, child) =>
case p @ Project(projectList, child) =>
val required = projectList.flatMap(_.references)
pushDownRequiredColumns(child, AttributeSet(required))
p.copy(child = pushDownRequiredColumns(child, AttributeSet(required)))

case Filter(condition, child) =>
case f @ Filter(condition, child) =>
val required = requiredByParent ++ condition.references
pushDownRequiredColumns(child, required)
f.copy(child = pushDownRequiredColumns(child, required))

case relation: DataSourceV2Relation => relation.reader match {
case reader: SupportsPushDownRequiredColumns =>
// TODO: Enable the below assert after we make `DataSourceV2Relation` immutable. Fow now
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: Fow

// it's possible that the mutable reader being updated by someone else, and we need to
// always call `reader.pruneColumns` here to correct it.
// assert(relation.output.toStructType == reader.readSchema(),
// "Schema of data source reader does not match the relation plan.")

val requiredColumns = relation.output.filter(requiredByParent.contains)
reader.pruneColumns(requiredColumns.toStructType)

case _ =>
val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap
val newOutput = reader.readSchema().map(_.name).map(nameToAttr)
relation.copy(output = newOutput)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdblue This is the bug I mentioned before. Finally I figured out a way to fix it surgically: always run column pruning even no column needs to be pruned. This helps us correct the required schema of the reader, if it's updated by someone else.


case _ => relation
}

// TODO: there may be more operators that can be used to calculate the required columns. We
// can add more and more in the future.
case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.outputSet))
case _ => plan.mapChildren(c => pushDownRequiredColumns(c, c.outputSet))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import test.org.apache.spark.sql.sources.v2._
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -316,6 +316,24 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
val reader4 = getReader(q4)
assert(reader4.requiredSchema.fieldNames === Seq("i"))
}

test("SPARK-23315: get output from canonicalized data source v2 related plans") {
def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = {
val logical = df.queryExecution.optimizedPlan.collect {
case d: DataSourceV2Relation => d
}.head
assert(logical.canonicalized.output.length == numOutput)

val physical = df.queryExecution.executedPlan.collect {
case d: DataSourceV2ScanExec => d
}.head
assert(physical.canonicalized.output.length == numOutput)
}

val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
checkCanonicalizedOutput(df, 2)
checkCanonicalizedOutput(df.select('i), 1)
}
}

class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {
Expand Down