From c0551aaf0bb6c015762e5c84da41bc87743da010 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 22 Mar 2016 15:34:58 -0700 Subject: [PATCH 1/5] star expansion --- .../sql/catalyst/analysis/Analyzer.scala | 45 ++++++++++++------- .../org/apache/spark/sql/DataFrameSuite.scala | 39 ++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 8 ++++ 3 files changed, 75 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5951a70c4809a..2efb4dae397a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -380,27 +380,12 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p - // If the projection list contains Stars, expand it. case p: Project if containsStar(p.projectList) => - val expanded = p.projectList.flatMap { - case s: Star => s.expand(p.child, resolver) - case ua @ UnresolvedAlias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) => - UnresolvedAlias(child = expandStarExpression(ua.child, p.child)) :: Nil - case a @ Alias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) => - a.withNewChildren(expandStarExpression(a.child, p.child) :: Nil) - .asInstanceOf[Alias] :: Nil - case o => o :: Nil - } - Project(projectList = expanded, p.child) + p.copy(projectList = buildExpandedProjectList(p.projectList, p)) // If the aggregate function argument contains Stars, expand it. case a: Aggregate if containsStar(a.aggregateExpressions) => - val expanded = a.aggregateExpressions.flatMap { - case s: Star => s.expand(a.child, resolver) - case o if containsStar(o :: Nil) => expandStarExpression(o, a.child) :: Nil - case o => o :: Nil - }.map(_.asInstanceOf[NamedExpression]) - a.copy(aggregateExpressions = expanded) + a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a)) // If the script transformation input contains Stars, expand it. case t: ScriptTransformation if containsStar(t.input) => t.copy( @@ -413,6 +398,22 @@ class Analyzer( failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF") } + /** + * Build a project list for Project/Aggregate and expand the star if possible + */ + private def buildExpandedProjectList( + exprs: Seq[NamedExpression], + plan: UnaryNode): Seq[NamedExpression] = { + exprs.flatMap { + // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*") + case s: Star => s.expand(plan.child, resolver) + // Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b + case UnresolvedAlias(s: Star, _) => expandStarExpression(s, plan.child) :: Nil + case o if containsStar(o :: Nil) => expandStarExpression(o, plan.child) :: Nil + case o => o :: Nil + }.map(_.asInstanceOf[NamedExpression]) + } + /** * Returns true if `exprs` contains a [[Star]]. */ @@ -439,6 +440,16 @@ class Analyzer( case s: Star => s.expand(child, resolver) case o => o :: Nil }) + case p: Murmur3Hash if containsStar(p.children) => + p.copy(children = p.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + case p: Concat if containsStar(p.children) => + p.copy(children = p.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) // count(*) has been replaced by count(1) case o if containsStar(o.children) => failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d03597ee5dcad..02b587cd53da6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -181,6 +181,45 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Seq(1, 1)) } + test("Star Expansion - hash") { + val structDf = testData2.select("a", "b").as("record") + checkAnswer( + structDf.groupBy($"a", $"b").agg(min(hash($"a", $"*"))), + structDf.groupBy($"a", $"b").agg(min(hash($"a", $"a", $"b")))) + + checkAnswer( + structDf.groupBy($"a", $"b").agg(hash($"a", $"*")), + structDf.groupBy($"a", $"b").agg(hash($"a", $"a", $"b"))) + + checkAnswer( + structDf.select(hash($"*")), + structDf.select(hash($"record.*"))) + + checkAnswer( + structDf.select(hash($"a", $"*")), + structDf.select(hash($"a", $"record.*"))) + } + + test("Star Expansion - concat") { + val structDf = testData2.select("a", "b").as("record") + + checkAnswer( + structDf.groupBy($"a", $"b").agg(min(concat($"a", $"*"))), + structDf.groupBy($"a", $"b").agg(min(concat($"a", $"a", $"b")))) + + checkAnswer( + structDf.groupBy($"a", $"b").agg(concat($"a", $"*")), + structDf.groupBy($"a", $"b").agg(concat($"a", $"a", $"b"))) + + checkAnswer( + structDf.select(concat($"*")), + structDf.select(concat($"record.*"))) + + checkAnswer( + structDf.select(concat($"a", $"*")), + structDf.select(concat($"a", $"record.*"))) + } + test("Star Expansion - explode should fail with a meaningful message if it takes a star") { val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") val e = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9f2233d5d821b..1040a437738cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1937,6 +1937,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("Star Expansion - group by") { + withSQLConf("spark.sql.retainGroupColumns" -> "false") { + checkAnswer( + testData2.groupBy($"a", $"b").agg($"*"), + sql("SELECT * FROM testData2 group by a, b")) + } + } + test("Common subexpression elimination") { // TODO: support subexpression elimination in whole stage codegen withSQLConf("spark.sql.codegen.wholeStage" -> "false") { From 4368a6ce88619dd99e20f73393710ae7e16a4951 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 22 Mar 2016 21:42:43 -0700 Subject: [PATCH 2/5] address comments. --- .../sql/catalyst/analysis/Analyzer.scala | 5 ----- .../org/apache/spark/sql/DataFrameSuite.scala | 20 ------------------- 2 files changed, 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2efb4dae397a2..5b60b43c5ed31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -445,11 +445,6 @@ class Analyzer( case s: Star => s.expand(child, resolver) case o => o :: Nil }) - case p: Concat if containsStar(p.children) => - p.copy(children = p.children.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - }) // count(*) has been replaced by count(1) case o if containsStar(o.children) => failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 02b587cd53da6..f97b502d0a386 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -200,26 +200,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { structDf.select(hash($"a", $"record.*"))) } - test("Star Expansion - concat") { - val structDf = testData2.select("a", "b").as("record") - - checkAnswer( - structDf.groupBy($"a", $"b").agg(min(concat($"a", $"*"))), - structDf.groupBy($"a", $"b").agg(min(concat($"a", $"a", $"b")))) - - checkAnswer( - structDf.groupBy($"a", $"b").agg(concat($"a", $"*")), - structDf.groupBy($"a", $"b").agg(concat($"a", $"a", $"b"))) - - checkAnswer( - structDf.select(concat($"*")), - structDf.select(concat($"record.*"))) - - checkAnswer( - structDf.select(concat($"a", $"*")), - structDf.select(concat($"a", $"record.*"))) - } - test("Star Expansion - explode should fail with a meaningful message if it takes a star") { val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") val e = intercept[AnalysisException] { From 73160a3a419b6b2796d574894d89101bccdcba7d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 23 Mar 2016 06:06:57 -0700 Subject: [PATCH 3/5] address comments. --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5b60b43c5ed31..a03d2a3d7dfde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -408,7 +408,7 @@ class Analyzer( // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*") case s: Star => s.expand(plan.child, resolver) // Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b - case UnresolvedAlias(s: Star, _) => expandStarExpression(s, plan.child) :: Nil + case UnresolvedAlias(s: Star, _) => s.expand(plan.child, resolver) case o if containsStar(o :: Nil) => expandStarExpression(o, plan.child) :: Nil case o => o :: Nil }.map(_.asInstanceOf[NamedExpression]) From 4c8d24acf3610f4c68dfd572d82492194061db60 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 23 Mar 2016 09:14:19 -0700 Subject: [PATCH 4/5] address comments. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a03d2a3d7dfde..1f711972dab7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -382,10 +382,10 @@ class Analyzer( case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. case p: Project if containsStar(p.projectList) => - p.copy(projectList = buildExpandedProjectList(p.projectList, p)) + p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) // If the aggregate function argument contains Stars, expand it. case a: Aggregate if containsStar(a.aggregateExpressions) => - a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a)) + a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) // If the script transformation input contains Stars, expand it. case t: ScriptTransformation if containsStar(t.input) => t.copy( @@ -403,13 +403,13 @@ class Analyzer( */ private def buildExpandedProjectList( exprs: Seq[NamedExpression], - plan: UnaryNode): Seq[NamedExpression] = { + plan: LogicalPlan): Seq[NamedExpression] = { exprs.flatMap { // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*") - case s: Star => s.expand(plan.child, resolver) + case s: Star => s.expand(plan, resolver) // Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b - case UnresolvedAlias(s: Star, _) => s.expand(plan.child, resolver) - case o if containsStar(o :: Nil) => expandStarExpression(o, plan.child) :: Nil + case UnresolvedAlias(s: Star, _) => s.expand(plan, resolver) + case o if containsStar(o :: Nil) => expandStarExpression(o, plan) :: Nil case o => o :: Nil }.map(_.asInstanceOf[NamedExpression]) } From 7af4eee4433672fdddab912d1096ec25774bf792 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 23 Mar 2016 18:10:55 -0700 Subject: [PATCH 5/5] address comments --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1f711972dab7e..07b0f5ee705b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -403,13 +403,13 @@ class Analyzer( */ private def buildExpandedProjectList( exprs: Seq[NamedExpression], - plan: LogicalPlan): Seq[NamedExpression] = { + child: LogicalPlan): Seq[NamedExpression] = { exprs.flatMap { // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*") - case s: Star => s.expand(plan, resolver) + case s: Star => s.expand(child, resolver) // Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b - case UnresolvedAlias(s: Star, _) => s.expand(plan, resolver) - case o if containsStar(o :: Nil) => expandStarExpression(o, plan) :: Nil + case UnresolvedAlias(s: Star, _) => s.expand(child, resolver) + case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil case o => o :: Nil }.map(_.asInstanceOf[NamedExpression]) }