Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed Jan 15, 2020
1 parent 753e9e1 commit 9d4ba8b
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 51 deletions.
Expand Up @@ -137,10 +137,11 @@ case class AggregateExpression(

@transient
override lazy val references: AttributeSet = {
mode match {
case Partial | Complete => aggregateFunction.references ++ filterAttributes
val aggAttributes = mode match {
case Partial | Complete => aggregateFunction.references
case PartialMerge | Final => AttributeSet(aggregateFunction.aggBufferAttributes)
}
aggAttributes ++ filterAttributes
}

override def toString: String = {
Expand All @@ -151,15 +152,15 @@ case class AggregateExpression(
}
val aggFuncStr = prefix + aggregateFunction.toAggString(isDistinct)
filter match {
case Some(predicate) => s"$aggFuncStr filter $predicate"
case Some(predicate) => s"$aggFuncStr filter (where $predicate)"
case _ => aggFuncStr
}
}

override def sql: String = {
val aggFuncStr = aggregateFunction.sql(isDistinct)
filter match {
case Some(predicate) => s"$aggFuncStr FILTER ${predicate.sql}"
case Some(predicate) => s"$aggFuncStr FILTER (WHERE ${predicate.sql})"
case _ => aggFuncStr
}
}
Expand Down
Expand Up @@ -51,53 +51,53 @@ SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData
struct<>
-- !query 3 output
org.apache.spark.sql.AnalysisException
grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) FILTER (testdata.`a` >= 2) AS `count(b) FILTER (a >= 2)`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.;
grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) FILTER (WHERE (testdata.`a` >= 2)) AS `count(b) FILTER (WHERE (a >= 2))`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.;


-- !query 4
SELECT COUNT(a) FILTER (WHERE a = 1), COUNT(b) FILTER (WHERE a > 1) FROM testData
-- !query 4 schema
struct<count(a) FILTER (a = 1):bigint,count(b) FILTER (a > 1):bigint>
struct<count(a) FILTER (WHERE (a = 1)):bigint,count(b) FILTER (WHERE (a > 1)):bigint>
-- !query 4 output
2 4


-- !query 5
SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp
-- !query 5 schema
struct<count(id) FILTER (hiredate = DATE '2001-01-01'):bigint>
struct<count(id) FILTER (WHERE (hiredate = DATE '2001-01-01')):bigint>
-- !query 5 output
2


-- !query 6
SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp
-- !query 6 schema
struct<count(id) FILTER (hiredate = to_date('2001-01-01 00:00:00')):bigint>
struct<count(id) FILTER (WHERE (hiredate = to_date('2001-01-01 00:00:00'))):bigint>
-- !query 6 output
2


-- !query 7
SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp
-- !query 7 schema
struct<count(id) FILTER (CAST(hiredate AS TIMESTAMP) = to_timestamp('2001-01-01 00:00:00')):bigint>
struct<count(id) FILTER (WHERE (CAST(hiredate AS TIMESTAMP) = to_timestamp('2001-01-01 00:00:00'))):bigint>
-- !query 7 output
2


