From 511047479d7bef9f421df83334fab64a78d414cd Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Fri, 24 Jun 2016 19:33:49 +0900 Subject: [PATCH 1/3] Add type checks in CheckAnalysis --- .../sql/catalyst/analysis/CheckAnalysis.scala | 9 ++++++--- .../catalyst/analysis/AnalysisErrorSuite.scala | 15 ++++++++++++++- .../spark/sql/DataFrameAggregateSuite.scala | 10 ++++++++++ 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 899227674f2ac..5684de60f2da0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.SimpleCatalogRelation import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, CollectSet} import org.apache.spark.sql.catalyst.plans.UsingJoin import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -73,9 +73,12 @@ trait CheckAnalysis extends PredicateHelper { s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") case g: Grouping => - failAnalysis(s"grouping() can only be used with GroupingSets/Cube/Rollup") + failAnalysis("grouping() can only be used with GroupingSets/Cube/Rollup") case g: GroupingID => - failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup") + failAnalysis("grouping_id() can only be used with GroupingSets/Cube/Rollup") + + case c: CollectSet if c.child.dataType.isInstanceOf[MapType] => + failAnalysis("collect_set() cannot have map type data") case w @ WindowExpression(AggregateExpression(_, _, true, _), _) => failAnalysis(s"Distinct window functions are not supported: $w") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index a41383fbf6562..d1a330ce4eb48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, CollectSet, Complete, Count} import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} @@ -435,6 +435,19 @@ class AnalysisErrorSuite extends AnalysisTest { "another aggregate function." :: Nil) } + test("we should fail analysis when we find map type data in collect_set") { + val dataType = MapType(StringType, IntegerType) + val plan = + Aggregate( + AttributeReference("a", IntegerType)(exprId = ExprId(2)) :: Nil, + Alias(CollectSet(AttributeReference("b", dataType)(exprId = ExprId(1))) + .toAggregateExpression(), "c")() :: Nil, + LocalRelation( + AttributeReference("a", IntegerType)(exprId = ExprId(2)), + AttributeReference("b", dataType)(exprId = ExprId(1)))) + assertAnalysisError(plan, "collect_set() cannot have map type data" :: Nil) + } + test("Join can't work on binary and map types") { val plan = Join( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 69a990789bcfd..92aa7b95434dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -457,6 +457,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("collect_set functions cannot have maps") { + val df = Seq((1, 3, 0), (2, 3, 0), (3, 4, 1)) + .toDF("a", "x", "y") + .select($"a", map($"x", $"y").as("b")) + val error = intercept[AnalysisException] { + df.select(collect_set($"a"), collect_set($"b")) + } + assert(error.message.contains("collect_set() cannot have map type data")) + } + test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), From dd7e233d2df97a68eeac4570647ac9f93e044d70 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Sat, 25 Jun 2016 09:57:56 +0900 Subject: [PATCH 2/3] Apply comments --- .../spark/sql/catalyst/analysis/CheckAnalysis.scala | 5 +---- .../sql/catalyst/expressions/aggregate/collect.scala | 9 +++++++++ .../sql/catalyst/analysis/AnalysisErrorSuite.scala | 10 +++++++--- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 5684de60f2da0..ac9693e079f51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.SimpleCatalogRelation import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, CollectSet} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.UsingJoin import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -77,9 +77,6 @@ trait CheckAnalysis extends PredicateHelper { case g: GroupingID => failAnalysis("grouping_id() can only be used with GroupingSets/Cube/Rollup") - case c: CollectSet if c.child.dataType.isInstanceOf[MapType] => - failAnalysis("collect_set() cannot have map type data") - case w @ WindowExpression(AggregateExpression(_, _, true, _), _) => failAnalysis(s"Distinct window functions are not supported: $w") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 1f4ff9c4b184e..ac2cefaddcf59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import scala.collection.generic.Growable import scala.collection.mutable +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow @@ -107,6 +108,14 @@ case class CollectSet( def this(child: Expression) = this(child, 0, 0) + override def checkInputDataTypes(): TypeCheckResult = { + if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("collect_set() cannot have map type data") + } + } + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index d1a330ce4eb48..bf0c6f2010c7d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -436,8 +436,8 @@ class AnalysisErrorSuite extends AnalysisTest { } test("we should fail analysis when we find map type data in collect_set") { - val dataType = MapType(StringType, IntegerType) - val plan = + def errorTest(dataType: DataType): Unit = { + val plan = Aggregate( AttributeReference("a", IntegerType)(exprId = ExprId(2)) :: Nil, Alias(CollectSet(AttributeReference("b", dataType)(exprId = ExprId(1))) @@ -445,7 +445,11 @@ class AnalysisErrorSuite extends AnalysisTest { LocalRelation( AttributeReference("a", IntegerType)(exprId = ExprId(2)), AttributeReference("b", dataType)(exprId = ExprId(1)))) - assertAnalysisError(plan, "collect_set() cannot have map type data" :: Nil) + assertAnalysisError(plan, "collect_set() cannot have map type data" :: Nil) + } + + val mapType = MapType(StringType, IntegerType) + (mapType :: ArrayType(mapType) :: StructType(StructField("x", mapType) :: Nil) :: Nil) } test("Join can't work on binary and map types") { From 1a445c887ebf0982186d2ac8b62aac627509eee3 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Sat, 25 Jun 2016 11:01:29 +0900 Subject: [PATCH 3/3] Remove test --- .../analysis/AnalysisErrorSuite.scala | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index bf0c6f2010c7d..a41383fbf6562 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, CollectSet, Complete, Count} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} @@ -435,23 +435,6 @@ class AnalysisErrorSuite extends AnalysisTest { "another aggregate function." :: Nil) } - test("we should fail analysis when we find map type data in collect_set") { - def errorTest(dataType: DataType): Unit = { - val plan = - Aggregate( - AttributeReference("a", IntegerType)(exprId = ExprId(2)) :: Nil, - Alias(CollectSet(AttributeReference("b", dataType)(exprId = ExprId(1))) - .toAggregateExpression(), "c")() :: Nil, - LocalRelation( - AttributeReference("a", IntegerType)(exprId = ExprId(2)), - AttributeReference("b", dataType)(exprId = ExprId(1)))) - assertAnalysisError(plan, "collect_set() cannot have map type data" :: Nil) - } - - val mapType = MapType(StringType, IntegerType) - (mapType :: ArrayType(mapType) :: StructType(StructField("x", mapType) :: Nil) :: Nil) - } - test("Join can't work on binary and map types") { val plan = Join(