diff --git a/docs/source/contributor-guide/spark_expressions_support.md b/docs/source/contributor-guide/spark_expressions_support.md index 9c19926a7c..7f27c7e02f 100644 --- a/docs/source/contributor-guide/spark_expressions_support.md +++ b/docs/source/contributor-guide/spark_expressions_support.md @@ -172,7 +172,7 @@ ### collection_funcs - [ ] array_size -- [ ] cardinality +- [x] cardinality - [x] concat - [x] reverse - [x] size diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 5edc08840a..052c9eb001 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -639,13 +639,9 @@ object CometArrayFilter extends CometExpressionSerde[ArrayFilter] { object CometSize extends CometExpressionSerde[Size] { - override def getUnsupportedReasons(): Seq[String] = Seq( - "Only supports `ArrayType` input; `MapType` input is not supported") - override def getSupportLevel(expr: Size): SupportLevel = { expr.child.dataType match { - case _: ArrayType => Compatible() - case _: MapType => Unsupported(Some("size does not support map inputs")) + case _: ArrayType | _: MapType => Compatible() case other => // this should be unreachable because Spark only supports map and array inputs Unsupported(Some(s"Unsupported child data type: $other")) @@ -660,7 +656,7 @@ object CometSize extends CometExpressionSerde[Size] { for { isNotNullExprProto <- createIsNotNullExprProto(expr, inputs, binding) sizeScalarExprProto <- scalarFunctionExprToProto("size", arrayExprProto) - emptyLiteralExprProto <- createLiteralExprProto(SQLConf.get.legacySizeOfNull) + emptyLiteralExprProto <- createLiteralExprProto(expr.legacySizeOfNull) } yield { val caseWhenExpr = ExprOuterClass.CaseWhen .newBuilder() diff --git a/spark/src/test/resources/sql-tests/expressions/array/cardinality.sql b/spark/src/test/resources/sql-tests/expressions/array/cardinality.sql new file mode 100644 index 0000000000..3e84d91036 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/cardinality.sql @@ -0,0 +1,60 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- cardinality() is an alias for size() with legacySizeOfNull=false: +-- it always returns NULL for NULL input (never -1), and supports +-- both array and map inputs. +-- inputTypes: TypeCollection(ArrayType, MapType) -> test both + +statement +CREATE TABLE test_cardinality( + arr array, + nested_arr array>, + struct_arr array>, + m map +) USING parquet + +statement +INSERT INTO test_cardinality VALUES + (array(1, 2, 3), array(array(1, 2), array(3)), array(named_struct('a', 1), named_struct('a', 2)), map('a', 1, 'b', 2)), + (array(10), array(array(10)), array(named_struct('a', 1)), map('x', 99)), + (array(), array(), array(), map()), + (NULL, NULL, NULL, NULL) + +-- column reference: array input +query +SELECT cardinality(arr) FROM test_cardinality + +-- column reference: map input +query +SELECT cardinality(m) FROM test_cardinality + +-- both in same query +query +SELECT cardinality(arr), cardinality(m) FROM test_cardinality + +-- cardinality returns NULL for NULL input (not -1 like size() in legacy mode) +query +SELECT cardinality(arr), cardinality(m) FROM test_cardinality WHERE arr IS NULL + +-- nested array input +query +SELECT cardinality(nested_arr) FROM test_cardinality + +-- array-of-structs input +query +SELECT cardinality(struct_arr) FROM test_cardinality diff --git a/spark/src/test/resources/sql-tests/expressions/array/size.sql b/spark/src/test/resources/sql-tests/expressions/array/size.sql index b006a4da0d..c77ad0ef67 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/size.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/size.sql @@ -21,7 +21,7 @@ CREATE TABLE test_size(arr array, m map) USING parquet statement INSERT INTO test_size VALUES (array(1, 2, 3), map('a', 1, 'b', 2)), (array(), map()), (NULL, NULL) -query spark_answer_only +query SELECT size(arr), size(m) FROM test_size -- literal arguments diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index f3c7d9f23e..75cb1987f0 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -126,22 +126,24 @@ class CometMapExpressionSuite extends CometTestBase { } } - test("fallback for size with map input") { + test("fallback for size with map constructor input") { withTempDir { dir => withTempView("t1") { val path = new Path(dir.toURI.toString, "test.parquet") makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = true, 100) spark.read.parquet(path.toString).createOrReplaceTempView("t1") - // Use column references in maps to avoid constant folding + // Size now supports MapType inputs, this falls back since CreateMap + // is not yet supported natively. Use column references to avoid + // constant folding. checkSparkAnswerAndFallbackReason( sql("SELECT size(case when _2 < 0 then map(_8, _9) else map() end) from t1"), - "size does not support map inputs") + "map is not supported") } } } - // fails with "map is not supported" + // still fails because CreateMap is not supported natively ignore("size with map input") { withTempDir { dir => withTempView("t1") {