Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47545][CONNECT] Dataset observe support for the Scala client #45701

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3337,8 +3337,69 @@ class Dataset[T] private[sql] (
}
}

/**
* Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset
* that returns the same result as the input, with the following guarantees: <ul> <li>It will
* compute the defined aggregates (metrics) on all the data that is flowing through the Dataset
* at that point.</li> <li>It will report the value of the defined aggregate columns as soon as
* we reach a completion point. A completion point is currently defined as the end of a
* query.</li> </ul> Please note that continuous execution is currently not supported.
*
* The metrics columns must either contain a literal (e.g. lit(42)), or should contain one or
* more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). Expressions that
* contain references to the input Dataset's columns must always be wrapped in an aggregate
* function.
*
* A user can retrieve the metrics by calling
* `org.apache.spark.sql.Dataset.collectResult().getObservedMetrics`.
*
* {{{
* // Observe row count (rows) and highest id (maxid) in the Dataset while writing it
* val observed_ds = ds.observe("my_metrics", count(lit(1)).as("rows"), max($"id").as("maxid"))
* observed_ds.write.parquet("ds.parquet")
* val metrics = observed_ds.collectResult().getObservedMetrics
* }}}
*
* @group typedrel
* @since 4.0.0
*/
@scala.annotation.varargs
def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = {
xupefei marked this conversation as resolved.
Show resolved Hide resolved
throw new UnsupportedOperationException("observe is not implemented.")
sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getCollectMetricsBuilder
.setInput(plan.getRoot)
.setName(name)
.addAllMetrics((expr +: exprs).map(_.expr).asJava)
}
}

/**
* Observe (named) metrics through an `org.apache.spark.sql.Observation` instance. This is
* equivalent to calling `observe(String, Column, Column*)` but does not require to collect all
* results before returning the metrics - the metrics are filled during iterating the results,
* as soon as they are available. This method does not support streaming datasets.
*
* A user can retrieve the metrics by accessing `org.apache.spark.sql.Observation.get`.
*
* {{{
* // Observe row count (rows) and highest id (maxid) in the Dataset while writing it
* val observation = Observation("my_metrics")
* val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), max($"id").as("maxid"))
* observed_ds.write.parquet("ds.parquet")
* val metrics = observation.get
* }}}
*
* @throws IllegalArgumentException
* If this is a streaming Dataset (this.isStreaming == true)
*
* @group typedrel
* @since 4.0.0
*/
@scala.annotation.varargs
def observe(observation: Observation, expr: Column, exprs: Column*): Dataset[T] = {
xupefei marked this conversation as resolved.
Show resolved Hide resolved
val df = observe(observation.name, expr, exprs: _*)
sparkSession.registerObservation(df.getPlanId.get, observation)
df
}

def checkpoint(): Dataset[T] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.util.UUID

class Observation(name: String) extends ObservationBase(name) {
xupefei marked this conversation as resolved.
Show resolved Hide resolved

/**
* Create an Observation instance without providing a name. This generates a random name.
*/
def this() = this(UUID.randomUUID().toString)
}

/**
* (Scala-specific) Create instances of Observation via Scala `apply`.
* @since 4.0.0
*/
object Observation {

/**
* Observation constructor for creating an anonymous observation.
*/
def apply(): Observation = new Observation()

/**
* Observation constructor for creating a named observation.
*/
def apply(name: String): Observation = new Observation(name)

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql

import java.io.Closeable
import java.net.URI
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.TimeUnit._
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}

