Skip to content

Commit

Permalink
[SPARK-45655][SQL][SS] Allow non-deterministic expressions inside Agg…
Browse files Browse the repository at this point in the history
…regateFunctions in CollectMetrics

### What changes were proposed in this pull request?

This PR allows non-deterministic expressions wrapped inside an `AggregateFunction` such as `count` inside `CollectMetrics` node. `CollectMetrics` is used to collect arbitrary metrics from the query, in certain scenarios user would like to collect metrics for filtering based on non-deterministic expressions (see query example below).

Currently, Analyzer does not allow non-deterministic expressions inside a `AggregateFunction` for `CollectMetrics`. This constraint is relaxed to allow collection of such metrics. Note that the metrics are relevant for a completed batch, and can change if the batch is replayed (because non-deterministic expression can behave differently for different runs).

While working on this feature, I found a issue with `checkMetric` logic to validate non-deterministic expressions inside an AggregateExpression. An expression is determined as non-deterministic if any of its children is non-deterministic, hence we need to match the case for `!e.deterministic && !seenAggregate` after we have matched if the current expression is a AggregateExpression. If the current expression is a AggregateExpression, we should validate further down in the tree recursively - otherwise we will fail for any non-deterministic expression.

```

val inputData = MemoryStream[Timestamp]

inputData.toDF()
      .filter("value < current_date()")
      .observe("metrics", count(expr("value >= current_date()")).alias("dropped"))
      .writeStream
      .queryName("ts_metrics_test")
      .format("memory")
      .outputMode("append")
      .start()

```

### Why are the changes needed?

1. Added a testcase to calculate dropped rows (by `CurrentBatchTimestamp`) and ensure the query is successful.

As an example, the query below fails (without this change) due to observe call on the DataFrame.

```

val inputData = MemoryStream[Timestamp]

inputData.toDF()
      .filter("value < current_date()")
      .observe("metrics", count(expr("value >= current_date()")).alias("dropped"))
      .writeStream
      .queryName("ts_metrics_test")
      .format("memory")
      .outputMode("append")
      .start()

```
2. Added testing in AnalysisSuite for non-deterministic expressions inside a AggregateFunction.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit test cases added.

```

[warn] 20 warnings found
WARNING: Using incubator modules: jdk.incubator.vector, jdk.incubator.foreign
[info] StreamingQueryStatusAndProgressSuite:
09:14:39.684 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
[info] Passed: Total 0, Failed 0, Errors 0, Passed 0
[info] No tests to run for hive / Test / testOnly
[info] - StreamingQueryProgress - prettyJson (436 milliseconds)
[info] - StreamingQueryProgress - json (3 milliseconds)
[info] - StreamingQueryProgress - toString (5 milliseconds)
[info] - StreamingQueryProgress - jsonString and fromJson (163 milliseconds)
[info] - StreamingQueryStatus - prettyJson (1 millisecond)
[info] - StreamingQueryStatus - json (1 millisecond)
[info] - StreamingQueryStatus - toString (2 milliseconds)
09:14:41.674 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /Users/bhuwan.sahni/workspace/spark/target/tmp/temporary-34d2749f-f4d0
-46d8-bc51-29da6411e1c5. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
09:14:41.710 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.
[info] - progress classes should be Serializable (5 seconds, 552 milliseconds)
09:14:46.345 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /Users/bhuwan.sahni/workspace/spark/target/tmp/temporary-3a41d397-c3c1
-490b-9cc7-d775b0c42208. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
09:14:46.345 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.
[info] - SPARK-19378: Continue reporting stateOp metrics even if there is no active trigger (1 second, 337 milliseconds)
09:14:47.677 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.
[info] - SPARK-29973: Make `processedRowsPerSecond` calculated more accurately and meaningfully (455 milliseconds)
09:14:48.174 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /Users/bhuwan.sahni/workspace/spark/target/tmp/temporary-360fc3b9-a2c5
-430c-a892-c9869f1f8339. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
09:14:48.174 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.
[info] - SPARK-45655: Use current batch timestamp in observe API (587 milliseconds)
09:14:48.768 WARN org.apache.spark.sql.streaming.StreamingQueryStatusAndProgressSuite:

```

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #43517 from sahnib/SPARK-45655.

