diff --git a/pyproject.toml b/pyproject.toml index 69d31ec9f..88eaca840 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,20 +69,20 @@ asyncio_default_fixture_loop_scope = "function" # Enable docstring linting using the google style guide [tool.ruff.lint] -select = ["ALL" ] +select = ["ALL"] ignore = [ - "A001", # Allow using words like min as variable names - "A002", # Allow using words like filter as variable names - "ANN401", # Allow Any for wrapper classes - "COM812", # Recommended to ignore these rules when using with ruff-format - "FIX002", # Allow TODO lines - consider removing at some point - "FBT001", # Allow boolean positional args - "FBT002", # Allow boolean positional args - "ISC001", # Recommended to ignore these rules when using with ruff-format - "SLF001", # Allow accessing private members + "A001", # Allow using words like min as variable names + "A002", # Allow using words like filter as variable names + "ANN401", # Allow Any for wrapper classes + "COM812", # Recommended to ignore these rules when using with ruff-format + "FIX002", # Allow TODO lines - consider removing at some point + "FBT001", # Allow boolean positional args + "FBT002", # Allow boolean positional args + "ISC001", # Recommended to ignore these rules when using with ruff-format + "SLF001", # Allow accessing private members "TD002", - "TD003", # Allow TODO lines - "UP007", # Disallowing Union is pedantic + "TD003", # Allow TODO lines + "UP007", # Disallowing Union is pedantic # TODO: Enable all of the following, but this PR is getting too large already "PLR0913", "TRY003", @@ -129,25 +129,33 @@ extend-allowed-calls = ["lit", "datafusion.lit"] ] "examples/*" = ["D", "W505", "E501", "T201", "S101"] "dev/*" = ["D", "E", "T", "S", "PLR", "C", "SIM", "UP", "EXE", "N817"] -"benchmarks/*" = ["D", "F", "T", "BLE", "FURB", "PLR", "E", "TD", "TRY", "S", "SIM", "EXE", "UP"] +"benchmarks/*" = [ + "D", + "F", + "T", + "BLE", + "FURB", + "PLR", + "E", + "TD", + "TRY", + "S", + "SIM", + "EXE", + "UP", +] "docs/*" = ["D"] [tool.codespell] -skip = [ - "./target", - "uv.lock", - "./python/tests/test_functions.py" -] +skip = ["./target", "uv.lock", "./python/tests/test_functions.py"] count = true -ignore-words-list = [ - "ans", - "IST" -] +ignore-words-list = ["ans", "IST"] [dependency-groups] dev = [ "maturin>=1.8.1", "numpy>1.25.0", + "pyarrow>=19.0.0", "pre-commit>=4.0.0", "pytest>=7.4.4", "pytest-asyncio>=0.23.3", diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 67568e313..26f49549b 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -22,15 +22,39 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload +from typing import ( + Any, + Callable, + Optional, + Protocol, + Sequence, + TypeVar, + Union, + cast, + overload, +) import pyarrow as pa import datafusion._internal as df_internal from datafusion.expr import Expr -if TYPE_CHECKING: - _R = TypeVar("_R", bound=pa.DataType) +PyArrowArray = Union[pa.Array, pa.ChunkedArray] +# Type alias for array batches exchanged with Python scalar UDFs. +# +# We need two related but different annotations here: +# - `PyArrowArray` is the concrete union type (pa.Array | pa.ChunkedArray) +# that is convenient for user-facing callables and casts. Use this when +# annotating or checking values that may be either an Array or +# a ChunkedArray. +# - `PyArrowArrayT` is a constrained `TypeVar` over the two concrete +# array flavors. Keeping a generic TypeVar allows helpers like +# `_wrap_extension_value` and `_wrap_udf_function` to remain generic +# and preserve the specific array "flavor" (Array vs ChunkedArray) +# flowing through them, rather than collapsing everything to the +# wide union. This improves type-checking and keeps return types +# precise in the wrapper logic. +PyArrowArrayT = TypeVar("PyArrowArrayT", pa.Array, pa.ChunkedArray) class Volatility(Enum): @@ -77,6 +101,87 @@ def __str__(self) -> str: return self.name.lower() +def _clone_field(field: pa.Field) -> pa.Field: + """Return a deep copy of ``field`` including its DataType.""" + return pa.schema([field]).field(0) + + +def _normalize_field(value: pa.DataType | pa.Field, *, default_name: str) -> pa.Field: + if isinstance(value, pa.Field): + return _clone_field(value) + if isinstance(value, pa.DataType): + return _clone_field(pa.field(default_name, value)) + msg = "Expected a pyarrow.DataType or pyarrow.Field" + raise TypeError(msg) + + +def _normalize_input_fields( + values: pa.DataType | pa.Field | Sequence[pa.DataType | pa.Field], +) -> list[pa.Field]: + if isinstance(values, (pa.DataType, pa.Field)): + sequence: Sequence[pa.DataType | pa.Field] = [values] + elif isinstance(values, Sequence) and not isinstance(values, (str, bytes)): + sequence = values + else: + msg = "input_types must be a DataType, Field, or a sequence of them" + raise TypeError(msg) + + return [ + _normalize_field(value, default_name=f"arg_{idx}") + for idx, value in enumerate(sequence) + ] + + +def _normalize_return_field( + value: pa.DataType | pa.Field, + *, + name: str, +) -> pa.Field: + default_name = f"{name}_result" if name else "result" + return _normalize_field(value, default_name=default_name) + + +def _wrap_extension_value( + value: PyArrowArrayT, data_type: pa.DataType +) -> PyArrowArrayT: + storage_type = getattr(data_type, "storage_type", None) + wrap_array = getattr(data_type, "wrap_array", None) + if storage_type is None or wrap_array is None: + return value + if isinstance(value, pa.Array) and value.type.equals(storage_type): + return wrap_array(value) + if isinstance(value, pa.ChunkedArray) and value.type.equals(storage_type): + wrapped_chunks = [wrap_array(chunk) for chunk in value.chunks] + if not wrapped_chunks: + empty_storage = pa.array([], type=storage_type) + return wrap_array(empty_storage) + return pa.chunked_array(wrapped_chunks, type=data_type) + return value + + +def _wrap_udf_function( + func: Callable[..., PyArrowArrayT], + input_fields: Sequence[pa.Field], + return_field: pa.Field, +) -> Callable[..., PyArrowArrayT]: + def wrapper(*args: Any, **kwargs: Any) -> PyArrowArrayT: + if args: + converted_args: list[Any] = list(args) + for idx, field in enumerate(input_fields): + if idx >= len(converted_args): + break + converted_args[idx] = _wrap_extension_value( + cast(PyArrowArray, converted_args[idx]), + field.type, + ) + else: + converted_args = [] + result = func(*converted_args, **kwargs) + return _wrap_extension_value(result, return_field.type) + + return wrapper + + class ScalarUDFExportable(Protocol): """Type hint for object that has __datafusion_scalar_udf__ PyCapsule.""" @@ -93,9 +198,9 @@ class ScalarUDF: def __init__( self, name: str, - func: Callable[..., _R], - input_types: pa.DataType | list[pa.DataType], - return_type: _R, + func: Callable[..., PyArrowArray] | ScalarUDFExportable, + input_types: pa.DataType | pa.Field | Sequence[pa.DataType | pa.Field], + return_type: pa.DataType | pa.Field, volatility: Volatility | str, ) -> None: """Instantiate a scalar user-defined function (UDF). @@ -105,10 +210,11 @@ def __init__( if hasattr(func, "__datafusion_scalar_udf__"): self._udf = df_internal.ScalarUDF.from_pycapsule(func) return - if isinstance(input_types, pa.DataType): - input_types = [input_types] + normalized_inputs = _normalize_input_fields(input_types) + normalized_return = _normalize_return_field(return_type, name=name) + wrapped_func = _wrap_udf_function(func, normalized_inputs, normalized_return) self._udf = df_internal.ScalarUDF( - name, func, input_types, return_type, str(volatility) + name, wrapped_func, normalized_inputs, normalized_return, str(volatility) ) def __repr__(self) -> str: @@ -127,18 +233,18 @@ def __call__(self, *args: Expr) -> Expr: @overload @staticmethod def udf( - input_types: list[pa.DataType], - return_type: _R, + input_types: list[pa.DataType | pa.Field], + return_type: pa.DataType | pa.Field, volatility: Volatility | str, name: Optional[str] = None, - ) -> Callable[..., ScalarUDF]: ... + ) -> Callable[[Callable[..., PyArrowArray]], Callable[..., Expr]]: ... @overload @staticmethod def udf( - func: Callable[..., _R], - input_types: list[pa.DataType], - return_type: _R, + func: Callable[..., PyArrowArray], + input_types: list[pa.DataType | pa.Field], + return_type: pa.DataType | pa.Field, volatility: Volatility | str, name: Optional[str] = None, ) -> ScalarUDF: ... @@ -164,10 +270,15 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417 backed ScalarUDF within a PyCapsule, you can pass this parameter and ignore the rest. They will be determined directly from the underlying function. See the online documentation for more information. - input_types (list[pa.DataType]): The data types of the arguments - to ``func``. This list must be of the same length as the number of - arguments. - return_type (_R): The data type of the return value from the function. + The callable should accept and return :class:`pyarrow.Array` or + :class:`pyarrow.ChunkedArray` values. + input_types (list[pa.DataType | pa.Field]): The argument types for ``func``. + This list must be of the same length as the number of arguments. Pass + :class:`pyarrow.Field` instances when you need to declare extension + metadata for an argument. + return_type (pa.DataType | pa.Field): The return type of the function. + Supply a :class:`pyarrow.Field` when the result should expose + extension metadata to downstream consumers. volatility (Volatility | str): See `Volatility` for allowed values. name (Optional[str]): A descriptive name for the function. @@ -179,8 +290,13 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417 def double_func(x): return x * 2 - double_udf = udf(double_func, [pa.int32()], pa.int32(), - "volatile", "double_it") + double_udf = udf( + double_func, + [pa.int32()], + pa.int32(), + "volatile", + "double_it", + ) Example: Using ``udf`` as a decorator:: @@ -190,9 +306,9 @@ def double_udf(x): """ def _function( - func: Callable[..., _R], - input_types: list[pa.DataType], - return_type: _R, + func: Callable[..., PyArrowArray], + input_types: list[pa.DataType | pa.Field], + return_type: pa.DataType | pa.Field, volatility: Volatility | str, name: Optional[str] = None, ) -> ScalarUDF: @@ -213,18 +329,18 @@ def _function( ) def _decorator( - input_types: list[pa.DataType], - return_type: _R, + input_types: list[pa.DataType | pa.Field], + return_type: pa.DataType | pa.Field, volatility: Volatility | str, name: Optional[str] = None, - ) -> Callable: - def decorator(func: Callable): + ) -> Callable[[Callable[..., PyArrowArray]], Callable[..., Expr]]: + def decorator(func: Callable[..., PyArrowArray]) -> Callable[..., Expr]: udf_caller = ScalarUDF.udf( func, input_types, return_type, volatility, name ) @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any): + def wrapper(*args: Any, **kwargs: Any) -> Expr: return udf_caller(*args, **kwargs) return wrapper @@ -357,10 +473,12 @@ def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901 This class allows you to define an aggregate function that can be used in data aggregation or window function calls. - Usage: - - As a function: ``udaf(accum, input_types, return_type, state_type, volatility, name)``. - - As a decorator: ``@udaf(input_types, return_type, state_type, volatility, name)``. - When using ``udaf`` as a decorator, do not pass ``accum`` explicitly. + Usage: + - As a function: ``udaf(accum, input_types, return_type, state_type,`` + ``volatility, name)``. + - As a decorator: ``@udaf(input_types, return_type, state_type,`` + ``volatility, name)``. + When using ``udaf`` as a decorator, do not pass ``accum`` explicitly. Function example: diff --git a/python/tests/test_udf.py b/python/tests/test_udf.py index a6c047552..313295bc8 100644 --- a/python/tests/test_udf.py +++ b/python/tests/test_udf.py @@ -15,6 +15,10 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + +import uuid + import pyarrow as pa import pytest from datafusion import column, udf @@ -124,3 +128,90 @@ def udf_with_param(values: pa.Array) -> pa.Array: result = df2.collect()[0].column(0) assert result == pa.array([False, True, True]) + + +def test_uuid_extension_chain(ctx) -> None: + uuid_type = pa.uuid() + uuid_field = pa.field("uuid_col", uuid_type) + + first = udf( + lambda values: values, + [uuid_field], + uuid_field, + volatility="immutable", + name="uuid_identity", + ) + + def ensure_extension(values: pa.Array | pa.ChunkedArray) -> pa.Array: + if isinstance(values, pa.ChunkedArray): + assert values.type.equals(uuid_type) + return values.combine_chunks() + assert isinstance(values, pa.ExtensionArray) + assert values.type.equals(uuid_type) + return values + + second = udf( + ensure_extension, + [uuid_field], + uuid_field, + volatility="immutable", + name="uuid_assert", + ) + + # The UUID extension metadata should survive UDF registration. + assert getattr(uuid_type, "extension_name", None) == "arrow.uuid" + assert getattr(uuid_field.type, "extension_name", None) == "arrow.uuid" + + storage = pa.array( + [ + uuid.UUID("00000000-0000-0000-0000-000000000000").bytes, + uuid.UUID("00000000-0000-0000-0000-000000000001").bytes, + ], + type=uuid_type.storage_type, + ) + batch = pa.RecordBatch.from_arrays( + [uuid_type.wrap_array(storage)], + names=["uuid_col"], + ) + + df = ctx.create_dataframe([[batch]]) + result = df.select(second(first(column("uuid_col")))).collect()[0].column(0) + + expected = uuid_type.wrap_array(storage) + + if isinstance(result, pa.ChunkedArray): + assert result.type.equals(uuid_type) + else: + assert isinstance(result, pa.ExtensionArray) + assert result.type.equals(uuid_type) + + assert result.equals(expected) + + empty_storage = pa.array([], type=uuid_type.storage_type) + empty_batch = pa.RecordBatch.from_arrays( + [uuid_type.wrap_array(empty_storage)], + names=["uuid_col"], + ) + + empty_first = udf( + lambda values: pa.chunked_array([], type=uuid_type.storage_type), + [uuid_field], + uuid_field, + volatility="immutable", + name="uuid_empty_chunk", + ) + + empty_df = ctx.create_dataframe([[empty_batch]]) + empty_result = ( + empty_df.select(second(empty_first(column("uuid_col")))).collect()[0].column(0) + ) + + expected_empty = uuid_type.wrap_array(empty_storage) + + if isinstance(empty_result, pa.ChunkedArray): + assert empty_result.type.equals(uuid_type) + assert empty_result.combine_chunks().equals(expected_empty) + else: + assert isinstance(empty_result, pa.ExtensionArray) + assert empty_result.type.equals(uuid_type) + assert empty_result.equals(expected_empty) diff --git a/src/udf.rs b/src/udf.rs index a9249d6c8..ae4e9b913 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -15,25 +15,26 @@ // specific language governing permissions and limitations // under the License. +use std::fmt; use std::sync::Arc; use datafusion_ffi::udf::{FFI_ScalarUDF, ForeignScalarUDF}; use pyo3::types::PyCapsule; use pyo3::{prelude::*, types::PyTuple}; -use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef}; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::pyarrow::FromPyArrow; -use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; -use datafusion::error::DataFusionError; -use datafusion::logical_expr::function::ScalarFunctionImplementation; -use datafusion::logical_expr::ScalarUDF; -use datafusion::logical_expr::{create_udf, ColumnarValue}; - use crate::errors::to_datafusion_err; use crate::errors::{py_datafusion_err, PyDataFusionResult}; use crate::expr::PyExpr; use crate::utils::{parse_volatility, validate_pycapsule}; +use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef}; +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::arrow::pyarrow::FromPyArrow; +use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; +use datafusion::error::DataFusionError; +use datafusion::logical_expr::{ + function::ScalarFunctionImplementation, ptr_eq::PtrEq, ColumnarValue, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, +}; /// Create a Rust callable function from a python function that expects pyarrow arrays fn pyarrow_function_to_rust( @@ -80,6 +81,86 @@ fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation { }) } +#[derive(PartialEq, Eq, Hash)] +struct PySimpleScalarUDF { + name: String, + signature: Signature, + return_field: Arc, + fun: PtrEq, +} + +impl fmt::Debug for PySimpleScalarUDF { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PySimpleScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("return_field", &self.return_field) + .finish() + } +} + +impl PySimpleScalarUDF { + fn new( + name: impl Into, + input_fields: Vec, + return_field: Field, + volatility: Volatility, + fun: ScalarFunctionImplementation, + ) -> Self { + let signature_types = input_fields + .into_iter() + .map(|field| field.data_type().clone()) + .collect(); + let signature = Signature::exact(signature_types, volatility); + Self { + name: name.into(), + signature, + return_field: Arc::new(return_field), + fun: fun.into(), + } + } +} + +impl ScalarUDFImpl for PySimpleScalarUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result { + Err(DataFusionError::Internal( + "return_type should be unreachable when return_field_from_args is implemented" + .to_string(), + )) + } + + fn return_field_from_args( + &self, + _args: ReturnFieldArgs, + ) -> datafusion::error::Result> { + Ok(Arc::new( + self.return_field + .as_ref() + .clone() + .with_name(self.name.clone()), + )) + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion::error::Result { + (self.fun)(&args.args) + } +} + /// Represents a PyScalarUDF #[pyclass(frozen, name = "ScalarUDF", module = "datafusion", subclass)] #[derive(Debug, Clone)] @@ -94,17 +175,19 @@ impl PyScalarUDF { fn new( name: &str, func: PyObject, - input_types: PyArrowType>, - return_type: PyArrowType, + input_types: PyArrowType>, + return_type: PyArrowType, volatility: &str, ) -> PyResult { - let function = create_udf( + let volatility = parse_volatility(volatility)?; + let scalar_impl = PySimpleScalarUDF::new( name, input_types.0, return_type.0, - parse_volatility(volatility)?, + volatility, to_scalar_function_impl(func), ); + let function = ScalarUDF::new_from_impl(scalar_impl); Ok(Self { function }) }