From 1a664b613621b2d5aa10f8a7831a5bc9ea0dd177 Mon Sep 17 00:00:00 2001 From: yantian Date: Wed, 20 May 2026 13:41:47 +0800 Subject: [PATCH 1/7] [feature] support python scalar udf --- .../python/pypaimon_rust/datafusion.pyi | 9 +- bindings/python/src/context.rs | 175 +++++++++++++++++- bindings/python/tests/test_datafusion.py | 23 +++ 3 files changed, 205 insertions(+), 2 deletions(-) diff --git a/bindings/python/python/pypaimon_rust/datafusion.pyi b/bindings/python/python/pypaimon_rust/datafusion.pyi index 4d0e973a..7f2428fe 100644 --- a/bindings/python/python/pypaimon_rust/datafusion.pyi +++ b/bindings/python/python/pypaimon_rust/datafusion.pyi @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List import pyarrow @@ -31,4 +31,11 @@ class SQLContext: def set_current_catalog(self, catalog_name: str) -> None: ... def set_current_database(self, database_name: str) -> None: ... def register_batch(self, name: str, batch: pyarrow.RecordBatch) -> None: ... + def register_scalar_function( + self, + name: str, + func: Callable[..., pyarrow.Array], + input_types: List[str], + return_type: str, + ) -> None: ... def sql(self, sql: str) -> List[pyarrow.RecordBatch]: ... diff --git a/bindings/python/src/context.rs b/bindings/python/src/context.rs index e1050d38..5d707193 100644 --- a/bindings/python/src/context.rs +++ b/bindings/python/src/context.rs @@ -15,17 +15,27 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::collections::HashMap; +use std::fmt::{self, Debug}; +use std::hash::{Hash, Hasher}; use std::sync::Arc; +use arrow::array::{make_array, Array, ArrayData, ArrayRef}; +use arrow::datatypes::DataType as ArrowDataType; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; use datafusion::catalog::CatalogProvider; +use datafusion::common::{DataFusionError, Result as DFResult}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, +}; use datafusion_ffi::catalog_provider::FFI_CatalogProvider; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use paimon::{CatalogFactory, Options}; use paimon_datafusion::{PaimonCatalogProvider, SQLContext}; +use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; -use pyo3::types::PyCapsule; +use pyo3::types::{PyCapsule, PyTuple}; use crate::error::{df_to_py_err, to_py_err}; use paimon_datafusion::runtime::runtime; @@ -57,6 +67,147 @@ fn ffi_logical_codec_from_pycapsule(obj: Bound<'_, PyAny>) -> PyResult PyResult { + match type_name.to_ascii_lowercase().as_str() { + "bool" | "boolean" => Ok(ArrowDataType::Boolean), + "int8" => Ok(ArrowDataType::Int8), + "int16" => Ok(ArrowDataType::Int16), + "int" | "int32" | "integer" => Ok(ArrowDataType::Int32), + "bigint" | "int64" | "long" => Ok(ArrowDataType::Int64), + "float" | "float32" => Ok(ArrowDataType::Float32), + "double" | "float64" => Ok(ArrowDataType::Float64), + "string" | "utf8" => Ok(ArrowDataType::Utf8), + "large_string" | "large_utf8" => Ok(ArrowDataType::LargeUtf8), + "binary" => Ok(ArrowDataType::Binary), + "large_binary" => Ok(ArrowDataType::LargeBinary), + other => Err(PyTypeError::new_err(format!( + "Unsupported Arrow type for Python UDF: {other}" + ))), + } +} + +fn df_execution_error(message: impl Into) -> DataFusionError { + DataFusionError::Execution(message.into()) +} + +fn columnar_value_to_array(value: &ColumnarValue, num_rows: usize) -> DFResult { + match value { + ColumnarValue::Array(array) => Ok(Arc::clone(array)), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows), + } +} + +struct PyScalarUDF { + name: String, + func: Py, + input_types: Vec, + return_type: ArrowDataType, + signature: Signature, +} + +impl PyScalarUDF { + fn new( + name: String, + func: Py, + input_types: Vec, + return_type: ArrowDataType, + ) -> Self { + let signature = Signature::exact(input_types.clone(), Volatility::Volatile); + Self { + name, + func, + input_types, + return_type, + signature, + } + } +} + +impl Debug for PyScalarUDF { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PyScalarUDF") + .field("name", &self.name) + .field("input_types", &self.input_types) + .field("return_type", &self.return_type) + .finish_non_exhaustive() + } +} + +impl PartialEq for PyScalarUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.input_types == other.input_types + && self.return_type == other.return_type + } +} + +impl Eq for PyScalarUDF {} + +impl Hash for PyScalarUDF { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.input_types.hash(state); + self.return_type.hash(state); + } +} + +impl ScalarUDFImpl for PyScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[ArrowDataType]) -> DFResult { + Ok(self.return_type.clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let arrays = args + .args + .iter() + .map(|value| columnar_value_to_array(value, args.number_rows)) + .collect::>>()?; + + let output = Python::try_attach(|py| -> PyResult { + let py_args = arrays + .iter() + .map(|array| array.to_data().to_pyarrow(py)) + .collect::>>()?; + let py_args = PyTuple::new(py, py_args)?; + let output = self.func.bind(py).call1(py_args)?; + Ok(make_array(ArrayData::from_pyarrow_bound(&output)?)) + }) + .ok_or_else(|| df_execution_error("Python interpreter is not available"))? + .map_err(|err| df_execution_error(format!("Python UDF '{}' failed: {err}", self.name)))?; + + if output.len() != args.number_rows { + return Err(df_execution_error(format!( + "Python UDF '{}' returned {} rows, expected {}", + self.name, + output.len(), + args.number_rows + ))); + } + if output.data_type() != &self.return_type { + return Err(df_execution_error(format!( + "Python UDF '{}' returned {:?}, expected {:?}", + self.name, + output.data_type(), + self.return_type + ))); + } + + Ok(ColumnarValue::Array(output)) + } +} + /// A Paimon catalog exportable to Python DataFusion `SessionContext`. #[pyclass(name = "PaimonCatalog")] pub struct PaimonCatalog { @@ -148,6 +299,28 @@ impl PySQLContext { .map_err(df_to_py_err) } + fn register_scalar_function( + &self, + py: Python<'_>, + name: String, + func: Py, + input_types: Vec, + return_type: String, + ) -> PyResult<()> { + if !func.bind(py).is_callable() { + return Err(PyTypeError::new_err("func must be callable")); + } + + let input_types = input_types + .iter() + .map(|type_name| parse_arrow_type(type_name)) + .collect::>>()?; + let return_type = parse_arrow_type(&return_type)?; + let udf = PyScalarUDF::new(name, func, input_types, return_type); + self.inner.ctx().register_udf(ScalarUDF::new_from_impl(udf)); + Ok(()) + } + fn sql(&self, py: Python<'_>, sql: String) -> PyResult>> { let rt = runtime(); let batches = rt.block_on(async { diff --git a/bindings/python/tests/test_datafusion.py b/bindings/python/tests/test_datafusion.py index 2576b7c2..a5b8fb17 100644 --- a/bindings/python/tests/test_datafusion.py +++ b/bindings/python/tests/test_datafusion.py @@ -100,6 +100,29 @@ def test_register_batch_bare_name(): ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp") +def test_register_scalar_function_from_python(): + with tempfile.TemporaryDirectory() as warehouse: + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + + batch = pa.record_batch([[1, 2, 3]], names=["id"]) + ctx.register_batch("my_temp", batch) + + def plus_ten(values): + return pa.array( + [None if value is None else value + 10 for value in values.to_pylist()], + type=pa.int64(), + ) + + ctx.register_scalar_function("plus_ten", plus_ten, ["int64"], "int64") + + batches = ctx.sql("SELECT plus_ten(id) AS id FROM paimon.default.my_temp ORDER BY id") + table = pa.Table.from_batches(batches) + assert table["id"].to_pylist() == [11, 12, 13] + + ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp") + + def test_temp_table_shadows_paimon_table(): with tempfile.TemporaryDirectory() as warehouse: ctx = SQLContext() From 3dabbfed6b32af1990a9a56e700508323dfae17e Mon Sep 17 00:00:00 2001 From: yantian Date: Wed, 20 May 2026 13:55:41 +0800 Subject: [PATCH 2/7] fix --- .../python/pypaimon_rust/datafusion.pyi | 13 +- bindings/python/src/context.rs | 8 +- bindings/python/tests/test_datafusion.py | 115 +++++++++++++++++- 3 files changed, 129 insertions(+), 7 deletions(-) diff --git a/bindings/python/python/pypaimon_rust/datafusion.pyi b/bindings/python/python/pypaimon_rust/datafusion.pyi index 7f2428fe..ad744abf 100644 --- a/bindings/python/python/pypaimon_rust/datafusion.pyi +++ b/bindings/python/python/pypaimon_rust/datafusion.pyi @@ -37,5 +37,16 @@ class SQLContext: func: Callable[..., pyarrow.Array], input_types: List[str], return_type: str, - ) -> None: ... + ) -> None: + """ + Register a Python scalar UDF. + + The callable receives one PyArrow Array per argument and must return a + PyArrow Array with the declared return type and the same row count. + Supported type names are: boolean, int8, int16, int32, int64, + float32, float64, string, large_string, binary, and large_binary. + Aliases such as bool, int, bigint, long, float, double, utf8, + large_utf8 are also accepted. + """ + ... def sql(self, sql: str) -> List[pyarrow.RecordBatch]: ... diff --git a/bindings/python/src/context.rs b/bindings/python/src/context.rs index 5d707193..55ceb230 100644 --- a/bindings/python/src/context.rs +++ b/bindings/python/src/context.rs @@ -323,9 +323,11 @@ impl PySQLContext { fn sql(&self, py: Python<'_>, sql: String) -> PyResult>> { let rt = runtime(); - let batches = rt.block_on(async { - let df = self.inner.sql(&sql).await.map_err(df_to_py_err)?; - df.collect().await.map_err(df_to_py_err) + let batches = py.detach(|| { + rt.block_on(async { + let df = self.inner.sql(&sql).await.map_err(df_to_py_err)?; + df.collect().await.map_err(df_to_py_err) + }) })?; batches .iter() diff --git a/bindings/python/tests/test_datafusion.py b/bindings/python/tests/test_datafusion.py index a5b8fb17..0dad0434 100644 --- a/bindings/python/tests/test_datafusion.py +++ b/bindings/python/tests/test_datafusion.py @@ -105,7 +105,7 @@ def test_register_scalar_function_from_python(): ctx = SQLContext() ctx.register_catalog("paimon", {"warehouse": warehouse}) - batch = pa.record_batch([[1, 2, 3]], names=["id"]) + batch = pa.record_batch([[1, None, 3]], names=["id"]) ctx.register_batch("my_temp", batch) def plus_ten(values): @@ -116,13 +116,122 @@ def plus_ten(values): ctx.register_scalar_function("plus_ten", plus_ten, ["int64"], "int64") - batches = ctx.sql("SELECT plus_ten(id) AS id FROM paimon.default.my_temp ORDER BY id") + batches = ctx.sql( + "SELECT plus_ten(id) AS id FROM paimon.default.my_temp ORDER BY id" + ) + table = pa.Table.from_batches(batches) + assert table["id"].to_pylist() == [11, 13, None] + + ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp") + + +def test_register_scalar_function_multi_input_plan(): + with tempfile.TemporaryDirectory() as warehouse: + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + + batch = pa.record_batch([[1, 2, 3]], names=["id"]) + ctx.register_batch("my_temp", batch) + + def plus_ten(values): + return pa.array([value + 10 for value in values.to_pylist()], type=pa.int64()) + + ctx.register_scalar_function("plus_ten", plus_ten, ["int64"], "int64") + + batches = ctx.sql( + """ + SELECT plus_ten(id) AS id FROM paimon.default.my_temp + UNION ALL + SELECT plus_ten(id) AS id FROM paimon.default.my_temp + ORDER BY id + """ + ) table = pa.Table.from_batches(batches) - assert table["id"].to_pylist() == [11, 12, 13] + assert table["id"].to_pylist() == [11, 11, 12, 12, 13, 13] ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp") +def test_register_scalar_function_rejects_non_callable(): + ctx = SQLContext() + try: + ctx.register_scalar_function("bad", 1, ["int64"], "int64") + assert False, "expected non-callable UDF registration to fail" + except TypeError as e: + assert "func must be callable" in str(e) + + +def test_register_scalar_function_rejects_unsupported_type(): + ctx = SQLContext() + + def identity(values): + return values + + try: + ctx.register_scalar_function("identity", identity, ["date32"], "date32") + assert False, "expected unsupported type registration to fail" + except TypeError as e: + assert "Unsupported Arrow type" in str(e) + + +def test_python_scalar_function_exception_surfaces(): + with tempfile.TemporaryDirectory() as warehouse: + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + ctx.register_batch("my_temp", pa.record_batch([[1]], names=["id"])) + + def boom(values): + raise RuntimeError("boom") + + ctx.register_scalar_function("boom", boom, ["int64"], "int64") + + try: + ctx.sql("SELECT boom(id) AS id FROM paimon.default.my_temp") + assert False, "expected Python UDF exception to fail the query" + except Exception as e: + message = str(e) + assert "Python UDF 'boom' failed" in message + assert "boom" in message + + +def test_python_scalar_function_rejects_wrong_length(): + with tempfile.TemporaryDirectory() as warehouse: + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + ctx.register_batch("my_temp", pa.record_batch([[1, 2]], names=["id"])) + + def wrong_length(values): + return pa.array([1], type=pa.int64()) + + ctx.register_scalar_function("wrong_length", wrong_length, ["int64"], "int64") + + try: + ctx.sql("SELECT wrong_length(id) AS id FROM paimon.default.my_temp") + assert False, "expected wrong-length UDF result to fail the query" + except Exception as e: + message = str(e) + assert "Python UDF 'wrong_length' returned 1 rows, expected 2" in message + + +def test_python_scalar_function_rejects_wrong_type(): + with tempfile.TemporaryDirectory() as warehouse: + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + ctx.register_batch("my_temp", pa.record_batch([[1]], names=["id"])) + + def wrong_type(values): + return pa.array(["not an int"], type=pa.string()) + + ctx.register_scalar_function("wrong_type", wrong_type, ["int64"], "int64") + + try: + ctx.sql("SELECT wrong_type(id) AS id FROM paimon.default.my_temp") + assert False, "expected wrong-type UDF result to fail the query" + except Exception as e: + message = str(e) + assert "Python UDF 'wrong_type' returned Utf8, expected Int64" in message + + def test_temp_table_shadows_paimon_table(): with tempfile.TemporaryDirectory() as warehouse: ctx = SQLContext() From f209e11100505b055c440f859d828fbef287eeb9 Mon Sep 17 00:00:00 2001 From: yantian Date: Wed, 20 May 2026 14:38:23 +0800 Subject: [PATCH 3/7] Align Python UDF API with DataFusion --- .../python/pypaimon_rust/datafusion.pyi | 62 ++++-- bindings/python/src/context.rs | 186 +++++++++++++++--- bindings/python/tests/test_datafusion.py | 55 ++++-- 3 files changed, 240 insertions(+), 63 deletions(-) diff --git a/bindings/python/python/pypaimon_rust/datafusion.pyi b/bindings/python/python/pypaimon_rust/datafusion.pyi index ad744abf..8f9c36ff 100644 --- a/bindings/python/python/pypaimon_rust/datafusion.pyi +++ b/bindings/python/python/pypaimon_rust/datafusion.pyi @@ -15,14 +15,55 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional, Sequence, TypeAlias, Union import pyarrow +ArrowTypeLike: TypeAlias = Union[pyarrow.DataType, pyarrow.Field, str] +InputFieldsLike: TypeAlias = Union[ArrowTypeLike, Sequence[ArrowTypeLike]] +VolatilityLike: TypeAlias = Union[str, Any] + class PaimonCatalog: def __init__(self, catalog_options: Dict[str, str]) -> None: ... def __datafusion_catalog_provider__(self, session: Any) -> object: ... +class ScalarUDF: + def __init__( + self, + name: str, + func: Callable[..., pyarrow.Array], + input_fields: InputFieldsLike, + return_field: ArrowTypeLike, + volatility: VolatilityLike, + ) -> None: ... + @staticmethod + def udf( + func: Callable[..., pyarrow.Array], + input_fields: InputFieldsLike, + return_field: ArrowTypeLike, + volatility: VolatilityLike, + name: Optional[str] = None, + ) -> "ScalarUDF": ... + @property + def name(self) -> str: ... + +def udf( + func: Callable[..., pyarrow.Array], + input_fields: InputFieldsLike, + return_field: ArrowTypeLike, + volatility: VolatilityLike, + name: Optional[str] = None, +) -> ScalarUDF: + """ + Create a scalar UDF. + + This mirrors DataFusion Python's function-style API: + ``udf(func, input_fields, return_field, volatility, name)``. + ``input_fields`` and ``return_field`` accept PyArrow DataType or Field + values. String type names remain accepted for compatibility. + """ + ... + class SQLContext: def __init__(self) -> None: ... def register_catalog( @@ -31,22 +72,5 @@ class SQLContext: def set_current_catalog(self, catalog_name: str) -> None: ... def set_current_database(self, database_name: str) -> None: ... def register_batch(self, name: str, batch: pyarrow.RecordBatch) -> None: ... - def register_scalar_function( - self, - name: str, - func: Callable[..., pyarrow.Array], - input_types: List[str], - return_type: str, - ) -> None: - """ - Register a Python scalar UDF. - - The callable receives one PyArrow Array per argument and must return a - PyArrow Array with the declared return type and the same row count. - Supported type names are: boolean, int8, int16, int32, int64, - float32, float64, string, large_string, binary, and large_binary. - Aliases such as bool, int, bigint, long, float, double, utf8, - large_utf8 are also accepted. - """ - ... + def register_udf(self, udf: ScalarUDF) -> None: ... def sql(self, sql: str) -> List[pyarrow.RecordBatch]: ... diff --git a/bindings/python/src/context.rs b/bindings/python/src/context.rs index 55ceb230..cc5b2bdc 100644 --- a/bindings/python/src/context.rs +++ b/bindings/python/src/context.rs @@ -22,12 +22,13 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use arrow::array::{make_array, Array, ArrayData, ArrayRef}; -use arrow::datatypes::DataType as ArrowDataType; +use arrow::datatypes::{DataType as ArrowDataType, Field as ArrowField}; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; use datafusion::catalog::CatalogProvider; use datafusion::common::{DataFusionError, Result as DFResult}; use datafusion::logical_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, ScalarFunctionArgs, ScalarUDF as DFScalarUDF, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_ffi::catalog_provider::FFI_CatalogProvider; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; @@ -35,7 +36,7 @@ use paimon::{CatalogFactory, Options}; use paimon_datafusion::{PaimonCatalogProvider, SQLContext}; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; -use pyo3::types::{PyCapsule, PyTuple}; +use pyo3::types::{PyCapsule, PyList, PyTuple}; use crate::error::{df_to_py_err, to_py_err}; use paimon_datafusion::runtime::runtime; @@ -86,6 +87,70 @@ fn parse_arrow_type(type_name: &str) -> PyResult { } } +fn parse_arrow_type_like(value: &Bound<'_, PyAny>) -> PyResult { + if let Ok(field) = ArrowField::from_pyarrow_bound(value) { + return Ok(field.data_type().clone()); + } + if let Ok(data_type) = ArrowDataType::from_pyarrow_bound(value) { + return Ok(data_type); + } + if let Ok(type_name) = value.extract::() { + return parse_arrow_type(&type_name); + } + + Err(PyTypeError::new_err( + "Expected a pyarrow.DataType, pyarrow.Field, or supported Arrow type name", + )) +} + +fn parse_input_types(input_fields: &Bound<'_, PyAny>) -> PyResult> { + if let Ok(fields) = input_fields.cast::() { + return fields + .iter() + .map(|field| parse_arrow_type_like(&field)) + .collect(); + } + if let Ok(fields) = input_fields.cast::() { + return fields + .iter() + .map(|field| parse_arrow_type_like(&field)) + .collect(); + } + + Ok(vec![parse_arrow_type_like(input_fields)?]) +} + +fn parse_volatility(volatility: &Bound<'_, PyAny>) -> PyResult { + let value = if let Ok(value) = volatility.extract::() { + value + } else if let Ok(name) = volatility.getattr("name") { + name.extract::()? + } else { + volatility.str()?.to_str()?.to_string() + }; + + match value.to_ascii_lowercase().as_str() { + "immutable" => Ok(Volatility::Immutable), + "stable" => Ok(Volatility::Stable), + "volatile" => Ok(Volatility::Volatile), + other => Err(PyTypeError::new_err(format!( + "Unsupported UDF volatility: {other}. Expected immutable, stable, or volatile" + ))), + } +} + +fn default_udf_name(py: Python<'_>, func: &Py) -> PyResult { + let func = func.bind(py); + if let Ok(name) = func.getattr("__qualname__") { + return Ok(name.extract::()?.to_ascii_lowercase()); + } + Ok(func + .getattr("__class__")? + .getattr("__name__")? + .extract::()? + .to_ascii_lowercase()) +} + fn df_execution_error(message: impl Into) -> DataFusionError { DataFusionError::Execution(message.into()) } @@ -102,6 +167,7 @@ struct PyScalarUDF { func: Py, input_types: Vec, return_type: ArrowDataType, + volatility: Volatility, signature: Signature, } @@ -111,24 +177,111 @@ impl PyScalarUDF { func: Py, input_types: Vec, return_type: ArrowDataType, + volatility: Volatility, ) -> Self { - let signature = Signature::exact(input_types.clone(), Volatility::Volatile); + let signature = Signature::exact(input_types.clone(), volatility); Self { name, func, input_types, return_type, + volatility, signature, } } } +#[pyclass(name = "ScalarUDF")] +pub struct PyScalarUDFObject { + name: String, + udf: DFScalarUDF, +} + +impl PyScalarUDFObject { + fn create( + py: Python<'_>, + name: String, + func: Py, + input_fields: &Bound<'_, PyAny>, + return_field: &Bound<'_, PyAny>, + volatility: &Bound<'_, PyAny>, + ) -> PyResult { + if !func.bind(py).is_callable() { + return Err(PyTypeError::new_err("`func` argument must be callable")); + } + + let input_types = parse_input_types(input_fields)?; + let return_type = parse_arrow_type_like(return_field)?; + let volatility = parse_volatility(volatility)?; + let udf = PyScalarUDF::new(name.clone(), func, input_types, return_type, volatility); + Ok(Self { + name, + udf: DFScalarUDF::new_from_impl(udf), + }) + } +} + +#[pymethods] +impl PyScalarUDFObject { + #[new] + fn new( + py: Python<'_>, + name: String, + func: Py, + input_fields: Bound<'_, PyAny>, + return_field: Bound<'_, PyAny>, + volatility: Bound<'_, PyAny>, + ) -> PyResult { + Self::create(py, name, func, &input_fields, &return_field, &volatility) + } + + #[staticmethod] + #[pyo3(signature = (func, input_fields, return_field, volatility, name = None))] + fn udf( + py: Python<'_>, + func: Py, + input_fields: Bound<'_, PyAny>, + return_field: Bound<'_, PyAny>, + volatility: Bound<'_, PyAny>, + name: Option, + ) -> PyResult { + let name = match name { + Some(name) => name, + None => default_udf_name(py, &func)?, + }; + Self::create(py, name, func, &input_fields, &return_field, &volatility) + } + + #[getter] + fn name(&self) -> &str { + &self.name + } + + fn __repr__(&self) -> String { + format!("ScalarUDF({})", self.name) + } +} + +#[pyfunction] +#[pyo3(signature = (func, input_fields, return_field, volatility, name = None))] +fn udf( + py: Python<'_>, + func: Py, + input_fields: Bound<'_, PyAny>, + return_field: Bound<'_, PyAny>, + volatility: Bound<'_, PyAny>, + name: Option, +) -> PyResult { + PyScalarUDFObject::udf(py, func, input_fields, return_field, volatility, name) +} + impl Debug for PyScalarUDF { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PyScalarUDF") .field("name", &self.name) .field("input_types", &self.input_types) .field("return_type", &self.return_type) + .field("volatility", &self.volatility) .finish_non_exhaustive() } } @@ -138,6 +291,7 @@ impl PartialEq for PyScalarUDF { self.name == other.name && self.input_types == other.input_types && self.return_type == other.return_type + && self.volatility == other.volatility } } @@ -148,6 +302,7 @@ impl Hash for PyScalarUDF { self.name.hash(state); self.input_types.hash(state); self.return_type.hash(state); + self.volatility.hash(state); } } @@ -299,25 +454,8 @@ impl PySQLContext { .map_err(df_to_py_err) } - fn register_scalar_function( - &self, - py: Python<'_>, - name: String, - func: Py, - input_types: Vec, - return_type: String, - ) -> PyResult<()> { - if !func.bind(py).is_callable() { - return Err(PyTypeError::new_err("func must be callable")); - } - - let input_types = input_types - .iter() - .map(|type_name| parse_arrow_type(type_name)) - .collect::>>()?; - let return_type = parse_arrow_type(&return_type)?; - let udf = PyScalarUDF::new(name, func, input_types, return_type); - self.inner.ctx().register_udf(ScalarUDF::new_from_impl(udf)); + fn register_udf(&self, udf: &PyScalarUDFObject) -> PyResult<()> { + self.inner.ctx().register_udf(udf.udf.clone()); Ok(()) } @@ -339,7 +477,9 @@ impl PySQLContext { pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { let this = PyModule::new(py, "datafusion")?; this.add_class::()?; + this.add_class::()?; this.add_class::()?; + this.add_function(wrap_pyfunction!(udf, &this)?)?; m.add_submodule(&this)?; py.import("sys")? .getattr("modules")? diff --git a/bindings/python/tests/test_datafusion.py b/bindings/python/tests/test_datafusion.py index 0dad0434..c11610eb 100644 --- a/bindings/python/tests/test_datafusion.py +++ b/bindings/python/tests/test_datafusion.py @@ -21,7 +21,7 @@ import pyarrow as pa from datafusion import SessionContext -from pypaimon_rust.datafusion import PaimonCatalog, SQLContext +from pypaimon_rust.datafusion import PaimonCatalog, SQLContext, ScalarUDF, udf WAREHOUSE = os.environ.get("PAIMON_TEST_WAREHOUSE", "/tmp/paimon-warehouse") @@ -100,7 +100,7 @@ def test_register_batch_bare_name(): ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp") -def test_register_scalar_function_from_python(): +def test_register_udf_from_python(): with tempfile.TemporaryDirectory() as warehouse: ctx = SQLContext() ctx.register_catalog("paimon", {"warehouse": warehouse}) @@ -114,7 +114,7 @@ def plus_ten(values): type=pa.int64(), ) - ctx.register_scalar_function("plus_ten", plus_ten, ["int64"], "int64") + ctx.register_udf(udf(plus_ten, [pa.int64()], pa.int64(), "volatile", "plus_ten")) batches = ctx.sql( "SELECT plus_ten(id) AS id FROM paimon.default.my_temp ORDER BY id" @@ -125,7 +125,7 @@ def plus_ten(values): ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp") -def test_register_scalar_function_multi_input_plan(): +def test_register_udf_multi_input_plan(): with tempfile.TemporaryDirectory() as warehouse: ctx = SQLContext() ctx.register_catalog("paimon", {"warehouse": warehouse}) @@ -136,7 +136,7 @@ def test_register_scalar_function_multi_input_plan(): def plus_ten(values): return pa.array([value + 10 for value in values.to_pylist()], type=pa.int64()) - ctx.register_scalar_function("plus_ten", plus_ten, ["int64"], "int64") + ctx.register_udf(udf(plus_ten, [pa.int64()], pa.int64(), "volatile", "plus_ten")) batches = ctx.sql( """ @@ -152,29 +152,38 @@ def plus_ten(values): ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp") -def test_register_scalar_function_rejects_non_callable(): - ctx = SQLContext() +def test_udf_rejects_non_callable(): try: - ctx.register_scalar_function("bad", 1, ["int64"], "int64") - assert False, "expected non-callable UDF registration to fail" + udf(1, [pa.int64()], pa.int64(), "volatile") + assert False, "expected non-callable UDF creation to fail" except TypeError as e: - assert "func must be callable" in str(e) + assert "`func` argument must be callable" in str(e) -def test_register_scalar_function_rejects_unsupported_type(): - ctx = SQLContext() - +def test_udf_rejects_unsupported_type(): def identity(values): return values try: - ctx.register_scalar_function("identity", identity, ["date32"], "date32") + udf(identity, [object()], pa.int64(), "volatile", "identity") assert False, "expected unsupported type registration to fail" except TypeError as e: - assert "Unsupported Arrow type" in str(e) + assert "Expected a pyarrow.DataType" in str(e) + + +def test_scalar_udf_constructor_matches_datafusion_shape(): + def identity(values): + return values + scalar_udf = ScalarUDF( + "identity", identity, [pa.field("value", pa.int64())], pa.int64(), "stable" + ) -def test_python_scalar_function_exception_surfaces(): + assert scalar_udf.name == "identity" + assert repr(scalar_udf) == "ScalarUDF(identity)" + + +def test_python_udf_exception_surfaces(): with tempfile.TemporaryDirectory() as warehouse: ctx = SQLContext() ctx.register_catalog("paimon", {"warehouse": warehouse}) @@ -183,7 +192,7 @@ def test_python_scalar_function_exception_surfaces(): def boom(values): raise RuntimeError("boom") - ctx.register_scalar_function("boom", boom, ["int64"], "int64") + ctx.register_udf(udf(boom, [pa.int64()], pa.int64(), "volatile", "boom")) try: ctx.sql("SELECT boom(id) AS id FROM paimon.default.my_temp") @@ -194,7 +203,7 @@ def boom(values): assert "boom" in message -def test_python_scalar_function_rejects_wrong_length(): +def test_python_udf_rejects_wrong_length(): with tempfile.TemporaryDirectory() as warehouse: ctx = SQLContext() ctx.register_catalog("paimon", {"warehouse": warehouse}) @@ -203,7 +212,9 @@ def test_python_scalar_function_rejects_wrong_length(): def wrong_length(values): return pa.array([1], type=pa.int64()) - ctx.register_scalar_function("wrong_length", wrong_length, ["int64"], "int64") + ctx.register_udf( + udf(wrong_length, [pa.int64()], pa.int64(), "volatile", "wrong_length") + ) try: ctx.sql("SELECT wrong_length(id) AS id FROM paimon.default.my_temp") @@ -213,7 +224,7 @@ def wrong_length(values): assert "Python UDF 'wrong_length' returned 1 rows, expected 2" in message -def test_python_scalar_function_rejects_wrong_type(): +def test_python_udf_rejects_wrong_type(): with tempfile.TemporaryDirectory() as warehouse: ctx = SQLContext() ctx.register_catalog("paimon", {"warehouse": warehouse}) @@ -222,7 +233,9 @@ def test_python_scalar_function_rejects_wrong_type(): def wrong_type(values): return pa.array(["not an int"], type=pa.string()) - ctx.register_scalar_function("wrong_type", wrong_type, ["int64"], "int64") + ctx.register_udf( + udf(wrong_type, [pa.int64()], pa.int64(), "volatile", "wrong_type") + ) try: ctx.sql("SELECT wrong_type(id) AS id FROM paimon.default.my_temp") From e4a5c98b2ce389a5d324c28e977f0f4c58f88bd7 Mon Sep 17 00:00:00 2001 From: yantian Date: Wed, 20 May 2026 14:42:15 +0800 Subject: [PATCH 4/7] Rename Python UDF wrapper --- .../python/python/pypaimon_rust/datafusion.pyi | 8 ++++---- bindings/python/src/context.rs | 18 +++++++++--------- bindings/python/tests/test_datafusion.py | 8 ++++---- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/bindings/python/python/pypaimon_rust/datafusion.pyi b/bindings/python/python/pypaimon_rust/datafusion.pyi index 8f9c36ff..dfa63c7a 100644 --- a/bindings/python/python/pypaimon_rust/datafusion.pyi +++ b/bindings/python/python/pypaimon_rust/datafusion.pyi @@ -27,7 +27,7 @@ class PaimonCatalog: def __init__(self, catalog_options: Dict[str, str]) -> None: ... def __datafusion_catalog_provider__(self, session: Any) -> object: ... -class ScalarUDF: +class PythonScalarUDF: def __init__( self, name: str, @@ -43,7 +43,7 @@ class ScalarUDF: return_field: ArrowTypeLike, volatility: VolatilityLike, name: Optional[str] = None, - ) -> "ScalarUDF": ... + ) -> "PythonScalarUDF": ... @property def name(self) -> str: ... @@ -53,7 +53,7 @@ def udf( return_field: ArrowTypeLike, volatility: VolatilityLike, name: Optional[str] = None, -) -> ScalarUDF: +) -> PythonScalarUDF: """ Create a scalar UDF. @@ -72,5 +72,5 @@ class SQLContext: def set_current_catalog(self, catalog_name: str) -> None: ... def set_current_database(self, database_name: str) -> None: ... def register_batch(self, name: str, batch: pyarrow.RecordBatch) -> None: ... - def register_udf(self, udf: ScalarUDF) -> None: ... + def register_udf(self, udf: PythonScalarUDF) -> None: ... def sql(self, sql: str) -> List[pyarrow.RecordBatch]: ... diff --git a/bindings/python/src/context.rs b/bindings/python/src/context.rs index cc5b2bdc..63cae85a 100644 --- a/bindings/python/src/context.rs +++ b/bindings/python/src/context.rs @@ -191,13 +191,13 @@ impl PyScalarUDF { } } -#[pyclass(name = "ScalarUDF")] -pub struct PyScalarUDFObject { +#[pyclass(name = "PythonScalarUDF")] +pub struct PyPythonScalarUDFObject { name: String, udf: DFScalarUDF, } -impl PyScalarUDFObject { +impl PyPythonScalarUDFObject { fn create( py: Python<'_>, name: String, @@ -222,7 +222,7 @@ impl PyScalarUDFObject { } #[pymethods] -impl PyScalarUDFObject { +impl PyPythonScalarUDFObject { #[new] fn new( py: Python<'_>, @@ -258,7 +258,7 @@ impl PyScalarUDFObject { } fn __repr__(&self) -> String { - format!("ScalarUDF({})", self.name) + format!("PythonScalarUDF({})", self.name) } } @@ -271,8 +271,8 @@ fn udf( return_field: Bound<'_, PyAny>, volatility: Bound<'_, PyAny>, name: Option, -) -> PyResult { - PyScalarUDFObject::udf(py, func, input_fields, return_field, volatility, name) +) -> PyResult { + PyPythonScalarUDFObject::udf(py, func, input_fields, return_field, volatility, name) } impl Debug for PyScalarUDF { @@ -454,7 +454,7 @@ impl PySQLContext { .map_err(df_to_py_err) } - fn register_udf(&self, udf: &PyScalarUDFObject) -> PyResult<()> { + fn register_udf(&self, udf: &PyPythonScalarUDFObject) -> PyResult<()> { self.inner.ctx().register_udf(udf.udf.clone()); Ok(()) } @@ -477,7 +477,7 @@ impl PySQLContext { pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { let this = PyModule::new(py, "datafusion")?; this.add_class::()?; - this.add_class::()?; + this.add_class::()?; this.add_class::()?; this.add_function(wrap_pyfunction!(udf, &this)?)?; m.add_submodule(&this)?; diff --git a/bindings/python/tests/test_datafusion.py b/bindings/python/tests/test_datafusion.py index c11610eb..17af490a 100644 --- a/bindings/python/tests/test_datafusion.py +++ b/bindings/python/tests/test_datafusion.py @@ -21,7 +21,7 @@ import pyarrow as pa from datafusion import SessionContext -from pypaimon_rust.datafusion import PaimonCatalog, SQLContext, ScalarUDF, udf +from pypaimon_rust.datafusion import PaimonCatalog, PythonScalarUDF, SQLContext, udf WAREHOUSE = os.environ.get("PAIMON_TEST_WAREHOUSE", "/tmp/paimon-warehouse") @@ -171,16 +171,16 @@ def identity(values): assert "Expected a pyarrow.DataType" in str(e) -def test_scalar_udf_constructor_matches_datafusion_shape(): +def test_python_scalar_udf_constructor_matches_datafusion_shape(): def identity(values): return values - scalar_udf = ScalarUDF( + scalar_udf = PythonScalarUDF( "identity", identity, [pa.field("value", pa.int64())], pa.int64(), "stable" ) assert scalar_udf.name == "identity" - assert repr(scalar_udf) == "ScalarUDF(identity)" + assert repr(scalar_udf) == "PythonScalarUDF(identity)" def test_python_udf_exception_surfaces(): From b4f04fd173429f6628c242885b16bf66f8f5326d Mon Sep 17 00:00:00 2001 From: yantian Date: Wed, 20 May 2026 15:32:58 +0800 Subject: [PATCH 5/7] Add Python UDF registration helper --- .../python/python/pypaimon_rust/__init__.py | 1 + .../python/python/pypaimon_rust/functions.py | 33 +++++++++++++++++ .../python/python/pypaimon_rust/functions.pyi | 36 +++++++++++++++++++ bindings/python/tests/test_datafusion.py | 26 ++++++++++++++ 4 files changed, 96 insertions(+) create mode 100644 bindings/python/python/pypaimon_rust/functions.py create mode 100644 bindings/python/python/pypaimon_rust/functions.pyi diff --git a/bindings/python/python/pypaimon_rust/__init__.py b/bindings/python/python/pypaimon_rust/__init__.py index b36002b7..68573aeb 100644 --- a/bindings/python/python/pypaimon_rust/__init__.py +++ b/bindings/python/python/pypaimon_rust/__init__.py @@ -16,3 +16,4 @@ # under the License. from .pypaimon_rust import * +from .functions import register_python_udf diff --git a/bindings/python/python/pypaimon_rust/functions.py b/bindings/python/python/pypaimon_rust/functions.py new file mode 100644 index 00000000..cdc6786e --- /dev/null +++ b/bindings/python/python/pypaimon_rust/functions.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional + +from pypaimon_rust.datafusion import PythonScalarUDF, udf + + +def register_python_udf( + ctx, + func, + input_fields, + return_field, + volatility="volatile", + name: Optional[str] = None, +) -> PythonScalarUDF: + scalar_udf = udf(func, input_fields, return_field, volatility, name) + ctx.register_udf(scalar_udf) + return scalar_udf diff --git a/bindings/python/python/pypaimon_rust/functions.pyi b/bindings/python/python/pypaimon_rust/functions.pyi new file mode 100644 index 00000000..fb98b5d7 --- /dev/null +++ b/bindings/python/python/pypaimon_rust/functions.pyi @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Callable, Optional + +import pyarrow + +from pypaimon_rust.datafusion import ( + ArrowTypeLike, + InputFieldsLike, + PythonScalarUDF, + VolatilityLike, +) + +def register_python_udf( + ctx, + func: Callable[..., pyarrow.Array], + input_fields: InputFieldsLike, + return_field: ArrowTypeLike, + volatility: VolatilityLike = "volatile", + name: Optional[str] = None, +) -> PythonScalarUDF: ... diff --git a/bindings/python/tests/test_datafusion.py b/bindings/python/tests/test_datafusion.py index 17af490a..1469a537 100644 --- a/bindings/python/tests/test_datafusion.py +++ b/bindings/python/tests/test_datafusion.py @@ -21,6 +21,7 @@ import pyarrow as pa from datafusion import SessionContext +from pypaimon_rust.functions import register_python_udf from pypaimon_rust.datafusion import PaimonCatalog, PythonScalarUDF, SQLContext, udf WAREHOUSE = os.environ.get("PAIMON_TEST_WAREHOUSE", "/tmp/paimon-warehouse") @@ -125,6 +126,31 @@ def plus_ten(values): ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp") +def test_register_python_udf_builtin_helper(): + with tempfile.TemporaryDirectory() as warehouse: + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + ctx.register_batch("my_temp", pa.record_batch([[1, 2]], names=["id"])) + + def plus_one(values): + return pa.array([value + 1 for value in values.to_pylist()], type=pa.int64()) + + scalar_udf = register_python_udf( + ctx, + plus_one, + [pa.int64()], + pa.int64(), + name="plus_one", + ) + + batches = ctx.sql("SELECT plus_one(id) AS id FROM paimon.default.my_temp") + table = pa.Table.from_batches(batches) + assert isinstance(scalar_udf, PythonScalarUDF) + assert table["id"].to_pylist() == [2, 3] + + ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp") + + def test_register_udf_multi_input_plan(): with tempfile.TemporaryDirectory() as warehouse: ctx = SQLContext() From 56d4712e1de4cb7b7a185955e9d002edf4e80401 Mon Sep 17 00:00:00 2001 From: yantian Date: Wed, 20 May 2026 15:38:15 +0800 Subject: [PATCH 6/7] Keep Python UDF API DataFusion-style --- .../python/python/pypaimon_rust/__init__.py | 1 - .../python/python/pypaimon_rust/functions.py | 33 ----------------- .../python/python/pypaimon_rust/functions.pyi | 36 ------------------- bindings/python/tests/test_datafusion.py | 26 -------------- 4 files changed, 96 deletions(-) delete mode 100644 bindings/python/python/pypaimon_rust/functions.py delete mode 100644 bindings/python/python/pypaimon_rust/functions.pyi diff --git a/bindings/python/python/pypaimon_rust/__init__.py b/bindings/python/python/pypaimon_rust/__init__.py index 68573aeb..b36002b7 100644 --- a/bindings/python/python/pypaimon_rust/__init__.py +++ b/bindings/python/python/pypaimon_rust/__init__.py @@ -16,4 +16,3 @@ # under the License. from .pypaimon_rust import * -from .functions import register_python_udf diff --git a/bindings/python/python/pypaimon_rust/functions.py b/bindings/python/python/pypaimon_rust/functions.py deleted file mode 100644 index cdc6786e..00000000 --- a/bindings/python/python/pypaimon_rust/functions.py +++ /dev/null @@ -1,33 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Optional - -from pypaimon_rust.datafusion import PythonScalarUDF, udf - - -def register_python_udf( - ctx, - func, - input_fields, - return_field, - volatility="volatile", - name: Optional[str] = None, -) -> PythonScalarUDF: - scalar_udf = udf(func, input_fields, return_field, volatility, name) - ctx.register_udf(scalar_udf) - return scalar_udf diff --git a/bindings/python/python/pypaimon_rust/functions.pyi b/bindings/python/python/pypaimon_rust/functions.pyi deleted file mode 100644 index fb98b5d7..00000000 --- a/bindings/python/python/pypaimon_rust/functions.pyi +++ /dev/null @@ -1,36 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Callable, Optional - -import pyarrow - -from pypaimon_rust.datafusion import ( - ArrowTypeLike, - InputFieldsLike, - PythonScalarUDF, - VolatilityLike, -) - -def register_python_udf( - ctx, - func: Callable[..., pyarrow.Array], - input_fields: InputFieldsLike, - return_field: ArrowTypeLike, - volatility: VolatilityLike = "volatile", - name: Optional[str] = None, -) -> PythonScalarUDF: ... diff --git a/bindings/python/tests/test_datafusion.py b/bindings/python/tests/test_datafusion.py index 1469a537..17af490a 100644 --- a/bindings/python/tests/test_datafusion.py +++ b/bindings/python/tests/test_datafusion.py @@ -21,7 +21,6 @@ import pyarrow as pa from datafusion import SessionContext -from pypaimon_rust.functions import register_python_udf from pypaimon_rust.datafusion import PaimonCatalog, PythonScalarUDF, SQLContext, udf WAREHOUSE = os.environ.get("PAIMON_TEST_WAREHOUSE", "/tmp/paimon-warehouse") @@ -126,31 +125,6 @@ def plus_ten(values): ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp") -def test_register_python_udf_builtin_helper(): - with tempfile.TemporaryDirectory() as warehouse: - ctx = SQLContext() - ctx.register_catalog("paimon", {"warehouse": warehouse}) - ctx.register_batch("my_temp", pa.record_batch([[1, 2]], names=["id"])) - - def plus_one(values): - return pa.array([value + 1 for value in values.to_pylist()], type=pa.int64()) - - scalar_udf = register_python_udf( - ctx, - plus_one, - [pa.int64()], - pa.int64(), - name="plus_one", - ) - - batches = ctx.sql("SELECT plus_one(id) AS id FROM paimon.default.my_temp") - table = pa.Table.from_batches(batches) - assert isinstance(scalar_udf, PythonScalarUDF) - assert table["id"].to_pylist() == [2, 3] - - ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp") - - def test_register_udf_multi_input_plan(): with tempfile.TemporaryDirectory() as warehouse: ctx = SQLContext() From 3b8a42c34d07ebbca608ba4402784edbe52412eb Mon Sep 17 00:00:00 2001 From: yantian Date: Wed, 20 May 2026 16:11:41 +0800 Subject: [PATCH 7/7] Add video_snapshot Python built-in UDF --- bindings/python/pyproject.toml | 8 + .../python/python/pypaimon_rust/functions.py | 179 ++++++++ .../python/python/pypaimon_rust/functions.pyi | 24 ++ bindings/python/src/blob.rs | 232 +++++++++++ bindings/python/src/context.rs | 383 +++--------------- bindings/python/src/lib.rs | 2 + bindings/python/src/udf.rs | 363 +++++++++++++++++ bindings/python/tests/test_datafusion.py | 340 +++++++++++++++- .../datafusion/src/blob_reader.rs | 101 +++++ crates/integrations/datafusion/src/catalog.rs | 22 +- crates/integrations/datafusion/src/lib.rs | 2 + .../datafusion/src/sql_context.rs | 19 +- .../integrations/datafusion/src/table/mod.rs | 10 + 13 files changed, 1355 insertions(+), 330 deletions(-) create mode 100644 bindings/python/python/pypaimon_rust/functions.py create mode 100644 bindings/python/python/pypaimon_rust/functions.pyi create mode 100644 bindings/python/src/blob.rs create mode 100644 bindings/python/src/udf.rs create mode 100644 crates/integrations/datafusion/src/blob_reader.rs diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml index 7ff28d0f..cdefdfe1 100644 --- a/bindings/python/pyproject.toml +++ b/bindings/python/pyproject.toml @@ -39,6 +39,12 @@ classifiers = [ "Programming Language :: Rust", ] +[project.optional-dependencies] +video = [ + "av>=17.0,<18.0", + "pillow>=12.0,<13.0", +] + [tool.maturin] module-name = "pypaimon_rust.pypaimon_rust" python-source = "python" @@ -54,4 +60,6 @@ dev = [ "pytest>=8.0", "pyarrow>=17.0,<24.0", "datafusion==53.0.0", + "av>=17.0,<18.0", + "pillow>=12.0,<13.0", ] diff --git a/bindings/python/python/pypaimon_rust/functions.py b/bindings/python/python/pypaimon_rust/functions.py new file mode 100644 index 00000000..9a3b4cf0 --- /dev/null +++ b/bindings/python/python/pypaimon_rust/functions.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import io +import logging +import struct +from typing import Any, BinaryIO + +logger = logging.getLogger(__name__) +_STILL_IMAGE_FORMATS = { + "apng", + "bmp_pipe", + "gif", + "ico", + "image2", + "image2pipe", + "jpeg_pipe", + "png_pipe", + "tiff_pipe", + "webp_pipe", +} + + +class _BlobDescriptorProbe: + CURRENT_VERSION = 2 + MAGIC = 0x424C4F4244455343 + + @classmethod + def is_blob_descriptor(cls, data: Any) -> bool: + if not isinstance(data, (bytes, bytearray, memoryview)): + return False + raw = bytes(data) + if len(raw) < 9: + return False + + version = raw[0] + # Version 1 has no magic header, so it cannot be distinguished safely + # from arbitrary inline video bytes in this heuristic. + if version == 1 or version > cls.CURRENT_VERSION: + return False + + try: + return struct.unpack(" bool: + return _BlobDescriptorProbe.is_blob_descriptor(data) + + +def open_blob_descriptor_stream( + raw_value: bytes, + blob_reader_registry=None, +) -> BinaryIO: + if blob_reader_registry is not None: + stream = blob_reader_registry.open_blob_descriptor_stream(raw_value) + if stream is not None: + return stream + + if _BlobDescriptorProbe.is_blob_descriptor(raw_value): + raise RuntimeError( + "BlobDescriptor input requires a registered Paimon table FileIO" + ) + return io.BytesIO(bytes(raw_value)) + + +def _decode_video_snapshot( + stream: BinaryIO, + image_format: str, + timestamp_ms: int = 0, +) -> bytes | None: + try: + import av + except ImportError as e: + raise ImportError("PyAV is required to decode video snapshots") from e + + with av.open(stream, mode="r") as container: + format_names = set((container.format.name or "").split(",")) + if format_names & _STILL_IMAGE_FORMATS: + logger.debug( + "video_snapshot input is a still image format: %s", + container.format.name, + ) + return None + if not container.streams.video: + return None + + target_seconds = timestamp_ms / 1000 + if timestamp_ms > 0: + container.seek(timestamp_ms * 1000, backward=True, any_frame=False) + + candidate = None + for frame in container.decode(video=0): + if ( + timestamp_ms > 0 + and frame.time is not None + and frame.time < target_seconds + ): + candidate = frame + continue + candidate = frame + break + + if candidate is not None: + try: + image = candidate.to_image() + except ImportError as e: + raise ImportError( + "Pillow is required to encode video_snapshot images" + ) from e + output = io.BytesIO() + image.save(output, format=image_format) + return output.getvalue() + return None + + +def _make_video_snapshot(image_format: str = "PNG", blob_reader_registry=None): + image_format = image_format.upper() + + def video_snapshot(values, timestamps_ms=None): + try: + import pyarrow as pa + except ImportError as e: + raise ImportError("pyarrow is required to return video_snapshot results") from e + + frames = [] + raw_values = values.to_pylist() + if timestamps_ms is None: + timestamp_values = [0] * len(raw_values) + else: + timestamp_values = timestamps_ms.to_pylist() + if len(timestamp_values) != len(raw_values): + raise ValueError( + "video_snapshot timestamp argument must have the same row count" + ) + + # v1 intentionally decodes rows serially; callers should filter or limit + # large scans before applying video_snapshot. + for raw_value, timestamp_ms in zip(raw_values, timestamp_values): + if raw_value is None or timestamp_ms is None: + frames.append(None) + continue + + try: + timestamp_ms = int(timestamp_ms) + if timestamp_ms < 0: + frames.append(None) + continue + stream = open_blob_descriptor_stream(raw_value, blob_reader_registry) + try: + frames.append( + _decode_video_snapshot(stream, image_format, timestamp_ms) + ) + finally: + stream.close() + except ImportError: + raise + except Exception as e: + logger.warning("Failed to decode video snapshot: %s", e) + frames.append(None) + + return pa.array(frames, type=pa.binary()) + + return video_snapshot diff --git a/bindings/python/python/pypaimon_rust/functions.pyi b/bindings/python/python/pypaimon_rust/functions.pyi new file mode 100644 index 00000000..aa91ba0a --- /dev/null +++ b/bindings/python/python/pypaimon_rust/functions.pyi @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, BinaryIO + +def is_blob_descriptor(data: Any) -> bool: ... +def open_blob_descriptor_stream( + raw_value: bytes, + blob_reader_registry: Any | None = None, +) -> BinaryIO: ... diff --git a/bindings/python/src/blob.rs b/bindings/python/src/blob.rs new file mode 100644 index 00000000..9e4ff2d6 --- /dev/null +++ b/bindings/python/src/blob.rs @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::future::Future; + +use paimon::io::{FileIO, FileRead}; +use paimon::spec::BlobDescriptor; +use paimon_datafusion::runtime::runtime; +use paimon_datafusion::BlobReaderRegistry; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::PyBytes; + +use crate::error::to_py_err; + +fn block_on_runtime(future: F, panic_error: &'static str) -> F::Output +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + if tokio::runtime::Handle::try_current().is_ok() { + let handle = runtime(); + std::thread::spawn(move || handle.block_on(future)) + .join() + .expect(panic_error) + } else { + runtime().block_on(future) + } +} + +#[pyclass(name = "BlobReaderRegistry", skip_from_py_object)] +#[derive(Clone)] +pub(crate) struct PyBlobReaderRegistry { + inner: BlobReaderRegistry, +} + +impl PyBlobReaderRegistry { + pub(crate) fn new(inner: BlobReaderRegistry) -> Self { + Self { inner } + } +} + +#[pymethods] +impl PyBlobReaderRegistry { + fn open_blob_descriptor_stream(&self, raw_value: &[u8]) -> PyResult> { + if !BlobDescriptor::is_blob_descriptor(raw_value) { + return Ok(None); + } + + let descriptor = BlobDescriptor::deserialize(raw_value).map_err(to_py_err)?; + let Some(file_io) = self.inner.resolve(descriptor.uri()) else { + return Ok(None); + }; + + Ok(Some(PyBlobInputStream::new(file_io, descriptor)?)) + } +} + +#[pyclass(name = "BlobInputStream")] +struct PyBlobInputStream { + file_io: FileIO, + uri: String, + offset: u64, + length: Option, + position: u64, + closed: bool, +} + +impl PyBlobInputStream { + fn new(file_io: FileIO, descriptor: BlobDescriptor) -> PyResult { + if descriptor.offset() < 0 { + return Err(PyValueError::new_err(format!( + "BlobDescriptor has negative offset: {}", + descriptor.offset() + ))); + } + if descriptor.length() < -1 { + return Err(PyValueError::new_err(format!( + "BlobDescriptor has invalid length: {}", + descriptor.length() + ))); + } + + Ok(Self { + file_io, + uri: descriptor.uri().to_string(), + offset: descriptor.offset() as u64, + length: (descriptor.length() >= 0).then_some(descriptor.length() as u64), + position: 0, + closed: false, + }) + } + + fn ensure_open(&self) -> PyResult<()> { + if self.closed { + Err(PyValueError::new_err("I/O operation on closed file.")) + } else { + Ok(()) + } + } + + fn stream_length(&self, py: Python<'_>) -> PyResult { + if let Some(length) = self.length { + return Ok(length); + } + + let file_io = self.file_io.clone(); + let uri = self.uri.clone(); + let offset = self.offset; + py.detach(|| { + block_on_runtime( + async move { + let input = file_io.new_input(&uri).map_err(to_py_err)?; + let metadata = input.metadata().await.map_err(to_py_err)?; + Ok(metadata.size.saturating_sub(offset)) + }, + "paimon blob metadata read thread panicked", + ) + }) + } + + fn read_bytes(&mut self, py: Python<'_>, size: isize) -> PyResult> { + self.ensure_open()?; + let stream_length = self.stream_length(py)?; + let remaining = stream_length.saturating_sub(self.position); + if remaining == 0 || size == 0 { + return Ok(Vec::new()); + } + + let to_read = if size < 0 { + remaining + } else { + remaining.min(size as u64) + }; + let start = self.offset + self.position; + let end = start + to_read; + let file_io = self.file_io.clone(); + let uri = self.uri.clone(); + let bytes = py.detach(|| { + block_on_runtime( + async move { + let input = file_io.new_input(&uri).map_err(to_py_err)?; + let reader = input.reader().await.map_err(to_py_err)?; + let bytes = reader.read(start..end).await.map_err(to_py_err)?; + Ok::<_, PyErr>(bytes.to_vec()) + }, + "paimon blob range read thread panicked", + ) + })?; + self.position += bytes.len() as u64; + Ok(bytes) + } +} + +#[pymethods] +impl PyBlobInputStream { + fn readable(&self) -> bool { + true + } + + fn seekable(&self) -> bool { + true + } + + fn tell(&self) -> u64 { + self.position + } + + #[getter] + fn closed(&self) -> bool { + self.closed + } + + fn __enter__(slf: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> { + slf + } + + fn __exit__( + &mut self, + _exc_type: &Bound<'_, PyAny>, + _exc: &Bound<'_, PyAny>, + _traceback: &Bound<'_, PyAny>, + ) -> bool { + self.close(); + false + } + + #[pyo3(signature = (size = -1))] + fn read<'py>(&mut self, py: Python<'py>, size: isize) -> PyResult> { + let bytes = self.read_bytes(py, size)?; + Ok(PyBytes::new(py, &bytes)) + } + + #[pyo3(signature = (pos, whence = 0))] + fn seek(&mut self, py: Python<'_>, pos: i64, whence: i32) -> PyResult { + self.ensure_open()?; + let base = match whence { + 0 => 0, + 1 => self.position as i64, + 2 => self.stream_length(py)? as i64, + other => return Err(PyValueError::new_err(format!("Invalid whence: {other}"))), + }; + let target = base + .checked_add(pos) + .ok_or_else(|| PyValueError::new_err("Seek position overflow"))?; + if target < 0 { + return Err(PyValueError::new_err(format!( + "Negative seek position: {target}" + ))); + } + self.position = target as u64; + Ok(self.position) + } + + fn close(&mut self) { + self.closed = true; + } +} diff --git a/bindings/python/src/context.rs b/bindings/python/src/context.rs index 63cae85a..f65d6a1a 100644 --- a/bindings/python/src/context.rs +++ b/bindings/python/src/context.rs @@ -15,30 +15,24 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::collections::HashMap; -use std::fmt::{self, Debug}; -use std::hash::{Hash, Hasher}; use std::sync::Arc; -use arrow::array::{make_array, Array, ArrayData, ArrayRef}; -use arrow::datatypes::{DataType as ArrowDataType, Field as ArrowField}; +use arrow::datatypes::DataType as ArrowDataType; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; use datafusion::catalog::CatalogProvider; -use datafusion::common::{DataFusionError, Result as DFResult}; -use datafusion::logical_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDF as DFScalarUDF, ScalarUDFImpl, Signature, - Volatility, -}; +use datafusion::logical_expr::{Signature, TypeSignature, Volatility}; use datafusion_ffi::catalog_provider::FFI_CatalogProvider; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use paimon::{CatalogFactory, Options}; use paimon_datafusion::{PaimonCatalogProvider, SQLContext}; -use pyo3::exceptions::PyTypeError; +use pyo3::exceptions::PyRuntimeWarning; use pyo3::prelude::*; -use pyo3::types::{PyCapsule, PyList, PyTuple}; +use pyo3::types::PyCapsule; +use crate::blob::PyBlobReaderRegistry; use crate::error::{df_to_py_err, to_py_err}; +use crate::udf::{build_python_scalar_udf, udf, PyPythonScalarUDFObject}; use paimon_datafusion::runtime::runtime; fn build_paimon_catalog_provider( @@ -68,301 +62,6 @@ fn ffi_logical_codec_from_pycapsule(obj: Bound<'_, PyAny>) -> PyResult PyResult { - match type_name.to_ascii_lowercase().as_str() { - "bool" | "boolean" => Ok(ArrowDataType::Boolean), - "int8" => Ok(ArrowDataType::Int8), - "int16" => Ok(ArrowDataType::Int16), - "int" | "int32" | "integer" => Ok(ArrowDataType::Int32), - "bigint" | "int64" | "long" => Ok(ArrowDataType::Int64), - "float" | "float32" => Ok(ArrowDataType::Float32), - "double" | "float64" => Ok(ArrowDataType::Float64), - "string" | "utf8" => Ok(ArrowDataType::Utf8), - "large_string" | "large_utf8" => Ok(ArrowDataType::LargeUtf8), - "binary" => Ok(ArrowDataType::Binary), - "large_binary" => Ok(ArrowDataType::LargeBinary), - other => Err(PyTypeError::new_err(format!( - "Unsupported Arrow type for Python UDF: {other}" - ))), - } -} - -fn parse_arrow_type_like(value: &Bound<'_, PyAny>) -> PyResult { - if let Ok(field) = ArrowField::from_pyarrow_bound(value) { - return Ok(field.data_type().clone()); - } - if let Ok(data_type) = ArrowDataType::from_pyarrow_bound(value) { - return Ok(data_type); - } - if let Ok(type_name) = value.extract::() { - return parse_arrow_type(&type_name); - } - - Err(PyTypeError::new_err( - "Expected a pyarrow.DataType, pyarrow.Field, or supported Arrow type name", - )) -} - -fn parse_input_types(input_fields: &Bound<'_, PyAny>) -> PyResult> { - if let Ok(fields) = input_fields.cast::() { - return fields - .iter() - .map(|field| parse_arrow_type_like(&field)) - .collect(); - } - if let Ok(fields) = input_fields.cast::() { - return fields - .iter() - .map(|field| parse_arrow_type_like(&field)) - .collect(); - } - - Ok(vec![parse_arrow_type_like(input_fields)?]) -} - -fn parse_volatility(volatility: &Bound<'_, PyAny>) -> PyResult { - let value = if let Ok(value) = volatility.extract::() { - value - } else if let Ok(name) = volatility.getattr("name") { - name.extract::()? - } else { - volatility.str()?.to_str()?.to_string() - }; - - match value.to_ascii_lowercase().as_str() { - "immutable" => Ok(Volatility::Immutable), - "stable" => Ok(Volatility::Stable), - "volatile" => Ok(Volatility::Volatile), - other => Err(PyTypeError::new_err(format!( - "Unsupported UDF volatility: {other}. Expected immutable, stable, or volatile" - ))), - } -} - -fn default_udf_name(py: Python<'_>, func: &Py) -> PyResult { - let func = func.bind(py); - if let Ok(name) = func.getattr("__qualname__") { - return Ok(name.extract::()?.to_ascii_lowercase()); - } - Ok(func - .getattr("__class__")? - .getattr("__name__")? - .extract::()? - .to_ascii_lowercase()) -} - -fn df_execution_error(message: impl Into) -> DataFusionError { - DataFusionError::Execution(message.into()) -} - -fn columnar_value_to_array(value: &ColumnarValue, num_rows: usize) -> DFResult { - match value { - ColumnarValue::Array(array) => Ok(Arc::clone(array)), - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows), - } -} - -struct PyScalarUDF { - name: String, - func: Py, - input_types: Vec, - return_type: ArrowDataType, - volatility: Volatility, - signature: Signature, -} - -impl PyScalarUDF { - fn new( - name: String, - func: Py, - input_types: Vec, - return_type: ArrowDataType, - volatility: Volatility, - ) -> Self { - let signature = Signature::exact(input_types.clone(), volatility); - Self { - name, - func, - input_types, - return_type, - volatility, - signature, - } - } -} - -#[pyclass(name = "PythonScalarUDF")] -pub struct PyPythonScalarUDFObject { - name: String, - udf: DFScalarUDF, -} - -impl PyPythonScalarUDFObject { - fn create( - py: Python<'_>, - name: String, - func: Py, - input_fields: &Bound<'_, PyAny>, - return_field: &Bound<'_, PyAny>, - volatility: &Bound<'_, PyAny>, - ) -> PyResult { - if !func.bind(py).is_callable() { - return Err(PyTypeError::new_err("`func` argument must be callable")); - } - - let input_types = parse_input_types(input_fields)?; - let return_type = parse_arrow_type_like(return_field)?; - let volatility = parse_volatility(volatility)?; - let udf = PyScalarUDF::new(name.clone(), func, input_types, return_type, volatility); - Ok(Self { - name, - udf: DFScalarUDF::new_from_impl(udf), - }) - } -} - -#[pymethods] -impl PyPythonScalarUDFObject { - #[new] - fn new( - py: Python<'_>, - name: String, - func: Py, - input_fields: Bound<'_, PyAny>, - return_field: Bound<'_, PyAny>, - volatility: Bound<'_, PyAny>, - ) -> PyResult { - Self::create(py, name, func, &input_fields, &return_field, &volatility) - } - - #[staticmethod] - #[pyo3(signature = (func, input_fields, return_field, volatility, name = None))] - fn udf( - py: Python<'_>, - func: Py, - input_fields: Bound<'_, PyAny>, - return_field: Bound<'_, PyAny>, - volatility: Bound<'_, PyAny>, - name: Option, - ) -> PyResult { - let name = match name { - Some(name) => name, - None => default_udf_name(py, &func)?, - }; - Self::create(py, name, func, &input_fields, &return_field, &volatility) - } - - #[getter] - fn name(&self) -> &str { - &self.name - } - - fn __repr__(&self) -> String { - format!("PythonScalarUDF({})", self.name) - } -} - -#[pyfunction] -#[pyo3(signature = (func, input_fields, return_field, volatility, name = None))] -fn udf( - py: Python<'_>, - func: Py, - input_fields: Bound<'_, PyAny>, - return_field: Bound<'_, PyAny>, - volatility: Bound<'_, PyAny>, - name: Option, -) -> PyResult { - PyPythonScalarUDFObject::udf(py, func, input_fields, return_field, volatility, name) -} - -impl Debug for PyScalarUDF { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("PyScalarUDF") - .field("name", &self.name) - .field("input_types", &self.input_types) - .field("return_type", &self.return_type) - .field("volatility", &self.volatility) - .finish_non_exhaustive() - } -} - -impl PartialEq for PyScalarUDF { - fn eq(&self, other: &Self) -> bool { - self.name == other.name - && self.input_types == other.input_types - && self.return_type == other.return_type - && self.volatility == other.volatility - } -} - -impl Eq for PyScalarUDF {} - -impl Hash for PyScalarUDF { - fn hash(&self, state: &mut H) { - self.name.hash(state); - self.input_types.hash(state); - self.return_type.hash(state); - self.volatility.hash(state); - } -} - -impl ScalarUDFImpl for PyScalarUDF { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - &self.name - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[ArrowDataType]) -> DFResult { - Ok(self.return_type.clone()) - } - - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { - let arrays = args - .args - .iter() - .map(|value| columnar_value_to_array(value, args.number_rows)) - .collect::>>()?; - - let output = Python::try_attach(|py| -> PyResult { - let py_args = arrays - .iter() - .map(|array| array.to_data().to_pyarrow(py)) - .collect::>>()?; - let py_args = PyTuple::new(py, py_args)?; - let output = self.func.bind(py).call1(py_args)?; - Ok(make_array(ArrayData::from_pyarrow_bound(&output)?)) - }) - .ok_or_else(|| df_execution_error("Python interpreter is not available"))? - .map_err(|err| df_execution_error(format!("Python UDF '{}' failed: {err}", self.name)))?; - - if output.len() != args.number_rows { - return Err(df_execution_error(format!( - "Python UDF '{}' returned {} rows, expected {}", - self.name, - output.len(), - args.number_rows - ))); - } - if output.data_type() != &self.return_type { - return Err(df_execution_error(format!( - "Python UDF '{}' returned {:?}, expected {:?}", - self.name, - output.data_type(), - self.return_type - ))); - } - - Ok(ColumnarValue::Array(output)) - } -} - /// A Paimon catalog exportable to Python DataFusion `SessionContext`. #[pyclass(name = "PaimonCatalog")] pub struct PaimonCatalog { @@ -399,28 +98,78 @@ pub struct PySQLContext { inner: SQLContext, } +impl PySQLContext { + fn register_video_snapshot_builtin(&self, py: Python<'_>) -> PyResult<()> { + let functions = py.import("pypaimon_rust.functions")?; + let blob_reader_registry = Py::new( + py, + PyBlobReaderRegistry::new(self.inner.blob_reader_registry()), + )?; + let func = functions + .getattr("_make_video_snapshot")? + .call1(("PNG", blob_reader_registry))? + .unbind(); + let signature = Signature::one_of( + vec![ + TypeSignature::Exact(vec![ArrowDataType::Binary]), + TypeSignature::Exact(vec![ArrowDataType::Binary, ArrowDataType::Int32]), + TypeSignature::Exact(vec![ArrowDataType::Binary, ArrowDataType::Int64]), + ], + Volatility::Volatile, + ); + let udf = build_python_scalar_udf( + "video_snapshot".to_string(), + func, + ArrowDataType::Binary, + signature, + ); + self.inner.ctx().register_udf(udf); + Ok(()) + } + + fn warn_video_snapshot_registration_failure(py: Python<'_>, err: PyErr) { + if let Ok(warnings) = py.import("warnings") { + let category = py.get_type::(); + let _ = warnings.call_method1( + "warn", + ( + format!("video_snapshot built-in could not be registered: {err}"), + category, + ), + ); + } + } +} + #[pymethods] impl PySQLContext { #[new] - fn new() -> Self { - Self { + fn new(py: Python<'_>) -> PyResult { + let ctx = Self { inner: SQLContext::new(), + }; + if let Err(err) = ctx.register_video_snapshot_builtin(py) { + Self::warn_video_snapshot_registration_failure(py, err); } + Ok(ctx) } fn register_catalog( &mut self, + py: Python<'_>, catalog_name: String, catalog_options: HashMap, ) -> PyResult<()> { let rt = runtime(); - rt.block_on(async { - let options = Options::from_map(catalog_options); - let catalog = CatalogFactory::create(options).await.map_err(to_py_err)?; - self.inner - .register_catalog(catalog_name, catalog) - .await - .map_err(df_to_py_err) + py.detach(|| { + rt.block_on(async { + let options = Options::from_map(catalog_options); + let catalog = CatalogFactory::create(options).await.map_err(to_py_err)?; + self.inner + .register_catalog(catalog_name, catalog) + .await + .map_err(df_to_py_err) + }) }) } @@ -455,7 +204,7 @@ impl PySQLContext { } fn register_udf(&self, udf: &PyPythonScalarUDFObject) -> PyResult<()> { - self.inner.ctx().register_udf(udf.udf.clone()); + self.inner.ctx().register_udf(udf.datafusion_udf()); Ok(()) } diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 326796d0..5f8d17a4 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -17,8 +17,10 @@ use pyo3::prelude::*; +mod blob; mod context; mod error; +mod udf; #[pymodule] fn pypaimon_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { diff --git a/bindings/python/src/udf.rs b/bindings/python/src/udf.rs new file mode 100644 index 00000000..36340289 --- /dev/null +++ b/bindings/python/src/udf.rs @@ -0,0 +1,363 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt::{self, Debug}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use arrow::array::{make_array, Array, ArrayData, ArrayRef}; +use arrow::datatypes::{DataType as ArrowDataType, Field as ArrowField}; +use arrow::pyarrow::{FromPyArrow, ToPyArrow}; +use datafusion::common::{DataFusionError, Result as DFResult}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF as DFScalarUDF, ScalarUDFImpl, Signature, + Volatility, +}; +use pyo3::exceptions::PyTypeError; +use pyo3::prelude::*; +use pyo3::types::{PyList, PyTuple}; + +fn parse_arrow_type(type_name: &str) -> PyResult { + match type_name.to_ascii_lowercase().as_str() { + "bool" | "boolean" => Ok(ArrowDataType::Boolean), + "int8" => Ok(ArrowDataType::Int8), + "int16" => Ok(ArrowDataType::Int16), + "int" | "int32" | "integer" => Ok(ArrowDataType::Int32), + "bigint" | "int64" | "long" => Ok(ArrowDataType::Int64), + "float" | "float32" => Ok(ArrowDataType::Float32), + "double" | "float64" => Ok(ArrowDataType::Float64), + "string" | "utf8" => Ok(ArrowDataType::Utf8), + "large_string" | "large_utf8" => Ok(ArrowDataType::LargeUtf8), + "binary" => Ok(ArrowDataType::Binary), + "large_binary" => Ok(ArrowDataType::LargeBinary), + other => Err(PyTypeError::new_err(format!( + "Unsupported Arrow type for Python UDF: {other}" + ))), + } +} + +fn parse_arrow_type_like(value: &Bound<'_, PyAny>) -> PyResult { + if let Ok(field) = ArrowField::from_pyarrow_bound(value) { + return Ok(field.data_type().clone()); + } + if let Ok(data_type) = ArrowDataType::from_pyarrow_bound(value) { + return Ok(data_type); + } + if let Ok(type_name) = value.extract::() { + return parse_arrow_type(&type_name); + } + + Err(PyTypeError::new_err( + "Expected a pyarrow.DataType, pyarrow.Field, or supported Arrow type name", + )) +} + +fn parse_input_types(input_fields: &Bound<'_, PyAny>) -> PyResult> { + if let Ok(fields) = input_fields.cast::() { + return fields + .iter() + .map(|field| parse_arrow_type_like(&field)) + .collect(); + } + if let Ok(fields) = input_fields.cast::() { + return fields + .iter() + .map(|field| parse_arrow_type_like(&field)) + .collect(); + } + + Ok(vec![parse_arrow_type_like(input_fields)?]) +} + +fn parse_volatility(volatility: &Bound<'_, PyAny>) -> PyResult { + let value = if let Ok(value) = volatility.extract::() { + value + } else if let Ok(name) = volatility.getattr("name") { + name.extract::()? + } else { + volatility.str()?.to_str()?.to_string() + }; + + match value.to_ascii_lowercase().as_str() { + "immutable" => Ok(Volatility::Immutable), + "stable" => Ok(Volatility::Stable), + "volatile" => Ok(Volatility::Volatile), + other => Err(PyTypeError::new_err(format!( + "Unsupported UDF volatility: {other}. Expected immutable, stable, or volatile" + ))), + } +} + +fn sanitize_udf_name(name: &str) -> String { + let mut sanitized = name + .chars() + .map(|ch| { + if ch.is_ascii_alphanumeric() || ch == '_' { + ch.to_ascii_lowercase() + } else { + '_' + } + }) + .collect::() + .trim_matches('_') + .to_string(); + + if sanitized.is_empty() { + sanitized.push_str("python_udf"); + } + if sanitized + .chars() + .next() + .is_some_and(|ch| ch.is_ascii_digit()) + { + sanitized.insert(0, '_'); + } + sanitized +} + +fn default_udf_name(py: Python<'_>, func: &Py) -> PyResult { + let func = func.bind(py); + if let Ok(name) = func.getattr("__name__") { + return Ok(sanitize_udf_name(&name.extract::()?)); + } + if let Ok(name) = func.getattr("__qualname__") { + return Ok(sanitize_udf_name(&name.extract::()?)); + } + let name = func + .getattr("__class__")? + .getattr("__name__")? + .extract::()?; + Ok(sanitize_udf_name(&name)) +} + +fn df_execution_error(message: impl Into) -> DataFusionError { + DataFusionError::Execution(message.into()) +} + +fn columnar_value_to_array(value: &ColumnarValue, num_rows: usize) -> DFResult { + match value { + ColumnarValue::Array(array) => Ok(Arc::clone(array)), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows), + } +} + +struct PyScalarUDF { + name: String, + func: Py, + return_type: ArrowDataType, + signature: Signature, +} + +impl PyScalarUDF { + fn new( + name: String, + func: Py, + return_type: ArrowDataType, + signature: Signature, + ) -> Self { + Self { + name, + func, + return_type, + signature, + } + } +} + +pub(crate) fn build_python_scalar_udf( + name: String, + func: Py, + return_type: ArrowDataType, + signature: Signature, +) -> DFScalarUDF { + DFScalarUDF::new_from_impl(PyScalarUDF::new(name, func, return_type, signature)) +} + +#[pyclass(name = "PythonScalarUDF")] +pub struct PyPythonScalarUDFObject { + name: String, + udf: DFScalarUDF, +} + +impl PyPythonScalarUDFObject { + fn create( + py: Python<'_>, + name: String, + func: Py, + input_fields: &Bound<'_, PyAny>, + return_field: &Bound<'_, PyAny>, + volatility: &Bound<'_, PyAny>, + ) -> PyResult { + if !func.bind(py).is_callable() { + return Err(PyTypeError::new_err("`func` argument must be callable")); + } + + let input_types = parse_input_types(input_fields)?; + let return_type = parse_arrow_type_like(return_field)?; + let volatility = parse_volatility(volatility)?; + let signature = Signature::exact(input_types, volatility); + let udf = PyScalarUDF::new(name.clone(), func, return_type, signature); + Ok(Self { + name, + udf: DFScalarUDF::new_from_impl(udf), + }) + } + + pub(crate) fn datafusion_udf(&self) -> DFScalarUDF { + self.udf.clone() + } +} + +#[pymethods] +impl PyPythonScalarUDFObject { + #[new] + fn new( + py: Python<'_>, + name: String, + func: Py, + input_fields: Bound<'_, PyAny>, + return_field: Bound<'_, PyAny>, + volatility: Bound<'_, PyAny>, + ) -> PyResult { + Self::create(py, name, func, &input_fields, &return_field, &volatility) + } + + #[staticmethod] + #[pyo3(signature = (func, input_fields, return_field, volatility, name = None))] + fn udf( + py: Python<'_>, + func: Py, + input_fields: Bound<'_, PyAny>, + return_field: Bound<'_, PyAny>, + volatility: Bound<'_, PyAny>, + name: Option, + ) -> PyResult { + let name = match name { + Some(name) => name, + None => default_udf_name(py, &func)?, + }; + Self::create(py, name, func, &input_fields, &return_field, &volatility) + } + + #[getter] + fn name(&self) -> &str { + &self.name + } + + fn __repr__(&self) -> String { + format!("PythonScalarUDF({})", self.name) + } +} + +#[pyfunction] +#[pyo3(signature = (func, input_fields, return_field, volatility, name = None))] +pub(crate) fn udf( + py: Python<'_>, + func: Py, + input_fields: Bound<'_, PyAny>, + return_field: Bound<'_, PyAny>, + volatility: Bound<'_, PyAny>, + name: Option, +) -> PyResult { + PyPythonScalarUDFObject::udf(py, func, input_fields, return_field, volatility, name) +} + +impl Debug for PyScalarUDF { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PyScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("return_type", &self.return_type) + .finish_non_exhaustive() + } +} + +impl PartialEq for PyScalarUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.return_type == other.return_type + && self.signature == other.signature + } +} + +impl Eq for PyScalarUDF {} + +impl Hash for PyScalarUDF { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.return_type.hash(state); + self.signature.hash(state); + } +} + +impl ScalarUDFImpl for PyScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[ArrowDataType]) -> DFResult { + Ok(self.return_type.clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let arrays = args + .args + .iter() + .map(|value| columnar_value_to_array(value, args.number_rows)) + .collect::>>()?; + + let output = Python::try_attach(|py| -> PyResult { + let py_args = arrays + .iter() + .map(|array| array.to_data().to_pyarrow(py)) + .collect::>>()?; + let py_args = PyTuple::new(py, py_args)?; + let output = self.func.bind(py).call1(py_args)?; + Ok(make_array(ArrayData::from_pyarrow_bound(&output)?)) + }) + .ok_or_else(|| df_execution_error("Python interpreter is not available"))? + .map_err(|err| df_execution_error(format!("Python UDF '{}' failed: {err}", self.name)))?; + + if output.len() != args.number_rows { + return Err(df_execution_error(format!( + "Python UDF '{}' returned {} rows, expected {}", + self.name, + output.len(), + args.number_rows + ))); + } + if output.data_type() != &self.return_type { + return Err(df_execution_error(format!( + "Python UDF '{}' returned {:?}, expected {:?}", + self.name, + output.data_type(), + self.return_type + ))); + } + + Ok(ColumnarValue::Array(output)) + } +} diff --git a/bindings/python/tests/test_datafusion.py b/bindings/python/tests/test_datafusion.py index 17af490a..d120a945 100644 --- a/bindings/python/tests/test_datafusion.py +++ b/bindings/python/tests/test_datafusion.py @@ -15,15 +15,63 @@ # specific language governing permissions and limitations # under the License. +import io import os +import struct +import sys import tempfile +import types +from pathlib import Path import pyarrow as pa +import pytest from datafusion import SessionContext from pypaimon_rust.datafusion import PaimonCatalog, PythonScalarUDF, SQLContext, udf WAREHOUSE = os.environ.get("PAIMON_TEST_WAREHOUSE", "/tmp/paimon-warehouse") +PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n" +BLOB_DESCRIPTOR_MAGIC = 0x424C4F4244455343 + + +def serialize_blob_descriptor(uri: str, offset: int, length: int) -> bytes: + uri_bytes = uri.encode("utf-8") + return ( + struct.pack(" None: + av = pytest.importorskip("av") + image_module = pytest.importorskip("PIL.Image") + + with av.open(str(path), mode="w") as container: + stream = container.add_stream("mpeg4", rate=1) + stream.width = 32 + stream.height = 32 + stream.pix_fmt = "yuv420p" + + for color in colors: + image = image_module.new("RGB", (32, 32), color=color) + frame = av.VideoFrame.from_image(image) + for packet in stream.encode(frame): + container.mux(packet) + for packet in stream.encode(): + container.mux(packet) + + +def sample_image_bytes() -> bytes: + image_module = pytest.importorskip("PIL.Image") + + output = io.BytesIO() + image = image_module.new("RGB", (32, 32), color=(40, 120, 220)) + image.save(output, format="PNG") + return output.getvalue() def extract_rows(batches): @@ -31,6 +79,207 @@ def extract_rows(batches): return sorted(zip(table["id"].to_pylist(), table["name"].to_pylist())) +def test_video_snapshot_builtin_registered_on_context_init(): + ctx = SQLContext() + + batches = ctx.sql("SELECT video_snapshot(CAST(NULL AS BYTEA)) AS cover_png") + table = pa.Table.from_batches(batches) + + assert table["cover_png"].to_pylist() == [None] + + +def test_sql_context_survives_video_snapshot_registration_failure(monkeypatch): + monkeypatch.setitem( + sys.modules, + "pypaimon_rust.functions", + types.SimpleNamespace(), + ) + + with pytest.warns( + RuntimeWarning, + match="video_snapshot built-in could not be registered", + ): + ctx = SQLContext() + + batches = ctx.sql("SELECT 1 AS value") + table = pa.Table.from_batches(batches) + assert table["value"].to_pylist() == [1] + + +def test_video_snapshot_builtin_auto_registered_for_sql(): + with tempfile.TemporaryDirectory() as warehouse: + video_path = Path(warehouse) / "sample.mp4" + write_sample_video(video_path) + video_bytes = video_path.read_bytes() + + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + ctx.register_batch( + "paimon.default.videos", + pa.record_batch( + [[1], pa.array([video_bytes], type=pa.binary())], + names=["id", "video"], + ), + ) + + batches = ctx.sql( + """ + SELECT id, video_snapshot(video) AS cover_png + FROM paimon.default.videos + """ + ) + table = pa.Table.from_batches(batches) + + assert table["id"].to_pylist() == [1] + assert table["cover_png"].to_pylist()[0].startswith(PNG_SIGNATURE) + + ctx.sql("DROP TEMPORARY TABLE paimon.default.videos") + + +def test_video_snapshot_descriptor_without_table_file_io_returns_null(): + with tempfile.TemporaryDirectory() as warehouse: + video_path = Path(warehouse) / "sample.mp4" + write_sample_video(video_path) + descriptor = serialize_blob_descriptor( + str(video_path), 0, video_path.stat().st_size + ) + + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + ctx.register_batch( + "paimon.default.videos", + pa.record_batch( + [[1], pa.array([descriptor], type=pa.binary())], + names=["id", "video"], + ), + ) + + batches = ctx.sql( + """ + SELECT id, video_snapshot(video) AS cover_png + FROM paimon.default.videos + """ + ) + table = pa.Table.from_batches(batches) + + assert table["id"].to_pylist() == [1] + assert table["cover_png"].to_pylist() == [None] + + ctx.sql("DROP TEMPORARY TABLE paimon.default.videos") + + +def test_video_snapshot_returns_null_for_image_bytes(): + with tempfile.TemporaryDirectory() as warehouse: + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + ctx.register_batch( + "paimon.default.media", + pa.record_batch( + [[1], pa.array([sample_image_bytes()], type=pa.binary())], + names=["id", "content"], + ), + ) + + batches = ctx.sql( + """ + SELECT id, video_snapshot(content) AS cover_png + FROM paimon.default.media + """ + ) + table = pa.Table.from_batches(batches) + + assert table["id"].to_pylist() == [1] + assert table["cover_png"].to_pylist() == [None] + + ctx.sql("DROP TEMPORARY TABLE paimon.default.media") + + +def test_video_snapshot_reads_descriptor_with_table_file_io(): + with tempfile.TemporaryDirectory() as warehouse: + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + ctx.sql("CREATE TABLE paimon.default.videos (id INT, video BINARY)") + + video_path = Path(warehouse) / "default.db" / "videos" / "sample.mp4" + video_path.parent.mkdir(parents=True, exist_ok=True) + write_sample_video(video_path) + descriptor = serialize_blob_descriptor( + str(video_path), 0, video_path.stat().st_size + ) + + ctx.register_batch( + "source_videos", + pa.record_batch( + [[1], pa.array([descriptor], type=pa.binary())], + names=["id", "video"], + ), + ) + ctx.sql( + """ + INSERT INTO paimon.default.videos + SELECT id, video FROM paimon.default.source_videos + """ + ) + + batches = ctx.sql( + """ + SELECT id, video_snapshot(video) AS cover_png + FROM paimon.default.videos + """ + ) + table = pa.Table.from_batches(batches) + + assert table["id"].to_pylist() == [1] + assert table["cover_png"].to_pylist()[0].startswith(PNG_SIGNATURE) + + ctx.sql("DROP TEMPORARY TABLE paimon.default.source_videos") + ctx.sql("DROP TABLE paimon.default.videos") + + +def test_video_snapshot_accepts_timestamp_ms(): + image_module = pytest.importorskip("PIL.Image") + + with tempfile.TemporaryDirectory() as warehouse: + video_path = Path(warehouse) / "sample.mp4" + write_sample_video(video_path, colors=((240, 40, 80), (40, 220, 80))) + video_bytes = video_path.read_bytes() + + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + ctx.register_batch( + "paimon.default.videos", + pa.record_batch( + [[1], pa.array([video_bytes], type=pa.binary())], + names=["id", "video"], + ), + ) + + batches = ctx.sql( + """ + SELECT + video_snapshot(video) AS first_png, + video_snapshot(video, CAST(1000 AS INT)) AS second_png, + video_snapshot(video, 5000) AS beyond_duration_png + FROM paimon.default.videos + """ + ) + row = pa.Table.from_batches(batches).to_pylist()[0] + + assert row["first_png"].startswith(PNG_SIGNATURE) + assert row["second_png"].startswith(PNG_SIGNATURE) + + first_image = image_module.open(io.BytesIO(row["first_png"])).convert("RGB") + second_image = image_module.open(io.BytesIO(row["second_png"])).convert("RGB") + assert first_image.getpixel((16, 16)) != second_image.getpixel((16, 16)) + assert row["beyond_duration_png"].startswith(PNG_SIGNATURE) + beyond_duration_image = image_module.open( + io.BytesIO(row["beyond_duration_png"]) + ).convert("RGB") + assert beyond_duration_image.getpixel((16, 16)) == second_image.getpixel((16, 16)) + + ctx.sql("DROP TEMPORARY TABLE paimon.default.videos") + + def test_query_simple_table_via_catalog_provider(): catalog = PaimonCatalog({"warehouse": WAREHOUSE}) ctx = SessionContext() @@ -125,7 +374,82 @@ def plus_ten(values): ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp") -def test_register_udf_multi_input_plan(): +def test_register_udf_default_name_is_sql_identifier_for_closure(): + with tempfile.TemporaryDirectory() as warehouse: + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + + batch = pa.record_batch([[1, 2]], names=["id"]) + ctx.register_batch("my_temp", batch) + + def build_udf(): + def plus_one(values): + return pa.array( + [value + 1 for value in values.to_pylist()], type=pa.int64() + ) + + return plus_one + + scalar_udf = udf(build_udf(), [pa.int64()], pa.int64(), "volatile") + assert scalar_udf.name == "plus_one" + ctx.register_udf(scalar_udf) + + batches = ctx.sql( + "SELECT plus_one(id) AS id FROM paimon.default.my_temp ORDER BY id" + ) + table = pa.Table.from_batches(batches) + assert table["id"].to_pylist() == [2, 3] + + ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp") + + +def test_register_udf_multiple_arguments(): + with tempfile.TemporaryDirectory() as warehouse: + ctx = SQLContext() + ctx.register_catalog("paimon", {"warehouse": warehouse}) + + batch = pa.record_batch( + [ + pa.array([1, 2, None], type=pa.int64()), + pa.array([10, 20, 30], type=pa.int64()), + ], + names=["id", "delta"], + ) + ctx.register_batch("my_temp", batch) + + def add_values(left, right): + values = [] + for left_value, right_value in zip(left.to_pylist(), right.to_pylist()): + if left_value is None or right_value is None: + values.append(None) + else: + values.append(left_value + right_value) + return pa.array(values, type=pa.int64()) + + ctx.register_udf( + udf( + add_values, + [pa.int64(), pa.int64()], + pa.int64(), + "volatile", + "add_values", + ) + ) + + batches = ctx.sql( + """ + SELECT add_values(id, delta) AS value + FROM paimon.default.my_temp + ORDER BY id + """ + ) + table = pa.Table.from_batches(batches) + assert table["value"].to_pylist() == [11, 22, None] + + ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp") + + +def test_register_udf_multi_partition_union_plan(): with tempfile.TemporaryDirectory() as warehouse: ctx = SQLContext() ctx.register_catalog("paimon", {"warehouse": warehouse}) @@ -155,7 +479,7 @@ def plus_ten(values): def test_udf_rejects_non_callable(): try: udf(1, [pa.int64()], pa.int64(), "volatile") - assert False, "expected non-callable UDF creation to fail" + pytest.fail("expected non-callable UDF creation to fail") except TypeError as e: assert "`func` argument must be callable" in str(e) @@ -166,7 +490,7 @@ def identity(values): try: udf(identity, [object()], pa.int64(), "volatile", "identity") - assert False, "expected unsupported type registration to fail" + pytest.fail("expected unsupported type registration to fail") except TypeError as e: assert "Expected a pyarrow.DataType" in str(e) @@ -196,7 +520,7 @@ def boom(values): try: ctx.sql("SELECT boom(id) AS id FROM paimon.default.my_temp") - assert False, "expected Python UDF exception to fail the query" + pytest.fail("expected Python UDF exception to fail the query") except Exception as e: message = str(e) assert "Python UDF 'boom' failed" in message @@ -218,7 +542,7 @@ def wrong_length(values): try: ctx.sql("SELECT wrong_length(id) AS id FROM paimon.default.my_temp") - assert False, "expected wrong-length UDF result to fail the query" + pytest.fail("expected wrong-length UDF result to fail the query") except Exception as e: message = str(e) assert "Python UDF 'wrong_length' returned 1 rows, expected 2" in message @@ -239,7 +563,7 @@ def wrong_type(values): try: ctx.sql("SELECT wrong_type(id) AS id FROM paimon.default.my_temp") - assert False, "expected wrong-type UDF result to fail the query" + pytest.fail("expected wrong-type UDF result to fail the query") except Exception as e: message = str(e) assert "Python UDF 'wrong_type' returned Utf8, expected Int64" in message @@ -319,7 +643,7 @@ def test_register_batch_invalid_catalog(): batch = pa.record_batch([[1]], names=["id"]) try: ctx.register_batch("unknown_catalog.default.my_temp", batch) - assert False, "Expected an error for unknown catalog" + pytest.fail("Expected an error for unknown catalog") except Exception as e: assert "unknown_catalog" in str(e).lower() or "not a paimon" in str(e).lower() or "unknown" in str(e).lower() @@ -336,6 +660,6 @@ def test_table_functions_registered_with_catalog(): for fn in ("vector_search", "full_text_search"): try: ctx.sql(f"SELECT * FROM {fn}('only_one_arg')") - assert False, f"expected {fn} to reject a single argument" + pytest.fail(f"expected {fn} to reject a single argument") except Exception as e: assert "requires 4 arguments" in str(e), str(e) diff --git a/crates/integrations/datafusion/src/blob_reader.rs b/crates/integrations/datafusion/src/blob_reader.rs new file mode 100644 index 00000000..fe7bd887 --- /dev/null +++ b/crates/integrations/datafusion/src/blob_reader.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, RwLock}; + +use paimon::io::FileIO; + +#[derive(Clone, Debug)] +struct BlobFileIO { + prefix: String, + file_io: FileIO, +} + +/// Session-scoped registry of Paimon [`FileIO`] instances for BlobDescriptor reads. +#[derive(Clone, Debug, Default)] +pub struct BlobReaderRegistry { + readers: Arc>>, +} + +impl BlobReaderRegistry { + pub fn register(&self, prefix: impl Into, file_io: FileIO) { + let prefix = prefix.into(); + let mut readers = self.readers.write().unwrap_or_else(|e| e.into_inner()); + if let Some(existing) = readers.iter_mut().find(|reader| reader.prefix == prefix) { + existing.file_io = file_io; + return; + } + readers.push(BlobFileIO { prefix, file_io }); + } + + pub fn register_if_absent(&self, prefix: impl Into, file_io: FileIO) { + let prefix = prefix.into(); + let mut readers = self.readers.write().unwrap_or_else(|e| e.into_inner()); + if readers.iter().any(|reader| reader.prefix == prefix) { + return; + } + readers.push(BlobFileIO { prefix, file_io }); + } + + pub fn resolve(&self, uri: &str) -> Option { + let readers = self.readers.read().unwrap_or_else(|e| e.into_inner()); + readers + .iter() + .filter(|reader| uri.starts_with(&reader.prefix)) + .max_by_key(|reader| reader.prefix.len()) + .map(|reader| reader.file_io.clone()) + } +} + +#[cfg(test)] +mod tests { + use std::fs; + + use paimon::io::{FileIOBuilder, FileRead}; + use paimon::spec::BlobDescriptor; + + use super::*; + + #[tokio::test] + async fn resolves_file_blob_descriptor_with_file_io() { + let directory = tempfile::tempdir().unwrap(); + let blob_path = directory.path().join("blob.bin"); + fs::write(&blob_path, b"prefixpayloadsuffix").unwrap(); + + let descriptor = BlobDescriptor::new( + blob_path.to_string_lossy().to_string(), + 6, + "payload".len() as i64, + ); + let descriptor = BlobDescriptor::deserialize(&descriptor.serialize()).unwrap(); + + let registry = BlobReaderRegistry::default(); + let file_io = FileIOBuilder::new("file").build().unwrap(); + registry.register(directory.path().to_string_lossy().to_string(), file_io); + + let resolved_file_io = registry + .resolve(descriptor.uri()) + .expect("file blob descriptor should resolve to registered FileIO"); + let input = resolved_file_io.new_input(descriptor.uri()).unwrap(); + let reader = input.reader().await.unwrap(); + let start = descriptor.offset() as u64; + let end = start + descriptor.length() as u64; + let bytes = reader.read(start..end).await.unwrap(); + + assert_eq!(&bytes[..], b"payload"); + } +} diff --git a/crates/integrations/datafusion/src/catalog.rs b/crates/integrations/datafusion/src/catalog.rs index a93201aa..018ef6b9 100644 --- a/crates/integrations/datafusion/src/catalog.rs +++ b/crates/integrations/datafusion/src/catalog.rs @@ -34,7 +34,7 @@ use crate::error::to_datafusion_error; use crate::runtime::{await_with_runtime, block_on_with_runtime}; use crate::system_tables; use crate::table::PaimonTableProvider; -use crate::DynamicOptions; +use crate::{BlobReaderRegistry, DynamicOptions}; /// Provides an interface to manage and access multiple schemas (databases) /// within a Paimon [`Catalog`]. @@ -54,6 +54,7 @@ pub struct PaimonCatalogProvider { /// propagate the panic to all subsequent operations. The worst case is a temp table /// becoming invisible or stale, which is recoverable by re-registering it. temp_tables: Arc>>>, + blob_reader_registry: BlobReaderRegistry, } impl Debug for PaimonCatalogProvider { @@ -73,17 +74,20 @@ impl PaimonCatalogProvider { catalog, dynamic_options: Default::default(), temp_tables: Arc::new(RwLock::new(HashMap::new())), + blob_reader_registry: BlobReaderRegistry::default(), } } pub(crate) fn with_dynamic_options( catalog: Arc, dynamic_options: DynamicOptions, + blob_reader_registry: BlobReaderRegistry, ) -> Self { PaimonCatalogProvider { catalog, dynamic_options, temp_tables: Arc::new(RwLock::new(HashMap::new())), + blob_reader_registry, } } } @@ -109,6 +113,7 @@ impl CatalogProvider for PaimonCatalogProvider { fn schema(&self, name: &str) -> Option> { let catalog = Arc::clone(&self.catalog); let dynamic_options = Arc::clone(&self.dynamic_options); + let blob_reader_registry = self.blob_reader_registry.clone(); let name = name.to_string(); let temp_provider = { @@ -124,6 +129,7 @@ impl CatalogProvider for PaimonCatalogProvider { name, dynamic_options, temp_provider, + blob_reader_registry, )) as Arc), Err(paimon::Error::DatabaseNotExist { .. }) => { if temp_provider.is_some() { @@ -132,6 +138,7 @@ impl CatalogProvider for PaimonCatalogProvider { name, dynamic_options, temp_provider, + blob_reader_registry, )) as Arc) } else { None @@ -154,6 +161,7 @@ impl CatalogProvider for PaimonCatalogProvider { ) -> DFResult>> { let catalog = Arc::clone(&self.catalog); let dynamic_options = Arc::clone(&self.dynamic_options); + let blob_reader_registry = self.blob_reader_registry.clone(); let name = name.to_string(); block_on_with_runtime( async move { @@ -166,6 +174,7 @@ impl CatalogProvider for PaimonCatalogProvider { name, dynamic_options, None, + blob_reader_registry, )) as Arc)) }, "paimon catalog access thread panicked", @@ -179,6 +188,7 @@ impl CatalogProvider for PaimonCatalogProvider { ) -> DFResult>> { let catalog = Arc::clone(&self.catalog); let dynamic_options = Arc::clone(&self.dynamic_options); + let blob_reader_registry = self.blob_reader_registry.clone(); let name = name.to_string(); block_on_with_runtime( async move { @@ -191,6 +201,7 @@ impl CatalogProvider for PaimonCatalogProvider { name, dynamic_options, None, + blob_reader_registry, )) as Arc)) }, "paimon catalog access thread panicked", @@ -289,6 +300,7 @@ pub struct PaimonSchemaProvider { dynamic_options: DynamicOptions, /// Optional temporary in-memory provider for temp tables and views. temp_provider: Option>, + blob_reader_registry: BlobReaderRegistry, } impl Debug for PaimonSchemaProvider { @@ -307,12 +319,14 @@ impl PaimonSchemaProvider { database: String, dynamic_options: DynamicOptions, temp_provider: Option>, + blob_reader_registry: BlobReaderRegistry, ) -> Self { PaimonSchemaProvider { catalog, database, dynamic_options, temp_provider, + blob_reader_registry, } } } @@ -372,6 +386,7 @@ impl SchemaProvider for PaimonSchemaProvider { let catalog = Arc::clone(&self.catalog); let dynamic_options = Arc::clone(&self.dynamic_options); + let blob_reader_registry = self.blob_reader_registry.clone(); let identifier = Identifier::new(self.database.clone(), base); await_with_runtime(async move { match catalog.get_table(&identifier).await { @@ -382,7 +397,10 @@ impl SchemaProvider for PaimonSchemaProvider { } else { table.copy_with_options(opts) }; - let provider = PaimonTableProvider::try_new(table)?; + let provider = PaimonTableProvider::try_new_with_blob_reader_registry( + table, + blob_reader_registry, + )?; Ok(Some(Arc::new(provider) as Arc)) } Err(paimon::Error::TableNotExist { .. }) => Ok(None), diff --git a/crates/integrations/datafusion/src/lib.rs b/crates/integrations/datafusion/src/lib.rs index 47f1bab9..f11cfce4 100644 --- a/crates/integrations/datafusion/src/lib.rs +++ b/crates/integrations/datafusion/src/lib.rs @@ -36,6 +36,7 @@ //! This version supports partition predicate pushdown by extracting //! translatable partition-only conjuncts from DataFusion filters. +mod blob_reader; mod catalog; mod delete; mod error; @@ -63,6 +64,7 @@ use std::sync::{Arc, RwLock}; /// so that SET/RESET mutations are visible to subsequent table scans. pub(crate) type DynamicOptions = Arc>>; +pub use blob_reader::BlobReaderRegistry; pub use catalog::{PaimonCatalogProvider, PaimonSchemaProvider}; pub use error::to_datafusion_error; #[cfg(feature = "fulltext")] diff --git a/crates/integrations/datafusion/src/sql_context.rs b/crates/integrations/datafusion/src/sql_context.rs index f77b79e2..ec93741a 100644 --- a/crates/integrations/datafusion/src/sql_context.rs +++ b/crates/integrations/datafusion/src/sql_context.rs @@ -65,7 +65,7 @@ use paimon::spec::{ }; use crate::error::to_datafusion_error; -use crate::DynamicOptions; +use crate::{BlobReaderRegistry, DynamicOptions}; /// A SQL context that supports registering multiple Paimon catalogs and executing SQL. /// @@ -81,6 +81,7 @@ pub struct SQLContext { catalogs: HashMap>, /// Session-scoped dynamic options set via `SET 'paimon.key' = 'value'`. dynamic_options: DynamicOptions, + blob_reader_registry: BlobReaderRegistry, } impl Default for SQLContext { @@ -102,9 +103,14 @@ impl SQLContext { ctx, catalogs: HashMap::new(), dynamic_options: Default::default(), + blob_reader_registry: BlobReaderRegistry::default(), } } + pub fn blob_reader_registry(&self) -> BlobReaderRegistry { + self.blob_reader_registry.clone() + } + /// Registers a Paimon catalog under the given name. /// /// The first registered catalog automatically becomes the current catalog @@ -134,6 +140,7 @@ impl SQLContext { Arc::new(crate::catalog::PaimonCatalogProvider::with_dynamic_options( catalog.clone(), self.dynamic_options.clone(), + self.blob_reader_registry.clone(), )), ); register_table_functions(&self.ctx, &catalog, default_db); @@ -471,7 +478,10 @@ impl SQLContext { options.insert(SCAN_VERSION_OPTION.to_string(), info.version.clone()); let table_with_options = paimon_table.copy_with_options(options); - let provider = Arc::new(PaimonTableProvider::try_new(table_with_options)?); + let provider = Arc::new(PaimonTableProvider::try_new_with_blob_reader_registry( + table_with_options, + self.blob_reader_registry.clone(), + )?); let uuid_name = format!("__paimon_tt_{}", uuid::Uuid::new_v4().as_simple()); self.register_temp_table(uuid_name.as_str(), provider)?; @@ -497,7 +507,10 @@ impl SQLContext { options.insert(SCAN_TIMESTAMP_MILLIS_OPTION.to_string(), millis.to_string()); let table_with_options = paimon_table.copy_with_options(options); - let provider = Arc::new(PaimonTableProvider::try_new(table_with_options)?); + let provider = Arc::new(PaimonTableProvider::try_new_with_blob_reader_registry( + table_with_options, + self.blob_reader_registry.clone(), + )?); let uuid_name = format!("__paimon_tt_{}", uuid::Uuid::new_v4().as_simple()); self.register_temp_table(uuid_name.as_str(), provider)?; diff --git a/crates/integrations/datafusion/src/table/mod.rs b/crates/integrations/datafusion/src/table/mod.rs index 5bae7443..3508362a 100644 --- a/crates/integrations/datafusion/src/table/mod.rs +++ b/crates/integrations/datafusion/src/table/mod.rs @@ -32,6 +32,7 @@ use datafusion::physical_plan::ExecutionPlan; use paimon::table::Table; use crate::physical_plan::PaimonDataSink; +use crate::BlobReaderRegistry; use crate::error::to_datafusion_error; #[cfg(test)] @@ -74,6 +75,15 @@ impl PaimonTableProvider { Ok(Self { table, schema }) } + pub fn try_new_with_blob_reader_registry( + table: Table, + blob_reader_registry: BlobReaderRegistry, + ) -> DFResult { + blob_reader_registry + .register_if_absent(table.location().to_string(), table.file_io().clone()); + Self::try_new(table) + } + pub fn table(&self) -> &Table { &self.table }