Skip to content

Commit

Permalink
[SPARK-31808][SQL] Makes struct function's output name and class name…
Browse files Browse the repository at this point in the history
… pretty

### What changes were proposed in this pull request?

This PR proposes to set the alias, and class name in its `ExpressionInfo` for `struct`.
- Class name in `ExpressionInfo`
  - from: `org.apache.spark.sql.catalyst.expressions.NamedStruct`
  - to:`org.apache.spark.sql.catalyst.expressions.CreateNamedStruct`
- Alias name: `named_struct(col1, v, ...)` -> `struct(v, ...)`

This PR takes over #28631

### Why are the changes needed?

To show the correct output name and class names to users.

### Does this PR introduce _any_ user-facing change?

Yes.

**Before:**

```scala
scala> sql("DESC FUNCTION struct").show(false)
+------------------------------------------------------------------------------------+
|function_desc                                                                       |
+------------------------------------------------------------------------------------+
|Function: struct                                                                    |
|Class: org.apache.spark.sql.catalyst.expressions.NamedStruct                        |
|Usage: struct(col1, col2, col3, ...) - Creates a struct with the given field values.|
+------------------------------------------------------------------------------------+
```

```scala
scala> sql("SELECT struct(1, 2)").show(false)
+------------------------------+
|named_struct(col1, 1, col2, 2)|
+------------------------------+
|[1, 2]                        |
+------------------------------+
```

**After:**

```scala
scala> sql("DESC FUNCTION struct").show(false)
+------------------------------------------------------------------------------------+
|function_desc                                                                       |
+------------------------------------------------------------------------------------+
|Function: struct                                                                    |
|Class: org.apache.spark.sql.catalyst.expressions.CreateNamedStruct                  |
|Usage: struct(col1, col2, col3, ...) - Creates a struct with the given field values.|
+------------------------------------------------------------------------------------+
```

```scala
scala> sql("SELECT struct(1, 2)").show(false)
+------------+
|struct(1, 2)|
+------------+
|[1, 2]      |
+------------+
```

### How was this patch tested?

Manually tested, and Jenkins tests.

Closes #28633 from HyukjinKwon/SPARK-31808.

