From 0f90ae4b3e8f9833768a8d43caca57050c47c984 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 29 Jun 2017 17:50:48 +0900 Subject: [PATCH 1/3] Update nullability based on children's output --- .../sql/catalyst/expressions/predicates.scala | 6 +++ .../sql/catalyst/optimizer/Optimizer.scala | 21 +++++++- .../plans/logical/basicLogicalOperators.scala | 20 +++++++- .../InferFiltersFromConstraintsSuite.scala | 12 +++-- ...ullabilityInAttributeReferencesSuite.scala | 50 +++++++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 19 +++++-- .../execution/basicPhysicalOperators.scala | 34 ++++++------- .../org/apache/spark/sql/DataFrameSuite.scala | 5 -- .../execution/WholeStageCodegenSuite.scala | 3 +- .../datasources/FileSourceStrategySuite.scala | 2 +- .../python/BatchEvalPythonExecSuite.scala | 12 ++--- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- 12 files changed, 144 insertions(+), 42 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4b85d9adbe311..f420556cf4293 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -75,6 +75,12 @@ trait PredicateHelper { } } + // If one expression and its children are null intolerant, it is null intolerant. + protected def isNullIntolerant(expr: Expression): Boolean = expr match { + case e: NullIntolerant => e.children.forall(isNullIntolerant) + case _ => false + } + /** * Returns true if `expr` can be evaluated using only the output of `plan`. This method * can be used to determine when it is acceptable to move expression evaluation within a query diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2829d1d81eb1a..104177f58c0bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -153,7 +153,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RewritePredicateSubquery, ColumnPruning, CollapseProject, - RemoveRedundantProject) + RemoveRedundantProject) :+ + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) } /** @@ -1309,3 +1311,20 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { } } } + +/** + * Updates nullability in [[AttributeReference]]s if nullability is different between + * non-leaf plan's expressions and the children output. + */ +object UpdateNullabilityInAttributeReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case p if !p.isInstanceOf[LeafNode] => + val nullabilityMap = AttributeMap(p.children.flatMap(_.output).map { x => x -> x.nullable }) + p transformExpressions { + case ar: AttributeReference => + nullabilityMap.get(ar).filterNot(_ == ar.nullable).map { nullable => + ar.withNullability(nullable) + }.getOrElse(ar) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a4fca790dd086..60b52d4292b1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -126,7 +126,25 @@ case class Generate( case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode with PredicateHelper { - override def output: Seq[Attribute] = child.output + + // The columns that will filtered out by `IsNotNull` could be considered as not nullable. + private lazy val notNullAttributes = splitConjunctivePredicates(condition).flatMap { + case isnotnull @ IsNotNull(a) + if isNullIntolerant(a) && a.references.subsetOf(child.outputSet) => + isnotnull.references.map(_.exprId) + case _ => + Seq.empty[ExprId] + }.toSet + + override def output: Seq[Attribute] = { + child.output.map { a => + if (a.nullable && notNullAttributes.contains(a.exprId)) { + a.withNullability(false) + } else { + a + } + } + } override def maxRows: Option[Long] = child.maxRows diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index f78c2356e35a5..76d392ed2ffa1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -35,7 +35,9 @@ class InferFiltersFromConstraintsSuite extends PlanTest { InferFiltersFromConstraints, CombineFilters, SimplifyBinaryComparison, - BooleanSimplification) :: Nil + BooleanSimplification) :: + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -127,6 +129,10 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + private def updateNullability(plan: LogicalPlan): LogicalPlan = { + UpdateNullabilityInAttributeReferences.apply(plan) + } + test("inner join with alias: alias contains multiple attributes") { val t1 = testRelation.subquery('t1) val t2 = testRelation.subquery('t2) @@ -141,7 +147,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) .analyze val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + comparePlans(optimized, updateNullability(correctAnswer)) } test("inner join with alias: alias contains single attributes") { @@ -158,7 +164,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) .analyze val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + comparePlans(optimized, updateNullability(correctAnswer)) } test("generate correct filters for alias that don't produce recursive constraints") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala new file mode 100644 index 0000000000000..3f950e1a1cf55 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + + +class UpdateNullabilityInAttributeReferencesSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("InferAndPushDownFilters", FixedPoint(100), + InferFiltersFromConstraints) :: + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) :: Nil + } + + test("update nullability when inferred constraints applied") { + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") { + val testRelation = LocalRelation('a.int, 'b.int) + val logicalPlan = testRelation.where('a =!= 2).select('a).analyze + var expectedSchema = new StructType().add("a", "INT", nullable = true) + assert(StructType.fromAttributes(logicalPlan.output) === expectedSchema) + val optimizedPlan = Optimize.execute(logicalPlan) + expectedSchema = new StructType().add("a", "INT", nullable = false) + assert(StructType.fromAttributes(optimizedPlan.output) === expectedSchema) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 82b4eb9fba242..93bae3d7acb7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -429,11 +429,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object InMemoryScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, filters, mem: InMemoryRelation) => + val scanBuilder = (attrs: Seq[Attribute]) => { + // Since filters might change the nullability of attributes in `pruneFilterProject`, + // we need to update the nullability based on `InMemoryRelation` output. + val nullabilityMap = AttributeMap(mem.output.map { a => a -> a.nullable }) + val newOutputAttrs = attrs.map { ar => + nullabilityMap.get(ar).filterNot(_ == ar.nullable).map { nullable => + ar.withNullability(nullable) + }.getOrElse(ar) + } + InMemoryTableScanExec(newOutputAttrs, filters, mem) + } pruneFilterProject( projectList, filters, identity[Seq[Expression]], // All filters still need to be evaluated. - InMemoryTableScanExec(_, filters, mem)) :: Nil + scanBuilder) :: Nil case _ => Nil } } @@ -538,10 +549,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.SortExec(sortExprs, global, planLater(child)) :: Nil case logical.Project(projectList, child) => execution.ProjectExec(projectList, planLater(child)) :: Nil - case logical.Filter(condition, child) => - execution.FilterExec(condition, planLater(child)) :: Nil + case f @ logical.Filter(condition, child) => + execution.FilterExec(condition, planLater(child), f.output) :: Nil case f: logical.TypedFilter => - execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil + execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child), f.output) :: Nil case e @ logical.Expand(_, _, child) => execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil case logical.Window(windowExprs, partitionSpec, orderSpec, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 4707022f74547..8845060ee2acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -82,7 +82,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) /** Physical plan for Filter. */ -case class FilterExec(condition: Expression, child: SparkPlan) +case class FilterExec(condition: Expression, child: SparkPlan, outputAttrs: Seq[Attribute]) extends UnaryExecNode with CodegenSupport with PredicateHelper { // Split out all the IsNotNulls from condition. @@ -91,27 +91,14 @@ case class FilterExec(condition: Expression, child: SparkPlan) case _ => false } - // If one expression and its children are null intolerant, it is null intolerant. - private def isNullIntolerant(expr: Expression): Boolean = expr match { - case e: NullIntolerant => e.children.forall(isNullIntolerant) - case _ => false - } - - // The columns that will filtered out by `IsNotNull` could be considered as not nullable. - private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId) - // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate // all the variables at the beginning to take advantage of short circuiting. override def usedInputs: AttributeSet = AttributeSet.empty + // Since some plan rewrite rules (e.g., python.ExtractPythonUDFs) possibly change child's output + // from optimized logical plans, we need to adjust the filter's output here. override def output: Seq[Attribute] = { - child.output.map { a => - if (a.nullable && notNullAttributes.contains(a.exprId)) { - a.withNullability(false) - } else { - a - } - } + child.output.map { attr => outputAttrs.find(_.exprId == attr.exprId).getOrElse(attr) } } override lazy val metrics = Map( @@ -188,10 +175,10 @@ case class FilterExec(condition: Expression, child: SparkPlan) } }.mkString("\n") - // Reset the isNull to false for the not-null columns, then the followed operators could + // Reset the isNull to false for the not-nullable columns, then the followed operators could // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => - if (notNullAttributes.contains(child.output(i).exprId)) { + if (!output(i).nullable) { ev.isNull = "false" } ev @@ -224,6 +211,15 @@ case class FilterExec(condition: Expression, child: SparkPlan) override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputPartitioning: Partitioning = child.outputPartitioning + + // Don't display `outputAttrs` names in explain + override def simpleString: String = s"Filter ($condition)" +} + +object FilterExec { + def apply(condition: Expression, child: SparkPlan): FilterExec = { + FilterExec(condition, child, child.output) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f7b3393f65cb1..60e84e6ee7504 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2055,11 +2055,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { expr: String, expectedNonNullableColumns: Seq[String]): Unit = { val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr) - // In the logical plan, all the output columns of input dataframe are nullable - dfWithFilter.queryExecution.optimizedPlan.collect { - case e: Filter => assert(e.output.forall(_.nullable)) - } - dfWithFilter.queryExecution.executedPlan.collect { // When the child expression in isnotnull is null-intolerant (i.e. any null input will // result in null output), the involved columns are converted to not nullable; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9180a22c260f1..09f618594f1d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -129,7 +129,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val dsIntFilter = dsInt.filter(_ > 0) val planInt = dsIntFilter.queryExecution.executedPlan assert(planInt.collect { - case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if i.supportsBatch => () + case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec, _)) + if i.supportsBatch => () }.length == 1) assert(dsIntFilter.collect() === Array(1, 2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c1d61b843d899..c5d4f24818651 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -521,7 +521,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi def getPhysicalFilters(df: DataFrame): ExpressionSet = { ExpressionSet( df.queryExecution.executedPlan.collect { - case execution.FilterExec(f, _) => splitConjunctivePredicates(f) + case execution.FilterExec(f, _, _) => splitConjunctivePredicates(f) }.flatten) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index d456c931f5275..3ba752ecedf46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -47,8 +47,8 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec( And(_: AttributeReference, _: AttributeReference), - InputAdapter(_: BatchEvalPythonExec)) => f - case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b + InputAdapter(_: BatchEvalPythonExec), _) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _, _))) => b } assert(qualifiedPlanNodes.size == 2) } @@ -57,8 +57,8 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { val df = Seq(("Hello", 4)).toDF("a", "b") .where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { - case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec)) => f - case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b + case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec), _) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _, _))) => b } assert(qualifiedPlanNodes.size == 2) } @@ -69,7 +69,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec( And(_: AttributeReference, _: GreaterThan), - InputAdapter(_: BatchEvalPythonExec)) => f + InputAdapter(_: BatchEvalPythonExec), _) => f case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b } assert(qualifiedPlanNodes.size == 2) @@ -82,7 +82,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec( And(_: AttributeReference, _: GreaterThan), - InputAdapter(_: BatchEvalPythonExec)) => f + InputAdapter(_: BatchEvalPythonExec), _) => f case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b } assert(qualifiedPlanNodes.size == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index bc4a120f7042f..65fb1fd13fbce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -366,7 +366,7 @@ private[sql] trait SQLTestUtilsBase protected def stripSparkFilter(df: DataFrame): DataFrame = { val schema = df.schema val withoutFilters = df.queryExecution.sparkPlan.transform { - case FilterExec(_, child) => child + case FilterExec(_, child, _) => child } spark.internalCreateDataFrame(withoutFilters.execute(), schema) From 12a88e8e73d161efbfd33c6120d35b730df41080 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 28 Mar 2018 23:39:51 +0900 Subject: [PATCH 2/3] Fix --- .../sql/catalyst/optimizer/Optimizer.scala | 6 ++-- .../spark/sql/execution/SparkStrategies.scala | 19 +++---------- .../execution/basicPhysicalOperators.scala | 28 +++++++++---------- .../execution/WholeStageCodegenSuite.scala | 3 +- .../datasources/FileSourceStrategySuite.scala | 2 +- .../python/BatchEvalPythonExecSuite.scala | 12 ++++---- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- 7 files changed, 28 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 104177f58c0bc..9a1bbc675e397 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1321,10 +1321,8 @@ object UpdateNullabilityInAttributeReferences extends Rule[LogicalPlan] { case p if !p.isInstanceOf[LeafNode] => val nullabilityMap = AttributeMap(p.children.flatMap(_.output).map { x => x -> x.nullable }) p transformExpressions { - case ar: AttributeReference => - nullabilityMap.get(ar).filterNot(_ == ar.nullable).map { nullable => - ar.withNullability(nullable) - }.getOrElse(ar) + case ar: AttributeReference if nullabilityMap.contains(ar) => + ar.withNullability(nullabilityMap(ar)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 93bae3d7acb7c..82b4eb9fba242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -429,22 +429,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object InMemoryScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, filters, mem: InMemoryRelation) => - val scanBuilder = (attrs: Seq[Attribute]) => { - // Since filters might change the nullability of attributes in `pruneFilterProject`, - // we need to update the nullability based on `InMemoryRelation` output. - val nullabilityMap = AttributeMap(mem.output.map { a => a -> a.nullable }) - val newOutputAttrs = attrs.map { ar => - nullabilityMap.get(ar).filterNot(_ == ar.nullable).map { nullable => - ar.withNullability(nullable) - }.getOrElse(ar) - } - InMemoryTableScanExec(newOutputAttrs, filters, mem) - } pruneFilterProject( projectList, filters, identity[Seq[Expression]], // All filters still need to be evaluated. - scanBuilder) :: Nil + InMemoryTableScanExec(_, filters, mem)) :: Nil case _ => Nil } } @@ -549,10 +538,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.SortExec(sortExprs, global, planLater(child)) :: Nil case logical.Project(projectList, child) => execution.ProjectExec(projectList, planLater(child)) :: Nil - case f @ logical.Filter(condition, child) => - execution.FilterExec(condition, planLater(child), f.output) :: Nil + case logical.Filter(condition, child) => + execution.FilterExec(condition, planLater(child)) :: Nil case f: logical.TypedFilter => - execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child), f.output) :: Nil + execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil case e @ logical.Expand(_, _, child) => execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil case logical.Window(windowExprs, partitionSpec, orderSpec, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 8845060ee2acb..5fccccc512cca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -82,7 +82,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) /** Physical plan for Filter. */ -case class FilterExec(condition: Expression, child: SparkPlan, outputAttrs: Seq[Attribute]) +case class FilterExec(condition: Expression, child: SparkPlan) extends UnaryExecNode with CodegenSupport with PredicateHelper { // Split out all the IsNotNulls from condition. @@ -91,14 +91,21 @@ case class FilterExec(condition: Expression, child: SparkPlan, outputAttrs: Seq[ case _ => false } + // The columns that will filtered out by `IsNotNull` could be considered as not nullable. + private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId) + // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate // all the variables at the beginning to take advantage of short circuiting. override def usedInputs: AttributeSet = AttributeSet.empty - // Since some plan rewrite rules (e.g., python.ExtractPythonUDFs) possibly change child's output - // from optimized logical plans, we need to adjust the filter's output here. override def output: Seq[Attribute] = { - child.output.map { attr => outputAttrs.find(_.exprId == attr.exprId).getOrElse(attr) } + child.output.map { a => + if (a.nullable && notNullAttributes.contains(a.exprId)) { + a.withNullability(false) + } else { + a + } + } } override lazy val metrics = Map( @@ -175,10 +182,10 @@ case class FilterExec(condition: Expression, child: SparkPlan, outputAttrs: Seq[ } }.mkString("\n") - // Reset the isNull to false for the not-nullable columns, then the followed operators could + // Reset the isNull to false for the not-null columns, then the followed operators could // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => - if (!output(i).nullable) { + if (notNullAttributes.contains(child.output(i).exprId)) { ev.isNull = "false" } ev @@ -211,15 +218,6 @@ case class FilterExec(condition: Expression, child: SparkPlan, outputAttrs: Seq[ override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputPartitioning: Partitioning = child.outputPartitioning - - // Don't display `outputAttrs` names in explain - override def simpleString: String = s"Filter ($condition)" -} - -object FilterExec { - def apply(condition: Expression, child: SparkPlan): FilterExec = { - FilterExec(condition, child, child.output) - } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 09f618594f1d1..9180a22c260f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -129,8 +129,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val dsIntFilter = dsInt.filter(_ > 0) val planInt = dsIntFilter.queryExecution.executedPlan assert(planInt.collect { - case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec, _)) - if i.supportsBatch => () + case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if i.supportsBatch => () }.length == 1) assert(dsIntFilter.collect() === Array(1, 2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c5d4f24818651..c1d61b843d899 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -521,7 +521,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi def getPhysicalFilters(df: DataFrame): ExpressionSet = { ExpressionSet( df.queryExecution.executedPlan.collect { - case execution.FilterExec(f, _, _) => splitConjunctivePredicates(f) + case execution.FilterExec(f, _) => splitConjunctivePredicates(f) }.flatten) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 3ba752ecedf46..d456c931f5275 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -47,8 +47,8 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec( And(_: AttributeReference, _: AttributeReference), - InputAdapter(_: BatchEvalPythonExec), _) => f - case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _, _))) => b + InputAdapter(_: BatchEvalPythonExec)) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b } assert(qualifiedPlanNodes.size == 2) } @@ -57,8 +57,8 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { val df = Seq(("Hello", 4)).toDF("a", "b") .where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { - case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec), _) => f - case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _, _))) => b + case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec)) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b } assert(qualifiedPlanNodes.size == 2) } @@ -69,7 +69,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec( And(_: AttributeReference, _: GreaterThan), - InputAdapter(_: BatchEvalPythonExec), _) => f + InputAdapter(_: BatchEvalPythonExec)) => f case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b } assert(qualifiedPlanNodes.size == 2) @@ -82,7 +82,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec( And(_: AttributeReference, _: GreaterThan), - InputAdapter(_: BatchEvalPythonExec), _) => f + InputAdapter(_: BatchEvalPythonExec)) => f case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b } assert(qualifiedPlanNodes.size == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 65fb1fd13fbce..bc4a120f7042f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -366,7 +366,7 @@ private[sql] trait SQLTestUtilsBase protected def stripSparkFilter(df: DataFrame): DataFrame = { val schema = df.schema val withoutFilters = df.queryExecution.sparkPlan.transform { - case FilterExec(_, child, _) => child + case FilterExec(_, child) => child } spark.internalCreateDataFrame(withoutFilters.execute(), schema) From 65ee6f7afa1f3e3a68baf4e827fe3c54c5de7467 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 31 Mar 2018 08:04:10 +0900 Subject: [PATCH 3/3] Fix --- .../sql/catalyst/expressions/predicates.scala | 6 --- .../plans/logical/basicLogicalOperators.scala | 20 +--------- .../InferFiltersFromConstraintsSuite.scala | 12 ++---- ...ullabilityInAttributeReferencesSuite.scala | 37 +++++++++++-------- .../optimizer/complexTypesSuite.scala | 9 ----- .../execution/basicPhysicalOperators.scala | 6 +++ 6 files changed, 32 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f420556cf4293..4b85d9adbe311 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -75,12 +75,6 @@ trait PredicateHelper { } } - // If one expression and its children are null intolerant, it is null intolerant. - protected def isNullIntolerant(expr: Expression): Boolean = expr match { - case e: NullIntolerant => e.children.forall(isNullIntolerant) - case _ => false - } - /** * Returns true if `expr` can be evaluated using only the output of `plan`. This method * can be used to determine when it is acceptable to move expression evaluation within a query diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 60b52d4292b1d..a4fca790dd086 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -126,25 +126,7 @@ case class Generate( case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode with PredicateHelper { - - // The columns that will filtered out by `IsNotNull` could be considered as not nullable. - private lazy val notNullAttributes = splitConjunctivePredicates(condition).flatMap { - case isnotnull @ IsNotNull(a) - if isNullIntolerant(a) && a.references.subsetOf(child.outputSet) => - isnotnull.references.map(_.exprId) - case _ => - Seq.empty[ExprId] - }.toSet - - override def output: Seq[Attribute] = { - child.output.map { a => - if (a.nullable && notNullAttributes.contains(a.exprId)) { - a.withNullability(false) - } else { - a - } - } - } + override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 76d392ed2ffa1..f78c2356e35a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -35,9 +35,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { InferFiltersFromConstraints, CombineFilters, SimplifyBinaryComparison, - BooleanSimplification) :: - Batch("UpdateAttributeReferences", Once, - UpdateNullabilityInAttributeReferences) :: Nil + BooleanSimplification) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -129,10 +127,6 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - private def updateNullability(plan: LogicalPlan): LogicalPlan = { - UpdateNullabilityInAttributeReferences.apply(plan) - } - test("inner join with alias: alias contains multiple attributes") { val t1 = testRelation.subquery('t1) val t2 = testRelation.subquery('t2) @@ -147,7 +141,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) .analyze val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, updateNullability(correctAnswer)) + comparePlans(optimized, correctAnswer) } test("inner join with alias: alias contains single attributes") { @@ -164,7 +158,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) .analyze val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, updateNullability(correctAnswer)) + comparePlans(optimized, correctAnswer) } test("generate correct filters for alias that don't produce recursive constraints") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala index 3f950e1a1cf55..09b11f5aba2a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala @@ -19,32 +19,39 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{CreateArray, GetArrayItem} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType class UpdateNullabilityInAttributeReferencesSuite extends PlanTest { - object Optimize extends RuleExecutor[LogicalPlan] { + object Optimizer extends RuleExecutor[LogicalPlan] { val batches = - Batch("InferAndPushDownFilters", FixedPoint(100), - InferFiltersFromConstraints) :: + Batch("Constant Folding", FixedPoint(10), + NullPropagation, + ConstantFolding, + BooleanSimplification, + SimplifyConditionals, + SimplifyBinaryComparison, + SimplifyExtractValueOps) :: Batch("UpdateAttributeReferences", Once, UpdateNullabilityInAttributeReferences) :: Nil } - test("update nullability when inferred constraints applied") { - withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") { - val testRelation = LocalRelation('a.int, 'b.int) - val logicalPlan = testRelation.where('a =!= 2).select('a).analyze - var expectedSchema = new StructType().add("a", "INT", nullable = true) - assert(StructType.fromAttributes(logicalPlan.output) === expectedSchema) - val optimizedPlan = Optimize.execute(logicalPlan) - expectedSchema = new StructType().add("a", "INT", nullable = false) - assert(StructType.fromAttributes(optimizedPlan.output) === expectedSchema) - } + test("update nullability in AttributeReference") { + val rel = LocalRelation('a.long.notNull) + // In the 'original' plans below, the Aggregate node produced by groupBy() has a + // nullable AttributeReference to `b`, because both array indexing and map lookup are + // nullable expressions. After optimization, the same attribute is now non-nullable, + // but the AttributeReference is not updated to reflect this. So, we need to update nullability + // by the `UpdateNullabilityInAttributeReferences` rule. + val original = rel + .select(GetArrayItem(CreateArray(Seq('a, 'a + 1L)), 0) as "b") + .groupBy($"b")("1") + val expected = rel.select('a as "b").groupBy($"b")("1").analyze + val optimized = Optimizer.execute(original.analyze) + comparePlans(optimized, expected) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 21ed987627b3b..633d86d495581 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -378,15 +378,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .groupBy($"foo")("1") checkRule(structRel, structExpected) - // These tests must use nullable attributes from the base relation for the following reason: - // in the 'original' plans below, the Aggregate node produced by groupBy() has a - // nullable AttributeReference to a1, because both array indexing and map lookup are - // nullable expressions. After optimization, the same attribute is now non-nullable, - // but the AttributeReference is not updated to reflect this. In the 'expected' plans, - // the grouping expressions have the same nullability as the original attribute in the - // relation. If that attribute is non-nullable, the tests will fail as the plans will - // compare differently, so for these tests we must use a nullable attribute. See - // SPARK-23634. val arrayRel = relation .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") .groupBy($"a1")("1") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 5fccccc512cca..4707022f74547 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -91,6 +91,12 @@ case class FilterExec(condition: Expression, child: SparkPlan) case _ => false } + // If one expression and its children are null intolerant, it is null intolerant. + private def isNullIntolerant(expr: Expression): Boolean = expr match { + case e: NullIntolerant => e.children.forall(isNullIntolerant) + case _ => false + } + // The columns that will filtered out by `IsNotNull` could be considered as not nullable. private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId)