From b4cdee35491895439469ea7fb51e501250c730a7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 21 Sep 2015 18:55:47 -0700 Subject: [PATCH] keep the name of expression if possible when do cast --- .../sql/catalyst/analysis/Analyzer.scala | 33 +++++++++---------- .../scala/org/apache/spark/sql/Column.scala | 4 +-- .../spark/sql/ColumnExpressionSuite.scala | 6 ---- .../org/apache/spark/sql/DataFrameSuite.scala | 5 +++ 4 files changed, 23 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e6046055bf0f6..a11ec96dc9071 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -141,32 +141,31 @@ class Analyzer( */ object ResolveAliases extends Rule[LogicalPlan] { private def assignAliases(exprs: Seq[NamedExpression]) = { - // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need - // to traverse the whole tree. exprs.zipWithIndex.map { - case (u @ UnresolvedAlias(child), i) => - child match { - case _: UnresolvedAttribute => u - case ne: NamedExpression => ne - case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil) - case e if !e.resolved => u - case other => Alias(other, s"_c$i")() + case (expr, i) => + expr transform { + case u @ UnresolvedAlias(child) => child match { + case ne: NamedExpression => ne + case e if !e.resolved => u + case g: Generator if g.elementTypes.size > 1 => MultiAlias(g, Nil) + case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() + case other => Alias(other, s"_c$i")() + } } - case (other, _) => other - } + }.asInstanceOf[Seq[NamedExpression]] } + private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = + exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case Aggregate(groups, aggs, child) - if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) => + case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) - case g: GroupingAnalytics - if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) => + case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) => g.withNewAggs(assignAliases(g.aggregations)) - case Project(projectList, child) - if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) => + case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 807bc8c30c12d..ec818203d948f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -835,8 +835,8 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.3.0 */ def cast(to: DataType): Column = expr match { - // Lift alias out of cast so we can support col.as("name").cast(IntegerType) - case Alias(childExpr, name) => Alias(Cast(childExpr, to), name)() + // keeps the name of expression if possible when do cast. + case ne: NamedExpression => UnresolvedAlias(Cast(expr, to)) case _ => Cast(expr, to) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 4e988f074b113..fa559c9c64005 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -588,12 +588,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } } - test("lift alias out of cast") { - compareExpressions( - col("1234").as("name").cast("int").expr, - col("1234").cast("int").as("name").expr) - } - test("columns can be compared") { assert('key.desc == 'key.desc) assert('key.desc != 'key.asc) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d919877746c72..a5a3c0e3b0680 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -957,4 +957,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { expected(except) ) } + + test("SPARK-10743: keep the name of expression if possible when do cast") { + val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src") + assert(df.select($"src.i".cast(StringType)).columns.head === "i") + } }