From 31e7c37354132545da59bff176af1613bd09447c Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 28 Jun 2019 17:10:25 +0900 Subject: [PATCH] [SPARK-28185][PYTHON][SQL] Closes the generator when Python UDFs stop early ## What changes were proposed in this pull request? Closes the generator when Python UDFs stop early. ### Manually verification on pandas iterator UDF and mapPartitions ```python from pyspark.sql import SparkSession from pyspark.sql.functions import pandas_udf, PandasUDFType from pyspark.sql.functions import col, udf from pyspark.taskcontext import TaskContext import time import os spark.conf.set('spark.sql.execution.arrow.maxRecordsPerBatch', '1') spark.conf.set('spark.sql.pandas.udf.buffer.size', '4') pandas_udf("int", PandasUDFType.SCALAR_ITER) def fi1(it): try: for batch in it: yield batch + 100 time.sleep(1.0) except BaseException as be: print("Debug: exception raised: " + str(type(be))) raise be finally: open("/tmp/000001.tmp", "a").close() df1 = spark.range(10).select(col('id').alias('a')).repartition(1) # will see log Debug: exception raised: # and file "/tmp/000001.tmp" generated. df1.select(col('a'), fi1('a')).limit(2).collect() def mapper(it): try: for batch in it: yield batch except BaseException as be: print("Debug: exception raised: " + str(type(be))) raise be finally: open("/tmp/000002.tmp", "a").close() df2 = spark.range(10000000).repartition(1) # will see log Debug: exception raised: # and file "/tmp/000002.tmp" generated. df2.rdd.mapPartitions(mapper).take(2) ``` ## How was this patch tested? Unit test added. Please review https://spark.apache.org/contributing.html before opening a pull request. Closes #24986 from WeichenXu123/pandas_iter_udf_limit. Authored-by: WeichenXu Signed-off-by: HyukjinKwon --- .../sql/tests/test_pandas_udf_scalar.py | 37 +++++++++++++++++++ python/pyspark/worker.py | 7 +++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index c291d4287e11b..d254508e5d35b 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -850,6 +850,43 @@ def test_close(batch_iter): with self.assertRaisesRegexp(Exception, "reached finally block"): self.spark.range(1).select(test_close(col("id"))).collect() + def test_scalar_iter_udf_close_early(self): + tmp_dir = tempfile.mkdtemp() + try: + tmp_file = tmp_dir + '/reach_finally_block' + + @pandas_udf('int', PandasUDFType.SCALAR_ITER) + def test_close(batch_iter): + generator_exit_caught = False + try: + for batch in batch_iter: + yield batch + time.sleep(1.0) # avoid the function finish too fast. + except GeneratorExit as ge: + generator_exit_caught = True + raise ge + finally: + assert generator_exit_caught, "Generator exit exception was not caught." + open(tmp_file, 'a').close() + + with QuietTest(self.sc): + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 1, + "spark.sql.pandas.udf.buffer.size": 4}): + self.spark.range(10).repartition(1) \ + .select(test_close(col("id"))).limit(2).collect() + # wait here because python udf worker will take some time to detect + # jvm side socket closed and then will trigger `GenerateExit` raised. + # wait timeout is 10s. + for i in range(100): + time.sleep(0.1) + if os.path.exists(tmp_file): + break + + assert os.path.exists(tmp_file), "finally block not reached." + + finally: + shutil.rmtree(tmp_dir) + # Regression test for SPARK-23314 def test_timestamp_dst(self): # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index ee46bb649d1fe..04376c9008288 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -481,7 +481,12 @@ def main(infile, outfile): def process(): iterator = deserializer.load_stream(infile) - serializer.dump_stream(func(split_index, iterator), outfile) + out_iter = func(split_index, iterator) + try: + serializer.dump_stream(out_iter, outfile) + finally: + if hasattr(out_iter, 'close'): + out_iter.close() if profiler: profiler.profile(process)