From f11984992bada2aaa3f21eeb57e4c8644fe7e470 Mon Sep 17 00:00:00 2001 From: Yadong Qi Date: Tue, 21 Apr 2015 21:03:40 +0800 Subject: [PATCH 1/5] order by aggregated function --- .../sql/catalyst/analysis/Analyzer.scala | 21 ++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 52 +++++++++++++++++++ .../scala/org/apache/spark/sql/TestData.scala | 13 +++++ 3 files changed, 82 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cb49e5ad5586f..6c85247cd1a0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ +import scala.collection.mutable.ArrayBuffer /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing @@ -355,7 +356,6 @@ class Analyzer( } case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved => - val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) // A small hack to create an object that will allow us to resolve any references that // refer to named expressions that are present in the grouping expressions. val groupingRelation = LocalRelation( @@ -364,11 +364,24 @@ class Analyzer( val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, a, groupingRelation) - if (missing.nonEmpty) { + val addForAlias = new ArrayBuffer[NamedExpression]() + val aliasedOrdering = resolvedOrdering.zipWithIndex.map { + case (o, i) => { + o transform { + case aggOrSub @ (_: AggregateExpression | _: Substring) => + val aliasName = aggOrSub.nodeName + i + val alias = Alias(aggOrSub, aliasName)() + addForAlias += alias + alias.toAttribute + } + }.asInstanceOf[SortOrder] + } + + if ((missing ++ addForAlias).nonEmpty) { // Add missing grouping exprs and then project them away after the sort. Project(a.output, - Sort(resolvedOrdering, global, - Aggregate(grouping, aggs ++ missing, child))) + Sort(aliasedOrdering, global, + Aggregate(grouping, aggs ++ missing ++ addForAlias, child))) } else { s // Nothing we can do here. Return original plan. } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9e02e69fda3f2..d6ab6f1e926c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1252,4 +1252,56 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } + + test("SPARK-6583 order by aggregated function") { + checkAnswer( + sql( + """ + |SELECT a + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + """.stripMargin), + Row("4") :: Row("1") :: Row("3") :: Row("2") :: Nil) + + checkAnswer( + sql( + """ + |SELECT sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + """.stripMargin), + Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a, sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + """.stripMargin), + Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a, sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + 1 + """.stripMargin), + Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT substr(a, 1, 1), sum(b) + |FROM orderByData + |GROUP BY substr(a, 1, 1) + |ORDER BY substr(a, 1, 1) + """.stripMargin), + Row("1", 7) :: Row("2", 15) :: Row("3", 11) :: Row("4", 3) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 225b51bd73d6c..03d328c025ada 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -205,4 +205,17 @@ object TestData { :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false) :: Nil).toDF() complexData.registerTempTable("complexData") + + case class OrderByData(a: String, b: Int) + val orderByData = + TestSQLContext.sparkContext.parallelize( + OrderByData("1", 3) :: + OrderByData("1", 4) :: + OrderByData("2", 7) :: + OrderByData("2", 8) :: + OrderByData("3", 5) :: + OrderByData("3", 6) :: + OrderByData("4", 1) :: + OrderByData("4", 2) :: Nil).toDF() + orderByData.registerTempTable("orderByData") } From 7f9b7360a41fdd33388a9582dad9ad05d9dfba23 Mon Sep 17 00:00:00 2001 From: Yadong Qi Date: Tue, 9 Jun 2015 14:20:32 +0800 Subject: [PATCH 2/5] delete Substring case --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 +++--- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 10 ---------- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8545532922ce6..854005e74cc16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -410,9 +410,9 @@ class Analyzer( val aliasedOrdering = resolvedOrdering.zipWithIndex.map { case (o, i) => { o transform { - case aggOrSub @ (_: AggregateExpression | _: Substring) => - val aliasName = aggOrSub.nodeName + i - val alias = Alias(aggOrSub, aliasName)() + case agg: AggregateExpression => + val aliasName = agg.nodeName + i + val alias = Alias(agg, aliasName)() addForAlias += alias alias.toAttribute } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 05ff9fe007f2a..c7f4a10adf590 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1372,16 +1372,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { |ORDER BY sum(b) + 1 """.stripMargin), Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) - - checkAnswer( - sql( - """ - |SELECT substr(a, 1, 1), sum(b) - |FROM orderByData - |GROUP BY substr(a, 1, 1) - |ORDER BY substr(a, 1, 1) - """.stripMargin), - Row("1", 7) :: Row("2", 15) :: Row("3", 11) :: Row("4", 3) :: Nil) } test("SPARK-7952: fix the equality check between boolean and numeric types") { From c8b25c1f8583eb8b657addf424d017041e5d61ee Mon Sep 17 00:00:00 2001 From: Yadong Qi Date: Wed, 10 Jun 2015 17:22:17 +0800 Subject: [PATCH 3/5] move the test data. --- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 3 +++ .../test/scala/org/apache/spark/sql/TestData.scala | 13 ------------- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c7f4a10adf590..37331b37a47b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1333,6 +1333,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-6583 order by aggregated function") { + Seq("1" -> 3, "1" -> 4, "2" -> 7, "2" -> 8, "3" -> 5, "3" -> 6, "4" -> 1, "4" -> 2) + .toDF("a", "b").registerTempTable("orderByData") + checkAnswer( sql( """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 549404d03c809..725a18bfae3a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -203,17 +203,4 @@ object TestData { :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false) :: Nil).toDF() complexData.registerTempTable("complexData") - - case class OrderByData(a: String, b: Int) - val orderByData = - TestSQLContext.sparkContext.parallelize( - OrderByData("1", 3) :: - OrderByData("1", 4) :: - OrderByData("2", 7) :: - OrderByData("2", 8) :: - OrderByData("3", 5) :: - OrderByData("3", 6) :: - OrderByData("4", 1) :: - OrderByData("4", 2) :: Nil).toDF() - orderByData.registerTempTable("orderByData") } From eb8938db11287fb7a7581b5a02b5218b148a71fa Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 11 Jun 2015 21:27:27 -0700 Subject: [PATCH 4/5] no vars --- .../sql/catalyst/analysis/Analyzer.scala | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 854005e74cc16..9d9c8e91ca0fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -404,26 +404,26 @@ class Analyzer( grouping.collect { case ne: NamedExpression => ne.toAttribute } ) - val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, a, groupingRelation) - - val addForAlias = new ArrayBuffer[NamedExpression]() - val aliasedOrdering = resolvedOrdering.zipWithIndex.map { - case (o, i) => { - o transform { - case agg: AggregateExpression => - val aliasName = agg.nodeName + i - val alias = Alias(agg, aliasName)() - addForAlias += alias - alias.toAttribute - } - }.asInstanceOf[SortOrder] - } + // Find sort attributes that are projected away so we can temporarily add them back in. + val (resolvedOrdering, unresolved) = resolveAndFindMissing(ordering, a, groupingRelation) + + // Find aggregate expressions and evaluate them early, since they can't be evaluated in a + // Sort. + val (aliasedAggregateList, withAggsRemoved) = resolvedOrdering.map { + case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty => + val aliased = Alias(aggOrdering.child, "_aggOrdering")() + (aliased :: Nil, aggOrdering.copy(child = aliased.toAttribute)) - if ((missing ++ addForAlias).nonEmpty) { + case other => (Nil, other) + }.unzip + + val missing = unresolved ++ aliasedAggregateList.flatten + + if (missing.nonEmpty) { // Add missing grouping exprs and then project them away after the sort. Project(a.output, - Sort(aliasedOrdering, global, - Aggregate(grouping, aggs ++ missing ++ addForAlias, child))) + Sort(withAggsRemoved, global, + Aggregate(grouping, aggs ++ missing, child))) } else { s // Nothing we can do here. Return original plan. } From 3226a97794d3b1e3d18ad5bdd717c2ef298025c4 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 11 Jun 2015 21:41:22 -0700 Subject: [PATCH 5/5] consistent ordering --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9d9c8e91ca0fd..2c1dff1442a26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -409,12 +409,12 @@ class Analyzer( // Find aggregate expressions and evaluate them early, since they can't be evaluated in a // Sort. - val (aliasedAggregateList, withAggsRemoved) = resolvedOrdering.map { + val (withAggsRemoved, aliasedAggregateList) = resolvedOrdering.map { case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty => val aliased = Alias(aggOrdering.child, "_aggOrdering")() - (aliased :: Nil, aggOrdering.copy(child = aliased.toAttribute)) + (aggOrdering.copy(child = aliased.toAttribute), aliased :: Nil) - case other => (Nil, other) + case other => (other, Nil) }.unzip val missing = unresolved ++ aliasedAggregateList.flatten