Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-46686][PYTHON][CONNECT] Basic support of SparkSession based Python UDF profiler #44697

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.connect.execution

import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

import com.google.protobuf.Message
Expand Down Expand Up @@ -185,19 +186,34 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
s"${executeHolder.request.getPlan.getOpTypeCase} not supported.")
}

if (executeHolder.observations.nonEmpty) {
val observedMetrics = executeHolder.observations.map { case (name, observation) =>
val observedMetrics: Map[String, Seq[(Option[String], Any)]] = {
executeHolder.observations.map { case (name, observation) =>
val values = observation.getOrEmpty.map { case (key, value) =>
(Some(key), value)
}.toSeq
name -> values
}.toMap
}
val accumulatedInPython: Map[String, Seq[(Option[String], Any)]] = {
executeHolder.sessionHolder.pythonAccumulator.flatMap { accumulator =>
accumulator.synchronized {
val value = accumulator.value.asScala.toSeq
if (value.nonEmpty) {
accumulator.reset()
Some("__python_accumulator__" -> value.map(value => (None, value)))
} else {
None
}
}
}.toMap
}
if (observedMetrics.nonEmpty || accumulatedInPython.nonEmpty) {
executeHolder.responseObserver.onNext(
SparkConnectPlanExecution
.createObservedMetricsResponse(
executeHolder.sessionHolder.sessionId,
executeHolder.sessionHolder.serverSessionId,
observedMetrics))
observedMetrics ++ accumulatedInPython))
}

lock.synchronized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -972,8 +972,8 @@ class SparkConnectPlanner(
pythonVer = fun.getPythonVer,
// Empty broadcast variables
broadcastVars = Lists.newArrayList(),
// Null accumulator
accumulator = null)
// Accumulator if available
accumulator = sessionHolder.pythonAccumulator.orNull)
}

private def transformCachedRemoteRelation(rel: proto.CachedRemoteRelation): LogicalPlan = {
Expand Down Expand Up @@ -1680,8 +1680,8 @@ class SparkConnectPlanner(
pythonVer = fun.getPythonVer,
// Empty broadcast variables
broadcastVars = Lists.newArrayList(),
// Null accumulator
accumulator = null)
// Accumulator if available
accumulator = sessionHolder.pythonAccumulator.orNull)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.util.Try

import com.google.common.base.Ticker
import com.google.common.cache.CacheBuilder

import org.apache.spark.{SparkException, SparkSQLException}
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -371,6 +373,14 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
private[connect] def listListenerIds(): Seq[String] = {
listenerCache.keySet().asScala.toSeq
}

/**
* An accumulator for Python executors.
*
* The accumulated results will be sent to the Python client via observed_metrics message.
*/
private[connect] val pythonAccumulator: Option[PythonAccumulator] =
Try(session.sparkContext.collectionAccumulator[Array[Byte]]).toOption
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, looks like we don't need Try(...) here? I took a cursory look, and seems it won't throw an exception.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, if the profile is disabled, we shouldn't probably create this accumulator to avoid performance issue.

Copy link
Member Author

@ueshin ueshin Jan 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like we don't need Try(...) here?

In some tests, mocks of session or sparkContext are used and they throw an exception when creating accumulators.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the profile is disabled, we shouldn't probably create this accumulator to avoid performance issue.

It needs to always have the accumulator because:

  • it can't know whether or not / when the profiler is enabled
  • to support the registered UDFs

What kind of performance issue do you concern?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is that regisgerting too many acumulators because calling this will create and register accumator for each session. Especially for Spark Connent, there could be a lot of Spark sessions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are already much more accumulators registered for each query, as SQLMetrics. I don't think one more accumulator per session could be an issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌

}

