From 54c3cc849e91b9bedc6379e04edc9e23245be760 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 17 Oct 2016 16:29:43 -0700 Subject: [PATCH 1/6] fix --- .../execution/basicPhysicalOperators.scala | 9 +++- .../org/apache/spark/sql/DataFrameSuite.scala | 48 ++++++++++++++++++- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index dd78a784915d2..f91aa835ef7fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -87,7 +87,14 @@ case class FilterExec(condition: Expression, child: SparkPlan) // Split out all the IsNotNulls from condition. private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { - case IsNotNull(a: NullIntolerant) if a.references.subsetOf(child.outputSet) => true + case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet) + case _ => false + } + + // One expression is null intolerant iff it and its children are null intolerant + private def isNullIntolerant(expr: Expression): Boolean = expr match { + case e: NullIntolerant => + if (e.isInstanceOf[LeafExpression]) true else e.children.forall(isNullIntolerant) case _ => false } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 16cc368208485..8645ed2315af4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -27,8 +27,8 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, Union} -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union} +import org.apache.spark.sql.execution.{FilterExec, QueryExecution} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} import org.apache.spark.sql.functions._ @@ -1615,4 +1615,48 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { qe.assertAnalyzed() } } + + private def verifyNullabilityInFilterExec(expr: String, isNullIntolerant: Boolean): Unit = { + val df = sparkContext.parallelize(Seq( + null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), + new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], + new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF("a", "b") + + val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr) + dfWithFilter.queryExecution.optimizedPlan.collect { + // In the logical plan, all the output columns are nullable + case e: Filter => assert(e.output.forall(_.nullable)) + } + + dfWithFilter.queryExecution.executedPlan.collect { + // When the child expression in isnotnull is null-intolerant (i.e. any null input will + // result in null output), the columns are converted to not nullable; Otherwise, no change + // should be made. + case e: FilterExec => + assert(e.output.forall(o => if (isNullIntolerant) !o.nullable else o.nullable)) + } + } + + test("SPARK-17957: no change on nullability in FilterExec output") { + verifyNullabilityInFilterExec("coalesce(a, b)", isNullIntolerant = false) + + verifyNullabilityInFilterExec( + "cast(coalesce(cast(coalesce(a, b) as double), 0.0) as int)", isNullIntolerant = false) + } + + test("SPARK-17957: set nullability to false in FilterExec output") { + verifyNullabilityInFilterExec("a + b * 3", isNullIntolerant = true) + + verifyNullabilityInFilterExec("a + b", isNullIntolerant = true) + + verifyNullabilityInFilterExec("cast((a + b) as boolean)", isNullIntolerant = true) + } + + test("SPARK-17957: outer join + na.fill") { + val df1 = Seq((1, 2), (2, 3)).toDF("a", "b") + val df2 = Seq((2, 5), (3, 4)).toDF("a", "c") + val joinedDf = df1.join(df2, Seq("a"), "outer").na.fill(0) + val df3 = Seq((3, 1)).toDF("a", "d") + checkAnswer(joinedDf.join(df3, "a"), Row(3, 0, 4, 1)) + } } From ce418f9ff0ddcc2312f338084899fd261a7875ee Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 19 Oct 2016 13:27:12 -0700 Subject: [PATCH 2/6] add more test cases --- .../org/apache/spark/sql/DataFrameSuite.scala | 58 +++++++++++++------ 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8645ed2315af4..f8fd8e9fa38e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1616,40 +1616,62 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - private def verifyNullabilityInFilterExec(expr: String, isNullIntolerant: Boolean): Unit = { - val df = sparkContext.parallelize(Seq( - null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), - new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], - new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF("a", "b") - + private def verifyNullabilityInFilterExec( + df: DataFrame, + expr: String, + expectedNullableColumns: Seq[String]): Unit = { val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr) + // In the logical plan, all the output columns of input dataframe are nullable dfWithFilter.queryExecution.optimizedPlan.collect { - // In the logical plan, all the output columns are nullable case e: Filter => assert(e.output.forall(_.nullable)) } dfWithFilter.queryExecution.executedPlan.collect { // When the child expression in isnotnull is null-intolerant (i.e. any null input will - // result in null output), the columns are converted to not nullable; Otherwise, no change - // should be made. + // result in null output), the involved columns are converted to not nullable; + // otherwise, no change should be made. case e: FilterExec => - assert(e.output.forall(o => if (isNullIntolerant) !o.nullable else o.nullable)) + assert(e.output.forall { o => + if (expectedNullableColumns.contains(o.name)) !o.nullable else o.nullable + }) } } test("SPARK-17957: no change on nullability in FilterExec output") { - verifyNullabilityInFilterExec("coalesce(a, b)", isNullIntolerant = false) + val df = sparkContext.parallelize(Seq( + null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), + new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], + new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() - verifyNullabilityInFilterExec( - "cast(coalesce(cast(coalesce(a, b) as double), 0.0) as int)", isNullIntolerant = false) + verifyNullabilityInFilterExec(df, + expr = "Rand()", expectedNullableColumns = Seq.empty[String]) + verifyNullabilityInFilterExec(df, + expr = "coalesce(_1, _2)", expectedNullableColumns = Seq.empty[String]) + verifyNullabilityInFilterExec(df, + expr = "coalesce(_1, 0) + Rand()", expectedNullableColumns = Seq.empty[String]) + verifyNullabilityInFilterExec(df, + expr = "cast(coalesce(cast(coalesce(_1, _2) as double), 0.0) as int)", + expectedNullableColumns = Seq.empty[String]) } test("SPARK-17957: set nullability to false in FilterExec output") { - verifyNullabilityInFilterExec("a + b * 3", isNullIntolerant = true) - - verifyNullabilityInFilterExec("a + b", isNullIntolerant = true) - - verifyNullabilityInFilterExec("cast((a + b) as boolean)", isNullIntolerant = true) + val df = sparkContext.parallelize(Seq( + null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), + new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], + new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() + + verifyNullabilityInFilterExec(df, + expr = "_1 + _2 * 3", expectedNullableColumns = Seq("_1", "_2")) + verifyNullabilityInFilterExec(df, + expr = "_1 + _2", expectedNullableColumns = Seq("_1", "_2")) + verifyNullabilityInFilterExec(df, + expr = "_1", expectedNullableColumns = Seq("_1")) + verifyNullabilityInFilterExec(df, + expr = "_2 + Rand()", expectedNullableColumns = Seq("_2")) + verifyNullabilityInFilterExec(df, + expr = "_2 * 3 + coalesce(_1, 0)", expectedNullableColumns = Seq("_2")) + verifyNullabilityInFilterExec(df, + expr = "cast((_1 + _2) as boolean)", expectedNullableColumns = Seq("_1", "_2")) } test("SPARK-17957: outer join + na.fill") { From 52cb8fb33e79b527008b68dccbaeb5bcc82f5feb Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 19 Oct 2016 19:13:57 -0700 Subject: [PATCH 3/6] change the parm name to expectedNonNullableColumns --- .../org/apache/spark/sql/DataFrameSuite.scala | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f8fd8e9fa38e9..57f2b5fa4564a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1619,7 +1619,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { private def verifyNullabilityInFilterExec( df: DataFrame, expr: String, - expectedNullableColumns: Seq[String]): Unit = { + expectedNonNullableColumns: Seq[String]): Unit = { val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr) // In the logical plan, all the output columns of input dataframe are nullable dfWithFilter.queryExecution.optimizedPlan.collect { @@ -1632,7 +1632,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // otherwise, no change should be made. case e: FilterExec => assert(e.output.forall { o => - if (expectedNullableColumns.contains(o.name)) !o.nullable else o.nullable + if (expectedNonNullableColumns.contains(o.name)) !o.nullable else o.nullable }) } } @@ -1644,14 +1644,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() verifyNullabilityInFilterExec(df, - expr = "Rand()", expectedNullableColumns = Seq.empty[String]) + expr = "Rand()", expectedNonNullableColumns = Seq.empty[String]) verifyNullabilityInFilterExec(df, - expr = "coalesce(_1, _2)", expectedNullableColumns = Seq.empty[String]) + expr = "coalesce(_1, _2)", expectedNonNullableColumns = Seq.empty[String]) verifyNullabilityInFilterExec(df, - expr = "coalesce(_1, 0) + Rand()", expectedNullableColumns = Seq.empty[String]) + expr = "coalesce(_1, 0) + Rand()", expectedNonNullableColumns = Seq.empty[String]) verifyNullabilityInFilterExec(df, expr = "cast(coalesce(cast(coalesce(_1, _2) as double), 0.0) as int)", - expectedNullableColumns = Seq.empty[String]) + expectedNonNullableColumns = Seq.empty[String]) } test("SPARK-17957: set nullability to false in FilterExec output") { @@ -1661,17 +1661,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() verifyNullabilityInFilterExec(df, - expr = "_1 + _2 * 3", expectedNullableColumns = Seq("_1", "_2")) + expr = "_1 + _2 * 3", expectedNonNullableColumns = Seq("_1", "_2")) verifyNullabilityInFilterExec(df, - expr = "_1 + _2", expectedNullableColumns = Seq("_1", "_2")) + expr = "_1 + _2", expectedNonNullableColumns = Seq("_1", "_2")) verifyNullabilityInFilterExec(df, - expr = "_1", expectedNullableColumns = Seq("_1")) + expr = "_1", expectedNonNullableColumns = Seq("_1")) verifyNullabilityInFilterExec(df, - expr = "_2 + Rand()", expectedNullableColumns = Seq("_2")) + expr = "_2 + Rand()", expectedNonNullableColumns = Seq("_2")) verifyNullabilityInFilterExec(df, - expr = "_2 * 3 + coalesce(_1, 0)", expectedNullableColumns = Seq("_2")) + expr = "_2 * 3 + coalesce(_1, 0)", expectedNonNullableColumns = Seq("_2")) verifyNullabilityInFilterExec(df, - expr = "cast((_1 + _2) as boolean)", expectedNullableColumns = Seq("_1", "_2")) + expr = "cast((_1 + _2) as boolean)", expectedNonNullableColumns = Seq("_1", "_2")) } test("SPARK-17957: outer join + na.fill") { From 4f2101e03e4a848508e504912c2a4c3f390ef51f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 24 Oct 2016 10:03:27 -0700 Subject: [PATCH 4/6] merge --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a5fe6ecb2af4d..b27eb7406f00c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1617,7 +1617,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } -<<<<<<< HEAD private def verifyNullabilityInFilterExec( df: DataFrame, expr: String, From 49daace1d86cbc0547d3da79049a78b0fe585942 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 28 Oct 2016 00:36:15 -0700 Subject: [PATCH 5/6] update the comment --- .../org/apache/spark/sql/execution/basicPhysicalOperators.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index ee1d7c38e75db..883686b8186ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -93,7 +93,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) case _ => false } - // One expression is null intolerant iff it and its children are null intolerant + // If one expression and its children are null intolerant, it is null intolerant. private def isNullIntolerant(expr: Expression): Boolean = expr match { case e: NullIntolerant => if (e.isInstanceOf[LeafExpression]) true else e.children.forall(isNullIntolerant) From 2364cc292e0f6585a4647225446b1044c723908c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 2 Nov 2016 20:51:33 -0700 Subject: [PATCH 6/6] address comments. --- .../apache/spark/sql/execution/basicPhysicalOperators.scala | 3 +-- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 4 ++++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 883686b8186ce..2039b8669465e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -95,8 +95,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) // If one expression and its children are null intolerant, it is null intolerant. private def isNullIntolerant(expr: Expression): Boolean = expr match { - case e: NullIntolerant => - if (e.isInstanceOf[LeafExpression]) true else e.children.forall(isNullIntolerant) + case e: NullIntolerant => e.children.forall(isNullIntolerant) case _ => false } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b27eb7406f00c..66970bcf98060 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1667,6 +1667,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { expr = "_1 + _2", expectedNonNullableColumns = Seq("_1", "_2")) verifyNullabilityInFilterExec(df, expr = "_1", expectedNonNullableColumns = Seq("_1")) + // `constructIsNotNullConstraints` infers the IsNotNull(_2) from IsNotNull(_2 + Rand()) + // Thus, we are able to set nullability of _2 to false. + // If IsNotNull(_2) is not given from `constructIsNotNullConstraints`, the impl of + // isNullIntolerant in `FilterExec` needs an update for more advanced inference. verifyNullabilityInFilterExec(df, expr = "_2 + Rand()", expectedNonNullableColumns = Seq("_2")) verifyNullabilityInFilterExec(df,