Authored-by: Bhuwan Sahni <bhuwan.sahni@databricks.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
sahnib authored and HeartSaVioR committed Nov 12, 2023
1 parent f9c8c7a commit 2605b87
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Median, PercentileCont, PercentileDisc}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Median, PercentileCont, PercentileDisc}
import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, DecorrelateInnerQuery, InlineCTE}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -476,10 +476,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
e.failAnalysis(
"INVALID_OBSERVED_METRICS.WINDOW_EXPRESSIONS_UNSUPPORTED",
Map("expr" -> toSQLExpr(s)))
case _ if !e.deterministic && !seenAggregate =>
e.failAnalysis(
"INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC",
Map("expr" -> toSQLExpr(s)))
case a: AggregateExpression if seenAggregate =>
e.failAnalysis(
"INVALID_OBSERVED_METRICS.NESTED_AGGREGATES_UNSUPPORTED",
Expand All @@ -492,12 +488,18 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
e.failAnalysis(
"INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_FILTER_UNSUPPORTED",
Map("expr" -> toSQLExpr(s)))
case _: AggregateExpression | _: AggregateFunction =>
e.children.foreach(checkMetric (s, _, seenAggregate = true))
case _: Attribute if !seenAggregate =>
e.failAnalysis(
"INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE",
Map("expr" -> toSQLExpr(s)))
case _: AggregateExpression =>
e.children.foreach(checkMetric (s, _, seenAggregate = true))
case a: Alias =>
checkMetric(s, a.child, seenAggregate)
case a if !e.deterministic && !seenAggregate =>
e.failAnalysis(
"INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC",
Map("expr" -> toSQLExpr(s)))
case _ =>
e.children.foreach(checkMetric (s, _, seenAggregate))
}
Expand Down Expand Up @@ -734,8 +736,13 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
"dataType" -> toSQLType(mapCol.dataType)))

case o if o.expressions.exists(!_.deterministic) &&
!o.isInstanceOf[Project] && !o.isInstanceOf[Filter] &&
!o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] &&
!o.isInstanceOf[Project] &&
// non-deterministic expressions inside CollectMetrics have been
// already validated inside checkMetric function
!o.isInstanceOf[CollectMetrics] &&
!o.isInstanceOf[Filter] &&
!o.isInstanceOf[Aggregate] &&
!o.isInstanceOf[Window] &&
!o.isInstanceOf[Expand] &&
!o.isInstanceOf[Generate] &&
!o.isInstanceOf[CreateVariable] &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -794,9 +794,20 @@ class AnalysisSuite extends AnalysisTest with Matchers {
// No columns
assert(!CollectMetrics("evt", Nil, testRelation, 0).resolved)

def checkAnalysisError(exprs: Seq[NamedExpression], errors: String*): Unit = {
assertAnalysisError(CollectMetrics("event", exprs, testRelation, 0), errors)
}
// non-deterministic expression inside an aggregate function is valid
val tsLiteral = Literal.create(java.sql.Timestamp.valueOf("2023-11-30 21:05:00.000000"),
TimestampType)

assertAnalysisSuccess(
CollectMetrics(
"invalid",
Count(
GreaterThan(tsLiteral, CurrentBatchTimestamp(1699485296000L, TimestampType))
).as("count") :: Nil,
testRelation,
0
)
)

// Unwrapped attribute
assertAnalysisErrorClass(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.streaming

import java.sql.Timestamp
import java.time.Instant
import java.time.temporal.ChronoUnit
import java.util.UUID

import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -355,6 +358,47 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
)
}

test("SPARK-45655: Use current batch timestamp in observe API") {
import testImplicits._

val inputData = MemoryStream[Timestamp]

// current_date() internally uses current batch timestamp on streaming query
val query = inputData.toDF()
.filter("value < current_date()")
.observe("metrics", count(expr("value >= current_date()")).alias("dropped"))
.writeStream
.queryName("ts_metrics_test")
.format("memory")
.outputMode("append")
.start()

val timeNow = Instant.now().truncatedTo(ChronoUnit.SECONDS)

// this value would be accepted by the filter and would not count towards
// dropped metrics.
val validValue = Timestamp.from(timeNow.minus(2, ChronoUnit.DAYS))
inputData.addData(validValue)

// would be dropped by the filter and count towards dropped metrics
inputData.addData(Timestamp.from(timeNow.plus(2, ChronoUnit.DAYS)))

query.processAllAvailable()
query.stop()

val dropped = query.recentProgress.map { p =>
val metricVal = Option(p.observedMetrics.get("metrics"))
metricVal.map(_.getLong(0)).getOrElse(0L)
}.sum
// ensure dropped metrics are correct
assert(dropped == 1)

val data = spark.read.table("ts_metrics_test").collect()

// ensure valid value ends up in output
assert(data(0).getAs[Timestamp](0).equals(validValue))
}

def waitUntilBatchProcessed: AssertOnQuery = Execute { q =>
eventually(Timeout(streamingTimeout)) {
if (q.exception.isEmpty) {
Expand Down

0 comments on commit 2605b87

Please sign in to comment.