Skip to content

Commit

Permalink
[SPARK-31615][SQL] Pretty string output for sql method of RuntimeRepl…
Browse files Browse the repository at this point in the history
…aceable expressions

### What changes were proposed in this pull request?

The RuntimeReplaceable ones are runtime replaceable, thus, their original parameters are not going to be resolved to PrettyAttribute and remain debug style string if we directly implement their `sql` methods with their parameters' `sql` methods.

This PR is raised with suggestions by maropu and cloud-fan https://github.com/apache/spark/pull/28402/files#r417656589. In this PR, we re-implement the `sql` methods of  the RuntimeReplaceable ones with toPettySQL

### Why are the changes needed?

Consistency of schema output between RuntimeReplaceable expressions and normal ones.

For example, `date_format` vs `to_timestamp`, before this PR, they output differently

#### Before
```sql
select date_format(timestamp '2019-10-06', 'yyyy-MM-dd uuuu')
struct<date_format(TIMESTAMP '2019-10-06 00:00:00', yyyy-MM-dd uuuu):string>

select to_timestamp("2019-10-06S10:11:12.12345", "yyyy-MM-dd'S'HH:mm:ss.SSSSSS")
struct<to_timestamp('2019-10-06S10:11:12.12345', 'yyyy-MM-dd\'S\'HH:mm:ss.SSSSSS'):timestamp>
```
#### After