-- !query 8
SELECT COUNT(id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") = "2001-01-01") FROM emp
-- !query 8 schema
struct<count(id) FILTER (date_format(CAST(hiredate AS TIMESTAMP), yyyy-MM-dd) = 2001-01-01):bigint>
struct<count(id) FILTER (WHERE (date_format(CAST(hiredate AS TIMESTAMP), yyyy-MM-dd) = 2001-01-01)):bigint>
-- !query 8 output
2


-- !query 9
SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a
-- !query 9 schema
struct<a:int,count(b) FILTER (a >= 2):bigint>
struct<a:int,count(b) FILTER (WHERE (a >= 2)):bigint>
-- !query 9 output
1 0
2 2
Expand All @@ -117,7 +117,7 @@ expression 'testdata.`a`' is neither present in the group by, nor is it an aggre
-- !query 11
SELECT COUNT(a) FILTER (WHERE a >= 0), COUNT(b) FILTER (WHERE a >= 3) FROM testData GROUP BY a
-- !query 11 schema
struct<count(a) FILTER (a >= 0):bigint,count(b) FILTER (a >= 3):bigint>
struct<count(a) FILTER (WHERE (a >= 0)):bigint,count(b) FILTER (WHERE (a >= 3)):bigint>
-- !query 11 output
0 0
2 0
Expand All @@ -128,7 +128,7 @@ struct<count(a) FILTER (a >= 0):bigint,count(b) FILTER (a >= 3):bigint>
-- !query 12
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp GROUP BY dept_id
-- !query 12 schema
struct<dept_id:int,sum(salary) FILTER (hiredate > DATE '2003-01-01'):double>
struct<dept_id:int,sum(salary) FILTER (WHERE (hiredate > DATE '2003-01-01')):double>
-- !query 12 output
10 200.0
100 400.0
Expand All @@ -141,7 +141,7 @@ NULL NULL
-- !query 13
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id
-- !query 13 schema
struct<dept_id:int,sum(salary) FILTER (hiredate > to_date('2003-01-01')):double>
struct<dept_id:int,sum(salary) FILTER (WHERE (hiredate > to_date('2003-01-01'))):double>
-- !query 13 output
10 200.0
100 400.0
Expand All @@ -154,7 +154,7 @@ NULL NULL
-- !query 14
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id
-- !query 14 schema
struct<dept_id:int,sum(salary) FILTER (CAST(hiredate AS TIMESTAMP) > to_timestamp('2003-01-01 00:00:00')):double>
struct<dept_id:int,sum(salary) FILTER (WHERE (CAST(hiredate AS TIMESTAMP) > to_timestamp('2003-01-01 00:00:00'))):double>
-- !query 14 output
10 200.0
100 400.0
Expand All @@ -167,7 +167,7 @@ NULL NULL
-- !query 15
SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id
-- !query 15 schema
struct<dept_id:int,sum(salary) FILTER (date_format(CAST(hiredate AS TIMESTAMP), yyyy-MM-dd) > 2003-01-01):double>
struct<dept_id:int,sum(salary) FILTER (WHERE (date_format(CAST(hiredate AS TIMESTAMP), yyyy-MM-dd) > 2003-01-01)):double>
-- !query 15 output
10 200.0
100 400.0
Expand All @@ -180,39 +180,39 @@ NULL NULL
-- !query 16
SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1
-- !query 16 schema
struct<foo:string,count(a) FILTER (b <= 2):bigint>
struct<foo:string,count(a) FILTER (WHERE (b <= 2)):bigint>
-- !query 16 output
foo 6


-- !query 17
SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= date "2003-01-01") FROM emp GROUP BY 1
-- !query 17 schema
struct<foo:string,sum(salary) FILTER (hiredate >= DATE '2003-01-01'):double>
struct<foo:string,sum(salary) FILTER (WHERE (hiredate >= DATE '2003-01-01')):double>
-- !query 17 output
foo 1350.0


-- !query 18
SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_date("2003-01-01")) FROM emp GROUP BY 1
-- !query 18 schema
struct<foo:string,sum(salary) FILTER (hiredate >= to_date('2003-01-01')):double>
struct<foo:string,sum(salary) FILTER (WHERE (hiredate >= to_date('2003-01-01'))):double>
-- !query 18 output
foo 1350.0


-- !query 19
SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_timestamp("2003-01-01")) FROM emp GROUP BY 1
-- !query 19 schema
struct<foo:string,sum(salary) FILTER (CAST(hiredate AS TIMESTAMP) >= to_timestamp('2003-01-01')):double>
struct<foo:string,sum(salary) FILTER (WHERE (CAST(hiredate AS TIMESTAMP) >= to_timestamp('2003-01-01'))):double>
-- !query 19 output
foo 1350.0


