Skip to content

Commit

Permalink
[SPARK-47545][CONNECT] Dataset observe support for the Scala client
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR adds support for `Dataset.observe` to the Spark Connect Scala client. Note that the support here does not include listener support as it runs on the serve side.

This PR includes a small refactoring to the `Observation` helper class. We extracted methods that are not bound to the SparkSession to `spark-api`, and added two subclasses on both `spark-core` and `spark-jvm-client`.

### Why are the changes needed?

Before this PR, the `DF.observe` method is only supported in the Python client.

### Does this PR introduce _any_ user-facing change?
Yes. The user can now issue `DF.observe(name, metrics...)` or `DF.observe(observationObject, metrics...)` to get stats of columns of a dataframe.

### How was this patch tested?

Added new e2e tests.

### Was this patch authored or co-authored using generative AI tooling?

Nope.

Closes apache#45701 from xupefei/scala-observe.

Authored-by: Paddy Xu <xupaddy@gmail.com>
Signed-off-by: Herman van Hovell <herman@databricks.com>
  • Loading branch information
xupefei authored and JacobZheng0927 committed May 11, 2024
1 parent acd315f commit ae05782
Show file tree
Hide file tree
Showing 14 changed files with 448 additions and 166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3337,8 +3337,69 @@ class Dataset[T] private[sql] (
}
}

/**
* Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset
* that returns the same result as the input, with the following guarantees: <ul> <li>It will
* compute the defined aggregates (metrics) on all the data that is flowing through the Dataset
* at that point.</li> <li>It will report the value of the defined aggregate columns as soon as
* we reach a completion point. A completion point is currently defined as the end of a
* query.</li> </ul> Please note that continuous execution is currently not supported.
*
* The metrics columns must either contain a literal (e.g. lit(42)), or should contain one or
* more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). Expressions that
* contain references to the input Dataset's columns must always be wrapped in an aggregate
* function.
*
* A user can retrieve the metrics by calling
* `org.apache.spark.sql.Dataset.collectResult().getObservedMetrics`.
*
* {{{
* // Observe row count (rows) and highest id (maxid) in the Dataset while writing it
* val observed_ds = ds.observe("my_metrics", count(lit(1)).as("rows"), max($"id").as("maxid"))
* observed_ds.write.parquet("ds.parquet")
* val metrics = observed_ds.collectResult().getObservedMetrics
* }}}
*
* @group typedrel
* @since 4.0.0
*/
@scala.annotation.varargs
def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = {
throw new UnsupportedOperationException("observe is not implemented.")
sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getCollectMetricsBuilder
.setInput(plan.getRoot)
.setName(name)
.addAllMetrics((expr +: exprs).map(_.expr).asJava)
}
}

/**
* Observe (named) metrics through an `org.apache.spark.sql.Observation` instance. This is
* equivalent to calling `observe(String, Column, Column*)` but does not require to collect all
* results before returning the metrics - the metrics are filled during iterating the results,
* as soon as they are available. This method does not support streaming datasets.
*
* A user can retrieve the metrics by accessing `org.apache.spark.sql.Observation.get`.
*
* {{{
* // Observe row count (rows) and highest id (maxid) in the Dataset while writing it
* val observation = Observation("my_metrics")
* val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), max($"id").as("maxid"))
* observed_ds.write.parquet("ds.parquet")
* val metrics = observation.get
* }}}
*
* @throws IllegalArgumentException
* If this is a streaming Dataset (this.isStreaming == true)
*
* @group typedrel
* @since 4.0.0
*/
@scala.annotation.varargs
def observe(observation: Observation, expr: Column, exprs: Column*): Dataset[T] = {
val df = observe(observation.name, expr, exprs: _*)
sparkSession.registerObservation(df.getPlanId.get, observation)
df
}

def checkpoint(): Dataset[T] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.util.UUID

class Observation(name: String) extends ObservationBase(name) {

/**
* Create an Observation instance without providing a name. This generates a random name.
*/
def this() = this(UUID.randomUUID().toString)
}

/**
* (Scala-specific) Create instances of Observation via Scala `apply`.
* @since 4.0.0
*/
object Observation {

/**
* Observation constructor for creating an anonymous observation.
*/
def apply(): Observation = new Observation()

/**
* Observation constructor for creating a named observation.
*/
def apply(name: String): Observation = new Observation(name)

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql

import java.io.Closeable
import java.net.URI
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.TimeUnit._
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}

