Skip to content

Commit

Permalink
[SPARK-46686][PYTHON][CONNECT] Basic support of SparkSession based Py…
Browse files Browse the repository at this point in the history
…thon UDF profiler

### What changes were proposed in this pull request?

Basic support of SparkSession based Python UDF profiler.

To enable the profiler, use a SQL conf `spark.sql.pyspark.udf.profiler`:

- `"perf"`: enable cProfiler
- `"memory"`: enable memory-profiler (TODO: [SPARK-46687](https://issues.apache.org/jira/browse/SPARK-46687))

```py
from pyspark.sql.functions import *

spark.conf.set("spark.sql.pyspark.udf.profiler", "perf")  # enable cProfiler

udf("string")
def f(x):
      return str(x)

df = spark.range(10).select(f(col("id")))
df.collect()

pandas_udf("string")
def g(x):
     return x.astype("string")

df = spark.range(10).select(g(col("id")))

spark.conf.unset("spark.sql.pyspark.udf.profiler")  # disable

df.collect()  # won't profile

spark.showPerfProfiles()  # show the result for only the first collect.
```

### Why are the changes needed?

The existing UDF profilers are SparkContext based, which can't support Spark Connect.

We should introduce SparkSession based profilers and support Spark Connect.

### Does this PR introduce _any_ user-facing change?

Yes, SparkSession-based UDF profilers will be available.

### How was this patch tested?

Added the related tests, manually, and existing tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #44697 from ueshin/issues/SPARK-46686/profiler.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
ueshin authored and HyukjinKwon committed Jan 17, 2024
1 parent 9a2f393 commit d8703dd
Show file tree
Hide file tree
Showing 36 changed files with 797 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.connect.execution

import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

import com.google.protobuf.Message
Expand Down Expand Up @@ -185,19 +186,34 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
s"${executeHolder.request.getPlan.getOpTypeCase} not supported.")
}

if (executeHolder.observations.nonEmpty) {
val observedMetrics = executeHolder.observations.map { case (name, observation) =>
val observedMetrics: Map[String, Seq[(Option[String], Any)]] = {
executeHolder.observations.map { case (name, observation) =>
val values = observation.getOrEmpty.map { case (key, value) =>
(Some(key), value)
}.toSeq
name -> values
}.toMap
}
val accumulatedInPython: Map[String, Seq[(Option[String], Any)]] = {
executeHolder.sessionHolder.pythonAccumulator.flatMap { accumulator =>
accumulator.synchronized {
val value = accumulator.value.asScala.toSeq
if (value.nonEmpty) {
accumulator.reset()
Some("__python_accumulator__" -> value.map(value => (None, value)))
} else {
None
}
}
}.toMap
}
if (observedMetrics.nonEmpty || accumulatedInPython.nonEmpty) {
executeHolder.responseObserver.onNext(
SparkConnectPlanExecution
.createObservedMetricsResponse(
executeHolder.sessionHolder.sessionId,
executeHolder.sessionHolder.serverSessionId,
observedMetrics))
observedMetrics ++ accumulatedInPython))
}

lock.synchronized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -972,8 +972,8 @@ class SparkConnectPlanner(
pythonVer = fun.getPythonVer,
// Empty broadcast variables
broadcastVars = Lists.newArrayList(),
// Null accumulator
accumulator = null)
// Accumulator if available
accumulator = sessionHolder.pythonAccumulator.orNull)
}

private def transformCachedRemoteRelation(rel: proto.CachedRemoteRelation): LogicalPlan = {
Expand Down Expand Up @@ -1680,8 +1680,8 @@ class SparkConnectPlanner(
pythonVer = fun.getPythonVer,
// Empty broadcast variables
broadcastVars = Lists.newArrayList(),
// Null accumulator
accumulator = null)
// Accumulator if available
accumulator = sessionHolder.pythonAccumulator.orNull)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.util.Try

import com.google.common.base.Ticker
import com.google.common.cache.CacheBuilder

