From 3ce072b4682a362d578a01181e3b8699cc38de93 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 7 Mar 2016 15:07:26 +0800 Subject: [PATCH 1/9] SQL generation support for window functions --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/expressions/SortOrder.scala | 3 +- .../expressions/windowExpressions.scala | 17 +++++++++ .../apache/spark/sql/hive/SQLBuilder.scala | 37 ++++++++++++++++++- .../sql/hive/LogicalPlanToSQLSuite.scala | 15 ++++++++ 5 files changed, 70 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fbbc3ee891c6b..21b88d164a09c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1180,7 +1180,7 @@ class Analyzer( // Finally, we create a Project to output currentChild's output // newExpressionsWithWindowFunctions. - Project(currentChild.output ++ newExpressionsWithWindowFunctions, currentChild) + Project(child.output ++ newExpressionsWithWindowFunctions, currentChild) } // end of addWindow // We have to use transformDown at here to make sure the rule of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index bd1d91487275b..b739361937b6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -57,7 +57,8 @@ case class SortOrder(child: Expression, direction: SortDirection) override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable - override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" + override def toString: String = s"$child ${direction.sql}" + override def sql: String = child.sql + " " + direction.sql def isAscending: Boolean = direction == Ascending } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 9e6bd0ee460f0..ece1a97d47a7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -75,6 +75,22 @@ case class WindowSpecDefinition( override def nullable: Boolean = true override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException + + override def sql: String = { + val partition = if (partitionSpec.isEmpty) { + "" + } else { + "PARTITION BY " + partitionSpec.map(_.sql).mkString(", ") + } + + val order = if (orderSpec.isEmpty) { + "" + } else { + "ORDER BY " + orderSpec.map(_.sql).mkString(", ") + } + + s"($partition $order ${frameSpecification.toString})" + } } /** @@ -278,6 +294,7 @@ case class WindowExpression( override def nullable: Boolean = windowFunction.nullable override def toString: String = s"$windowFunction $windowSpec" + override def sql: String = windowFunction.sql + " OVER " + windowSpec.sql } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 683f738054c5a..e170282d17e7d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -42,7 +42,7 @@ case class SubqueryHolder(query: String) extends LeafExpression with Unevaluable } /** - * A builder class used to convert a resolved logical plan into a SQL query string. Note that this + * A builder class used to convert a resolved logical plan into a SQL query string. Note that not * all resolved logical plan are convertible. They either don't have corresponding SQL * representations (e.g. logical plans that operate on local Scala collections), or are simply not * supported by this builder (yet). @@ -103,6 +103,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case p: Aggregate => aggregateToSQL(p) + case w: Window => + windowToSQL(w) + case Limit(limitExpr, child) => s"${toSQL(child)} LIMIT ${limitExpr.sql}" @@ -179,7 +182,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi build( toSQL(p.child), if (p.global) "ORDER BY" else "SORT BY", - p.order.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ") + p.order.map(_.sql).mkString(", ") ) case p: RepartitionByExpression => @@ -297,6 +300,15 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi ) } + private def windowToSQL(w: Window): String = { + build( + "SELECT", + (w.child.output ++ w.windowExpressions).map(_.sql).mkString(", "), + if (w.child == OneRowRelation) "" else "FROM", + toSQL(w.child) + ) + } + object Canonicalizer extends RuleExecutor[LogicalPlan] { override protected def batches: Seq[Batch] = Seq( Batch("Canonicalizer", FixedPoint(100), @@ -340,6 +352,27 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case plan: Project => wrapChildWithSubquery(plan) + + case w @ Window(_, _, _, _, + _: SubqueryAlias + | _: Filter + | _: Join + | _: MetastoreRelation + | OneRowRelation + | _: LocalLimit + | _: GlobalLimit + | _: Sample + ) => w + + case w: Window => + val alias = SQLBuilder.newSubqueryName + val childAttributes = w.child.outputSet + val qualified = w.windowExpressions.map(_.transform { + case a: Attribute if childAttributes.contains(a) => + a.withQualifiers(alias :: Nil) + }.asInstanceOf[NamedExpression]) + + w.copy(windowExpressions = qualified, child = SubqueryAlias(alias, w.child)) } def wrapChildWithSubquery(project: Project): Project = project match { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index ed85856f017df..57e18dc1c05f7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -445,4 +445,19 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { "f1", "b[0].f1", "f1", "c[foo]", "d[0]" ) } + + test("window functions") { + checkHiveQl("SELECT a, SUM(b) OVER (PARTITION BY c ORDER BY d) AS sum FROM parquet_t2") + checkHiveQl( + """ + |SELECT a + 1, SUM(b * 2) OVER (PARTITION BY c + d ORDER BY c - d) AS sum + |FROM parquet_t2 + """.stripMargin) + checkHiveQl( + """ + |SELECT a, SUM(b) OVER w1 AS sum, AVG(b) over w2 AS avg + |FROM parquet_t2 + |WINDOW w1 AS (PARTITION BY c ORDER BY d), w2 AS (PARTITION BY d ORDER BY c) + """.stripMargin) + } } From e037814575535a635938b164cf183c7e8a66ea0b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 7 Mar 2016 17:47:57 +0800 Subject: [PATCH 2/9] fix qualifiers --- .../expressions/windowExpressions.scala | 5 + .../apache/spark/sql/hive/SQLBuilder.scala | 107 +++++++++++------- .../sql/hive/LogicalPlanToSQLSuite.scala | 84 ++++++++++++-- 3 files changed, 150 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index ece1a97d47a7e..0c3dffc7d8ca4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -468,6 +468,7 @@ object SizeBasedWindowFunction { the window partition.""") case class RowNumber() extends RowNumberLike { override val evaluateExpression = rowNumber + override def sql: String = "ROW_NUMBER()" } /** @@ -487,6 +488,7 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { // return the same value for equal values in the partition. override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType)) + override def sql: String = "CUME_DIST()" } /** @@ -625,6 +627,7 @@ abstract class RankLike extends AggregateWindowFunction { case class Rank(children: Seq[Expression]) extends RankLike { def this() = this(Nil) override def withOrder(order: Seq[Expression]): Rank = Rank(order) + override def sql: String = "RANK()" } /** @@ -649,6 +652,7 @@ case class DenseRank(children: Seq[Expression]) extends RankLike { override val updateExpressions = increaseRank +: children override val aggBufferAttributes = rank +: orderAttrs override val initialValues = zero +: orderInit + override def sql: String = "DENSE_RANK()" } /** @@ -675,4 +679,5 @@ case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBase override val evaluateExpression = If(GreaterThan(n, one), Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), DoubleType)), Literal(0.0d)) + override def sql: String = "PERCENT_RANK()" } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index e170282d17e7d..edabd4fffd815 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive import java.util.concurrent.atomic.AtomicLong +import scala.collection.mutable import scala.util.control.NonFatal import org.apache.spark.Logging @@ -107,7 +108,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi windowToSQL(w) case Limit(limitExpr, child) => - s"${toSQL(child)} LIMIT ${limitExpr.sql}" + s"${toSQL(child)} LIMIT ${exprToSQL(limitExpr, child)}" case p: Sample if p.isTableSample => val fraction = math.min(100, math.max(0, (p.upperBound - p.lowerBound) * 100)) @@ -126,12 +127,12 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi build(toSQL(p.child), "TABLESAMPLE(" + fraction + " PERCENT)") } - case p: Filter => - val whereOrHaving = p.child match { + case Filter(condition, child) => + val whereOrHaving = child match { case _: Aggregate => "HAVING" case _ => "WHERE" } - build(toSQL(p.child), whereOrHaving, p.condition.sql) + build(toSQL(child), whereOrHaving, exprToSQL(condition, child)) case p @ Distinct(u: Union) if u.children.length > 1 => val childrenSql = u.children.map(c => s"(${toSQL(c)})") @@ -166,7 +167,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi p.joinType.sql, "JOIN", toSQL(p.right), - p.condition.map(" ON " + _.sql).getOrElse("")) + p.condition.map(" ON " + exprToSQL(_, p)).getOrElse("")) case p: MetastoreRelation => build( @@ -176,20 +177,20 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) if orders.map(_.child) == partitionExprs => - build(toSQL(child), "CLUSTER BY", partitionExprs.map(_.sql).mkString(", ")) + build(toSQL(child), "CLUSTER BY", exprsToSQL(partitionExprs, child)) case p: Sort => build( toSQL(p.child), if (p.global) "ORDER BY" else "SORT BY", - p.order.map(_.sql).mkString(", ") + exprsToSQL(p.order, p.child) ) case p: RepartitionByExpression => build( toSQL(p.child), "DISTRIBUTE BY", - p.partitionExpressions.map(_.sql).mkString(", ") + exprsToSQL(p.partitionExpressions, p.child) ) case OneRowRelation => @@ -207,11 +208,55 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi private def build(segments: String*): String = segments.map(_.trim).filter(_.nonEmpty).mkString(" ") + private def updateQualifier( + expr: Expression, + qualifiers: Seq[(String, AttributeSet)]): Expression = { + if (qualifiers.isEmpty) { + expr + } else { + expr transform { + case a: Attribute => + val index = qualifiers.indexWhere { + case (_, inputAttributes) => inputAttributes.contains(a) + } + if (index == -1) { + a + } else { + a.withQualifiers(qualifiers(index)._1 :: Nil) + } + } + } + } + + private def findQualifiers(input: LogicalPlan): Seq[(String, AttributeSet)] = { + val results = mutable.ArrayBuffer.empty[(String, AttributeSet)] + val nodes = mutable.Stack(input) + + while (nodes.nonEmpty) { + val node = nodes.pop() + node match { + case SubqueryAlias(alias, child) => results += alias -> child.outputSet + case _ => node.children.foreach(nodes.push) + } + } + + results.toSeq + } + + private def exprToSQL(e: Expression, input: LogicalPlan): String = { + updateQualifier(e, findQualifiers(input)).sql + } + + private def exprsToSQL(exprs: Seq[Expression], input: LogicalPlan): String = { + val qualifiers = findQualifiers(input) + exprs.map(updateQualifier(_, qualifiers).sql).mkString(", ") + } + private def projectToSQL(plan: Project, isDistinct: Boolean): String = { build( "SELECT", if (isDistinct) "DISTINCT" else "", - plan.projectList.map(_.sql).mkString(", "), + exprsToSQL(plan.projectList, plan.child), if (plan.child == OneRowRelation) "" else "FROM", toSQL(plan.child) ) @@ -221,7 +266,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val groupingSQL = plan.groupingExpressions.map(_.sql).mkString(", ") build( "SELECT", - plan.aggregateExpressions.map(_.sql).mkString(", "), + exprsToSQL(plan.aggregateExpressions, plan.child), if (plan.child == OneRowRelation) "" else "FROM", toSQL(plan.child), if (groupingSQL.isEmpty) "" else "GROUP BY", @@ -244,11 +289,12 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi expand: Expand, project: Project): String = { assert(agg.groupingExpressions.length > 1) + val input = project.child // The last column of Expand is always grouping ID val gid = expand.output.last - val numOriginalOutput = project.child.output.length + val numOriginalOutput = input.output.length // Assumption: Aggregate's groupingExpressions is composed of // 1) the attributes of aliased group by expressions // 2) gid, which is always the last one @@ -257,7 +303,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // 1) the original output (Project's child.output), // 2) the aliased group by expressions. val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child) - val groupingSQL = groupByExprs.map(_.sql).mkString(", ") + val groupingSQL = exprsToSQL(groupByExprs, input) // a map from group by attributes to the original group by expressions. val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) @@ -272,8 +318,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi } } val groupingSetSQL = - "GROUPING SETS(" + - groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")" + "GROUPING SETS(" + groupingSet.map(e => "(" + exprsToSQL(e, input) + ")").mkString(", ") + ")" val aggExprs = agg.aggregateExpressions.map { case expr => expr.transformDown { @@ -291,9 +336,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi build( "SELECT", - aggExprs.map(_.sql).mkString(", "), + exprsToSQL(aggExprs, input), if (agg.child == OneRowRelation) "" else "FROM", - toSQL(project.child), + toSQL(input), "GROUP BY", groupingSQL, groupingSetSQL @@ -303,7 +348,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi private def windowToSQL(w: Window): String = { build( "SELECT", - (w.child.output ++ w.windowExpressions).map(_.sql).mkString(", "), + exprsToSQL(w.child.output ++ w.windowExpressions, w.child), if (w.child == OneRowRelation) "" else "FROM", toSQL(w.child) ) @@ -337,7 +382,10 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // +- Aggregate ... // +- MetastoreRelation default, src, None case plan @ Project(_, Filter(_, _: Aggregate)) => - wrapChildWithSubquery(plan) + plan.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, plan.child)) + + case w @ Window(_, _, _, _, Filter(_, _: Aggregate)) => + w.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, w.child)) case plan @ Project(_, _: SubqueryAlias @@ -351,7 +399,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi ) => plan case plan: Project => - wrapChildWithSubquery(plan) + plan.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, plan.child)) case w @ Window(_, _, _, _, _: SubqueryAlias @@ -365,26 +413,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi ) => w case w: Window => - val alias = SQLBuilder.newSubqueryName - val childAttributes = w.child.outputSet - val qualified = w.windowExpressions.map(_.transform { - case a: Attribute if childAttributes.contains(a) => - a.withQualifiers(alias :: Nil) - }.asInstanceOf[NamedExpression]) - - w.copy(windowExpressions = qualified, child = SubqueryAlias(alias, w.child)) - } - - def wrapChildWithSubquery(project: Project): Project = project match { - case Project(projectList, child) => - val alias = SQLBuilder.newSubqueryName - val childAttributes = child.outputSet - val aliasedProjectList = projectList.map(_.transform { - case a: Attribute if childAttributes.contains(a) => - a.withQualifiers(alias :: Nil) - }.asInstanceOf[NamedExpression]) - - Project(aliasedProjectList, SubqueryAlias(alias, child)) + w.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, w.child)) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 57e18dc1c05f7..37923b28fb74e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -446,18 +446,88 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { ) } - test("window functions") { - checkHiveQl("SELECT a, SUM(b) OVER (PARTITION BY c ORDER BY d) AS sum FROM parquet_t2") + test("window basic") { + checkHiveQl("SELECT MAX(value) OVER (PARTITION BY key % 3) FROM parquet_t1") checkHiveQl( """ - |SELECT a + 1, SUM(b * 2) OVER (PARTITION BY c + d ORDER BY c - d) AS sum - |FROM parquet_t2 + |SELECT key, value, ROUND(AVG(key) OVER (), 2) + |FROM parquet_t1 ORDER BY key """.stripMargin) checkHiveQl( """ - |SELECT a, SUM(b) OVER w1 AS sum, AVG(b) over w2 AS avg - |FROM parquet_t2 - |WINDOW w1 AS (PARTITION BY c ORDER BY d), w2 AS (PARTITION BY d ORDER BY c) + |SELECT value, MAX(key + 1) OVER (PARTITION BY key % 5 ORDER BY key % 7) AS max + |FROM parquet_t1 + """.stripMargin) + } + + test("window with different window specification") { + checkHiveQl( + """ + |SELECT key, value, + |DENSE_RANK() OVER (ORDER BY key, value) AS dr, + |MAX(value) OVER (PARTITION BY key ORDER BY key) AS max + |FROM parquet_t1 + """.stripMargin) + } + + test("window with the same window specification with aggregate + having") { + checkHiveQl( + """ + |SELECT key, value, + |MAX(value) OVER (PARTITION BY key % 5 ORDER BY key) AS max + |FROM parquet_t1 GROUP BY key, value HAVING key > 5 + """.stripMargin) + } + + test("window with the same window specification with aggregate functions") { + checkHiveQl( + """ + |SELECT key, value, + |MAX(value) OVER (PARTITION BY key % 5 ORDER BY key) AS max + |FROM parquet_t1 GROUP BY key, value + """.stripMargin) + } + + test("window with the same window specification with aggregate") { + checkHiveQl( + """ + |SELECT key, value, + |DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr, + |COUNT(key) + |FROM parquet_t1 GROUP BY key, value + """.stripMargin) + } + + test("window with the same window specification without aggregate and filter") { + checkHiveQl( + """ + |SELECT key, value, + |DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr, + |COUNT(key) OVER(DISTRIBUTE BY key SORT BY key, value) AS ca + |FROM parquet_t1 + """.stripMargin) + } + + test("window clause") { + checkHiveQl( + """ + |SELECT key, MAX(value) OVER w1 AS MAX, MIN(value) OVER w2 AS min + |FROM parquet_t1 + |WINDOW w1 AS (PARTITION BY key % 5 ORDER BY key), w2 AS (PARTITION BY key % 6) + """.stripMargin) + } + + test("special window functions") { + checkHiveQl( + """ + |SELECT + | RANK() OVER w, + | PERCENT_RANK() OVER w, + | DENSE_RANK() OVER w, + | ROW_NUMBER() OVER w, + | CUME_DIST() OVER w + |FROM parquet_t1 + |WINDOW w AS (PARTITION BY key % 5 ORDER BY key) """.stripMargin) } } From 9a66fbb756d78c393d2493dea5a8194bae1d61b5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 7 Mar 2016 20:24:01 +0800 Subject: [PATCH 3/9] add one more test --- .../org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 37923b28fb74e..bb5b9422f731e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -332,6 +332,10 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT id FROM parquet_t0 DISTRIBUTE BY id SORT BY id") } + test("SPARK-13720: sort by after having") { + checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key HAVING MAX(key) > 0 SORT BY key") + } + test("distinct aggregation") { checkHiveQl("SELECT COUNT(DISTINCT id) FROM parquet_t0") } From 40bd17a3d35b017d9af240da8a40df7e2998f610 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 7 Mar 2016 21:06:46 +0800 Subject: [PATCH 4/9] more comments --- .../apache/spark/sql/hive/SQLBuilder.scala | 46 ++++++++++++------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index edabd4fffd815..f2c65e974986f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive import java.util.concurrent.atomic.AtomicLong -import scala.collection.mutable import scala.util.control.NonFatal import org.apache.spark.Logging @@ -208,6 +207,11 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi private def build(segments: String*): String = segments.map(_.trim).filter(_.nonEmpty).mkString(" ") + /** + * Given a seq of qualifiers(names and their corresponding [[AttributeSet]]), transform the given + * expression tree, if an [[Attribute]] belongs to one of the [[AttributeSet]]s, update its + * qualifier with the corresponding name of the [[AttributeSet]]. + */ private def updateQualifier( expr: Expression, qualifiers: Seq[(String, AttributeSet)]): Expression = { @@ -228,27 +232,37 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi } } - private def findQualifiers(input: LogicalPlan): Seq[(String, AttributeSet)] = { - val results = mutable.ArrayBuffer.empty[(String, AttributeSet)] - val nodes = mutable.Stack(input) - - while (nodes.nonEmpty) { - val node = nodes.pop() - node match { - case SubqueryAlias(alias, child) => results += alias -> child.outputSet - case _ => node.children.foreach(nodes.push) - } - } - - results.toSeq + /** + * Finds the outer most [[SubqueryAlias]] nodes in the input logical plan and return their alias + * names and outputSet. + */ + private def findOutermostQualifiers(input: LogicalPlan): Seq[(String, AttributeSet)] = { + input.collectFirst { + case SubqueryAlias(alias, child) => Seq(alias -> child.outputSet) + case plan => plan.children.flatMap(findOutermostQualifiers) + }.toSeq.flatten } + /** + * Converts an expression to SQL string. + * + * Note that we may add extra [[SubqueryAlias]]s to the logical plan, but the qualifiers haven't + * been propagated yet. So here we try to find the corrected qualifiers first, and then update + * the given expression with the qualifiers and finally convert it to SQL string. + */ private def exprToSQL(e: Expression, input: LogicalPlan): String = { - updateQualifier(e, findQualifiers(input)).sql + updateQualifier(e, findOutermostQualifiers(input)).sql } + /** + * Converts a seq of expressions to SQL string. + * + * Note that we may add extra [[SubqueryAlias]]s to the logical plan, but the qualifiers haven't + * been propagated yet. So here we try to find the corrected qualifiers first, and then update + * the given expressions with the qualifiers and finally convert them to SQL string. + */ private def exprsToSQL(exprs: Seq[Expression], input: LogicalPlan): String = { - val qualifiers = findQualifiers(input) + val qualifiers = findOutermostQualifiers(input) exprs.map(updateQualifier(_, qualifiers).sql).mkString(", ") } From 276a870dee9d150c35220c391d8d41acd463c314 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 7 Mar 2016 23:03:58 +0800 Subject: [PATCH 5/9] fix a bug? --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 21b88d164a09c..8054a948fbf63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1180,7 +1180,7 @@ class Analyzer( // Finally, we create a Project to output currentChild's output // newExpressionsWithWindowFunctions. - Project(child.output ++ newExpressionsWithWindowFunctions, currentChild) + Project(currentChild.output ++ newExpressionsWithWindowFunctions, currentChild) } // end of addWindow // We have to use transformDown at here to make sure the rule of @@ -1201,7 +1201,7 @@ class Analyzer( val withWindow = addWindow(windowExpressions, withFilter) // Finally, generate output columns according to the original projectList. - val finalProjectList = aggregateExprs.map (_.toAttribute) + val finalProjectList = aggregateExprs.map(_.toAttribute) Project(finalProjectList, withWindow) case p: LogicalPlan if !p.childrenResolved => p @@ -1217,7 +1217,7 @@ class Analyzer( val withWindow = addWindow(windowExpressions, withAggregate) // Finally, generate output columns according to the original projectList. - val finalProjectList = aggregateExprs.map (_.toAttribute) + val finalProjectList = aggregateExprs.map(_.toAttribute) Project(finalProjectList, withWindow) // We only extract Window Expressions after all expressions of the Project @@ -1232,7 +1232,7 @@ class Analyzer( val withWindow = addWindow(windowExpressions, withProject) // Finally, generate output columns according to the original projectList. - val finalProjectList = projectList.map (_.toAttribute) + val finalProjectList = projectList.map(_.toAttribute) Project(finalProjectList, withWindow) } } From 656a13a84be56de2a6806296492951016082092e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 8 Mar 2016 00:23:12 +0800 Subject: [PATCH 6/9] simplification --- .../apache/spark/sql/hive/SQLBuilder.scala | 134 ++++++------------ .../sql/hive/LogicalPlanToSQLSuite.scala | 5 +- 2 files changed, 47 insertions(+), 92 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index f2c65e974986f..79c19ac8e0994 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -107,7 +107,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi windowToSQL(w) case Limit(limitExpr, child) => - s"${toSQL(child)} LIMIT ${exprToSQL(limitExpr, child)}" + s"${toSQL(child)} LIMIT ${limitExpr.sql}" case p: Sample if p.isTableSample => val fraction = math.min(100, math.max(0, (p.upperBound - p.lowerBound) * 100)) @@ -131,7 +131,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case _: Aggregate => "HAVING" case _ => "WHERE" } - build(toSQL(child), whereOrHaving, exprToSQL(condition, child)) + build(toSQL(child), whereOrHaving, condition.sql) case p @ Distinct(u: Union) if u.children.length > 1 => val childrenSql = u.children.map(c => s"(${toSQL(c)})") @@ -166,7 +166,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi p.joinType.sql, "JOIN", toSQL(p.right), - p.condition.map(" ON " + exprToSQL(_, p)).getOrElse("")) + p.condition.map(" ON " + _.sql).getOrElse("")) case p: MetastoreRelation => build( @@ -176,20 +176,20 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) if orders.map(_.child) == partitionExprs => - build(toSQL(child), "CLUSTER BY", exprsToSQL(partitionExprs, child)) + build(toSQL(child), "CLUSTER BY", partitionExprs.map(_.sql).mkString(", ")) case p: Sort => build( toSQL(p.child), if (p.global) "ORDER BY" else "SORT BY", - exprsToSQL(p.order, p.child) + p.order.map(_.sql).mkString(", ") ) case p: RepartitionByExpression => build( toSQL(p.child), "DISTRIBUTE BY", - exprsToSQL(p.partitionExpressions, p.child) + p.partitionExpressions.map(_.sql).mkString(", ") ) case OneRowRelation => @@ -207,70 +207,11 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi private def build(segments: String*): String = segments.map(_.trim).filter(_.nonEmpty).mkString(" ") - /** - * Given a seq of qualifiers(names and their corresponding [[AttributeSet]]), transform the given - * expression tree, if an [[Attribute]] belongs to one of the [[AttributeSet]]s, update its - * qualifier with the corresponding name of the [[AttributeSet]]. - */ - private def updateQualifier( - expr: Expression, - qualifiers: Seq[(String, AttributeSet)]): Expression = { - if (qualifiers.isEmpty) { - expr - } else { - expr transform { - case a: Attribute => - val index = qualifiers.indexWhere { - case (_, inputAttributes) => inputAttributes.contains(a) - } - if (index == -1) { - a - } else { - a.withQualifiers(qualifiers(index)._1 :: Nil) - } - } - } - } - - /** - * Finds the outer most [[SubqueryAlias]] nodes in the input logical plan and return their alias - * names and outputSet. - */ - private def findOutermostQualifiers(input: LogicalPlan): Seq[(String, AttributeSet)] = { - input.collectFirst { - case SubqueryAlias(alias, child) => Seq(alias -> child.outputSet) - case plan => plan.children.flatMap(findOutermostQualifiers) - }.toSeq.flatten - } - - /** - * Converts an expression to SQL string. - * - * Note that we may add extra [[SubqueryAlias]]s to the logical plan, but the qualifiers haven't - * been propagated yet. So here we try to find the corrected qualifiers first, and then update - * the given expression with the qualifiers and finally convert it to SQL string. - */ - private def exprToSQL(e: Expression, input: LogicalPlan): String = { - updateQualifier(e, findOutermostQualifiers(input)).sql - } - - /** - * Converts a seq of expressions to SQL string. - * - * Note that we may add extra [[SubqueryAlias]]s to the logical plan, but the qualifiers haven't - * been propagated yet. So here we try to find the corrected qualifiers first, and then update - * the given expressions with the qualifiers and finally convert them to SQL string. - */ - private def exprsToSQL(exprs: Seq[Expression], input: LogicalPlan): String = { - val qualifiers = findOutermostQualifiers(input) - exprs.map(updateQualifier(_, qualifiers).sql).mkString(", ") - } - private def projectToSQL(plan: Project, isDistinct: Boolean): String = { build( "SELECT", if (isDistinct) "DISTINCT" else "", - exprsToSQL(plan.projectList, plan.child), + plan.projectList.map(_.sql).mkString(", "), if (plan.child == OneRowRelation) "" else "FROM", toSQL(plan.child) ) @@ -280,7 +221,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val groupingSQL = plan.groupingExpressions.map(_.sql).mkString(", ") build( "SELECT", - exprsToSQL(plan.aggregateExpressions, plan.child), + plan.aggregateExpressions.map(_.sql).mkString(", "), if (plan.child == OneRowRelation) "" else "FROM", toSQL(plan.child), if (groupingSQL.isEmpty) "" else "GROUP BY", @@ -303,12 +244,11 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi expand: Expand, project: Project): String = { assert(agg.groupingExpressions.length > 1) - val input = project.child // The last column of Expand is always grouping ID val gid = expand.output.last - val numOriginalOutput = input.output.length + val numOriginalOutput = project.child.output.length // Assumption: Aggregate's groupingExpressions is composed of // 1) the attributes of aliased group by expressions // 2) gid, which is always the last one @@ -317,7 +257,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // 1) the original output (Project's child.output), // 2) the aliased group by expressions. val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child) - val groupingSQL = exprsToSQL(groupByExprs, input) + val groupingSQL = groupByExprs.map(_.sql).mkString(", ") // a map from group by attributes to the original group by expressions. val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) @@ -331,8 +271,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr) } } - val groupingSetSQL = - "GROUPING SETS(" + groupingSet.map(e => "(" + exprsToSQL(e, input) + ")").mkString(", ") + ")" + val groupingSetSQL = "GROUPING SETS(" + + groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")" val aggExprs = agg.aggregateExpressions.map { case expr => expr.transformDown { @@ -350,9 +290,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi build( "SELECT", - exprsToSQL(aggExprs, input), + aggExprs.map(_.sql).mkString(", "), if (agg.child == OneRowRelation) "" else "FROM", - toSQL(input), + toSQL(project.child), "GROUP BY", groupingSQL, groupingSetSQL @@ -362,7 +302,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi private def windowToSQL(w: Window): String = { build( "SELECT", - exprsToSQL(w.child.output ++ w.windowExpressions, w.child), + (w.child.output ++ w.windowExpressions).map(_.sql).mkString(", "), if (w.child == OneRowRelation) "" else "FROM", toSQL(w.child) ) @@ -370,20 +310,23 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi object Canonicalizer extends RuleExecutor[LogicalPlan] { override protected def batches: Seq[Batch] = Seq( - Batch("Canonicalizer", FixedPoint(100), + Batch("Collapse Project", FixedPoint(100), // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over // `Aggregate`s to perform type casting. This rule merges these `Project`s into // `Aggregate`s. - CollapseProject, - + CollapseProject), + Batch("Recover Scoping Info", Once, // Used to handle other auxiliary `Project`s added by analyzer (e.g. // `ResolveAggregateFunctions` rule) - RecoverScopingInfo + AddSubquery, + // Previous rule will add extra sub-queries, this rule is used to re-propagate and update + // the qualifiers bottom up. + UpdateQualifiers ) ) - object RecoverScopingInfo extends Rule[LogicalPlan] { - override def apply(tree: LogicalPlan): LogicalPlan = tree transform { + object AddSubquery extends Rule[LogicalPlan] { + override def apply(tree: LogicalPlan): LogicalPlan = tree transformUp { // This branch handles aggregate functions within HAVING clauses. For example: // // SELECT key FROM src GROUP BY key HAVING max(value) > "val_255" @@ -395,11 +338,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // +- Filter ... // +- Aggregate ... // +- MetastoreRelation default, src, None - case plan @ Project(_, Filter(_, _: Aggregate)) => - plan.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, plan.child)) + case plan @ Project(_, Filter(_, _: Aggregate)) => wrapChildWithSubquery(plan) - case w @ Window(_, _, _, _, Filter(_, _: Aggregate)) => - w.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, w.child)) + case w @ Window(_, _, _, _, Filter(_, _: Aggregate)) => wrapChildWithSubquery(w) case plan @ Project(_, _: SubqueryAlias @@ -412,8 +353,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi | _: Sample ) => plan - case plan: Project => - plan.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, plan.child)) + case plan: Project => wrapChildWithSubquery(plan) case w @ Window(_, _, _, _, _: SubqueryAlias @@ -426,8 +366,24 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi | _: Sample ) => w - case w: Window => - w.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, w.child)) + case w: Window => wrapChildWithSubquery(w) + } + + private def wrapChildWithSubquery(plan: UnaryNode): LogicalPlan = { + val newChild = SubqueryAlias(SQLBuilder.newSubqueryName, plan.child) + plan.withNewChildren(Seq(newChild)) + } + } + + object UpdateQualifiers extends Rule[LogicalPlan] { + override def apply(tree: LogicalPlan): LogicalPlan = tree transformUp { + case plan => + val inputAttributes = plan.children.flatMap(_.output) + plan transformExpressions { + case a: AttributeReference if !plan.producedAttributes.contains(a) => + val qualifier = inputAttributes.find(_ semanticEquals a).map(_.qualifiers) + a.withQualifiers(qualifier.getOrElse(Nil)) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index bb5b9422f731e..b0242b2b48310 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -67,7 +67,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { | |# Resolved query plan: |${df.queryExecution.analyzed.treeString} - """.stripMargin) + """.stripMargin, e) } try { @@ -84,8 +84,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { | |# Resolved query plan: |${df.queryExecution.analyzed.treeString} - """.stripMargin, - cause) + """.stripMargin, cause) } } From c82229a42efec9131652435b9543df81d1feab6c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 8 Mar 2016 10:51:54 +0800 Subject: [PATCH 7/9] add log --- .../org/apache/spark/sql/hive/execution/HiveComparisonTest.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index afb816211eab0..1053246fc2958 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -449,6 +449,7 @@ abstract class HiveComparisonTest |Failed to execute query using catalyst: |Error: ${e.getMessage} |${stackTraceToString(e)} + |$queryString |$query |== HIVE - ${hive.size} row(s) == |${hive.mkString("\n")} From 054f50a8661d0d2a20b2924da3815fd13f29568a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 8 Mar 2016 16:22:29 +0800 Subject: [PATCH 8/9] fix a bug --- .../expressions/windowExpressions.scala | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 0c3dffc7d8ca4..b8679474cf354 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp} import org.apache.spark.sql.types._ @@ -30,6 +31,7 @@ sealed trait WindowSpec /** * The specification for a window function. + * * @param partitionSpec It defines the way that input rows are partitioned. * @param orderSpec It defines the ordering of rows in a partition. * @param frameSpecification It defines the window frame in a partition. @@ -518,12 +520,25 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction { def this() = this(Literal(1)) + override def children: Seq[Expression] = Seq(buckets) + // Validate buckets. Note that this could be relaxed, the bucket value only needs to constant // for each partition. - buckets.eval() match { - case b: Int if b > 0 => // Ok - case x => throw new AnalysisException( - "Buckets expression must be a foldable positive integer expression: $x") + override def checkInputDataTypes(): TypeCheckResult = { + if (!buckets.foldable) { + return TypeCheckFailure(s"Buckets expression must be foldable, but got $buckets") + } + + if (buckets.dataType != IntegerType) { + return TypeCheckFailure(s"Buckets expression must be integer type, but got $buckets") + } + + val i = buckets.eval().asInstanceOf[Int] + if (i > 0) { + TypeCheckSuccess + } else { + TypeCheckFailure(s"Buckets expression must be positive, but got: $i") + } } private val bucket = AttributeReference("bucket", IntegerType, nullable = false)() From dab7a2f1a5cc0438405b0fa1cf532ab883bed7e7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 9 Mar 2016 11:16:35 +0800 Subject: [PATCH 9/9] address comments --- .../apache/spark/sql/hive/SQLBuilder.scala | 20 +++++++++++++- .../sql/hive/LogicalPlanToSQLSuite.scala | 26 ++++++++++++++++--- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 79c19ac8e0994..bf12982da7cde 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -320,7 +320,23 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // `ResolveAggregateFunctions` rule) AddSubquery, // Previous rule will add extra sub-queries, this rule is used to re-propagate and update - // the qualifiers bottom up. + // the qualifiers bottom up, e.g.: + // + // Sort + // ordering = t1.a + // Project + // projectList = [t1.a, t1.b] + // Subquery gen_subquery + // child ... + // + // will be transformed to: + // + // Sort + // ordering = gen_subquery.a + // Project + // projectList = [gen_subquery.a, gen_subquery.b] + // Subquery gen_subquery + // child ... UpdateQualifiers ) ) @@ -355,6 +371,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case plan: Project => wrapChildWithSubquery(plan) + // We will generate "SELECT ... FROM ..." for Window operator, so its child operator should + // be able to put in the FROM clause, or we wrap it with a subquery. case w @ Window(_, _, _, _, _: SubqueryAlias | _: Filter diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index b0242b2b48310..7fb35d3adf3a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -463,12 +463,29 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { """.stripMargin) } + test("multiple window functions in one expression") { + checkHiveQl( + """ + |SELECT + | MAX(key) OVER (ORDER BY key DESC, value) / MIN(key) OVER (PARTITION BY key % 3) + |FROM parquet_t1 + """.stripMargin) + } + + test("regular expressions and window functions in one expression") { + checkHiveQl("SELECT MAX(key) OVER (PARTITION BY key % 3) + key FROM parquet_t1") + } + + test("aggregate functions and window functions in one expression") { + checkHiveQl("SELECT MAX(c) + COUNT(a) OVER () FROM parquet_t2 GROUP BY a, b") + } + test("window with different window specification") { checkHiveQl( """ |SELECT key, value, |DENSE_RANK() OVER (ORDER BY key, value) AS dr, - |MAX(value) OVER (PARTITION BY key ORDER BY key) AS max + |MAX(value) OVER (PARTITION BY key ORDER BY key ASC) AS max |FROM parquet_t1 """.stripMargin) } @@ -477,7 +494,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl( """ |SELECT key, value, - |MAX(value) OVER (PARTITION BY key % 5 ORDER BY key) AS max + |MAX(value) OVER (PARTITION BY key % 5 ORDER BY key DESC) AS max |FROM parquet_t1 GROUP BY key, value HAVING key > 5 """.stripMargin) } @@ -528,7 +545,10 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { | PERCENT_RANK() OVER w, | DENSE_RANK() OVER w, | ROW_NUMBER() OVER w, - | CUME_DIST() OVER w + | NTILE(10) OVER w, + | CUME_DIST() OVER w, + | LAG(key, 2) OVER w, + | LEAD(key, 2) OVER w |FROM parquet_t1 |WINDOW w AS (PARTITION BY key % 5 ORDER BY key) """.stripMargin)