From df2a1fe131476aac128d63df9b06ec4bca0c2c07 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Mon, 25 May 2020 20:36:00 -0700 Subject: [PATCH] [SPARK-31808][SQL] Makes struct function's output name and class name pretty ### What changes were proposed in this pull request? This PR proposes to set the alias, and class name in its `ExpressionInfo` for `struct`. - Class name in `ExpressionInfo` - from: `org.apache.spark.sql.catalyst.expressions.NamedStruct` - to:`org.apache.spark.sql.catalyst.expressions.CreateNamedStruct` - Alias name: `named_struct(col1, v, ...)` -> `struct(v, ...)` This PR takes over https://github.com/apache/spark/pull/28631 ### Why are the changes needed? To show the correct output name and class names to users. ### Does this PR introduce _any_ user-facing change? Yes. **Before:** ```scala scala> sql("DESC FUNCTION struct").show(false) +------------------------------------------------------------------------------------+ |function_desc | +------------------------------------------------------------------------------------+ |Function: struct | |Class: org.apache.spark.sql.catalyst.expressions.NamedStruct | |Usage: struct(col1, col2, col3, ...) - Creates a struct with the given field values.| +------------------------------------------------------------------------------------+ ``` ```scala scala> sql("SELECT struct(1, 2)").show(false) +------------------------------+ |named_struct(col1, 1, col2, 2)| +------------------------------+ |[1, 2] | +------------------------------+ ``` **After:** ```scala scala> sql("DESC FUNCTION struct").show(false) +------------------------------------------------------------------------------------+ |function_desc | +------------------------------------------------------------------------------------+ |Function: struct | |Class: org.apache.spark.sql.catalyst.expressions.CreateNamedStruct | |Usage: struct(col1, col2, col3, ...) - Creates a struct with the given field values.| +------------------------------------------------------------------------------------+ ``` ```scala scala> sql("SELECT struct(1, 2)").show(false) +------------+ |struct(1, 2)| +------------+ |[1, 2] | +------------+ ``` ### How was this patch tested? Manually tested, and Jenkins tests. Closes #28633 from HyukjinKwon/SPARK-31808. Authored-by: HyukjinKwon Signed-off-by: Dongjoon Hyun --- .../expressions/complexTypeCreator.scala | 34 ++++++++++++++++--- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../org/apache/spark/sql/functions.scala | 2 +- .../sql-functions/sql-expression-schema.md | 6 ++-- .../sql-tests/results/group-by-filter.sql.out | 2 +- .../sql-tests/results/group-by.sql.out | 2 +- .../sql-tests/results/struct.sql.out | 2 +- .../typeCoercion/native/mapZipWith.sql.out | 4 +-- .../results/udf/udf-group-by.sql.out | 2 +- 9 files changed, 40 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 858c91a4d8e86..5212ef3930bc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ @@ -311,7 +311,12 @@ case object NamePlaceholder extends LeafExpression with Unevaluable { /** * Returns a Row containing the evaluation of all children expressions. */ -object CreateStruct extends FunctionBuilder { +object CreateStruct { + /** + * Returns a named struct with generated names or using the names when available. + * It should not be used for `struct` expressions or functions explicitly called + * by users. + */ def apply(children: Seq[Expression]): CreateNamedStruct = { CreateNamedStruct(children.zipWithIndex.flatMap { case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) @@ -320,12 +325,23 @@ object CreateStruct extends FunctionBuilder { }) } + /** + * Returns a named struct with a pretty SQL name. It will show the pretty SQL string + * in its output column name as if `struct(...)` was called. Should be + * used for `struct` expressions or functions explicitly called by users. + */ + def create(children: Seq[Expression]): CreateNamedStruct = { + val expr = CreateStruct(children) + expr.setTagValue(FUNC_ALIAS, "struct") + expr + } + /** * Entry to use in the function registry. */ val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = { val info: ExpressionInfo = new ExpressionInfo( - "org.apache.spark.sql.catalyst.expressions.NamedStruct", + classOf[CreateNamedStruct].getCanonicalName, null, "struct", "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.", @@ -335,7 +351,7 @@ object CreateStruct extends FunctionBuilder { "", "", "") - ("struct", (info, this)) + ("struct", (info, this.create)) } } @@ -433,7 +449,15 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { """.stripMargin, isNull = FalseLiteral) } - override def prettyName: String = "named_struct" + // There is an alias set at `CreateStruct.create`. If there is an alias, + // this is the struct function explicitly called by a user and we should + // respect it in the SQL string as `struct(...)`. + override def prettyName: String = getTagValue(FUNC_ALIAS).getOrElse("named_struct") + + override def sql: String = getTagValue(FUNC_ALIAS).map { alias => + val childrenSQL = children.indices.filter(_ % 2 == 1).map(children(_).sql).mkString(", ") + s"$alias($childrenSQL)" + }.getOrElse(super.sql) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c0cecf8536c39..03571a740df3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1534,7 +1534,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Create a [[CreateStruct]] expression. */ override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) { - CreateStruct(ctx.argument.asScala.map(expression)) + CreateStruct.create(ctx.argument.asScala.map(expression)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5481337bf6cee..0cca3e7b47c56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1306,7 +1306,7 @@ object functions { * @since 1.4.0 */ @scala.annotation.varargs - def struct(cols: Column*): Column = withExpr { CreateStruct(cols.map(_.expr)) } + def struct(cols: Column*): Column = withExpr { CreateStruct.create(cols.map(_.expr)) } /** * Creates a new struct column that composes multiple input columns. diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 23173c8ba1f11..8949b62f0a512 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -2,7 +2,7 @@ ## Summary - Number of queries: 337 - Number of expressions that missing example: 34 - - Expressions missing examples: and,string,tinyint,double,smallint,date,decimal,boolean,float,binary,bigint,int,timestamp,cume_dist,dense_rank,input_file_block_length,input_file_block_start,input_file_name,lag,lead,monotonically_increasing_id,ntile,struct,!,not,or,percent_rank,rank,row_number,spark_partition_id,version,window,positive,count_min_sketch + - Expressions missing examples: and,string,tinyint,double,smallint,date,decimal,boolean,float,binary,bigint,int,timestamp,struct,cume_dist,dense_rank,input_file_block_length,input_file_block_start,input_file_name,lag,lead,monotonically_increasing_id,ntile,!,not,or,percent_rank,rank,row_number,spark_partition_id,version,window,positive,count_min_sketch ## Schema of Built-in Functions | Class name | Function name or alias | Query example | Output schema | | ---------- | ---------------------- | ------------- | ------------- | @@ -79,6 +79,7 @@ | org.apache.spark.sql.catalyst.expressions.CreateArray | array | SELECT array(1, 2, 3) | struct> | | org.apache.spark.sql.catalyst.expressions.CreateMap | map | SELECT map(1.0, '2', 3.0, '4') | struct> | | org.apache.spark.sql.catalyst.expressions.CreateNamedStruct | named_struct | SELECT named_struct("a", 1, "b", 2, "c", 3) | struct> | +| org.apache.spark.sql.catalyst.expressions.CreateNamedStruct | struct | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.CsvToStructs | from_csv | SELECT from_csv('1, 0.8', 'a INT, b DOUBLE') | struct> | | org.apache.spark.sql.catalyst.expressions.Cube | cube | SELECT name, age, count(*) FROM VALUES (2, 'Alice'), (5, 'Bob') people(age, name) GROUP BY cube(name, age) | struct | | org.apache.spark.sql.catalyst.expressions.CumeDist | cume_dist | N/A | N/A | @@ -171,7 +172,7 @@ | org.apache.spark.sql.catalyst.expressions.MapEntries | map_entries | SELECT map_entries(map(1, 'a', 2, 'b')) | struct>> | | org.apache.spark.sql.catalyst.expressions.MapFilter | map_filter | SELECT map_filter(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v) | struct namedlambdavariable()), namedlambdavariable(), namedlambdavariable())):map> | | org.apache.spark.sql.catalyst.expressions.MapFromArrays | map_from_arrays | SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')) | struct> | -| org.apache.spark.sql.catalyst.expressions.MapFromEntries | map_from_entries | SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))) | struct> | +| org.apache.spark.sql.catalyst.expressions.MapFromEntries | map_from_entries | SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))) | struct> | | org.apache.spark.sql.catalyst.expressions.MapKeys | map_keys | SELECT map_keys(map(1, 'a', 2, 'b')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapValues | map_values | SELECT map_values(map(1, 'a', 2, 'b')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapZipWith | map_zip_with | SELECT map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)) | struct> | @@ -186,7 +187,6 @@ | org.apache.spark.sql.catalyst.expressions.Murmur3Hash | hash | SELECT hash('Spark', array(123), 2) | struct | | org.apache.spark.sql.catalyst.expressions.NTile | ntile | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.NaNvl | nanvl | SELECT nanvl(cast('NaN' as double), 123) | struct | -| org.apache.spark.sql.catalyst.expressions.NamedStruct | struct | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.NextDay | next_day | SELECT next_day('2015-01-14', 'TU') | struct | | org.apache.spark.sql.catalyst.expressions.Not | ! | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.Not | not | N/A | N/A | diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index 3fcd132701a3f..d41d25280146b 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -272,7 +272,7 @@ struct= 0)):bigint> -- !query SELECT 'foo', MAX(STRUCT(a)) FILTER (WHERE b >= 1) FROM testData WHERE a = 0 GROUP BY 1 -- !query schema -struct= 1)):struct> +struct= 1)):struct> -- !query output diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 7bfdd0ad53a95..50eb2a9f22f69 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -87,7 +87,7 @@ struct -- !query SELECT 'foo', MAX(STRUCT(a)) FROM testData WHERE a = 0 GROUP BY 1 -- !query schema -struct> +struct> -- !query output diff --git a/sql/core/src/test/resources/sql-tests/results/struct.sql.out b/sql/core/src/test/resources/sql-tests/results/struct.sql.out index f294c5213d319..3b610edc47169 100644 --- a/sql/core/src/test/resources/sql-tests/results/struct.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/struct.sql.out @@ -83,7 +83,7 @@ struct -- !query SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x -- !query schema -struct +struct -- !query output 1 delta 2 eta diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out index ed7ab5a342c12..d046ff249379f 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out @@ -85,7 +85,7 @@ FROM various_maps struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 +cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 -- !query @@ -113,7 +113,7 @@ FROM various_maps struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 +cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out index 6403406413db9..da5256f5c0453 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out @@ -87,7 +87,7 @@ struct> +struct> -- !query output