Skip to content

Commit

Permalink
[SPARK-35146][SQL] Migrate to transformWithPruning or resolveWithPrun…
Browse files Browse the repository at this point in the history
…ing for rules in finishAnalysis.scala

### What changes were proposed in this pull request?

Added the following TreePattern enums:
- BOOL_AGG
- COUNT_IF
- CURRENT_LIKE
- RUNTIME_REPLACEABLE

Added tree traversal pruning to the following rules:
- ReplaceExpressions
- RewriteNonCorrelatedExists
- ComputeCurrentTime
- GetCurrentDatabaseAndCatalog

### Why are the changes needed?

Reduce the number of tree traversals and hence improve the query compilation latency.

Performance improvement (org.apache.spark.sql.TPCDSQuerySuite):
Rule name | Total Time (baseline) | Total Time (experiment) | experiment/baseline
ReplaceExpressions | 27546369 | 19753804 | 0.72
RewriteNonCorrelatedExists | 17304883 | 2086194 | 0.12
ComputeCurrentTime | 35751301 | 19984477 | 0.56
GetCurrentDatabaseAndCatalog | 37230787 | 18874013 | 0.51

### How was this patch tested?

Existing tests.

Closes #32461 from sigmod/finish_analysis.

Authored-by: Yingyi Bu <yingyi.bu@databricks.com>
Signed-off-by: Gengliang Wang <ltnwgl@gmail.com>
  • Loading branch information
sigmod authored and gengliangwang committed May 11, 2021
1 parent c4ca232 commit 7c9a9ec
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 4 deletions.
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -336,6 +337,8 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable {

override def sql: String = mkString(exprsReplaced.map(_.sql))

final override val nodePatterns: Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE)

def mkString(childrenString: Seq[String]): String = {
prettyName + childrenString.mkString("(", ", ", ")")
}
Expand Down
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes, UnevaluableAggregate}
import org.apache.spark.sql.catalyst.trees.TreePattern.{COUNT_IF, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, LongType}

Expand Down Expand Up @@ -48,6 +49,8 @@ case class CountIf(predicate: Expression) extends UnevaluableAggregate with Impl

override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType)

final override val nodePatterns: Seq[TreePattern] = Seq(COUNT_IF)

override def checkInputDataTypes(): TypeCheckResult = predicate.dataType match {
case BooleanType =>
TypeCheckResult.TypeCheckSuccess
Expand Down
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreePattern.{BOOL_AGG, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types._

Expand All @@ -31,6 +32,8 @@ abstract class UnevaluableBooleanAggBase(arg: Expression)

override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType)

final override val nodePatterns: Seq[TreePattern] = Seq(BOOL_AGG)

override def checkInputDataTypes(): TypeCheckResult = {
arg.dataType match {
case dt if dt != BooleanType =>
Expand Down
Expand Up @@ -27,6 +27,7 @@ import org.apache.commons.text.StringEscapeUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{CURRENT_LIKE, TreePattern}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, LegacyDateFormats, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
Expand Down Expand Up @@ -86,6 +87,7 @@ case class CurrentTimeZone() extends LeafExpression with Unevaluable {
override def nullable: Boolean = false
override def dataType: DataType = StringType
override def prettyName: String = "current_timezone"
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
}

/**
Expand Down Expand Up @@ -122,6 +124,8 @@ case class CurrentDate(timeZoneId: Option[String] = None)

override def dataType: DataType = DateType

final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

Expand All @@ -135,6 +139,7 @@ abstract class CurrentTimestampLike() extends LeafExpression with CodegenFallbac
override def nullable: Boolean = false
override def dataType: DataType = TimestampType
override def eval(input: InternalRow): Any = currentTimestamp()
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
}

/**
Expand Down
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{CURRENT_LIKE, TreePattern}
import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -164,6 +165,7 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable {
override def dataType: DataType = StringType
override def nullable: Boolean = false
override def prettyName: String = "current_database"
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
}

/**
Expand All @@ -182,6 +184,7 @@ case class CurrentCatalog() extends LeafExpression with Unevaluable {
override def dataType: DataType = StringType
override def nullable: Boolean = false
override def prettyName: String = "current_catalog"
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
}

// scalastyle:off line.size.limit
Expand Down
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.types._
Expand All @@ -42,7 +43,8 @@ import org.apache.spark.sql.types._
* how RuntimeReplaceable does.
*/
object ReplaceExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
_.containsAnyPattern(RUNTIME_REPLACEABLE, COUNT_IF, BOOL_AGG)) {
case e: RuntimeReplaceable => e.child
case CountIf(predicate) => Count(new NullIf(predicate, Literal.FalseLiteral))
case BoolOr(arg) => Max(arg)
Expand All @@ -57,7 +59,8 @@ object ReplaceExpressions extends Rule[LogicalPlan] {
* WHERE (SELECT 1 FROM (SELECT A FROM TABLE B WHERE COL1 > 10) LIMIT 1) IS NOT NULL
*/
object RewriteNonCorrelatedExists extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
_.containsPattern(EXISTS_SUBQUERY)) {
case exists: Exists if exists.children.isEmpty =>
IsNotNull(
ScalarSubquery(
Expand All @@ -77,7 +80,7 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
val currentTime = Literal.create(timestamp, timeExpr.dataType)
val timezone = Literal.create(conf.sessionLocalTimeZone, StringType)

plan transformAllExpressions {
plan.transformAllExpressionsWithPruning(_.containsPattern(CURRENT_LIKE)) {
case currentDate @ CurrentDate(Some(timeZoneId)) =>
currentDates.getOrElseUpdate(timeZoneId, {
Literal.create(
Expand All @@ -101,7 +104,7 @@ case class GetCurrentDatabaseAndCatalog(catalogManager: CatalogManager) extends
val currentNamespace = catalogManager.currentNamespace.quoted
val currentCatalog = catalogManager.currentCatalog.name()

plan transformAllExpressions {
plan.transformAllExpressionsWithPruning(_.containsPattern(CURRENT_LIKE)) {
case CurrentDatabase() =>
Literal.create(currentNamespace, StringType)
case CurrentCatalog() =>
Expand Down
Expand Up @@ -28,11 +28,14 @@ object TreePattern extends Enumeration {
val APPEND_COLUMNS: Value = Value
val BINARY_ARITHMETIC: Value = Value
val BINARY_COMPARISON: Value = Value
val BOOL_AGG: Value = Value
val CASE_WHEN: Value = Value
val CAST: Value = Value
val CONCAT: Value = Value
val COUNT: Value = Value
val COUNT_IF: Value = Value
val CREATE_NAMED_STRUCT: Value = Value
val CURRENT_LIKE: Value = Value
val DESERIALIZE_TO_OBJECT: Value = Value
val DYNAMIC_PRUNING_SUBQUERY: Value = Value
val EXISTS_SUBQUERY = Value
Expand All @@ -54,6 +57,7 @@ object TreePattern extends Enumeration {
val SERIALIZE_FROM_OBJECT: Value = Value
val OUTER_REFERENCE: Value = Value
val PLAN_EXPRESSION: Value = Value
val RUNTIME_REPLACEABLE: Value = Value
val SCALAR_SUBQUERY: Value = Value
val TRUE_OR_FALSE_LITERAL: Value = Value
val WINDOW_EXPRESSION: Value = Value
Expand Down

0 comments on commit 7c9a9ec

Please sign in to comment.