Expand All @@ -36,7 +37,7 @@ import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder}
import org.apache.spark.sql.connect.client.{ClassFinder, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.functions.lit
Expand Down Expand Up @@ -80,6 +81,8 @@ class SparkSession private[sql] (
client.analyze(proto.AnalyzePlanRequest.AnalyzeCase.SPARK_VERSION).getSparkVersion.getVersion
}

private[sql] val observationRegistry = new ConcurrentHashMap[Long, Observation]()

/**
* Runtime configuration interface for Spark.
*
Expand Down Expand Up @@ -532,8 +535,12 @@ class SparkSession private[sql] (

private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = {
val value = client.execute(plan)
val result = new SparkResult(value, allocator, encoder, timeZoneId)
result
new SparkResult(
value,
allocator,
encoder,
timeZoneId,
Some(setMetricsAndUnregisterObservation))
}

private[sql] def execute(f: proto.Relation.Builder => Unit): Unit = {
Expand All @@ -554,6 +561,9 @@ class SparkSession private[sql] (
client.execute(plan).filter(!_.hasExecutionProgress).toSeq
}

private[sql] def execute(plan: proto.Plan): CloseableIterator[ExecutePlanResponse] =
client.execute(plan)

private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = {
val command = proto.Command.newBuilder().setRegisterFunction(udf).build()
execute(command)
Expand Down Expand Up @@ -779,6 +789,21 @@ class SparkSession private[sql] (
* Set to false to prevent client.releaseSession on close() (testing only)
*/
private[sql] var releaseSessionOnClose = true

private[sql] def registerObservation(planId: Long, observation: Observation): Unit = {
if (observationRegistry.putIfAbsent(planId, observation) != null) {
throw new IllegalArgumentException("An Observation can be used with a Dataset only once")
}
}

private[sql] def setMetricsAndUnregisterObservation(
planId: Long,
metrics: Map[String, Any]): Unit = {
val observationOrNull = observationRegistry.remove(planId)
if (observationOrNull != null) {
observationOrNull.setMetricsAndNotify(Some(metrics))
}
}
}

// The minimal builder needed to create a spark session.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.time.DateTimeException
import java.util.Properties

import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters._

import org.apache.commons.io.FileUtils
Expand All @@ -41,6 +43,7 @@ import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.test.{IntegrationTestUtils, RemoteSparkSession, SQLHelper}
import org.apache.spark.sql.test.SparkConnectServerUtils.port
import org.apache.spark.sql.types._
import org.apache.spark.util.SparkThreadUtils

class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester {

Expand Down Expand Up @@ -1511,6 +1514,46 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
(0 until 5).foreach(i => assert(row.get(i * 2) === row.get(i * 2 + 1)))
}
}

test("Observable metrics") {
val df = spark.range(99).withColumn("extra", col("id") - 1)
val ob1 = new Observation("ob1")
val observedDf = df.observe(ob1, min("id"), avg("id"), max("id"))
val observedObservedDf = observedDf.observe("ob2", min("extra"), avg("extra"), max("extra"))

val ob1Schema = new StructType()
.add("min(id)", LongType)
.add("avg(id)", DoubleType)
.add("max(id)", LongType)
val ob2Schema = new StructType()
.add("min(extra)", LongType)
.add("avg(extra)", DoubleType)
.add("max(extra)", LongType)
val ob1Metrics = Map("ob1" -> new GenericRowWithSchema(Array(0, 49, 98), ob1Schema))
val ob2Metrics = Map("ob2" -> new GenericRowWithSchema(Array(-1, 48, 97), ob2Schema))

assert(df.collectResult().getObservedMetrics === Map.empty)
assert(observedDf.collectResult().getObservedMetrics === ob1Metrics)
assert(observedObservedDf.collectResult().getObservedMetrics === ob1Metrics ++ ob2Metrics)
}

test("Observation.get is blocked until the query is finished") {
val df = spark.range(99).withColumn("extra", col("id") - 1)
val observation = new Observation("ob1")
val observedDf = df.observe(observation, min("id"), avg("id"), max("id"))

// Start a new thread to get the observation
val future = Future(observation.get)(ExecutionContext.global)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record. IMO the observation class should have been using a future from the get go.

// make sure the thread is blocked right now
val e = intercept[java.util.concurrent.TimeoutException] {
SparkThreadUtils.awaitResult(future, 2.seconds)
}
assert(e.getMessage.contains("Future timed out"))
observedDf.collect()
// make sure the thread is unblocked after the query is finished
val metrics = SparkThreadUtils.awaitResult(future, 2.seconds)
assert(metrics === Map("min(id)" -> 0, "avg(id)" -> 49, "max(id)" -> 98))
}
}

