Skip to content

Commit

Permalink
SPARK-3371 : Fixed Renaming a function expression with group by gives…
Browse files Browse the repository at this point in the history
… error

Signed-off-by: ravipesala <ravindra.pesala@huawei.com>
  • Loading branch information
ravipesala committed Sep 23, 2014
1 parent f9d6220 commit bad2fd0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,25 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
case (e, i) => Alias(e, s"c$i")()
}
}

/** Creates the aliases to the grouping expressions */
protected def assignAliasesForGroups(
grpExprs: Seq[Expression],
projExprs: Seq[Expression]): Seq[NamedExpression] = {
grpExprs.zipWithIndex.map {
case (e, i) =>
var aliasForGrp:NamedExpression = null
projExprs.foreach {
case Alias(pe,pi) if pe.fastEquals(e) => aliasForGrp = Alias(e, pi)()
case _ =>
}
if (aliasForGrp == null) {
Alias(e, s"c$i")()
} else {
aliasForGrp
}
}
}

protected lazy val query: Parser[LogicalPlan] = (
select * (
Expand All @@ -166,7 +185,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
val withFilter = f.map(f => Filter(f, base)).getOrElse(base)
val withProjection =
g.map {g =>
Aggregate(assignAliases(g), assignAliases(p), withFilter)
Aggregate(assignAliasesForGroups(g,p), assignAliases(p), withFilter)
}.getOrElse(Project(assignAliases(p), withFilter))
val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -672,4 +672,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"),
("true", "false") :: Nil)
}

test("SPARK-3371 Renaming a function expression with group by gives error") {
registerFunction("len", (s: String) => s.length)
checkAnswer(
sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"),
Seq(Seq("1")))
}
}

0 comments on commit bad2fd0

Please sign in to comment.