From a58d91b1594f454e436885906798f3d6e1f781c9 Mon Sep 17 00:00:00 2001 From: wuyi Date: Wed, 27 Nov 2019 15:37:01 +0800 Subject: [PATCH] [SPARK-29768][SQL] Column pruning through nondeterministic expressions ### What changes were proposed in this pull request? Support columnar pruning through non-deterministic expressions. ### Why are the changes needed? In some cases, columns can still be pruned even though nondeterministic expressions appears. e.g. for the plan `Filter('a = 1, Project(Seq('a, rand() as 'r), LogicalRelation('a, 'b)))`, we shall still prune column b while non-deterministic expression appears. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Added a new test file: `ScanOperationSuite`. Added test in `FileSourceStrategySuite` to verify the right prune behavior for both DS v1 and v2. Closes #26629 from Ngone51/SPARK-29768. Authored-by: wuyi Signed-off-by: Wenchen Fan --- .../sql/catalyst/planning/patterns.scala | 101 ++++++++++++++--- .../planning/ScanOperationSuite.scala | 104 ++++++++++++++++++ .../datasources/DataSourceStrategy.scala | 8 +- .../datasources/FileSourceStrategy.scala | 4 +- .../v2/V2ScanRelationPushDown.scala | 4 +- .../datasources/FileSourceStrategySuite.scala | 35 +++++- .../spark/sql/sources/PrunedScanSuite.scala | 4 + 7 files changed, 235 insertions(+), 25 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index c2a12eda19137..4e790b1dd3f36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -26,6 +26,28 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +trait OperationHelper { + type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan) + + protected def collectAliases(fields: Seq[Expression]): AttributeMap[Expression] = + AttributeMap(fields.collect { + case a: Alias => (a.toAttribute, a.child) + }) + + protected def substitute(aliases: AttributeMap[Expression])(expr: Expression): Expression = { + expr.transform { + case a @ Alias(ref: AttributeReference, name) => + aliases.get(ref) + .map(Alias(_, name)(a.exprId, a.qualifier)) + .getOrElse(a) + + case a: AttributeReference => + aliases.get(a) + .map(Alias(_, a.name)(a.exprId, a.qualifier)).getOrElse(a) + } + } +} + /** * A pattern that matches any number of project or filter operations on top of another relational * operator. All filter operators are collected and their conditions are broken up and returned @@ -33,8 +55,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ * [[org.apache.spark.sql.catalyst.expressions.Alias Aliases]] are in-lined/substituted if * necessary. */ -object PhysicalOperation extends PredicateHelper { - type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan) +object PhysicalOperation extends OperationHelper with PredicateHelper { def unapply(plan: LogicalPlan): Option[ReturnType] = { val (fields, filters, child, _) = collectProjectsAndFilters(plan) @@ -74,22 +95,72 @@ object PhysicalOperation extends PredicateHelper { case other => (None, Nil, other, AttributeMap(Seq())) } +} - private def collectAliases(fields: Seq[Expression]): AttributeMap[Expression] = - AttributeMap(fields.collect { - case a: Alias => (a.toAttribute, a.child) - }) +/** + * A variant of [[PhysicalOperation]]. It matches any number of project or filter + * operations even if they are non-deterministic, as long as they satisfy the + * requirement of CollapseProject and CombineFilters. + */ +object ScanOperation extends OperationHelper with PredicateHelper { + type ScanReturnType = Option[(Option[Seq[NamedExpression]], + Seq[Expression], LogicalPlan, AttributeMap[Expression])] - private def substitute(aliases: AttributeMap[Expression])(expr: Expression): Expression = { - expr.transform { - case a @ Alias(ref: AttributeReference, name) => - aliases.get(ref) - .map(Alias(_, name)(a.exprId, a.qualifier)) - .getOrElse(a) + def unapply(plan: LogicalPlan): Option[ReturnType] = { + collectProjectsAndFilters(plan) match { + case Some((fields, filters, child, _)) => + Some((fields.getOrElse(child.output), filters, child)) + case None => None + } + } - case a: AttributeReference => - aliases.get(a) - .map(Alias(_, a.name)(a.exprId, a.qualifier)).getOrElse(a) + private def hasCommonNonDeterministic( + expr: Seq[Expression], + aliases: AttributeMap[Expression]): Boolean = { + expr.exists(_.collect { + case a: AttributeReference if aliases.contains(a) => aliases(a) + }.exists(!_.deterministic)) + } + + private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = { + plan match { + case Project(fields, child) => + collectProjectsAndFilters(child) match { + case Some((_, filters, other, aliases)) => + // Follow CollapseProject and only keep going if the collected Projects + // do not have common non-deterministic expressions. + if (!hasCommonNonDeterministic(fields, aliases)) { + val substitutedFields = + fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] + Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields))) + } else { + None + } + case None => None + } + + case Filter(condition, child) => + collectProjectsAndFilters(child) match { + case Some((fields, filters, other, aliases)) => + // Follow CombineFilters and only keep going if the collected Filters + // are all deterministic and this filter doesn't have common non-deterministic + // expressions with lower Project. + if (filters.forall(_.deterministic) && + !hasCommonNonDeterministic(Seq(condition), aliases)) { + val substitutedCondition = substitute(aliases)(condition) + Some((fields, filters ++ splitConjunctivePredicates(substitutedCondition), + other, aliases)) + } else { + None + } + case None => None + } + + case h: ResolvedHint => + collectProjectsAndFilters(h.child) + + case other => + Some((None, Nil, other, AttributeMap(Seq()))) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala new file mode 100644 index 0000000000000..7790f467a890b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.planning + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.TestRelations +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.DoubleType + +class ScanOperationSuite extends SparkFunSuite { + private val relation = TestRelations.testRelation2 + private val colA = relation.output(0) + private val colB = relation.output(1) + private val aliasR = Alias(Rand(1), "r")() + private val aliasId = Alias(MonotonicallyIncreasingID(), "id")() + private val colR = AttributeReference("r", DoubleType)(aliasR.exprId, aliasR.qualifier) + + test("Project with a non-deterministic field and a deterministic child Filter") { + val project1 = Project(Seq(colB, aliasR), Filter(EqualTo(colA, Literal(1)), relation)) + project1 match { + case ScanOperation(projects, filters, _: LocalRelation) => + assert(projects.size === 2) + assert(projects(0) === colB) + assert(projects(1) === aliasR) + assert(filters.size === 1) + } + } + + test("Project with all deterministic fields but a non-deterministic child Filter") { + val project2 = Project(Seq(colA, colB), Filter(EqualTo(aliasR, Literal(1)), relation)) + project2 match { + case ScanOperation(projects, filters, _: LocalRelation) => + assert(projects.size === 2) + assert(projects(0) === colA) + assert(projects(1) === colB) + assert(filters.size === 1) + } + } + + test("Project which has the same non-deterministic expression with its child Project") { + val project3 = Project(Seq(colA, colR), Project(Seq(colA, aliasR), relation)) + assert(ScanOperation.unapply(project3).isEmpty) + } + + test("Project which has different non-deterministic expressions with its child Project") { + val project4 = Project(Seq(colA, aliasId), Project(Seq(colA, aliasR), relation)) + project4 match { + case ScanOperation(projects, _, _: LocalRelation) => + assert(projects.size === 2) + assert(projects(0) === colA) + assert(projects(1) === aliasId) + } + } + + test("Filter which has the same non-deterministic expression with its child Project") { + val filter1 = Filter(EqualTo(colR, Literal(1)), Project(Seq(colA, aliasR), relation)) + assert(ScanOperation.unapply(filter1).isEmpty) + } + + test("Deterministic filter with a child Project with a non-deterministic expression") { + val filter2 = Filter(EqualTo(colA, Literal(1)), Project(Seq(colA, aliasR), relation)) + filter2 match { + case ScanOperation(projects, filters, _: LocalRelation) => + assert(projects.size === 2) + assert(projects(0) === colA) + assert(projects(1) === aliasR) + assert(filters.size === 1) + } + } + + test("Filter which has different non-deterministic expressions with its child Project") { + val filter3 = Filter(EqualTo(MonotonicallyIncreasingID(), Literal(1)), + Project(Seq(colA, aliasR), relation)) + filter3 match { + case ScanOperation(projects, filters, _: LocalRelation) => + assert(projects.size === 2) + assert(projects(0) === colA) + assert(projects(1) === aliasR) + assert(filters.size === 1) + } + } + + + test("Deterministic filter which has a non-deterministic child Filter") { + val filter4 = Filter(EqualTo(colA, Literal(1)), Filter(EqualTo(aliasR, Literal(1)), relation)) + assert(ScanOperation.unapply(filter4).isEmpty) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 211642d78aabe..46444f0a05605 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} @@ -264,7 +264,7 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with import DataSourceStrategy._ def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _, _)) => + case ScanOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _, _)) => pruneFilterProjectRaw( l, projects, @@ -272,7 +272,7 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with (requestedColumns, allPredicates, _) => toCatalystRDD(l, requestedColumns, t.buildScan(requestedColumns, allPredicates))) :: Nil - case PhysicalOperation(projects, filters, + case ScanOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _, _, _)) => pruneFilterProject( l, @@ -280,7 +280,7 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with filters, (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _, _, _)) => + case ScanOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _, _, _)) => pruneFilterProject( l, projects, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index afc9bfeda84a6..bd342c7f404fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.util.collection.BitSet @@ -137,7 +137,7 @@ object FileSourceStrategy extends Strategy with Logging { } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projects, filters, + case ScanOperation(projects, filters, l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => // Filters on this relation fall into four categories based on where we can use them to avoid // reading unneeded data: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 6aa8d989583d1..239e3e8f82f18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions.{And, SubqueryExpression} -import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.datasources.DataSourceStrategy @@ -27,7 +27,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { import DataSourceV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => + case ScanOperation(project, filters, relation: DataSourceV2Relation) => val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) val (withSubquery, withoutSubquery) = filters.partition(SubqueryExpression.hasSubquery) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index fa8111407665a..812305ba24403 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -31,12 +31,13 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper} import org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.execution.{DataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.{DataSourceScanExec, FileSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} import org.apache.spark.util.Utils class FileSourceStrategySuite extends QueryTest with SharedSparkSession with PredicateHelper { @@ -497,6 +498,36 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre } } + test("SPARK-29768: Column pruning through non-deterministic expressions") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "parquet") { + withTempPath { path => + spark.range(10) + .selectExpr("id as key", "id * 3 as s1", "id * 5 as s2") + .write.format("parquet").save(path.getAbsolutePath) + val df1 = spark.read.parquet(path.getAbsolutePath) + val df2 = df1.selectExpr("key", "rand()").where("key > 5") + val plan = df2.queryExecution.sparkPlan + val scan = plan.collect { case scan: FileSourceScanExec => scan } + assert(scan.size === 1) + assert(scan.head.requiredSchema == StructType(StructField("key", LongType) :: Nil)) + } + } + + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + withTempPath { path => + spark.range(10) + .selectExpr("id as key", "id * 3 as s1", "id * 5 as s2") + .write.format("parquet").save(path.getAbsolutePath) + val df1 = spark.read.parquet(path.getAbsolutePath) + val df2 = df1.selectExpr("key", "rand()").where("key > 5") + val plan = df2.queryExecution.optimizedPlan + val scan = plan.collect { case r: DataSourceV2ScanRelation => r } + assert(scan.size === 1) + assert(scan.head.scan.readSchema() == StructType(StructField("key", LongType) :: Nil)) + } + } + } + // Helpers for checking the arguments passed to the FileFormat. protected val checkPartitionSchema = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index d99c605b2e478..237717a3ad196 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -115,6 +115,10 @@ class PrunedScanSuite extends DataSourceTest with SharedSparkSession { testPruning("SELECT b, b FROM oneToTenPruned", "b") testPruning("SELECT a FROM oneToTenPruned", "a") testPruning("SELECT b FROM oneToTenPruned", "b") + testPruning("SELECT a, rand() FROM oneToTenPruned WHERE a > 5", "a") + testPruning("SELECT a FROM oneToTenPruned WHERE rand() > 5", "a") + testPruning("SELECT a, rand() FROM oneToTenPruned WHERE rand() > 5", "a") + testPruning("SELECT a, rand() FROM oneToTenPruned WHERE b > 5", "a", "b") def testPruning(sqlString: String, expectedColumns: String*): Unit = { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") {