Skip to content

Commit

Permalink
[CARMEL-6265] Only push down low cost expression (#1081)
Browse files Browse the repository at this point in the history
* Only push down low cost expression

* fix
  • Loading branch information
wangyum authored and GitHub Enterprise committed Oct 2, 2022
1 parent c67115a commit d809219
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,20 +186,24 @@ object PushPartialAggregationThroughJoin extends Rule[LogicalPlan]
}).asInstanceOf[NamedExpression]
}

private def lowerCostExp(ae: AggregateExpression): Boolean = {
PredicateReorder.expressionCost(ae) <= 100
}

private def pushableAggExp(ae: AggregateExpression): Boolean = ae match {
case AggregateExpression(_: Sum, Complete, false, None, _) => true
case AggregateExpression(_: Min, Complete, false, None, _) => true
case AggregateExpression(_: Max, Complete, false, None, _) => true
case AggregateExpression(_: First, Complete, false, None, _) => true
case AggregateExpression(_: Last, Complete, false, None, _) => true
case AggregateExpression(_: Sum, Complete, false, None, _) => lowerCostExp(ae)
case AggregateExpression(_: Min, Complete, false, None, _) => lowerCostExp(ae)
case AggregateExpression(_: Max, Complete, false, None, _) => lowerCostExp(ae)
case AggregateExpression(_: First, Complete, false, None, _) => lowerCostExp(ae)
case AggregateExpression(_: Last, Complete, false, None, _) => lowerCostExp(ae)
case AggregateExpression(Average(e), Complete, false, None, _) =>
e.dataType.isInstanceOf[NumericType]
e.dataType.isInstanceOf[NumericType] && lowerCostExp(ae)
case _ => false
}

// Support count(*), count(id)
private def pushableCountExp(ae: AggregateExpression): Boolean = ae match {
case AggregateExpression(_: Count, Complete, false, None, _) => true
case AggregateExpression(_: Count, Complete, false, None, _) => lowerCostExp(ae)
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion.InConversion
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Cast, CheckOverflow, CheckOverflowInSum, Divide, Expression, Literal, PromotePrecision}
import org.apache.spark.sql.catalyst.expressions.{Cast, CheckOverflow, CheckOverflowInSum, Divide, Expression, If, Literal, PromotePrecision}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Sum}
import org.apache.spark.sql.catalyst.optimizer.customAnalyze._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, LongType}
import org.apache.spark.sql.types._

// Custom Analyzer to exclude DecimalPrecision rule
object ExcludeDecimalPrecisionAnalyzer extends Analyzer(
Expand Down Expand Up @@ -636,4 +636,13 @@ class PushPartialAggregationThroughJoinSuite extends PlanTest {
comparePlans(Optimize.execute(originalQuery), ColumnPruning(originalQuery))
}
}

test("Do not push down aggregate expressions if it's not lower cost expression") {
val originalQuery = testRelation1
.join(testRelation2, joinType = Inner, condition = Some('a === 'x))
.groupBy()(sum(If('y.cast(StringType) likeAny("1", "2"), 1, 0)).as("sum_y"))
.analyze

comparePlans(Optimize.execute(originalQuery), ColumnPruning(originalQuery))
}
}

0 comments on commit d809219

Please sign in to comment.