diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 66f08523479bf..4ae6baf794bd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -980,12 +980,22 @@ class DataFrameAggregateSuite extends QueryTest withTempView("t1", "t2") { sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)") sql("create temporary view t2 as select * from values (3, 4) as t2(c, d)") + // test without grouping keys checkAnswer(sql("select sum(if(c > (select a from t1), d, 0)) as csum from t2"), Row(4) :: Nil) + // test with grouping keys checkAnswer(sql("select c, sum(if(c > (select a from t1), d, 0)) as csum from " + "t2 group by c"), Row(3, 4) :: Nil) + + // test with distinct + checkAnswer(sql("select avg(distinct(d)), sum(distinct(if(c > (select a from t1)," + + " d, 0))) as csum from t2 group by c"), Row(4, 4) :: Nil) + + // test subquery with agg + checkAnswer(sql("select sum(distinct(if(c > (select sum(distinct(a)) from t1)," + + " d, 0))) as csum from t2 group by c"), Row(4) :: Nil) } } }