object SessionHolder {
Expand Down
11 changes: 8 additions & 3 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat

import org.apache.spark._
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -83,7 +84,11 @@ private[spark] trait PythonFunction {
def pythonExec: String
def pythonVer: String
def broadcastVars: JList[Broadcast[PythonBroadcast]]
def accumulator: PythonAccumulatorV2
def accumulator: PythonAccumulator
}

private[spark] object PythonFunction {
type PythonAccumulator = CollectionAccumulator[Array[Byte]]
}

/**
Expand All @@ -96,7 +101,7 @@ private[spark] case class SimplePythonFunction(
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: PythonAccumulatorV2) extends PythonFunction {
accumulator: PythonAccumulator) extends PythonFunction {

def this(
command: Array[Byte],
Expand All @@ -105,7 +110,7 @@ private[spark] case class SimplePythonFunction(
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: PythonAccumulatorV2) = {
accumulator: PythonAccumulator) = {
this(command.toImmutableArraySeq,
envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

import org.apache.spark._
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES, Python}
import org.apache.spark.internal.config.Python._
Expand Down Expand Up @@ -146,10 +147,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}.getOrElse("pyspark.worker")

// TODO: support accumulator in multiple UDF
protected val accumulator: PythonAccumulatorV2 = funcs.head.funcs.head.accumulator
protected val accumulator: PythonAccumulator = funcs.head.funcs.head.accumulator

// Python accumulator is always set in production except in tests. See SPARK-27893
private val maybeAccumulator: Option[PythonAccumulatorV2] = Option(accumulator)
private val maybeAccumulator: Option[PythonAccumulator] = Option(accumulator)

// Expose a ServerSocket to support method calls via socket from Python side. Only relevant for
// for tasks that are a part of barrier stage, refer [[BarrierTaskContext]] for details.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.{DataInputStream, DataOutputStream, File}
import java.nio.charset.StandardCharsets

import org.apache.spark.{SparkEnv, SparkFiles}
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging

Expand Down Expand Up @@ -186,7 +187,8 @@ private[spark] object PythonWorkerUtils extends Logging {
* The updates are sent by `worker_util.send_accumulator_updates`.
*/
def receiveAccumulatorUpdates(
maybeAccumulator: Option[PythonAccumulatorV2], dataIn: DataInputStream): Unit = {
maybeAccumulator: Option[PythonAccumulator],
dataIn: DataInputStream): Unit = {
val numAccumulatorUpdates = dataIn.readInt()
(1 to numAccumulatorUpdates).foreach { _ =>
val update = readBytes(dataIn)
Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_parity_column",
"pyspark.sql.tests.connect.test_parity_readwriter",
"pyspark.sql.tests.connect.test_parity_udf",
"pyspark.sql.tests.connect.test_parity_udf_profiler",
"pyspark.sql.tests.connect.test_parity_udtf",
"pyspark.sql.tests.connect.test_parity_pandas_udf",
"pyspark.sql.tests.connect.test_parity_pandas_map",
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def _deserialize_accumulator(
return accum


class SpecialAccumulatorIds:
SQL_UDF_PROFIER = -1


class Accumulator(Generic[T]):

"""
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,9 @@ def stats(self) -> CodeMapDict:
"""Return the collected memory profiles"""
return cast(CodeMapDict, self._accumulator.value)

@staticmethod
def _show_results(
self, code_map: CodeMapDict, stream: Optional[Any] = None, precision: int = 1
code_map: CodeMapDict, stream: Optional[Any] = None, precision: int = 1
) -> None:
if stream is None:
stream = sys.stdout
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Expand All @@ -29,8 +30,10 @@ from typing_extensions import Literal, Protocol

import datetime
import decimal
import pstats

from pyspark._typing import PrimitiveType
from pyspark.profiler import CodeMapDict
import pyspark.sql.types
from pyspark.sql.column import Column

Expand Down Expand Up @@ -79,3 +82,5 @@ class UserDefinedFunctionLike(Protocol):
def returnType(self) -> pyspark.sql.types.DataType: ...
def __call__(self, *args: ColumnOrName) -> Column: ...
def asNondeterministic(self) -> UserDefinedFunctionLike: ...

ProfileResults = Dict[int, Tuple[Optional[pstats.Stats], Optional[CodeMapDict]]]
13 changes: 12 additions & 1 deletion python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@
from google.protobuf import text_format, any_pb2
from google.rpc import error_details_pb2

from pyspark.accumulators import SpecialAccumulatorIds
from pyspark.loose_version import LooseVersion
from pyspark.version import __version__
from pyspark.resource.information import ResourceInformation
from pyspark.sql.connect.client.artifact import ArtifactManager
from pyspark.sql.connect.client.logging import logger
from pyspark.sql.connect.profiler import ConnectProfilerCollector
from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
from pyspark.sql.connect.client.retries import RetryPolicy, Retrying, DefaultPolicy
from pyspark.sql.connect.conversion import storage_level_to_proto, proto_to_storage_level
Expand Down Expand Up @@ -636,6 +638,8 @@ class ClientThreadLocals(threading.local):
# be updated on the first response received.
self._server_session_id: Optional[str] = None

self._profiler_collector = ConnectProfilerCollector()

def _retrying(self) -> "Retrying":
return Retrying(self._retry_policies)

Expand Down Expand Up @@ -1169,7 +1173,14 @@ def handle_response(
if b.observed_metrics:
logger.debug("Received observed metric batch.")
for observed_metrics in self._build_observed_metrics(b.observed_metrics):
if observed_metrics.name in observations:
if observed_metrics.name == "__python_accumulator__":
from pyspark.worker_util import pickleSer

for metric in observed_metrics.metrics:
(aid, update) = pickleSer.loads(LiteralExpression._to_value(metric))
if aid == SpecialAccumulatorIds.SQL_UDF_PROFIER:
self._profiler_collector._update(update)
elif observed_metrics.name in observations:
observation_result = observations[observed_metrics.name]._result
assert observation_result is not None
observation_result.update(
Expand Down
41 changes: 41 additions & 0 deletions python/pyspark/sql/connect/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING

from pyspark.sql.profiler import ProfilerCollector, ProfileResultsParam

if TYPE_CHECKING:
from pyspark.sql._typing import ProfileResults


class ConnectProfilerCollector(ProfilerCollector):
"""
ProfilerCollector for Spark Connect.
"""

def __init__(self) -> None:
super().__init__()
self._value = ProfileResultsParam.zero(None)

@property
def _profile_results(self) -> "ProfileResults":
with self._lock:
return self._value if self._value is not None else {}

def _update(self, update: "ProfileResults") -> None:
with self._lock:
self._value = ProfileResultsParam.addInPlace(self._profile_results, update)
10 changes: 10 additions & 0 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
CachedRelation,
CachedRemoteRelation,
)
from pyspark.sql.connect.profiler import ProfilerCollector
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.streaming.readwriter import DataStreamReader
from pyspark.sql.connect.streaming.query import StreamingQueryManager
Expand Down Expand Up @@ -919,6 +920,15 @@ def create_conf(**kwargs: Any) -> SparkConf:
def session_id(self) -> str:
return self._session_id

@property
def _profiler_collector(self) -> ProfilerCollector:
return self._client._profiler_collector

def showPerfProfiles(self, id: Optional[int] = None) -> None:
self._profiler_collector.show_perf_profiles(id)

showPerfProfiles.__doc__ = PySparkSession.showPerfProfiles.__doc__


SparkSession.__doc__ = PySparkSession.__doc__

Expand Down
Loading