Skip to content

Commit

Permalink
[SPARK-8104] [SQL] auto alias expressions in analyzer
Browse files Browse the repository at this point in the history
Currently we auto alias expression in parser. However, during parser phase we don't have enough information to do the right alias. For example, Generator that has more than 1 kind of element need MultiAlias, ExtractValue don't need Alias if it's in middle of a ExtractValue chain.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes apache#6647 from cloud-fan/alias and squashes the following commits:

552eba4 [Wenchen Fan] fix python
5b5786d [Wenchen Fan] fix agg
73a90cb [Wenchen Fan] fix case-preserve of ExtractValue
4cfd23c [Wenchen Fan] fix order by
d18f401 [Wenchen Fan] refine
9f07359 [Wenchen Fan] address comments
39c1aef [Wenchen Fan] small fix
33640ec [Wenchen Fan] auto alias expressions in analyzer
  • Loading branch information
cloud-fan authored and animesh committed Jun 25, 2015
1 parent 54ab140 commit 75cf091
Show file tree
Hide file tree
Showing 16 changed files with 150 additions and 117 deletions.
9 changes: 5 additions & 4 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def __init__(self, sparkContext, sqlContext=None):
>>> df.registerTempTable("allTypes")
>>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
[Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \
time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
Expand Down Expand Up @@ -176,17 +177,17 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
[Row(c0=u'4')]
[Row(_c0=u'4')]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
[Row(_c0=4)]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
[Row(_c0=4)]
"""
func = lambda _, it: map(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,6 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected val WHERE = Keyword("WHERE")
protected val WITH = Keyword("WITH")

protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = {
exprs.zipWithIndex.map {
case (ne: NamedExpression, _) => ne
case (e, i) => Alias(e, s"c$i")()
}
}

protected lazy val start: Parser[LogicalPlan] =
start1 | insert | cte

Expand All @@ -130,8 +123,8 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
val base = r.getOrElse(OneRowRelation)
val withFilter = f.map(Filter(_, base)).getOrElse(base)
val withProjection = g
.map(Aggregate(_, assignAliases(p), withFilter))
.getOrElse(Project(assignAliases(p), withFilter))
.map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter))
.getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter))
val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct)
val withOrder = o.map(_(withHaving)).getOrElse(withHaving)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -74,10 +72,10 @@ class Analyzer(
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
ResolveAliases ::
ExtractWindowExpressions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
TrimGroupingAliases ::
typeCoercionRules ++
extendedResolutionRules : _*)
)
Expand Down Expand Up @@ -132,12 +130,38 @@ class Analyzer(
}

/**
* Removes no-op Alias expressions from the plan.
* Replaces [[UnresolvedAlias]]s with concrete aliases.
*/
object TrimGroupingAliases extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Aggregate(groups, aggs, child) =>
Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child)
object ResolveAliases extends Rule[LogicalPlan] {
private def assignAliases(exprs: Seq[NamedExpression]) = {
// The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need
// to transform down the whole tree.
exprs.zipWithIndex.map {
case (u @ UnresolvedAlias(child), i) =>
child match {
case _: UnresolvedAttribute => u
case ne: NamedExpression => ne
case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)()
case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil)
case e if !e.resolved => u
case other => Alias(other, s"_c$i")()
}
case (other, _) => other
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case Aggregate(groups, aggs, child)
if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) =>
Aggregate(groups, assignAliases(aggs), child)

case g: GroupingAnalytics
if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) =>
g.withNewAggs(assignAliases(g.aggregations))

case Project(projectList, child)
if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) =>
Project(assignAliases(projectList), child)
}
}

Expand Down Expand Up @@ -228,7 +252,7 @@ class Analyzer(
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case i@InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
i.copy(table = EliminateSubQueries(getTable(u)))
case u: UnresolvedRelation =>
getTable(u)
Expand All @@ -248,24 +272,24 @@ class Analyzer(
Project(
projectList.flatMap {
case s: Star => s.expand(child.output, resolver)
case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) =>
case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
Alias(child = f.copy(children = expandedArgs), name)() :: Nil
case Alias(c @ CreateArray(args), name) if containsStar(args) =>
UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil
case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
Alias(c.copy(children = expandedArgs), name)() :: Nil
case Alias(c @ CreateStruct(args), name) if containsStar(args) =>
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
Alias(c.copy(children = expandedArgs), name)() :: Nil
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case o => o :: Nil
},
child)
Expand Down Expand Up @@ -353,7 +377,9 @@ class Analyzer(
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
withPosition(u) {
q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
}
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
Expand All @@ -379,6 +405,11 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}

private def trimUnresolvedAlias(ne: NamedExpression) = ne match {
case UnresolvedAlias(child) => child
case other => other
}

private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
ordering.map { order =>
// Resolve SortOrder in one round.
Expand All @@ -388,7 +419,7 @@ class Analyzer(
try {
val newOrder = order transformUp {
case u @ UnresolvedAttribute(nameParts) =>
plan.resolve(nameParts, resolver).getOrElse(u)
plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
Expand Down Expand Up @@ -586,18 +617,6 @@ class Analyzer(
/** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */
private object AliasedGenerator {
def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
case Alias(g: Generator, name)
if g.resolved &&
g.elementTypes.size > 1 &&
java.util.regex.Pattern.matches("_c[0-9]+", name) => {
// Assume the default name given by parser is "_c[0-9]+",
// TODO in long term, move the naming logic from Parser to Analyzer.
// In projection, Parser gave default name for TGF as does for normal UDF,
// but the TGF probably have multiple output columns/names.
// e.g. SELECT explode(map(key, value)) FROM src;
// Let's simply ignore the default given name for this case.
Some((g, Nil))
}
case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 =>
// If not given the default names, and the TGF with multiple output columns
failAnalysis(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,7 @@ trait CheckAnalysis {
case e => e.children.foreach(checkValidAggregateExpression)
}

val cleaned = aggregateExprs.map(_.transform {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
case Alias(g, _) => g
})

cleaned.foreach(checkValidAggregateExpression)
aggregateExprs.foreach(checkValidAggregateExpression)

case _ => // Fallbacks to the following checks
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.{errors, trees}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -206,3 +205,22 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)

override def toString: String = s"$child[$extraction]"
}

/**
* Holds the expression that has yet to be aliased.
*/
case class UnresolvedAlias(child: Expression) extends NamedExpression
with trees.UnaryNode[Expression] {

override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def name: String = throw new UnresolvedException(this, "name")

override lazy val resolved = false

override def eval(input: InternalRow = null): Any =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import scala.collection.Map

import org.apache.spark.sql.{catalyst, AnalysisException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types._

Expand All @@ -41,16 +41,22 @@ object ExtractValue {
resolver: Resolver): ExtractValue = {

(child.dataType, extraction) match {
case (StructType(fields), Literal(fieldName, StringType)) =>
val ordinal = findField(fields, fieldName.toString, resolver)
GetStructField(child, fields(ordinal), ordinal)
case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) =>
val ordinal = findField(fields, fieldName.toString, resolver)
GetArrayStructFields(child, fields(ordinal), ordinal, containsNull)
case (StructType(fields), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal)

case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull)

case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
GetArrayItem(child, extraction)

case (_: MapType, _) =>
GetMapValue(child, extraction)

case (otherType, _) =>
val errorMsg = otherType match {
case StructType(_) | ArrayType(StructType(_), _) =>
Expand Down Expand Up @@ -94,16 +100,21 @@ trait ExtractValue extends UnaryExpression {
self: Product =>
}

abstract class ExtractValueWithStruct extends ExtractValue {
self: Product =>

def field: StructField
override def toString: String = s"$child.${field.name}"
}

/**
* Returns the value of fields in the Struct `child`.
*/
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
extends ExtractValue {
extends ExtractValueWithStruct {

override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable
override def foldable: Boolean = child.foldable
override def toString: String = s"$child.${field.name}"

override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[InternalRow]
Expand All @@ -118,12 +129,9 @@ case class GetArrayStructFields(
child: Expression,
field: StructField,
ordinal: Int,
containsNull: Boolean) extends ExtractValue {
containsNull: Boolean) extends ExtractValueWithStruct {

override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def nullable: Boolean = child.nullable
override def foldable: Boolean = child.foldable
override def toString: String = s"$child.${field.name}"

override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,8 @@ object PartialAggregation {
partialEvaluations(new TreeNodeRef(e)).finalEvaluation

case e: Expression =>
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
namedGroupingExpressions.collectFirst {
case (expr, ne) if expr semanticEquals trimmed => ne.toAttribute
case (expr, ne) if expr semanticEquals e => ne.toAttribute
}.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
Expand Down Expand Up @@ -252,14 +252,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
// and aliases it with the last part of the identifier.
// and wrap it with UnresolvedAlias which will be removed later.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
// Then this will add ExtractValue("c", ExtractValue("b", a)), and alias
// the final expression as "c".
// Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as
// UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))).
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), resolver))
val aliasName = nestedFields.last
Some(Alias(fieldExprs, aliasName)())
Some(UnresolvedAlias(fieldExprs))

// No matches.
case Seq() =>
Expand Down
Loading

0 comments on commit 75cf091

Please sign in to comment.