Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/pyspark/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

from types import CodeType
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -278,6 +279,10 @@ def __init__(self, **kw: Any) -> None:
backend = kw.get("backend", "psutil")
self.code_map = CodeMapForUDFV2(include_children=include_children, backend=backend)

def add_code(self, code: CodeType) -> None:
"""Record line profiling information for the given code object."""
self.code_map.add(code)


class PStatsParam(AccumulatorParam[Optional[pstats.Stats]]):
"""PStatsParam is used to merge pstats.Stats"""
Expand Down
12 changes: 9 additions & 3 deletions python/pyspark/sql/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import os
import pstats
from threading import RLock
from types import TracebackType
from types import CodeType, TracebackType
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union, TYPE_CHECKING, overload
import warnings

Expand Down Expand Up @@ -120,13 +120,19 @@ class WorkerMemoryProfiler:
"""

def __init__(
self, accumulator: Accumulator["ProfileResults"], result_id: int, func: Callable
self,
accumulator: Accumulator["ProfileResults"],
result_id: int,
func_or_code: Union[Callable, CodeType],
) -> None:
from pyspark.profiler import UDFLineProfilerV2

self._accumulator = accumulator
self._profiler = UDFLineProfilerV2()
self._profiler.add_function(func)
if isinstance(func_or_code, CodeType):
self._profiler.add_code(func_or_code)
else:
self._profiler.add_function(func_or_code)
self._result_id = result_id

def start(self) -> None:
Expand Down
22 changes: 22 additions & 0 deletions python/pyspark/tests/test_memory_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,28 @@ def filter_func(iterator):
for id in self.profile_results:
self.assert_udf_memory_profile_present(udf_id=id)

@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
def test_memory_profiler_different_function(self):
df = self.spark.createDataFrame([(1,), (2,), (3,)], ["x"])

def ident(batches):
for b in batches:
yield b

def func(batches):
return ident(batches)

with self.sql_conf({"spark.sql.pyspark.udf.profiler": "memory"}):
df.mapInArrow(func, schema="y long").show()

self.assertEqual(1, len(self.profile_results))

for id in self.profile_results:
self.assert_udf_memory_profile_present(udf_id=id)

@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,11 +1375,12 @@ def wrap_memory_profiler(f, eval_type, result_id):
if _is_iter_based(eval_type):

def profiling_func(*args, **kwargs):
iterator = iter(f(*args, **kwargs))
g = f(*args, **kwargs)
iterator = iter(g)

while True:
try:
with WorkerMemoryProfiler(accumulator, result_id, f):
with WorkerMemoryProfiler(accumulator, result_id, g.gi_code):
item = next(iterator)
yield item
except StopIteration:
Expand Down