From 23543fc027db53c783a2567e769afa18c7bab558 Mon Sep 17 00:00:00 2001 From: ntjohnson1 <24689722+ntjohnson1@users.noreply.github.com> Date: Tue, 28 Apr 2026 11:26:37 -0400 Subject: [PATCH 1/2] feat: allow replacing the global SessionContext MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Promote the previously immutable global context slot in `datafusion-python-util` from `OnceLock>` to a `RwLock>` and expose `set_global_ctx` (Rust) / `SessionContext.set_as_global` (Python). Users who register UDFs or otherwise customize a context can now make it the default seen by `SessionContext.global_ctx()` and the module-level `read_*` helpers. Existing snapshots returned by `get_global_ctx()` are unaffected — the swap only changes what subsequent readers see. Also fixes a pre-existing clippy `uninlined_format_args` nit in `dataframe.rs` that was tripping the pre-commit hook. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/core/src/context.rs | 15 +++++++-- crates/core/src/dataframe.rs | 2 +- crates/util/src/lib.rs | 65 +++++++++++++++++++++++++++++++++--- python/datafusion/context.py | 16 +++++++++ 4 files changed, 90 insertions(+), 8 deletions(-) diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index e46d359d6..22c12be9b 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -56,7 +56,7 @@ use datafusion_ffi::table_provider_factory::FFI_TableProviderFactory; use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec; use datafusion_python_util::{ create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, get_global_ctx, - get_tokio_runtime, spawn_future, wait_for_future, + get_tokio_runtime, set_global_ctx, spawn_future, wait_for_future, }; use object_store::ObjectStore; use pyo3::IntoPyObjectExt; @@ -407,11 +407,22 @@ impl PySessionContext { #[staticmethod] #[pyo3(signature = ())] pub fn global_ctx() -> PyResult { - let ctx = get_global_ctx().clone(); + let ctx = get_global_ctx(); let logical_codec = Self::default_logical_codec(&ctx); Ok(Self { ctx, logical_codec }) } + /// Replace the process-wide global `SessionContext` with this one. + /// + /// All subsequent callers of `SessionContext.global_ctx()` (and Rust + /// helpers that fall back to the global context, such as the + /// `read_parquet` / `read_csv` / etc. module-level helpers) will see this + /// context. Existing references already obtained from `global_ctx()` are + /// not affected. + pub fn set_as_global(&self) { + set_global_ctx(self.ctx.clone()); + } + /// Register an object store with the given name #[pyo3(signature = (scheme, store, host=None))] pub fn register_object_store( diff --git a/crates/core/src/dataframe.rs b/crates/core/src/dataframe.rs index 2e74991b8..66cdee56e 100644 --- a/crates/core/src/dataframe.rs +++ b/crates/core/src/dataframe.rs @@ -851,7 +851,7 @@ impl PyDataFrame { Some(f) => f .parse::() .map_err(|e| { - PyDataFusionError::Common(format!("Invalid explain format '{}': {}", f, e)) + PyDataFusionError::Common(format!("Invalid explain format '{f}': {e}")) })?, None => datafusion::common::format::ExplainFormat::Indent, }; diff --git a/crates/util/src/lib.rs b/crates/util/src/lib.rs index 5b1c89936..58a09f192 100644 --- a/crates/util/src/lib.rs +++ b/crates/util/src/lib.rs @@ -17,7 +17,7 @@ use std::future::Future; use std::ptr::NonNull; -use std::sync::{Arc, OnceLock}; +use std::sync::{Arc, OnceLock, RwLock}; use std::time::Duration; use datafusion::datasource::TableProvider; @@ -59,11 +59,29 @@ pub fn is_ipython_env(py: Python) -> &'static bool { }) } -/// Utility to get the Global Datafussion CTX +fn global_ctx_slot() -> &'static RwLock> { + static CTX: OnceLock>> = OnceLock::new(); + CTX.get_or_init(|| RwLock::new(Arc::new(SessionContext::new()))) +} + +/// Utility to get the Global DataFusion CTX. +/// +/// Returns an owned `Arc` snapshot. The underlying slot can be +/// replaced via [`set_global_ctx`]; existing snapshots are unaffected. #[inline] -pub fn get_global_ctx() -> &'static Arc { - static CTX: OnceLock> = OnceLock::new(); - CTX.get_or_init(|| Arc::new(SessionContext::new())) +pub fn get_global_ctx() -> Arc { + global_ctx_slot() + .read() + .expect("global SessionContext lock poisoned") + .clone() +} + +/// Replace the Global DataFusion CTX. Subsequent calls to [`get_global_ctx`] +/// will return the new context. Already-cloned `Arc`s are not affected. +pub fn set_global_ctx(ctx: Arc) { + *global_ctx_slot() + .write() + .expect("global SessionContext lock poisoned") = ctx; } /// Utility to collect rust futures with GIL released and respond to @@ -224,3 +242,40 @@ pub fn ffi_logical_codec_from_pycapsule(obj: Bound) -> PyResult SessionContext: wrapper.ctx = internal_ctx return wrapper + def set_as_global(self) -> None: + """Install this context as the process-wide global ``SessionContext``. + + After this call, :meth:`SessionContext.global_ctx` (and the module-level + helpers in :mod:`datafusion.io` that fall back to the global context) + will return this context. Existing references already obtained from + ``global_ctx()`` are not invalidated. + + Example:: + + ctx = SessionContext() + ctx.register_udf(my_udf) + ctx.set_as_global() + """ + self.ctx.set_as_global() + def enable_url_table(self) -> SessionContext: """Control if local files can be queried as tables. From 5796b536ca90fd2af9879d25456a03b453fc2c9d Mon Sep 17 00:00:00 2001 From: ntjohnson1 <24689722+ntjohnson1@users.noreply.github.com> Date: Tue, 28 Apr 2026 11:40:26 -0400 Subject: [PATCH 2/2] feat: pickle/dill support for Expr Add `to_bytes` / `from_bytes` on `Expr` (Python wrapper) and the underlying `RawExpr` (Rust). Serialization uses `datafusion-proto`'s `Serializeable` trait, encoding function references by name. The Python wrapper implements `__getstate__` / `__setstate__` on top, so `pickle.dumps` / `dill.dumps` work out of the box. Reconstruction resolves function names against the process-wide global `SessionContext` (introduced as settable in the previous commit). Built-in functions always roundtrip; user-defined functions roundtrip when registered on a context that has been installed via `SessionContext.set_as_global()`. Adds `dill` to the dev dependency group and parametrized tests covering both serializers across columns, literals, binary ops, casts, between, aggregates, case/when, and a UDF with the global-ctx pattern. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/core/src/expr.rs | 26 ++++++ pyproject.toml | 1 + python/datafusion/expr.py | 25 ++++++ python/tests/test_pickle.py | 167 ++++++++++++++++++++++++++++++++++++ uv.lock | 11 +++ 5 files changed, 230 insertions(+) create mode 100644 python/tests/test_pickle.py diff --git a/crates/core/src/expr.rs b/crates/core/src/expr.rs index c4f2a12da..fc5d987bc 100644 --- a/crates/core/src/expr.rs +++ b/crates/core/src/expr.rs @@ -31,9 +31,12 @@ use datafusion::logical_expr::{ Between, BinaryExpr, Case, Cast, Expr, ExprFuncBuilder, ExprFunctionExt, Like, LogicalPlan, Operator, TryCast, WindowFunctionDefinition, col, lit, lit_with_metadata, }; +use datafusion_proto::bytes::Serializeable; +use datafusion_python_util::get_global_ctx; use pyo3::IntoPyObjectExt; use pyo3::basic::CompareOp; use pyo3::prelude::*; +use pyo3::types::PyBytes; use window::PyWindowFrame; use self::alias::PyAlias; @@ -256,6 +259,29 @@ impl PyExpr { Ok(format!("Expr({})", self.expr)) } + /// Serialize the underlying expression to bytes via the `datafusion-proto` + /// wire format. Used by the Python `Expr` wrapper to implement + /// `__getstate__` / `__setstate__`; also exposed directly so callers can + /// persist or transmit expressions without going through `pickle`. + fn to_bytes<'py>(&self, py: Python<'py>) -> PyDataFusionResult> { + let bytes = self.expr.to_bytes()?; + Ok(PyBytes::new(py, &bytes)) + } + + /// Reconstruct a `RawExpr` from bytes produced by [`PyExpr::to_bytes`]. + /// + /// Function references (built-ins, UDFs, UDAFs, UDWFs) are resolved by + /// name against the process-wide global `SessionContext`. Built-in + /// functions are registered on every fresh context, so they always + /// roundtrip. To roundtrip user-defined functions, register them on a + /// context and call `SessionContext.set_as_global()` before unpickling. + #[staticmethod] + fn from_bytes(bytes: &[u8]) -> PyDataFusionResult { + let ctx = get_global_ctx(); + let expr = Expr::from_bytes_with_registry(bytes, ctx.as_ref())?; + Ok(expr.into()) + } + fn __add__(&self, rhs: PyExpr) -> PyResult { Ok((self.expr.clone() + rhs.expr).into()) } diff --git a/pyproject.toml b/pyproject.toml index 951f7adc3..e6bcb75ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -188,6 +188,7 @@ ignore-words-list = ["IST", "ans"] dev = [ "arro3-core==0.6.5", "codespell==2.4.1", + "dill>=0.3.8", "maturin>=1.8.1", "nanoarrow==0.8.0", "numpy>1.25.0;python_version<'3.14'", diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 0f7f3ab5a..9a9e9626e 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -410,6 +410,31 @@ def __init__(self, expr: expr_internal.RawExpr) -> None: """This constructor should not be called by the end user.""" self.expr = expr + def to_bytes(self) -> bytes: + """Serialize this expression to bytes via the ``datafusion-proto`` wire format. + + Function references (built-ins and UDFs/UDAFs/UDWFs) are encoded by + name; on :py:meth:`from_bytes` the names are resolved against the + process-wide global :py:class:`SessionContext`. Built-in functions + always roundtrip; for user-defined functions, register them on a + context and call :py:meth:`SessionContext.set_as_global` before + loading. + """ + return self.expr.to_bytes() + + @classmethod + def from_bytes(cls, data: bytes) -> Expr: + """Inverse of :py:meth:`to_bytes`. See that method for caveats.""" + return cls(expr_internal.RawExpr.from_bytes(data)) + + def __getstate__(self) -> bytes: + """Serialize for ``pickle`` / ``dill``. Delegates to :py:meth:`to_bytes`.""" + return self.to_bytes() + + def __setstate__(self, state: bytes) -> None: + """Inverse of :py:meth:`__getstate__`.""" + self.expr = expr_internal.RawExpr.from_bytes(state) + def to_variant(self) -> Any: """Convert this expression into a python object if possible.""" return self.expr.to_variant() diff --git a/python/tests/test_pickle.py b/python/tests/test_pickle.py new file mode 100644 index 000000000..ed24ef5b0 --- /dev/null +++ b/python/tests/test_pickle.py @@ -0,0 +1,167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pickle / dill roundtrip tests for :py:class:`datafusion.Expr`. + +The wire format is `datafusion-proto`'s ``LogicalExprNode``. Function +references are encoded by name, so unpickling resolves them against the +process-wide global :py:class:`SessionContext`. Tests that need a +non-built-in function temporarily install a custom global context and +restore the previous one. +""" + +import pickle +from contextlib import contextmanager + +import dill +import pyarrow as pa +import pytest +from datafusion import SessionContext, col, lit, udf +from datafusion import functions as f +from datafusion.expr import Expr + + +@pytest.fixture +def ctx(): + return SessionContext() + + +@pytest.fixture +def df(ctx): + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, None])], + names=["a", "b"], + ) + return ctx.create_dataframe([[batch]], name="t") + + +@contextmanager +def temporary_global_ctx(new_ctx): + """Install ``new_ctx`` as the process-wide global and restore on exit.""" + previous = SessionContext.global_ctx() + new_ctx.set_as_global() + try: + yield + finally: + previous.set_as_global() + + +@pytest.mark.parametrize("dumper", [pickle, dill], ids=["pickle", "dill"]) +@pytest.mark.parametrize( + "build_expr", + [ + pytest.param(lambda: col("a"), id="column"), + pytest.param(lambda: lit(42), id="literal_int"), + pytest.param(lambda: lit("hello"), id="literal_str"), + pytest.param(lambda: col("a") + lit(1), id="binary_add"), + pytest.param(lambda: (col("a") * lit(2)) - col("b"), id="binary_nested"), + pytest.param(lambda: col("a").alias("renamed"), id="alias"), + pytest.param(lambda: col("a").cast(pa.float64()), id="cast"), + pytest.param(lambda: col("a").is_null(), id="is_null"), + pytest.param(lambda: col("a").between(lit(1), lit(10)), id="between"), + pytest.param(lambda: ~(col("a") > lit(0)), id="not_gt"), + pytest.param(lambda: f.sum(col("a")), id="agg_sum"), + pytest.param( + lambda: f.case(col("a")).when(lit(1), lit("one")).end(), + id="case_when", + ), + ], +) +def test_builtin_roundtrip(build_expr, dumper): + """Built-in expressions roundtrip via pickle and dill.""" + expr = build_expr() + restored = dumper.loads(dumper.dumps(expr)) + assert isinstance(restored, Expr) + # canonical_name() gives a full string form including function names, + # so equal canonical names imply structural equivalence. + assert restored.canonical_name() == expr.canonical_name() + + +@pytest.mark.parametrize("dumper", [pickle, dill], ids=["pickle", "dill"]) +def test_pickled_expr_executes(df, dumper): + """A roundtripped expression evaluates to the same result as the original.""" + expr = (col("a") + lit(10)).alias("a_plus_ten") + restored = dumper.loads(dumper.dumps(expr)) + + original = df.select(expr).collect()[0].column(0) + after = df.select(restored).collect()[0].column(0) + assert original == after + assert original == pa.array([11, 12, 13], type=pa.int64()) + + +def test_udf_roundtrip_via_global_ctx(): + """UDFs roundtrip when registered on the active global context. + + Mirrors the documented usage of ``SessionContext.set_as_global``. + """ + is_null = udf( + lambda x: x.is_null(), + [pa.int64()], + pa.bool_(), + volatility="immutable", + name="pickle_test_is_null", + ) + + custom_ctx = SessionContext() + custom_ctx.register_udf(is_null) + + expr = is_null(col("b")) + + with temporary_global_ctx(custom_ctx): + data = pickle.dumps(expr) + restored = pickle.loads(data) # noqa: S301 + assert restored.canonical_name() == expr.canonical_name() + + # Also evaluate to confirm the UDF body is wired up post-roundtrip. + batch = pa.RecordBatch.from_arrays([pa.array([1, None, 3])], names=["b"]) + df = custom_ctx.create_dataframe([[batch]], name="t_udf") + result = df.select(restored.alias("nul")).collect()[0].column(0) + assert result == pa.array([False, True, False]) + + +def test_udf_roundtrip_fails_without_registration(): + """Without the UDF registered on the global context, unpickle errors out + rather than silently substituting a different implementation.""" + is_null = udf( + lambda x: x.is_null(), + [pa.int64()], + pa.bool_(), + volatility="immutable", + name="pickle_test_unknown_udf", + ) + expr = is_null(col("b")) + + data = pickle.dumps(expr) + # The default global ctx does not have this UDF registered. Reconstruction + # must raise rather than silently substitute a placeholder. DataFusion + # surfaces this as a generic Python ``Exception`` whose message names the + # missing function, so match on the function name. + with pytest.raises(Exception, match="pickle_test_unknown_udf"): + pickle.loads(data) # noqa: S301 + + +def test_getstate_returns_bytes(): + """``__getstate__`` is exposed directly and returns raw bytes — useful for + callers that want to persist or transmit expressions without pickle.""" + expr = col("a") + lit(1) + state = expr.__getstate__() + assert isinstance(state, bytes) + assert len(state) > 0 + + rebuilt = Expr.__new__(Expr) + rebuilt.__setstate__(state) + assert rebuilt.canonical_name() == expr.canonical_name() diff --git a/uv.lock b/uv.lock index 3b7135e32..e05d9240c 100644 --- a/uv.lock +++ b/uv.lock @@ -324,6 +324,7 @@ dependencies = [ dev = [ { name = "arro3-core" }, { name = "codespell" }, + { name = "dill" }, { name = "maturin" }, { name = "nanoarrow" }, { name = "numpy", version = "2.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14'" }, @@ -360,6 +361,7 @@ requires-dist = [ dev = [ { name = "arro3-core", specifier = "==0.6.5" }, { name = "codespell", specifier = "==2.4.1" }, + { name = "dill", specifier = ">=0.3.8" }, { name = "maturin", specifier = ">=1.8.1" }, { name = "nanoarrow", specifier = "==0.8.0" }, { name = "numpy", marker = "python_full_version < '3.14'", specifier = ">1.25.0" }, @@ -406,6 +408,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6e/c6/ac0b6c1e2d138f1002bcf799d330bd6d85084fece321e662a14223794041/Deprecated-1.2.18-py2.py3-none-any.whl", hash = "sha256:bd5011788200372a32418f888e326a09ff80d0214bd961147cfed01b5c018eec", size = 9998, upload-time = "2025-01-27T10:46:09.186Z" }, ] +[[package]] +name = "dill" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/81/e1/56027a71e31b02ddc53c7d65b01e68edf64dea2932122fe7746a516f75d5/dill-0.4.1.tar.gz", hash = "sha256:423092df4182177d4d8ba8290c8a5b640c66ab35ec7da59ccfa00f6fa3eea5fa", size = 187315, upload-time = "2026-01-19T02:36:56.85Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/77/dc8c558f7593132cf8fefec57c4f60c83b16941c574ac5f619abb3ae7933/dill-0.4.1-py3-none-any.whl", hash = "sha256:1e1ce33e978ae97fcfcff5638477032b801c46c7c65cf717f95fbc2248f79a9d", size = 120019, upload-time = "2026-01-19T02:36:55.663Z" }, +] + [[package]] name = "distlib" version = "0.3.9"