-- !query 20
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id
-- !query 20 schema
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary) FILTER (id > 200):double>
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary) FILTER (WHERE (id > 200)):double>
-- !query 20 output
10 2 2 400.0 NULL
100 2 2 800.0 800.0
Expand All @@ -225,7 +225,7 @@ NULL 1 1 400.0 400.0
-- !query 21
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id
-- !query 21 schema
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary) FILTER ((id + dept_id) > 500):double>
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary) FILTER (WHERE ((id + dept_id) > 500)):double>
-- !query 21 output
10 2 2 400.0 NULL
100 2 2 800.0 800.0
Expand All @@ -238,7 +238,7 @@ NULL 1 1 400.0 NULL
-- !query 22
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id
-- !query 22 schema
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary) FILTER (salary < 400.0):double,sum(salary) FILTER (id > 200):double>
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double,sum(salary) FILTER (WHERE (id > 200)):double>
-- !query 22 output
10 2 2 400.0 NULL
100 2 2 NULL 800.0
Expand All @@ -251,7 +251,7 @@ NULL 1 1 NULL 400.0
-- !query 23
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id
-- !query 23 schema
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary) FILTER (salary < 400.0):double,sum(salary) FILTER ((id + dept_id) > 500):double>
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double,sum(salary) FILTER (WHERE ((id + dept_id) > 500)):double>
-- !query 23 output
10 2 2 400.0 NULL
100 2 2 NULL 800.0
Expand All @@ -264,23 +264,23 @@ NULL 1 1 NULL NULL
-- !query 24
SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1
-- !query 24 schema
struct<foo:string,approx_count_distinct(a) FILTER (b >= 0):bigint>
struct<foo:string,approx_count_distinct(a) FILTER (WHERE (b >= 0)):bigint>
-- !query 24 output



-- !query 25
SELECT 'foo', MAX(STRUCT(a)) FILTER (WHERE b >= 1) FROM testData WHERE a = 0 GROUP BY 1
-- !query 25 schema
struct<foo:string,max(named_struct(a, a)) FILTER (b >= 1):struct<a:int>>
struct<foo:string,max(named_struct(a, a)) FILTER (WHERE (b >= 1)):struct<a:int>>
-- !query 25 output



