Skip to content

[SPARK-51831][SQL] Column pruning with existsJoin for Datasource V2 #51046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, PushVariantIntoScan, SchemaPruning, V1Writes}
import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes}
import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, PostV2ScanRelationPushDown, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes}
import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning, RowLevelOperationRuntimeGroupFiltering}
import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs, ExtractPythonUDTFs}

@@ -95,6 +95,7 @@ class SparkOptimizer(
LimitPushDownThroughWindow,
ConstantFolding,
EliminateLimits),
Batch("Post Push down for V2 Relations", FixedPoint(1), PostV2ScanRelationPushDown),
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*),
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition)))

Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ScanOperation}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, CountStar, Max, Min, Sum}
@@ -38,7 +39,30 @@ import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, StructTyp
import org.apache.spark.sql.util.SchemaUtils._
import org.apache.spark.util.ArrayImplicits._

object PostV2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
import V2ScanRelationPushDown._

def apply(plan: LogicalPlan): LogicalPlan = {
val pushdownRules = Seq[LogicalPlan => LogicalPlan] (
createScanBuilder,
pruneColumns)

pushdownRules.foldLeft(plan) { (newPlan, pushDownRule) =>
pushDownRule(newPlan)
}
}

private def createScanBuilder(plan: LogicalPlan) = plan.transform {
case r @ DataSourceV2ScanRelation(relation, _, _, _, _)
if relation.getTagValue(V2_SCAN_BUILDER_HOLDER).nonEmpty =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When will the content of tags be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeap

val sHolder = relation.getTagValue(V2_SCAN_BUILDER_HOLDER).get
sHolder.cachedScanRelation = Some(r)
sHolder
}
}

object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
val V2_SCAN_BUILDER_HOLDER = TreeNodeTag[ScanBuilderHolder]("v2_scan_builder_holder")
import DataSourceV2Implicits._

def apply(plan: LogicalPlan): LogicalPlan = {
@@ -105,7 +129,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {

private def rewriteAggregate(agg: Aggregate): LogicalPlan = agg.child match {
case PhysicalOperation(project, Nil, holder @ ScanBuilderHolder(_, _,
r: SupportsPushDownAggregates)) if CollapseProject.canCollapseExpressions(
r: SupportsPushDownAggregates, _)) if CollapseProject.canCollapseExpressions(
agg.aggregateExpressions, project, alwaysInline = true) =>
val aliasMap = getAliasMap(project)
val actualResultExprs = agg.aggregateExpressions.map(replaceAliasButKeepName(_, aliasMap))
@@ -374,7 +398,25 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {

val wrappedScan = getWrappedScan(scan, sHolder)

val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
def sameOutput(
cachedOutput: Seq[AttributeReference], newOutput: Seq[AttributeReference]): Boolean = {
cachedOutput.size == newOutput.size &&
cachedOutput.zip(newOutput).forall { case (cachedField, newField) =>
cachedField.canonicalized.semanticEquals(newField.canonicalized)
}
}

val scanRelation: DataSourceV2ScanRelation = if (sHolder.cachedScanRelation.isEmpty) {
val relation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
// reuse sHolder to support column pruning after optimization
sHolder.output = output
sHolder.relation.setTagValue(V2_SCAN_BUILDER_HOLDER, sHolder)
relation
} else if (sameOutput(sHolder.output, output)) {
sHolder.cachedScanRelation.get
} else {
sHolder.cachedScanRelation.get.copy(scan = wrappedScan, output = output)
}

val projectionOverSchema =
ProjectionOverSchema(output.toStructType, AttributeSet(output))
@@ -559,7 +601,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
case class ScanBuilderHolder(
var output: Seq[AttributeReference],
relation: DataSourceV2Relation,
builder: ScanBuilder) extends LeafNode {
builder: ScanBuilder,
var cachedScanRelation: Option[DataSourceV2ScanRelation] = None) extends LeafNode {
var pushedLimit: Option[Int] = None

var pushedOffset: Option[Int] = None
Original file line number Diff line number Diff line change
@@ -658,16 +658,69 @@ abstract class SchemaPruningSuite
|where not exists (select null from employees e where e.name.first = c.name.first
| and e.employer.name = c.employer.company.name)
|""".stripMargin)
checkScan(query,
"struct<name:struct<first:string,middle:string,last:string>," +
"employer:struct<id:int,company:struct<name:string,address:string>>>",
"struct<name:struct<first:string,middle:string,last:string>," +
"employer:struct<name:string,address:string>>")
// TODO: SPARK-51381: Fix the schema pruning for V1 nested columns
if (SQLConf.get.getConf(SQLConf.USE_V1_SOURCE_LIST).contains(dataSourceName)) {
checkScan(query,
"struct<name:struct<first:string,middle:string,last:string>," +
"employer:struct<id:int,company:struct<name:string,address:string>>>",
"struct<name:struct<first:string,middle:string,last:string>," +
"employer:struct<name:string,address:string>>")
} else {
checkScan(query,
"struct<name:struct<first:string>," +
"employer:struct<company:struct<name:string>>>",
"struct<name:struct<first:string>," +
"employer:struct<name:string>>")
}
checkAnswer(query, Row(3))
}
}
}

testSchemaPruning("SPARK-51831: Column pruning with exists Join") {
withTempPath { dir =>
spark.range(100)
.withColumn("col1", col("id") + 1)
.withColumn("col2", col("id") + 2)
.withColumn("col3", col("id") + 3)
.withColumn("col4", col("id") + 4)
.withColumn("col5", col("id") + 5)
.withColumn("col6", col("id") + 6)
.withColumn("col7", col("id") + 7)
.withColumn("col8", col("id") + 8)
.withColumn("col9", col("id") + 9)
.write
.mode("overwrite")
.format(dataSourceName)
.save(dir.getCanonicalPath + "/t1")
spark.range(10)
.write
.mode("overwrite")
.format(dataSourceName)
.save(dir.getCanonicalPath + "/t2")

spark.read
.format(dataSourceName)
.load(dir.getCanonicalPath + "/t1")
.createOrReplaceTempView("t1")
spark.read
.format(dataSourceName)
.load(dir.getCanonicalPath + "/t2")
.createOrReplaceTempView("t2")
val query = sql(
"""
|select sum(t1.id) as sum_id
|from t1, t2
|where t1.id == t2.id
| and exists(select * from t1 where t1.id == t2.id and t1.col1>5)
|""".stripMargin)
checkScan(query,
"struct<id:long>",
"struct<id:long>",
"struct<id:long, col1:long>")
}
}

protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = {
test(s"Spark vectorized reader - without partition data column - $testName") {
withSQLConf(vectorizedReaderEnabledKey -> "true") {