Expand All @@ -36,7 +37,7 @@ import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder}
import org.apache.spark.sql.connect.client.{ClassFinder, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.functions.lit
Expand Down Expand Up @@ -80,6 +81,8 @@ class SparkSession private[sql] (
client.analyze(proto.AnalyzePlanRequest.AnalyzeCase.SPARK_VERSION).getSparkVersion.getVersion
}

private[sql] val observationRegistry = new ConcurrentHashMap[Long, Observation]()

/**
* Runtime configuration interface for Spark.
*
Expand Down Expand Up @@ -532,8 +535,12 @@ class SparkSession private[sql] (

private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = {
val value = client.execute(plan)
val result = new SparkResult(value, allocator, encoder, timeZoneId)
result
new SparkResult(
value,
allocator,
encoder,
timeZoneId,
Some(setMetricsAndUnregisterObservation))
}

private[sql] def execute(f: proto.Relation.Builder => Unit): Unit = {
Expand All @@ -554,6 +561,9 @@ class SparkSession private[sql] (
client.execute(plan).filter(!_.hasExecutionProgress).toSeq
}

private[sql] def execute(plan: proto.Plan): CloseableIterator[ExecutePlanResponse] =
client.execute(plan)

private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = {
val command = proto.Command.newBuilder().setRegisterFunction(udf).build()
execute(command)
Expand Down Expand Up @@ -779,6 +789,21 @@ class SparkSession private[sql] (
* Set to false to prevent client.releaseSession on close() (testing only)
*/
private[sql] var releaseSessionOnClose = true

private[sql] def registerObservation(planId: Long, observation: Observation): Unit = {
if (observationRegistry.putIfAbsent(planId, observation) != null) {
throw new IllegalArgumentException("An Observation can be used with a Dataset only once")
}
}

private[sql] def setMetricsAndUnregisterObservation(
planId: Long,
metrics: Map[String, Any]): Unit = {
val observationOrNull = observationRegistry.remove(planId)
if (observationOrNull != null) {
observationOrNull.setMetricsAndNotify(Some(metrics))
}
}
}

// The minimal builder needed to create a spark session.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.time.DateTimeException
import java.util.Properties

import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters._

import org.apache.commons.io.FileUtils
Expand All @@ -41,6 +43,7 @@ import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.test.{IntegrationTestUtils, RemoteSparkSession, SQLHelper}
import org.apache.spark.sql.test.SparkConnectServerUtils.port
import org.apache.spark.sql.types._
import org.apache.spark.util.SparkThreadUtils

class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester {

Expand Down Expand Up @@ -1511,6 +1514,46 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
(0 until 5).foreach(i => assert(row.get(i * 2) === row.get(i * 2 + 1)))
}
}

test("Observable metrics") {
val df = spark.range(99).withColumn("extra", col("id") - 1)
val ob1 = new Observation("ob1")
val observedDf = df.observe(ob1, min("id"), avg("id"), max("id"))
val observedObservedDf = observedDf.observe("ob2", min("extra"), avg("extra"), max("extra"))

val ob1Schema = new StructType()
.add("min(id)", LongType)
.add("avg(id)", DoubleType)
.add("max(id)", LongType)
val ob2Schema = new StructType()
.add("min(extra)", LongType)
.add("avg(extra)", DoubleType)
.add("max(extra)", LongType)
val ob1Metrics = Map("ob1" -> new GenericRowWithSchema(Array(0, 49, 98), ob1Schema))
val ob2Metrics = Map("ob2" -> new GenericRowWithSchema(Array(-1, 48, 97), ob2Schema))

assert(df.collectResult().getObservedMetrics === Map.empty)
assert(observedDf.collectResult().getObservedMetrics === ob1Metrics)
assert(observedObservedDf.collectResult().getObservedMetrics === ob1Metrics ++ ob2Metrics)
}

test("Observation.get is blocked until the query is finished") {
val df = spark.range(99).withColumn("extra", col("id") - 1)
val observation = new Observation("ob1")
val observedDf = df.observe(observation, min("id"), avg("id"), max("id"))

// Start a new thread to get the observation
val future = Future(observation.get)(ExecutionContext.global)
// make sure the thread is blocked right now
val e = intercept[java.util.concurrent.TimeoutException] {
SparkThreadUtils.awaitResult(future, 2.seconds)
}
assert(e.getMessage.contains("Future timed out"))
observedDf.collect()
// make sure the thread is unblocked after the query is finished
val metrics = SparkThreadUtils.awaitResult(future, 2.seconds)
assert(metrics === Map("min(id)" -> 0, "avg(id)" -> 49, "max(id)" -> 98))
}
}

