Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
typeCoercionRules ++
extendedRules : _*),
Batch("Check Analysis", Once,
CheckResolution),
CheckResolution,
CheckAggregation),
Batch("AnalysisOperators", fixedPoint,
EliminateAnalysisOperators)
)
Expand All @@ -88,6 +89,32 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
}
}

/**
* Checks for non-aggregated attributes with aggregation
*/
object CheckAggregation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transform {
case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
def isValidAggregateExpression(expr: Expression): Boolean = expr match {
case _: AggregateExpression => true
case e: Attribute => groupingExprs.contains(e)
case e if groupingExprs.contains(e) => true
case e if e.references.isEmpty => true
case e => e.children.forall(isValidAggregateExpression)
}

aggregateExprs.foreach { e =>
if (!isValidAggregateExpression(e)) {
throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
}
}

aggregatePlan
}
}
}

/**
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
Expand Down Expand Up @@ -204,18 +231,17 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
*/
object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
if aggregate.resolved && containsAggregate(havingCondition) => {
val evaluatedCondition = Alias(havingCondition, "havingCondition")()
val aggExprsWithHaving = evaluatedCondition +: originalAggExprs

Project(aggregate.output,
Filter(evaluatedCondition.toAttribute,
aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
}

}

protected def containsAggregate(condition: Expression): Boolean =
condition
.collect { case ae: AggregateExpression => ae }
Expand Down
26 changes: 26 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.test._
import org.scalatest.BeforeAndAfterAll
Expand Down Expand Up @@ -694,4 +695,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(
sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1)
}

test("throw errors for non-aggregate attributes with aggregation") {
def checkAggregation(query: String, isInvalidQuery: Boolean = true) {
val logicalPlan = sql(query).queryExecution.logical

if (isInvalidQuery) {
val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed)
assert(
e.getMessage.startsWith("Expression not in GROUP BY"),
"Non-aggregate attribute(s) not detected\n" + logicalPlan)
} else {
// Should not throw
sql(query).queryExecution.analyzed
}
}

checkAggregation("SELECT key, COUNT(*) FROM testData")
checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false)

checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key")
checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false)

checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1")
checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false)
}
}