diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fbbc3ee891c6b..8054a948fbf63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1201,7 +1201,7 @@ class Analyzer( val withWindow = addWindow(windowExpressions, withFilter) // Finally, generate output columns according to the original projectList. - val finalProjectList = aggregateExprs.map (_.toAttribute) + val finalProjectList = aggregateExprs.map(_.toAttribute) Project(finalProjectList, withWindow) case p: LogicalPlan if !p.childrenResolved => p @@ -1217,7 +1217,7 @@ class Analyzer( val withWindow = addWindow(windowExpressions, withAggregate) // Finally, generate output columns according to the original projectList. - val finalProjectList = aggregateExprs.map (_.toAttribute) + val finalProjectList = aggregateExprs.map(_.toAttribute) Project(finalProjectList, withWindow) // We only extract Window Expressions after all expressions of the Project @@ -1232,7 +1232,7 @@ class Analyzer( val withWindow = addWindow(windowExpressions, withProject) // Finally, generate output columns according to the original projectList. - val finalProjectList = projectList.map (_.toAttribute) + val finalProjectList = projectList.map(_.toAttribute) Project(finalProjectList, withWindow) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index bd1d91487275b..b739361937b6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -57,7 +57,8 @@ case class SortOrder(child: Expression, direction: SortDirection) override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable - override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" + override def toString: String = s"$child ${direction.sql}" + override def sql: String = child.sql + " " + direction.sql def isAscending: Boolean = direction == Ascending } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 9e6bd0ee460f0..b8679474cf354 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp} import org.apache.spark.sql.types._ @@ -30,6 +31,7 @@ sealed trait WindowSpec /** * The specification for a window function. + * * @param partitionSpec It defines the way that input rows are partitioned. * @param orderSpec It defines the ordering of rows in a partition. * @param frameSpecification It defines the window frame in a partition. @@ -75,6 +77,22 @@ case class WindowSpecDefinition( override def nullable: Boolean = true override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException + + override def sql: String = { + val partition = if (partitionSpec.isEmpty) { + "" + } else { + "PARTITION BY " + partitionSpec.map(_.sql).mkString(", ") + } + + val order = if (orderSpec.isEmpty) { + "" + } else { + "ORDER BY " + orderSpec.map(_.sql).mkString(", ") + } + + s"($partition $order ${frameSpecification.toString})" + } } /** @@ -278,6 +296,7 @@ case class WindowExpression( override def nullable: Boolean = windowFunction.nullable override def toString: String = s"$windowFunction $windowSpec" + override def sql: String = windowFunction.sql + " OVER " + windowSpec.sql } /** @@ -451,6 +470,7 @@ object SizeBasedWindowFunction { the window partition.""") case class RowNumber() extends RowNumberLike { override val evaluateExpression = rowNumber + override def sql: String = "ROW_NUMBER()" } /** @@ -470,6 +490,7 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { // return the same value for equal values in the partition. override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType)) + override def sql: String = "CUME_DIST()" } /** @@ -499,12 +520,25 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction { def this() = this(Literal(1)) + override def children: Seq[Expression] = Seq(buckets) + // Validate buckets. Note that this could be relaxed, the bucket value only needs to constant // for each partition. - buckets.eval() match { - case b: Int if b > 0 => // Ok - case x => throw new AnalysisException( - "Buckets expression must be a foldable positive integer expression: $x") + override def checkInputDataTypes(): TypeCheckResult = { + if (!buckets.foldable) { + return TypeCheckFailure(s"Buckets expression must be foldable, but got $buckets") + } + + if (buckets.dataType != IntegerType) { + return TypeCheckFailure(s"Buckets expression must be integer type, but got $buckets") + } + + val i = buckets.eval().asInstanceOf[Int] + if (i > 0) { + TypeCheckSuccess + } else { + TypeCheckFailure(s"Buckets expression must be positive, but got: $i") + } } private val bucket = AttributeReference("bucket", IntegerType, nullable = false)() @@ -608,6 +642,7 @@ abstract class RankLike extends AggregateWindowFunction { case class Rank(children: Seq[Expression]) extends RankLike { def this() = this(Nil) override def withOrder(order: Seq[Expression]): Rank = Rank(order) + override def sql: String = "RANK()" } /** @@ -632,6 +667,7 @@ case class DenseRank(children: Seq[Expression]) extends RankLike { override val updateExpressions = increaseRank +: children override val aggBufferAttributes = rank +: orderAttrs override val initialValues = zero +: orderInit + override def sql: String = "DENSE_RANK()" } /** @@ -658,4 +694,5 @@ case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBase override val evaluateExpression = If(GreaterThan(n, one), Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), DoubleType)), Literal(0.0d)) + override def sql: String = "PERCENT_RANK()" } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 683f738054c5a..bf12982da7cde 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -42,7 +42,7 @@ case class SubqueryHolder(query: String) extends LeafExpression with Unevaluable } /** - * A builder class used to convert a resolved logical plan into a SQL query string. Note that this + * A builder class used to convert a resolved logical plan into a SQL query string. Note that not * all resolved logical plan are convertible. They either don't have corresponding SQL * representations (e.g. logical plans that operate on local Scala collections), or are simply not * supported by this builder (yet). @@ -103,6 +103,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case p: Aggregate => aggregateToSQL(p) + case w: Window => + windowToSQL(w) + case Limit(limitExpr, child) => s"${toSQL(child)} LIMIT ${limitExpr.sql}" @@ -123,12 +126,12 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi build(toSQL(p.child), "TABLESAMPLE(" + fraction + " PERCENT)") } - case p: Filter => - val whereOrHaving = p.child match { + case Filter(condition, child) => + val whereOrHaving = child match { case _: Aggregate => "HAVING" case _ => "WHERE" } - build(toSQL(p.child), whereOrHaving, p.condition.sql) + build(toSQL(child), whereOrHaving, condition.sql) case p @ Distinct(u: Union) if u.children.length > 1 => val childrenSql = u.children.map(c => s"(${toSQL(c)})") @@ -179,7 +182,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi build( toSQL(p.child), if (p.global) "ORDER BY" else "SORT BY", - p.order.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ") + p.order.map(_.sql).mkString(", ") ) case p: RepartitionByExpression => @@ -268,9 +271,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr) } } - val groupingSetSQL = - "GROUPING SETS(" + - groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")" + val groupingSetSQL = "GROUPING SETS(" + + groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")" val aggExprs = agg.aggregateExpressions.map { case expr => expr.transformDown { @@ -297,22 +299,50 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi ) } + private def windowToSQL(w: Window): String = { + build( + "SELECT", + (w.child.output ++ w.windowExpressions).map(_.sql).mkString(", "), + if (w.child == OneRowRelation) "" else "FROM", + toSQL(w.child) + ) + } + object Canonicalizer extends RuleExecutor[LogicalPlan] { override protected def batches: Seq[Batch] = Seq( - Batch("Canonicalizer", FixedPoint(100), + Batch("Collapse Project", FixedPoint(100), // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over // `Aggregate`s to perform type casting. This rule merges these `Project`s into // `Aggregate`s. - CollapseProject, - + CollapseProject), + Batch("Recover Scoping Info", Once, // Used to handle other auxiliary `Project`s added by analyzer (e.g. // `ResolveAggregateFunctions` rule) - RecoverScopingInfo + AddSubquery, + // Previous rule will add extra sub-queries, this rule is used to re-propagate and update + // the qualifiers bottom up, e.g.: + // + // Sort + // ordering = t1.a + // Project + // projectList = [t1.a, t1.b] + // Subquery gen_subquery + // child ... + // + // will be transformed to: + // + // Sort + // ordering = gen_subquery.a + // Project + // projectList = [gen_subquery.a, gen_subquery.b] + // Subquery gen_subquery + // child ... + UpdateQualifiers ) ) - object RecoverScopingInfo extends Rule[LogicalPlan] { - override def apply(tree: LogicalPlan): LogicalPlan = tree transform { + object AddSubquery extends Rule[LogicalPlan] { + override def apply(tree: LogicalPlan): LogicalPlan = tree transformUp { // This branch handles aggregate functions within HAVING clauses. For example: // // SELECT key FROM src GROUP BY key HAVING max(value) > "val_255" @@ -324,8 +354,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // +- Filter ... // +- Aggregate ... // +- MetastoreRelation default, src, None - case plan @ Project(_, Filter(_, _: Aggregate)) => - wrapChildWithSubquery(plan) + case plan @ Project(_, Filter(_, _: Aggregate)) => wrapChildWithSubquery(plan) + + case w @ Window(_, _, _, _, Filter(_, _: Aggregate)) => wrapChildWithSubquery(w) case plan @ Project(_, _: SubqueryAlias @@ -338,20 +369,39 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi | _: Sample ) => plan - case plan: Project => - wrapChildWithSubquery(plan) + case plan: Project => wrapChildWithSubquery(plan) + + // We will generate "SELECT ... FROM ..." for Window operator, so its child operator should + // be able to put in the FROM clause, or we wrap it with a subquery. + case w @ Window(_, _, _, _, + _: SubqueryAlias + | _: Filter + | _: Join + | _: MetastoreRelation + | OneRowRelation + | _: LocalLimit + | _: GlobalLimit + | _: Sample + ) => w + + case w: Window => wrapChildWithSubquery(w) } - def wrapChildWithSubquery(project: Project): Project = project match { - case Project(projectList, child) => - val alias = SQLBuilder.newSubqueryName - val childAttributes = child.outputSet - val aliasedProjectList = projectList.map(_.transform { - case a: Attribute if childAttributes.contains(a) => - a.withQualifiers(alias :: Nil) - }.asInstanceOf[NamedExpression]) + private def wrapChildWithSubquery(plan: UnaryNode): LogicalPlan = { + val newChild = SubqueryAlias(SQLBuilder.newSubqueryName, plan.child) + plan.withNewChildren(Seq(newChild)) + } + } - Project(aliasedProjectList, SubqueryAlias(alias, child)) + object UpdateQualifiers extends Rule[LogicalPlan] { + override def apply(tree: LogicalPlan): LogicalPlan = tree transformUp { + case plan => + val inputAttributes = plan.children.flatMap(_.output) + plan transformExpressions { + case a: AttributeReference if !plan.producedAttributes.contains(a) => + val qualifier = inputAttributes.find(_ semanticEquals a).map(_.qualifiers) + a.withQualifiers(qualifier.getOrElse(Nil)) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index ed85856f017df..7fb35d3adf3a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -67,7 +67,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { | |# Resolved query plan: |${df.queryExecution.analyzed.treeString} - """.stripMargin) + """.stripMargin, e) } try { @@ -84,8 +84,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { | |# Resolved query plan: |${df.queryExecution.analyzed.treeString} - """.stripMargin, - cause) + """.stripMargin, cause) } } @@ -332,6 +331,10 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT id FROM parquet_t0 DISTRIBUTE BY id SORT BY id") } + test("SPARK-13720: sort by after having") { + checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key HAVING MAX(key) > 0 SORT BY key") + } + test("distinct aggregation") { checkHiveQl("SELECT COUNT(DISTINCT id) FROM parquet_t0") } @@ -445,4 +448,109 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { "f1", "b[0].f1", "f1", "c[foo]", "d[0]" ) } + + test("window basic") { + checkHiveQl("SELECT MAX(value) OVER (PARTITION BY key % 3) FROM parquet_t1") + checkHiveQl( + """ + |SELECT key, value, ROUND(AVG(key) OVER (), 2) + |FROM parquet_t1 ORDER BY key + """.stripMargin) + checkHiveQl( + """ + |SELECT value, MAX(key + 1) OVER (PARTITION BY key % 5 ORDER BY key % 7) AS max + |FROM parquet_t1 + """.stripMargin) + } + + test("multiple window functions in one expression") { + checkHiveQl( + """ + |SELECT + | MAX(key) OVER (ORDER BY key DESC, value) / MIN(key) OVER (PARTITION BY key % 3) + |FROM parquet_t1 + """.stripMargin) + } + + test("regular expressions and window functions in one expression") { + checkHiveQl("SELECT MAX(key) OVER (PARTITION BY key % 3) + key FROM parquet_t1") + } + + test("aggregate functions and window functions in one expression") { + checkHiveQl("SELECT MAX(c) + COUNT(a) OVER () FROM parquet_t2 GROUP BY a, b") + } + + test("window with different window specification") { + checkHiveQl( + """ + |SELECT key, value, + |DENSE_RANK() OVER (ORDER BY key, value) AS dr, + |MAX(value) OVER (PARTITION BY key ORDER BY key ASC) AS max + |FROM parquet_t1 + """.stripMargin) + } + + test("window with the same window specification with aggregate + having") { + checkHiveQl( + """ + |SELECT key, value, + |MAX(value) OVER (PARTITION BY key % 5 ORDER BY key DESC) AS max + |FROM parquet_t1 GROUP BY key, value HAVING key > 5 + """.stripMargin) + } + + test("window with the same window specification with aggregate functions") { + checkHiveQl( + """ + |SELECT key, value, + |MAX(value) OVER (PARTITION BY key % 5 ORDER BY key) AS max + |FROM parquet_t1 GROUP BY key, value + """.stripMargin) + } + + test("window with the same window specification with aggregate") { + checkHiveQl( + """ + |SELECT key, value, + |DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr, + |COUNT(key) + |FROM parquet_t1 GROUP BY key, value + """.stripMargin) + } + + test("window with the same window specification without aggregate and filter") { + checkHiveQl( + """ + |SELECT key, value, + |DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr, + |COUNT(key) OVER(DISTRIBUTE BY key SORT BY key, value) AS ca + |FROM parquet_t1 + """.stripMargin) + } + + test("window clause") { + checkHiveQl( + """ + |SELECT key, MAX(value) OVER w1 AS MAX, MIN(value) OVER w2 AS min + |FROM parquet_t1 + |WINDOW w1 AS (PARTITION BY key % 5 ORDER BY key), w2 AS (PARTITION BY key % 6) + """.stripMargin) + } + + test("special window functions") { + checkHiveQl( + """ + |SELECT + | RANK() OVER w, + | PERCENT_RANK() OVER w, + | DENSE_RANK() OVER w, + | ROW_NUMBER() OVER w, + | NTILE(10) OVER w, + | CUME_DIST() OVER w, + | LAG(key, 2) OVER w, + | LEAD(key, 2) OVER w + |FROM parquet_t1 + |WINDOW w AS (PARTITION BY key % 5 ORDER BY key) + """.stripMargin) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index afb816211eab0..1053246fc2958 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -449,6 +449,7 @@ abstract class HiveComparisonTest |Failed to execute query using catalyst: |Error: ${e.getMessage} |${stackTraceToString(e)} + |$queryString |$query |== HIVE - ${hive.size} row(s) == |${hive.mkString("\n")}