private[sql] case class ClassData(a: String, b: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.COL_POS_KEY"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.DATASET_ID_KEY"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.curId"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.observe"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.queryExecution"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ message ExecutePlanResponse {
string name = 1;
repeated Expression.Literal values = 2;
repeated string keys = 3;
int64 plan_id = 4;
}

message ResultComplete {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,22 @@ import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch}
import org.apache.arrow.vector.types.pojo

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse.ObservedMetrics
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, ConcatenatingArrowStreamReader, MessageIterator}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ArrowUtils

private[sql] class SparkResult[T](
responses: CloseableIterator[proto.ExecutePlanResponse],
allocator: BufferAllocator,
encoder: AgnosticEncoder[T],
timeZoneId: String)
timeZoneId: String,
setObservationMetricsOpt: Option[(Long, Map[String, Any]) => Unit] = None)
extends AutoCloseable { self =>

case class StageInfo(
Expand Down Expand Up @@ -79,6 +83,7 @@ private[sql] class SparkResult[T](
private[this] var arrowSchema: pojo.Schema = _
private[this] var nextResultIndex: Int = 0
private val resultMap = mutable.Map.empty[Int, (Long, Seq[ArrowMessage])]
private val observedMetrics = mutable.Map.empty[String, Row]
private val cleanable =
SparkResult.cleaner.register(this, new SparkResultCloseable(resultMap, responses))

Expand Down Expand Up @@ -117,6 +122,9 @@ private[sql] class SparkResult[T](
while (!stop && responses.hasNext) {
val response = responses.next()

// Collect metrics for this response
observedMetrics ++= processObservedMetrics(response.getObservedMetricsList)

// Save and validate operationId
if (opId == null) {
opId = response.getOperationId
Expand Down Expand Up @@ -198,6 +206,29 @@ private[sql] class SparkResult[T](
nonEmpty
}

private def processObservedMetrics(
metrics: java.util.List[ObservedMetrics]): Iterable[(String, Row)] = {
metrics.asScala.map { metric =>
assert(metric.getKeysCount == metric.getValuesCount)
var schema = new StructType()
val keys = mutable.ListBuffer.empty[String]
val values = mutable.ListBuffer.empty[Any]
(0 until metric.getKeysCount).map { i =>
val key = metric.getKeys(i)
val value = LiteralValueProtoConverter.toCatalystValue(metric.getValues(i))
schema = schema.add(key, LiteralValueProtoConverter.toDataType(value.getClass))
xupefei marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a bit of a twist here. So, LiteralValueProtoConverter, returns a Tuple for a nested struct. This is not really expected in a Row. We can address this in a follow-up.

keys += key
values += value
}
// If the metrics is registered by an Observation object, attach them and unblock any
// blocked thread.
setObservationMetricsOpt.foreach { setObservationMetrics =>
setObservationMetrics(metric.getPlanId, keys.zip(values).toMap)
}
metric.getName -> new GenericRowWithSchema(values.toArray, schema)
}
}

/**
* Returns the number of elements in the result.
*/
Expand Down Expand Up @@ -248,6 +279,15 @@ private[sql] class SparkResult[T](
result
}

/**
* Returns all observed metrics in the result.
*/
def getObservedMetrics: Map[String, Row] = {
// We need to process all responses to get all metrics.
processResponses()
observedMetrics.toMap
}

/**
* Returns an iterator over the contents of the result.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ object LiteralValueProtoConverter {
def toLiteralProto(literal: Any, dataType: DataType): proto.Expression.Literal =
toLiteralProtoBuilder(literal, dataType).build()

private def toDataType(clz: Class[_]): DataType = clz match {
private[sql] def toDataType(clz: Class[_]): DataType = clz match {
// primitive types
case JShort.TYPE => ShortType
case JInteger.TYPE => IntegerType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
.createObservedMetricsResponse(
executeHolder.sessionHolder.sessionId,
executeHolder.sessionHolder.serverSessionId,
executeHolder.request.getPlan.getRoot.getCommon.getPlanId,
observedMetrics ++ accumulatedInPython))
}

Expand Down