-- !query 26
SELECT a + b, COUNT(b) FILTER (WHERE b >= 2) FROM testData GROUP BY a + b
-- !query 26 schema
struct<(a + b):int,count(b) FILTER (b >= 2):bigint>
struct<(a + b):int,count(b) FILTER (WHERE (b >= 2)):bigint>
-- !query 26 output
2 0
3 1
Expand All @@ -301,7 +301,7 @@ expression 'testdata.`a`' is neither present in the group by, nor is it an aggre
-- !query 28
SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1
-- !query 28 schema
struct<((a + 1) + 1):int,count(b) FILTER (b > 0):bigint>
struct<((a + 1) + 1):int,count(b) FILTER (WHERE (b > 0)):bigint>
-- !query 28 output
3 2
4 2
Expand All @@ -312,7 +312,7 @@ NULL 1
-- !query 29
SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k
-- !query 29 schema
struct<k:int,count(b) FILTER (b > 0):bigint>
struct<k:int,count(b) FILTER (WHERE (b > 0)):bigint>
-- !query 29 output
1 2
2 2
Expand All @@ -327,7 +327,7 @@ SELECT emp.dept_id,
FROM emp
GROUP BY dept_id
-- !query 30 schema
struct<dept_id:int,avg(salary):double,avg(salary) FILTER (id > scalarsubquery()):double>
struct<dept_id:int,avg(salary):double,avg(salary) FILTER (WHERE (id > scalarsubquery())):double>
-- !query 30 output
10 133.33333333333334 NULL
100 400.0 400.0
Expand All @@ -344,7 +344,7 @@ SELECT emp.dept_id,
FROM emp
GROUP BY dept_id
-- !query 31 schema
struct<dept_id:int,avg(salary):double,avg(salary) FILTER (dept_id = scalarsubquery()):double>
struct<dept_id:int,avg(salary):double,avg(salary) FILTER (WHERE (dept_id = scalarsubquery())):double>
-- !query 31 output
10 133.33333333333334 133.33333333333334
100 400.0 NULL
Expand All @@ -366,7 +366,7 @@ GROUP BY dept_id
struct<>
-- !query 32 output
org.apache.spark.sql.AnalysisException
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) filter exists#x [dept_id#x] AS avg(salary) FILTER exists(dept_id)#x]
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) filter (where exists#x [dept_id#x]) AS avg(salary) FILTER (WHERE exists(dept_id))#x]
: +- Project [state#x]
: +- Filter (dept_id#x = outer(dept_id#x))
: +- SubqueryAlias `dept`
Expand All @@ -392,7 +392,7 @@ GROUP BY dept_id
struct<>
-- !query 33 output
org.apache.spark.sql.AnalysisException
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) filter NOT exists#x [dept_id#x] AS sum(salary) FILTER (NOT exists(dept_id))#x]
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) filter (where NOT exists#x [dept_id#x]) AS sum(salary) FILTER (WHERE (NOT exists(dept_id)))#x]
: +- Project [state#x]
: +- Filter (dept_id#x = outer(dept_id#x))
: +- SubqueryAlias `dept`
Expand All @@ -417,7 +417,7 @@ GROUP BY dept_id
struct<>
-- !query 34 output
org.apache.spark.sql.AnalysisException
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) filter dept_id#x IN (list#x []) AS avg(salary) FILTER (dept_id IN (listquery()))#x]
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) filter (where dept_id#x IN (list#x [])) AS avg(salary) FILTER (WHERE (dept_id IN (listquery())))#x]
: +- Distinct
: +- Project [dept_id#x]
: +- SubqueryAlias `dept`
Expand All @@ -442,7 +442,7 @@ GROUP BY dept_id
struct<>
-- !query 35 output
org.apache.spark.sql.AnalysisException
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) filter NOT dept_id#x IN (list#x []) AS sum(salary) FILTER (NOT (dept_id IN (listquery())))#x]
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) filter (where NOT dept_id#x IN (list#x [])) AS sum(salary) FILTER (WHERE (NOT (dept_id IN (listquery()))))#x]
: +- Distinct
: +- Project [dept_id#x]
: +- SubqueryAlias `dept`
Expand Down
Expand Up @@ -14,15 +14,15 @@ It is not allowed to use an aggregate function in the argument of another aggreg
-- !query 1
select min(unique1) filter (where unique1 > 100) from tenk1
-- !query 1 schema
struct<min(unique1) FILTER (unique1 > 100):int>
struct<min(unique1) FILTER (WHERE (unique1 > 100)):int>
-- !query 1 output
101


-- !query 2
select sum(1/ten) filter (where ten > 0) from tenk1
-- !query 2 schema
struct<sum((CAST(1 AS DOUBLE) / CAST(ten AS DOUBLE))) FILTER (ten > 0):double>
struct<sum((CAST(1 AS DOUBLE) / CAST(ten AS DOUBLE))) FILTER (WHERE (ten > 0)):double>
-- !query 2 output
2828.9682539682954

Expand Down
33 changes: 19 additions & 14 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Expand Up @@ -27,6 +27,7 @@ import scala.collection.parallel.immutable.ParVector
import org.apache.spark.{AccumulatorSuite, SparkException}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial}
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.HiveResult.hiveResultString
Expand Down Expand Up @@ -2843,32 +2844,36 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
val query = s"SELECT ${funcToResult._1} FILTER (WHERE b > 1) FROM testData2"
val df = sql(query)
val physical = df.queryExecution.sparkPlan
val aggregateExpressions = physical.collectFirst {
val aggregateExpressions = physical.collect {
case agg: HashAggregateExec => agg.aggregateExpressions
case agg: ObjectHashAggregateExec => agg.aggregateExpressions
}.flatten
aggregateExpressions.foreach { expr =>
if (expr.mode == Complete || expr.mode == Partial) {
assert(expr.filter.isDefined)
} else {
assert(expr.filter.isEmpty)
}
}
assert(aggregateExpressions.isDefined)
assert(aggregateExpressions.get.size == 1)
aggregateExpressions.get.foreach { expr =>
assert(expr.filter.isDefined)
}
checkAnswer(df, Row(funcToResult._2) :: Nil)
checkAnswer(df, Row(funcToResult._2))
}
}

test("Support filter clause for aggregate function uses SortAggregateExec") {
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
val df = sql("SELECT PERCENTILE(a, 1) FILTER (WHERE b > 1) FROM testData2")
val physical = df.queryExecution.sparkPlan
val aggregateExpressions = physical.collectFirst {
val aggregateExpressions = physical.collect {
case agg: SortAggregateExec => agg.aggregateExpressions
}.flatten
aggregateExpressions.foreach { expr =>
if (expr.mode == Complete || expr.mode == Partial) {
assert(expr.filter.isDefined)
} else {
assert(expr.filter.isEmpty)
}
}
assert(aggregateExpressions.isDefined)
assert(aggregateExpressions.get.size == 1)
aggregateExpressions.get.foreach { expr =>
assert(expr.filter.isDefined)
}
checkAnswer(df, Row(3) :: Nil)
checkAnswer(df, Row(3))
}
}

Expand Down

0 comments on commit 9d4ba8b

Please sign in to comment.