Authored-by: HyukjinKwon <gurwls223@apache.org>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
HyukjinKwon authored and dongjoon-hyun committed May 26, 2020
1 parent 6c80ebb commit df2a1fe
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 16 deletions.
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
Expand Down Expand Up @@ -311,7 +311,12 @@ case object NamePlaceholder extends LeafExpression with Unevaluable {
/**
* Returns a Row containing the evaluation of all children expressions.
*/
object CreateStruct extends FunctionBuilder {
object CreateStruct {
/**
* Returns a named struct with generated names or using the names when available.
* It should not be used for `struct` expressions or functions explicitly called
* by users.
*/
def apply(children: Seq[Expression]): CreateNamedStruct = {
CreateNamedStruct(children.zipWithIndex.flatMap {
case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e)
Expand All @@ -320,12 +325,23 @@ object CreateStruct extends FunctionBuilder {
})
}

/**
* Returns a named struct with a pretty SQL name. It will show the pretty SQL string
* in its output column name as if `struct(...)` was called. Should be
* used for `struct` expressions or functions explicitly called by users.
*/
def create(children: Seq[Expression]): CreateNamedStruct = {
val expr = CreateStruct(children)
expr.setTagValue(FUNC_ALIAS, "struct")
expr
}

/**
* Entry to use in the function registry.
*/
val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = {
val info: ExpressionInfo = new ExpressionInfo(
"org.apache.spark.sql.catalyst.expressions.NamedStruct",
classOf[CreateNamedStruct].getCanonicalName,
null,
"struct",
"_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.",
Expand All @@ -335,7 +351,7 @@ object CreateStruct extends FunctionBuilder {
"",
"",
"")
("struct", (info, this))
("struct", (info, this.create))
}
}

Expand Down Expand Up @@ -433,7 +449,15 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
""".stripMargin, isNull = FalseLiteral)
}

override def prettyName: String = "named_struct"
// There is an alias set at `CreateStruct.create`. If there is an alias,
// this is the struct function explicitly called by a user and we should
// respect it in the SQL string as `struct(...)`.
override def prettyName: String = getTagValue(FUNC_ALIAS).getOrElse("named_struct")

override def sql: String = getTagValue(FUNC_ALIAS).map { alias =>
val childrenSQL = children.indices.filter(_ % 2 == 1).map(children(_).sql).mkString(", ")
s"$alias($childrenSQL)"
}.getOrElse(super.sql)
}

/**
Expand Down
Expand Up @@ -1534,7 +1534,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
* Create a [[CreateStruct]] expression.
*/
override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) {
CreateStruct(ctx.argument.asScala.map(expression))
CreateStruct.create(ctx.argument.asScala.map(expression))
}

/**
Expand Down
Expand Up @@ -1306,7 +1306,7 @@ object functions {
* @since 1.4.0
*/
@scala.annotation.varargs
def struct(cols: Column*): Column = withExpr { CreateStruct(cols.map(_.expr)) }
def struct(cols: Column*): Column = withExpr { CreateStruct.create(cols.map(_.expr)) }

/**
* Creates a new struct column that composes multiple input columns.
Expand Down
Expand Up @@ -2,7 +2,7 @@
## Summary
- Number of queries: 337
- Number of expressions that missing example: 34
- Expressions missing examples: and,string,tinyint,double,smallint,date,decimal,boolean,float,binary,bigint,int,timestamp,cume_dist,dense_rank,input_file_block_length,input_file_block_start,input_file_name,lag,lead,monotonically_increasing_id,ntile,struct,!,not,or,percent_rank,rank,row_number,spark_partition_id,version,window,positive,count_min_sketch
- Expressions missing examples: and,string,tinyint,double,smallint,date,decimal,boolean,float,binary,bigint,int,timestamp,struct,cume_dist,dense_rank,input_file_block_length,input_file_block_start,input_file_name,lag,lead,monotonically_increasing_id,ntile,!,not,or,percent_rank,rank,row_number,spark_partition_id,version,window,positive,count_min_sketch
## Schema of Built-in Functions
| Class name | Function name or alias | Query example | Output schema |
| ---------- | ---------------------- | ------------- | ------------- |
Expand Down Expand Up @@ -79,6 +79,7 @@
| org.apache.spark.sql.catalyst.expressions.CreateArray | array | SELECT array(1, 2, 3) | struct<array(1, 2, 3):array<int>> |
| org.apache.spark.sql.catalyst.expressions.CreateMap | map | SELECT map(1.0, '2', 3.0, '4') | struct<map(1.0, 2, 3.0, 4):map<decimal(2,1),string>> |
| org.apache.spark.sql.catalyst.expressions.CreateNamedStruct | named_struct | SELECT named_struct("a", 1, "b", 2, "c", 3) | struct<named_struct(a, 1, b, 2, c, 3):struct<a:int,b:int,c:int>> |
| org.apache.spark.sql.catalyst.expressions.CreateNamedStruct | struct | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.CsvToStructs | from_csv | SELECT from_csv('1, 0.8', 'a INT, b DOUBLE') | struct<from_csv(1, 0.8):struct<a:int,b:double>> |
| org.apache.spark.sql.catalyst.expressions.Cube | cube | SELECT name, age, count(*) FROM VALUES (2, 'Alice'), (5, 'Bob') people(age, name) GROUP BY cube(name, age) | struct<name:string,age:int,count(1):bigint> |
| org.apache.spark.sql.catalyst.expressions.CumeDist | cume_dist | N/A | N/A |
Expand Down Expand Up @@ -171,7 +172,7 @@
| org.apache.spark.sql.catalyst.expressions.MapEntries | map_entries | SELECT map_entries(map(1, 'a', 2, 'b')) | struct<map_entries(map(1, a, 2, b)):array<struct<key:int,value:string>>> |
| org.apache.spark.sql.catalyst.expressions.MapFilter | map_filter | SELECT map_filter(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v) | struct<map_filter(map(1, 0, 2, 2, 3, -1), lambdafunction((namedlambdavariable() > namedlambdavariable()), namedlambdavariable(), namedlambdavariable())):map<int,int>> |
| org.apache.spark.sql.catalyst.expressions.MapFromArrays | map_from_arrays | SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')) | struct<map_from_arrays(array(1.0, 3.0), array(2, 4)):map<decimal(2,1),string>> |
| org.apache.spark.sql.catalyst.expressions.MapFromEntries | map_from_entries | SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))) | struct<map_from_entries(array(named_struct(col1, 1, col2, a), named_struct(col1, 2, col2, b))):map<int,string>> |
| org.apache.spark.sql.catalyst.expressions.MapFromEntries | map_from_entries | SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))) | struct<map_from_entries(array(struct(1, a), struct(2, b))):map<int,string>> |
| org.apache.spark.sql.catalyst.expressions.MapKeys | map_keys | SELECT map_keys(map(1, 'a', 2, 'b')) | struct<map_keys(map(1, a, 2, b)):array<int>> |
| org.apache.spark.sql.catalyst.expressions.MapValues | map_values | SELECT map_values(map(1, 'a', 2, 'b')) | struct<map_values(map(1, a, 2, b)):array<string>> |
| org.apache.spark.sql.catalyst.expressions.MapZipWith | map_zip_with | SELECT map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)) | struct<map_zip_with(map(1, a, 2, b), map(1, x, 2, y), lambdafunction(concat(namedlambdavariable(), namedlambdavariable()), namedlambdavariable(), namedlambdavariable(), namedlambdavariable())):map<int,string>> |
Expand All @@ -186,7 +187,6 @@
| org.apache.spark.sql.catalyst.expressions.Murmur3Hash | hash | SELECT hash('Spark', array(123), 2) | struct<hash(Spark, array(123), 2):int> |
| org.apache.spark.sql.catalyst.expressions.NTile | ntile | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.NaNvl | nanvl | SELECT nanvl(cast('NaN' as double), 123) | struct<nanvl(CAST(NaN AS DOUBLE), CAST(123 AS DOUBLE)):double> |
| org.apache.spark.sql.catalyst.expressions.NamedStruct | struct | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.NextDay | next_day | SELECT next_day('2015-01-14', 'TU') | struct<next_day(CAST(2015-01-14 AS DATE), TU):date> |
| org.apache.spark.sql.catalyst.expressions.Not | ! | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.Not | not | N/A | N/A |
Expand Down
Expand Up @@ -272,7 +272,7 @@ struct<foo:string,approx_count_distinct(a) FILTER (WHERE (b >= 0)):bigint>
-- !query
SELECT 'foo', MAX(STRUCT(a)) FILTER (WHERE b >= 1) FROM testData WHERE a = 0 GROUP BY 1
-- !query schema
struct<foo:string,max(named_struct(a, a)) FILTER (WHERE (b >= 1)):struct<a:int>>
struct<foo:string,max(struct(a)) FILTER (WHERE (b >= 1)):struct<a:int>>
-- !query output


Expand Down
Expand Up @@ -87,7 +87,7 @@ struct<foo:string,approx_count_distinct(a):bigint>
-- !query
SELECT 'foo', MAX(STRUCT(a)) FROM testData WHERE a = 0 GROUP BY 1
-- !query schema
struct<foo:string,max(named_struct(a, a)):struct<a:int>>
struct<foo:string,max(struct(a)):struct<a:int>>
-- !query output


Expand Down
Expand Up @@ -83,7 +83,7 @@ struct<ID:int,NST:string>
-- !query
SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x
-- !query schema
struct<ID:int,named_struct(STC, ST.C AS `C` AS `STC`, STD, ST.D AS `D` AS `STD`).STD:string>
struct<ID:int,struct(ST.C AS `C` AS `STC`, ST.D AS `D` AS `STD`).STD:string>
-- !query output
1 delta
2 eta
Expand Down
Expand Up @@ -85,7 +85,7 @@ FROM various_maps
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7
cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7


-- !query
Expand Down Expand Up @@ -113,7 +113,7 @@ FROM various_maps
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7
cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7


-- !query
Expand Down
Expand Up @@ -87,7 +87,7 @@ struct<foo:string,CAST(udf(cast(approx_count_distinct(cast(udf(cast(a as string)
-- !query
SELECT 'foo', MAX(STRUCT(udf(a))) FROM testData WHERE a = 0 GROUP BY udf(1)
-- !query schema
struct<foo:string,max(named_struct(col1, CAST(udf(cast(a as string)) AS INT))):struct<col1:int>>
struct<foo:string,max(struct(CAST(udf(cast(a as string)) AS INT))):struct<col1:int>>
-- !query output


Expand Down

0 comments on commit df2a1fe

Please sign in to comment.