Skip to content

Commit

Permalink
add OrderIrrelevantAggs constrain
Browse files Browse the repository at this point in the history
  • Loading branch information
WangGuangxin committed Oct 3, 2019
1 parent 75b43f5 commit e29b323
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ import org.apache.spark.sql.types._
1.5
""",
since = "1.0.0")
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
with OrderIrrelevantAggs {

override def prettyName: String = "avg"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.types._
""",
since = "1.0.0")
// scalastyle:on line.size.limit
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
case class Count(children: Seq[Expression]) extends DeclarativeAggregate with OrderIrrelevantAggs {
override def nullable: Boolean = false

// Return data type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._
50
""",
since = "1.0.0")
case class Max(child: Expression) extends DeclarativeAggregate {
case class Max(child: Expression) extends DeclarativeAggregate with OrderIrrelevantAggs {

override def children: Seq[Expression] = child :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._
-1
""",
since = "1.0.0")
case class Min(child: Expression) extends DeclarativeAggregate {
case class Min(child: Expression) extends DeclarativeAggregate with OrderIrrelevantAggs {

override def children: Seq[Expression] = child :: Nil

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.aggregate

/**
* An [[OrderIrrelevantAggs]] trait denotes those aggregate functions that its result
* has nothing to do with the order of input data.
* For example, [[Sum]] is [[OrderIrrelevantAggs]] while [[First]] is not.
*/
trait OrderIrrelevantAggs extends AggregateFunction {
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ import org.apache.spark.sql.types._
NULL
""",
since = "1.0.0")
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
with OrderIrrelevantAggs {

override def children: Seq[Expression] = child :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ case class EveryAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
false
""",
since = "3.0.0")
case class AnyAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
case class AnyAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg)
with OrderIrrelevantAggs {
override def nodeName: String = "Any"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.PredicateHelper
import org.apache.spark.sql.catalyst.expressions.aggregate.OrderIrrelevantAggs
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule

Expand Down Expand Up @@ -56,7 +57,7 @@ object RemoveSortInSubquery extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case j @ Join(oldLeft, oldRight, _, _, _) =>
j.copy(left = removeTopLevelSort(oldLeft), right = removeTopLevelSort(oldRight))
case g @ Aggregate(_, _, oldChild) =>
case g @ Aggregate(_, aggs, oldChild) if aggs.forall(_.isInstanceOf[OrderIrrelevantAggs]) =>
g.copy(child = removeTopLevelSort(oldChild))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class RemoveSortInSubquerySuite extends PlanTest {
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
val testRelationB = LocalRelation('d.int)

test("remove orderBy in groupBy subquery") {
test("remove orderBy in groupBy subquery with count aggs") {
val projectPlan = testRelation.select('a, 'b)
val unnecessaryOrderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
val groupByPlan = unnecessaryOrderByPlan.groupBy('a)(count(1))
Expand All @@ -46,6 +46,33 @@ class RemoveSortInSubquerySuite extends PlanTest {
comparePlans(Optimize.execute(optimized), correctAnswer)
}

test("remove orderBy in groupBy subquery with sum aggs") {
val projectPlan = testRelation.select('a, 'b)
val unnecessaryOrderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
val groupByPlan = unnecessaryOrderByPlan.groupBy('a)(sum('a))
val optimized = Optimize.execute(groupByPlan.analyze)
val correctAnswer = projectPlan.groupBy('a)(sum('a)).analyze
comparePlans(Optimize.execute(optimized), correctAnswer)
}

test("remove orderBy in groupBy subquery with first aggs") {
val projectPlan = testRelation.select('a, 'b)
val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
val groupByPlan = orderByPlan.groupBy('a)(first('a))
val optimized = Optimize.execute(groupByPlan.analyze)
val correctAnswer = groupByPlan.analyze
comparePlans(Optimize.execute(optimized), correctAnswer)
}

test("remove orderBy in groupBy subquery with first and count aggs") {
val projectPlan = testRelation.select('a, 'b)
val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
val groupByPlan = orderByPlan.groupBy('a)(first('a), count(1))
val optimized = Optimize.execute(groupByPlan.analyze)
val correctAnswer = groupByPlan.analyze
comparePlans(Optimize.execute(optimized), correctAnswer)
}

test("should not remove orderBy with limit in groupBy subquery") {
val projectPlan = testRelation.select('a, 'b)
val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc).limit(10)
Expand Down

0 comments on commit e29b323

Please sign in to comment.