From 493157a3b97616d221ec2b5ddf1a21cdf9a1a3f4 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Thu, 14 May 2020 21:41:37 +0800 Subject: [PATCH] check operator --- .../apache/spark/sql/DataFrameAggregateSuite.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 28e214a6cbf0f..2293d4ae61aff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -998,12 +998,16 @@ class DataFrameAggregateSuite extends QueryTest " d, 0))) as csum from t2 group by c"), Row(4) :: Nil) // test SortAggregateExec - checkAnswer(sql("select max(if(c > (select a from t1), 'str1', 'str2')) as csum from t2"), - Row("str1") :: Nil) + var df = sql("select max(if(c > (select a from t1), 'str1', 'str2')) as csum from t2") + assert(df.queryExecution.executedPlan + .find { case _: SortAggregateExec => true }.isDefined) + checkAnswer(df, Row("str1") :: Nil) // test ObjectHashAggregateExec - checkAnswer(sql("select collect_list(d), sum(if(c > (select a from t1), d, 0)) as csum" + - " from t2"), Row(Array(4), 4) :: Nil) + df = sql("select collect_list(d), sum(if(c > (select a from t1), d, 0)) as csum from t2") + assert(df.queryExecution.executedPlan + .find { case _: ObjectHashAggregateExec => true }.isDefined) + checkAnswer(df, Row(Array(4), 4) :: Nil) } } }