Skip to content

Commit

Permalink
code review
Browse files Browse the repository at this point in the history
  • Loading branch information
hvanhovell committed Nov 29, 2019
1 parent 7f34e08 commit cdc390a
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 36 deletions.
Expand Up @@ -2432,6 +2432,10 @@ class Analyzer(
nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
}.copy(child = newChild)

// Don't touch collect metrics. Top-level metrics are not supported (check analysis will fail)
// and we want to retain them inside the aggregate functions.
case m: CollectMetrics => m

// todo: It's hard to write a general rule to pull out nondeterministic expressions
// from LogicalPlan, currently we only do it for UnaryNode which has same output
// schema with its child.
Expand Down
Expand Up @@ -281,6 +281,41 @@ trait CheckAnalysis extends PredicateHelper {
groupingExprs.foreach(checkValidGroupingExprs)
aggregateExprs.foreach(checkValidAggregateExpression)

case CollectMetrics(name, metrics, _) =>
if (name == null || name.isEmpty) {
operator.failAnalysis(s"observed metrics should be named: $operator")
}
// Check if an expression is a valid metric. A metric must meet the following criteria:
// - Is not a window function;
// - Is not nested aggregate function;
// - Is not a distinct aggregate function;
// - Has only non-deterministic functions that are nested inside an aggregate function;
// - Has only attributes that are nested inside an aggregate function.
def checkMetric(s: Expression, e: Expression, seenAggregate: Boolean = false): Unit = {
e match {
case _: WindowExpression =>
e.failAnalysis(
"window expressions are not allowed in observed metrics, but found: " + s.sql)
case _ if !e.deterministic && !seenAggregate =>
e.failAnalysis(s"non-deterministic expression ${s.sql} can only be used " +
"as an argument to an aggregate function.")
case a: AggregateExpression if seenAggregate =>
e.failAnalysis(
"nested aggregates are not allowed in observed metrics, but found: " + s.sql)
case a: AggregateExpression if a.isDistinct =>
e.failAnalysis(
"distinct aggregates are not allowed in observed metrics, but found: " + s.sql)
case _: Attribute if !seenAggregate =>
e.failAnalysis (s"attribute ${s.sql} can only be used as an argument to an " +
"aggregate function.")
case _: AggregateExpression =>
e.children.foreach(checkMetric (s, _, seenAggregate = true))
case _ =>
e.children.foreach(checkMetric (s, _, seenAggregate))
}
}
metrics.foreach(m => checkMetric(m, m))

case Sort(orders, _, _) =>
orders.foreach { order =>
if (!RowOrdering.isOrderable(order.dataType)) {
Expand Down
Expand Up @@ -979,34 +979,8 @@ case class CollectMetrics(
child: LogicalPlan)
extends UnaryNode {

/**
* Check if an expression is a valid metric. A metric must meet the following criteria:
* - Is not a window function;
* - Is not nested aggregate function;
* - Is not a distinct aggregate function;
* - Has only non-deterministic functions that are nested inside an aggregate function;
* - Has only attributes that are nested inside an aggregate function.
*
* @param e expression to check.
* @param seenAggregate `true` iff one of the parents on the expression is an aggregate function.
* @return `true` if the metric is valid, `false` otherwise.
*/
private def isValidMetric(e: Expression, seenAggregate: Boolean = false): Boolean = {
e match {
case _: WindowExpression => false
case a: AggregateExpression if seenAggregate || a.isDistinct => false
case _: AggregateExpression => e.children.forall(isValidMetric(_, seenAggregate = true))
case _: Nondeterministic if !seenAggregate => false
case _: Attribute if !seenAggregate => false
case _ => e.children.forall(isValidMetric(_, seenAggregate))
}
}

override lazy val resolved: Boolean = {
def metricsResolved: Boolean = metrics.forall { e =>
e.resolved && isValidMetric(e)
}
name.nonEmpty && metrics.nonEmpty && metricsResolved && childrenResolved
name.nonEmpty && metrics.nonEmpty && metrics.forall(_.resolved) && childrenResolved
}

override def output: Seq[Attribute] = child.output
Expand Down
Expand Up @@ -662,31 +662,44 @@ class AnalysisSuite extends AnalysisTest with Matchers {

// Bad name
assert(!CollectMetrics("", sum :: Nil, testRelation).resolved)
assertAnalysisError(CollectMetrics("", sum :: Nil, testRelation),
"observed metrics should be named" :: Nil)

def checkUnresolved(exprs: NamedExpression*): Unit = {
assert(!CollectMetrics("event", exprs, testRelation).resolved)
}
// No columns
checkUnresolved()
assert(!CollectMetrics("evt", Nil, testRelation).resolved)

def checkAnalysisError(exprs: Seq[NamedExpression], errors: String*): Unit = {
assertAnalysisError(CollectMetrics("event", exprs, testRelation), errors)
}

// Unwrapped attribute
checkUnresolved(a)
checkAnalysisError(
a :: Nil,
"Attribute", "can only be used as an argument to an aggregate function")

// Unwrapped non-deterministic expression
checkUnresolved(Rand(10).as("rnd"))
checkAnalysisError(
Rand(10).as("rnd") :: Nil,
"non-deterministic expression", "can only be used as an argument to an aggregate function")

// Distinct aggregate
checkUnresolved(Sum(a).toAggregateExpression(isDistinct = true).as("sum"))
checkAnalysisError(
Sum(a).toAggregateExpression(isDistinct = true).as("sum") :: Nil,
"distinct aggregates are not allowed in observed metrics, but found")

// Nested aggregate
checkUnresolved(Sum(Sum(a).toAggregateExpression()).toAggregateExpression().as("sum"))
checkAnalysisError(
Sum(Sum(a).toAggregateExpression()).toAggregateExpression().as("sum") :: Nil,
"nested aggregates are not allowed in observed metrics, but found")

// Windowed aggregate
val windowExpr = WindowExpression(
RowNumber(),
WindowSpecDefinition(Nil, a.asc :: Nil,
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)))
checkUnresolved(windowExpr.as("rn"))
checkAnalysisError(
windowExpr.as("rn") :: Nil,
"window expressions are not allowed in observed metrics, but found")
}

test("check CollectMetrics duplicates") {
Expand Down

0 comments on commit cdc390a

Please sign in to comment.