Skip to content

Commit

Permalink
[SPARK-29768][SQL] Column pruning through nondeterministic expressions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Support columnar pruning through non-deterministic expressions.

### Why are the changes needed?

In some cases, columns can still be pruned even though nondeterministic expressions appears.
e.g. for the plan  `Filter('a = 1, Project(Seq('a, rand() as 'r), LogicalRelation('a, 'b)))`, we shall still prune column b while non-deterministic expression appears.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Added a new test file: `ScanOperationSuite`.
Added test in `FileSourceStrategySuite` to verify the right prune behavior for both DS v1 and v2.

Closes #26629 from Ngone51/SPARK-29768.

Authored-by: wuyi <ngone_5451@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
Ngone51 authored and cloud-fan committed Nov 27, 2019
1 parent 4fd585d commit a58d91b
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 25 deletions.
Expand Up @@ -26,15 +26,36 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._

trait OperationHelper {
type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan)

protected def collectAliases(fields: Seq[Expression]): AttributeMap[Expression] =
AttributeMap(fields.collect {
case a: Alias => (a.toAttribute, a.child)
})

protected def substitute(aliases: AttributeMap[Expression])(expr: Expression): Expression = {
expr.transform {
case a @ Alias(ref: AttributeReference, name) =>
aliases.get(ref)
.map(Alias(_, name)(a.exprId, a.qualifier))
.getOrElse(a)

case a: AttributeReference =>
aliases.get(a)
.map(Alias(_, a.name)(a.exprId, a.qualifier)).getOrElse(a)
}
}
}

/**
* A pattern that matches any number of project or filter operations on top of another relational
* operator. All filter operators are collected and their conditions are broken up and returned
* together with the top project operator.
* [[org.apache.spark.sql.catalyst.expressions.Alias Aliases]] are in-lined/substituted if
* necessary.
*/
object PhysicalOperation extends PredicateHelper {
type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan)
object PhysicalOperation extends OperationHelper with PredicateHelper {

def unapply(plan: LogicalPlan): Option[ReturnType] = {
val (fields, filters, child, _) = collectProjectsAndFilters(plan)
Expand Down Expand Up @@ -74,22 +95,72 @@ object PhysicalOperation extends PredicateHelper {
case other =>
(None, Nil, other, AttributeMap(Seq()))
}
}

private def collectAliases(fields: Seq[Expression]): AttributeMap[Expression] =
AttributeMap(fields.collect {
case a: Alias => (a.toAttribute, a.child)
})
/**
* A variant of [[PhysicalOperation]]. It matches any number of project or filter
* operations even if they are non-deterministic, as long as they satisfy the
* requirement of CollapseProject and CombineFilters.
*/
object ScanOperation extends OperationHelper with PredicateHelper {
type ScanReturnType = Option[(Option[Seq[NamedExpression]],
Seq[Expression], LogicalPlan, AttributeMap[Expression])]

private def substitute(aliases: AttributeMap[Expression])(expr: Expression): Expression = {
expr.transform {
case a @ Alias(ref: AttributeReference, name) =>
aliases.get(ref)
.map(Alias(_, name)(a.exprId, a.qualifier))
.getOrElse(a)
def unapply(plan: LogicalPlan): Option[ReturnType] = {
collectProjectsAndFilters(plan) match {
case Some((fields, filters, child, _)) =>
Some((fields.getOrElse(child.output), filters, child))
case None => None
}
}

case a: AttributeReference =>
aliases.get(a)
.map(Alias(_, a.name)(a.exprId, a.qualifier)).getOrElse(a)
private def hasCommonNonDeterministic(
expr: Seq[Expression],
aliases: AttributeMap[Expression]): Boolean = {
expr.exists(_.collect {
case a: AttributeReference if aliases.contains(a) => aliases(a)
}.exists(!_.deterministic))
}

private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = {
plan match {
case Project(fields, child) =>
collectProjectsAndFilters(child) match {
case Some((_, filters, other, aliases)) =>
// Follow CollapseProject and only keep going if the collected Projects
// do not have common non-deterministic expressions.
if (!hasCommonNonDeterministic(fields, aliases)) {
val substitutedFields =
fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields)))
} else {
None
}
case None => None
}

case Filter(condition, child) =>
collectProjectsAndFilters(child) match {
case Some((fields, filters, other, aliases)) =>
// Follow CombineFilters and only keep going if the collected Filters
// are all deterministic and this filter doesn't have common non-deterministic
// expressions with lower Project.
if (filters.forall(_.deterministic) &&
!hasCommonNonDeterministic(Seq(condition), aliases)) {
val substitutedCondition = substitute(aliases)(condition)
Some((fields, filters ++ splitConjunctivePredicates(substitutedCondition),
other, aliases))
} else {
None
}
case None => None
}

case h: ResolvedHint =>
collectProjectsAndFilters(h.child)

case other =>
Some((None, Nil, other, AttributeMap(Seq())))
}
}
}
Expand Down
@@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.planning

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.TestRelations
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.DoubleType

