From a1574824ddacd58bbc1e2c9570177b33b18cec6d Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 7 May 2018 01:18:17 -0700 Subject: [PATCH 01/12] [SPARK-21274] Implement EXCEPT ALL clause. --- python/pyspark/sql/dataframe.py | 21 ++ .../sql/catalyst/analysis/Analyzer.scala | 7 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 +- .../sql/catalyst/expressions/generators.scala | 31 ++ .../sql/catalyst/optimizer/Optimizer.scala | 59 ++++ .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../plans/logical/basicLogicalOperators.scala | 6 +- .../scala/org/apache/spark/sql/Dataset.scala | 15 + .../resources/sql-tests/inputs/except-all.sql | 146 ++++++++ .../sql-tests/results/except-all.sql.out | 323 ++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 70 +++- 11 files changed, 675 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/except-all.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/except-all.sql.out diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c40aea9bcef0a..97c99d95034dd 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -293,6 +293,27 @@ def explain(self, extended=False): else: print(self._jdf.queryExecution().simpleString()) + @since(2.4) + def exceptAll(self, other): + """ Return a new :class:`DataFrame` containing rows in this :class:`DataFrame` but + not in another :class:`DataFrame while preserving duplicates. + + This is equivalent to `EXCEPT ALL` in SQL. + >>> df1 = spark.createDataFrame([("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"]) + >>> df2 = spark.createDataFrame([("a", 1), ("b", 3)], ["C1", "C2"]) + + >>> df1.exceptAll(df2).show() + +---+---+ + | C1| C2| + +---+---+ + | a| 2| + | c| 4| + +---+---+ + + Also as standard in SQL, this function resolves columns by position (not by name). + """ + return DataFrame(self._jdf.exceptAll(other._jdf), self.sql_ctx) + @since(1.3) def isLocal(self): """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8e8f8e3e7eda5..ecb26e9566128 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -916,9 +916,10 @@ class Analyzer( j.copy(right = dedupRight(left, right)) case i @ Intersect(left, right) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) - case i @ Except(left, right) if !i.duplicateResolved => - i.copy(right = dedupRight(left, right)) - + case e @ Except(left, right) if !e.duplicateResolved => + e.copy(right = dedupRight(left, right)) + case e @ ExceptAll(left, right) if !e.duplicateResolved => + e.copy(right = dedupRight(left, right)) // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index f9478a1c3cf4b..f78c67eb35f32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -53,7 +53,7 @@ trait CheckAnalysis extends PredicateHelper { } protected def mapColumnInSetOperation(plan: LogicalPlan): Option[Attribute] = plan match { - case _: Intersect | _: Except | _: Distinct => + case _: Intersect | _: ExceptBase | _: Distinct => plan.output.find(a => hasMapType(a.dataType)) case d: Deduplicate => d.keys.find(a => hasMapType(a.dataType)) @@ -330,7 +330,7 @@ trait CheckAnalysis extends PredicateHelper { |Conflicting attributes: ${conflictingAttributes.mkString(",")} """.stripMargin) - case e: Except if !e.duplicateResolved => + case e: ExceptBase if !e.duplicateResolved => val conflictingAttributes = e.left.outputSet.intersect(e.right.outputSet) failAnalysis( s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index b6e0d364d3a96..c7ca7350ba578 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -223,6 +223,37 @@ case class Stack(children: Seq[Expression]) extends Generator { } } +/** + * Replicate the row N times. N is specified as the first argument to the function. + * This is a internal function solely used by optimizer to rewrite EXCEPT ALL AND + * INTERSECT ALL queries. + */ +@ExpressionDescription( +usage = "_FUNC_(n, expr1, ..., exprk) - Replicates `n`, `expr1`, ..., `exprk` into `n` rows.", +examples = """ This is a internal function which is used for query rewrites only to support + EXCEPT ALL AND INTERSECT ALL. + """) +case class ReplicateRows(children: Seq[Expression]) extends Generator with CodegenFallback { + private lazy val numColumns = children.length - 1 // remove the multiplier value from output. + + override def elementSchema: StructType = + StructType(children.tail.zipWithIndex.map { + case (e, index) => StructField(s"col$index", e.dataType) + }) + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val numRows = children.head.eval(input).asInstanceOf[Long] + val values = children.tail.map(_.eval(input)).toArray + Range.Long(0, numRows, 1).map { i => + val fields = new Array[Any](numColumns) + for (col <- 0 until numColumns) { + fields.update(col, values(col)) + } + InternalRow(fields: _*) + } + } +} + /** * Wrapper around another generator to specify outer behavior. This is used to implement functions * such as explode_outer. This expression gets replaced during analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3c264eb8586b5..1fe77d169e7cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -135,6 +135,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) Batch("Subquery", Once, OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, + RewriteExcepAll, ReplaceIntersectWithSemiJoin, ReplaceExceptWithFilter, ReplaceExceptWithAntiJoin, @@ -1429,6 +1430,64 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { } } +/** + * Replaces logical [[ExceptAll]] operator using a combination of Union, Aggregate + * and Generate operator. + * + * Input Query : + * {{{ + * SELECT c1 FROM ut1 EXCEPT ALL SELECT c1 FROM ut2 + * }}} + * + * Rewritten Query: + * {{{ + * SELECT c1 + * FROM ( + * SELECT replicate_rows(sum_val, c1) AS (sum_val, c1) + * FROM ( + * SELECT c1, cnt, sum_val + * FROM ( + * SELECT c1, sum(vcol) AS sum_val + * FROM ( + * SELECT 1L as vcol, c1 FROM ut1 + * UNION ALL + * SELECT -1L as vcol, c1 FROM ut2 + * ) AS union_all + * GROUP BY union_all.c1 + * ) + * WHERE sum_val > 0 + * ) + * ) + * }}} + */ + +object RewriteExcepAll extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case ExceptAll(left, right) => + assert(left.output.size == right.output.size) + + val newColumnLeft = Alias(Literal(1L), "vcol")() + val newColumnRight = Alias(Literal(-1L), "vcol")() + val modifiedLeftPlan = Project(Seq(newColumnLeft) ++ left.output, left) + val modifiedRightPlan = Project(Seq(newColumnRight) ++ right.output, right) + val unionPlan = Union(modifiedLeftPlan, modifiedRightPlan) + val aggSumCol = + Alias(AggregateExpression(Sum(unionPlan.output.head.toAttribute), Complete, false), "sum")() + val aggOutputColumns = left.output ++ Seq(aggSumCol) + val aggregatePlan = Aggregate(left.output, aggOutputColumns, unionPlan) + val filteredAggPlan = Filter(GreaterThan(aggSumCol.toAttribute, Literal(0L)), aggregatePlan) + val genRowPlan = Generate( + ReplicateRows(Seq(aggSumCol.toAttribute) ++ left.output), + Nil, + false, + None, + left.output, + filteredAggPlan + ) + Project(left.output, genRowPlan) + } +} + /** * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result * but only makes the grouping key bigger. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 49f578a24aaeb..59e941dc0ac50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -537,7 +537,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.INTERSECT => Intersect(left, right) case SqlBaseParser.EXCEPT if all => - throw new ParseException("EXCEPT ALL is not supported.", ctx) + ExceptAll(left, right) case SqlBaseParser.EXCEPT => Except(left, right) case SqlBaseParser.SETMINUS if all => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ea5a9b8ed5542..ae78ca190246b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -183,14 +183,16 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation } } -case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - +abstract class ExceptBase(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output override protected def validConstraints: Set[Expression] = leftConstraints } +case class Except(left: LogicalPlan, right: LogicalPlan) extends ExceptBase(left, right) +case class ExceptAll(left: LogicalPlan, right: LogicalPlan) extends ExceptBase(left, right) + /** Factory for constructing new `Union` nodes. */ object Union { def apply(left: LogicalPlan, right: LogicalPlan): Union = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b63235ec827c9..336b4983f9736 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1948,6 +1948,21 @@ class Dataset[T] private[sql]( Except(planWithBarrier, other.planWithBarrier) } + /** + * Returns a new Dataset containing rows in this Dataset but not in another Dataset while + * preserving the duplicates. + * This is equivalent to `EXCEPT ALL` in SQL. + * + * @note Equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * + * @group typedrel + * @since 2.4.0 + */ + def exceptAll(other: Dataset[T]): Dataset[T] = withSetOperator { + ExceptAll(planWithBarrier, other.planWithBarrier) + } + /** * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), * using a user-supplied seed. diff --git a/sql/core/src/test/resources/sql-tests/inputs/except-all.sql b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql new file mode 100644 index 0000000000000..b14f7eeb991b1 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql @@ -0,0 +1,146 @@ +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (0),(1),(2),(2),(2),(2),(3),(null),(null) AS tab1(c1) ; +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1),(2),(2),(3),(5),(5),(null) AS tab2(c1) ; +CREATE TEMPORARY VIEW tab3 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (2, 3), + (2, 2) + AS tab3(k, v); +CREATE TEMPORARY VIEW tab4 AS SELECT * FROM VALUES + (1, 2), + (2, 3), + (2, 2), + (2, 2), + (2, 20) + AS tab4(k, v); + +-- Basic ExceptAll +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2; + +-- ExceptAll same table in both branches +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 IS NOT NULL; + +-- Empty left relation +SELECT * FROM tab1 WHERE c1 > 5 +EXCEPT ALL +SELECT * FROM tab2; + +-- Empty right relation +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 > 6; + +-- Type Coerced ExceptAll +SELECT * FROM tab1 +EXCEPT ALL +SELECT CAST(1 AS BIGINT); + +-- Error as types of two side are not compatible +SELECT * FROM tab1 +EXCEPT ALL +SELECT array(1); + +-- Basic +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4; + +-- Basic +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3; + +-- ExceptAll + Intersect +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +INTERSECT DISTINCT +SELECT * FROM tab4; + +-- ExceptAll + Except +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Chain of set operations +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Mismatch on number of columns across both branches +SELECT k FROM tab3 +EXCEPT ALL +SELECT k, v FROM tab4; + +-- Chain of set operations +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Chain of set operations +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +EXCEPT DISTINCT +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Join under except all. Should produce empty resultset since both left and right sets +-- are same. +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k); + +-- Join under except all (2) +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab4.v AS k, + tab3.k AS v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k); + +-- Group by under ExceptAll +SELECT v FROM tab3 GROUP BY v +EXCEPT ALL +SELECT k FROM tab4 GROUP BY k + +-- Clean-up +DROP VIEW IF EXISTS tab1; +DROP VIEW IF EXISTS tab2; +DROP VIEW IF EXISTS tab3; +DROP VIEW IF EXISTS tab4; diff --git a/sql/core/src/test/resources/sql-tests/results/except-all.sql.out b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out new file mode 100644 index 0000000000000..39a03bab0762f --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out @@ -0,0 +1,323 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 24 + + +-- !query 0 +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (0),(1),(2),(2),(2),(2),(3),(null),(null) AS tab1(c1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1),(2),(2),(3),(5),(5),(null) AS tab2(c1) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TEMPORARY VIEW tab3 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (2, 3), + (2, 2) + AS tab3(k, v) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE TEMPORARY VIEW tab4 AS SELECT * FROM VALUES + (1, 2), + (2, 3), + (2, 2), + (2, 2), + (2, 20) + AS tab4(k, v) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 +-- !query 4 schema +struct +-- !query 4 output +0 +2 +2 +NULL + + +-- !query 5 +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 IS NOT NULL +-- !query 5 schema +struct +-- !query 5 output +0 +2 +2 +NULL +NULL + + +-- !query 6 +SELECT * FROM tab1 WHERE c1 > 5 +EXCEPT ALL +SELECT * FROM tab2 +-- !query 6 schema +struct +-- !query 6 output + + + +-- !query 7 +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 > 6 +-- !query 7 schema +struct +-- !query 7 output +0 +1 +2 +2 +2 +2 +3 +NULL +NULL + + +-- !query 8 +SELECT * FROM tab1 +EXCEPT ALL +SELECT CAST(1 AS BIGINT) +-- !query 8 schema +struct +-- !query 8 output +0 +2 +2 +2 +2 +3 +NULL +NULL + + +-- !query 9 +SELECT * FROM tab1 +EXCEPT ALL +SELECT array(1) +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +ExceptAll can only be performed on tables with the compatible column types. array <> int at the first column of the second table; + + +-- !query 10 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +-- !query 10 schema +struct +-- !query 10 output +1 2 +1 3 + + +-- !query 11 +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +-- !query 11 schema +struct +-- !query 11 output +2 2 +2 20 + + +-- !query 12 +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +INTERSECT DISTINCT +SELECT * FROM tab4 +-- !query 12 schema +struct +-- !query 12 output +2 2 +2 20 + + +-- !query 13 +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 13 schema +struct +-- !query 13 output + + + +-- !query 14 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 14 schema +struct +-- !query 14 output +1 3 + + +-- !query 15 +SELECT k FROM tab3 +EXCEPT ALL +SELECT k, v FROM tab4 +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +ExceptAll can only be performed on tables with the same number of columns, but the first table has 1 columns and the second table has 2 columns; + + +-- !query 16 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 16 schema +struct +-- !query 16 output +1 3 + + +-- !query 17 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +EXCEPT DISTINCT +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 17 schema +struct +-- !query 17 output + + + +-- !query 18 +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +-- !query 18 schema +struct +-- !query 18 output + + + +-- !query 19 +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab4.v AS k, + tab3.k AS v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +-- !query 19 schema +struct +-- !query 19 output +1 2 +1 2 +1 2 +2 20 +2 20 +2 3 +2 3 + + +-- !query 20 +SELECT v FROM tab3 GROUP BY v +EXCEPT ALL +SELECT k FROM tab4 GROUP BY k + +DROP VIEW IF EXISTS tab1 +-- !query 20 schema +struct<> +-- !query 20 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'DROP' expecting (line 5, pos 0) + +== SQL == +SELECT v FROM tab3 GROUP BY v +EXCEPT ALL +SELECT k FROM tab4 GROUP BY k + +DROP VIEW IF EXISTS tab1 +^^^ + + +-- !query 21 +DROP VIEW IF EXISTS tab2 +-- !query 21 schema +struct<> +-- !query 21 output + + + +-- !query 22 +DROP VIEW IF EXISTS tab3 +-- !query 22 schema +struct<> +-- !query 22 output + + + +-- !query 23 +DROP VIEW IF EXISTS tab4 +-- !query 23 schema +struct<> +-- !query 23 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9cf8c47fa6cf1..af0735920cc29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} -import org.apache.spark.sql.test.SQLTestData.TestData2 +import org.apache.spark.sql.test.SQLTestData.{NullInts, NullStrings, TestData2} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -629,6 +629,74 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df4.schema.forall(!_.nullable)) } + test("except all") { + checkAnswer( + lowerCaseData.exceptAll(upperCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.exceptAll(lowerCaseData), Nil) + checkAnswer(upperCaseData.exceptAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.exceptAll(nullInts.filter("0 = 1")), + nullInts) + checkAnswer( + nullInts.exceptAll(nullInts), + Nil) + + // check that duplicate values are preserved + checkAnswer( + allNulls.exceptAll(allNulls.filter("0 = 1")), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + checkAnswer( + allNulls.exceptAll(allNulls.limit(2)), + Row(null) :: Row(null) :: Nil) + + // check that duplicates are retained. + val df = spark.sparkContext.parallelize( + NullStrings(1, "id1") :: + NullStrings(1, "id1") :: + NullStrings(2, "id1") :: + NullStrings(3, null) :: Nil).toDF("id", "value") + + checkAnswer( + df.exceptAll(df.filter("0 = 1")), + Row(1, "id1") :: + Row(1, "id1") :: + Row(2, "id1") :: + Row(3, null) :: Nil) + + // check if the empty set on the left side works + checkAnswer( + allNulls.filter("0 = 1").exceptAll(allNulls), + Nil) + + } + + test("exceptAll - nullability") { + val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.exceptAll(nullInts) + checkAnswer(df1, Row(11) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.exceptAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) + assert(df2.schema.forall(_.nullable)) + + val df3 = nullInts.exceptAll(nullInts) + checkAnswer(df3, Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.exceptAll(nonNullableInts) + checkAnswer(df4, Nil) + assert(df4.schema.forall(!_.nullable)) + } + test("intersect") { checkAnswer( lowerCaseData.intersect(lowerCaseData), From 62ab006e52d811b33d43da7f3e7e2da0c500b81d Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 24 Jul 2018 01:37:31 -0700 Subject: [PATCH 02/12] python docstyle --- python/pyspark/sql/dataframe.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 97c99d95034dd..bcc61b5acd2fe 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -295,10 +295,11 @@ def explain(self, extended=False): @since(2.4) def exceptAll(self, other): - """ Return a new :class:`DataFrame` containing rows in this :class:`DataFrame` but - not in another :class:`DataFrame while preserving duplicates. + """Return a new :class:`DataFrame` containing rows in this :class:`DataFrame` but + not in another :class:`DataFrame` while preserving duplicates. This is equivalent to `EXCEPT ALL` in SQL. + >>> df1 = spark.createDataFrame([("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"]) >>> df2 = spark.createDataFrame([("a", 1), ("b", 3)], ["C1", "C2"]) From 38a948fe0cd622ae955954f9b54f21dc7b29b561 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 24 Jul 2018 04:28:30 -0700 Subject: [PATCH 03/12] Code review + test failures --- .../apache/spark/sql/catalyst/expressions/generators.scala | 5 ----- .../apache/spark/sql/catalyst/parser/ErrorParserSuite.scala | 3 --- .../apache/spark/sql/catalyst/parser/PlanParserSuite.scala | 1 - 3 files changed, 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index c7ca7350ba578..2d31e903613de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -228,11 +228,6 @@ case class Stack(children: Seq[Expression]) extends Generator { * This is a internal function solely used by optimizer to rewrite EXCEPT ALL AND * INTERSECT ALL queries. */ -@ExpressionDescription( -usage = "_FUNC_(n, expr1, ..., exprk) - Replicates `n`, `expr1`, ..., `exprk` into `n` rows.", -examples = """ This is a internal function which is used for query rewrites only to support - EXCEPT ALL AND INTERSECT ALL. - """) case class ReplicateRows(children: Seq[Expression]) extends Generator with CodegenFallback { private lazy val numColumns = children.length - 1 // remove the multiplier value from output. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala index f67697eb86c26..baaf01800b33b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -58,8 +58,5 @@ class ErrorParserSuite extends SparkFunSuite { intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", "^^^") - intercept("select * from r except all select * from t", 1, 0, - "EXCEPT ALL is not supported", - "^^^") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index fb51376c6163f..629e3c4f3fcfb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -65,7 +65,6 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) assertEqual("select * from a union all select * from b", a.union(b)) assertEqual("select * from a except select * from b", a.except(b)) - intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.") assertEqual("select * from a except distinct select * from b", a.except(b)) assertEqual("select * from a minus select * from b", a.except(b)) intercept("select * from a minus all select * from b", "MINUS ALL is not supported.") From 072e0d78f95e4ad2f617292da2ee8e099f43e121 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 24 Jul 2018 08:49:16 -0700 Subject: [PATCH 04/12] code review --- .../apache/spark/sql/catalyst/expressions/generators.scala | 6 +++--- sql/core/src/test/resources/sql-tests/inputs/except-all.sql | 4 ++-- .../src/test/resources/sql-tests/results/except-all.sql.out | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 2d31e903613de..d6e67b9ac3d10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -225,7 +225,7 @@ case class Stack(children: Seq[Expression]) extends Generator { /** * Replicate the row N times. N is specified as the first argument to the function. - * This is a internal function solely used by optimizer to rewrite EXCEPT ALL AND + * This is an internal function solely used by optimizer to rewrite EXCEPT ALL AND * INTERSECT ALL queries. */ case class ReplicateRows(children: Seq[Expression]) extends Generator with CodegenFallback { @@ -234,12 +234,12 @@ case class ReplicateRows(children: Seq[Expression]) extends Generator with Codeg override def elementSchema: StructType = StructType(children.tail.zipWithIndex.map { case (e, index) => StructField(s"col$index", e.dataType) - }) + }) override def eval(input: InternalRow): TraversableOnce[InternalRow] = { val numRows = children.head.eval(input).asInstanceOf[Long] val values = children.tail.map(_.eval(input)).toArray - Range.Long(0, numRows, 1).map { i => + Range.Long(0, numRows, 1).map { _ => val fields = new Array[Any](numColumns) for (col <- 0 until numColumns) { fields.update(col, values(col)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/except-all.sql b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql index b14f7eeb991b1..db4bc7834eeed 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/except-all.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql @@ -1,7 +1,7 @@ CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES - (0),(1),(2),(2),(2),(2),(3),(null),(null) AS tab1(c1) ; + (0), (1), (2), (2), (2), (2), (3), (null), (null) AS tab1(c1); CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES - (1),(2),(2),(3),(5),(5),(null) AS tab2(c1) ; + (1), (2), (2), (3), (5), (5), (null) AS tab2(c1); CREATE TEMPORARY VIEW tab3 AS SELECT * FROM VALUES (1, 2), (1, 2), diff --git a/sql/core/src/test/resources/sql-tests/results/except-all.sql.out b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out index 39a03bab0762f..9c361d7fe6b1a 100644 --- a/sql/core/src/test/resources/sql-tests/results/except-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out @@ -4,7 +4,7 @@ -- !query 0 CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES - (0),(1),(2),(2),(2),(2),(3),(null),(null) AS tab1(c1) + (0), (1), (2), (2), (2), (2), (3), (null), (null) AS tab1(c1) -- !query 0 schema struct<> -- !query 0 output @@ -13,7 +13,7 @@ struct<> -- !query 1 CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES - (1),(2),(2),(3),(5),(5),(null) AS tab2(c1) + (1), (2), (2), (3), (5), (5), (null) AS tab2(c1) -- !query 1 schema struct<> -- !query 1 output From 7e451e3111dd3d0d1cc59fbe5292db90f302ebc8 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 24 Jul 2018 10:32:11 -0700 Subject: [PATCH 05/12] Add a unit test to test rewrite --- .../optimizer/SetOperationSuite.scala | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index aa8841109329c..e9d68abded054 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{Alias, GreaterThan, Literal, ReplicateRows} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -144,4 +144,26 @@ class SetOperationSuite extends PlanTest { Distinct(Union(query3 :: query4 :: Nil))).analyze comparePlans(distinctUnionCorrectAnswer2, optimized2) } + + test("EXCEPT ALL rewrite") { + val input = ExceptAll(testRelation, testRelation2) + val rewrittenPlan = RewriteExcepAll(input) + + val planFragment = testRelation.select(Literal(1L).as("vcol"), 'a, 'b, 'c) + .union(testRelation2.select(Literal(-1L).as("vcol"), 'd, 'e, 'f)) + .groupBy('a, 'b, 'c)('a, 'b, 'c, sum('vcol).as("sum")) + .where(GreaterThan('sum, Literal(0L))).analyze + val multiplerAttr = planFragment.output.last + val output = planFragment.output.dropRight(1) + val expectedPlan = Project(output, + Generate( + ReplicateRows(Seq(multiplerAttr) ++ output), + Nil, + false, + None, + output, + planFragment + )) + comparePlans(expectedPlan, rewrittenPlan) + } } From c39f88e4003e213b67d8706ce95f726fc986c014 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 25 Jul 2018 16:15:02 -0700 Subject: [PATCH 06/12] code review --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1fe77d169e7cc..84de0debd87cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1445,7 +1445,7 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { * FROM ( * SELECT replicate_rows(sum_val, c1) AS (sum_val, c1) * FROM ( - * SELECT c1, cnt, sum_val + * SELECT c1, sum_val * FROM ( * SELECT c1, sum(vcol) AS sum_val * FROM ( From 83ea225c70b499f778a99fc17fd142f33720a84a Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 26 Jul 2018 00:09:53 -0700 Subject: [PATCH 07/12] Remove exceptall logical operator --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 +- .../sql/catalyst/analysis/TypeCoercion.scala | 42 ++++++++++--------- .../UnsupportedOperationChecker.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../optimizer/ReplaceExceptWithFilter.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../plans/logical/basicLogicalOperators.scala | 9 ++-- .../optimizer/SetOperationSuite.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- 11 files changed, 39 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ecb26e9566128..af65e982546f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -916,9 +916,7 @@ class Analyzer( j.copy(right = dedupRight(left, right)) case i @ Intersect(left, right) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) - case e @ Except(left, right) if !e.duplicateResolved => - e.copy(right = dedupRight(left, right)) - case e @ ExceptAll(left, right) if !e.duplicateResolved => + case e @ Except(left, right, _) if !e.duplicateResolved => e.copy(right = dedupRight(left, right)) // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on its descendants diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index f78c67eb35f32..f9478a1c3cf4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -53,7 +53,7 @@ trait CheckAnalysis extends PredicateHelper { } protected def mapColumnInSetOperation(plan: LogicalPlan): Option[Attribute] = plan match { - case _: Intersect | _: ExceptBase | _: Distinct => + case _: Intersect | _: Except | _: Distinct => plan.output.find(a => hasMapType(a.dataType)) case d: Deduplicate => d.keys.find(a => hasMapType(a.dataType)) @@ -330,7 +330,7 @@ trait CheckAnalysis extends PredicateHelper { |Conflicting attributes: ${conflictingAttributes.mkString(",")} """.stripMargin) - case e: ExceptBase if !e.duplicateResolved => + case e: Except if !e.duplicateResolved => val conflictingAttributes = e.left.outputSet.intersect(e.right.outputSet) failAnalysis( s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 6bdb639011a17..520695b35acda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -319,11 +319,17 @@ object TypeCoercion { object WidenSetOperationTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case s @ SetOperation(left, right) if s.childrenResolved && - left.output.length == right.output.length && !s.resolved => + case s @ Except(left, right, isAll) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) assert(newChildren.length == 2) - s.makeCopy(Array(newChildren.head, newChildren.last)) + Except(newChildren.head, newChildren.last, isAll) + + case s @ Intersect(left, right) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + assert(newChildren.length == 2) + Intersect(newChildren.head, newChildren.last) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => @@ -391,7 +397,7 @@ object TypeCoercion { } override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -453,7 +459,7 @@ object TypeCoercion { } override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -512,7 +518,7 @@ object TypeCoercion { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -555,7 +561,7 @@ object TypeCoercion { object FunctionArgumentConversion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -670,7 +676,7 @@ object TypeCoercion { */ object Division extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.childrenResolved => e @@ -693,7 +699,7 @@ object TypeCoercion { */ object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) => val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) maybeCommonType.map { commonType => @@ -711,7 +717,7 @@ object TypeCoercion { */ object IfCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) => @@ -731,7 +737,7 @@ object TypeCoercion { * Coerces NullTypes in the Stack expression to the column types of the corresponding positions. */ object StackCoercion extends TypeCoercionRule { - override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows => Stack(children.zipWithIndex.map { // The first child is the number of rows for stack. @@ -751,8 +757,7 @@ object TypeCoercion { */ case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p => + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => p transformExpressionsUp { // Skip nodes if unresolved or empty children case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c @@ -774,8 +779,7 @@ object TypeCoercion { */ case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p => + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => p transformExpressionsUp { // Skip nodes if unresolved or not enough children case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c @@ -803,7 +807,7 @@ object TypeCoercion { private val acceptedTypes = Seq(DateType, TimestampType, StringType) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -824,7 +828,7 @@ object TypeCoercion { private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -963,7 +967,7 @@ object TypeCoercion { */ object WindowFrameCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper)) if order.resolved => s.copy(frameSpecification = SpecifiedWindowFrame( @@ -1001,7 +1005,7 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { protected def coerceTypes(plan: LogicalPlan): LogicalPlan - private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan transformUp { // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index f68df5d29b545..c9a3ee47a02be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -306,7 +306,7 @@ object UnsupportedOperationChecker { case u: Union if u.children.map(_.isStreaming).distinct.size == 2 => throwError("Union between streaming and batch DataFrames/Datasets is not supported") - case Except(left, right) if right.isStreaming => + case Except(left, right, _) if right.isStreaming => throwError("Except on a streaming DataFrame/Dataset on the right is not supported") case Intersect(left, right) if left.isStreaming && right.isStreaming => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 84de0debd87cc..1ca1a04f871f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1423,7 +1423,7 @@ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { */ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Except(left, right) => + case Except(left, right, false) => assert(left.output.size == right.output.size) val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And))) @@ -1463,7 +1463,7 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { object RewriteExcepAll extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case ExceptAll(left, right) => + case Except(left, right, true) => assert(left.output.size == right.output.size) val newColumnLeft = Alias(Literal(1L), "vcol")() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala index 45edf266bbce4..efd3944eba7f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -46,7 +46,7 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { } plan.transform { - case e @ Except(left, right) if isEligible(left, right) => + case e @ Except(left, right, false) if isEligible(left, right) => val newCondition = transformCondition(left, skipProject(right)) newCondition.map { c => Distinct(Filter(Not(c), left)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 59e941dc0ac50..5c55e87a23f51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -537,7 +537,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.INTERSECT => Intersect(left, right) case SqlBaseParser.EXCEPT if all => - ExceptAll(left, right) + Except(left, right, true) case SqlBaseParser.EXCEPT => Except(left, right) case SqlBaseParser.SETMINUS if all => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ae78ca190246b..498a13a62bd22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -183,16 +183,17 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation } } -abstract class ExceptBase(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { +case class Except( + left: LogicalPlan, + right: LogicalPlan, + isAll: Boolean = false) extends SetOperation(left, right) { + override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" ) /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output override protected def validConstraints: Set[Expression] = leftConstraints } -case class Except(left: LogicalPlan, right: LogicalPlan) extends ExceptBase(left, right) -case class ExceptAll(left: LogicalPlan, right: LogicalPlan) extends ExceptBase(left, right) - /** Factory for constructing new `Union` nodes. */ object Union { def apply(left: LogicalPlan, right: LogicalPlan): Union = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index e9d68abded054..4a73f707e340e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -146,7 +146,7 @@ class SetOperationSuite extends PlanTest { } test("EXCEPT ALL rewrite") { - val input = ExceptAll(testRelation, testRelation2) + val input = Except(testRelation, testRelation2, true) val rewrittenPlan = RewriteExcepAll(input) val planFragment = testRelation.select(Literal(1L).as("vcol"), 'a, 'b, 'c) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 336b4983f9736..dbce5ece497a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1960,7 +1960,7 @@ class Dataset[T] private[sql]( * @since 2.4.0 */ def exceptAll(other: Dataset[T]): Dataset[T] = withSetOperator { - ExceptAll(planWithBarrier, other.planWithBarrier) + Except(planWithBarrier, other.planWithBarrier, true) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0c4ea857fd1d7..16db5e928d8c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -532,7 +532,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Intersect(left, right) => throw new IllegalStateException( "logical intersect operator should have been replaced by semi-join in the optimizer") - case logical.Except(left, right) => + case logical.Except(left, right, _) => throw new IllegalStateException( "logical except operator should have been replaced by anti-join in the optimizer") From f23bc4781ab5c77a602118f7f5090f186c86f4cf Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 26 Jul 2018 02:32:26 -0700 Subject: [PATCH 08/12] fix --- .../org/apache/spark/sql/execution/SparkStrategies.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 16db5e928d8c1..bb4cbd3ceb607 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -532,9 +532,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Intersect(left, right) => throw new IllegalStateException( "logical intersect operator should have been replaced by semi-join in the optimizer") - case logical.Except(left, right, _) => + case logical.Except(left, right, false) => throw new IllegalStateException( "logical except operator should have been replaced by anti-join in the optimizer") + case logical.Except(left, right, true) => + throw new IllegalStateException( + "logical except operator should have been replaced by union, aggregate" + + "and generator operators in the optimizer") case logical.DeserializeToObject(deserializer, objAttr, child) => execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil From 8e0c1ef3f876cdd0e8eed0e6639762ddeb584b53 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 26 Jul 2018 10:47:36 -0700 Subject: [PATCH 09/12] minor --- .../scala/org/apache/spark/sql/execution/SparkStrategies.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index bb4cbd3ceb607..4faac17d8c501 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -538,7 +538,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Except(left, right, true) => throw new IllegalStateException( "logical except operator should have been replaced by union, aggregate" + - "and generator operators in the optimizer") + "and generate operators in the optimizer") case logical.DeserializeToObject(deserializer, objAttr, child) => execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil From a7ec7b6227be1896cc5315008c2e2031df126cc7 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 26 Jul 2018 16:13:11 -0700 Subject: [PATCH 10/12] code review --- python/pyspark/sql/dataframe.py | 5 ++++- .../spark/sql/catalyst/optimizer/Optimizer.scala | 10 +++++----- .../apache/spark/sql/catalyst/parser/AstBuilder.scala | 2 +- .../sql/catalyst/optimizer/SetOperationSuite.scala | 2 +- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 5 +++-- .../apache/spark/sql/execution/SparkStrategies.scala | 2 +- 6 files changed, 15 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index bcc61b5acd2fe..b2e0a5b2390c2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -300,13 +300,16 @@ def exceptAll(self, other): This is equivalent to `EXCEPT ALL` in SQL. - >>> df1 = spark.createDataFrame([("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"]) + >>> df1 = spark.createDataFrame( + ... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"]) >>> df2 = spark.createDataFrame([("a", 1), ("b", 3)], ["C1", "C2"]) >>> df1.exceptAll(df2).show() +---+---+ | C1| C2| +---+---+ + | a| 1| + | a| 1| | a| 2| | c| 4| +---+---+ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1ca1a04f871f9..193f6591c9a8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1431,7 +1431,7 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { } /** - * Replaces logical [[ExceptAll]] operator using a combination of Union, Aggregate + * Replaces logical [[Except]] operator using a combination of Union, Aggregate * and Generate operator. * * Input Query : @@ -1443,7 +1443,7 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { * {{{ * SELECT c1 * FROM ( - * SELECT replicate_rows(sum_val, c1) AS (sum_val, c1) + * SELECT replicate_rows(sum_val, c1) * FROM ( * SELECT c1, sum_val * FROM ( @@ -1478,9 +1478,9 @@ object RewriteExcepAll extends Rule[LogicalPlan] { val filteredAggPlan = Filter(GreaterThan(aggSumCol.toAttribute, Literal(0L)), aggregatePlan) val genRowPlan = Generate( ReplicateRows(Seq(aggSumCol.toAttribute) ++ left.output), - Nil, - false, - None, + unrequiredChildIndex = Nil, + outer = false, + qualifier = None, left.output, filteredAggPlan ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 5c55e87a23f51..8b3c0686181fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -537,7 +537,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.INTERSECT => Intersect(left, right) case SqlBaseParser.EXCEPT if all => - Except(left, right, true) + Except(left, right, isAll = true) case SqlBaseParser.EXCEPT => Except(left, right) case SqlBaseParser.SETMINUS if all => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 4a73f707e340e..f002aa3aacaba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -146,7 +146,7 @@ class SetOperationSuite extends PlanTest { } test("EXCEPT ALL rewrite") { - val input = Except(testRelation, testRelation2, true) + val input = Except(testRelation, testRelation2, isAll = true) val rewrittenPlan = RewriteExcepAll(input) val planFragment = testRelation.select(Literal(1L).as("vcol"), 'a, 'b, 'c) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index dbce5ece497a2..e6a3b0adcdaa6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1954,13 +1954,14 @@ class Dataset[T] private[sql]( * This is equivalent to `EXCEPT ALL` in SQL. * * @note Equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. + * and thus is not affected by a custom `equals` function defined on `T`. Also as standard in + * SQL, this function resolves columns by position (not by name). * * @group typedrel * @since 2.4.0 */ def exceptAll(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(planWithBarrier, other.planWithBarrier, true) + Except(planWithBarrier, other.planWithBarrier, isAll = true) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4faac17d8c501..3f5fd3dbb9e2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -537,7 +537,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "logical except operator should have been replaced by anti-join in the optimizer") case logical.Except(left, right, true) => throw new IllegalStateException( - "logical except operator should have been replaced by union, aggregate" + + "logical except (all) operator should have been replaced by union, aggregate" + "and generate operators in the optimizer") case logical.DeserializeToObject(deserializer, objAttr, child) => From 1e50b9d8b2ef5cb05fceed9c433e4e0350153412 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 26 Jul 2018 21:48:45 -0700 Subject: [PATCH 11/12] missing ; --- .../resources/sql-tests/inputs/except-all.sql | 2 +- .../sql-tests/results/except-all.sql.out | 34 ++++++++----------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/except-all.sql b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql index db4bc7834eeed..08b9a437b3d14 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/except-all.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql @@ -137,7 +137,7 @@ FROM (SELECT tab4.v AS k, -- Group by under ExceptAll SELECT v FROM tab3 GROUP BY v EXCEPT ALL -SELECT k FROM tab4 GROUP BY k +SELECT k FROM tab4 GROUP BY k; -- Clean-up DROP VIEW IF EXISTS tab1; diff --git a/sql/core/src/test/resources/sql-tests/results/except-all.sql.out b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out index 9c361d7fe6b1a..2a21c1505350c 100644 --- a/sql/core/src/test/resources/sql-tests/results/except-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 24 +-- Number of queries: 25 -- !query 0 @@ -280,27 +280,15 @@ struct -- !query 20 SELECT v FROM tab3 GROUP BY v EXCEPT ALL -SELECT k FROM tab4 GROUP BY k - -DROP VIEW IF EXISTS tab1 +SELECT k FROM tab4 GROUP BY k -- !query 20 schema -struct<> +struct -- !query 20 output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input 'DROP' expecting (line 5, pos 0) - -== SQL == -SELECT v FROM tab3 GROUP BY v -EXCEPT ALL -SELECT k FROM tab4 GROUP BY k - -DROP VIEW IF EXISTS tab1 -^^^ +3 -- !query 21 -DROP VIEW IF EXISTS tab2 +DROP VIEW IF EXISTS tab1 -- !query 21 schema struct<> -- !query 21 output @@ -308,7 +296,7 @@ struct<> -- !query 22 -DROP VIEW IF EXISTS tab3 +DROP VIEW IF EXISTS tab2 -- !query 22 schema struct<> -- !query 22 output @@ -316,8 +304,16 @@ struct<> -- !query 23 -DROP VIEW IF EXISTS tab4 +DROP VIEW IF EXISTS tab3 -- !query 23 schema struct<> -- !query 23 output + + +-- !query 24 +DROP VIEW IF EXISTS tab4 +-- !query 24 schema +struct<> +-- !query 24 output + From 4e04883133b09189ecaba29cbe3919174da76ed8 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 27 Jul 2018 01:29:33 -0700 Subject: [PATCH 12/12] rebase error --- .../sql/catalyst/analysis/TypeCoercion.scala | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 520695b35acda..f9edca53d571e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -397,7 +397,7 @@ object TypeCoercion { } override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -459,7 +459,7 @@ object TypeCoercion { } override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -518,7 +518,7 @@ object TypeCoercion { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -561,7 +561,7 @@ object TypeCoercion { object FunctionArgumentConversion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -676,7 +676,7 @@ object TypeCoercion { */ object Division extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.childrenResolved => e @@ -699,7 +699,7 @@ object TypeCoercion { */ object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) => val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) maybeCommonType.map { commonType => @@ -717,7 +717,7 @@ object TypeCoercion { */ object IfCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) => @@ -737,7 +737,7 @@ object TypeCoercion { * Coerces NullTypes in the Stack expression to the column types of the corresponding positions. */ object StackCoercion extends TypeCoercionRule { - override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows => Stack(children.zipWithIndex.map { // The first child is the number of rows for stack. @@ -757,7 +757,8 @@ object TypeCoercion { */ case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p => p transformExpressionsUp { // Skip nodes if unresolved or empty children case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c @@ -779,7 +780,8 @@ object TypeCoercion { */ case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p => p transformExpressionsUp { // Skip nodes if unresolved or not enough children case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c @@ -807,7 +809,7 @@ object TypeCoercion { private val acceptedTypes = Seq(DateType, TimestampType, StringType) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -828,7 +830,7 @@ object TypeCoercion { private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -967,7 +969,7 @@ object TypeCoercion { */ object WindowFrameCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper)) if order.resolved => s.copy(frameSpecification = SpecifiedWindowFrame( @@ -1005,7 +1007,7 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { protected def coerceTypes(plan: LogicalPlan): LogicalPlan - private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan transformUp { + private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q