import org.apache.spark.{SparkException, SparkSQLException}
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -371,6 +373,14 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
private[connect] def listListenerIds(): Seq[String] = {
listenerCache.keySet().asScala.toSeq
}

/**
* An accumulator for Python executors.
*
* The accumulated results will be sent to the Python client via observed_metrics message.
*/
private[connect] val pythonAccumulator: Option[PythonAccumulator] =
Try(session.sparkContext.collectionAccumulator[Array[Byte]]).toOption
}

object SessionHolder {
Expand Down
11 changes: 8 additions & 3 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat

import org.apache.spark._
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -83,7 +84,11 @@ private[spark] trait PythonFunction {
def pythonExec: String
def pythonVer: String
def broadcastVars: JList[Broadcast[PythonBroadcast]]
def accumulator: PythonAccumulatorV2
def accumulator: PythonAccumulator
}

private[spark] object PythonFunction {
type PythonAccumulator = CollectionAccumulator[Array[Byte]]
}

/**
Expand All @@ -96,7 +101,7 @@ private[spark] case class SimplePythonFunction(
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: PythonAccumulatorV2) extends PythonFunction {
accumulator: PythonAccumulator) extends PythonFunction {

def this(
command: Array[Byte],
Expand All @@ -105,7 +110,7 @@ private[spark] case class SimplePythonFunction(
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: PythonAccumulatorV2) = {
accumulator: PythonAccumulator) = {
this(command.toImmutableArraySeq,
envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

import org.apache.spark._
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES, Python}
import org.apache.spark.internal.config.Python._
Expand Down Expand Up @@ -146,10 +147,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}.getOrElse("pyspark.worker")

// TODO: support accumulator in multiple UDF
protected val accumulator: PythonAccumulatorV2 = funcs.head.funcs.head.accumulator
protected val accumulator: PythonAccumulator = funcs.head.funcs.head.accumulator

// Python accumulator is always set in production except in tests. See SPARK-27893
private val maybeAccumulator: Option[PythonAccumulatorV2] = Option(accumulator)
private val maybeAccumulator: Option[PythonAccumulator] = Option(accumulator)

// Expose a ServerSocket to support method calls via socket from Python side. Only relevant for
// for tasks that are a part of barrier stage, refer [[BarrierTaskContext]] for details.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.{DataInputStream, DataOutputStream, File}
import java.nio.charset.StandardCharsets

import org.apache.spark.{SparkEnv, SparkFiles}
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging

Expand Down Expand Up @@ -186,7 +187,8 @@ private[spark] object PythonWorkerUtils extends Logging {
* The updates are sent by `worker_util.send_accumulator_updates`.
*/
def receiveAccumulatorUpdates(
maybeAccumulator: Option[PythonAccumulatorV2], dataIn: DataInputStream): Unit = {
maybeAccumulator: Option[PythonAccumulator],
dataIn: DataInputStream): Unit = {
val numAccumulatorUpdates = dataIn.readInt()
(1 to numAccumulatorUpdates).foreach { _ =>
val update = readBytes(dataIn)
Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_parity_column",
"pyspark.sql.tests.connect.test_parity_readwriter",
"pyspark.sql.tests.connect.test_parity_udf",
"pyspark.sql.tests.connect.test_parity_udf_profiler",
"pyspark.sql.tests.connect.test_parity_udtf",
"pyspark.sql.tests.connect.test_parity_pandas_udf",
"pyspark.sql.tests.connect.test_parity_pandas_map",
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def _deserialize_accumulator(
return accum


class SpecialAccumulatorIds:
SQL_UDF_PROFIER = -1


class Accumulator(Generic[T]):

"""
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,9 @@ def stats(self) -> CodeMapDict:
"""Return the collected memory profiles"""
return cast(CodeMapDict, self._accumulator.value)

@staticmethod
def _show_results(
self, code_map: CodeMapDict, stream: Optional[Any] = None, precision: int = 1
code_map: CodeMapDict, stream: Optional[Any] = None, precision: int = 1
) -> None:
if stream is None:
stream = sys.stdout
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Expand All @@ -29,8 +30,10 @@ from typing_extensions import Literal, Protocol

import datetime
import decimal
import pstats

from pyspark._typing import PrimitiveType
from pyspark.profiler import CodeMapDict
import pyspark.sql.types
from pyspark.sql.column import Column

Expand Down Expand Up @@ -79,3 +82,5 @@ class UserDefinedFunctionLike(Protocol):
def returnType(self) -> pyspark.sql.types.DataType: ...
def __call__(self, *args: ColumnOrName) -> Column: ...
def asNondeterministic(self) -> UserDefinedFunctionLike: ...

ProfileResults = Dict[int, Tuple[Optional[pstats.Stats], Optional[CodeMapDict]]]
13 changes: 12 additions & 1 deletion python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@
from google.protobuf import text_format, any_pb2
from google.rpc import error_details_pb2

from pyspark.accumulators import SpecialAccumulatorIds
from pyspark.loose_version import LooseVersion
from pyspark.version import __version__
from pyspark.resource.information import ResourceInformation
from pyspark.sql.connect.client.artifact import ArtifactManager
from pyspark.sql.connect.client.logging import logger
from pyspark.sql.connect.profiler import ConnectProfilerCollector
from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
from pyspark.sql.connect.client.retries import RetryPolicy, Retrying, DefaultPolicy
from pyspark.sql.connect.conversion import storage_level_to_proto, proto_to_storage_level
Expand Down Expand Up @@ -636,6 +638,8 @@ class ClientThreadLocals(threading.local):
# be updated on the first response received.
self._server_session_id: Optional[str] = None

self._profiler_collector = ConnectProfilerCollector()

def _retrying(self) -> "Retrying":
return Retrying(self._retry_policies)

Expand Down Expand Up @@ -1169,7 +1173,14 @@ def handle_response(
if b.observed_metrics:
logger.debug("Received observed metric batch.")
for observed_metrics in self._build_observed_metrics(b.observed_metrics):
if observed_metrics.name in observations:
if observed_metrics.name == "__python_accumulator__":
from pyspark.worker_util import pickleSer

for metric in observed_metrics.metrics:
(aid, update) = pickleSer.loads(LiteralExpression._to_value(metric))
if aid == SpecialAccumulatorIds.SQL_UDF_PROFIER:
self._profiler_collector._update(update)
elif observed_metrics.name in observations:
observation_result = observations[observed_metrics.name]._result
assert observation_result is not None
observation_result.update(
Expand Down
41 changes: 41 additions & 0 deletions python/pyspark/sql/connect/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING

from pyspark.sql.profiler import ProfilerCollector, ProfileResultsParam

if TYPE_CHECKING:
from pyspark.sql._typing import ProfileResults


class ConnectProfilerCollector(ProfilerCollector):
"""
ProfilerCollector for Spark Connect.
"""

def __init__(self) -> None:
super().__init__()
self._value = ProfileResultsParam.zero(None)

@property
def _profile_results(self) -> "ProfileResults":
with self._lock:
return self._value if self._value is not None else {}

def _update(self, update: "ProfileResults") -> None:
with self._lock:
self._value = ProfileResultsParam.addInPlace(self._profile_results, update)
10 changes: 10 additions & 0 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
CachedRelation,
CachedRemoteRelation,
)
from pyspark.sql.connect.profiler import ProfilerCollector
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.streaming.readwriter import DataStreamReader
from pyspark.sql.connect.streaming.query import StreamingQueryManager
Expand Down Expand Up @@ -919,6 +920,15 @@ def create_conf(**kwargs: Any) -> SparkConf:
def session_id(self) -> str:
return self._session_id

@property
def _profiler_collector(self) -> ProfilerCollector:
return self._client._profiler_collector

def showPerfProfiles(self, id: Optional[int] = None) -> None:
self._profiler_collector.show_perf_profiles(id)

showPerfProfiles.__doc__ = PySparkSession.showPerfProfiles.__doc__


SparkSession.__doc__ = PySparkSession.__doc__

Expand Down

0 comments on commit d8703dd

Please sign in to comment.