class ScanOperationSuite extends SparkFunSuite {
private val relation = TestRelations.testRelation2
private val colA = relation.output(0)
private val colB = relation.output(1)
private val aliasR = Alias(Rand(1), "r")()
private val aliasId = Alias(MonotonicallyIncreasingID(), "id")()
private val colR = AttributeReference("r", DoubleType)(aliasR.exprId, aliasR.qualifier)

test("Project with a non-deterministic field and a deterministic child Filter") {
val project1 = Project(Seq(colB, aliasR), Filter(EqualTo(colA, Literal(1)), relation))
project1 match {
case ScanOperation(projects, filters, _: LocalRelation) =>
assert(projects.size === 2)
assert(projects(0) === colB)
assert(projects(1) === aliasR)
assert(filters.size === 1)
}
}

test("Project with all deterministic fields but a non-deterministic child Filter") {
val project2 = Project(Seq(colA, colB), Filter(EqualTo(aliasR, Literal(1)), relation))
project2 match {
case ScanOperation(projects, filters, _: LocalRelation) =>
assert(projects.size === 2)
assert(projects(0) === colA)
assert(projects(1) === colB)
assert(filters.size === 1)
}
}

test("Project which has the same non-deterministic expression with its child Project") {
val project3 = Project(Seq(colA, colR), Project(Seq(colA, aliasR), relation))
assert(ScanOperation.unapply(project3).isEmpty)
}

test("Project which has different non-deterministic expressions with its child Project") {
val project4 = Project(Seq(colA, aliasId), Project(Seq(colA, aliasR), relation))
project4 match {
case ScanOperation(projects, _, _: LocalRelation) =>
assert(projects.size === 2)
assert(projects(0) === colA)
assert(projects(1) === aliasId)
}
}

test("Filter which has the same non-deterministic expression with its child Project") {
val filter1 = Filter(EqualTo(colR, Literal(1)), Project(Seq(colA, aliasR), relation))
assert(ScanOperation.unapply(filter1).isEmpty)
}

test("Deterministic filter with a child Project with a non-deterministic expression") {
val filter2 = Filter(EqualTo(colA, Literal(1)), Project(Seq(colA, aliasR), relation))
filter2 match {
case ScanOperation(projects, filters, _: LocalRelation) =>
assert(projects.size === 2)
assert(projects(0) === colA)
assert(projects(1) === aliasR)
assert(filters.size === 1)
}
}

test("Filter which has different non-deterministic expressions with its child Project") {
val filter3 = Filter(EqualTo(MonotonicallyIncreasingID(), Literal(1)),
Project(Seq(colA, aliasR), relation))
filter3 match {
case ScanOperation(projects, filters, _: LocalRelation) =>
assert(projects.size === 2)
assert(projects(0) === colA)
assert(projects(1) === aliasR)
assert(filters.size === 1)
}
}


test("Deterministic filter which has a non-deterministic child Filter") {
val filter4 = Filter(EqualTo(colA, Literal(1)), Filter(EqualTo(aliasR, Literal(1)), relation))
assert(ScanOperation.unapply(filter4).isEmpty)
}
}
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
Expand Down Expand Up @@ -264,23 +264,23 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with
import DataSourceStrategy._

