From 9d4ba8b26ab09f2d8b8fd01828550169297dcfb3 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 15 Jan 2020 21:49:22 +0900 Subject: [PATCH] Fix --- .../expressions/aggregate/interfaces.scala | 9 +-- .../sql-tests/results/group-by-filter.sql.out | 62 +++++++++---------- .../postgreSQL/aggregates_part3.sql.out | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 33 +++++----- 4 files changed, 57 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index aa61e371df1c7..801aacfffa740 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -137,10 +137,11 @@ case class AggregateExpression( @transient override lazy val references: AttributeSet = { - mode match { - case Partial | Complete => aggregateFunction.references ++ filterAttributes + val aggAttributes = mode match { + case Partial | Complete => aggregateFunction.references case PartialMerge | Final => AttributeSet(aggregateFunction.aggBufferAttributes) } + aggAttributes ++ filterAttributes } override def toString: String = { @@ -151,7 +152,7 @@ case class AggregateExpression( } val aggFuncStr = prefix + aggregateFunction.toAggString(isDistinct) filter match { - case Some(predicate) => s"$aggFuncStr filter $predicate" + case Some(predicate) => s"$aggFuncStr filter (where $predicate)" case _ => aggFuncStr } } @@ -159,7 +160,7 @@ case class AggregateExpression( override def sql: String = { val aggFuncStr = aggregateFunction.sql(isDistinct) filter match { - case Some(predicate) => s"$aggFuncStr FILTER ${predicate.sql}" + case Some(predicate) => s"$aggFuncStr FILTER (WHERE ${predicate.sql})" case _ => aggFuncStr } } diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index 10aa4056f7729..4bfef4efbcf5b 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -51,13 +51,13 @@ SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData struct<> -- !query 3 output org.apache.spark.sql.AnalysisException -grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) FILTER (testdata.`a` >= 2) AS `count(b) FILTER (a >= 2)`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.; +grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) FILTER (WHERE (testdata.`a` >= 2)) AS `count(b) FILTER (WHERE (a >= 2))`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.; -- !query 4 SELECT COUNT(a) FILTER (WHERE a = 1), COUNT(b) FILTER (WHERE a > 1) FROM testData -- !query 4 schema -struct 1):bigint> +struct 1)):bigint> -- !query 4 output 2 4 @@ -65,7 +65,7 @@ struct 1):bigint> -- !query 5 SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp -- !query 5 schema -struct +struct -- !query 5 output 2 @@ -73,7 +73,7 @@ struct -- !query 6 SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp -- !query 6 schema -struct +struct -- !query 6 output 2 @@ -81,7 +81,7 @@ struct -- !query 7 SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp -- !query 7 schema -struct +struct -- !query 7 output 2 @@ -89,7 +89,7 @@ struct +struct -- !query 8 output 2 @@ -97,7 +97,7 @@ struct= 2) FROM testData GROUP BY a -- !query 9 schema -struct= 2):bigint> +struct= 2)):bigint> -- !query 9 output 1 0 2 2 @@ -117,7 +117,7 @@ expression 'testdata.`a`' is neither present in the group by, nor is it an aggre -- !query 11 SELECT COUNT(a) FILTER (WHERE a >= 0), COUNT(b) FILTER (WHERE a >= 3) FROM testData GROUP BY a -- !query 11 schema -struct= 0):bigint,count(b) FILTER (a >= 3):bigint> +struct= 0)):bigint,count(b) FILTER (WHERE (a >= 3)):bigint> -- !query 11 output 0 0 2 0 @@ -128,7 +128,7 @@ struct= 0):bigint,count(b) FILTER (a >= 3):bigint> -- !query 12 SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp GROUP BY dept_id -- !query 12 schema -struct DATE '2003-01-01'):double> +struct DATE '2003-01-01')):double> -- !query 12 output 10 200.0 100 400.0 @@ -141,7 +141,7 @@ NULL NULL -- !query 13 SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id -- !query 13 schema -struct to_date('2003-01-01')):double> +struct to_date('2003-01-01'))):double> -- !query 13 output 10 200.0 100 400.0 @@ -154,7 +154,7 @@ NULL NULL -- !query 14 SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id -- !query 14 schema -struct to_timestamp('2003-01-01 00:00:00')):double> +struct to_timestamp('2003-01-01 00:00:00'))):double> -- !query 14 output 10 200.0 100 400.0 @@ -167,7 +167,7 @@ NULL NULL -- !query 15 SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id -- !query 15 schema -struct 2003-01-01):double> +struct 2003-01-01)):double> -- !query 15 output 10 200.0 100 400.0 @@ -180,7 +180,7 @@ NULL NULL -- !query 16 SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1 -- !query 16 schema -struct +struct -- !query 16 output foo 6 @@ -188,7 +188,7 @@ foo 6 -- !query 17 SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= date "2003-01-01") FROM emp GROUP BY 1 -- !query 17 schema -struct= DATE '2003-01-01'):double> +struct= DATE '2003-01-01')):double> -- !query 17 output foo 1350.0 @@ -196,7 +196,7 @@ foo 1350.0 -- !query 18 SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_date("2003-01-01")) FROM emp GROUP BY 1 -- !query 18 schema -struct= to_date('2003-01-01')):double> +struct= to_date('2003-01-01'))):double> -- !query 18 output foo 1350.0 @@ -204,7 +204,7 @@ foo 1350.0 -- !query 19 SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_timestamp("2003-01-01")) FROM emp GROUP BY 1 -- !query 19 schema -struct= to_timestamp('2003-01-01')):double> +struct= to_timestamp('2003-01-01'))):double> -- !query 19 output foo 1350.0 @@ -212,7 +212,7 @@ foo 1350.0 -- !query 20 select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id -- !query 20 schema -struct 200):double> +struct 200)):double> -- !query 20 output 10 2 2 400.0 NULL 100 2 2 800.0 800.0 @@ -225,7 +225,7 @@ NULL 1 1 400.0 400.0 -- !query 21 select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id -- !query 21 schema -struct 500):double> +struct 500)):double> -- !query 21 output 10 2 2 400.0 NULL 100 2 2 800.0 800.0 @@ -238,7 +238,7 @@ NULL 1 1 400.0 NULL -- !query 22 select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id -- !query 22 schema -struct 200):double> +struct 200)):double> -- !query 22 output 10 2 2 400.0 NULL 100 2 2 NULL 800.0 @@ -251,7 +251,7 @@ NULL 1 1 NULL 400.0 -- !query 23 select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id -- !query 23 schema -struct 500):double> +struct 500)):double> -- !query 23 output 10 2 2 400.0 NULL 100 2 2 NULL 800.0 @@ -264,7 +264,7 @@ NULL 1 1 NULL NULL -- !query 24 SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1 -- !query 24 schema -struct= 0):bigint> +struct= 0)):bigint> -- !query 24 output @@ -272,7 +272,7 @@ struct= 0):bigint> -- !query 25 SELECT 'foo', MAX(STRUCT(a)) FILTER (WHERE b >= 1) FROM testData WHERE a = 0 GROUP BY 1 -- !query 25 schema -struct= 1):struct> +struct= 1)):struct> -- !query 25 output @@ -280,7 +280,7 @@ struct= 1):struct> -- !query 26 SELECT a + b, COUNT(b) FILTER (WHERE b >= 2) FROM testData GROUP BY a + b -- !query 26 schema -struct<(a + b):int,count(b) FILTER (b >= 2):bigint> +struct<(a + b):int,count(b) FILTER (WHERE (b >= 2)):bigint> -- !query 26 output 2 0 3 1 @@ -301,7 +301,7 @@ expression 'testdata.`a`' is neither present in the group by, nor is it an aggre -- !query 28 SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1 -- !query 28 schema -struct<((a + 1) + 1):int,count(b) FILTER (b > 0):bigint> +struct<((a + 1) + 1):int,count(b) FILTER (WHERE (b > 0)):bigint> -- !query 28 output 3 2 4 2 @@ -312,7 +312,7 @@ NULL 1 -- !query 29 SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k -- !query 29 schema -struct 0):bigint> +struct 0)):bigint> -- !query 29 output 1 2 2 2 @@ -327,7 +327,7 @@ SELECT emp.dept_id, FROM emp GROUP BY dept_id -- !query 30 schema -struct scalarsubquery()):double> +struct scalarsubquery())):double> -- !query 30 output 10 133.33333333333334 NULL 100 400.0 400.0 @@ -344,7 +344,7 @@ SELECT emp.dept_id, FROM emp GROUP BY dept_id -- !query 31 schema -struct +struct -- !query 31 output 10 133.33333333333334 133.33333333333334 100 400.0 NULL @@ -366,7 +366,7 @@ GROUP BY dept_id struct<> -- !query 32 output org.apache.spark.sql.AnalysisException -IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) filter exists#x [dept_id#x] AS avg(salary) FILTER exists(dept_id)#x] +IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) filter (where exists#x [dept_id#x]) AS avg(salary) FILTER (WHERE exists(dept_id))#x] : +- Project [state#x] : +- Filter (dept_id#x = outer(dept_id#x)) : +- SubqueryAlias `dept` @@ -392,7 +392,7 @@ GROUP BY dept_id struct<> -- !query 33 output org.apache.spark.sql.AnalysisException -IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) filter NOT exists#x [dept_id#x] AS sum(salary) FILTER (NOT exists(dept_id))#x] +IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) filter (where NOT exists#x [dept_id#x]) AS sum(salary) FILTER (WHERE (NOT exists(dept_id)))#x] : +- Project [state#x] : +- Filter (dept_id#x = outer(dept_id#x)) : +- SubqueryAlias `dept` @@ -417,7 +417,7 @@ GROUP BY dept_id struct<> -- !query 34 output org.apache.spark.sql.AnalysisException -IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) filter dept_id#x IN (list#x []) AS avg(salary) FILTER (dept_id IN (listquery()))#x] +IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) filter (where dept_id#x IN (list#x [])) AS avg(salary) FILTER (WHERE (dept_id IN (listquery())))#x] : +- Distinct : +- Project [dept_id#x] : +- SubqueryAlias `dept` @@ -442,7 +442,7 @@ GROUP BY dept_id struct<> -- !query 35 output org.apache.spark.sql.AnalysisException -IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) filter NOT dept_id#x IN (list#x []) AS sum(salary) FILTER (NOT (dept_id IN (listquery())))#x] +IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) filter (where NOT dept_id#x IN (list#x [])) AS sum(salary) FILTER (WHERE (NOT (dept_id IN (listquery()))))#x] : +- Distinct : +- Project [dept_id#x] : +- SubqueryAlias `dept` diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out index a408dfd5db5cb..d2ab138efcdae 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out @@ -14,7 +14,7 @@ It is not allowed to use an aggregate function in the argument of another aggreg -- !query 1 select min(unique1) filter (where unique1 > 100) from tenk1 -- !query 1 schema -struct 100):int> +struct 100)):int> -- !query 1 output 101 @@ -22,7 +22,7 @@ struct 100):int> -- !query 2 select sum(1/ten) filter (where ten > 0) from tenk1 -- !query 2 schema -struct 0):double> +struct 0)):double> -- !query 2 output 2828.9682539682954 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e472ceac7c1a6..59e34e23198de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,6 +27,7 @@ import scala.collection.parallel.immutable.ParVector import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.HiveResult.hiveResultString @@ -2843,16 +2844,18 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark val query = s"SELECT ${funcToResult._1} FILTER (WHERE b > 1) FROM testData2" val df = sql(query) val physical = df.queryExecution.sparkPlan - val aggregateExpressions = physical.collectFirst { + val aggregateExpressions = physical.collect { case agg: HashAggregateExec => agg.aggregateExpressions case agg: ObjectHashAggregateExec => agg.aggregateExpressions + }.flatten + aggregateExpressions.foreach { expr => + if (expr.mode == Complete || expr.mode == Partial) { + assert(expr.filter.isDefined) + } else { + assert(expr.filter.isEmpty) + } } - assert(aggregateExpressions.isDefined) - assert(aggregateExpressions.get.size == 1) - aggregateExpressions.get.foreach { expr => - assert(expr.filter.isDefined) - } - checkAnswer(df, Row(funcToResult._2) :: Nil) + checkAnswer(df, Row(funcToResult._2)) } } @@ -2860,15 +2863,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { val df = sql("SELECT PERCENTILE(a, 1) FILTER (WHERE b > 1) FROM testData2") val physical = df.queryExecution.sparkPlan - val aggregateExpressions = physical.collectFirst { + val aggregateExpressions = physical.collect { case agg: SortAggregateExec => agg.aggregateExpressions + }.flatten + aggregateExpressions.foreach { expr => + if (expr.mode == Complete || expr.mode == Partial) { + assert(expr.filter.isDefined) + } else { + assert(expr.filter.isEmpty) + } } - assert(aggregateExpressions.isDefined) - assert(aggregateExpressions.get.size == 1) - aggregateExpressions.get.foreach { expr => - assert(expr.filter.isDefined) - } - checkAnswer(df, Row(3) :: Nil) + checkAnswer(df, Row(3)) } }