private[sql] case class ClassData(a: String, b: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.COL_POS_KEY"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.DATASET_ID_KEY"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.curId"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.observe"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.queryExecution"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ message ExecutePlanResponse {
string name = 1;
repeated Expression.Literal values = 2;
repeated string keys = 3;
int64 plan_id = 4;
}

message ResultComplete {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,22 @@ import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch}
import org.apache.arrow.vector.types.pojo

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse.ObservedMetrics
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, ConcatenatingArrowStreamReader, MessageIterator}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ArrowUtils

private[sql] class SparkResult[T](
responses: CloseableIterator[proto.ExecutePlanResponse],
allocator: BufferAllocator,
encoder: AgnosticEncoder[T],
timeZoneId: String)
timeZoneId: String,
setObservationMetricsOpt: Option[(Long, Map[String, Any]) => Unit] = None)
extends AutoCloseable { self =>

case class StageInfo(
Expand Down Expand Up @@ -79,6 +83,7 @@ private[sql] class SparkResult[T](
private[this] var arrowSchema: pojo.Schema = _
private[this] var nextResultIndex: Int = 0
private val resultMap = mutable.Map.empty[Int, (Long, Seq[ArrowMessage])]
private val observedMetrics = mutable.Map.empty[String, Row]
private val cleanable =
SparkResult.cleaner.register(this, new SparkResultCloseable(resultMap, responses))

Expand Down Expand Up @@ -117,6 +122,9 @@ private[sql] class SparkResult[T](
while (!stop && responses.hasNext) {
val response = responses.next()

// Collect metrics for this response
observedMetrics ++= processObservedMetrics(response.getObservedMetricsList)

// Save and validate operationId
if (opId == null) {
opId = response.getOperationId
Expand Down Expand Up @@ -198,6 +206,29 @@ private[sql] class SparkResult[T](
nonEmpty
}

private def processObservedMetrics(
metrics: java.util.List[ObservedMetrics]): Iterable[(String, Row)] = {
metrics.asScala.map { metric =>
assert(metric.getKeysCount == metric.getValuesCount)
var schema = new StructType()
val keys = mutable.ListBuffer.empty[String]
val values = mutable.ListBuffer.empty[Any]
(0 until metric.getKeysCount).map { i =>
val key = metric.getKeys(i)
val value = LiteralValueProtoConverter.toCatalystValue(metric.getValues(i))
schema = schema.add(key, LiteralValueProtoConverter.toDataType(value.getClass))
keys += key
values += value
}
// If the metrics is registered by an Observation object, attach them and unblock any
// blocked thread.
setObservationMetricsOpt.foreach { setObservationMetrics =>
setObservationMetrics(metric.getPlanId, keys.zip(values).toMap)
}
metric.getName -> new GenericRowWithSchema(values.toArray, schema)
}
}

/**
* Returns the number of elements in the result.
*/
Expand Down Expand Up @@ -248,6 +279,15 @@ private[sql] class SparkResult[T](
result
}

/**
* Returns all observed metrics in the result.
*/
def getObservedMetrics: Map[String, Row] = {
// We need to process all responses to get all metrics.
processResponses()
observedMetrics.toMap
}

/**
* Returns an iterator over the contents of the result.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ object LiteralValueProtoConverter {
def toLiteralProto(literal: Any, dataType: DataType): proto.Expression.Literal =
toLiteralProtoBuilder(literal, dataType).build()

private def toDataType(clz: Class[_]): DataType = clz match {
private[sql] def toDataType(clz: Class[_]): DataType = clz match {
// primitive types
case JShort.TYPE => ShortType
case JInteger.TYPE => IntegerType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
.createObservedMetricsResponse(
executeHolder.sessionHolder.sessionId,
executeHolder.sessionHolder.serverSessionId,
executeHolder.request.getPlan.getRoot.getCommon.getPlanId,
observedMetrics ++ accumulatedInPython))
}

Expand Down

0 comments on commit ae05782

Please sign in to comment.