diff --git a/python/benchmarks/bench_eval_type.py b/python/benchmarks/bench_eval_type.py index 2cc467e1d1b02..ed648691af827 100644 --- a/python/benchmarks/bench_eval_type.py +++ b/python/benchmarks/bench_eval_type.py @@ -145,6 +145,27 @@ def write_udf_payload( buf.write(command) write_long(0, buf) # result_id + @classmethod + def write_udtf_payload( + cls, + handler: type, + return_type: StructType, + arg_offsets: list[int], + buf: io.BytesIO, + ) -> None: + """Write the ``read_udtf`` portion of the protocol (no partitions, no analyze).""" + write_int(len(arg_offsets), buf) # num_arg + for offset in arg_offsets: + write_int(offset, buf) + cls.write_bool(False, buf) # is_kwarg + write_int(0, buf) # num_partition_child_indexes + cls.write_bool(False, buf) # has_pickled_analyze_result + command = cloudpickle_dumps(handler) + write_int(len(command), buf) + buf.write(command) + cls.write_utf8(return_type.json(), buf) + cls.write_utf8("benchmark_udtf", buf) # udtf_name + @classmethod def write_utf8(cls, s: str, buf: io.BytesIO) -> None: """Write a length-prefixed UTF-8 string (matches ``UTF8Deserializer.loads``).""" @@ -160,18 +181,19 @@ def write_worker_input( write_data: Callable[[io.BufferedIOBase], None], buf: io.BufferedIOBase, runner_conf: dict[str, str] | None = None, + eval_conf: dict[str, str] | None = None, ) -> None: """Write the full worker binary stream: preamble + command + data + end.""" cls.write_preamble(buf) write_int(eval_type, buf) - if runner_conf: - write_int(len(runner_conf), buf) - for k, v in runner_conf.items(): - cls.write_utf8(k, buf) - cls.write_utf8(v, buf) - else: - write_int(0, buf) # RunnerConf (0 key-value pairs) - write_int(0, buf) # EvalConf (0 key-value pairs) + for conf in (runner_conf, eval_conf): + if conf: + write_int(len(conf), buf) + for k, v in conf.items(): + cls.write_utf8(k, buf) + cls.write_utf8(v, buf) + else: + write_int(0, buf) # 0 key-value pairs write_udf(buf) write_data(buf) write_int(-4, buf) # SpecialLengths.END_OF_STREAM @@ -471,6 +493,100 @@ class ArrowBatchedUDFPeakmemBench(_ArrowBatchedBenchMixin, _PeakmemBenchBase): pass +# -- SQL_ARROW_TABLE_UDF ---------------------------------------------------- +# Python UDTF (``@udtf(useArrow=True)``): handler is a class with ``eval(self, *args)`` +# that yields output rows. Each input row triggers one ``eval`` call; yielded rows +# are converted to Arrow via ``LocalDataToArrowConversion``. + + +class _ArrowTableUDFIdentity: + def eval(self, x): + yield (x,) + + +class _ArrowTableUDFExplode: + def eval(self, x): + for _ in range(3): + yield (x,) + + +class _ArrowTableUDFFilter: + def eval(self, x): + if x is not None and (hash(x) & 1): + yield (x,) + + +class _ArrowTableUDFStringify: + def eval(self, x): + yield (str(x),) + + +class _ArrowTableUDFBenchMixin: + """Provides ``_write_scenario`` for SQL_ARROW_TABLE_UDF (Python UDTF, useArrow=True). + + Writes the extra ``input_type`` (StructType JSON) into ``EvalConf`` that the + non-legacy path requires, and uses the UDTF wire protocol (no num_udfs/result_id). + """ + + # Per-input-row ``LocalDataToArrowConversion.convert`` call makes this path + # ~15-20x slower than SQL_ARROW_BATCHED_UDF, so row counts are scaled down + # accordingly to keep each measurement under ASV's per-sample budget. + _scenario_configs = { + "sm_batch_few_col": ("mixed", 2_000, 5, 500), + "sm_batch_many_col": ("mixed", 500, 50, 500), + "lg_batch_few_col": ("mixed", 5_000, 5, 2_500), + "lg_batch_many_col": ("mixed", 2_000, 50, 2_000), + "pure_ints": ("pure_ints", 5_000, 10, 2_500), + "pure_strings": ("pure_strings", 5_000, 10, 2_500), + } + + @staticmethod + def _build_scenario(name): + np.random.seed(42) + type_key, num_rows, num_cols, batch_size = _ArrowTableUDFBenchMixin._scenario_configs[name] + pool = MockDataFactory.NAMED_TYPE_POOLS[type_key] + return MockDataFactory.make_batches( + num_rows=num_rows, + num_cols=num_cols, + spark_type_pool=pool, + batch_size=batch_size, + ) + + # Each entry: (handler_class, return_type_or_None, arg_offsets). + # ``None`` return_type means "use input column 0's type". + _udtfs = { + "identity_udtf": (_ArrowTableUDFIdentity, None, [0]), + "explode_udtf": (_ArrowTableUDFExplode, None, [0]), + "filter_udtf": (_ArrowTableUDFFilter, None, [0]), + "stringify_udtf": (_ArrowTableUDFStringify, StringType(), [0]), + } + params = [list(_scenario_configs), list(_udtfs)] + param_names = ["scenario", "udtf"] + + def _write_scenario(self, scenario, udtf_name, buf): + batches, schema = self._build_scenario(scenario) + handler, ret_type, arg_offsets = self._udtfs[udtf_name] + if ret_type is None: + ret_type = schema.fields[0].dataType + return_type = StructType([StructField("c0", ret_type)]) + + MockProtocolWriter.write_worker_input( + PythonEvalType.SQL_ARROW_TABLE_UDF, + lambda b: MockProtocolWriter.write_udtf_payload(handler, return_type, arg_offsets, b), + lambda b: MockProtocolWriter.write_data_payload(iter(batches), b), + buf, + eval_conf={"input_type": schema.json()}, + ) + + +class ArrowTableUDFTimeBench(_ArrowTableUDFBenchMixin, _TimeBenchBase): + pass + + +class ArrowTableUDFPeakmemBench(_ArrowTableUDFBenchMixin, _PeakmemBenchBase): + pass + + # -- SQL_COGROUPED_MAP_ARROW_UDF ------------------------------------------------ # UDF receives two ``pa.Table`` (left, right) per co-group, returns ``pa.Table``.