Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -336,6 +337,8 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable {

override def sql: String = mkString(exprsReplaced.map(_.sql))

final override val nodePatterns: Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE)

def mkString(childrenString: Seq[String]): String = {
prettyName + childrenString.mkString("(", ", ", ")")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes, UnevaluableAggregate}
import org.apache.spark.sql.catalyst.trees.TreePattern.{COUNT_IF, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, LongType}

Expand Down Expand Up @@ -48,6 +49,8 @@ case class CountIf(predicate: Expression) extends UnevaluableAggregate with Impl

override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType)

final override val nodePatterns: Seq[TreePattern] = Seq(COUNT_IF)

override def checkInputDataTypes(): TypeCheckResult = predicate.dataType match {
case BooleanType =>
TypeCheckResult.TypeCheckSuccess
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreePattern.{BOOL_AGG, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types._

Expand All @@ -31,6 +32,8 @@ abstract class UnevaluableBooleanAggBase(arg: Expression)

override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType)

final override val nodePatterns: Seq[TreePattern] = Seq(BOOL_AGG)

override def checkInputDataTypes(): TypeCheckResult = {
arg.dataType match {
case dt if dt != BooleanType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.commons.text.StringEscapeUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{CURRENT_LIKE, TreePattern}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, LegacyDateFormats, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
Expand Down Expand Up @@ -86,6 +87,7 @@ case class CurrentTimeZone() extends LeafExpression with Unevaluable {
override def nullable: Boolean = false
override def dataType: DataType = StringType
override def prettyName: String = "current_timezone"
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
}

/**
Expand Down Expand Up @@ -122,6 +124,8 @@ case class CurrentDate(timeZoneId: Option[String] = None)

override def dataType: DataType = DateType

final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

Expand All @@ -135,6 +139,7 @@ abstract class CurrentTimestampLike() extends LeafExpression with CodegenFallbac
override def nullable: Boolean = false
override def dataType: DataType = TimestampType
override def eval(input: InternalRow): Any = currentTimestamp()
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{CURRENT_LIKE, TreePattern}
import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -164,6 +165,7 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable {
override def dataType: DataType = StringType
override def nullable: Boolean = false
override def prettyName: String = "current_database"
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
}

/**
Expand All @@ -182,6 +184,7 @@ case class CurrentCatalog() extends LeafExpression with Unevaluable {
override def dataType: DataType = StringType
override def nullable: Boolean = false
override def prettyName: String = "current_catalog"
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
}

// scalastyle:off line.size.limit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.types._
Expand All @@ -42,7 +43,8 @@ import org.apache.spark.sql.types._
* how RuntimeReplaceable does.
*/
object ReplaceExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
_.containsAnyPattern(RUNTIME_REPLACEABLE, COUNT_IF, BOOL_AGG)) {
case e: RuntimeReplaceable => e.child
case CountIf(predicate) => Count(new NullIf(predicate, Literal.FalseLiteral))
case BoolOr(arg) => Max(arg)
Expand All @@ -57,7 +59,8 @@ object ReplaceExpressions extends Rule[LogicalPlan] {
* WHERE (SELECT 1 FROM (SELECT A FROM TABLE B WHERE COL1 > 10) LIMIT 1) IS NOT NULL
*/
object RewriteNonCorrelatedExists extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
_.containsPattern(EXISTS_SUBQUERY)) {
case exists: Exists if exists.children.isEmpty =>
IsNotNull(
ScalarSubquery(
Expand All @@ -77,7 +80,7 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
val currentTime = Literal.create(timestamp, timeExpr.dataType)
val timezone = Literal.create(conf.sessionLocalTimeZone, StringType)

plan transformAllExpressions {
plan.transformAllExpressionsWithPruning(_.containsPattern(CURRENT_LIKE)) {
case currentDate @ CurrentDate(Some(timeZoneId)) =>
currentDates.getOrElseUpdate(timeZoneId, {
Literal.create(
Expand All @@ -101,7 +104,7 @@ case class GetCurrentDatabaseAndCatalog(catalogManager: CatalogManager) extends
val currentNamespace = catalogManager.currentNamespace.quoted
val currentCatalog = catalogManager.currentCatalog.name()

plan transformAllExpressions {
plan.transformAllExpressionsWithPruning(_.containsPattern(CURRENT_LIKE)) {
case CurrentDatabase() =>
Literal.create(currentNamespace, StringType)
case CurrentCatalog() =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ object TreePattern extends Enumeration {
val APPEND_COLUMNS: Value = Value
val BINARY_ARITHMETIC: Value = Value
val BINARY_COMPARISON: Value = Value
val BOOL_AGG: Value = Value
val CASE_WHEN: Value = Value
val CAST: Value = Value
val CONCAT: Value = Value
val COUNT: Value = Value
val COUNT_IF: Value = Value
val CREATE_NAMED_STRUCT: Value = Value
val CURRENT_LIKE: Value = Value
val DESERIALIZE_TO_OBJECT: Value = Value
val DYNAMIC_PRUNING_SUBQUERY: Value = Value
val EXISTS_SUBQUERY = Value
Expand All @@ -54,6 +57,7 @@ object TreePattern extends Enumeration {
val SERIALIZE_FROM_OBJECT: Value = Value
val OUTER_REFERENCE: Value = Value
val PLAN_EXPRESSION: Value = Value
val RUNTIME_REPLACEABLE: Value = Value
val SCALAR_SUBQUERY: Value = Value
val TRUE_OR_FALSE_LITERAL: Value = Value
val WINDOW_EXPRESSION: Value = Value
Expand Down