```sql
select date_format(timestamp '2019-10-06', 'yyyy-MM-dd uuuu')
struct<date_format(TIMESTAMP '2019-10-06 00:00:00', yyyy-MM-dd uuuu):string>

select to_timestamp("2019-10-06T10:11:12'12", "yyyy-MM-dd'T'HH:mm:ss''SSSS")

struct<to_timestamp(2019-10-06T10:11:12'12, yyyy-MM-dd'T'HH:mm:ss''SSSS):timestamp>

````

### Does this PR introduce _any_ user-facing change?

Yes, the schema output style changed for the runtime replaceable expressions as shown in the above example

### How was this patch tested?
regenerate all related tests

Closes #28420 from yaooqinn/SPARK-31615.

Authored-by: Kent Yao <yaooqinn@hotmail.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
  • Loading branch information
yaooqinn authored and maropu committed May 7, 2020
1 parent bd6b53c commit b31ae7b
Show file tree
Hide file tree
Showing 23 changed files with 292 additions and 244 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_datetime_functions(self):
from datetime import date
df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol")
parse_result = df.select(functions.to_date(functions.col("dateCol"))).first()
self.assertEquals(date(2017, 1, 22), parse_result['to_date(`dateCol`)'])
self.assertEquals(date(2017, 1, 22), parse_result['to_date(dateCol)'])

def test_unbounded_frames(self):
from pyspark.sql import functions as F
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -323,6 +324,19 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable {
// two `RuntimeReplaceable` are considered to be semantically equal if their "child" expressions
// are semantically equal.
override lazy val canonicalized: Expression = child.canonicalized

/**
* Only used to generate SQL representation of this expression.
*
* Implementations should override this with original parameters
*/
def exprsReplaced: Seq[Expression]

override def sql: String = mkString(exprsReplaced.map(_.sql))

def mkString(childrenString: Seq[String]): String = {
prettyName + childrenString.mkString("(", ", ", ")")
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.util.{DateTimeUtils, LegacyDateFormats, Tim
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.catalyst.util.LegacyDateFormats.SIMPLE_DATE_FORMAT
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
Expand Down Expand Up @@ -1205,8 +1206,9 @@ case class DatetimeSub(
start: Expression,
interval: Expression,
child: Expression) extends RuntimeReplaceable {
override def exprsReplaced: Seq[Expression] = Seq(start, interval)
override def toString: String = s"$start - $interval"
override def sql: String = s"${start.sql} - ${interval.sql}"
override def mkString(childrenString: Seq[String]): String = childrenString.mkString(" - ")
}

/**
Expand Down Expand Up @@ -1553,14 +1555,8 @@ case class ParseToDate(left: Expression, format: Option[Expression], child: Expr
this(left, None, Cast(left, DateType))
}

override def exprsReplaced: Seq[Expression] = left +: format.toSeq
override def flatArguments: Iterator[Any] = Iterator(left, format)
override def sql: String = {
if (format.isDefined) {
s"$prettyName(${left.sql}, ${format.get.sql})"
} else {
s"$prettyName(${left.sql})"
}
}

override def prettyName: String = "to_date"
}
Expand Down Expand Up @@ -1601,13 +1597,7 @@ case class ParseToTimestamp(left: Expression, format: Option[Expression], child:
def this(left: Expression) = this(left, None, Cast(left, TimestampType))

override def flatArguments: Iterator[Any] = Iterator(left, format)
override def sql: String = {
if (format.isDefined) {
s"$prettyName(${left.sql}, ${format.get.sql})"
} else {
s"$prettyName(${left.sql})"
}
}
override def exprsReplaced: Seq[Expression] = left +: format.toSeq

override def prettyName: String = "to_timestamp"
override def dataType: DataType = TimestampType
Expand Down Expand Up @@ -2161,7 +2151,8 @@ case class DatePart(field: Expression, source: Expression, child: Expression)
}

override def flatArguments: Iterator[Any] = Iterator(field, source)
override def sql: String = s"$prettyName(${field.sql}, ${source.sql})"
override def exprsReplaced: Seq[Expression] = Seq(field, source)

override def prettyName: String = "date_part"
}

Expand Down Expand Up @@ -2221,8 +2212,12 @@ case class Extract(field: Expression, source: Expression, child: Expression)
}

override def flatArguments: Iterator[Any] = Iterator(field, source)
override def sql: String = s"$prettyName(${field.sql} FROM ${source.sql})"
override def prettyName: String = "extract"

override def exprsReplaced: Seq[Expression] = Seq(field, source)

override def mkString(childrenString: Seq[String]): String = {
prettyName + childrenString.mkString("(", " FROM ", ")")
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.TypeUtils
Expand Down Expand Up @@ -138,7 +138,7 @@ case class IfNull(left: Expression, right: Expression, child: Expression)
}

override def flatArguments: Iterator[Any] = Iterator(left, right)
override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
override def exprsReplaced: Seq[Expression] = Seq(left, right)
}


Expand All @@ -158,7 +158,7 @@ case class NullIf(left: Expression, right: Expression, child: Expression)
}

override def flatArguments: Iterator[Any] = Iterator(left, right)
override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
override def exprsReplaced: Seq[Expression] = Seq(left, right)
}


Expand All @@ -177,7 +177,7 @@ case class Nvl(left: Expression, right: Expression, child: Expression) extends R
}

override def flatArguments: Iterator[Any] = Iterator(left, right)
override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
override def exprsReplaced: Seq[Expression] = Seq(left, right)
}


Expand All @@ -199,7 +199,7 @@ case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression, child:
}

override def flatArguments: Iterator[Any] = Iterator(expr1, expr2, expr3)
override def sql: String = s"$prettyName(${expr1.sql}, ${expr2.sql}, ${expr3.sql})"
override def exprsReplaced: Seq[Expression] = Seq(expr1, expr2, expr3)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1695,7 +1695,7 @@ case class Right(str: Expression, len: Expression, child: Expression) extends Ru
}

override def flatArguments: Iterator[Any] = Iterator(str, len)
override def sql: String = s"$prettyName(${str.sql}, ${len.sql})"
override def exprsReplaced: Seq[Expression] = Seq(str, len)
}

/**
Expand All @@ -1717,7 +1717,7 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run
}

override def flatArguments: Iterator[Any] = Iterator(str, len)
override def sql: String = s"$prettyName(${str.sql}, ${len.sql})"
override def exprsReplaced: Seq[Expression] = Seq(str, len)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ package object util extends Logging {
PrettyAttribute(usePrettyExpression(e.child).sql + "." + name, e.dataType)
case e: GetArrayStructFields =>
PrettyAttribute(usePrettyExpression(e.child) + "." + e.field.name, e.dataType)
case r: RuntimeReplaceable =>
PrettyAttribute(r.mkString(r.exprsReplaced.map(toPrettySQL)), r.dataType)
}

def quoteIdentifier(name: String): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
| org.apache.spark.sql.catalyst.expressions.DateAdd | date_add | SELECT date_add('2016-07-30', 1) | struct<date_add(CAST(2016-07-30 AS DATE), 1):date> |
| org.apache.spark.sql.catalyst.expressions.DateDiff | datediff | SELECT datediff('2009-07-31', '2009-07-30') | struct<datediff(CAST(2009-07-31 AS DATE), CAST(2009-07-30 AS DATE)):int> |
| org.apache.spark.sql.catalyst.expressions.DateFormatClass | date_format | SELECT date_format('2016-04-08', 'y') | struct<date_format(CAST(2016-04-08 AS TIMESTAMP), y):string> |
| org.apache.spark.sql.catalyst.expressions.DatePart | date_part | SELECT date_part('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456') | struct<date_part('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456'):int> |
| org.apache.spark.sql.catalyst.expressions.DatePart | date_part | SELECT date_part('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456') | struct<date_part(YEAR, TIMESTAMP '2019-08-12 01:00:00.123456'):int> |
| org.apache.spark.sql.catalyst.expressions.DateSub | date_sub | SELECT date_sub('2016-07-30', 1) | struct<date_sub(CAST(2016-07-30 AS DATE), 1):date> |
| org.apache.spark.sql.catalyst.expressions.DayOfMonth | day | SELECT day('2009-07-30') | struct<day(CAST(2009-07-30 AS DATE)):int> |
| org.apache.spark.sql.catalyst.expressions.DayOfMonth | dayofmonth | SELECT dayofmonth('2009-07-30') | struct<dayofmonth(CAST(2009-07-30 AS DATE)):int> |
Expand All @@ -108,7 +108,7 @@
| org.apache.spark.sql.catalyst.expressions.Explode | explode | SELECT explode(array(10, 20)) | struct<col:int> |
| org.apache.spark.sql.catalyst.expressions.Explode | explode_outer | SELECT explode_outer(array(10, 20)) | struct<col:int> |
| org.apache.spark.sql.catalyst.expressions.Expm1 | expm1 | SELECT expm1(0) | struct<EXPM1(CAST(0 AS DOUBLE)):double> |
| org.apache.spark.sql.catalyst.expressions.Extract | extract | SELECT extract(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456') | struct<extract('YEAR' FROM TIMESTAMP '2019-08-12 01:00:00.123456'):int> |
| org.apache.spark.sql.catalyst.expressions.Extract | extract | SELECT extract(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456') | struct<extract(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456'):int> |
| org.apache.spark.sql.catalyst.expressions.Factorial | factorial | SELECT factorial(5) | struct<factorial(5):bigint> |
| org.apache.spark.sql.catalyst.expressions.FindInSet | find_in_set | SELECT find_in_set('ab','abc,b,ab,c,def') | struct<find_in_set(ab, abc,b,ab,c,def):int> |
| org.apache.spark.sql.catalyst.expressions.Flatten | flatten | SELECT flatten(array(array(1, 2), array(3, 4))) | struct<flatten(array(array(1, 2), array(3, 4))):array<int>> |
Expand All @@ -128,7 +128,7 @@
| org.apache.spark.sql.catalyst.expressions.Hour | hour | SELECT hour('2009-07-30 12:58:59') | struct<hour(CAST(2009-07-30 12:58:59 AS TIMESTAMP)):int> |
| org.apache.spark.sql.catalyst.expressions.Hypot | hypot | SELECT hypot(3, 4) | struct<HYPOT(CAST(3 AS DOUBLE), CAST(4 AS DOUBLE)):double> |
| org.apache.spark.sql.catalyst.expressions.If | if | SELECT if(1 < 2, 'a', 'b') | struct<(IF((1 < 2), a, b)):string> |
| org.apache.spark.sql.catalyst.expressions.IfNull | ifnull | SELECT ifnull(NULL, array('2')) | struct<ifnull(NULL, array('2')):array<string>> |
| org.apache.spark.sql.catalyst.expressions.IfNull | ifnull | SELECT ifnull(NULL, array('2')) | struct<ifnull(NULL, array(2)):array<string>> |
| org.apache.spark.sql.catalyst.expressions.In | in | SELECT 1 in(1, 2, 3) | struct<(1 IN (1, 2, 3)):boolean> |
| org.apache.spark.sql.catalyst.expressions.InitCap | initcap | SELECT initcap('sPark sql') | struct<initcap(sPark sql):string> |
| org.apache.spark.sql.catalyst.expressions.Inline | inline | SELECT inline(array(struct(1, 'a'), struct(2, 'b'))) | struct<col1:int,col2:string> |
Expand All @@ -147,7 +147,7 @@
| org.apache.spark.sql.catalyst.expressions.LastDay | last_day | SELECT last_day('2009-01-12') | struct<last_day(CAST(2009-01-12 AS DATE)):date> |
| org.apache.spark.sql.catalyst.expressions.Lead | lead | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.Least | least | SELECT least(10, 9, 2, 4, 3) | struct<least(10, 9, 2, 4, 3):int> |
| org.apache.spark.sql.catalyst.expressions.Left | left | SELECT left('Spark SQL', 3) | struct<left('Spark SQL', 3):string> |
| org.apache.spark.sql.catalyst.expressions.Left | left | SELECT left('Spark SQL', 3) | struct<left(Spark SQL, 3):string> |
| org.apache.spark.sql.catalyst.expressions.Length | character_length | SELECT character_length('Spark SQL ') | struct<character_length(Spark SQL ):int> |
| org.apache.spark.sql.catalyst.expressions.Length | char_length | SELECT char_length('Spark SQL ') | struct<char_length(Spark SQL ):int> |
| org.apache.spark.sql.catalyst.expressions.Length | length | SELECT length('Spark SQL ') | struct<length(Spark SQL ):int> |
Expand Down Expand Up @@ -189,13 +189,13 @@
| org.apache.spark.sql.catalyst.expressions.Not | not | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.Now | now | SELECT now() | struct<now():timestamp> |
| org.apache.spark.sql.catalyst.expressions.NullIf | nullif | SELECT nullif(2, 2) | struct<nullif(2, 2):int> |
| org.apache.spark.sql.catalyst.expressions.Nvl | nvl | SELECT nvl(NULL, array('2')) | struct<nvl(NULL, array('2')):array<string>> |
| org.apache.spark.sql.catalyst.expressions.Nvl | nvl | SELECT nvl(NULL, array('2')) | struct<nvl(NULL, array(2)):array<string>> |
| org.apache.spark.sql.catalyst.expressions.Nvl2 | nvl2 | SELECT nvl2(NULL, 2, 1) | struct<nvl2(NULL, 2, 1):int> |
| org.apache.spark.sql.catalyst.expressions.OctetLength | octet_length | SELECT octet_length('Spark SQL') | struct<octet_length(Spark SQL):int> |
| org.apache.spark.sql.catalyst.expressions.Or | or | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.Overlay | overlay | SELECT overlay('Spark SQL' PLACING '_' FROM 6) | struct<overlay(Spark SQL, _, 6, -1):string> |
| org.apache.spark.sql.catalyst.expressions.ParseToDate | to_date | SELECT to_date('2009-07-30 04:17:52') | struct<to_date('2009-07-30 04:17:52'):date> |
| org.apache.spark.sql.catalyst.expressions.ParseToTimestamp | to_timestamp | SELECT to_timestamp('2016-12-31 00:12:00') | struct<to_timestamp('2016-12-31 00:12:00'):timestamp> |
| org.apache.spark.sql.catalyst.expressions.ParseToDate | to_date | SELECT to_date('2009-07-30 04:17:52') | struct<to_date(2009-07-30 04:17:52):date> |
| org.apache.spark.sql.catalyst.expressions.ParseToTimestamp | to_timestamp | SELECT to_timestamp('2016-12-31 00:12:00') | struct<to_timestamp(2016-12-31 00:12:00):timestamp> |
| org.apache.spark.sql.catalyst.expressions.ParseUrl | parse_url | SELECT parse_url('http://spark.apache.org/path?query=1', 'HOST') | struct<parse_url(http://spark.apache.org/path?query=1, HOST):string> |
| org.apache.spark.sql.catalyst.expressions.PercentRank | percent_rank | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.Pi | pi | SELECT pi() | struct<PI():double> |
Expand All @@ -215,7 +215,7 @@
| org.apache.spark.sql.catalyst.expressions.Remainder | % | SELECT 2 % 1.8 | struct<(CAST(CAST(2 AS DECIMAL(1,0)) AS DECIMAL(2,1)) % CAST(1.8 AS DECIMAL(2,1))):decimal(2,1)> |
| org.apache.spark.sql.catalyst.expressions.Remainder | mod | SELECT 2 % 1.8 | struct<(CAST(CAST(2 AS DECIMAL(1,0)) AS DECIMAL(2,1)) % CAST(1.8 AS DECIMAL(2,1))):decimal(2,1)> |
| org.apache.spark.sql.catalyst.expressions.Reverse | reverse | SELECT reverse('Spark SQL') | struct<reverse(Spark SQL):string> |
| org.apache.spark.sql.catalyst.expressions.Right | right | SELECT right('Spark SQL', 3) | struct<right('Spark SQL', 3):string> |
| org.apache.spark.sql.catalyst.expressions.Right | right | SELECT right('Spark SQL', 3) | struct<right(Spark SQL, 3):string> |
| org.apache.spark.sql.catalyst.expressions.Rint | rint | SELECT rint(12.3456) | struct<ROUND(CAST(12.3456 AS DOUBLE)):double> |
| org.apache.spark.sql.catalyst.expressions.Rollup | rollup | SELECT name, age, count(*) FROM VALUES (2, 'Alice'), (5, 'Bob') people(age, name) GROUP BY rollup(name, age) | struct<name:string,age:int,count(1):bigint> |
| org.apache.spark.sql.catalyst.expressions.Round | round | SELECT round(2.5, 0) | struct<round(2.5, 0):decimal(2,0)> |
Expand Down
5 changes: 5 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/extract.sql
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,8 @@ select extract('doy', c) from t;
select extract('hour', c) from t;
select extract('minute', c) from t;
select extract('second', c) from t;

select c - i from t;
select year(c - i) from t;
select extract(year from c - i) from t;
select extract(month from to_timestamp(c) - i) from t;
Loading

0 comments on commit b31ae7b

Please sign in to comment.