Skip to content

Commit

Permalink
[SPARK-12799] Simplify various string output for expressions
Browse files Browse the repository at this point in the history
This PR introduces several major changes:

1. Replacing `Expression.prettyString` with `Expression.sql`

   The `prettyString` method is mostly an internal, developer faced facility for debugging purposes, and shouldn't be exposed to users.

1. Using SQL-like representation as column names for selected fields that are not named expression (back-ticks and double quotes should be removed)

   Before, we were using `prettyString` as column names when possible, and sometimes the result column names can be weird.  Here are several examples:

   Expression         | `prettyString` | `sql`      | Note
   ------------------ | -------------- | ---------- | ---------------
   `a && b`           | `a && b`       | `a AND b`  |
   `a.getField("f")`  | `a[f]`         | `a.f`      | `a` is a struct

1. Adding trait `NonSQLExpression` extending from `Expression` for expressions that don't have a SQL representation (e.g. Scala UDF/UDAF and Java/Scala object expressions used for encoders)

   `NonSQLExpression.sql` may return an arbitrary user facing string representation of the expression.

Author: Cheng Lian <lian@databricks.com>

Closes #10757 from liancheng/spark-12799.simplify-expression-string-methods.
  • Loading branch information
liancheng committed Feb 21, 2016
1 parent d806ed3 commit d9efe63
Show file tree
Hide file tree
Showing 49 changed files with 402 additions and 278 deletions.
4 changes: 2 additions & 2 deletions R/pkg/inst/tests/testthat/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1047,13 +1047,13 @@ test_that("column functions", {
schema = c("a", "b", "c"))
result <- collect(select(df, struct("a", "c")))
expected <- data.frame(row.names = 1:2)
expected$"struct(a,c)" <- list(listToStruct(list(a = 1L, c = 3L)),
expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)),
listToStruct(list(a = 4L, c = 6L)))
expect_equal(result, expected)

result <- collect(select(df, struct(df$a, df$b)))
expected <- data.frame(row.names = 1:2)
expected$"struct(a,b)" <- list(listToStruct(list(a = 1L, b = 2L)),
expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)),
listToStruct(list(a = 4L, b = 5L)))
expect_equal(result, expected)

Expand Down
32 changes: 16 additions & 16 deletions python/pyspark/sql/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,17 @@ def getField(self, name):
>>> from pyspark.sql import Row
>>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
>>> df.select(df.r.getField("b")).show()
+----+
|r[b]|
+----+
| b|
+----+
+---+
|r.b|
+---+
| b|
+---+
>>> df.select(df.r.a).show()
+----+
|r[a]|
+----+
| 1|
+----+
+---+
|r.a|
+---+
| 1|
+---+
"""
return self[name]

Expand Down Expand Up @@ -346,12 +346,12 @@ def between(self, lowerBound, upperBound):
expression is between the given columns.
>>> df.select(df.name, df.age.between(2, 4)).show()
+-----+--------------------------+
| name|((age >= 2) && (age <= 4))|
+-----+--------------------------+
|Alice| true|
| Bob| false|
+-----+--------------------------+
+-----+---------------------------+
| name|((age >= 2) AND (age <= 4))|
+-----+---------------------------+
|Alice| true|
| Bob| false|
+-----+---------------------------+
"""
return (self >= lowerBound) & (self <= upperBound)