def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _, _)) =>
case ScanOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _, _)) =>
pruneFilterProjectRaw(
l,
projects,
filters,
(requestedColumns, allPredicates, _) =>
toCatalystRDD(l, requestedColumns, t.buildScan(requestedColumns, allPredicates))) :: Nil

case PhysicalOperation(projects, filters,
case ScanOperation(projects, filters,
l @ LogicalRelation(t: PrunedFilteredScan, _, _, _)) =>
pruneFilterProject(
l,
projects,
filters,
(a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil

case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _, _, _)) =>
case ScanOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _, _, _)) =>
pruneFilterProject(
l,
projects,
Expand Down
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.util.collection.BitSet
Expand Down Expand Up @@ -137,7 +137,7 @@ object FileSourceStrategy extends Strategy with Logging {
}

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(projects, filters,
case ScanOperation(projects, filters,
l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) =>
// Filters on this relation fall into four categories based on where we can use them to avoid
// reading unneeded data:
Expand Down
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.catalyst.expressions.{And, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
Expand All @@ -27,7 +27,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] {
import DataSourceV2Implicits._

override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case PhysicalOperation(project, filters, relation: DataSourceV2Relation) =>
case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options)

val (withSubquery, withoutSubquery) = filters.partition(SubqueryExpression.hasSubquery)
Expand Down
Expand Up @@ -31,12 +31,13 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper}
import org.apache.spark.sql.catalyst.util
import org.apache.spark.sql.execution.{DataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.{DataSourceScanExec, FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
import org.apache.spark.util.Utils

class FileSourceStrategySuite extends QueryTest with SharedSparkSession with PredicateHelper {
Expand Down Expand Up @@ -497,6 +498,36 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre
}
}

test("SPARK-29768: Column pruning through non-deterministic expressions") {
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "parquet") {
withTempPath { path =>
spark.range(10)
.selectExpr("id as key", "id * 3 as s1", "id * 5 as s2")
.write.format("parquet").save(path.getAbsolutePath)
val df1 = spark.read.parquet(path.getAbsolutePath)
val df2 = df1.selectExpr("key", "rand()").where("key > 5")
val plan = df2.queryExecution.sparkPlan
val scan = plan.collect { case scan: FileSourceScanExec => scan }
assert(scan.size === 1)
assert(scan.head.requiredSchema == StructType(StructField("key", LongType) :: Nil))
}
}

withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
withTempPath { path =>
spark.range(10)
.selectExpr("id as key", "id * 3 as s1", "id * 5 as s2")
.write.format("parquet").save(path.getAbsolutePath)
val df1 = spark.read.parquet(path.getAbsolutePath)
val df2 = df1.selectExpr("key", "rand()").where("key > 5")
val plan = df2.queryExecution.optimizedPlan
val scan = plan.collect { case r: DataSourceV2ScanRelation => r }
assert(scan.size === 1)
assert(scan.head.scan.readSchema() == StructType(StructField("key", LongType) :: Nil))
}
}
}

// Helpers for checking the arguments passed to the FileFormat.

protected val checkPartitionSchema =
Expand Down
Expand Up @@ -115,6 +115,10 @@ class PrunedScanSuite extends DataSourceTest with SharedSparkSession {
testPruning("SELECT b, b FROM oneToTenPruned", "b")
testPruning("SELECT a FROM oneToTenPruned", "a")
testPruning("SELECT b FROM oneToTenPruned", "b")
testPruning("SELECT a, rand() FROM oneToTenPruned WHERE a > 5", "a")
testPruning("SELECT a FROM oneToTenPruned WHERE rand() > 5", "a")
testPruning("SELECT a, rand() FROM oneToTenPruned WHERE rand() > 5", "a")
testPruning("SELECT a, rand() FROM oneToTenPruned WHERE b > 5", "a", "b")

def testPruning(sqlString: String, expectedColumns: String*): Unit = {
test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") {
Expand Down

0 comments on commit a58d91b

Please sign in to comment.