Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
WangGuangxin committed Oct 13, 2019
1 parent d21c683 commit d2328a0
Showing 1 changed file with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{Alias, NamedExpression, PredicateHelper}
import org.apache.spark.sql.catalyst.expressions.{NamedExpression, PredicateHelper}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, OrderIrrelevantAggs}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -54,17 +54,20 @@ object RemoveSortInSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}

private def isOrderIrrelevantAggs(expr: NamedExpression): Boolean = {
expr match {
case Alias(AggregateExpression(_: OrderIrrelevantAggs, _, _, _), _) => true
case _ => false
private def isOrderIrrelevantAggs(aggs: Seq[NamedExpression]): Boolean = {
val aggExpressions = aggs.flatMap { e =>
e.collect {
case ae: AggregateExpression => ae
}
}

aggExpressions.forall(_.aggregateFunction.isInstanceOf[OrderIrrelevantAggs])
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case j @ Join(originLeft, originRight, _, _, _) =>
j.copy(left = removeTopLevelSort(originLeft), right = removeTopLevelSort(originRight))
case g @ Aggregate(_, aggs, originChild) if aggs.forall(isOrderIrrelevantAggs) =>
case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) =>
g.copy(child = removeTopLevelSort(originChild))
}
}

0 comments on commit d2328a0

Please sign in to comment.