Expand Down
10 changes: 5 additions & 5 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def __init__(self, sparkContext, sqlContext=None):
>>> df.registerTempTable("allTypes")
>>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
[Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \
time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
[Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \
dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
Expand Down Expand Up @@ -210,17 +210,17 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
[Row(_c0=u'4')]
[Row(stringLengthString(test)=u'4')]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(_c0=4)]
[Row(stringLengthInt(test)=4)]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(_c0=4)]
[Row(stringLengthInt(test)=4)]
"""
udf = UserDefinedFunction(f, returnType, name)
self._ssql_ctx.udf().registerPython(name, udf._judf)
Expand Down
30 changes: 15 additions & 15 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,22 +223,22 @@ def coalesce(*cols):
+----+----+
>>> cDf.select(coalesce(cDf["a"], cDf["b"])).show()
+-------------+
|coalesce(a,b)|
+-------------+
| null|
| 1|
| 2|
+-------------+
+--------------+
|coalesce(a, b)|
+--------------+
| null|
| 1|
| 2|
+--------------+
>>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show()
+----+----+---------------+
| a| b|coalesce(a,0.0)|
+----+----+---------------+
|null|null| 0.0|
| 1|null| 1.0|
|null| 2| 0.0|
+----+----+---------------+
+----+----+----------------+
| a| b|coalesce(a, 0.0)|
+----+----+----------------+
|null|null| 0.0|
| 1|null| 1.0|
|null| 2| 0.0|
+----+----+----------------+
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.coalesce(_to_seq(sc, cols, _to_java_column))
Expand Down Expand Up @@ -1528,7 +1528,7 @@ def array_contains(col, value):
>>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
>>> df.select(array_contains(df.data, "a")).collect()
[Row(array_contains(data,a)=True), Row(array_contains(data,a)=False)]
[Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.util.usePrettyExpression
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -165,7 +166,8 @@ class Analyzer(
case e if !e.resolved => u
case g: Generator => MultiAlias(g, Nil)
case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)()
case other => Alias(other, optionalAliasName.getOrElse(s"_c$i"))()
case e: ExtractValue => Alias(e, usePrettyExpression(e).sql)()
case e => Alias(e, optionalAliasName.getOrElse(usePrettyExpression(e).sql))()
}
}
}.asInstanceOf[Seq[NamedExpression]]
Expand Down Expand Up @@ -328,7 +330,7 @@ class Analyzer(
throw new AnalysisException(
s"Aggregate expression required for pivot, found '$aggregate'")
}
val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString
val name = if (singleAgg) value.toString else value + "_" + aggregate.sql
Alias(filteredAggregate, name)()
}
}
Expand Down Expand Up @@ -1456,7 +1458,7 @@ object CleanupAliases extends Rule[LogicalPlan] {
*/
object ResolveUpCast extends Rule[LogicalPlan] {
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " +
throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
"You can either add an explicit cast to the input data or choose a higher precision " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ trait CheckAnalysis {
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns: [$from]")
a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]")

case e: Expression if e.checkInputDataTypes().isFailure =>
e.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(message) =>
e.failAnalysis(
s"cannot resolve '${e.prettyString}' due to data type mismatch: $message")
s"cannot resolve '${e.sql}' due to data type mismatch: $message")
}

case c: Cast if !c.resolved =>
Expand Down Expand Up @@ -106,23 +106,23 @@ trait CheckAnalysis {
operator match {
case f: Filter if f.condition.dataType != BooleanType =>
failAnalysis(
s"filter expression '${f.condition.prettyString}' " +
s"filter expression '${f.condition.sql}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")

case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
failAnalysis(
s"join condition '${condition.prettyString}' " +
s"join condition '${condition.sql}' " +
s"of type ${condition.dataType.simpleString} is not a boolean.")

case j @ Join(_, _, _, Some(condition)) =>
def checkValidJoinConditionExprs(expr: Expression): Unit = expr match {
case p: Predicate =>
p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs)
case e if e.dataType.isInstanceOf[BinaryType] =>
failAnalysis(s"binary type expression ${e.prettyString} cannot be used " +
failAnalysis(s"binary type expression ${e.sql} cannot be used " +
"in join conditions")
case e if e.dataType.isInstanceOf[MapType] =>
failAnalysis(s"map type expression ${e.prettyString} cannot be used " +
failAnalysis(s"map type expression ${e.sql} cannot be used " +
"in join conditions")
case _ => // OK
}
Expand All @@ -144,13 +144,13 @@ trait CheckAnalysis {

if (!child.deterministic) {
failAnalysis(
s"nondeterministic expression ${expr.prettyString} should not " +
s"nondeterministic expression ${expr.sql} should not " +
s"appear in the arguments of an aggregate function.")
}
}
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"expression '${e.sql}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() (or first_value) if you don't care " +
"which value you get.")
Expand All @@ -163,7 +163,7 @@ trait CheckAnalysis {
// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
failAnalysis(
s"expression ${expr.prettyString} cannot be used as a grouping expression " +
s"expression ${expr.sql} cannot be used as a grouping expression " +
s"because its data type ${expr.dataType.simpleString} is not a orderable " +
s"data type.")
}
Expand All @@ -172,7 +172,7 @@ trait CheckAnalysis {
// This is just a sanity check, our analysis rule PullOutNondeterministic should
// already pull out those nondeterministic expressions and evaluate them in
// a Project node.
failAnalysis(s"nondeterministic expression ${expr.prettyString} should not " +
failAnalysis(s"nondeterministic expression ${expr.sql} should not " +
s"appear in grouping expression.")
}
}
Expand Down Expand Up @@ -217,7 +217,7 @@ trait CheckAnalysis {
case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
failAnalysis(
s"""Only a single table generating function is allowed in a SELECT clause, found:
| ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)
| ${exprs.map(_.sql).mkString(",")}""".stripMargin)

case j: Join if !j.duplicateResolved =>
val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
Expand Down Expand Up @@ -248,9 +248,9 @@ trait CheckAnalysis {
failAnalysis(
s"""nondeterministic expressions are only allowed in
|Project, Filter, Aggregate or Window, found:
| ${o.expressions.map(_.prettyString).mkString(",")}
| ${o.expressions.map(_.sql).mkString(",")}
|in operator ${operator.simpleString}
""".stripMargin)
""".stripMargin)

