Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 124 additions & 8 deletions python/benchmarks/bench_eval_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,27 @@ def write_udf_payload(
buf.write(command)
write_long(0, buf) # result_id

@classmethod
def write_udtf_payload(
cls,
handler: type,
return_type: StructType,
arg_offsets: list[int],
buf: io.BytesIO,
) -> None:
"""Write the ``read_udtf`` portion of the protocol (no partitions, no analyze)."""
write_int(len(arg_offsets), buf) # num_arg
for offset in arg_offsets:
write_int(offset, buf)
cls.write_bool(False, buf) # is_kwarg
write_int(0, buf) # num_partition_child_indexes
cls.write_bool(False, buf) # has_pickled_analyze_result
command = cloudpickle_dumps(handler)
write_int(len(command), buf)
buf.write(command)
cls.write_utf8(return_type.json(), buf)
cls.write_utf8("benchmark_udtf", buf) # udtf_name

@classmethod
def write_utf8(cls, s: str, buf: io.BytesIO) -> None:
"""Write a length-prefixed UTF-8 string (matches ``UTF8Deserializer.loads``)."""
Expand All @@ -160,18 +181,19 @@ def write_worker_input(
write_data: Callable[[io.BufferedIOBase], None],
buf: io.BufferedIOBase,
runner_conf: dict[str, str] | None = None,
eval_conf: dict[str, str] | None = None,
) -> None:
"""Write the full worker binary stream: preamble + command + data + end."""
cls.write_preamble(buf)
write_int(eval_type, buf)
if runner_conf:
write_int(len(runner_conf), buf)
for k, v in runner_conf.items():
cls.write_utf8(k, buf)
cls.write_utf8(v, buf)
else:
write_int(0, buf) # RunnerConf (0 key-value pairs)
write_int(0, buf) # EvalConf (0 key-value pairs)
for conf in (runner_conf, eval_conf):
if conf:
write_int(len(conf), buf)
for k, v in conf.items():
cls.write_utf8(k, buf)
cls.write_utf8(v, buf)
else:
write_int(0, buf) # 0 key-value pairs
write_udf(buf)
write_data(buf)
write_int(-4, buf) # SpecialLengths.END_OF_STREAM
Expand Down Expand Up @@ -471,6 +493,100 @@ class ArrowBatchedUDFPeakmemBench(_ArrowBatchedBenchMixin, _PeakmemBenchBase):
pass


# -- SQL_ARROW_TABLE_UDF ----------------------------------------------------
# Python UDTF (``@udtf(useArrow=True)``): handler is a class with ``eval(self, *args)``
# that yields output rows. Each input row triggers one ``eval`` call; yielded rows
# are converted to Arrow via ``LocalDataToArrowConversion``.


class _ArrowTableUDFIdentity:
def eval(self, x):
yield (x,)


class _ArrowTableUDFExplode:
def eval(self, x):
for _ in range(3):
yield (x,)


class _ArrowTableUDFFilter:
def eval(self, x):
if x is not None and (hash(x) & 1):
yield (x,)


class _ArrowTableUDFStringify:
def eval(self, x):
yield (str(x),)


class _ArrowTableUDFBenchMixin:
"""Provides ``_write_scenario`` for SQL_ARROW_TABLE_UDF (Python UDTF, useArrow=True).

Writes the extra ``input_type`` (StructType JSON) into ``EvalConf`` that the
non-legacy path requires, and uses the UDTF wire protocol (no num_udfs/result_id).
"""

# Per-input-row ``LocalDataToArrowConversion.convert`` call makes this path
# ~15-20x slower than SQL_ARROW_BATCHED_UDF, so row counts are scaled down
# accordingly to keep each measurement under ASV's per-sample budget.
_scenario_configs = {
"sm_batch_few_col": ("mixed", 2_000, 5, 500),
"sm_batch_many_col": ("mixed", 500, 50, 500),
"lg_batch_few_col": ("mixed", 5_000, 5, 2_500),
"lg_batch_many_col": ("mixed", 2_000, 50, 2_000),
"pure_ints": ("pure_ints", 5_000, 10, 2_500),
"pure_strings": ("pure_strings", 5_000, 10, 2_500),
}

@staticmethod
def _build_scenario(name):
np.random.seed(42)
type_key, num_rows, num_cols, batch_size = _ArrowTableUDFBenchMixin._scenario_configs[name]
pool = MockDataFactory.NAMED_TYPE_POOLS[type_key]
return MockDataFactory.make_batches(
num_rows=num_rows,
num_cols=num_cols,
spark_type_pool=pool,
batch_size=batch_size,
)

# Each entry: (handler_class, return_type_or_None, arg_offsets).
# ``None`` return_type means "use input column 0's type".
_udtfs = {
"identity_udtf": (_ArrowTableUDFIdentity, None, [0]),
"explode_udtf": (_ArrowTableUDFExplode, None, [0]),
"filter_udtf": (_ArrowTableUDFFilter, None, [0]),
"stringify_udtf": (_ArrowTableUDFStringify, StringType(), [0]),
}
params = [list(_scenario_configs), list(_udtfs)]
param_names = ["scenario", "udtf"]

def _write_scenario(self, scenario, udtf_name, buf):
batches, schema = self._build_scenario(scenario)
handler, ret_type, arg_offsets = self._udtfs[udtf_name]
if ret_type is None:
ret_type = schema.fields[0].dataType
return_type = StructType([StructField("c0", ret_type)])

MockProtocolWriter.write_worker_input(
PythonEvalType.SQL_ARROW_TABLE_UDF,
lambda b: MockProtocolWriter.write_udtf_payload(handler, return_type, arg_offsets, b),
lambda b: MockProtocolWriter.write_data_payload(iter(batches), b),
buf,
eval_conf={"input_type": schema.json()},
)


class ArrowTableUDFTimeBench(_ArrowTableUDFBenchMixin, _TimeBenchBase):
pass


class ArrowTableUDFPeakmemBench(_ArrowTableUDFBenchMixin, _PeakmemBenchBase):
pass


# -- SQL_COGROUPED_MAP_ARROW_UDF ------------------------------------------------
# UDF receives two ``pa.Table`` (left, right) per co-group, returns ``pa.Table``.

Expand Down