diff --git a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py index 13edcec6b57f6..2bdd7bda3bc21 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py @@ -17,6 +17,7 @@ import os import time import unittest +import logging from pyspark.errors import PythonException from pyspark.sql import Row @@ -26,6 +27,8 @@ have_pyarrow, pyarrow_requirement_message, ) +from pyspark.testing.utils import assertDataFrameEqual +from pyspark.util import is_remote_only if have_pyarrow: import pyarrow as pa @@ -367,6 +370,49 @@ def test_negative_and_zero_batch_size(self): with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}): CogroupedMapInArrowTestsMixin.test_apply_in_arrow(self) + @unittest.skipIf(is_remote_only(), "Requires JVM access") + def test_cogroup_apply_in_arrow_with_logging(self): + import pyarrow as pa + + def func_with_logging(left, right): + assert isinstance(left, pa.Table) + assert isinstance(right, pa.Table) + logger = logging.getLogger("test_arrow_cogrouped_map") + logger.warning( + "arrow cogrouped map: " + + f"{dict(v1=left['v1'].to_pylist(), v2=right['v2'].to_pylist())}" + ) + return left.join(right, keys="id", join_type="inner") + + left_df = self.spark.createDataFrame([(1, 10), (2, 20), (1, 30)], ["id", "v1"]) + right_df = self.spark.createDataFrame([(1, 100), (2, 200), (1, 300)], ["id", "v2"]) + + grouped_left = left_df.groupBy("id") + grouped_right = right_df.groupBy("id") + cogrouped_df = grouped_left.cogroup(grouped_right) + + with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}): + assertDataFrameEqual( + cogrouped_df.applyInArrow(func_with_logging, "id long, v1 long, v2 long"), + [Row(id=1, v1=v1, v2=v2) for v1 in [10, 30] for v2 in [100, 300]] + + [Row(id=2, v1=20, v2=200)], + ) + + logs = self.spark.table("system.session.python_worker_logs") + + assertDataFrameEqual( + logs.select("level", "msg", "context", "logger"), + [ + Row( + level="WARNING", + msg=f"arrow cogrouped map: {dict(v1=v1, v2=v2)}", + context={"func_name": func_with_logging.__name__}, + logger="test_arrow_cogrouped_map", + ) + for v1, v2 in [([10, 30], [100, 300]), ([20], [200])] + ], + ) + class CogroupedMapInArrowTests(CogroupedMapInArrowTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py index 8d3d929096b18..829c38385bd0e 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py @@ -17,6 +17,7 @@ import inspect import os import time +import logging from typing import Iterator, Tuple import unittest @@ -29,6 +30,8 @@ have_pyarrow, pyarrow_requirement_message, ) +from pyspark.testing.utils import assertDataFrameEqual +from pyspark.util import is_remote_only if have_pyarrow: import pyarrow as pa @@ -394,6 +397,80 @@ def test_negative_and_zero_batch_size(self): with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}): ApplyInArrowTestsMixin.test_apply_in_arrow(self) + @unittest.skipIf(is_remote_only(), "Requires JVM access") + def test_apply_in_arrow_with_logging(self): + import pyarrow as pa + + def func_with_logging(group): + assert isinstance(group, pa.Table) + logger = logging.getLogger("test_arrow_grouped_map") + logger.warning(f"arrow grouped map: {group.to_pydict()}") + return group + + df = self.spark.range(9).withColumn("value", col("id") * 10) + grouped_df = df.groupBy((col("id") % 2).cast("int")) + + with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}): + assertDataFrameEqual( + grouped_df.applyInArrow(func_with_logging, "id long, value long"), + df, + ) + + logs = self.spark.table("system.session.python_worker_logs") + + assertDataFrameEqual( + logs.select("level", "msg", "context", "logger"), + [ + Row( + level="WARNING", + msg=f"arrow grouped map: {dict(id=lst, value=[v*10 for v in lst])}", + context={"func_name": func_with_logging.__name__}, + logger="test_arrow_grouped_map", + ) + for lst in [[0, 2, 4, 6, 8], [1, 3, 5, 7]] + ], + ) + + @unittest.skipIf(is_remote_only(), "Requires JVM access") + def test_apply_in_arrow_iter_with_logging(self): + import pyarrow as pa + + def func_with_logging(group: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: + logger = logging.getLogger("test_arrow_grouped_map") + for batch in group: + assert isinstance(batch, pa.RecordBatch) + logger.warning(f"arrow grouped map: {batch.to_pydict()}") + yield batch + + df = self.spark.range(9).withColumn("value", col("id") * 10) + grouped_df = df.groupBy((col("id") % 2).cast("int")) + + with self.sql_conf( + { + "spark.sql.execution.arrow.maxRecordsPerBatch": 3, + "spark.sql.pyspark.worker.logging.enabled": "true", + } + ): + assertDataFrameEqual( + grouped_df.applyInArrow(func_with_logging, "id long, value long"), + df, + ) + + logs = self.spark.table("system.session.python_worker_logs") + + assertDataFrameEqual( + logs.select("level", "msg", "context", "logger"), + [ + Row( + level="WARNING", + msg=f"arrow grouped map: {dict(id=lst, value=[v*10 for v in lst])}", + context={"func_name": func_with_logging.__name__}, + logger="test_arrow_grouped_map", + ) + for lst in [[0, 2, 4], [6, 8], [1, 3, 5], [7]] + ], + ) + class ApplyInArrowTests(ApplyInArrowTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/tests/arrow/test_arrow_map.py b/python/pyspark/sql/tests/arrow/test_arrow_map.py index 0f9f5b4224400..4a56a32fbcddb 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_map.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_map.py @@ -17,6 +17,7 @@ import os import time import unittest +import logging from pyspark.sql.utils import PythonException from pyspark.testing.sqlutils import ( @@ -26,6 +27,9 @@ pandas_requirement_message, pyarrow_requirement_message, ) +from pyspark.sql import Row +from pyspark.testing.utils import assertDataFrameEqual +from pyspark.util import is_remote_only if have_pyarrow: import pyarrow as pa @@ -221,6 +225,46 @@ def func(iterator): df = self.spark.range(1) df.mapInArrow(func, "a int").collect() + @unittest.skipIf(is_remote_only(), "Requires JVM access") + def test_map_in_arrow_with_logging(self): + import pyarrow as pa + + def func_with_logging(iterator): + logger = logging.getLogger("test_arrow_map") + for batch in iterator: + assert isinstance(batch, pa.RecordBatch) + logger.warning(f"arrow map: {batch.to_pydict()}") + yield batch + + with self.sql_conf( + { + "spark.sql.execution.arrow.maxRecordsPerBatch": "3", + "spark.sql.pyspark.worker.logging.enabled": "true", + } + ): + assertDataFrameEqual( + self.spark.range(9, numPartitions=2).mapInArrow(func_with_logging, "id long"), + [Row(id=i) for i in range(9)], + ) + + logs = self.spark.table("system.session.python_worker_logs") + + assertDataFrameEqual( + logs.select("level", "msg", "context", "logger"), + self._expected_logs_for_test_map_in_arrow_with_logging(func_with_logging.__name__), + ) + + def _expected_logs_for_test_map_in_arrow_with_logging(self, func_name): + return [ + Row( + level="WARNING", + msg=f"arrow map: {dict(id=lst)}", + context={"func_name": func_name}, + logger="test_arrow_map", + ) + for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]] + ] + class MapInArrowTests(MapInArrowTestsMixin, ReusedSQLTestCase): @classmethod @@ -253,6 +297,17 @@ def setUpClass(cls): cls.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "3") cls.spark.conf.set("spark.sql.execution.arrow.maxBytesPerBatch", "10") + def _expected_logs_for_test_map_in_arrow_with_logging(self, func_name): + return [ + Row( + level="WARNING", + msg=f"arrow map: {dict(id=[i])}", + context={"func_name": func_name}, + logger="test_arrow_map", + ) + for i in range(9) + ] + class MapInArrowWithOutputArrowBatchSlicingRecordsTests(MapInArrowTests): @classmethod diff --git a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py index 55b4edd72d5df..90e05caf21800 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py @@ -60,18 +60,6 @@ def test_register_java_function(self): def test_register_java_udaf(self): super(ArrowPythonUDFTests, self).test_register_java_udaf() - @unittest.skip( - "TODO(SPARK-53976): Python worker logging is not supported for Arrow Python UDFs." - ) - def test_udf_with_logging(self): - super().test_udf_with_logging() - - @unittest.skip( - "TODO(SPARK-53976): Python worker logging is not supported for Arrow Python UDFs." - ) - def test_multiple_udfs_with_logging(self): - super().test_multiple_udfs_with_logging() - def test_complex_input_types(self): row = ( self.spark.range(1) diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py index 136a99e194118..f719b4fb16bd2 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py @@ -16,9 +16,10 @@ # import unittest +import logging from pyspark.sql.functions import arrow_udf, ArrowUDFType -from pyspark.util import PythonEvalType +from pyspark.util import PythonEvalType, is_remote_only from pyspark.sql import Row from pyspark.sql.types import ( ArrayType, @@ -35,6 +36,7 @@ numpy_requirement_message, have_pyarrow, pyarrow_requirement_message, + assertDataFrameEqual, ) from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -1021,6 +1023,42 @@ def arrow_max(v): self.assertEqual(expected, result) + @unittest.skipIf(is_remote_only(), "Requires JVM access") + def test_grouped_agg_arrow_udf_with_logging(self): + import pyarrow as pa + + @arrow_udf("double", ArrowUDFType.GROUPED_AGG) + def my_grouped_agg_arrow_udf(x): + assert isinstance(x, pa.Array) + logger = logging.getLogger("test_grouped_agg_arrow") + logger.warning(f"grouped agg arrow udf: {len(x)}") + return pa.compute.sum(x) + + df = self.spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v") + ) + + with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}): + assertDataFrameEqual( + df.groupby("id").agg(my_grouped_agg_arrow_udf("v").alias("result")), + [Row(id=1, result=3.0), Row(id=2, result=18.0)], + ) + + logs = self.spark.table("system.session.python_worker_logs") + + assertDataFrameEqual( + logs.select("level", "msg", "context", "logger"), + [ + Row( + level="WARNING", + msg=f"grouped agg arrow udf: {n}", + context={"func_name": my_grouped_agg_arrow_udf.__name__}, + logger="test_grouped_agg_arrow", + ) + for n in [2, 3] + ], + ) + class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py index a682c6515ef61..05f33a4ae42f7 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py @@ -20,13 +20,14 @@ import time import unittest import datetime +import logging from decimal import Decimal from typing import Iterator, Tuple from pyspark.util import PythonEvalType from pyspark.sql.functions import arrow_udf, ArrowUDFType -from pyspark.sql import functions as F +from pyspark.sql import Row, functions as F from pyspark.sql.types import ( IntegerType, ByteType, @@ -51,8 +52,10 @@ numpy_requirement_message, have_pyarrow, pyarrow_requirement_message, + assertDataFrameEqual, ) from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.util import is_remote_only @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) @@ -1179,6 +1182,80 @@ def test_unsupported_return_types(self): def func_a(a: pa.Array) -> pa.Array: return a + @unittest.skipIf(is_remote_only(), "Requires JVM access") + def test_scalar_arrow_udf_with_logging(self): + import pyarrow as pa + + @arrow_udf("string") + def my_scalar_arrow_udf(x): + assert isinstance(x, pa.Array) + logger = logging.getLogger("test_scalar_arrow") + logger.warning(f"scalar arrow udf: {x.to_pylist()}") + return pa.array(["scalar_arrow_" + str(val.as_py()) for val in x]) + + with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}): + assertDataFrameEqual( + self.spark.range(3, numPartitions=2).select( + my_scalar_arrow_udf("id").alias("result") + ), + [Row(result=f"scalar_arrow_{i}") for i in range(3)], + ) + + logs = self.spark.table("system.session.python_worker_logs") + + assertDataFrameEqual( + logs.select("level", "msg", "context", "logger"), + [ + Row( + level="WARNING", + msg=f"scalar arrow udf: {lst}", + context={"func_name": my_scalar_arrow_udf.__name__}, + logger="test_scalar_arrow", + ) + for lst in [[0], [1, 2]] + ], + ) + + @unittest.skipIf(is_remote_only(), "Requires JVM access") + def test_scalar_iter_arrow_udf_with_logging(self): + import pyarrow as pa + + @arrow_udf("string", ArrowUDFType.SCALAR_ITER) + def my_scalar_iter_arrow_udf(it): + logger = logging.getLogger("test_scalar_iter_arrow") + for x in it: + assert isinstance(x, pa.Array) + logger.warning(f"scalar iter arrow udf: {x.to_pylist()}") + yield pa.array(["scalar_iter_arrow_" + str(val.as_py()) for val in x]) + + with self.sql_conf( + { + "spark.sql.execution.arrow.maxRecordsPerBatch": "3", + "spark.sql.pyspark.worker.logging.enabled": "true", + } + ): + assertDataFrameEqual( + self.spark.range(9, numPartitions=2).select( + my_scalar_iter_arrow_udf("id").alias("result") + ), + [Row(result=f"scalar_iter_arrow_{i}") for i in range(9)], + ) + + logs = self.spark.table("system.session.python_worker_logs") + + assertDataFrameEqual( + logs.select("level", "msg", "context", "logger"), + [ + Row( + level="WARNING", + msg=f"scalar iter arrow udf: {lst}", + context={"func_name": my_scalar_iter_arrow_udf.__name__}, + logger="test_scalar_iter_arrow", + ) + for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]] + ], + ) + class ScalarArrowUDFTests(ScalarArrowUDFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py index d67b99475bf89..240e34487b006 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py @@ -16,10 +16,11 @@ # import unittest +import logging from pyspark.sql.functions import arrow_udf, ArrowUDFType -from pyspark.util import PythonEvalType -from pyspark.sql import functions as sf +from pyspark.util import PythonEvalType, is_remote_only +from pyspark.sql import Row, functions as sf from pyspark.sql.window import Window from pyspark.errors import AnalysisException, PythonException, PySparkTypeError from pyspark.testing.utils import ( @@ -27,6 +28,7 @@ numpy_requirement_message, have_pyarrow, pyarrow_requirement_message, + assertDataFrameEqual, ) from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -804,6 +806,49 @@ def arrow_sum_unbounded(v): ) self.assertEqual(expected2, result2) + @unittest.skipIf(is_remote_only(), "Requires JVM access") + def test_window_arrow_udf_with_logging(self): + import pyarrow as pa + + @arrow_udf("double", ArrowUDFType.GROUPED_AGG) + def my_window_arrow_udf(x): + assert isinstance(x, pa.Array) + logger = logging.getLogger("test_window_arrow") + logger.warning(f"window arrow udf: {x.to_pylist()}") + return pa.compute.sum(x) + + df = self.spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v") + ) + w = Window.partitionBy("id").orderBy("v").rangeBetween(Window.unboundedPreceding, 0) + + with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}): + assertDataFrameEqual( + df.select("id", my_window_arrow_udf("v").over(w).alias("result")), + [ + Row(id=1, result=1.0), + Row(id=1, result=3.0), + Row(id=2, result=3.0), + Row(id=2, result=8.0), + Row(id=2, result=18.0), + ], + ) + + logs = self.spark.table("system.session.python_worker_logs") + + assertDataFrameEqual( + logs.select("level", "msg", "context", "logger"), + [ + Row( + level="WARNING", + msg=f"window arrow udf: {lst}", + context={"func_name": my_window_arrow_udf.__name__}, + logger="test_window_arrow", + ) + for lst in [[1.0], [1.0, 2.0], [3.0], [3.0, 5.0], [3.0, 5.0, 10.0]] + ], + ) + class WindowArrowUDFTests(WindowArrowUDFTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py index 44bd8a6fa9df1..ab954dd133f3c 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py @@ -16,6 +16,7 @@ # import unittest +import logging from typing import cast from pyspark.sql import functions as sf @@ -38,6 +39,8 @@ pandas_requirement_message, pyarrow_requirement_message, ) +from pyspark.testing.utils import assertDataFrameEqual +from pyspark.util import is_remote_only if have_pandas: import pandas as pd @@ -714,6 +717,48 @@ def test_negative_and_zero_batch_size(self): with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}): CogroupedApplyInPandasTestsMixin.test_with_key_right(self) + @unittest.skipIf(is_remote_only(), "Requires JVM access") + def test_cogroup_apply_in_pandas_with_logging(self): + import pandas as pd + + def func_with_logging(left_pdf, right_pdf): + assert isinstance(left_pdf, pd.DataFrame) + assert isinstance(right_pdf, pd.DataFrame) + logger = logging.getLogger("test_pandas_cogrouped_map") + logger.warning( + f"pandas cogrouped map: {dict(v1=list(left_pdf['v1']), v2=list(right_pdf['v2']))}" + ) + return pd.merge(left_pdf, right_pdf, on=["id"]) + + left_df = self.spark.createDataFrame([(1, 10), (2, 20), (1, 30)], ["id", "v1"]) + right_df = self.spark.createDataFrame([(1, 100), (2, 200), (1, 300)], ["id", "v2"]) + + grouped_left = left_df.groupBy("id") + grouped_right = right_df.groupBy("id") + cogrouped_df = grouped_left.cogroup(grouped_right) + + with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}): + assertDataFrameEqual( + cogrouped_df.applyInPandas(func_with_logging, "id long, v1 long, v2 long"), + [Row(id=1, v1=v1, v2=v2) for v1 in [10, 30] for v2 in [100, 300]] + + [Row(id=2, v1=20, v2=200)], + ) + + logs = self.spark.table("system.session.python_worker_logs") + + assertDataFrameEqual( + logs.select("level", "msg", "context", "logger"), + [ + Row( + level="WARNING", + msg=f"pandas cogrouped map: {dict(v1=v1, v2=v2)}", + context={"func_name": func_with_logging.__name__}, + logger="test_pandas_cogrouped_map", + ) + for v1, v2 in [([10, 30], [100, 300]), ([20], [200])] + ], + ) + class CogroupedApplyInPandasTests(CogroupedApplyInPandasTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py index 4c52303481fa7..0e922d0728714 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py @@ -17,6 +17,7 @@ import datetime import unittest +import logging from collections import OrderedDict from decimal import Decimal @@ -60,6 +61,8 @@ pandas_requirement_message, pyarrow_requirement_message, ) +from pyspark.testing.utils import assertDataFrameEqual +from pyspark.util import is_remote_only if have_pandas: import pandas as pd @@ -985,6 +988,42 @@ def test_negative_and_zero_batch_size(self): with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}): ApplyInPandasTestsMixin.test_complex_groupby(self) + @unittest.skipIf(is_remote_only(), "Requires JVM access") + def test_apply_in_pandas_with_logging(self): + import pandas as pd + + def func_with_logging(pdf): + assert isinstance(pdf, pd.DataFrame) + logger = logging.getLogger("test_pandas_grouped_map") + logger.warning( + f"pandas grouped map: {dict(id=list(pdf['id']), value=list(pdf['value']))}" + ) + return pdf + + df = self.spark.range(9).withColumn("value", col("id") * 10) + grouped_df = df.groupBy((col("id") % 2).cast("int")) + + with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}): + assertDataFrameEqual( + grouped_df.applyInPandas(func_with_logging, "id long, value long"), + df, + ) + + logs = self.spark.table("system.session.python_worker_logs") + + assertDataFrameEqual( + logs.select("level", "msg", "context", "logger"), + [ + Row( + level="WARNING", + msg=f"pandas grouped map: {dict(id=lst, value=[v*10 for v in lst])}", + context={"func_name": func_with_logging.__name__}, + logger="test_pandas_grouped_map", + ) + for lst in [[0, 2, 4, 6, 8], [1, 3, 5, 7]] + ], + ) + class ApplyInPandasTests(ApplyInPandasTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py b/python/pyspark/sql/tests/pandas/test_pandas_map.py index b241b91e02a29..5e0e33a05b22b 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py @@ -19,6 +19,7 @@ import tempfile import time import unittest +import logging from typing import cast from pyspark.sql import Row @@ -33,7 +34,8 @@ pandas_requirement_message, pyarrow_requirement_message, ) -from pyspark.testing.utils import eventually +from pyspark.testing.utils import assertDataFrameEqual, eventually +from pyspark.util import is_remote_only if have_pandas: import pandas as pd @@ -486,6 +488,43 @@ def func(iterator): df = self.spark.range(1) self.assertEqual([Row(a=2, b=1)], df.mapInPandas(func, "a int, b int").collect()) + @unittest.skipIf(is_remote_only(), "Requires JVM access") + def test_map_in_pandas_with_logging(self): + import pandas as pd + + def func_with_logging(iterator): + logger = logging.getLogger("test_pandas_map") + for pdf in iterator: + assert isinstance(pdf, pd.DataFrame) + logger.warning(f"pandas map: {list(pdf['id'])}") + yield pdf + + with self.sql_conf( + { + "spark.sql.execution.arrow.maxRecordsPerBatch": "3", + "spark.sql.pyspark.worker.logging.enabled": "true", + } + ): + assertDataFrameEqual( + self.spark.range(9, numPartitions=2).mapInPandas(func_with_logging, "id long"), + [Row(id=i) for i in range(9)], + ) + + logs = self.spark.table("system.session.python_worker_logs") + + assertDataFrameEqual( + logs.select("level", "msg", "context", "logger"), + [ + Row( + level="WARNING", + msg=f"pandas map: {lst}", + context={"func_name": func_with_logging.__name__}, + logger="test_pandas_map", + ) + for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]] + ], + ) + class MapInPandasTests(ReusedSQLTestCase, MapInPandasTestsMixin): @classmethod diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py index 3fd970061b30a..2b3e42312df99 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py @@ -16,9 +16,10 @@ # import unittest +import logging from typing import cast -from pyspark.util import PythonEvalType +from pyspark.util import PythonEvalType, is_remote_only from pyspark.sql import Row, functions as sf from pyspark.sql.functions import ( array, @@ -826,6 +827,40 @@ def pandas_max(v): self.assertEqual(expected, result) + @unittest.skipIf(is_remote_only(), "Requires JVM access") + def test_grouped_agg_pandas_udf_with_logging(self): + @pandas_udf("double", PandasUDFType.GROUPED_AGG) + def my_grouped_agg_pandas_udf(x): + assert isinstance(x, pd.Series) + logger = logging.getLogger("test_grouped_agg_pandas") + logger.warning(f"grouped agg pandas udf: {len(x)}") + return x.sum() + + df = self.spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v") + ) + + with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}): + assertDataFrameEqual( + df.groupby("id").agg(my_grouped_agg_pandas_udf("v").alias("result")), + [Row(id=1, result=3.0), Row(id=2, result=18.0)], + ) + + logs = self.spark.table("system.session.python_worker_logs") + + assertDataFrameEqual( + logs.select("level", "msg", "context", "logger"), + [ + Row( + level="WARNING", + msg=f"grouped agg pandas udf: {n}", + context={"func_name": my_grouped_agg_pandas_udf.__name__}, + logger="test_grouped_agg_pandas", + ) + for n in [2, 3] + ], + ) + class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py index 3c2ae56067ae6..fbfe1a226b5e2 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py @@ -20,13 +20,14 @@ import tempfile import time import unittest +import logging from datetime import date, datetime from decimal import Decimal from typing import cast from pyspark import TaskContext -from pyspark.util import PythonEvalType -from pyspark.sql import Column +from pyspark.util import PythonEvalType, is_remote_only +from pyspark.sql import Column, Row from pyspark.sql.functions import ( array, col, @@ -1917,6 +1918,76 @@ def test_arrow_cast_enabled_str_to_numeric(self): row = df.select(pandas_udf(lambda _: pd.Series(["123"]), t)(df.id)).first() assert row[0] == 123 + @unittest.skipIf(is_remote_only(), "Requires JVM access") + def test_scalar_pandas_udf_with_logging(self): + @pandas_udf("string", PandasUDFType.SCALAR) + def my_scalar_pandas_udf(x): + assert isinstance(x, pd.Series) + logger = logging.getLogger("test_scalar_pandas") + logger.warning(f"scalar pandas udf: {list(x)}") + return pd.Series(["scalar_pandas_" + str(val) for val in x]) + + with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}): + assertDataFrameEqual( + self.spark.range(3, numPartitions=2).select( + my_scalar_pandas_udf("id").alias("result") + ), + [Row(result=f"scalar_pandas_{i}") for i in range(3)], + ) + + logs = self.spark.table("system.session.python_worker_logs") + + assertDataFrameEqual( + logs.select("level", "msg", "context", "logger"), + [ + Row( + level="WARNING", + msg=f"scalar pandas udf: {lst}", + context={"func_name": my_scalar_pandas_udf.__name__}, + logger="test_scalar_pandas", + ) + for lst in [[0], [1, 2]] + ], + ) + + @unittest.skipIf(is_remote_only(), "Requires JVM access") + def test_scalar_iter_pandas_udf_with_logging(self): + @pandas_udf("string", PandasUDFType.SCALAR_ITER) + def my_scalar_iter_pandas_udf(it): + logger = logging.getLogger("test_scalar_iter_pandas") + for x in it: + assert isinstance(x, pd.Series) + logger.warning(f"scalar iter pandas udf: {list(x)}") + yield pd.Series(["scalar_iter_pandas_" + str(val) for val in x]) + + with self.sql_conf( + { + "spark.sql.execution.arrow.maxRecordsPerBatch": "3", + "spark.sql.pyspark.worker.logging.enabled": "true", + } + ): + assertDataFrameEqual( + self.spark.range(9, numPartitions=2).select( + my_scalar_iter_pandas_udf("id").alias("result") + ), + [Row(result=f"scalar_iter_pandas_{i}") for i in range(9)], + ) + + logs = self.spark.table("system.session.python_worker_logs") + + assertDataFrameEqual( + logs.select("level", "msg", "context", "logger"), + [ + Row( + level="WARNING", + msg=f"scalar iter pandas udf: {lst}", + context={"func_name": my_scalar_iter_pandas_udf.__name__}, + logger="test_scalar_iter_pandas", + ) + for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]] + ], + ) + class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py index 547e237902b3f..6fa7e9063836b 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py @@ -16,6 +16,7 @@ # import unittest +import logging from typing import cast from decimal import Decimal @@ -38,6 +39,8 @@ pyarrow_requirement_message, ) from pyspark.testing.utils import assertDataFrameEqual +from pyspark.sql import Row +from pyspark.util import is_remote_only if have_pandas: from pandas.testing import assert_frame_equal @@ -633,6 +636,49 @@ def pandas_sum_unbounded(v): ) self.assertEqual(expected2, result2) + @unittest.skipIf(is_remote_only(), "Requires JVM access") + def test_window_pandas_udf_with_logging(self): + import pandas as pd + + @pandas_udf("double", PandasUDFType.GROUPED_AGG) + def my_window_pandas_udf(x): + assert isinstance(x, pd.Series) + logger = logging.getLogger("test_window_pandas") + logger.warning(f"window pandas udf: {list(x)}") + return x.sum() + + df = self.spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v") + ) + w = Window.partitionBy("id").orderBy("v").rangeBetween(Window.unboundedPreceding, 0) + + with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}): + assertDataFrameEqual( + df.select("id", my_window_pandas_udf("v").over(w).alias("result")), + [ + Row(id=1, result=1.0), + Row(id=1, result=3.0), + Row(id=2, result=3.0), + Row(id=2, result=8.0), + Row(id=2, result=18.0), + ], + ) + + logs = self.spark.table("system.session.python_worker_logs") + + assertDataFrameEqual( + logs.select("level", "msg", "context", "logger"), + [ + Row( + level="WARNING", + msg=f"window pandas udf: {lst}", + context={"func_name": my_window_pandas_udf.__name__}, + logger="test_window_pandas", + ) + for lst in [[1.0], [1.0, 2.0], [3.0], [3.0, 5.0], [3.0, 5.0, 10.0]] + ], + ) + class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase): pass diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala index 26bd5368e6f9c..63e7e32c1c7b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala @@ -170,7 +170,8 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { conf.arrowUseLargeVarTypes, pythonRunnerConf, metrics, - jobArtifactUUID) + jobArtifactUUID, + None) // TODO: Python worker logging } def createPythonMetrics(): Array[CustomMetric] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala index f4e8831f23b85..a3d6c57c58bdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala @@ -144,6 +144,12 @@ case class ArrowAggregatePythonExec( val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + val sessionUUID = { + Option(session).collect { + case session if session.sessionState.conf.pythonWorkerLoggingEnabled => + session.sessionUUID + } + } // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { @@ -190,6 +196,7 @@ case class ArrowAggregatePythonExec( pythonRunnerConf, pythonMetrics, jobArtifactUUID, + sessionUUID, conf.pythonUDFProfiler) with GroupedPythonArrowInput val columnarBatchIter = runner.compute(projectedRowIter, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 92236ca42b2db..7498815cda4e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -82,6 +82,12 @@ case class ArrowEvalPythonExec( } private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + private[this] val sessionUUID = { + Option(session).collect { + case session if session.sessionState.conf.pythonWorkerLoggingEnabled => + session.sessionUUID + } + } override protected def evaluatorFactory: EvalPythonEvaluatorFactory = { new ArrowEvalPythonEvaluatorFactory( @@ -95,6 +101,7 @@ case class ArrowEvalPythonExec( ArrowPythonRunner.getPythonRunnerConfMap(conf), pythonMetrics, jobArtifactUUID, + sessionUUID, conf.pythonUDFProfiler) } @@ -121,6 +128,7 @@ class ArrowEvalPythonEvaluatorFactory( pythonRunnerConf: Map[String, String], pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], + sessionUUID: Option[String], profiler: Option[String]) extends EvalPythonEvaluatorFactory(childOutput, udfs, output) { @@ -147,6 +155,7 @@ class ArrowEvalPythonEvaluatorFactory( pythonRunnerConf, pythonMetrics, jobArtifactUUID, + sessionUUID, profiler) with BatchedPythonArrowInput val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 77aec2a35f21d..b94e00bc11ef2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.python import java.io.DataOutputStream +import java.util import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow @@ -36,12 +37,20 @@ abstract class BaseArrowPythonRunner[IN, OUT <: AnyRef]( protected override val largeVarTypes: Boolean, protected override val workerConf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) + jobArtifactUUID: Option[String], + sessionUUID: Option[String]) extends BasePythonRunner[IN, OUT]( funcs.map(_._1), evalType, argOffsets, jobArtifactUUID, pythonMetrics) with PythonArrowInput[IN] with PythonArrowOutput[OUT] { + override val envVars: util.Map[String, String] = { + val envVars = new util.HashMap(funcs.head._1.funcs.head.envVars) + sessionUUID.foreach { uuid => + envVars.put("PYSPARK_SPARK_SESSION_UUID", uuid) + } + envVars + } override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( funcs.head._1.funcs.head.pythonExec) @@ -77,10 +86,11 @@ abstract class RowInputArrowPythonRunner( largeVarTypes: Boolean, workerConf: Map[String, String], pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) + jobArtifactUUID: Option[String], + sessionUUID: Option[String]) extends BaseArrowPythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf, - pythonMetrics, jobArtifactUUID) + pythonMetrics, jobArtifactUUID, sessionUUID) with BasicPythonArrowInput with BasicPythonArrowOutput @@ -97,10 +107,11 @@ class ArrowPythonRunner( workerConf: Map[String, String], pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], + sessionUUID: Option[String], profiler: Option[String]) extends RowInputArrowPythonRunner( funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf, - pythonMetrics, jobArtifactUUID) { + pythonMetrics, jobArtifactUUID, sessionUUID) { override protected def writeUDF(dataOut: DataOutputStream): Unit = PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler) @@ -120,10 +131,11 @@ class ArrowPythonWithNamedArgumentRunner( workerConf: Map[String, String], pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], + sessionUUID: Option[String], profiler: Option[String]) extends RowInputArrowPythonRunner( funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes, workerConf, - pythonMetrics, jobArtifactUUID) { + pythonMetrics, jobArtifactUUID, sessionUUID) { override protected def writeUDF(dataOut: DataOutputStream): Unit = { if (evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala index 82c03b1d02293..2bf974d9026ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala @@ -45,6 +45,7 @@ class ArrowWindowPythonEvaluatorFactory( val evalType: Int, val spillSize: SQLMetric, pythonMetrics: Map[String, SQLMetric], + sessionUUID: Option[String], profiler: Option[String]) extends PartitionEvaluatorFactory[InternalRow, InternalRow] with WindowEvaluatorFactoryBase { @@ -378,6 +379,7 @@ class ArrowWindowPythonEvaluatorFactory( pythonRunnerConf, pythonMetrics, jobArtifactUUID, + sessionUUID, profiler) with GroupedPythonArrowInput val windowFunctionResult = runner.compute(pythonInput, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonExec.scala index c8259c10dbd93..ba3ffe7639eba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonExec.scala @@ -91,6 +91,13 @@ case class ArrowWindowPythonExec( "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size") ) + private[this] val sessionUUID = { + Option(session).collect { + case session if session.sessionState.conf.pythonWorkerLoggingEnabled => + session.sessionUUID + } + } + protected override def doExecute(): RDD[InternalRow] = { val evaluatorFactory = new ArrowWindowPythonEvaluatorFactory( @@ -101,6 +108,7 @@ case class ArrowWindowPythonExec( evalType, longMetric("spillSize"), pythonMetrics, + sessionUUID, conf.pythonUDFProfiler) // Start processing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index 9dbdd285338ec..00eb9039d05cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.python import java.io.DataOutputStream +import java.util import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, PythonWorker} @@ -43,12 +44,20 @@ class CoGroupedArrowPythonRunner( conf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], + sessionUUID: Option[String], profiler: Option[String]) extends BasePythonRunner[ (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch]( funcs.map(_._1), evalType, argOffsets, jobArtifactUUID, pythonMetrics) with BasicPythonArrowOutput { + override val envVars: util.Map[String, String] = { + val envVars = new util.HashMap(funcs.head._1.funcs.head.envVars) + sessionUUID.foreach { uuid => + envVars.put("PYSPARK_SPARK_SESSION_UUID", uuid) + } + envVars + } override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( funcs.head._1.funcs.head.pythonExec) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala index af487218391e3..38427866458ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala @@ -68,6 +68,12 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with BinaryExecNode with Pyth val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup) val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup) val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + val sessionUUID = { + Option(session).collect { + case session if session.sessionState.conf.pythonWorkerLoggingEnabled => + session.sessionUUID + } + } // Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty left.execute().zipPartitions(right.execute()) { (leftData, rightData) => @@ -89,6 +95,7 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with BinaryExecNode with Pyth pythonRunnerConf, pythonMetrics, jobArtifactUUID, + sessionUUID, conf.pythonUDFProfiler) executePython(data, output, runner) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala index 57a50a8fc8578..7d221552226d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala @@ -48,6 +48,12 @@ trait FlatMapGroupsInBatchExec extends SparkPlan with UnaryExecNode with PythonS private val chainedFunc = Seq((ChainedPythonFunctions(Seq(pythonFunction)), pythonUDF.resultId.id)) private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + private[this] val sessionUUID = { + Option(session).collect { + case session if session.sessionState.conf.pythonWorkerLoggingEnabled => + session.sessionUUID + } + } override def producedAttributes: AttributeSet = AttributeSet(output) @@ -92,6 +98,7 @@ trait FlatMapGroupsInBatchExec extends SparkPlan with UnaryExecNode with PythonS pythonRunnerConf, pythonMetrics, jobArtifactUUID, + sessionUUID, conf.pythonUDFProfiler) with GroupedPythonArrowInput executePython(data, output, runner) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala index 9e3e8610ed375..4e78b3035a7ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala @@ -40,7 +40,8 @@ class MapInBatchEvaluatorFactory( largeVarTypes: Boolean, pythonRunnerConf: Map[String, String], val pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) + jobArtifactUUID: Option[String], + sessionUUID: Option[String]) extends PartitionEvaluatorFactory[InternalRow, InternalRow] { override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] = @@ -72,6 +73,7 @@ class MapInBatchEvaluatorFactory( pythonRunnerConf, pythonMetrics, jobArtifactUUID, + sessionUUID, None) with BatchedPythonArrowInput val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala index c003d503c7caf..1d03c0cf76037 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala @@ -44,6 +44,12 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { override def producedAttributes: AttributeSet = AttributeSet(output) private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + private[this] val sessionUUID = { + Option(session).collect { + case session if session.sessionState.conf.pythonWorkerLoggingEnabled => + session.sessionUUID + } + } override def outputPartitioning: Partitioning = child.outputPartitioning @@ -63,7 +69,8 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { conf.arrowUseLargeVarTypes, pythonRunnerConf, pythonMetrics, - jobArtifactUUID) + jobArtifactUUID, + sessionUUID) val rdd = if (isBarrier) { val rddBarrier = child.execute().barrier()