case _ => // Analysis successful!
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
new AttributeReference("gid", IntegerType, false)(isGenerated = true)
val groupByMap = a.groupingExpressions.collect {
case ne: NamedExpression => ne -> ne.toAttribute
case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)()
case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)()
}
val groupByAttrs = groupByMap.map(_._2)

Expand Down Expand Up @@ -184,7 +184,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
val regularAggOperatorMap = regularAggExprs.map { e =>
// Perform the actual aggregation in the initial aggregate.
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)()
val operator = Alias(e.copy(aggregateFunction = af), e.sql)()

// Select the result of the first aggregate in the last aggregate.
val result = AggregateExpression(
Expand Down Expand Up @@ -269,5 +269,5 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
// NamedExpression. This is done to prevent collisions between distinct and regular aggregate
// children, in this case attribute reuse causes the input of the regular aggregate to bound to
// the (nulled out) input of the distinct aggregate.
e -> new AttributeReference(e.prettyString, e.dataType, true)()
e -> new AttributeReference(e.sql, e.dataType, true)()
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.types.{DataType, StructType}

/**
Expand Down Expand Up @@ -67,6 +68,8 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)

override def toString: String = s"'$name"

override def sql: String = quoteIdentifier(name)
}

object UnresolvedAttribute {
Expand Down Expand Up @@ -141,11 +144,8 @@ case class UnresolvedFunction(
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false

override def prettyString: String = {
s"${name}(${children.map(_.prettyString).mkString(",")})"
}

override def toString: String = s"'$name(${children.mkString(",")})"
override def prettyName: String = name
override def toString: String = s"'$name(${children.mkString(", ")})"
}

/**
Expand Down Expand Up @@ -208,10 +208,9 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
Alias(extract, f.name)()
}

case _ => {
case _ =>
throw new AnalysisException("Can only star expand struct data types. Attribute: `" +
target.get + "`")
}
}
} else {
val from = input.inputSet.map(_.name).mkString(", ")
Expand All @@ -228,6 +227,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
* For example the SQL expression "stack(2, key, value, key, value) as (a, b)" could be represented
* as follows:
* MultiAlias(stack_function, Seq(a, b))
*
* @param child the computation being performed
* @param names the names to be associated with each output of computing [[child]].
Expand Down Expand Up @@ -284,13 +284,14 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
override lazy val resolved = false

override def toString: String = s"$child[$extraction]"
override def sql: String = s"${child.sql}[${extraction.sql}]"
}

/**
* Holds the expression that has yet to be aliased.
*
* @param child The computation that is needs to be resolved during analysis.
* @param aliasName The name if specified to be asoosicated with the result of computing [[child]]
* @param aliasName The name if specified to be associated with the result of computing [[child]]
*
*/
case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ trait ExpectsInputTypes extends Expression {
val mismatches = children.zip(inputTypes).zipWithIndex.collect {
case ((child, expected), idx) if !expected.acceptsType(child.dataType) =>
s"argument ${idx + 1} requires ${expected.simpleString} type, " +
s"however, '${child.prettyString}' is of ${child.dataType.simpleString} type."
s"however, '${child.sql}' is of ${child.dataType.simpleString} type."
}

if (mismatches.isEmpty) {
Expand Down

0 comments on commit d9efe63

Please sign in to comment.