diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index ef9c4b9af40d3..242c799dd226e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql.catalyst.optimizer -import scala.collection.mutable +import java.time.{Instant, LocalDateTime} import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ} +import org.apache.spark.sql.catalyst.trees.TreePatternBits +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, instantToMicros, localDateTimeToMicros} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -73,29 +75,30 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] { */ object ComputeCurrentTime extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val currentDates = mutable.Map.empty[String, Literal] - val timeExpr = CurrentTimestamp() - val timestamp = timeExpr.eval(EmptyRow).asInstanceOf[Long] - val currentTime = Literal.create(timestamp, timeExpr.dataType) + val instant = Instant.now() + val currentTimestampMicros = instantToMicros(instant) + val currentTime = Literal.create(currentTimestampMicros, TimestampType) val timezone = Literal.create(conf.sessionLocalTimeZone, StringType) - val localTimestamps = mutable.Map.empty[String, Literal] - plan.transformAllExpressionsWithPruning(_.containsPattern(CURRENT_LIKE)) { - case currentDate @ CurrentDate(Some(timeZoneId)) => - currentDates.getOrElseUpdate(timeZoneId, { - Literal.create(currentDate.eval().asInstanceOf[Int], DateType) - }) - case CurrentTimestamp() | Now() => currentTime - case CurrentTimeZone() => timezone - case localTimestamp @ LocalTimestamp(Some(timeZoneId)) => - localTimestamps.getOrElseUpdate(timeZoneId, { - Literal.create(localTimestamp.eval().asInstanceOf[Long], TimestampNTZType) - }) + def transformCondition(treePatternbits: TreePatternBits): Boolean = { + treePatternbits.containsPattern(CURRENT_LIKE) + } + + plan.transformDownWithSubqueries(transformCondition) { + case subQuery => + subQuery.transformAllExpressionsWithPruning(transformCondition) { + case cd: CurrentDate => + Literal.create(DateTimeUtils.microsToDays(currentTimestampMicros, cd.zoneId), DateType) + case CurrentTimestamp() | Now() => currentTime + case CurrentTimeZone() => timezone + case localTimestamp: LocalTimestamp => + val asDateTime = LocalDateTime.ofInstant(instant, localTimestamp.zoneId) + Literal.create(localDateTimeToMicros(asDateTime), TimestampNTZType) + } } } } - /** * Replaces the expression of CurrentDatabase with the current database name. * Replaces the expression of CurrentCatalog with the current catalog name. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0f8df5df3764a..d0283f4d36720 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -454,7 +454,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] * to rewrite the whole plan, include its subqueries, in one go. */ def transformWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = - transformDownWithSubqueries(f) + transformDownWithSubqueries(AlwaysProcess.fn, UnknownRuleId)(f) /** * Returns a copy of this node where the given partial function has been recursively applied @@ -479,7 +479,10 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] * first to this node, then this node's subqueries and finally this node's children. * When the partial function does not apply to a given node, it is left unchanged. */ - def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { + def transformDownWithSubqueries( + cond: TreePatternBits => Boolean = AlwaysProcess.fn, ruleId: RuleId = UnknownRuleId) + (f: PartialFunction[PlanType, PlanType]) +: PlanType = { val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { override def isDefinedAt(x: PlanType): Boolean = true @@ -487,13 +490,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] val transformed = f.applyOrElse[PlanType, PlanType](plan, identity) transformed transformExpressionsDown { case planExpression: PlanExpression[PlanType] => - val newPlan = planExpression.plan.transformDownWithSubqueries(f) + val newPlan = planExpression.plan.transformDownWithSubqueries(cond, ruleId)(f) planExpression.withNewPlan(newPlan) } } } - transformDown(g) + transformDownWithPruning(cond, ruleId)(g) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala index 9b04dcddfb2ce..c034906c09bb6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala @@ -19,10 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer import java.time.{LocalDateTime, ZoneId} +import scala.collection.JavaConverters.mapAsScalaMap +import scala.concurrent.duration._ + import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, Literal, LocalTimestamp} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, InSubquery, ListQuery, Literal, LocalTimestamp, Now} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf @@ -41,11 +44,7 @@ class ComputeCurrentTimeSuite extends PlanTest { val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = (System.currentTimeMillis() + 1) * 1000 - val lits = new scala.collection.mutable.ArrayBuffer[Long] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Long] - e - } + val lits = literals[Long](plan) assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) @@ -59,11 +58,7 @@ class ComputeCurrentTimeSuite extends PlanTest { val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = DateTimeUtils.currentDate(ZoneId.systemDefault()) - val lits = new scala.collection.mutable.ArrayBuffer[Int] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Int] - e - } + val lits = literals[Int](plan) assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) @@ -73,13 +68,9 @@ class ComputeCurrentTimeSuite extends PlanTest { test("SPARK-33469: Add current_timezone function") { val in = Project(Seq(Alias(CurrentTimeZone(), "c")()), LocalRelation()) val plan = Optimize.execute(in.analyze).asInstanceOf[Project] - val lits = new scala.collection.mutable.ArrayBuffer[String] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[UTF8String].toString - e - } + val lits = literals[UTF8String](plan) assert(lits.size == 1) - assert(lits.head == SQLConf.get.sessionLocalTimeZone) + assert(lits.head == UTF8String.fromString(SQLConf.get.sessionLocalTimeZone)) } test("analyzer should replace localtimestamp with literals") { @@ -92,14 +83,66 @@ class ComputeCurrentTimeSuite extends PlanTest { val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = DateTimeUtils.localDateTimeToMicros(LocalDateTime.now(zoneId)) - val lits = new scala.collection.mutable.ArrayBuffer[Long] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Long] - e - } + val lits = literals[Long](plan) assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } + + test("analyzer should use equal timestamps across subqueries") { + val timestampInSubQuery = Project(Seq(Alias(LocalTimestamp(), "timestamp1")()), LocalRelation()) + val listSubQuery = ListQuery(timestampInSubQuery) + val valueSearchedInSubQuery = Seq(Alias(LocalTimestamp(), "timestamp2")()) + val inFilterWithSubQuery = InSubquery(valueSearchedInSubQuery, listSubQuery) + val input = Project(Nil, Filter(inFilterWithSubQuery, LocalRelation())) + + val plan = Optimize.execute(input.analyze).asInstanceOf[Project] + + val lits = literals[Long](plan) + assert(lits.size == 3) // transformDownWithSubqueries covers the inner timestamp twice + assert(lits.toSet.size == 1) + } + + test("analyzer should use consistent timestamps for different timezones") { + val localTimestamps = mapAsScalaMap(ZoneId.SHORT_IDS) + .map { case (zoneId, _) => Alias(LocalTimestamp(Some(zoneId)), zoneId)() }.toSeq + val input = Project(localTimestamps, LocalRelation()) + + val plan = Optimize.execute(input).asInstanceOf[Project] + + val lits = literals[Long](plan) + assert(lits.size === localTimestamps.size) + // there are timezones with a 30 or 45 minute offset + val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet + assert(offsetsFromQuarterHour.size == 1) + } + + test("analyzer should use consistent timestamps for different timestamp functions") { + val differentTimestamps = Seq( + Alias(CurrentTimestamp(), "currentTimestamp")(), + Alias(Now(), "now")(), + Alias(LocalTimestamp(Some("PLT")), "localTimestampWithTimezone")() + ) + val input = Project(differentTimestamps, LocalRelation()) + + val plan = Optimize.execute(input).asInstanceOf[Project] + + val lits = literals[Long](plan) + assert(lits.size === differentTimestamps.size) + // there are timezones with a 30 or 45 minute offset + val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet + assert(offsetsFromQuarterHour.size == 1) + } + + private def literals[T](plan: LogicalPlan): Seq[T] = { + val literals = new scala.collection.mutable.ArrayBuffer[T] + plan.transformWithSubqueries { case subQuery => + subQuery.transformAllExpressions { case expression: Literal => + literals += expression.value.asInstanceOf[T] + expression + } + } + literals.asInstanceOf[Seq[T]] + } }