Skip to content

Commit

Permalink
[SPARK-28782][SQL] Generator support in aggregate expressions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Support generator in aggregate expressions.

In this PR, I check the aggregate logical plan, if its aggregateExpressions include generator, then convert this agg plan into "normal agg plan + generator plan + projection plan". I.e:
```
aggregate(with generator)
 |--child_plan
```
===>
```
project
  |--generator(resolved)
         |--aggregate
               |--child_plan
```

### Why are the changes needed?

We should support sql like:
```
select explode(array(min(a), max(a))) from t
```

### Does this PR introduce any user-facing change?
No

### How was this patch tested?

Unit test added.

Closes #25512 from WeichenXu123/explode_bug.

Authored-by: WeichenXu <weichen.xu@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
WeichenXu123 authored and cloud-fan committed Sep 5, 2019
1 parent dde3931 commit f8bc91f
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
Expand Up @@ -2003,6 +2003,58 @@ class Analyzer(
throw new AnalysisException("Only one generator allowed per select clause but found " +
generators.size + ": " + generators.map(toPrettySQL).mkString(", "))

case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) =>
val nestedGenerator = aggList.find(hasNestedGenerator).get
throw new AnalysisException("Generators are not supported when it's nested in " +
"expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator)))

case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 =>
val generators = aggList.filter(hasGenerator).map(trimAlias)
throw new AnalysisException("Only one generator allowed per aggregate clause but found " +
generators.size + ": " + generators.map(toPrettySQL).mkString(", "))

case agg @ Aggregate(groupList, aggList, child) if aggList.forall {
case AliasedGenerator(_, _, _) => true
case other => other.resolved
} && aggList.exists(hasGenerator) =>
// If generator in the aggregate list was visited, set the boolean flag true.
var generatorVisited = false

val projectExprs = Array.ofDim[NamedExpression](aggList.length)
val newAggList = aggList
.map(CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
.zipWithIndex
.flatMap {
case (AliasedGenerator(generator, names, outer), idx) =>
// It's a sanity check, this should not happen as the previous case will throw
// exception earlier.
assert(!generatorVisited, "More than one generator found in aggregate.")
generatorVisited = true

val newGenChildren: Seq[Expression] = generator.children.zipWithIndex.map {
case (e, idx) => if (e.foldable) e else Alias(e, s"_gen_input_${idx}")()
}
val newGenerator = {
val g = generator.withNewChildren(newGenChildren.map { e =>
if (e.foldable) e else e.asInstanceOf[Alias].toAttribute
}).asInstanceOf[Generator]
if (outer) GeneratorOuter(g) else g
}
val newAliasedGenerator = if (names.length == 1) {
Alias(newGenerator, names(0))()
} else {
MultiAlias(newGenerator, names)
}
projectExprs(idx) = newAliasedGenerator
newGenChildren.filter(!_.foldable).asInstanceOf[Seq[NamedExpression]]
case (other, idx) =>
projectExprs(idx) = other.toAttribute
other :: Nil
}

val newAgg = Aggregate(groupList, newAggList, child)
Project(projectExprs.toList, newAgg)

case p @ Project(projectList, child) =>
// Holds the resolved generator, if one exists in the project list.
var resolvedGenerator: Generate = null
Expand Down
Expand Up @@ -308,6 +308,40 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession {
sql("select * from values 1, 2 lateral view outer empty_gen() a as b"),
Row(1, null) :: Row(2, null) :: Nil)
}

test("generator in aggregate expression") {
withTempView("t1") {
Seq((1, 1), (1, 2), (2, 3)).toDF("c1", "c2").createTempView("t1")
checkAnswer(
sql("select explode(array(min(c2), max(c2))) from t1"),
Row(1) :: Row(3) :: Nil
)
checkAnswer(
sql("select posexplode(array(min(c2), max(c2))) from t1 group by c1"),
Row(0, 1) :: Row(1, 2) :: Row(0, 3) :: Row(1, 3) :: Nil
)
// test generator "stack" which require foldable argument
checkAnswer(
sql("select stack(2, min(c1), max(c1), min(c2), max(c2)) from t1"),
Row(1, 2) :: Row(1, 3) :: Nil
)

val msg1 = intercept[AnalysisException] {
sql("select 1 + explode(array(min(c2), max(c2))) from t1 group by c1")
}.getMessage
assert(msg1.contains("Generators are not supported when it's nested in expressions"))

val msg2 = intercept[AnalysisException] {
sql(
"""select
| explode(array(min(c2), max(c2))),
| posexplode(array(min(c2), max(c2)))
|from t1 group by c1
""".stripMargin)
}.getMessage
assert(msg2.contains("Only one generator allowed per aggregate clause"))
}
}
}

case class EmptyGenerator() extends Generator {
Expand Down

0 comments on commit f8bc91f

Please sign in to comment.