diff --git a/Cargo.lock b/Cargo.lock index d712eecfcc72..2a462d360782 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2292,8 +2292,17 @@ dependencies = [ "async-ffi", "async-trait", "datafusion", + "datafusion-catalog", "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-functions", + "datafusion-functions-aggregate", "datafusion-functions-aggregate-common", + "datafusion-functions-table", + "datafusion-functions-window", + "datafusion-physical-expr", + "datafusion-physical-plan", "datafusion-proto", "datafusion-proto-common", "doc-comment", @@ -3037,7 +3046,6 @@ version = "0.1.0" dependencies = [ "abi_stable", "datafusion", - "datafusion-ffi", "ffi_module_interface", "tokio", ] diff --git a/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs b/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs index a83f15926f05..4a0c218328a2 100644 --- a/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs +++ b/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs @@ -21,6 +21,7 @@ use abi_stable::{export_root_module, prefix_type::PrefixTypeTrait}; use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{common::record_batch, datasource::MemTable}; +use datafusion_ffi::session::task_ctx_accessor::FFI_TaskContextAccessor; use datafusion_ffi::table_provider::FFI_TableProvider; use ffi_module_interface::{TableProviderModule, TableProviderModuleRef}; @@ -34,7 +35,9 @@ fn create_record_batch(start_value: i32, num_values: usize) -> RecordBatch { /// Here we only wish to create a simple table provider as an example. /// We create an in-memory table and convert it to it's FFI counterpart. -extern "C" fn construct_simple_table_provider() -> FFI_TableProvider { +extern "C" fn construct_simple_table_provider( + task_ctx_accessor: FFI_TaskContextAccessor, +) -> FFI_TableProvider { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Float64, true), @@ -50,7 +53,7 @@ extern "C" fn construct_simple_table_provider() -> FFI_TableProvider { let table_provider = MemTable::try_new(schema, vec![batches]).unwrap(); - FFI_TableProvider::new(Arc::new(table_provider), true, None) + FFI_TableProvider::new(Arc::new(table_provider), true, None, task_ctx_accessor) } #[export_root_module] diff --git a/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs b/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs index 88690e929713..3d222b99d723 100644 --- a/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs +++ b/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs @@ -22,6 +22,7 @@ use abi_stable::{ sabi_types::VersionStrings, StableAbi, }; +use datafusion_ffi::session::task_ctx_accessor::FFI_TaskContextAccessor; use datafusion_ffi::table_provider::FFI_TableProvider; #[repr(C)] @@ -34,7 +35,7 @@ use datafusion_ffi::table_provider::FFI_TableProvider; /// how a user may wish to separate these concerns. pub struct TableProviderModule { /// Constructs the table provider - pub create_table: extern "C" fn() -> FFI_TableProvider, + pub create_table: extern "C" fn(FFI_TaskContextAccessor) -> FFI_TableProvider, } impl RootModule for TableProviderModuleRef { diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml index 028a366aab1c..1f68bb1bb1be 100644 --- a/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml +++ b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml @@ -24,6 +24,5 @@ publish = false [dependencies] abi_stable = "0.11.3" datafusion = { workspace = true } -datafusion-ffi = { workspace = true } ffi_module_interface = { path = "../ffi_module_interface" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs index 6e376ca866e8..6918608551f1 100644 --- a/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs +++ b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs @@ -23,7 +23,8 @@ use datafusion::{ }; use abi_stable::library::{development_utils::compute_library_path, RootModule}; -use datafusion_ffi::table_provider::ForeignTableProvider; +use datafusion::catalog::TableProvider; +use datafusion::execution::TaskContextAccessor; use ffi_module_interface::TableProviderModuleRef; #[tokio::main] @@ -39,6 +40,9 @@ async fn main() -> Result<()> { TableProviderModuleRef::load_from_directory(&library_path) .map_err(|e| DataFusionError::External(Box::new(e)))?; + let ctx = Arc::new(SessionContext::new()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; + // By calling the code below, the table provided will be created within // the module's code. let ffi_table_provider = @@ -46,16 +50,14 @@ async fn main() -> Result<()> { .create_table() .ok_or(DataFusionError::NotImplemented( "External table provider failed to implement create_table".to_string(), - ))?(); + ))?(task_ctx_accessor.into()); // In order to access the table provider within this executable, we need to // turn it into a `ForeignTableProvider`. - let foreign_table_provider: ForeignTableProvider = (&ffi_table_provider).into(); - - let ctx = SessionContext::new(); + let foreign_table_provider: Arc = (&ffi_table_provider).into(); // Display the data to show the full cycle works. - ctx.register_table("external_table", Arc::new(foreign_table_provider))?; + ctx.register_table("external_table", foreign_table_provider)?; let df = ctx.table("external_table").await?; df.show().await?; diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 98804e424b40..aa378d42622d 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -52,8 +52,8 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ exec_err, internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, - Column, DFSchema, DataFusionError, ParamValues, ScalarValue, SchemaError, - TableReference, UnnestOptions, + unqualified_field_not_found, Column, DFSchema, DataFusionError, ParamValues, + ScalarValue, SchemaError, TableReference, UnnestOptions, }; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::{ @@ -310,11 +310,20 @@ impl DataFrame { pub fn select_columns(self, columns: &[&str]) -> Result { let fields = columns .iter() - .flat_map(|name| { - self.plan + .map(|name| { + let fields = self + .plan .schema() - .qualified_fields_with_unqualified_name(name) + .qualified_fields_with_unqualified_name(name); + if fields.is_empty() { + Err(unqualified_field_not_found(name, self.plan.schema())) + } else { + Ok(fields) + } }) + .collect::, _>>()? + .into_iter() + .flatten() .collect::>(); let expr: Vec = fields .into_iter() @@ -1655,7 +1664,7 @@ impl DataFrame { pub fn into_view(self) -> Arc { Arc::new(DataFrameTableProvider { plan: self.plan, - table_type: TableType::Temporary, + table_type: TableType::View, }) } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index c732c2c92f64..46fa5633bea5 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -91,6 +91,7 @@ use datafusion_session::SessionStore; use async_trait::async_trait; use chrono::{DateTime, Utc}; +use datafusion_execution::TaskContextAccessor; use object_store::ObjectStore; use parking_lot::RwLock; use url::Url; @@ -1794,6 +1795,12 @@ impl FunctionRegistry for SessionContext { } } +impl TaskContextAccessor for SessionContext { + fn get_task_context(&self) -> Arc { + self.task_ctx() + } +} + /// Create a new task context instance from SessionContext impl From<&SessionContext> for TaskContext { fn from(session: &SessionContext) -> Self { diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 05f5a204c096..610d6937deb7 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -66,8 +66,8 @@ use datafusion::test_util::{ use datafusion_catalog::TableProvider; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; use datafusion_common::{ - assert_contains, internal_datafusion_err, Constraint, Constraints, DFSchema, - DataFusionError, ScalarValue, TableReference, UnnestOptions, + assert_contains, internal_datafusion_err, internal_err, Constraint, Constraints, + DFSchema, DataFusionError, ScalarValue, TableReference, UnnestOptions, }; use datafusion_common_runtime::SpawnedTask; use datafusion_datasource::file_format::format_as_file_type; @@ -305,6 +305,17 @@ async fn select_columns() -> Result<()> { Ok(()) } +#[tokio::test] +async fn select_columns_with_nonexistent_columns() -> Result<()> { + let t = test_table().await?; + let t2 = t.select_columns(&["canada", "c2", "rocks"]); + let Err(DataFusionError::SchemaError(_, _)) = t2 else { + return internal_err!("select_columns with nonexistent columns should error"); + }; + + Ok(()) +} + #[tokio::test] async fn select_expr() -> Result<()> { // build plan using Table API @@ -1627,7 +1638,9 @@ async fn register_table() -> Result<()> { let df_impl = DataFrame::new(ctx.state(), df.logical_plan().clone()); // register a dataframe as a table - ctx.register_table("test_table", df_impl.clone().into_view())?; + let table_provider = df_impl.clone().into_view(); + assert_eq!(table_provider.table_type(), TableType::View); + ctx.register_table("test_table", table_provider)?; // pull the table out let table = ctx.table("test_table").await?; diff --git a/datafusion/execution/src/lib.rs b/datafusion/execution/src/lib.rs index 55243e301e0e..b0a4dd0afc37 100644 --- a/datafusion/execution/src/lib.rs +++ b/datafusion/execution/src/lib.rs @@ -46,4 +46,4 @@ pub mod registry { pub use disk_manager::DiskManager; pub use registry::FunctionRegistry; pub use stream::{RecordBatchStream, SendableRecordBatchStream}; -pub use task::TaskContext; +pub use task::{TaskContext, TaskContextAccessor}; diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index c2a6cfe2c833..cda130ac23a3 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -32,7 +32,7 @@ use std::{collections::HashMap, sync::Arc}; /// information. /// /// [`SessionContext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct TaskContext { /// Session Id session_id: String, @@ -211,6 +211,10 @@ impl FunctionRegistry for TaskContext { } } +pub trait TaskContextAccessor { + fn get_task_context(&self) -> Arc; +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index 3ac08180fb68..131b90102a18 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -46,9 +46,17 @@ arrow = { workspace = true, features = ["ffi"] } arrow-schema = { workspace = true } async-ffi = { version = "0.5.0", features = ["abi_stable"] } async-trait = { workspace = true } -datafusion = { workspace = true, default-features = false } +datafusion-catalog = { workspace = true } datafusion-common = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true, optional = true } +datafusion-functions-aggregate = { workspace = true, optional = true } datafusion-functions-aggregate-common = { workspace = true } +datafusion-functions-table = { workspace = true, optional = true } +datafusion-functions-window = { workspace = true, optional = true } +datafusion-physical-expr = { workspace = true } +datafusion-physical-plan = { workspace = true } datafusion-proto = { workspace = true } datafusion-proto-common = { workspace = true } futures = { workspace = true } @@ -58,8 +66,14 @@ semver = "1.0.27" tokio = { workspace = true } [dev-dependencies] +datafusion = { workspace = true, default-features = false, features = ["sql"] } doc-comment = { workspace = true } [features] -integration-tests = [] +integration-tests = [ + "dep:datafusion-functions", + "dep:datafusion-functions-aggregate", + "dep:datafusion-functions-table", + "dep:datafusion-functions-window", +] tarpaulin_include = [] # Exists only to prevent warnings on stable and still have accurate coverage diff --git a/datafusion/ffi/README.md b/datafusion/ffi/README.md index 72070984f931..9df415160864 100644 --- a/datafusion/ffi/README.md +++ b/datafusion/ffi/README.md @@ -101,6 +101,101 @@ In this crate we have a variety of structs which closely mimic the behavior of their internal counterparts. To see detailed notes about how to use them, see the example in `FFI_TableProvider`. +## Task Context Accessor + +Many of the FFI structs in this crate contain a `FFI_TaskContextAccessor`. The +purpose of this struct is to _weakly_ hold a reference to a method to +access the current `TaskContext`. The reason we need this accessor is because +we use the `datafusion-proto` crate to serialize and deserialize data across +the FFI boundary. In particular, we need to serialize and deserialize +functions using a `TaskContext`. + +This becomes difficult because we may need to register multiple user defined +functions, table or catalog providers, etc with a `Session`, and each of these +will need the `TaskContext` to perform the processing. For this reason we +cannot simply include the `TaskContext` at the time of registration because +it would not have knowledge of anything registered afterward. + +The `FFI_TaskContextAccessor` is built up from a trait that provides a method +to get the current `TaskContext`. It only holds a `Weak` reference to the +`TaskContextAccessor`, because otherwise we could create a circular dependency +at runtime. It is imperative that if you use these methods that your accessor +remains valid for the lifetime of the calls. The `TaskContextAccessor` is +implemented on `SessionContext` and it is easy to implement on an struct that +implements `Session`. + +## Library Marker ID + +When reviewing the code, many of the structs in this crate contain a call to +a `library_maker_id`. The purpose of this call is to determine if a library is +accessing _local_ code through the FFI structs. Consider this example: you have +a `primary` program that exposes functions to create a schema provider. You +have a `secondary` library that exposes a function to create a catalog provider +and the `secondary` library uses the schema provider of the `primary` program. +From the point of view of the `secondary` library, the schema provider is +foreign code. + +Now when we register the `secondary` library with the `primary` program as a +catalog provider and we make calls to get a schema, the `secondary` library +will return a FFI wrapped schema provider back to the `primary` program. In +this case that schema provider is actually local code to the `primary` program +except that it is wrapped in the FFI code! + +We work around this by the `library_marker_id` calls. What this does is it +creates a global variable within each library and returns a `u64` address +of that library. This is guaranteed to be unique for every library that contains +FFI code. By comparing these `u64` addresses we can determine if a FFI struct +is local or foreign. + +In our example of the schema provider, if you were to make a call in your +primary program to get the schema provider, it would reach out to the foreign +catalog provider and send back a `FFI_SchemaProvider` object. By then +comparing the `library_marker_id` of this object to the `primary` program, we +determine it is local code. This means it is safe to access the underlying +private data. + +## Testing Coverage + +Since this library contains a large amount of `unsafe` code, it is important +to ensure proper test coverage. To generate a coverage report, you can use +[tarpaulin] as follows. It is necessary to use the `integration-tests` feature +in order to properly generate coverage. + +```shell +cargo tarpaulin --package datafusion-ffi --tests --features integration-tests --out Html +``` + +While it is not normally required to check Rust code for memory leaks, this +crate does manual memory management due to the FFI boundary. You can test for +leaks using the generated unit tests. How you run these checks differs depending +on your OS. + +### Linux + +On Linux, you can install `cargo-valgrind` + +```shell +cargo valgrind test --features integration-tests -p datafusion-ffi +``` + +### MacOS + +You can find the generated binaries for your unit tests by running `cargo test`. + +```shell +cargo test --features integration-tests -p datafusion-ffi --no-run +``` + +This should generate output that shows the path to the test binaries. Then +you can run commands such as the following. The specific paths of the tests +will vary. + +```shell +leaks --atExit -- target/debug/deps/datafusion_ffi-e77a2604a85a8afe +leaks --atExit -- target/debug/deps/ffi_integration-e91b7127a59b71a7 +# ... +``` + [apache datafusion]: https://datafusion.apache.org/ [api docs]: http://docs.rs/datafusion-ffi/latest [rust abi]: https://doc.rust-lang.org/reference/abi.html @@ -110,3 +205,4 @@ the example in `FFI_TableProvider`. [bindgen]: https://crates.io/crates/bindgen [`datafusion-python`]: https://datafusion.apache.org/python/ [datafusion-contrib]: https://github.com/datafusion-contrib +[tarpaulin]: https://crates.io/crates/cargo-tarpaulin diff --git a/datafusion/ffi/src/arrow_wrappers.rs b/datafusion/ffi/src/arrow_wrappers.rs index 7b3751dcae82..a00e2be75377 100644 --- a/datafusion/ffi/src/arrow_wrappers.rs +++ b/datafusion/ffi/src/arrow_wrappers.rs @@ -38,7 +38,7 @@ impl From for WrappedSchema { Ok(s) => s, Err(e) => { error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {e}"); - FFI_ArrowSchema::empty() + FFI_ArrowSchema::try_from(Schema::empty()).unwrap() } }; @@ -94,3 +94,29 @@ impl TryFrom<&ArrayRef> for WrappedArray { Ok(WrappedArray { array, schema }) } } + +#[cfg(test)] +mod tests { + use crate::arrow_wrappers::WrappedSchema; + use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef}; + use std::sync::Arc; + + /// Test an unsupported field type. This is necessary only so we can get good unit test coverage + /// so that we can also verify memory is properly maintained since we are doing `unsafe` operations. + #[test] + fn test_unsupported_schema() -> Result<(), ArrowError> { + let field = Arc::new(Field::new("a", DataType::Int32, false)); + let schema = Arc::new(Schema::new(vec![Field::new( + "a", + DataType::ListView(field), + false, + )])); + + let wrapped_schema = WrappedSchema::from(schema); + + let schema: SchemaRef = wrapped_schema.into(); + assert!(schema.fields().is_empty()); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/catalog_provider.rs b/datafusion/ffi/src/catalog_provider.rs index 65dcab34f17d..0079d296f050 100644 --- a/datafusion/ffi/src/catalog_provider.rs +++ b/datafusion/ffi/src/catalog_provider.rs @@ -21,7 +21,7 @@ use abi_stable::{ std_types::{ROption, RResult, RString, RVec}, StableAbi, }; -use datafusion::catalog::{CatalogProvider, SchemaProvider}; +use datafusion_catalog::{CatalogProvider, SchemaProvider}; use tokio::runtime::Handle; use crate::{ @@ -29,7 +29,8 @@ use crate::{ schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider}, }; -use datafusion::error::Result; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; +use datafusion_common::error::Result; /// A stable struct for sharing [`CatalogProvider`] across FFI boundaries. #[repr(C)] @@ -57,6 +58,10 @@ pub struct FFI_CatalogProvider { cascade: bool, ) -> RResult, RString>, + /// Accessor for TaskContext to be used during protobuf serialization + /// and deserialization. + task_ctx_accessor: FFI_TaskContextAccessor, + /// Used to create a clone on the provider of the execution plan. This should /// only need to be called by the receiver of the plan. pub clone: unsafe extern "C" fn(plan: &Self) -> Self, @@ -70,6 +75,10 @@ pub struct FFI_CatalogProvider { /// Internal data. This is only to be accessed by the provider of the plan. /// A [`ForeignCatalogProvider`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_CatalogProvider {} @@ -81,9 +90,9 @@ struct ProviderPrivateData { } impl FFI_CatalogProvider { - unsafe fn inner(&self) -> &Arc { + fn inner(&self) -> &Arc { let private_data = self.private_data as *const ProviderPrivateData; - &(*private_data).provider + unsafe { &(*private_data).provider } } unsafe fn runtime(&self) -> Option { @@ -105,7 +114,13 @@ unsafe extern "C" fn schema_fn_wrapper( ) -> ROption { let maybe_schema = provider.inner().schema(name.as_str()); maybe_schema - .map(|schema| FFI_SchemaProvider::new(schema, provider.runtime())) + .map(|schema| { + FFI_SchemaProvider::new( + schema, + provider.runtime(), + provider.task_ctx_accessor.clone(), + ) + }) .into() } @@ -115,13 +130,15 @@ unsafe extern "C" fn register_schema_fn_wrapper( schema: &FFI_SchemaProvider, ) -> RResult, RString> { let runtime = provider.runtime(); - let provider = provider.inner(); - let schema = Arc::new(ForeignSchemaProvider::from(schema)); + let schema = >::from(schema); - let returned_schema = - rresult_return!(provider.register_schema(name.as_str(), schema)) - .map(|schema| FFI_SchemaProvider::new(schema, runtime)) - .into(); + let returned_schema = rresult_return!(provider + .inner() + .register_schema(name.as_str(), schema)) + .map(|schema| { + FFI_SchemaProvider::new(schema, runtime, provider.task_ctx_accessor.clone()) + }) + .into(); RResult::ROk(returned_schema) } @@ -132,14 +149,19 @@ unsafe extern "C" fn deregister_schema_fn_wrapper( cascade: bool, ) -> RResult, RString> { let runtime = provider.runtime(); - let provider = provider.inner(); let maybe_schema = - rresult_return!(provider.deregister_schema(name.as_str(), cascade)); + rresult_return!(provider.inner().deregister_schema(name.as_str(), cascade)); RResult::ROk( maybe_schema - .map(|schema| FFI_SchemaProvider::new(schema, runtime)) + .map(|schema| { + FFI_SchemaProvider::new( + schema, + runtime, + provider.task_ctx_accessor.clone(), + ) + }) .into(), ) } @@ -165,10 +187,12 @@ unsafe extern "C" fn clone_fn_wrapper( schema: schema_fn_wrapper, register_schema: register_schema_fn_wrapper, deregister_schema: deregister_schema_fn_wrapper, + task_ctx_accessor: provider.task_ctx_accessor.clone(), clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, private_data, + library_marker_id: crate::get_library_marker_id, } } @@ -183,6 +207,7 @@ impl FFI_CatalogProvider { pub fn new( provider: Arc, runtime: Option, + task_ctx_accessor: FFI_TaskContextAccessor, ) -> Self { let private_data = Box::new(ProviderPrivateData { provider, runtime }); @@ -191,10 +216,12 @@ impl FFI_CatalogProvider { schema: schema_fn_wrapper, register_schema: register_schema_fn_wrapper, deregister_schema: deregister_schema_fn_wrapper, + task_ctx_accessor, clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -209,9 +236,14 @@ pub struct ForeignCatalogProvider(FFI_CatalogProvider); unsafe impl Send for ForeignCatalogProvider {} unsafe impl Sync for ForeignCatalogProvider {} -impl From<&FFI_CatalogProvider> for ForeignCatalogProvider { +impl From<&FFI_CatalogProvider> for Arc { fn from(provider: &FFI_CatalogProvider) -> Self { - Self(provider.clone()) + if (provider.library_marker_id)() == crate::get_library_marker_id() { + return Arc::clone(provider.inner()); + } + + Arc::new(ForeignCatalogProvider(provider.clone())) + as Arc } } @@ -241,7 +273,8 @@ impl CatalogProvider for ForeignCatalogProvider { (self.0.schema)(&self.0, name.into()).into(); maybe_provider.map(|provider| { - Arc::new(ForeignSchemaProvider(provider)) as Arc + >::from(&provider) + as Arc }) } } @@ -254,14 +287,19 @@ impl CatalogProvider for ForeignCatalogProvider { unsafe { let schema = match schema.as_any().downcast_ref::() { Some(s) => &s.0, - None => &FFI_SchemaProvider::new(schema, None), + None => &FFI_SchemaProvider::new( + schema, + None, + self.0.task_ctx_accessor.clone(), + ), }; let returned_schema: Option = df_result!((self.0.register_schema)(&self.0, name.into(), schema))? .into(); - Ok(returned_schema - .map(|s| Arc::new(ForeignSchemaProvider(s)) as Arc)) + Ok(returned_schema.map(|s| { + >::from(&s) as Arc + })) } } @@ -275,17 +313,19 @@ impl CatalogProvider for ForeignCatalogProvider { df_result!((self.0.deregister_schema)(&self.0, name.into(), cascade))? .into(); - Ok(returned_schema - .map(|s| Arc::new(ForeignSchemaProvider(s)) as Arc)) + Ok(returned_schema.map(|s| { + >::from(&s) as Arc + })) } } } #[cfg(test)] mod tests { - use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider}; - use super::*; + use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider}; + use datafusion::prelude::SessionContext; + use datafusion_execution::TaskContextAccessor; #[test] fn test_round_trip_ffi_catalog_provider() { @@ -298,9 +338,14 @@ mod tests { .unwrap() .is_none()); - let ffi_catalog = FFI_CatalogProvider::new(catalog, None); + let ctx = Arc::new(SessionContext::new()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; + + let mut ffi_catalog = + FFI_CatalogProvider::new(catalog, None, task_ctx_accessor.into()); + ffi_catalog.library_marker_id = crate::mock_foreign_marker_id; - let foreign_catalog: ForeignCatalogProvider = (&ffi_catalog).into(); + let foreign_catalog: Arc = (&ffi_catalog).into(); let prior_schema_names = foreign_catalog.schema_names(); assert_eq!(prior_schema_names.len(), 1); diff --git a/datafusion/ffi/src/execution_plan.rs b/datafusion/ffi/src/execution_plan.rs index 70c957d8c373..739ee9266c02 100644 --- a/datafusion/ffi/src/execution_plan.rs +++ b/datafusion/ffi/src/execution_plan.rs @@ -17,23 +17,21 @@ use std::{ffi::c_void, pin::Pin, sync::Arc}; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; +use crate::{ + df_result, plan_properties::FFI_PlanProperties, + record_batch_stream::FFI_RecordBatchStream, rresult, rresult_return, +}; use abi_stable::{ std_types::{RResult, RString, RVec}, StableAbi, }; -use datafusion::{ - error::DataFusionError, - execution::{SendableRecordBatchStream, TaskContext}, - physical_plan::{DisplayAs, ExecutionPlan, PlanProperties}, -}; -use datafusion::{error::Result, physical_plan::DisplayFormatType}; +use datafusion_common::error::{DataFusionError, Result}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_plan::DisplayFormatType; +use datafusion_physical_plan::{DisplayAs, ExecutionPlan, PlanProperties}; use tokio::runtime::Handle; -use crate::{ - df_result, plan_properties::FFI_PlanProperties, - record_batch_stream::FFI_RecordBatchStream, rresult, -}; - /// A stable struct for sharing a [`ExecutionPlan`] across FFI boundaries. #[repr(C)] #[derive(Debug, StableAbi)] @@ -55,6 +53,10 @@ pub struct FFI_ExecutionPlan { partition: usize, ) -> RResult, + /// Accessor for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_accessor: FFI_TaskContextAccessor, + /// Used to create a clone on the provider of the execution plan. This should /// only need to be called by the receiver of the plan. pub clone: unsafe extern "C" fn(plan: &Self) -> Self, @@ -65,6 +67,10 @@ pub struct FFI_ExecutionPlan { /// Internal data. This is only to be accessed by the provider of the plan. /// A [`ForeignExecutionPlan`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_ExecutionPlan {} @@ -72,32 +78,35 @@ unsafe impl Sync for FFI_ExecutionPlan {} pub struct ExecutionPlanPrivateData { pub plan: Arc, - pub context: Arc, pub runtime: Option, } +impl FFI_ExecutionPlan { + fn inner(&self) -> &Arc { + let private_data = self.private_data as *const ExecutionPlanPrivateData; + unsafe { &(*private_data).plan } + } +} + unsafe extern "C" fn properties_fn_wrapper( plan: &FFI_ExecutionPlan, ) -> FFI_PlanProperties { - let private_data = plan.private_data as *const ExecutionPlanPrivateData; - let plan = &(*private_data).plan; - - plan.properties().into() + FFI_PlanProperties::new(plan.inner().properties(), plan.task_ctx_accessor.clone()) } unsafe extern "C" fn children_fn_wrapper( plan: &FFI_ExecutionPlan, ) -> RVec { + let ctx = &plan.task_ctx_accessor; let private_data = plan.private_data as *const ExecutionPlanPrivateData; let plan = &(*private_data).plan; - let ctx = &(*private_data).context; let runtime = &(*private_data).runtime; let children: Vec<_> = plan .children() .into_iter() .map(|child| { - FFI_ExecutionPlan::new(Arc::clone(child), Arc::clone(ctx), runtime.clone()) + FFI_ExecutionPlan::new(Arc::clone(child), ctx.clone(), runtime.clone()) }) .collect(); @@ -108,21 +117,18 @@ unsafe extern "C" fn execute_fn_wrapper( plan: &FFI_ExecutionPlan, partition: usize, ) -> RResult { + let ctx = rresult_return!(>::try_from(&plan.task_ctx_accessor)); let private_data = plan.private_data as *const ExecutionPlanPrivateData; let plan = &(*private_data).plan; - let ctx = &(*private_data).context; let runtime = (*private_data).runtime.clone(); rresult!(plan - .execute(partition, Arc::clone(ctx)) + .execute(partition, ctx) .map(|rbs| FFI_RecordBatchStream::new(rbs, runtime))) } unsafe extern "C" fn name_fn_wrapper(plan: &FFI_ExecutionPlan) -> RString { - let private_data = plan.private_data as *const ExecutionPlanPrivateData; - let plan = &(*private_data).plan; - - plan.name().into() + plan.inner().name().into() } unsafe extern "C" fn release_fn_wrapper(plan: &mut FFI_ExecutionPlan) { @@ -131,14 +137,11 @@ unsafe extern "C" fn release_fn_wrapper(plan: &mut FFI_ExecutionPlan) { } unsafe extern "C" fn clone_fn_wrapper(plan: &FFI_ExecutionPlan) -> FFI_ExecutionPlan { + let ctx = plan.task_ctx_accessor.clone(); let private_data = plan.private_data as *const ExecutionPlanPrivateData; let plan_data = &(*private_data); - FFI_ExecutionPlan::new( - Arc::clone(&plan_data.plan), - Arc::clone(&plan_data.context), - plan_data.runtime.clone(), - ) + FFI_ExecutionPlan::new(Arc::clone(&plan_data.plan), ctx, plan_data.runtime.clone()) } impl Clone for FFI_ExecutionPlan { @@ -151,23 +154,21 @@ impl FFI_ExecutionPlan { /// This function is called on the provider's side. pub fn new( plan: Arc, - context: Arc, + context: FFI_TaskContextAccessor, runtime: Option, ) -> Self { - let private_data = Box::new(ExecutionPlanPrivateData { - plan, - context, - runtime, - }); + let private_data = Box::new(ExecutionPlanPrivateData { plan, runtime }); Self { properties: properties_fn_wrapper, children: children_fn_wrapper, name: name_fn_wrapper, execute: execute_fn_wrapper, + task_ctx_accessor: context, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -218,10 +219,14 @@ impl DisplayAs for ForeignExecutionPlan { } } -impl TryFrom<&FFI_ExecutionPlan> for ForeignExecutionPlan { +impl TryFrom<&FFI_ExecutionPlan> for Arc { type Error = DataFusionError; fn try_from(plan: &FFI_ExecutionPlan) -> Result { + if (plan.library_marker_id)() == crate::get_library_marker_id() { + return Ok(Arc::clone(plan.inner())); + } + unsafe { let name = (plan.name)(plan).into(); @@ -230,16 +235,17 @@ impl TryFrom<&FFI_ExecutionPlan> for ForeignExecutionPlan { let children_rvec = (plan.children)(plan); let children = children_rvec .iter() - .map(ForeignExecutionPlan::try_from) - .map(|child| child.map(|c| Arc::new(c) as Arc)) + .map(>::try_from) .collect::>>()?; - Ok(Self { + let plan = ForeignExecutionPlan { name, plan: plan.clone(), properties, children, - }) + }; + + Ok(Arc::new(plan)) } } } @@ -290,6 +296,7 @@ impl ExecutionPlan for ForeignExecutionPlan { #[cfg(test)] mod tests { + use super::*; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{ physical_plan::{ @@ -298,8 +305,7 @@ mod tests { }, prelude::SessionContext, }; - - use super::*; + use datafusion_execution::TaskContextAccessor; #[derive(Debug)] pub struct EmptyExec { @@ -375,19 +381,22 @@ mod tests { fn test_round_trip_ffi_execution_plan() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); - let ctx = SessionContext::new(); + let ctx = Arc::new(SessionContext::new()) as Arc; let original_plan = Arc::new(EmptyExec::new(schema)); let original_name = original_plan.name().to_string(); - let local_plan = FFI_ExecutionPlan::new(original_plan, ctx.task_ctx(), None); + let mut local_plan = FFI_ExecutionPlan::new(original_plan, ctx.into(), None); + + // Force round trip to go through foreign provider + local_plan.library_marker_id = crate::mock_foreign_marker_id; - let foreign_plan: ForeignExecutionPlan = (&local_plan).try_into()?; + let foreign_plan: Arc = (&local_plan).try_into()?; assert!(original_name == foreign_plan.name()); let display = datafusion::physical_plan::display::DisplayableExecutionPlan::new( - &foreign_plan, + foreign_plan.as_ref(), ); let buf = display.one_line().to_string(); @@ -403,16 +412,18 @@ mod tests { fn test_ffi_execution_plan_children() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); - let ctx = SessionContext::new(); + let ctx = Arc::new(SessionContext::new()) as Arc; + let ctx = FFI_TaskContextAccessor::from(ctx); // Version 1: Adding child to the foreign plan let child_plan = Arc::new(EmptyExec::new(Arc::clone(&schema))); - let child_local = FFI_ExecutionPlan::new(child_plan, ctx.task_ctx(), None); - let child_foreign = Arc::new(ForeignExecutionPlan::try_from(&child_local)?); + let mut child_local = FFI_ExecutionPlan::new(child_plan, ctx.clone(), None); + child_local.library_marker_id = crate::mock_foreign_marker_id; + let child_foreign = >::try_from(&child_local)?; let parent_plan = Arc::new(EmptyExec::new(Arc::clone(&schema))); - let parent_local = FFI_ExecutionPlan::new(parent_plan, ctx.task_ctx(), None); - let parent_foreign = Arc::new(ForeignExecutionPlan::try_from(&parent_local)?); + let parent_local = FFI_ExecutionPlan::new(parent_plan, ctx.clone(), None); + let parent_foreign = >::try_from(&parent_local)?; assert_eq!(parent_foreign.children().len(), 0); assert_eq!(child_foreign.children().len(), 0); @@ -422,13 +433,13 @@ mod tests { // Version 2: Adding child to the local plan let child_plan = Arc::new(EmptyExec::new(Arc::clone(&schema))); - let child_local = FFI_ExecutionPlan::new(child_plan, ctx.task_ctx(), None); - let child_foreign = Arc::new(ForeignExecutionPlan::try_from(&child_local)?); + let child_local = FFI_ExecutionPlan::new(child_plan, ctx.clone(), None); + let child_foreign = >::try_from(&child_local)?; let parent_plan = Arc::new(EmptyExec::new(Arc::clone(&schema))); let parent_plan = parent_plan.with_new_children(vec![child_foreign])?; - let parent_local = FFI_ExecutionPlan::new(parent_plan, ctx.task_ctx(), None); - let parent_foreign = Arc::new(ForeignExecutionPlan::try_from(&parent_local)?); + let parent_local = FFI_ExecutionPlan::new(parent_plan, ctx, None); + let parent_foreign = >::try_from(&parent_local)?; assert_eq!(parent_foreign.children().len(), 1); diff --git a/datafusion/ffi/src/insert_op.rs b/datafusion/ffi/src/insert_op.rs index 8e8693076cc0..acba1aa2ae08 100644 --- a/datafusion/ffi/src/insert_op.rs +++ b/datafusion/ffi/src/insert_op.rs @@ -16,7 +16,7 @@ // under the License. use abi_stable::StableAbi; -use datafusion::logical_expr::logical_plan::dml::InsertOp; +use datafusion_expr::logical_plan::dml::InsertOp; /// FFI safe version of [`InsertOp`]. #[repr(C)] diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index 0c2340e8ce7b..24e5fa632377 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -31,7 +31,7 @@ pub mod insert_op; pub mod plan_properties; pub mod record_batch_stream; pub mod schema_provider; -pub mod session_config; +pub mod session; pub mod table_provider; pub mod table_source; pub mod udaf; @@ -54,5 +54,31 @@ pub extern "C" fn version() -> u64 { version.major } +static LIBRARY_MARKER: u8 = 0; + +/// This utility is used to determine if two FFI structs are within +/// the same library. It is possible that the interplay between +/// foreign and local functions calls create one FFI struct that +/// references another. It is helpful to determine if a foreign +/// struct is truly foreign or in the same library. If we are in the +/// same library, then we can access the underlying types directly. +/// +/// This function works by checking the address of the library +/// marker. Each library that implements the FFI code will have +/// a different address for the marker. By checking the marker +/// address we can determine if a struct is truly Foreign or is +/// actually within the same originating library. +pub extern "C" fn get_library_marker_id() -> u64 { + &LIBRARY_MARKER as *const u8 as u64 +} + +/// For unit testing in this crate we need to trick the providers +/// into thinking we have a foreign call. We do this by overwriting +/// their `library_marker_id` function to return a different value. +#[cfg(test)] +pub(crate) extern "C" fn mock_foreign_marker_id() -> u64 { + get_library_marker_id() + 1 +} + #[cfg(doctest)] doc_comment::doctest!("../README.md", readme_example_test); diff --git a/datafusion/ffi/src/plan_properties.rs b/datafusion/ffi/src/plan_properties.rs index 48c2698a58c7..52436dea35bf 100644 --- a/datafusion/ffi/src/plan_properties.rs +++ b/datafusion/ffi/src/plan_properties.rs @@ -17,35 +17,29 @@ use std::{ffi::c_void, sync::Arc}; -use abi_stable::{ - std_types::{ - RResult::{self, ROk}, - RString, RVec, - }, - StableAbi, -}; +use crate::arrow_wrappers::WrappedSchema; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; +use crate::{df_result, rresult_return}; +use abi_stable::std_types::{RResult, RString, RVec}; +use abi_stable::StableAbi; use arrow::datatypes::SchemaRef; -use datafusion::{ - error::{DataFusionError, Result}, - physical_expr::EquivalenceProperties, - physical_plan::{ - execution_plan::{Boundedness, EmissionType}, - PlanProperties, - }, - prelude::SessionContext, +use datafusion_common::error::{DataFusionError, Result}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_plan::{ + execution_plan::{Boundedness, EmissionType}, + PlanProperties, +}; +use datafusion_proto::physical_plan::from_proto::{ + parse_physical_sort_exprs, parse_protobuf_partitioning, }; -use datafusion_proto::{ - physical_plan::{ - from_proto::{parse_physical_sort_exprs, parse_protobuf_partitioning}, - to_proto::{serialize_partitioning, serialize_physical_sort_exprs}, - DefaultPhysicalExtensionCodec, - }, - protobuf::{Partitioning, PhysicalSortExprNodeCollection}, +use datafusion_proto::physical_plan::to_proto::{ + serialize_partitioning, serialize_physical_sort_exprs, }; +use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; +use datafusion_proto::protobuf::{Partitioning, PhysicalSortExprNodeCollection}; use prost::Message; -use crate::{arrow_wrappers::WrappedSchema, df_result, rresult_return}; - /// A stable struct for sharing [`PlanProperties`] across FFI boundaries. #[repr(C)] #[derive(Debug, StableAbi)] @@ -62,25 +56,39 @@ pub struct FFI_PlanProperties { /// Indicate boundedness of the plan and its memory requirements. pub boundedness: unsafe extern "C" fn(plan: &Self) -> FFI_Boundedness, - /// The output ordering is a [`PhysicalSortExprNodeCollection`] protobuf message - /// serialized into bytes to pass across the FFI boundary. + /// The output ordering of the plan. pub output_ordering: unsafe extern "C" fn(plan: &Self) -> RResult, RString>, /// Return the schema of the plan. pub schema: unsafe extern "C" fn(plan: &Self) -> WrappedSchema, + /// Accessor for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_accessor: FFI_TaskContextAccessor, + /// Release the memory of the private data when it is no longer being used. pub release: unsafe extern "C" fn(arg: &mut Self), /// Internal data. This is only to be accessed by the provider of the plan. /// The foreign library should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } struct PlanPropertiesPrivateData { props: PlanProperties, } +impl FFI_PlanProperties { + fn inner(&self) -> &PlanProperties { + let private_data = self.private_data as *const PlanPropertiesPrivateData; + unsafe { &(*private_data).props } + } +} + unsafe extern "C" fn output_partitioning_fn_wrapper( properties: &FFI_PlanProperties, ) -> RResult, RString> { @@ -92,23 +100,19 @@ unsafe extern "C" fn output_partitioning_fn_wrapper( rresult_return!(serialize_partitioning(props.output_partitioning(), &codec)); let output_partitioning = partitioning_data.encode_to_vec(); - ROk(output_partitioning.into()) + RResult::ROk(output_partitioning.into()) } unsafe extern "C" fn emission_type_fn_wrapper( properties: &FFI_PlanProperties, ) -> FFI_EmissionType { - let private_data = properties.private_data as *const PlanPropertiesPrivateData; - let props = &(*private_data).props; - props.emission_type.into() + (&properties.inner().emission_type).into() } unsafe extern "C" fn boundedness_fn_wrapper( properties: &FFI_PlanProperties, ) -> FFI_Boundedness { - let private_data = properties.private_data as *const PlanPropertiesPrivateData; - let props = &(*private_data).props; - props.boundedness.into() + (&properties.inner().boundedness).into() } unsafe extern "C" fn output_ordering_fn_wrapper( @@ -131,14 +135,11 @@ unsafe extern "C" fn output_ordering_fn_wrapper( } None => Vec::default(), }; - ROk(output_ordering.into()) + RResult::ROk(output_ordering.into()) } unsafe extern "C" fn schema_fn_wrapper(properties: &FFI_PlanProperties) -> WrappedSchema { - let private_data = properties.private_data as *const PlanPropertiesPrivateData; - let props = &(*private_data).props; - - let schema: SchemaRef = Arc::clone(props.eq_properties.schema()); + let schema: SchemaRef = Arc::clone(properties.inner().eq_properties.schema()); schema.into() } @@ -154,8 +155,11 @@ impl Drop for FFI_PlanProperties { } } -impl From<&PlanProperties> for FFI_PlanProperties { - fn from(props: &PlanProperties) -> Self { +impl FFI_PlanProperties { + pub fn new( + props: &PlanProperties, + task_ctx_accessor: FFI_TaskContextAccessor, + ) -> Self { let private_data = Box::new(PlanPropertiesPrivateData { props: props.clone(), }); @@ -166,8 +170,10 @@ impl From<&PlanProperties> for FFI_PlanProperties { boundedness: boundedness_fn_wrapper, output_ordering: output_ordering_fn_wrapper, schema: schema_fn_wrapper, + task_ctx_accessor, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -176,12 +182,14 @@ impl TryFrom for PlanProperties { type Error = DataFusionError; fn try_from(ffi_props: FFI_PlanProperties) -> Result { + if (ffi_props.library_marker_id)() == crate::get_library_marker_id() { + return Ok(ffi_props.inner().clone()); + } + let ffi_schema = unsafe { (ffi_props.schema)(&ffi_props) }; let schema = (&ffi_schema.0).try_into()?; - // TODO Extend FFI to get the registry and codex - let default_ctx = SessionContext::new(); - let task_context = default_ctx.task_ctx(); + let task_ctx: Arc = (&ffi_props.task_ctx_accessor).try_into()?; let codex = DefaultPhysicalExtensionCodec {}; let ffi_orderings = unsafe { (ffi_props.output_ordering)(&ffi_props) }; @@ -191,7 +199,7 @@ impl TryFrom for PlanProperties { .map_err(|e| DataFusionError::External(Box::new(e)))?; let sort_exprs = parse_physical_sort_exprs( &proto_output_ordering.physical_sort_expr_nodes, - &task_context, + &task_ctx, &schema, &codex, )?; @@ -203,7 +211,7 @@ impl TryFrom for PlanProperties { .map_err(|e| DataFusionError::External(Box::new(e)))?; let partitioning = parse_protobuf_partitioning( Some(&proto_output_partitioning), - &task_context, + &task_ctx, &schema, &codex, )? @@ -242,14 +250,14 @@ pub enum FFI_Boundedness { Unbounded { requires_infinite_memory: bool }, } -impl From for FFI_Boundedness { - fn from(value: Boundedness) -> Self { +impl From<&Boundedness> for FFI_Boundedness { + fn from(value: &Boundedness) -> Self { match value { Boundedness::Bounded => FFI_Boundedness::Bounded, Boundedness::Unbounded { requires_infinite_memory, } => FFI_Boundedness::Unbounded { - requires_infinite_memory, + requires_infinite_memory: *requires_infinite_memory, }, } } @@ -278,8 +286,8 @@ pub enum FFI_EmissionType { Both, } -impl From for FFI_EmissionType { - fn from(value: EmissionType) -> Self { +impl From<&EmissionType> for FFI_EmissionType { + fn from(value: &EmissionType) -> Self { match value { EmissionType::Incremental => FFI_EmissionType::Incremental, EmissionType::Final => FFI_EmissionType::Final, @@ -300,9 +308,10 @@ impl From for EmissionType { #[cfg(test)] mod tests { - use datafusion::{physical_expr::PhysicalSortExpr, physical_plan::Partitioning}; - use super::*; + use datafusion::prelude::SessionContext; + use datafusion::{physical_expr::PhysicalSortExpr, physical_plan::Partitioning}; + use datafusion_execution::TaskContextAccessor; #[test] fn test_round_trip_ffi_plan_properties() -> Result<()> { @@ -320,8 +329,10 @@ mod tests { EmissionType::Incremental, Boundedness::Bounded, ); + let ctx = Arc::new(SessionContext::default()) as Arc; - let local_props_ptr = FFI_PlanProperties::from(&original_props); + let mut local_props_ptr = FFI_PlanProperties::new(&original_props, (&ctx).into()); + local_props_ptr.library_marker_id = crate::mock_foreign_marker_id; let foreign_props: PlanProperties = local_props_ptr.try_into()?; diff --git a/datafusion/ffi/src/record_batch_stream.rs b/datafusion/ffi/src/record_batch_stream.rs index 1739235d1703..608e731f2115 100644 --- a/datafusion/ffi/src/record_batch_stream.rs +++ b/datafusion/ffi/src/record_batch_stream.rs @@ -27,12 +27,9 @@ use arrow::{ ffi::{from_ffi, to_ffi}, }; use async_ffi::{ContextExt, FfiContext, FfiPoll}; -use datafusion::error::Result; -use datafusion::{ - error::DataFusionError, - execution::{RecordBatchStream, SendableRecordBatchStream}, -}; +use datafusion_common::error::{DataFusionError, Result}; use datafusion_common::{exec_datafusion_err, exec_err}; +use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; use futures::{Stream, TryStreamExt}; use tokio::runtime::Handle; @@ -107,7 +104,9 @@ unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_RecordBatchStream) { drop(private_data); } -fn record_batch_to_wrapped_array( +// TODO(tsaucer) switch to Result +// and put the rresult handling to the caller +pub fn record_batch_to_wrapped_array( record_batch: RecordBatch, ) -> RResult { let struct_array = StructArray::from(record_batch); @@ -157,7 +156,7 @@ impl RecordBatchStream for FFI_RecordBatchStream { } } -fn wrapped_array_to_record_batch(array: WrappedArray) -> Result { +pub fn wrapped_array_to_record_batch(array: WrappedArray) -> Result { let array_data = unsafe { from_ffi(array.array, &array.schema.0).map_err(DataFusionError::from)? }; let array = make_array(array_data); diff --git a/datafusion/ffi/src/schema_provider.rs b/datafusion/ffi/src/schema_provider.rs index b5970d5881d6..d05acd700f10 100644 --- a/datafusion/ffi/src/schema_provider.rs +++ b/datafusion/ffi/src/schema_provider.rs @@ -23,19 +23,16 @@ use abi_stable::{ }; use async_ffi::{FfiFuture, FutureExt}; use async_trait::async_trait; -use datafusion::{ - catalog::{SchemaProvider, TableProvider}, - error::DataFusionError, -}; +use datafusion_catalog::{SchemaProvider, TableProvider}; +use datafusion_common::error::{DataFusionError, Result}; use tokio::runtime::Handle; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; use crate::{ df_result, rresult_return, table_provider::{FFI_TableProvider, ForeignTableProvider}, }; -use datafusion::error::Result; - /// A stable struct for sharing [`SchemaProvider`] across FFI boundaries. #[repr(C)] #[derive(Debug, StableAbi)] @@ -67,6 +64,10 @@ pub struct FFI_SchemaProvider { pub table_exist: unsafe extern "C" fn(provider: &Self, name: RString) -> bool, + /// Accessor for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_accessor: FFI_TaskContextAccessor, + /// Used to create a clone on the provider of the execution plan. This should /// only need to be called by the receiver of the plan. pub clone: unsafe extern "C" fn(plan: &Self) -> Self, @@ -80,6 +81,10 @@ pub struct FFI_SchemaProvider { /// Internal data. This is only to be accessed by the provider of the plan. /// A [`ForeignSchemaProvider`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_SchemaProvider {} @@ -91,9 +96,9 @@ struct ProviderPrivateData { } impl FFI_SchemaProvider { - unsafe fn inner(&self) -> &Arc { + fn inner(&self) -> &Arc { let private_data = self.private_data as *const ProviderPrivateData; - &(*private_data).provider + unsafe { &(*private_data).provider } } unsafe fn runtime(&self) -> Option { @@ -115,12 +120,13 @@ unsafe extern "C" fn table_fn_wrapper( provider: &FFI_SchemaProvider, name: RString, ) -> FfiFuture, RString>> { + let task_ctx_accessor = provider.task_ctx_accessor.clone(); let runtime = provider.runtime(); let provider = Arc::clone(provider.inner()); async move { let table = rresult_return!(provider.table(name.as_str()).await) - .map(|t| FFI_TableProvider::new(t, true, runtime)) + .map(|t| FFI_TableProvider::new(t, true, runtime, task_ctx_accessor)) .into(); RResult::ROk(table) @@ -134,12 +140,13 @@ unsafe extern "C" fn register_table_fn_wrapper( table: FFI_TableProvider, ) -> RResult, RString> { let runtime = provider.runtime(); + let task_ctx_accessor = provider.task_ctx_accessor.clone(); let provider = provider.inner(); let table = Arc::new(ForeignTableProvider(table)); let returned_table = rresult_return!(provider.register_table(name.into(), table)) - .map(|t| FFI_TableProvider::new(t, true, runtime)); + .map(|t| FFI_TableProvider::new(t, true, runtime, task_ctx_accessor)); RResult::ROk(returned_table.into()) } @@ -148,11 +155,12 @@ unsafe extern "C" fn deregister_table_fn_wrapper( provider: &FFI_SchemaProvider, name: RString, ) -> RResult, RString> { + let task_ctx_accessor = provider.task_ctx_accessor.clone(); let runtime = provider.runtime(); let provider = provider.inner(); let returned_table = rresult_return!(provider.deregister_table(name.as_str())) - .map(|t| FFI_TableProvider::new(t, true, runtime)); + .map(|t| FFI_TableProvider::new(t, true, runtime, task_ctx_accessor)); RResult::ROk(returned_table.into()) } @@ -191,6 +199,8 @@ unsafe extern "C" fn clone_fn_wrapper( register_table: register_table_fn_wrapper, deregister_table: deregister_table_fn_wrapper, table_exist: table_exist_fn_wrapper, + task_ctx_accessor: provider.task_ctx_accessor.clone(), + library_marker_id: crate::get_library_marker_id, } } @@ -205,6 +215,7 @@ impl FFI_SchemaProvider { pub fn new( provider: Arc, runtime: Option, + task_ctx_accessor: FFI_TaskContextAccessor, ) -> Self { let owner_name = provider.owner_name().map(|s| s.into()).into(); let private_data = Box::new(ProviderPrivateData { provider, runtime }); @@ -220,6 +231,8 @@ impl FFI_SchemaProvider { register_table: register_table_fn_wrapper, deregister_table: deregister_table_fn_wrapper, table_exist: table_exist_fn_wrapper, + task_ctx_accessor, + library_marker_id: crate::get_library_marker_id, } } } @@ -234,9 +247,13 @@ pub struct ForeignSchemaProvider(pub FFI_SchemaProvider); unsafe impl Send for ForeignSchemaProvider {} unsafe impl Sync for ForeignSchemaProvider {} -impl From<&FFI_SchemaProvider> for ForeignSchemaProvider { +impl From<&FFI_SchemaProvider> for Arc { fn from(provider: &FFI_SchemaProvider) -> Self { - Self(provider.clone()) + if (provider.library_marker_id)() == crate::get_library_marker_id() { + return Arc::clone(provider.inner()); + } + + Arc::new(ForeignSchemaProvider(provider.clone())) } } @@ -274,9 +291,7 @@ impl SchemaProvider for ForeignSchemaProvider { let table: Option = df_result!((self.0.table)(&self.0, name.into()).await)?.into(); - let table = table.as_ref().map(|t| { - Arc::new(ForeignTableProvider::from(t)) as Arc - }); + let table = table.as_ref().map(>::from); Ok(table) } @@ -290,7 +305,12 @@ impl SchemaProvider for ForeignSchemaProvider { unsafe { let ffi_table = match table.as_any().downcast_ref::() { Some(t) => t.0.clone(), - None => FFI_TableProvider::new(table, true, None), + None => FFI_TableProvider::new( + table, + true, + None, + self.0.task_ctx_accessor.clone(), + ), }; let returned_provider: Option = @@ -319,10 +339,11 @@ impl SchemaProvider for ForeignSchemaProvider { #[cfg(test)] mod tests { + use super::*; use arrow::datatypes::Schema; + use datafusion::prelude::SessionContext; use datafusion::{catalog::MemorySchemaProvider, datasource::empty::EmptyTable}; - - use super::*; + use datafusion_execution::TaskContextAccessor; fn empty_table() -> Arc { Arc::new(EmptyTable::new(Arc::new(Schema::empty()))) @@ -336,10 +357,14 @@ mod tests { .register_table("prior_table".to_string(), empty_table()) .unwrap() .is_none()); + let ctx = Arc::new(SessionContext::new()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; - let ffi_schema_provider = FFI_SchemaProvider::new(schema_provider, None); + let mut ffi_schema_provider = + FFI_SchemaProvider::new(schema_provider, None, task_ctx_accessor.into()); + ffi_schema_provider.library_marker_id = crate::mock_foreign_marker_id; - let foreign_schema_provider: ForeignSchemaProvider = + let foreign_schema_provider: Arc = (&ffi_schema_provider).into(); let prior_table_names = foreign_schema_provider.table_names(); diff --git a/datafusion/ffi/src/session_config.rs b/datafusion/ffi/src/session/config.rs similarity index 75% rename from datafusion/ffi/src/session_config.rs rename to datafusion/ffi/src/session/config.rs index a07b66c60196..e22b09cc8fad 100644 --- a/datafusion/ffi/src/session_config.rs +++ b/datafusion/ffi/src/session/config.rs @@ -19,13 +19,9 @@ use abi_stable::{ std_types::{RHashMap, RString}, StableAbi, }; -use datafusion::{config::ConfigOptions, error::Result}; -use datafusion::{error::DataFusionError, prelude::SessionConfig}; -use std::sync::Arc; -use std::{ - collections::HashMap, - ffi::{c_char, c_void, CString}, -}; +use datafusion_common::error::{DataFusionError, Result}; +use datafusion_execution::config::SessionConfig; +use std::{collections::HashMap, ffi::c_void}; /// A stable struct for sharing [`SessionConfig`] across FFI boundaries. /// Instead of attempting to expose the entire SessionConfig interface, we @@ -54,18 +50,27 @@ pub struct FFI_SessionConfig { pub release: unsafe extern "C" fn(arg: &mut Self), /// Internal data. This is only to be accessed by the provider of the plan. - /// A [`ForeignSessionConfig`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_SessionConfig {} unsafe impl Sync for FFI_SessionConfig {} +impl FFI_SessionConfig { + fn inner(&self) -> &SessionConfig { + let private_data = self.private_data as *mut SessionConfigPrivateData; + unsafe { &(*private_data).config } + } +} + unsafe extern "C" fn config_options_fn_wrapper( config: &FFI_SessionConfig, ) -> RHashMap { - let private_data = config.private_data as *mut SessionConfigPrivateData; - let config_options = &(*private_data).config; + let config_options = config.inner().options(); let mut options = RHashMap::default(); for config_entry in config_options.entries() { @@ -85,7 +90,7 @@ unsafe extern "C" fn release_fn_wrapper(config: &mut FFI_SessionConfig) { unsafe extern "C" fn clone_fn_wrapper(config: &FFI_SessionConfig) -> FFI_SessionConfig { let old_private_data = config.private_data as *mut SessionConfigPrivateData; - let old_config = Arc::clone(&(*old_private_data).config); + let old_config = (*old_private_data).config.clone(); let private_data = Box::new(SessionConfigPrivateData { config: old_config }); @@ -94,31 +99,18 @@ unsafe extern "C" fn clone_fn_wrapper(config: &FFI_SessionConfig) -> FFI_Session private_data: Box::into_raw(private_data) as *mut c_void, clone: clone_fn_wrapper, release: release_fn_wrapper, + library_marker_id: crate::get_library_marker_id, } } struct SessionConfigPrivateData { - pub config: Arc, + pub config: SessionConfig, } impl From<&SessionConfig> for FFI_SessionConfig { fn from(session: &SessionConfig) -> Self { - let mut config_keys = Vec::new(); - let mut config_values = Vec::new(); - for config_entry in session.options().entries() { - if let Some(value) = config_entry.value { - let key_cstr = CString::new(config_entry.key).unwrap_or_default(); - let key_ptr = key_cstr.into_raw() as *const c_char; - config_keys.push(key_ptr); - - config_values - .push(CString::new(value).unwrap_or_default().into_raw() - as *const c_char); - } - } - let private_data = Box::new(SessionConfigPrivateData { - config: Arc::clone(session.options()), + config: session.clone(), }); Self { @@ -126,6 +118,7 @@ impl From<&SessionConfig> for FFI_SessionConfig { private_data: Box::into_raw(private_data) as *mut c_void, clone: clone_fn_wrapper, release: release_fn_wrapper, + library_marker_id: crate::get_library_marker_id, } } } @@ -145,13 +138,15 @@ impl Drop for FFI_SessionConfig { /// A wrapper struct for accessing [`SessionConfig`] across a FFI boundary. /// The [`SessionConfig`] will be generated from a hash map of the config /// options in the provider and will be reconstructed on this side of the -/// interface.s -pub struct ForeignSessionConfig(pub SessionConfig); - -impl TryFrom<&FFI_SessionConfig> for ForeignSessionConfig { +/// interface. +impl TryFrom<&FFI_SessionConfig> for SessionConfig { type Error = DataFusionError; fn try_from(config: &FFI_SessionConfig) -> Result { + if (config.library_marker_id)() == crate::get_library_marker_id() { + return Ok(config.inner().clone()); + } + let config_options = unsafe { (config.config_options)(config) }; let mut options_map = HashMap::new(); @@ -159,7 +154,7 @@ impl TryFrom<&FFI_SessionConfig> for ForeignSessionConfig { options_map.insert(kv_pair.0.to_string(), kv_pair.1.to_string()); }); - Ok(Self(SessionConfig::from_string_hash_map(&options_map)?)) + SessionConfig::from_string_hash_map(&options_map) } } @@ -172,13 +167,15 @@ mod tests { let session_config = SessionConfig::new(); let original_options = session_config.options().entries(); - let ffi_config: FFI_SessionConfig = (&session_config).into(); + let mut ffi_config: FFI_SessionConfig = (&session_config).into(); + let _ = ffi_config.clone(); + ffi_config.library_marker_id = crate::mock_foreign_marker_id; - let foreign_config: ForeignSessionConfig = (&ffi_config).try_into()?; + let foreign_config: SessionConfig = (&ffi_config).try_into()?; - let returned_options = foreign_config.0.options().entries(); + let returned_options = foreign_config.options().entries(); - assert!(original_options.len() == returned_options.len()); + assert_eq!(original_options.len(), returned_options.len()); Ok(()) } diff --git a/datafusion/ffi/src/session/mod.rs b/datafusion/ffi/src/session/mod.rs new file mode 100644 index 000000000000..84b876707847 --- /dev/null +++ b/datafusion/ffi/src/session/mod.rs @@ -0,0 +1,633 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arrow_wrappers::WrappedSchema; +use crate::execution_plan::FFI_ExecutionPlan; +use crate::session::config::FFI_SessionConfig; +use crate::session::task_context::FFI_TaskContext; +use crate::udaf::FFI_AggregateUDF; +use crate::udf::FFI_ScalarUDF; +use crate::udwf::FFI_WindowUDF; +use crate::{df_result, rresult, rresult_return}; +use abi_stable::std_types::{RHashMap, RStr}; +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow_schema::ffi::FFI_ArrowSchema; +use arrow_schema::SchemaRef; +use async_ffi::{FfiFuture, FutureExt}; +use async_trait::async_trait; +use datafusion_catalog::Session; +use datafusion_common::config::{ConfigOptions, TableOptions}; +use datafusion_common::{DFSchema, DataFusionError}; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_execution::TaskContext; +use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::{ + AggregateUDF, AggregateUDFImpl, Expr, LogicalPlan, ScalarUDF, ScalarUDFImpl, + WindowUDF, WindowUDFImpl, +}; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_proto::bytes::{logical_plan_from_bytes, logical_plan_to_bytes}; +use datafusion_proto::logical_plan::{ + from_proto::parse_expr, to_proto::serialize_expr, DefaultLogicalExtensionCodec, +}; +use datafusion_proto::physical_plan::{ + from_proto::parse_physical_expr, to_proto::serialize_physical_expr, + DefaultPhysicalExtensionCodec, +}; +use datafusion_proto::protobuf::{LogicalExprNode, PhysicalExprNode}; +use prost::Message; +use std::any::Any; +use std::collections::HashMap; +use std::{ffi::c_void, sync::Arc}; +use tokio::runtime::Handle; + +pub mod config; +pub mod task_context; +pub mod task_ctx_accessor; +pub use task_ctx_accessor::FFI_TaskContextAccessor; + +/// A stable struct for sharing [`Session`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_Session { + pub session_id: unsafe extern "C" fn(&Self) -> RStr, + + pub config: unsafe extern "C" fn(&Self) -> FFI_SessionConfig, + + pub create_physical_plan: + unsafe extern "C" fn( + &Self, + logical_plan_serialized: RVec, + ) -> FfiFuture>, + + pub create_physical_expr: unsafe extern "C" fn( + &Self, + expr_serialized: RVec, + schema: WrappedSchema, + ) -> RResult, RString>, + + pub scalar_functions: unsafe extern "C" fn(&Self) -> RHashMap, + + pub aggregate_functions: + unsafe extern "C" fn(&Self) -> RHashMap, + + pub window_functions: unsafe extern "C" fn(&Self) -> RHashMap, + + // TODO: Expand scope of FFI to include runtime environment + // pub runtime_env: unsafe extern "C" fn(&Self) -> FFI_RuntimeEnv, + + // pub execution_props: unsafe extern "C" fn(&Self) -> FFI_ExecutionProps, + pub table_options: unsafe extern "C" fn(&Self) -> RHashMap, + + pub default_table_options: unsafe extern "C" fn(&Self) -> RHashMap, + + pub task_ctx: unsafe extern "C" fn(&Self) -> FFI_TaskContext, + + /// Accessor for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_accessor: FFI_TaskContextAccessor, + + /// Used to create a clone on the provider of the registry. This should + /// only need to be called by the receiver of the plan. + pub clone: unsafe extern "C" fn(plan: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + + /// Return the major DataFusion version number of this registry. + pub version: unsafe extern "C" fn() -> u64, + + /// Internal data. This is only to be accessed by the provider of the plan. + /// A [`ForeignSession`] should never attempt to access this data. + pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, +} + +unsafe impl Send for FFI_Session {} +unsafe impl Sync for FFI_Session {} + +struct SessionPrivateData<'a> { + session: &'a (dyn Session + Send + Sync), + runtime: Option, +} + +impl FFI_Session { + fn inner(&self) -> &(dyn Session + Send + Sync) { + let private_data = self.private_data as *const SessionPrivateData; + unsafe { (*private_data).session } + } + + fn task_ctx(&self) -> Result, DataFusionError> { + (&self.task_ctx_accessor).try_into() + } + + unsafe fn runtime(&self) -> &Option { + let private_data = self.private_data as *const SessionPrivateData; + &(*private_data).runtime + } +} + +unsafe extern "C" fn session_id_fn_wrapper(session: &FFI_Session) -> RStr<'_> { + let session = session.inner(); + session.session_id().into() +} + +unsafe extern "C" fn config_fn_wrapper(session: &FFI_Session) -> FFI_SessionConfig { + let session = session.inner(); + session.config().into() +} + +unsafe extern "C" fn create_physical_plan_fn_wrapper( + session: &FFI_Session, + logical_plan_serialized: RVec, +) -> FfiFuture> { + let task_ctx_accessor = session.task_ctx_accessor.clone(); + let runtime = session.runtime().clone(); + let session = session.clone(); + async move { + let session = session.inner(); + let task_ctx = session.task_ctx(); + + let logical_plan = rresult_return!(logical_plan_from_bytes( + logical_plan_serialized.as_slice(), + &task_ctx + )); + + let physical_plan = session.create_physical_plan(&logical_plan).await; + + rresult!(physical_plan.map(|plan| FFI_ExecutionPlan::new( + plan, + task_ctx_accessor, + runtime + ))) + } + .into_ffi() +} + +unsafe extern "C" fn create_physical_expr_fn_wrapper( + session: &FFI_Session, + expr_serialized: RVec, + schema: WrappedSchema, +) -> RResult, RString> { + let task_ctx = rresult_return!(session.task_ctx()); + let session = session.inner(); + + let codec = DefaultLogicalExtensionCodec {}; + let logical_expr = LogicalExprNode::decode(expr_serialized.as_slice()).unwrap(); + let logical_expr = parse_expr(&logical_expr, task_ctx.as_ref(), &codec).unwrap(); + let schema: SchemaRef = schema.into(); + let schema: DFSchema = rresult_return!(schema.try_into()); + + let physical_expr = + rresult_return!(session.create_physical_expr(logical_expr, &schema)); + let codec = DefaultPhysicalExtensionCodec {}; + let physical_expr = + rresult_return!(serialize_physical_expr(&physical_expr, &codec)).encode_to_vec(); + + RResult::ROk(physical_expr.into()) +} + +unsafe extern "C" fn scalar_functions_fn_wrapper( + session: &FFI_Session, +) -> RHashMap { + let session = session.inner(); + session + .scalar_functions() + .iter() + .map(|(name, udf)| (name.clone().into(), FFI_ScalarUDF::from(Arc::clone(udf)))) + .collect() +} + +unsafe extern "C" fn aggregate_functions_fn_wrapper( + session: &FFI_Session, +) -> RHashMap { + let task_ctx_accessor = &session.task_ctx_accessor; + let session = session.inner(); + session + .aggregate_functions() + .iter() + .map(|(name, udaf)| { + ( + name.clone().into(), + FFI_AggregateUDF::new(Arc::clone(udaf), task_ctx_accessor.clone()), + ) + }) + .collect() +} + +unsafe extern "C" fn window_functions_fn_wrapper( + session: &FFI_Session, +) -> RHashMap { + let task_ctx_accessor = &session.task_ctx_accessor; + let session = session.inner(); + session + .window_functions() + .iter() + .map(|(name, udwf)| { + ( + name.clone().into(), + FFI_WindowUDF::new(Arc::clone(udwf), task_ctx_accessor.clone()), + ) + }) + .collect() +} + +fn table_options_to_rhash(options: &TableOptions) -> RHashMap { + options + .entries() + .into_iter() + .filter_map(|entry| entry.value.map(|v| (entry.key.into(), v.into()))) + .collect() +} + +unsafe extern "C" fn table_options_fn_wrapper( + session: &FFI_Session, +) -> RHashMap { + let session = session.inner(); + let table_options = session.table_options(); + table_options_to_rhash(table_options) +} + +unsafe extern "C" fn default_table_options_fn_wrapper( + session: &FFI_Session, +) -> RHashMap { + let session = session.inner(); + let table_options = session.default_table_options(); + + table_options_to_rhash(&table_options) +} + +unsafe extern "C" fn task_ctx_fn_wrapper(session: &FFI_Session) -> FFI_TaskContext { + let task_ctx_accessor = session.task_ctx_accessor.clone(); + let session = session.inner(); + FFI_TaskContext::new(session.task_ctx(), task_ctx_accessor) +} + +unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_Session) { + let private_data = Box::from_raw(provider.private_data as *mut SessionPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(provider: &FFI_Session) -> FFI_Session { + let old_private_data = provider.private_data as *const SessionPrivateData; + + let private_data = Box::into_raw(Box::new(SessionPrivateData { + session: (*old_private_data).session, + runtime: (*old_private_data).runtime.clone(), + })) as *mut c_void; + + FFI_Session { + session_id: session_id_fn_wrapper, + config: config_fn_wrapper, + create_physical_plan: create_physical_plan_fn_wrapper, + create_physical_expr: create_physical_expr_fn_wrapper, + scalar_functions: scalar_functions_fn_wrapper, + aggregate_functions: aggregate_functions_fn_wrapper, + window_functions: window_functions_fn_wrapper, + table_options: table_options_fn_wrapper, + default_table_options: default_table_options_fn_wrapper, + task_ctx: task_ctx_fn_wrapper, + task_ctx_accessor: provider.task_ctx_accessor.clone(), + + clone: clone_fn_wrapper, + release: release_fn_wrapper, + version: super::version, + private_data, + library_marker_id: crate::get_library_marker_id, + } +} + +impl Drop for FFI_Session { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +impl FFI_Session { + /// Creates a new [`FFI_Session`]. + pub fn new( + session: &(dyn Session + Send + Sync), + task_ctx_accessor: FFI_TaskContextAccessor, + runtime: Option, + ) -> Self { + let private_data = Box::new(SessionPrivateData { session, runtime }); + + Self { + session_id: session_id_fn_wrapper, + config: config_fn_wrapper, + create_physical_plan: create_physical_plan_fn_wrapper, + create_physical_expr: create_physical_expr_fn_wrapper, + scalar_functions: scalar_functions_fn_wrapper, + aggregate_functions: aggregate_functions_fn_wrapper, + window_functions: window_functions_fn_wrapper, + table_options: table_options_fn_wrapper, + default_table_options: default_table_options_fn_wrapper, + task_ctx: task_ctx_fn_wrapper, + task_ctx_accessor, + + clone: clone_fn_wrapper, + release: release_fn_wrapper, + version: super::version, + private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, + } + } +} + +/// This wrapper struct exists on the receiver side of the FFI interface, so it has +/// no guarantees about being able to access the data in `private_data`. Any functions +/// defined on this struct must only use the stable functions provided in +/// FFI_Session to interact with the foreign table provider. +#[derive(Debug)] +pub struct ForeignSession { + session: FFI_Session, + config: SessionConfig, + scalar_functions: HashMap>, + aggregate_functions: HashMap>, + window_functions: HashMap>, + table_options: TableOptions, + runtime_env: Arc, + props: ExecutionProps, +} + +unsafe impl Send for ForeignSession {} +unsafe impl Sync for ForeignSession {} + +impl FFI_Session { + pub fn as_local(&self) -> Option<&(dyn Session + Send + Sync)> { + if (self.library_marker_id)() == crate::get_library_marker_id() { + return Some(self.inner()); + } + None + } +} + +impl TryFrom<&FFI_Session> for ForeignSession { + type Error = DataFusionError; + fn try_from(session: &FFI_Session) -> Result { + unsafe { + let table_options = + table_options_from_rhashmap((session.table_options)(session)); + + let config = (session.config)(session); + let config = SessionConfig::try_from(&config)?; + + let scalar_functions = (session.scalar_functions)(session) + .into_iter() + .map(|kv_pair| { + let udf = >::try_from(&kv_pair.1)?; + + Ok(( + kv_pair.0.into_string(), + Arc::new(ScalarUDF::new_from_shared_impl(udf)), + )) + }) + .collect::>()?; + let aggregate_functions = (session.aggregate_functions)(session) + .into_iter() + .map(|kv_pair| { + let udaf = >::try_from(&kv_pair.1)?; + + Ok(( + kv_pair.0.into_string(), + Arc::new(AggregateUDF::new_from_shared_impl(udaf)), + )) + }) + .collect::>()?; + let window_functions = (session.window_functions)(session) + .into_iter() + .map(|kv_pair| { + let udwf = >::try_from(&kv_pair.1)?; + + Ok(( + kv_pair.0.into_string(), + Arc::new(WindowUDF::new_from_shared_impl(udwf)), + )) + }) + .collect::>()?; + + Ok(Self { + session: session.clone(), + config, + table_options, + scalar_functions, + aggregate_functions, + window_functions, + runtime_env: Default::default(), + props: Default::default(), + }) + } + } +} + +impl Clone for FFI_Session { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +fn table_options_from_rhashmap(options: RHashMap) -> TableOptions { + let options = options + .into_iter() + .map(|kv_pair| (kv_pair.0.into_string(), kv_pair.1.into_string())) + .collect(); + + TableOptions::from_string_hash_map(&options).unwrap_or_else(|err| { + log::warn!("Error parsing default table options: {err}"); + TableOptions::default() + }) +} + +#[async_trait] +impl Session for ForeignSession { + fn session_id(&self) -> &str { + unsafe { (self.session.session_id)(&self.session).as_str() } + } + + fn config(&self) -> &SessionConfig { + &self.config + } + + fn config_options(&self) -> &ConfigOptions { + self.config.options() + } + + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + ) -> datafusion_common::Result> { + unsafe { + let logical_plan = logical_plan_to_bytes(logical_plan)?; + let physical_plan = df_result!( + (self.session.create_physical_plan)( + &self.session, + logical_plan.as_ref().into() + ) + .await + )?; + let physical_plan = >::try_from(&physical_plan)?; + + Ok(physical_plan) + } + } + + fn create_physical_expr( + &self, + expr: Expr, + df_schema: &DFSchema, + ) -> datafusion_common::Result> { + unsafe { + let codec = DefaultLogicalExtensionCodec {}; + let logical_expr = serialize_expr(&expr, &codec)?.encode_to_vec(); + let schema = WrappedSchema(FFI_ArrowSchema::try_from(df_schema.as_arrow())?); + + let physical_expr = df_result!((self.session.create_physical_expr)( + &self.session, + logical_expr.into(), + schema + ))?; + + let physical_expr = PhysicalExprNode::decode(physical_expr.as_slice()) + .map_err(|err| DataFusionError::External(Box::new(err)))?; + + let codec = DefaultPhysicalExtensionCodec {}; + let physical_expr = parse_physical_expr( + &physical_expr, + self.task_ctx().as_ref(), + df_schema.as_arrow(), + &codec, + )?; + + Ok(physical_expr) + } + } + + fn scalar_functions(&self) -> &HashMap> { + &self.scalar_functions + } + + fn aggregate_functions(&self) -> &HashMap> { + &self.aggregate_functions + } + + fn window_functions(&self) -> &HashMap> { + &self.window_functions + } + + fn runtime_env(&self) -> &Arc { + &self.runtime_env + } + + fn execution_props(&self) -> &ExecutionProps { + &self.props + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn table_options(&self) -> &TableOptions { + &self.table_options + } + + fn default_table_options(&self) -> TableOptions { + unsafe { + table_options_from_rhashmap((self.session.default_table_options)( + &self.session, + )) + } + } + + fn table_options_mut(&mut self) -> &mut TableOptions { + log::warn!("Mutating table options is not supported via FFI. Changes will not have an effect."); + &mut self.table_options + } + + fn task_ctx(&self) -> Arc { + unsafe { Arc::new((self.session.task_ctx)(&self.session).into()) } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::{DataType, Field, Schema}; + use datafusion::prelude::SessionContext; + use datafusion_common::DataFusionError; + use datafusion_execution::TaskContextAccessor; + use datafusion_expr::col; + use datafusion_expr::registry::FunctionRegistry; + use std::sync::Arc; + + #[tokio::test] + async fn test_ffi_session() -> Result<(), DataFusionError> { + let ctx = Arc::new(SessionContext::new()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; + let state = ctx.state(); + + let local_session = FFI_Session::new(&state, task_ctx_accessor.into(), None); + let foreign_session = ForeignSession::try_from(&local_session)?; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let df_schema = schema.try_into()?; + let physical_expr = foreign_session.create_physical_expr(col("a"), &df_schema)?; + assert_eq!( + format!("{physical_expr:?}"), + "Column { name: \"a\", index: 0 }" + ); + + assert_eq!(foreign_session.session_id(), state.session_id()); + + let logical_plan = LogicalPlan::default(); + let physical_plan = foreign_session.create_physical_plan(&logical_plan).await?; + assert_eq!(format!("{physical_plan:?}"), "EmptyExec { schema: Schema { fields: [], metadata: {} }, partitions: 1, cache: PlanProperties { eq_properties: EquivalenceProperties { eq_group: EquivalenceGroup { map: {}, classes: [] }, oeq_class: OrderingEquivalenceClass { orderings: [] }, oeq_cache: OrderingEquivalenceCache { normal_cls: OrderingEquivalenceClass { orderings: [] }, leading_map: {} }, constraints: Constraints { inner: [] }, schema: Schema { fields: [], metadata: {} } }, partitioning: UnknownPartitioning(1), emission_type: Incremental, boundedness: Bounded, evaluation_type: Lazy, scheduling_type: Cooperative, output_ordering: None } }"); + + assert_eq!( + format!("{:?}", foreign_session.default_table_options()), + format!("{:?}", state.default_table_options()) + ); + + assert_eq!( + format!("{:?}", foreign_session.table_options()), + format!("{:?}", state.table_options()) + ); + + let local_udfs = state.udfs(); + for udf in foreign_session.scalar_functions().keys() { + assert!(local_udfs.contains(udf)); + } + let local_udafs = state.udafs(); + for udaf in foreign_session.aggregate_functions().keys() { + assert!(local_udafs.contains(udaf)); + } + let local_udwfs = state.udwfs(); + for udwf in foreign_session.window_functions().keys() { + assert!(local_udwfs.contains(udwf)); + } + + Ok(()) + } +} diff --git a/datafusion/ffi/src/session/task_context.rs b/datafusion/ffi/src/session/task_context.rs new file mode 100644 index 000000000000..a15227327162 --- /dev/null +++ b/datafusion/ffi/src/session/task_context.rs @@ -0,0 +1,249 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::session::config::FFI_SessionConfig; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; +use crate::udaf::FFI_AggregateUDF; +use crate::udf::FFI_ScalarUDF; +use crate::udwf::FFI_WindowUDF; +use abi_stable::pmr::ROption; +use abi_stable::std_types::RHashMap; +use abi_stable::{std_types::RString, StableAbi}; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_execution::TaskContext; +use datafusion_expr::{ + AggregateUDF, AggregateUDFImpl, ScalarUDF, ScalarUDFImpl, WindowUDF, WindowUDFImpl, +}; +use std::{ffi::c_void, sync::Arc}; + +/// A stable struct for sharing [`TaskContext`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_TaskContext { + pub session_id: unsafe extern "C" fn(&Self) -> RString, + + pub task_id: unsafe extern "C" fn(&Self) -> ROption, + + pub session_config: unsafe extern "C" fn(&Self) -> FFI_SessionConfig, + + pub scalar_functions: unsafe extern "C" fn(&Self) -> RHashMap, + + pub aggregate_functions: + unsafe extern "C" fn(&Self) -> RHashMap, + + pub window_functions: unsafe extern "C" fn(&Self) -> RHashMap, + + /// Accessor for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_accessor: FFI_TaskContextAccessor, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the plan. + /// The foreign library should never attempt to access this data. + pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, +} + +struct TaskContextPrivateData { + ctx: Arc, +} + +impl FFI_TaskContext { + unsafe fn inner(&self) -> &TaskContext { + let private_data = self.private_data as *const TaskContextPrivateData; + &(*private_data).ctx + } +} + +unsafe extern "C" fn session_id_fn_wrapper(ctx: &FFI_TaskContext) -> RString { + let ctx = ctx.inner(); + ctx.session_id().into() +} + +unsafe extern "C" fn task_id_fn_wrapper(ctx: &FFI_TaskContext) -> ROption { + let ctx = ctx.inner(); + ctx.task_id().map(|s| s.as_str().into()).into() +} + +unsafe extern "C" fn session_config_fn_wrapper( + ctx: &FFI_TaskContext, +) -> FFI_SessionConfig { + let ctx = ctx.inner(); + ctx.session_config().into() +} + +unsafe extern "C" fn scalar_functions_fn_wrapper( + ctx: &FFI_TaskContext, +) -> RHashMap { + let ctx = ctx.inner(); + ctx.scalar_functions() + .iter() + .map(|(name, udf)| (name.to_owned().into(), udf.into())) + .collect() +} + +unsafe extern "C" fn aggregate_functions_fn_wrapper( + ctx: &FFI_TaskContext, +) -> RHashMap { + let task_ctx_accessor = &ctx.task_ctx_accessor; + let ctx = ctx.inner(); + ctx.aggregate_functions() + .iter() + .map(|(name, udaf)| { + ( + name.to_owned().into(), + FFI_AggregateUDF::new(Arc::clone(udaf), task_ctx_accessor.clone()), + ) + }) + .collect() +} + +unsafe extern "C" fn window_functions_fn_wrapper( + ctx: &FFI_TaskContext, +) -> RHashMap { + let task_ctx_accessor = &ctx.task_ctx_accessor; + let ctx = ctx.inner(); + ctx.window_functions() + .iter() + .map(|(name, udf)| { + ( + name.to_owned().into(), + FFI_WindowUDF::new(Arc::clone(udf), task_ctx_accessor.clone()), + ) + }) + .collect() +} + +unsafe extern "C" fn release_fn_wrapper(ctx: &mut FFI_TaskContext) { + let private_data = Box::from_raw(ctx.private_data as *mut TaskContextPrivateData); + drop(private_data); +} + +impl Drop for FFI_TaskContext { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +impl FFI_TaskContext { + pub fn new( + ctx: Arc, + task_ctx_accessor: FFI_TaskContextAccessor, + ) -> Self { + let private_data = Box::new(TaskContextPrivateData { ctx }); + + FFI_TaskContext { + session_id: session_id_fn_wrapper, + task_id: task_id_fn_wrapper, + session_config: session_config_fn_wrapper, + scalar_functions: scalar_functions_fn_wrapper, + aggregate_functions: aggregate_functions_fn_wrapper, + window_functions: window_functions_fn_wrapper, + task_ctx_accessor, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, + } + } +} + +impl From for TaskContext { + fn from(ffi_ctx: FFI_TaskContext) -> Self { + unsafe { + if (ffi_ctx.library_marker_id)() == crate::get_library_marker_id() { + return ffi_ctx.inner().clone(); + } + + let task_id = (ffi_ctx.task_id)(&ffi_ctx).map(|s| s.to_string()).into(); + let session_id = (ffi_ctx.session_id)(&ffi_ctx).into(); + let session_config = (ffi_ctx.session_config)(&ffi_ctx); + let session_config = + SessionConfig::try_from(&session_config).unwrap_or_default(); + + let scalar_functions = (ffi_ctx.scalar_functions)(&ffi_ctx) + .into_iter() + .filter_map(|kv_pair| { + let udf = >::try_from(&kv_pair.1); + + if let Err(err) = &udf { + log::error!("Unable to create WindowUDF in FFI: {err}") + } + + udf.ok().map(|udf| { + ( + kv_pair.0.into_string(), + Arc::new(ScalarUDF::new_from_shared_impl(udf)), + ) + }) + }) + .collect(); + let aggregate_functions = (ffi_ctx.aggregate_functions)(&ffi_ctx) + .into_iter() + .filter_map(|kv_pair| { + let udaf = >::try_from(&kv_pair.1); + + if let Err(err) = &udaf { + log::error!("Unable to create AggregateUDF in FFI: {err}") + } + + udaf.ok().map(|udaf| { + ( + kv_pair.0.into_string(), + Arc::new(AggregateUDF::new_from_shared_impl(udaf)), + ) + }) + }) + .collect(); + let window_functions = (ffi_ctx.window_functions)(&ffi_ctx) + .into_iter() + .filter_map(|kv_pair| { + let udwf = >::try_from(&kv_pair.1); + + if let Err(err) = &udwf { + log::error!("Unable to create WindowUDF in FFI: {err}") + } + + udwf.ok().map(|udwf| { + ( + kv_pair.0.into_string(), + Arc::new(WindowUDF::new_from_shared_impl(udwf)), + ) + }) + }) + .collect(); + + let runtime = Arc::new(RuntimeEnv::default()); + + TaskContext::new( + task_id, + session_id, + session_config, + scalar_functions, + aggregate_functions, + window_functions, + runtime, + ) + } + } +} diff --git a/datafusion/ffi/src/session/task_ctx_accessor.rs b/datafusion/ffi/src/session/task_ctx_accessor.rs new file mode 100644 index 000000000000..72525b30a748 --- /dev/null +++ b/datafusion/ffi/src/session/task_ctx_accessor.rs @@ -0,0 +1,156 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::session::task_context::FFI_TaskContext; +use crate::{df_result, rresult}; +use abi_stable::std_types::RResult; +use abi_stable::{std_types::RString, StableAbi}; +use datafusion_common::{exec_datafusion_err, DataFusionError}; +use datafusion_execution::{TaskContext, TaskContextAccessor}; +use std::sync::Weak; +use std::{ffi::c_void, sync::Arc}; + +/// Struct for accessing the [`TaskContext`]. This method contains a weak +/// reference, so there are no guarantees that the [`TaskContext`] remains +/// valid. This is used primarily for protobuf encoding and decoding of +/// data passed across the FFI boundary. See the crate README for +/// additional information. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_TaskContextAccessor { + pub get_task_context: + unsafe extern "C" fn(&Self) -> RResult, + + /// Used to create a clone on the task context accessor. This should + /// only need to be called by the receiver of the plan. + pub clone: unsafe extern "C" fn(plan: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the plan. + /// The foreign library should never attempt to access this data. + pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, +} + +unsafe impl Send for FFI_TaskContextAccessor {} +unsafe impl Sync for FFI_TaskContextAccessor {} + +struct TaskContextAccessorPrivateData { + ctx: Weak, +} + +impl FFI_TaskContextAccessor { + unsafe fn inner(&self) -> Option> { + let private_data = self.private_data as *const TaskContextAccessorPrivateData; + (*private_data) + .ctx + .upgrade() + .map(|ctx| ctx.get_task_context()) + } +} + +unsafe extern "C" fn get_task_context_fn_wrapper( + ctx_accessor: &FFI_TaskContextAccessor, +) -> RResult { + rresult!(ctx_accessor + .inner() + .map(|ctx| FFI_TaskContext::new(ctx, ctx_accessor.clone())) + .ok_or_else(|| { + exec_datafusion_err!( + "TaskContextAccessor went out of scope over FFI boundary." + ) + })) +} + +unsafe extern "C" fn clone_fn_wrapper( + accessor: &FFI_TaskContextAccessor, +) -> FFI_TaskContextAccessor { + let private_data = accessor.private_data as *const TaskContextAccessorPrivateData; + let ctx = Weak::clone(&(*private_data).ctx); + + let private_data = Box::new(TaskContextAccessorPrivateData { ctx }); + + FFI_TaskContextAccessor { + get_task_context: get_task_context_fn_wrapper, + release: release_fn_wrapper, + clone: clone_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, + } +} +unsafe extern "C" fn release_fn_wrapper(ctx: &mut FFI_TaskContextAccessor) { + let private_data = + Box::from_raw(ctx.private_data as *mut TaskContextAccessorPrivateData); + drop(private_data); +} +impl Drop for FFI_TaskContextAccessor { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +impl Clone for FFI_TaskContextAccessor { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From> for FFI_TaskContextAccessor { + fn from(ctx: Arc) -> Self { + (&ctx).into() + } +} + +impl From<&Arc> for FFI_TaskContextAccessor { + fn from(ctx: &Arc) -> Self { + let ctx = Arc::downgrade(ctx); + let private_data = Box::new(TaskContextAccessorPrivateData { ctx }); + + FFI_TaskContextAccessor { + get_task_context: get_task_context_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, + } + } +} + +impl TryFrom<&FFI_TaskContextAccessor> for Arc { + type Error = DataFusionError; + fn try_from(ffi_ctx: &FFI_TaskContextAccessor) -> Result { + unsafe { + if (ffi_ctx.library_marker_id)() == crate::get_library_marker_id() { + return ffi_ctx.inner().ok_or_else(|| { + exec_datafusion_err!( + "TaskContextAccessor went out of scope over FFI boundary." + ) + }); + } + + df_result!((ffi_ctx.get_task_context)(ffi_ctx)) + .map(Into::into) + .map(Arc::new) + } + } +} diff --git a/datafusion/ffi/src/table_provider.rs b/datafusion/ffi/src/table_provider.rs index 890511997a70..038ba3c6d0bf 100644 --- a/datafusion/ffi/src/table_provider.rs +++ b/datafusion/ffi/src/table_provider.rs @@ -24,15 +24,8 @@ use abi_stable::{ use arrow::datatypes::SchemaRef; use async_ffi::{FfiFuture, FutureExt}; use async_trait::async_trait; -use datafusion::{ - catalog::{Session, TableProvider}, - datasource::TableType, - error::DataFusionError, - execution::{session_state::SessionStateBuilder, TaskContext}, - logical_expr::{logical_plan::dml::InsertOp, TableProviderFilterPushDown}, - physical_plan::ExecutionPlan, - prelude::{Expr, SessionContext}, -}; +use datafusion_catalog::{Session, TableProvider}; + use datafusion_proto::{ logical_plan::{ from_proto::parse_exprs, to_proto::serialize_exprs, DefaultLogicalExtensionCodec, @@ -42,19 +35,21 @@ use datafusion_proto::{ use prost::Message; use tokio::runtime::Handle; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; +use crate::session::{FFI_Session, ForeignSession}; use crate::{ arrow_wrappers::WrappedSchema, - df_result, rresult_return, - session_config::ForeignSessionConfig, - table_source::{FFI_TableProviderFilterPushDown, FFI_TableType}, -}; - -use super::{ - execution_plan::{FFI_ExecutionPlan, ForeignExecutionPlan}, + df_result, + execution_plan::FFI_ExecutionPlan, insert_op::FFI_InsertOp, - session_config::FFI_SessionConfig, + rresult_return, + table_source::{FFI_TableProviderFilterPushDown, FFI_TableType}, }; -use datafusion::error::Result; +use datafusion_common::{DataFusionError, Result}; +use datafusion_execution::TaskContext; +use datafusion_expr::dml::InsertOp; +use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType}; +use datafusion_physical_plan::ExecutionPlan; /// A stable struct for sharing [`TableProvider`] across FFI boundaries. /// @@ -115,7 +110,7 @@ pub struct FFI_TableProvider { /// * `limit` - if specified, limit the number of rows returned pub scan: unsafe extern "C" fn( provider: &Self, - session_config: &FFI_SessionConfig, + session_config: &FFI_Session, projections: RVec, filters_serialized: RVec, limit: ROption, @@ -138,11 +133,15 @@ pub struct FFI_TableProvider { pub insert_into: unsafe extern "C" fn( provider: &Self, - session_config: &FFI_SessionConfig, + session_config: &FFI_Session, input: &FFI_ExecutionPlan, insert_op: FFI_InsertOp, ) -> FfiFuture>, + /// Accessor for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_accessor: FFI_TaskContextAccessor, + /// Used to create a clone on the provider of the execution plan. This should /// only need to be called by the receiver of the plan. pub clone: unsafe extern "C" fn(plan: &Self) -> Self, @@ -154,39 +153,49 @@ pub struct FFI_TableProvider { pub version: unsafe extern "C" fn() -> u64, /// Internal data. This is only to be accessed by the provider of the plan. - /// A [`ForeignExecutionPlan`] should never attempt to access this data. + /// A [`ForeignTableProvider`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_TableProvider {} unsafe impl Sync for FFI_TableProvider {} struct ProviderPrivateData { - provider: Arc, + provider: Arc, runtime: Option, } -unsafe extern "C" fn schema_fn_wrapper(provider: &FFI_TableProvider) -> WrappedSchema { - let private_data = provider.private_data as *const ProviderPrivateData; - let provider = &(*private_data).provider; +impl FFI_TableProvider { + fn inner(&self) -> &Arc { + let private_data = self.private_data as *const ProviderPrivateData; + unsafe { &(*private_data).provider } + } - provider.schema().into() + fn runtime(&self) -> &Option { + let private_data = self.private_data as *const ProviderPrivateData; + unsafe { &(*private_data).runtime } + } +} + +unsafe extern "C" fn schema_fn_wrapper(provider: &FFI_TableProvider) -> WrappedSchema { + provider.inner().schema().into() } unsafe extern "C" fn table_type_fn_wrapper( provider: &FFI_TableProvider, ) -> FFI_TableType { - let private_data = provider.private_data as *const ProviderPrivateData; - let provider = &(*private_data).provider; - - provider.table_type().into() + provider.inner().table_type().into() } fn supports_filters_pushdown_internal( - provider: &Arc, + provider: &Arc, filters_serialized: &[u8], + task_ctx: &Arc, ) -> Result> { - let default_ctx = SessionContext::new(); let codec = DefaultLogicalExtensionCodec {}; let filters = match filters_serialized.is_empty() { @@ -195,7 +204,7 @@ fn supports_filters_pushdown_internal( let proto_filters = LogicalExprList::decode(filters_serialized) .map_err(|e| DataFusionError::Plan(e.to_string()))?; - parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec)? + parse_exprs(proto_filters.expr.iter(), task_ctx.as_ref(), &codec)? } }; let filters_borrowed: Vec<&Expr> = filters.iter().collect(); @@ -213,38 +222,36 @@ unsafe extern "C" fn supports_filters_pushdown_fn_wrapper( provider: &FFI_TableProvider, filters_serialized: RVec, ) -> RResult, RString> { - let private_data = provider.private_data as *const ProviderPrivateData; - let provider = &(*private_data).provider; + let task_ctx = + rresult_return!(>::try_from(&provider.task_ctx_accessor)); - supports_filters_pushdown_internal(provider, &filters_serialized) + supports_filters_pushdown_internal(provider.inner(), &filters_serialized, &task_ctx) .map_err(|e| e.to_string().into()) .into() } unsafe extern "C" fn scan_fn_wrapper( provider: &FFI_TableProvider, - session_config: &FFI_SessionConfig, + session: &FFI_Session, projections: RVec, filters_serialized: RVec, limit: ROption, ) -> FfiFuture> { - let private_data = provider.private_data as *mut ProviderPrivateData; - let internal_provider = &(*private_data).provider; - let session_config = session_config.clone(); - let runtime = &(*private_data).runtime; + let task_ctx: Result, DataFusionError> = + (&provider.task_ctx_accessor).try_into(); + let task_ctx_accessor = provider.task_ctx_accessor.clone(); + let session = ForeignSession::try_from(session); + let internal_provider = Arc::clone(provider.inner()); + let runtime = provider.runtime().clone(); async move { - let config = rresult_return!(ForeignSessionConfig::try_from(&session_config)); - let session = SessionStateBuilder::new() - .with_default_features() - .with_config(config.0) - .build(); - let ctx = SessionContext::new_with_state(session); + let session = rresult_return!(session); + let task_ctx = rresult_return!(task_ctx); let filters = match filters_serialized.is_empty() { true => vec![], false => { - let default_ctx = SessionContext::new(); + // let default_ctx = SessionContext::new(); let codec = DefaultLogicalExtensionCodec {}; let proto_filters = @@ -252,7 +259,7 @@ unsafe extern "C" fn scan_fn_wrapper( rresult_return!(parse_exprs( proto_filters.expr.iter(), - &default_ctx, + task_ctx.as_ref(), &codec )) } @@ -262,13 +269,13 @@ unsafe extern "C" fn scan_fn_wrapper( let plan = rresult_return!( internal_provider - .scan(&ctx.state(), Some(&projections), &filters, limit.into()) + .scan(&session, Some(&projections), &filters, limit.into()) .await ); RResult::ROk(FFI_ExecutionPlan::new( plan, - ctx.task_ctx(), + task_ctx_accessor, runtime.clone(), )) } @@ -277,37 +284,34 @@ unsafe extern "C" fn scan_fn_wrapper( unsafe extern "C" fn insert_into_fn_wrapper( provider: &FFI_TableProvider, - session_config: &FFI_SessionConfig, + session: &FFI_Session, input: &FFI_ExecutionPlan, insert_op: FFI_InsertOp, ) -> FfiFuture> { - let private_data = provider.private_data as *mut ProviderPrivateData; - let internal_provider = &(*private_data).provider; - let session_config = session_config.clone(); + let task_ctx_accessor = provider.task_ctx_accessor.clone(); + let internal_provider = Arc::clone(provider.inner()); + let session = session.clone(); let input = input.clone(); - let runtime = &(*private_data).runtime; + let runtime = provider.runtime().clone(); async move { - let config = rresult_return!(ForeignSessionConfig::try_from(&session_config)); - let session = SessionStateBuilder::new() - .with_default_features() - .with_config(config.0) - .build(); - let ctx = SessionContext::new_with_state(session); + let local_session = session.as_local(); + let foreign_session = rresult_return!(ForeignSession::try_from(&session)); + let session = local_session.unwrap_or(&foreign_session); - let input = rresult_return!(ForeignExecutionPlan::try_from(&input).map(Arc::new)); + let input = rresult_return!(>::try_from(&input)); let insert_op = InsertOp::from(insert_op); let plan = rresult_return!( internal_provider - .insert_into(&ctx.state(), input, insert_op) + .insert_into(session, input, insert_op) .await ); RResult::ROk(FFI_ExecutionPlan::new( plan, - ctx.task_ctx(), + task_ctx_accessor, runtime.clone(), )) } @@ -334,10 +338,12 @@ unsafe extern "C" fn clone_fn_wrapper(provider: &FFI_TableProvider) -> FFI_Table table_type: table_type_fn_wrapper, supports_filters_pushdown: provider.supports_filters_pushdown, insert_into: provider.insert_into, + task_ctx_accessor: provider.task_ctx_accessor.clone(), clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, private_data, + library_marker_id: crate::get_library_marker_id, } } @@ -350,9 +356,10 @@ impl Drop for FFI_TableProvider { impl FFI_TableProvider { /// Creates a new [`FFI_TableProvider`]. pub fn new( - provider: Arc, + provider: Arc, can_support_pushdown_filters: bool, runtime: Option, + task_ctx_accessor: FFI_TaskContextAccessor, ) -> Self { let private_data = Box::new(ProviderPrivateData { provider, runtime }); @@ -365,10 +372,12 @@ impl FFI_TableProvider { false => None, }, insert_into: insert_into_fn_wrapper, + task_ctx_accessor, clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -383,9 +392,13 @@ pub struct ForeignTableProvider(pub FFI_TableProvider); unsafe impl Send for ForeignTableProvider {} unsafe impl Sync for ForeignTableProvider {} -impl From<&FFI_TableProvider> for ForeignTableProvider { +impl From<&FFI_TableProvider> for Arc { fn from(provider: &FFI_TableProvider) -> Self { - Self(provider.clone()) + if (provider.library_marker_id)() == crate::get_library_marker_id() { + Arc::clone(provider.inner()) as Arc + } else { + Arc::new(ForeignTableProvider(provider.clone())) + } } } @@ -417,7 +430,8 @@ impl TableProvider for ForeignTableProvider { filters: &[Expr], limit: Option, ) -> Result> { - let session_config: FFI_SessionConfig = session.config().into(); + // let session_config: FFI_SessionConfig = session.config().into(); + let session = FFI_Session::new(session, self.0.task_ctx_accessor.clone(), None); let projections: Option> = projection.map(|p| p.iter().map(|v| v.to_owned()).collect()); @@ -431,17 +445,17 @@ impl TableProvider for ForeignTableProvider { let plan = unsafe { let maybe_plan = (self.0.scan)( &self.0, - &session_config, + &session, projections.unwrap_or_default(), filters_serialized, limit.into(), ) .await; - ForeignExecutionPlan::try_from(&df_result!(maybe_plan)?)? + >::try_from(&df_result!(maybe_plan)?)? }; - Ok(Arc::new(plan)) + Ok(plan) } /// Tests whether the table provider can make use of a filter expression @@ -480,30 +494,30 @@ impl TableProvider for ForeignTableProvider { input: Arc, insert_op: InsertOp, ) -> Result> { - let session_config: FFI_SessionConfig = session.config().into(); - let rc = Handle::try_current().ok(); - let input = - FFI_ExecutionPlan::new(input, Arc::new(TaskContext::from(session)), rc); + let session = + FFI_Session::new(session, self.0.task_ctx_accessor.clone(), rc.clone()); + + let input = FFI_ExecutionPlan::new(input, self.0.task_ctx_accessor.clone(), rc); let insert_op: FFI_InsertOp = insert_op.into(); let plan = unsafe { let maybe_plan = - (self.0.insert_into)(&self.0, &session_config, &input, insert_op).await; + (self.0.insert_into)(&self.0, &session, &input, insert_op).await; - ForeignExecutionPlan::try_from(&df_result!(maybe_plan)?)? + >::try_from(&df_result!(maybe_plan)?)? }; - Ok(Arc::new(plan)) + Ok(plan) } } #[cfg(test)] mod tests { - use arrow::datatypes::Schema; - use datafusion::prelude::{col, lit}; - use super::*; + use arrow::datatypes::Schema; + use datafusion::prelude::{col, lit, SessionContext}; + use datafusion_execution::TaskContextAccessor; #[tokio::test] async fn test_round_trip_ffi_table_provider_scan() -> Result<()> { @@ -526,16 +540,19 @@ mod tests { vec![Arc::new(Float32Array::from(vec![64.0]))], )?; - let ctx = SessionContext::new(); + let ctx = Arc::new(SessionContext::new()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; let provider = Arc::new(MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?); - let ffi_provider = FFI_TableProvider::new(provider, true, None); + let mut ffi_provider = + FFI_TableProvider::new(provider, true, None, task_ctx_accessor.into()); + ffi_provider.library_marker_id = crate::mock_foreign_marker_id; - let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into(); + let foreign_table_provider: Arc = (&ffi_provider).into(); - ctx.register_table("t", Arc::new(foreign_table_provider))?; + ctx.register_table("t", foreign_table_provider)?; let df = ctx.table("t").await?; @@ -568,16 +585,19 @@ mod tests { vec![Arc::new(Float32Array::from(vec![64.0]))], )?; - let ctx = SessionContext::new(); + let ctx = Arc::new(SessionContext::new()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; let provider = Arc::new(MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?); - let ffi_provider = FFI_TableProvider::new(provider, true, None); + let mut ffi_provider = + FFI_TableProvider::new(provider, true, None, task_ctx_accessor.into()); + ffi_provider.library_marker_id = crate::mock_foreign_marker_id; - let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into(); + let foreign_table_provider: Arc = (&ffi_provider).into(); - ctx.register_table("t", Arc::new(foreign_table_provider))?; + ctx.register_table("t", foreign_table_provider)?; let result = ctx .sql("INSERT INTO t VALUES (128.0);") @@ -615,15 +635,18 @@ mod tests { vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))], )?; - let ctx = SessionContext::new(); + let ctx = Arc::new(SessionContext::new()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; let provider = Arc::new(MemTable::try_new(schema, vec![vec![batch1]])?); - let ffi_provider = FFI_TableProvider::new(provider, true, None); + let mut ffi_provider = + FFI_TableProvider::new(provider, true, None, task_ctx_accessor.into()); + ffi_provider.library_marker_id = crate::mock_foreign_marker_id; - let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into(); + let foreign_table_provider: Arc = (&ffi_provider).into(); - ctx.register_table("t", Arc::new(foreign_table_provider))?; + ctx.register_table("t", foreign_table_provider)?; let result = ctx .sql("SELECT COUNT(*) as cnt FROM t") diff --git a/datafusion/ffi/src/table_source.rs b/datafusion/ffi/src/table_source.rs index 418fdf16a564..dfdf8c1c64ae 100644 --- a/datafusion/ffi/src/table_source.rs +++ b/datafusion/ffi/src/table_source.rs @@ -16,7 +16,7 @@ // under the License. use abi_stable::StableAbi; -use datafusion::{datasource::TableType, logical_expr::TableProviderFilterPushDown}; +use datafusion_expr::{TableProviderFilterPushDown, TableType}; /// FFI safe version of [`TableProviderFilterPushDown`]. #[repr(C)] diff --git a/datafusion/ffi/src/tests/async_provider.rs b/datafusion/ffi/src/tests/async_provider.rs index cef4161d8c1f..5a0e217f453f 100644 --- a/datafusion/ffi/src/tests/async_provider.rs +++ b/datafusion/ffi/src/tests/async_provider.rs @@ -27,27 +27,24 @@ use std::{any::Any, fmt::Debug, sync::Arc}; +use super::create_record_batch; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; use crate::table_provider::FFI_TableProvider; use arrow::array::RecordBatch; use arrow::datatypes::Schema; use async_trait::async_trait; -use datafusion::{ - catalog::{Session, TableProvider}, - error::Result, - execution::RecordBatchStream, - physical_expr::EquivalenceProperties, - physical_plan::{ExecutionPlan, Partitioning}, - prelude::Expr, -}; -use datafusion_common::exec_err; +use datafusion_catalog::{Session, TableProvider}; +use datafusion_common::{error::Result, exec_err}; +use datafusion_execution::RecordBatchStream; +use datafusion_expr::Expr; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::ExecutionPlan; use futures::Stream; use tokio::{ runtime::Handle, sync::{broadcast, mpsc}, }; -use super::create_record_batch; - #[derive(Debug)] pub struct AsyncTableProvider { batch_request: mpsc::Sender, @@ -135,8 +132,8 @@ impl TableProvider for AsyncTableProvider { super::create_test_schema() } - fn table_type(&self) -> datafusion::logical_expr::TableType { - datafusion::logical_expr::TableType::Base + fn table_type(&self) -> datafusion_expr::TableType { + datafusion_expr::TableType::Base } async fn scan( @@ -163,7 +160,7 @@ impl Drop for AsyncTableProvider { #[derive(Debug)] struct AsyncTestExecutionPlan { - properties: datafusion::physical_plan::PlanProperties, + properties: datafusion_physical_plan::PlanProperties, batch_request: mpsc::Sender, batch_receiver: broadcast::Receiver>, } @@ -174,11 +171,11 @@ impl AsyncTestExecutionPlan { batch_receiver: broadcast::Receiver>, ) -> Self { Self { - properties: datafusion::physical_plan::PlanProperties::new( + properties: datafusion_physical_plan::PlanProperties::new( EquivalenceProperties::new(super::create_test_schema()), Partitioning::UnknownPartitioning(3), - datafusion::physical_plan::execution_plan::EmissionType::Incremental, - datafusion::physical_plan::execution_plan::Boundedness::Bounded, + datafusion_physical_plan::execution_plan::EmissionType::Incremental, + datafusion_physical_plan::execution_plan::Boundedness::Bounded, ), batch_request, batch_receiver, @@ -195,7 +192,7 @@ impl ExecutionPlan for AsyncTestExecutionPlan { self } - fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + fn properties(&self) -> &datafusion_physical_plan::PlanProperties { &self.properties } @@ -213,8 +210,8 @@ impl ExecutionPlan for AsyncTestExecutionPlan { fn execute( &self, _partition: usize, - _context: Arc, - ) -> Result { + _context: Arc, + ) -> Result { Ok(Box::pin(AsyncTestRecordBatchStream { batch_request: self.batch_request.clone(), batch_receiver: self.batch_receiver.resubscribe(), @@ -222,10 +219,10 @@ impl ExecutionPlan for AsyncTestExecutionPlan { } } -impl datafusion::physical_plan::DisplayAs for AsyncTestExecutionPlan { +impl datafusion_physical_plan::DisplayAs for AsyncTestExecutionPlan { fn fmt_as( &self, - _t: datafusion::physical_plan::DisplayFormatType, + _t: datafusion_physical_plan::DisplayFormatType, _f: &mut std::fmt::Formatter, ) -> std::fmt::Result { // Do nothing, just a test @@ -277,7 +274,14 @@ impl Stream for AsyncTestRecordBatchStream { } } -pub(crate) fn create_async_table_provider() -> FFI_TableProvider { +pub(crate) fn create_async_table_provider( + task_ctx_accessor: FFI_TaskContextAccessor, +) -> FFI_TableProvider { let (table_provider, tokio_rt) = start_async_provider(); - FFI_TableProvider::new(Arc::new(table_provider), true, Some(tokio_rt)) + FFI_TableProvider::new( + Arc::new(table_provider), + true, + Some(tokio_rt), + task_ctx_accessor, + ) } diff --git a/datafusion/ffi/src/tests/catalog.rs b/datafusion/ffi/src/tests/catalog.rs index f4293adb41b9..6630adad2be1 100644 --- a/datafusion/ffi/src/tests/catalog.rs +++ b/datafusion/ffi/src/tests/catalog.rs @@ -28,17 +28,15 @@ use std::{any::Any, fmt::Debug, sync::Arc}; use crate::catalog_provider::FFI_CatalogProvider; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; use arrow::datatypes::Schema; use async_trait::async_trait; -use datafusion::{ - catalog::{ - CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, - TableProvider, - }, - common::exec_err, - datasource::MemTable, - error::{DataFusionError, Result}, +use datafusion_catalog::{ + CatalogProvider, MemTable, MemoryCatalogProvider, MemorySchemaProvider, + SchemaProvider, TableProvider, }; +use datafusion_common::error::{DataFusionError, Result}; +use datafusion_common::exec_err; /// This schema provider is intended only for unit tests. It prepopulates with one /// table and only allows for tables named sales and purchases. @@ -49,7 +47,7 @@ pub struct FixedSchemaProvider { pub fn fruit_table() -> Arc { use arrow::datatypes::{DataType, Field}; - use datafusion::common::record_batch; + use datafusion_common::record_batch; let schema = Arc::new(Schema::new(vec![ Field::new("units", DataType::Int32, true), @@ -177,7 +175,9 @@ impl CatalogProvider for FixedCatalogProvider { } } -pub(crate) extern "C" fn create_catalog_provider() -> FFI_CatalogProvider { +pub(crate) extern "C" fn create_catalog_provider( + task_ctx_accessor: FFI_TaskContextAccessor, +) -> FFI_CatalogProvider { let catalog_provider = Arc::new(FixedCatalogProvider::default()); - FFI_CatalogProvider::new(catalog_provider, None) + FFI_CatalogProvider::new(catalog_provider, None, task_ctx_accessor) } diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index 816086c32041..979531d47e99 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -34,12 +34,12 @@ use crate::udaf::FFI_AggregateUDF; use crate::udwf::FFI_WindowUDF; use super::{table_provider::FFI_TableProvider, udf::FFI_ScalarUDF}; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; +use crate::tests::udf_udaf_udwf::{create_ffi_cumedist_func, create_ffi_ntile_func}; use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, Schema}; use async_provider::create_async_table_provider; -use datafusion::{ - arrow::datatypes::{DataType, Field, Schema}, - common::record_batch, -}; +use datafusion_common::record_batch; use sync_provider::create_sync_table_provider; use udf_udaf_udwf::{ create_ffi_abs_func, create_ffi_random_func, create_ffi_rank_func, @@ -60,25 +60,37 @@ pub mod utils; /// module. pub struct ForeignLibraryModule { /// Construct an opinionated catalog provider - pub create_catalog: extern "C" fn() -> FFI_CatalogProvider, + pub create_catalog: + extern "C" fn(task_ctx_accessor: FFI_TaskContextAccessor) -> FFI_CatalogProvider, /// Constructs the table provider - pub create_table: extern "C" fn(synchronous: bool) -> FFI_TableProvider, + pub create_table: extern "C" fn( + synchronous: bool, + task_ctx_accessor: FFI_TaskContextAccessor, + ) -> FFI_TableProvider, /// Create a scalar UDF pub create_scalar_udf: extern "C" fn() -> FFI_ScalarUDF, pub create_nullary_udf: extern "C" fn() -> FFI_ScalarUDF, - pub create_table_function: extern "C" fn() -> FFI_TableFunction, + pub create_table_function: + extern "C" fn(task_ctx_accessor: FFI_TaskContextAccessor) -> FFI_TableFunction, /// Create an aggregate UDAF using sum - pub create_sum_udaf: extern "C" fn() -> FFI_AggregateUDF, + pub create_sum_udaf: extern "C" fn(FFI_TaskContextAccessor) -> FFI_AggregateUDF, /// Create grouping UDAF using stddev - pub create_stddev_udaf: extern "C" fn() -> FFI_AggregateUDF, + pub create_stddev_udaf: extern "C" fn(FFI_TaskContextAccessor) -> FFI_AggregateUDF, + + /// Rank will test `evaluate` + pub create_rank_udwf: extern "C" fn(FFI_TaskContextAccessor) -> FFI_WindowUDF, + + /// NTile will test `evaluate_all` + pub create_ntile_udwf: extern "C" fn(FFI_TaskContextAccessor) -> FFI_WindowUDF, - pub create_rank_udwf: extern "C" fn() -> FFI_WindowUDF, + /// NTile will test `evaluate_all_with_rank` + pub create_cumedist_udwf: extern "C" fn(FFI_TaskContextAccessor) -> FFI_WindowUDF, pub version: extern "C" fn() -> u64, } @@ -111,10 +123,13 @@ pub fn create_record_batch(start_value: i32, num_values: usize) -> RecordBatch { /// Here we only wish to create a simple table provider as an example. /// We create an in-memory table and convert it to it's FFI counterpart. -extern "C" fn construct_table_provider(synchronous: bool) -> FFI_TableProvider { +extern "C" fn construct_table_provider( + synchronous: bool, + task_ctx_accessor: FFI_TaskContextAccessor, +) -> FFI_TableProvider { match synchronous { - true => create_sync_table_provider(), - false => create_async_table_provider(), + true => create_sync_table_provider(task_ctx_accessor), + false => create_async_table_provider(task_ctx_accessor), } } @@ -130,6 +145,8 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { create_sum_udaf: create_ffi_sum_func, create_stddev_udaf: create_ffi_stddev_func, create_rank_udwf: create_ffi_rank_func, + create_ntile_udwf: create_ffi_ntile_func, + create_cumedist_udwf: create_ffi_cumedist_func, version: super::version, } .leak_into_prefix() diff --git a/datafusion/ffi/src/tests/sync_provider.rs b/datafusion/ffi/src/tests/sync_provider.rs index ff85e0b15b39..bde8939afe91 100644 --- a/datafusion/ffi/src/tests/sync_provider.rs +++ b/datafusion/ffi/src/tests/sync_provider.rs @@ -15,14 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use crate::table_provider::FFI_TableProvider; -use datafusion::datasource::MemTable; - use super::{create_record_batch, create_test_schema}; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; +use crate::table_provider::FFI_TableProvider; +use datafusion_catalog::MemTable; +use std::sync::Arc; -pub(crate) fn create_sync_table_provider() -> FFI_TableProvider { +pub(crate) fn create_sync_table_provider( + task_ctx_accessor: FFI_TaskContextAccessor, +) -> FFI_TableProvider { let schema = create_test_schema(); // It is useful to create these as multiple record batches @@ -35,5 +36,5 @@ pub(crate) fn create_sync_table_provider() -> FFI_TableProvider { let table_provider = MemTable::try_new(schema, vec![batches]).unwrap(); - FFI_TableProvider::new(Arc::new(table_provider), true, None) + FFI_TableProvider::new(Arc::new(table_provider), true, None, task_ctx_accessor) } diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs index 55e31ef3ab77..6ff632efacf7 100644 --- a/datafusion/ffi/src/tests/udf_udaf_udwf.rs +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -19,15 +19,18 @@ use crate::{ udaf::FFI_AggregateUDF, udf::FFI_ScalarUDF, udtf::FFI_TableFunction, udwf::FFI_WindowUDF, }; -use datafusion::{ - catalog::TableFunctionImpl, - functions::math::{abs::AbsFunc, random::RandomFunc}, - functions_aggregate::{stddev::Stddev, sum::Sum}, - functions_table::generate_series::RangeFunc, - functions_window::rank::Rank, - logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}, -}; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; +use datafusion_catalog::TableFunctionImpl; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use datafusion_functions::math::abs::AbsFunc; +use datafusion_functions::math::random::RandomFunc; +use datafusion_functions_aggregate::stddev::Stddev; +use datafusion_functions_aggregate::sum::Sum; +use datafusion_functions_table::generate_series::RangeFunc; +use datafusion_functions_window::cume_dist::CumeDist; +use datafusion_functions_window::ntile::Ntile; +use datafusion_functions_window::rank::Rank; use std::sync::Arc; pub(crate) extern "C" fn create_ffi_abs_func() -> FFI_ScalarUDF { @@ -42,32 +45,56 @@ pub(crate) extern "C" fn create_ffi_random_func() -> FFI_ScalarUDF { udf.into() } -pub(crate) extern "C" fn create_ffi_table_func() -> FFI_TableFunction { +pub(crate) extern "C" fn create_ffi_table_func( + task_ctx_accessor: FFI_TaskContextAccessor, +) -> FFI_TableFunction { let udtf: Arc = Arc::new(RangeFunc {}); - FFI_TableFunction::new(udtf, None) + FFI_TableFunction::new(udtf, None, task_ctx_accessor) } -pub(crate) extern "C" fn create_ffi_sum_func() -> FFI_AggregateUDF { +pub(crate) extern "C" fn create_ffi_sum_func( + task_ctx_accessor: FFI_TaskContextAccessor, +) -> FFI_AggregateUDF { let udaf: Arc = Arc::new(Sum::new().into()); - udaf.into() + FFI_AggregateUDF::new(udaf, task_ctx_accessor) } -pub(crate) extern "C" fn create_ffi_stddev_func() -> FFI_AggregateUDF { +pub(crate) extern "C" fn create_ffi_stddev_func( + task_ctx_accessor: FFI_TaskContextAccessor, +) -> FFI_AggregateUDF { let udaf: Arc = Arc::new(Stddev::new().into()); - udaf.into() + FFI_AggregateUDF::new(udaf, task_ctx_accessor) } -pub(crate) extern "C" fn create_ffi_rank_func() -> FFI_WindowUDF { +pub(crate) extern "C" fn create_ffi_rank_func( + task_ctx_accessor: FFI_TaskContextAccessor, +) -> FFI_WindowUDF { let udwf: Arc = Arc::new( Rank::new( "rank_demo".to_string(), - datafusion::functions_window::rank::RankType::Basic, + datafusion_functions_window::rank::RankType::Basic, ) .into(), ); - udwf.into() + FFI_WindowUDF::new(udwf, task_ctx_accessor) +} + +pub(crate) extern "C" fn create_ffi_ntile_func( + task_ctx_accessor: FFI_TaskContextAccessor, +) -> FFI_WindowUDF { + let udwf: Arc = Arc::new(Ntile::new().into()); + + FFI_WindowUDF::new(udwf, task_ctx_accessor) +} + +pub(crate) extern "C" fn create_ffi_cumedist_func( + task_ctx_accessor: FFI_TaskContextAccessor, +) -> FFI_WindowUDF { + let udwf: Arc = Arc::new(CumeDist::new().into()); + + FFI_WindowUDF::new(udwf, task_ctx_accessor) } diff --git a/datafusion/ffi/src/tests/utils.rs b/datafusion/ffi/src/tests/utils.rs index 6465b17d9b60..4df0334c21fe 100644 --- a/datafusion/ffi/src/tests/utils.rs +++ b/datafusion/ffi/src/tests/utils.rs @@ -17,7 +17,7 @@ use crate::tests::ForeignLibraryModuleRef; use abi_stable::library::RootModule; -use datafusion::error::{DataFusionError, Result}; +use datafusion_common::error::{DataFusionError, Result}; use std::path::Path; /// Compute the path to the library. It would be preferable to simply use diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index 80b872159f48..f11ef0e3ade8 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -15,21 +15,20 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, ops::Deref}; - +use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; use abi_stable::{ std_types::{RResult, RString, RVec}, StableAbi, }; use arrow::{array::ArrayRef, error::ArrowError}; -use datafusion::{ +use datafusion_common::{ error::{DataFusionError, Result}, - logical_expr::Accumulator, scalar::ScalarValue, }; +use datafusion_expr::Accumulator; use prost::Message; - -use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; +use std::ptr::null_mut; +use std::{ffi::c_void, ops::Deref}; /// A stable struct for sharing [`Accumulator`] across FFI boundaries. /// For an explanation of each field, see the corresponding function @@ -70,6 +69,10 @@ pub struct FFI_Accumulator { /// Internal data. This is only to be accessed by the provider of the accumulator. /// A [`ForeignAccumulator`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_Accumulator {} @@ -173,9 +176,11 @@ unsafe extern "C" fn retract_batch_fn_wrapper( } unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { - let private_data = - Box::from_raw(accumulator.private_data as *mut AccumulatorPrivateData); - drop(private_data); + if !accumulator.private_data.is_null() { + let private_data = + Box::from_raw(accumulator.private_data as *mut AccumulatorPrivateData); + drop(private_data); + } } impl From> for FFI_Accumulator { @@ -193,6 +198,7 @@ impl From> for FFI_Accumulator { supports_retract_batch, release: release_fn_wrapper, private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -217,9 +223,20 @@ pub struct ForeignAccumulator { unsafe impl Send for ForeignAccumulator {} unsafe impl Sync for ForeignAccumulator {} -impl From for ForeignAccumulator { - fn from(accumulator: FFI_Accumulator) -> Self { - Self { accumulator } +impl From for Box { + fn from(mut accumulator: FFI_Accumulator) -> Self { + if (accumulator.library_marker_id)() == crate::get_library_marker_id() { + unsafe { + let private_data = Box::from_raw( + accumulator.private_data as *mut AccumulatorPrivateData, + ); + // We must set this to null to avoid a double free + accumulator.private_data = null_mut(); + private_data.accumulator + } + } else { + Box::new(ForeignAccumulator { accumulator }) + } } } @@ -313,7 +330,7 @@ mod tests { scalar::ScalarValue, }; - use super::{FFI_Accumulator, ForeignAccumulator}; + use super::FFI_Accumulator; #[test] fn test_foreign_avg_accumulator() -> Result<()> { @@ -322,8 +339,9 @@ mod tests { let original_supports_retract = original_accum.supports_retract_batch(); let boxed_accum: Box = Box::new(original_accum); - let ffi_accum: FFI_Accumulator = boxed_accum.into(); - let mut foreign_accum: ForeignAccumulator = ffi_accum.into(); + let mut ffi_accum: FFI_Accumulator = boxed_accum.into(); + ffi_accum.library_marker_id = crate::mock_foreign_marker_id; + let mut foreign_accum: Box = ffi_accum.into(); // Send in an array to average. There are 5 values and it should average to 30.0 let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs index 6ac0a0b21d2d..ce406fb6ea67 100644 --- a/datafusion/ffi/src/udaf/accumulator_args.rs +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -18,27 +18,26 @@ use std::sync::Arc; use crate::arrow_wrappers::WrappedSchema; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; use abi_stable::{ std_types::{RString, RVec}, StableAbi, }; use arrow::{datatypes::Schema, ffi::FFI_ArrowSchema}; use arrow_schema::FieldRef; -use datafusion::{ - error::DataFusionError, - logical_expr::function::AccumulatorArgs, - physical_expr::{PhysicalExpr, PhysicalSortExpr}, - prelude::SessionContext, -}; +use datafusion_common::error::DataFusionError; use datafusion_common::exec_datafusion_err; -use datafusion_proto::{ - physical_plan::{ - from_proto::{parse_physical_exprs, parse_physical_sort_exprs}, - to_proto::{serialize_physical_exprs, serialize_physical_sort_exprs}, - DefaultPhysicalExtensionCodec, - }, - protobuf::PhysicalAggregateExprNode, +use datafusion_execution::TaskContext; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use datafusion_proto::physical_plan::from_proto::{ + parse_physical_exprs, parse_physical_sort_exprs, +}; +use datafusion_proto::physical_plan::to_proto::{ + serialize_physical_exprs, serialize_physical_sort_exprs, }; +use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; +use datafusion_proto::protobuf::PhysicalAggregateExprNode; use prost::Message; /// A stable struct for sharing [`AccumulatorArgs`] across FFI boundaries. @@ -50,15 +49,22 @@ use prost::Message; pub struct FFI_AccumulatorArgs { return_field: WrappedSchema, schema: WrappedSchema, + ignore_nulls: bool, + is_distinct: bool, is_reversed: bool, name: RString, physical_expr_def: RVec, -} -impl TryFrom> for FFI_AccumulatorArgs { - type Error = DataFusionError; + /// Accessor for TaskContext to be used during protobuf serialization + /// and deserialization. + task_ctx_accessor: FFI_TaskContextAccessor, +} - fn try_from(args: AccumulatorArgs) -> Result { +impl FFI_AccumulatorArgs { + pub fn try_new( + args: AccumulatorArgs, + task_ctx_accessor: FFI_TaskContextAccessor, + ) -> Result { let return_field = WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?); let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?); @@ -84,8 +90,11 @@ impl TryFrom> for FFI_AccumulatorArgs { return_field, schema, is_reversed: args.is_reversed, + ignore_nulls: args.ignore_nulls, + is_distinct: args.is_distinct, name: args.name.into(), physical_expr_def, + task_ctx_accessor, }) } } @@ -120,8 +129,7 @@ impl TryFrom for ForeignAccumulatorArgs { let return_field = Arc::new((&value.return_field.0).try_into()?); let schema = Schema::try_from(&value.schema.0)?; - let default_ctx = SessionContext::new(); - let task_ctx = default_ctx.task_ctx(); + let task_ctx: Arc = (&value.task_ctx_accessor).try_into()?; let codex = DefaultPhysicalExtensionCodec {}; let order_bys = parse_physical_sort_exprs( @@ -142,11 +150,11 @@ impl TryFrom for ForeignAccumulatorArgs { return_field, schema, expr_fields, - ignore_nulls: proto_def.ignore_nulls, + ignore_nulls: value.ignore_nulls, order_bys, is_reversed: value.is_reversed, name: value.name.to_string(), - is_distinct: proto_def.distinct, + is_distinct: value.is_distinct, exprs, }) } @@ -172,10 +180,13 @@ impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> { mod tests { use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::prelude::SessionContext; use datafusion::{ error::Result, logical_expr::function::AccumulatorArgs, physical_expr::PhysicalSortExpr, physical_plan::expressions::col, }; + use datafusion_execution::TaskContextAccessor; + use std::sync::Arc; #[test] fn test_round_trip_accumulator_args() -> Result<()> { @@ -192,8 +203,10 @@ mod tests { exprs: &[col("a", &schema)?], }; let orig_str = format!("{orig_args:?}"); + let ctx = Arc::new(SessionContext::new()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; - let ffi_args: FFI_AccumulatorArgs = orig_args.try_into()?; + let ffi_args = FFI_AccumulatorArgs::try_new(orig_args, task_ctx_accessor.into())?; let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?; let round_trip_args: AccumulatorArgs = (&foreign_args).into(); diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index 58a18c69db7c..537c6f4a8d20 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, ops::Deref, sync::Arc}; - use crate::{ arrow_wrappers::{WrappedArray, WrappedSchema}, df_result, rresult, rresult_return, @@ -30,10 +28,10 @@ use arrow::{ error::ArrowError, ffi::to_ffi, }; -use datafusion::{ - error::{DataFusionError, Result}, - logical_expr::{EmitTo, GroupsAccumulator}, -}; +use datafusion_common::error::{DataFusionError, Result}; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use std::ptr::null_mut; +use std::{ffi::c_void, ops::Deref, sync::Arc}; /// A stable struct for sharing [`GroupsAccumulator`] across FFI boundaries. /// For an explanation of each field, see the corresponding function @@ -86,6 +84,10 @@ pub struct FFI_GroupsAccumulator { /// Internal data. This is only to be accessed by the provider of the accumulator. /// A [`ForeignGroupsAccumulator`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_GroupsAccumulator {} @@ -215,9 +217,11 @@ unsafe extern "C" fn convert_to_state_fn_wrapper( } unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) { - let private_data = - Box::from_raw(accumulator.private_data as *mut GroupsAccumulatorPrivateData); - drop(private_data); + if !accumulator.private_data.is_null() { + let private_data = + Box::from_raw(accumulator.private_data as *mut GroupsAccumulatorPrivateData); + drop(private_data); + } } impl From> for FFI_GroupsAccumulator { @@ -236,6 +240,7 @@ impl From> for FFI_GroupsAccumulator { release: release_fn_wrapper, private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -260,9 +265,19 @@ pub struct ForeignGroupsAccumulator { unsafe impl Send for ForeignGroupsAccumulator {} unsafe impl Sync for ForeignGroupsAccumulator {} -impl From for ForeignGroupsAccumulator { - fn from(accumulator: FFI_GroupsAccumulator) -> Self { - Self { accumulator } +impl From for Box { + fn from(mut accumulator: FFI_GroupsAccumulator) -> Self { + if (accumulator.library_marker_id)() == crate::get_library_marker_id() { + unsafe { + let private_data = Box::from_raw( + accumulator.private_data as *mut GroupsAccumulatorPrivateData, + ); + accumulator.private_data = null_mut(); + private_data.accumulator + } + } else { + Box::new(ForeignGroupsAccumulator { accumulator }) + } } } @@ -436,14 +451,15 @@ mod tests { }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; - use super::{FFI_EmitTo, FFI_GroupsAccumulator, ForeignGroupsAccumulator}; + use super::{FFI_EmitTo, FFI_GroupsAccumulator}; #[test] - fn test_foreign_avg_accumulator() -> Result<()> { + fn test_foreign_bool_groups_accumulator() -> Result<()> { let boxed_accum: Box = Box::new(BooleanGroupsAccumulator::new(|a, b| a && b, true)); - let ffi_accum: FFI_GroupsAccumulator = boxed_accum.into(); - let mut foreign_accum: ForeignGroupsAccumulator = ffi_accum.into(); + let mut ffi_accum: FFI_GroupsAccumulator = boxed_accum.into(); + ffi_accum.library_marker_id = crate::mock_foreign_marker_id; + let mut foreign_accum: Box = ffi_accum.into(); // Send in an array to evaluate. We want a mean of 30 and standard deviation of 4. let values = create_array!(Boolean, vec![true, true, true, false, true, true]); diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index ce5611590b67..7b7bb6ac9871 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -19,30 +19,28 @@ use abi_stable::{ std_types::{ROption, RResult, RStr, RString, RVec}, StableAbi, }; -use accumulator::{FFI_Accumulator, ForeignAccumulator}; +use accumulator::FFI_Accumulator; use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; use arrow::datatypes::{DataType, Field}; use arrow::ffi::FFI_ArrowSchema; use arrow_schema::FieldRef; -use datafusion::{ - error::DataFusionError, - logical_expr::{ - function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, - type_coercion::functions::fields_with_aggregate_udf, - utils::AggregateOrderSensitivity, - Accumulator, GroupsAccumulator, - }, -}; -use datafusion::{ - error::Result, - logical_expr::{AggregateUDF, AggregateUDFImpl, Signature}, +use datafusion_common::error::DataFusionError; +use datafusion_expr::{ + function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + type_coercion::functions::fields_with_aggregate_udf, + utils::AggregateOrderSensitivity, + Accumulator, GroupsAccumulator, }; + +use datafusion_common::error::Result; use datafusion_common::exec_datafusion_err; +use datafusion_expr::{AggregateUDF, AggregateUDFImpl, Signature}; use datafusion_proto_common::from_proto::parse_proto_fields_to_fields; -use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; +use groups_accumulator::FFI_GroupsAccumulator; use std::hash::{Hash, Hasher}; use std::{ffi::c_void, sync::Arc}; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; use crate::{ arrow_wrappers::WrappedSchema, @@ -135,6 +133,10 @@ pub struct FFI_AggregateUDF { arg_types: RVec, ) -> RResult, RString>, + /// Accessor for TaskContext to be used during protobuf serialization + /// and deserialization. + task_ctx_accessor: FFI_TaskContextAccessor, + /// Used to create a clone on the provider of the udaf. This should /// only need to be called by the receiver of the udaf. pub clone: unsafe extern "C" fn(udaf: &Self) -> Self, @@ -145,6 +147,10 @@ pub struct FFI_AggregateUDF { /// Internal data. This is only to be accessed by the provider of the udaf. /// A [`ForeignAggregateUDF`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_AggregateUDF {} @@ -155,9 +161,9 @@ pub struct AggregateUDFPrivateData { } impl FFI_AggregateUDF { - unsafe fn inner(&self) -> &Arc { + fn inner(&self) -> &Arc { let private_data = self.private_data as *const AggregateUDFPrivateData; - &(*private_data).udaf + unsafe { &(*private_data).udaf } } } @@ -236,6 +242,7 @@ unsafe extern "C" fn with_beneficial_ordering_fn_wrapper( udaf: &FFI_AggregateUDF, beneficial_ordering: bool, ) -> RResult, RString> { + let task_ctx_accessor = udaf.task_ctx_accessor.clone(); let udaf = udaf.inner().as_ref().clone(); let result = rresult_return!(udaf.with_beneficial_ordering(beneficial_ordering)); @@ -243,7 +250,7 @@ unsafe extern "C" fn with_beneficial_ordering_fn_wrapper( .map(|func| func.with_beneficial_ordering(beneficial_ordering)) .transpose()) .flatten() - .map(|func| FFI_AggregateUDF::from(Arc::new(func))); + .map(|func| FFI_AggregateUDF::new(Arc::new(func), task_ctx_accessor)); RResult::ROk(result.into()) } @@ -326,7 +333,7 @@ unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) { } unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF { - Arc::clone(udaf.inner()).into() + FFI_AggregateUDF::new(Arc::clone(udaf.inner()), udaf.task_ctx_accessor.clone()) } impl Clone for FFI_AggregateUDF { @@ -335,8 +342,11 @@ impl Clone for FFI_AggregateUDF { } } -impl From> for FFI_AggregateUDF { - fn from(udaf: Arc) -> Self { +impl FFI_AggregateUDF { + pub fn new( + udaf: Arc, + task_ctx_accessor: FFI_TaskContextAccessor, + ) -> Self { let name = udaf.name().into(); let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect(); let is_nullable = udaf.is_nullable(); @@ -358,9 +368,11 @@ impl From> for FFI_AggregateUDF { state_fields: state_fields_fn_wrapper, order_sensitivity: order_sensitivity_fn_wrapper, coerce_types: coerce_types_fn_wrapper, + task_ctx_accessor, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -400,18 +412,22 @@ impl Hash for ForeignAggregateUDF { } } -impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF { +impl TryFrom<&FFI_AggregateUDF> for Arc { type Error = DataFusionError; fn try_from(udaf: &FFI_AggregateUDF) -> Result { + if (udaf.library_marker_id)() == crate::get_library_marker_id() { + return Ok(Arc::clone(udaf.inner().inner())); + } + let signature = Signature::user_defined((&udaf.volatility).into()); let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect(); - Ok(Self { + Ok(Arc::new(ForeignAggregateUDF { udaf: udaf.clone(), signature, aliases, - }) + })) } } @@ -451,11 +467,11 @@ impl AggregateUDFImpl for ForeignAggregateUDF { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let args = acc_args.try_into()?; + let args = + FFI_AccumulatorArgs::try_new(acc_args, self.udaf.task_ctx_accessor.clone())?; unsafe { - df_result!((self.udaf.accumulator)(&self.udaf, args)).map(|accum| { - Box::new(ForeignAccumulator::from(accum)) as Box - }) + df_result!((self.udaf.accumulator)(&self.udaf, args)) + .map(>::from) } } @@ -499,13 +515,15 @@ impl AggregateUDFImpl for ForeignAggregateUDF { } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { - let args = match FFI_AccumulatorArgs::try_from(args) { - Ok(v) => v, - Err(e) => { - log::warn!("Attempting to convert accumulator arguments: {e}"); - return false; - } - }; + let args = + match FFI_AccumulatorArgs::try_new(args, self.udaf.task_ctx_accessor.clone()) + { + Ok(v) => v, + Err(e) => { + log::warn!("Attempting to convert accumulator arguments: {e}"); + return false; + } + }; unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) } } @@ -514,15 +532,12 @@ impl AggregateUDFImpl for ForeignAggregateUDF { &self, args: AccumulatorArgs, ) -> Result> { - let args = FFI_AccumulatorArgs::try_from(args)?; + let args = + FFI_AccumulatorArgs::try_new(args, self.udaf.task_ctx_accessor.clone())?; unsafe { - df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args)).map( - |accum| { - Box::new(ForeignGroupsAccumulator::from(accum)) - as Box - }, - ) + df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args)) + .map(>::from) } } @@ -534,11 +549,11 @@ impl AggregateUDFImpl for ForeignAggregateUDF { &self, args: AccumulatorArgs, ) -> Result> { - let args = args.try_into()?; + let args = + FFI_AccumulatorArgs::try_new(args, self.udaf.task_ctx_accessor.clone())?; unsafe { - df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args)).map( - |accum| Box::new(ForeignAccumulator::from(accum)) as Box, - ) + df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args)) + .map(>::from) } } @@ -553,11 +568,9 @@ impl AggregateUDFImpl for ForeignAggregateUDF { ))? .into_option(); - let result = result - .map(|func| ForeignAggregateUDF::try_from(&func)) - .transpose()?; - - Ok(result.map(|func| Arc::new(func) as Arc)) + result + .map(|func| >::try_from(&func)) + .transpose() } } @@ -613,17 +626,18 @@ impl From for FFI_AggregateOrderSensitivity { #[cfg(test)] mod tests { + use super::*; use arrow::datatypes::Schema; + use datafusion::prelude::SessionContext; use datafusion::{ common::create_array, functions_aggregate::sum::Sum, physical_expr::PhysicalSortExpr, physical_plan::expressions::col, scalar::ScalarValue, }; + use datafusion_execution::TaskContextAccessor; use std::any::Any; use std::collections::HashMap; - use super::*; - #[derive(Default, Debug, Hash, Eq, PartialEq)] struct SumWithCopiedMetadata { inner: Sum, @@ -658,13 +672,17 @@ mod tests { fn create_test_foreign_udaf( original_udaf: impl AggregateUDFImpl + 'static, + ctx: &Arc, ) -> Result { + let task_ctx_accessor = Arc::clone(ctx) as Arc; let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); - let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + let mut local_udaf = + FFI_AggregateUDF::new(Arc::clone(&original_udaf), task_ctx_accessor.into()); + local_udaf.library_marker_id = crate::mock_foreign_marker_id; - let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; - Ok(foreign_udaf.into()) + let foreign_udaf: Arc = (&local_udaf).try_into()?; + Ok(AggregateUDF::new_from_shared_impl(foreign_udaf)) } #[test] @@ -672,13 +690,17 @@ mod tests { let original_udaf = Sum::new(); let original_name = original_udaf.name().to_owned(); let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + let task_ctx_accessor = + Arc::new(SessionContext::new()) as Arc; // Convert to FFI format - let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + let mut local_udaf = + FFI_AggregateUDF::new(Arc::clone(&original_udaf), task_ctx_accessor.into()); + local_udaf.library_marker_id = crate::mock_foreign_marker_id; // Convert back to native format - let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; - let foreign_udaf: AggregateUDF = foreign_udaf.into(); + let foreign_udaf: Arc = (&local_udaf).try_into()?; + let foreign_udaf = AggregateUDF::new_from_shared_impl(foreign_udaf); assert_eq!(original_name, foreign_udaf.name()); Ok(()) @@ -686,8 +708,9 @@ mod tests { #[test] fn test_foreign_udaf_aliases() -> Result<()> { + let ctx = Arc::new(SessionContext::new()); let foreign_udaf = - create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]); + create_test_foreign_udaf(Sum::new(), &ctx)?.with_aliases(["my_function"]); let return_field = foreign_udaf @@ -699,7 +722,8 @@ mod tests { #[test] fn test_foreign_udaf_accumulator() -> Result<()> { - let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + let ctx = Arc::new(SessionContext::new()); + let foreign_udaf = create_test_foreign_udaf(Sum::new(), &ctx)?; let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); let acc_args = AccumulatorArgs { @@ -726,13 +750,17 @@ mod tests { fn test_round_trip_udaf_metadata() -> Result<()> { let original_udaf = SumWithCopiedMetadata::default(); let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + let task_ctx_accessor = + Arc::new(SessionContext::new()) as Arc; // Convert to FFI format - let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + let mut local_udaf = + FFI_AggregateUDF::new(Arc::clone(&original_udaf), task_ctx_accessor.into()); + local_udaf.library_marker_id = crate::mock_foreign_marker_id; // Convert back to native format - let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; - let foreign_udaf: AggregateUDF = foreign_udaf.into(); + let foreign_udaf: Arc = (&local_udaf).try_into()?; + let foreign_udaf = AggregateUDF::new_from_shared_impl(foreign_udaf); let metadata: HashMap = [("a_key".to_string(), "a_value".to_string())] @@ -749,8 +777,10 @@ mod tests { #[test] fn test_beneficial_ordering() -> Result<()> { + let ctx = Arc::new(SessionContext::new()); let foreign_udaf = create_test_foreign_udaf( datafusion::functions_aggregate::first_last::FirstValue::new(), + &ctx, )?; let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap(); @@ -776,7 +806,8 @@ mod tests { #[test] fn test_sliding_accumulator() -> Result<()> { - let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + let ctx = Arc::new(SessionContext::new()); + let foreign_udaf = create_test_foreign_udaf(Sum::new(), &ctx)?; let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); // Note: sum distinct is only support Int64 until now diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 5e59cfc5ecb0..ffd2ca0d8a56 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -32,17 +32,12 @@ use arrow::{ ffi::{from_ffi, to_ffi, FFI_ArrowSchema}, }; use arrow_schema::FieldRef; -use datafusion::config::ConfigOptions; -use datafusion::logical_expr::ReturnFieldArgs; -use datafusion::{ - error::DataFusionError, - logical_expr::type_coercion::functions::data_types_with_scalar_udf, -}; -use datafusion::{ - error::Result, - logical_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, - }, +use datafusion_common::config::ConfigOptions; +use datafusion_common::error::{DataFusionError, Result}; +use datafusion_common::not_impl_err; +use datafusion_expr::{ + type_coercion::functions::data_types_with_scalar_udf, ColumnarValue, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, }; use return_type_args::{ FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned, @@ -66,13 +61,6 @@ pub struct FFI_ScalarUDF { /// FFI equivalent to the `volatility` of a [`ScalarUDF`] pub volatility: FFI_Volatility, - /// Determines the return type of the underlying [`ScalarUDF`] based on the - /// argument types. - pub return_type: unsafe extern "C" fn( - udf: &Self, - arg_types: RVec, - ) -> RResult, - /// Determines the return info of the underlying [`ScalarUDF`]. Either this /// or return_type may be implemented on a UDF. pub return_field_from_args: unsafe extern "C" fn( @@ -114,6 +102,10 @@ pub struct FFI_ScalarUDF { /// Internal data. This is only to be accessed by the provider of the udf. /// A [`ForeignScalarUDF`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_ScalarUDF {} @@ -123,29 +115,18 @@ pub struct ScalarUDFPrivateData { pub udf: Arc, } -unsafe extern "C" fn return_type_fn_wrapper( - udf: &FFI_ScalarUDF, - arg_types: RVec, -) -> RResult { - let private_data = udf.private_data as *const ScalarUDFPrivateData; - let udf = &(*private_data).udf; - - let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); - - let return_type = udf - .return_type(&arg_types) - .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from)) - .map(WrappedSchema); - - rresult!(return_type) +impl FFI_ScalarUDF { + fn inner(&self) -> &Arc { + let private_data = self.private_data as *const ScalarUDFPrivateData; + unsafe { &(*private_data).udf } + } } unsafe extern "C" fn return_field_from_args_fn_wrapper( udf: &FFI_ScalarUDF, args: FFI_ReturnFieldArgs, ) -> RResult { - let private_data = udf.private_data as *const ScalarUDFPrivateData; - let udf = &(*private_data).udf; + let udf = udf.inner(); let args: ForeignReturnFieldArgsOwned = rresult_return!((&args).try_into()); let args_ref: ForeignReturnFieldArgs = (&args).into(); @@ -162,8 +143,7 @@ unsafe extern "C" fn coerce_types_fn_wrapper( udf: &FFI_ScalarUDF, arg_types: RVec, ) -> RResult, RString> { - let private_data = udf.private_data as *const ScalarUDFPrivateData; - let udf = &(*private_data).udf; + let udf = udf.inner(); let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); @@ -179,8 +159,7 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( number_rows: usize, return_field: WrappedSchema, ) -> RResult { - let private_data = udf.private_data as *const ScalarUDFPrivateData; - let udf = &(*private_data).udf; + let udf = udf.inner(); let args = args .into_iter() @@ -230,10 +209,9 @@ unsafe extern "C" fn release_fn_wrapper(udf: &mut FFI_ScalarUDF) { } unsafe extern "C" fn clone_fn_wrapper(udf: &FFI_ScalarUDF) -> FFI_ScalarUDF { - let private_data = udf.private_data as *const ScalarUDFPrivateData; - let udf_data = &(*private_data); + let udf = udf.inner(); - Arc::clone(&udf_data.udf).into() + Arc::clone(udf).into() } impl Clone for FFI_ScalarUDF { @@ -242,6 +220,12 @@ impl Clone for FFI_ScalarUDF { } } +impl From<&Arc> for FFI_ScalarUDF { + fn from(udf: &Arc) -> Self { + Arc::clone(udf).into() + } +} + impl From> for FFI_ScalarUDF { fn from(udf: Arc) -> Self { let name = udf.name().into(); @@ -257,12 +241,12 @@ impl From> for FFI_ScalarUDF { volatility, short_circuits, invoke_with_args: invoke_with_args_fn_wrapper, - return_type: return_type_fn_wrapper, return_field_from_args: return_field_from_args_fn_wrapper, coerce_types: coerce_types_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -321,21 +305,25 @@ impl Hash for ForeignScalarUDF { } } -impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF { +impl TryFrom<&FFI_ScalarUDF> for Arc { type Error = DataFusionError; fn try_from(udf: &FFI_ScalarUDF) -> Result { - let name = udf.name.to_owned().into(); - let signature = Signature::user_defined((&udf.volatility).into()); - - let aliases = udf.aliases.iter().map(|s| s.to_string()).collect(); - - Ok(Self { - name, - udf: udf.clone(), - aliases, - signature, - }) + if (udf.library_marker_id)() == crate::get_library_marker_id() { + Ok(Arc::clone(udf.inner().inner())) + } else { + let name = udf.name.to_owned().into(); + let signature = Signature::user_defined((&udf.volatility).into()); + + let aliases = udf.aliases.iter().map(|s| s.to_string()).collect(); + + Ok(Arc::new(ForeignScalarUDF { + name, + udf: udf.clone(), + aliases, + signature, + })) + } } } @@ -352,14 +340,8 @@ impl ScalarUDFImpl for ForeignScalarUDF { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; - - let result = unsafe { (self.udf.return_type)(&self.udf, arg_types) }; - - let result = df_result!(result); - - result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + not_impl_err!("return_type is not implemented since return_field_from_args is.") } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { @@ -455,9 +437,10 @@ mod tests { let original_udf = datafusion::functions::math::abs::AbsFunc::new(); let original_udf = Arc::new(ScalarUDF::from(original_udf)); - let local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into(); + let mut local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into(); + local_udf.library_marker_id = crate::mock_foreign_marker_id; - let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?; + let foreign_udf: Arc = (&local_udf).try_into()?; assert_eq!(original_udf.name(), foreign_udf.name()); diff --git a/datafusion/ffi/src/udf/return_type_args.rs b/datafusion/ffi/src/udf/return_type_args.rs index c437c9537be6..463b3d28a9f3 100644 --- a/datafusion/ffi/src/udf/return_type_args.rs +++ b/datafusion/ffi/src/udf/return_type_args.rs @@ -20,13 +20,13 @@ use abi_stable::{ StableAbi, }; use arrow_schema::FieldRef; -use datafusion::{ - common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnFieldArgs, - scalar::ScalarValue, +use datafusion_common::{ + error::DataFusionError, exec_datafusion_err, scalar::ScalarValue, }; use crate::arrow_wrappers::WrappedSchema; use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; +use datafusion_expr::ReturnFieldArgs; use prost::Message; /// A stable struct for sharing a [`ReturnFieldArgs`] across FFI boundaries. diff --git a/datafusion/ffi/src/udtf.rs b/datafusion/ffi/src/udtf.rs index edd5273c70a8..1b574cc7aaa9 100644 --- a/datafusion/ffi/src/udtf.rs +++ b/datafusion/ffi/src/udtf.rs @@ -22,11 +22,12 @@ use abi_stable::{ StableAbi, }; -use datafusion::error::Result; -use datafusion::{ - catalog::{TableFunctionImpl, TableProvider}, - prelude::{Expr, SessionContext}, -}; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; +use crate::{df_result, rresult_return, table_provider::FFI_TableProvider}; +use datafusion_catalog::{TableFunctionImpl, TableProvider}; +use datafusion_common::error::Result; +use datafusion_execution::TaskContext; +use datafusion_expr::Expr; use datafusion_proto::{ logical_plan::{ from_proto::parse_exprs, to_proto::serialize_exprs, DefaultLogicalExtensionCodec, @@ -36,11 +37,6 @@ use datafusion_proto::{ use prost::Message; use tokio::runtime::Handle; -use crate::{ - df_result, rresult_return, - table_provider::{FFI_TableProvider, ForeignTableProvider}, -}; - /// A stable struct for sharing a [`TableFunctionImpl`] across FFI boundaries. #[repr(C)] #[derive(Debug, StableAbi)] @@ -53,6 +49,10 @@ pub struct FFI_TableFunction { args: RVec, ) -> RResult, + /// Accessor for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_accessor: FFI_TaskContextAccessor, + /// Used to create a clone on the provider of the udtf. This should /// only need to be called by the receiver of the udtf. pub clone: unsafe extern "C" fn(udtf: &Self) -> Self, @@ -63,6 +63,10 @@ pub struct FFI_TableFunction { /// Internal data. This is only to be accessed by the provider of the udtf. /// A [`ForeignTableFunction`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_TableFunction {} @@ -89,19 +93,30 @@ unsafe extern "C" fn call_fn_wrapper( udtf: &FFI_TableFunction, args: RVec, ) -> RResult { + let task_ctx_accessor = udtf.task_ctx_accessor.clone(); + let task_ctx: Arc = + rresult_return!((&udtf.task_ctx_accessor).try_into()); + let runtime = udtf.runtime(); let udtf = udtf.inner(); - let default_ctx = SessionContext::new(); let codec = DefaultLogicalExtensionCodec {}; let proto_filters = rresult_return!(LogicalExprList::decode(args.as_ref())); - let args = - rresult_return!(parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec)); + let args = rresult_return!(parse_exprs( + proto_filters.expr.iter(), + task_ctx.as_ref(), + &codec + )); let table_provider = rresult_return!(udtf.call(&args)); - RResult::ROk(FFI_TableProvider::new(table_provider, false, runtime)) + RResult::ROk(FFI_TableProvider::new( + table_provider, + false, + runtime, + task_ctx_accessor, + )) } unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) { @@ -110,10 +125,11 @@ unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) { } unsafe extern "C" fn clone_fn_wrapper(udtf: &FFI_TableFunction) -> FFI_TableFunction { + let task_ctx_accessor = udtf.task_ctx_accessor.clone(); let runtime = udtf.runtime(); let udtf = udtf.inner(); - FFI_TableFunction::new(Arc::clone(udtf), runtime) + FFI_TableFunction::new(Arc::clone(udtf), runtime, task_ctx_accessor) } impl Clone for FFI_TableFunction { @@ -123,30 +139,20 @@ impl Clone for FFI_TableFunction { } impl FFI_TableFunction { - pub fn new(udtf: Arc, runtime: Option) -> Self { + pub fn new( + udtf: Arc, + runtime: Option, + task_ctx_accessor: FFI_TaskContextAccessor, + ) -> Self { let private_data = Box::new(TableFunctionPrivateData { udtf, runtime }); Self { call: call_fn_wrapper, + task_ctx_accessor, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, - } - } -} - -impl From> for FFI_TableFunction { - fn from(udtf: Arc) -> Self { - let private_data = Box::new(TableFunctionPrivateData { - udtf, - runtime: None, - }); - - Self { - call: call_fn_wrapper, - clone: clone_fn_wrapper, - release: release_fn_wrapper, - private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -169,9 +175,13 @@ pub struct ForeignTableFunction(FFI_TableFunction); unsafe impl Send for ForeignTableFunction {} unsafe impl Sync for ForeignTableFunction {} -impl From for ForeignTableFunction { +impl From for Arc { fn from(value: FFI_TableFunction) -> Self { - Self(value) + if (value.library_marker_id)() == crate::get_library_marker_id() { + Arc::clone(value.inner()) + } else { + Arc::new(ForeignTableFunction(value)) + } } } @@ -186,25 +196,25 @@ impl TableFunctionImpl for ForeignTableFunction { let table_provider = unsafe { (self.0.call)(&self.0, filters_serialized) }; let table_provider = df_result!(table_provider)?; - let table_provider: ForeignTableProvider = (&table_provider).into(); - Ok(Arc::new(table_provider)) + Ok((&table_provider).into()) } } #[cfg(test)] mod tests { + use super::*; use arrow::{ array::{ record_batch, ArrayRef, Float64Array, RecordBatch, StringArray, UInt64Array, }, datatypes::{DataType, Field, Schema}, }; + use datafusion::prelude::SessionContext; use datafusion::{ catalog::MemTable, common::exec_err, prelude::lit, scalar::ScalarValue, }; - - use super::*; + use datafusion_execution::TaskContextAccessor; #[derive(Debug)] struct TestUDTF {} @@ -288,14 +298,23 @@ mod tests { async fn test_round_trip_udtf() -> Result<()> { let original_udtf = Arc::new(TestUDTF {}) as Arc; - let local_udtf: FFI_TableFunction = - FFI_TableFunction::new(Arc::clone(&original_udtf), None); + let ctx = Arc::new(SessionContext::default()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; + let mut local_udtf: FFI_TableFunction = FFI_TableFunction::new( + Arc::clone(&original_udtf), + None, + task_ctx_accessor.into(), + ); + + // Add unit test coverage to check for memory leaks on clone + let _ = local_udtf.clone(); + + local_udtf.library_marker_id = crate::mock_foreign_marker_id; - let foreign_udf: ForeignTableFunction = local_udtf.into(); + let foreign_udf: Arc = local_udtf.into(); let table = foreign_udf.call(&[lit(6_u64), lit("one"), lit(2.0), lit(3_u64)])?; - let ctx = SessionContext::default(); let _ = ctx.register_table("test-table", table)?; let returned_batches = ctx.table("test-table").await?.collect().await?; diff --git a/datafusion/ffi/src/udwf/mod.rs b/datafusion/ffi/src/udwf/mod.rs index 9f56e2d4788b..110ba225f28a 100644 --- a/datafusion/ffi/src/udwf/mod.rs +++ b/datafusion/ffi/src/udwf/mod.rs @@ -25,21 +25,15 @@ use arrow::{ datatypes::{DataType, SchemaRef}, }; use arrow_schema::{Field, FieldRef}; -use datafusion::logical_expr::LimitEffect; -use datafusion::physical_expr::PhysicalExpr; -use datafusion::{ - error::DataFusionError, - logical_expr::{ - function::WindowUDFFieldArgs, type_coercion::functions::fields_with_window_udf, - PartitionEvaluator, - }, -}; -use datafusion::{ - error::Result, - logical_expr::{Signature, WindowUDF, WindowUDFImpl}, -}; +use datafusion_common::error::{DataFusionError, Result}; use datafusion_common::exec_err; -use partition_evaluator::{FFI_PartitionEvaluator, ForeignPartitionEvaluator}; +use datafusion_expr::function::WindowUDFFieldArgs; +use datafusion_expr::type_coercion::functions::fields_with_window_udf; +use datafusion_expr::{ + LimitEffect, PartitionEvaluator, Signature, WindowUDF, WindowUDFImpl, +}; +use datafusion_physical_expr::PhysicalExpr; +use partition_evaluator::FFI_PartitionEvaluator; use partition_evaluator_args::{ FFI_PartitionEvaluatorArgs, ForeignPartitionEvaluatorArgs, }; @@ -50,6 +44,7 @@ mod partition_evaluator; mod partition_evaluator_args; mod range; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; use crate::{ arrow_wrappers::WrappedSchema, @@ -95,6 +90,10 @@ pub struct FFI_WindowUDF { pub sort_options: ROption, + /// Accessor for TaskContext to be used during protobuf serialization + /// and deserialization. + task_ctx_accessor: FFI_TaskContextAccessor, + /// Used to create a clone on the provider of the udf. This should /// only need to be called by the receiver of the udf. pub clone: unsafe extern "C" fn(udf: &Self) -> Self, @@ -105,6 +104,10 @@ pub struct FFI_WindowUDF { /// Internal data. This is only to be accessed by the provider of the udf. /// A [`ForeignWindowUDF`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_WindowUDF {} @@ -115,9 +118,9 @@ pub struct WindowUDFPrivateData { } impl FFI_WindowUDF { - unsafe fn inner(&self) -> &Arc { + fn inner(&self) -> &Arc { let private_data = self.private_data as *const WindowUDFPrivateData; - &(*private_data).udf + unsafe { &(*private_data).udf } } } @@ -199,8 +202,10 @@ unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF { coerce_types: coerce_types_fn_wrapper, field: field_fn_wrapper, clone: clone_fn_wrapper, + task_ctx_accessor: udwf.task_ctx_accessor.clone(), release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } @@ -210,8 +215,8 @@ impl Clone for FFI_WindowUDF { } } -impl From> for FFI_WindowUDF { - fn from(udf: Arc) -> Self { +impl FFI_WindowUDF { + pub fn new(udf: Arc, task_ctx_accessor: FFI_TaskContextAccessor) -> Self { let name = udf.name().into(); let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect(); let volatility = udf.signature().volatility.into(); @@ -228,8 +233,10 @@ impl From> for FFI_WindowUDF { coerce_types: coerce_types_fn_wrapper, field: field_fn_wrapper, clone: clone_fn_wrapper, + task_ctx_accessor, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -270,21 +277,25 @@ impl Hash for ForeignWindowUDF { } } -impl TryFrom<&FFI_WindowUDF> for ForeignWindowUDF { +impl TryFrom<&FFI_WindowUDF> for Arc { type Error = DataFusionError; fn try_from(udf: &FFI_WindowUDF) -> Result { - let name = udf.name.to_owned().into(); - let signature = Signature::user_defined((&udf.volatility).into()); - - let aliases = udf.aliases.iter().map(|s| s.to_string()).collect(); - - Ok(Self { - name, - udf: udf.clone(), - aliases, - signature, - }) + if (udf.library_marker_id)() == crate::get_library_marker_id() { + Ok(Arc::clone(udf.inner().inner())) + } else { + let name = udf.name.to_owned().into(); + let signature = Signature::user_defined((&udf.volatility).into()); + + let aliases = udf.aliases.iter().map(|s| s.to_string()).collect(); + + Ok(Arc::new(ForeignWindowUDF { + name, + udf: udf.clone(), + aliases, + signature, + })) + } } } @@ -315,17 +326,17 @@ impl WindowUDFImpl for ForeignWindowUDF { fn partition_evaluator( &self, - args: datafusion::logical_expr::function::PartitionEvaluatorArgs, + args: datafusion_expr::function::PartitionEvaluatorArgs, ) -> Result> { let evaluator = unsafe { - let args = FFI_PartitionEvaluatorArgs::try_from(args)?; + let args = FFI_PartitionEvaluatorArgs::try_new( + args, + self.udf.task_ctx_accessor.clone(), + )?; (self.udf.partition_evaluator)(&self.udf, args) }; - df_result!(evaluator).map(|evaluator| { - Box::new(ForeignPartitionEvaluator::from(evaluator)) - as Box - }) + df_result!(evaluator).map(>::from) } fn field(&self, field_args: WindowUDFFieldArgs) -> Result { @@ -387,36 +398,43 @@ impl From<&FFI_SortOptions> for SortOptions { #[cfg(feature = "integration-tests")] mod tests { use crate::tests::create_record_batch; - use crate::udwf::{FFI_WindowUDF, ForeignWindowUDF}; + use crate::udwf::FFI_WindowUDF; use arrow::array::{create_array, ArrayRef}; use datafusion::functions_window::lead_lag::{lag_udwf, WindowShift}; use datafusion::logical_expr::expr::Sort; use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF, WindowUDFImpl}; use datafusion::prelude::SessionContext; + use datafusion_execution::TaskContextAccessor; use std::sync::Arc; fn create_test_foreign_udwf( original_udwf: impl WindowUDFImpl + 'static, + ctx: Arc, ) -> datafusion::common::Result { let original_udwf = Arc::new(WindowUDF::from(original_udwf)); - let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into(); + let mut local_udwf = FFI_WindowUDF::new(Arc::clone(&original_udwf), ctx.into()); + local_udwf.library_marker_id = crate::mock_foreign_marker_id; - let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?; - Ok(foreign_udwf.into()) + let foreign_udwf: Arc = (&local_udwf).try_into()?; + Ok(WindowUDF::new_from_shared_impl(foreign_udwf)) } #[test] fn test_round_trip_udwf() -> datafusion::common::Result<()> { let original_udwf = lag_udwf(); let original_name = original_udwf.name().to_owned(); + let task_ctx_accessor = + Arc::new(SessionContext::default()) as Arc; // Convert to FFI format - let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into(); + let mut local_udwf = + FFI_WindowUDF::new(Arc::clone(&original_udwf), task_ctx_accessor.into()); + local_udwf.library_marker_id = crate::mock_foreign_marker_id; // Convert back to native format - let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?; - let foreign_udwf: WindowUDF = foreign_udwf.into(); + let foreign_udwf: Arc = (&local_udwf).try_into()?; + let foreign_udwf = WindowUDF::new_from_shared_impl(foreign_udwf); assert_eq!(original_name, foreign_udwf.name()); Ok(()) @@ -424,9 +442,12 @@ mod tests { #[tokio::test] async fn test_lag_udwf() -> datafusion::common::Result<()> { - let udwf = create_test_foreign_udwf(WindowShift::lag())?; + let ctx = Arc::new(SessionContext::default()); + let udwf = create_test_foreign_udwf( + WindowShift::lag(), + Arc::clone(&ctx) as Arc, + )?; - let ctx = SessionContext::default(); let df = ctx.read_batch(create_record_batch(-5, 5))?; let df = df.select(vec![ diff --git a/datafusion/ffi/src/udwf/partition_evaluator.rs b/datafusion/ffi/src/udwf/partition_evaluator.rs index 14cf23b919aa..ff28b0704efb 100644 --- a/datafusion/ffi/src/udwf/partition_evaluator.rs +++ b/datafusion/ffi/src/udwf/partition_evaluator.rs @@ -15,22 +15,21 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, ops::Range}; - +use super::range::FFI_Range; use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; use abi_stable::{ std_types::{RResult, RString, RVec}, StableAbi, }; use arrow::{array::ArrayRef, error::ArrowError}; -use datafusion::{ +use datafusion_common::{ error::{DataFusionError, Result}, - logical_expr::{window_state::WindowAggState, PartitionEvaluator}, scalar::ScalarValue, }; +use datafusion_expr::{window_state::WindowAggState, PartitionEvaluator}; use prost::Message; - -use super::range::FFI_Range; +use std::ptr::null_mut; +use std::{ffi::c_void, ops::Range}; /// A stable struct for sharing [`PartitionEvaluator`] across FFI boundaries. /// For an explanation of each field, see the corresponding function @@ -76,6 +75,10 @@ pub struct FFI_PartitionEvaluator { /// Internal data. This is only to be accessed by the provider of the evaluator. /// A [`ForeignPartitionEvaluator`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_PartitionEvaluator {} @@ -86,14 +89,14 @@ pub struct PartitionEvaluatorPrivateData { } impl FFI_PartitionEvaluator { - unsafe fn inner_mut(&mut self) -> &mut Box { + fn inner_mut(&mut self) -> &mut Box { let private_data = self.private_data as *mut PartitionEvaluatorPrivateData; - &mut (*private_data).evaluator + unsafe { &mut (*private_data).evaluator } } - unsafe fn inner(&self) -> &(dyn PartitionEvaluator + 'static) { + fn inner(&self) -> &(dyn PartitionEvaluator + 'static) { let private_data = self.private_data as *mut PartitionEvaluatorPrivateData; - (*private_data).evaluator.as_ref() + unsafe { (*private_data).evaluator.as_ref() } } } @@ -170,9 +173,11 @@ unsafe extern "C" fn get_range_fn_wrapper( } unsafe extern "C" fn release_fn_wrapper(evaluator: &mut FFI_PartitionEvaluator) { - let private_data = - Box::from_raw(evaluator.private_data as *mut PartitionEvaluatorPrivateData); - drop(private_data); + if !evaluator.private_data.is_null() { + let private_data = + Box::from_raw(evaluator.private_data as *mut PartitionEvaluatorPrivateData); + drop(private_data); + } } impl From> for FFI_PartitionEvaluator { @@ -195,6 +200,7 @@ impl From> for FFI_PartitionEvaluator { uses_window_frame, release: release_fn_wrapper, private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -219,9 +225,19 @@ pub struct ForeignPartitionEvaluator { unsafe impl Send for ForeignPartitionEvaluator {} unsafe impl Sync for ForeignPartitionEvaluator {} -impl From for ForeignPartitionEvaluator { - fn from(evaluator: FFI_PartitionEvaluator) -> Self { - Self { evaluator } +impl From for Box { + fn from(mut evaluator: FFI_PartitionEvaluator) -> Self { + if (evaluator.library_marker_id)() == crate::get_library_marker_id() { + unsafe { + let private_data = Box::from_raw( + evaluator.private_data as *mut PartitionEvaluatorPrivateData, + ); + evaluator.private_data = null_mut(); + private_data.evaluator + } + } else { + Box::new(ForeignPartitionEvaluator { evaluator }) + } } } diff --git a/datafusion/ffi/src/udwf/partition_evaluator_args.rs b/datafusion/ffi/src/udwf/partition_evaluator_args.rs index cd2641256437..046ad1bbf3d5 100644 --- a/datafusion/ffi/src/udwf/partition_evaluator_args.rs +++ b/datafusion/ffi/src/udwf/partition_evaluator_args.rs @@ -18,6 +18,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::arrow_wrappers::WrappedSchema; +use crate::session::task_ctx_accessor::FFI_TaskContextAccessor; use abi_stable::{std_types::RVec, StableAbi}; use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, @@ -25,20 +26,15 @@ use arrow::{ ffi::FFI_ArrowSchema, }; use arrow_schema::FieldRef; -use datafusion::{ - error::{DataFusionError, Result}, - logical_expr::function::PartitionEvaluatorArgs, - physical_plan::{expressions::Column, PhysicalExpr}, - prelude::SessionContext, -}; +use datafusion_common::error::{DataFusionError, Result}; use datafusion_common::exec_datafusion_err; -use datafusion_proto::{ - physical_plan::{ - from_proto::parse_physical_expr, to_proto::serialize_physical_exprs, - DefaultPhysicalExtensionCodec, - }, - protobuf::PhysicalExprNode, -}; +use datafusion_execution::TaskContext; +use datafusion_expr::function::PartitionEvaluatorArgs; +use datafusion_physical_plan::{expressions::Column, PhysicalExpr}; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr; +use datafusion_proto::physical_plan::to_proto::serialize_physical_exprs; +use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; +use datafusion_proto::protobuf::PhysicalExprNode; use prost::Message; /// A stable struct for sharing [`PartitionEvaluatorArgs`] across FFI boundaries. @@ -53,11 +49,17 @@ pub struct FFI_PartitionEvaluatorArgs { is_reversed: bool, ignore_nulls: bool, schema: WrappedSchema, + + /// Accessor for TaskContext to be used during protobuf serialization + /// and deserialization. + task_ctx_accessor: FFI_TaskContextAccessor, } -impl TryFrom> for FFI_PartitionEvaluatorArgs { - type Error = DataFusionError; - fn try_from(args: PartitionEvaluatorArgs) -> Result { +impl FFI_PartitionEvaluatorArgs { + pub fn try_new( + args: PartitionEvaluatorArgs, + task_ctx_accessor: FFI_TaskContextAccessor, + ) -> Result { // This is a bit of a hack. Since PartitionEvaluatorArgs does not carry a schema // around, and instead passes the data types directly we are unable to decode the // protobuf PhysicalExpr correctly. In evaluating the code the only place these @@ -117,6 +119,7 @@ impl TryFrom> for FFI_PartitionEvaluatorArgs { schema, is_reversed: args.is_reversed(), ignore_nulls: args.ignore_nulls(), + task_ctx_accessor, }) } } @@ -136,10 +139,10 @@ impl TryFrom for ForeignPartitionEvaluatorArgs { type Error = DataFusionError; fn try_from(value: FFI_PartitionEvaluatorArgs) -> Result { - let default_ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; let schema: SchemaRef = value.schema.into(); + let task_ctx: Arc = (&value.task_ctx_accessor).try_into()?; let input_exprs = value .input_exprs @@ -148,9 +151,7 @@ impl TryFrom for ForeignPartitionEvaluatorArgs { .collect::, prost::DecodeError>>() .map_err(|e| exec_datafusion_err!("Failed to decode PhysicalExprNode: {e}"))? .iter() - .map(|expr_node| { - parse_physical_expr(expr_node, &default_ctx.task_ctx(), &schema, &codec) - }) + .map(|expr_node| parse_physical_expr(expr_node, &task_ctx, &schema, &codec)) .collect::>>()?; let input_fields = input_exprs diff --git a/datafusion/ffi/src/util.rs b/datafusion/ffi/src/util.rs index 151464dc9745..9809544bdb65 100644 --- a/datafusion/ffi/src/util.rs +++ b/datafusion/ffi/src/util.rs @@ -16,12 +16,14 @@ // under the License. use crate::arrow_wrappers::WrappedSchema; -use abi_stable::std_types::RVec; +use abi_stable::std_types::{RResult, RString, RVec}; use arrow::datatypes::Field; use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; use arrow_schema::FieldRef; use std::sync::Arc; +pub type FFIResult = RResult; + /// This macro is a helpful conversion utility to convert from an abi_stable::RResult to a /// DataFusion result. #[macro_export] diff --git a/datafusion/ffi/src/volatility.rs b/datafusion/ffi/src/volatility.rs index f1705da294a3..a3faf8098b7f 100644 --- a/datafusion/ffi/src/volatility.rs +++ b/datafusion/ffi/src/volatility.rs @@ -16,7 +16,7 @@ // under the License. use abi_stable::StableAbi; -use datafusion::logical_expr::Volatility; +use datafusion_expr::Volatility; #[repr(C)] #[derive(Debug, StableAbi, Clone)] diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index eb53e76bfb9b..54fbe7b7b793 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -21,8 +21,8 @@ mod tests { use datafusion::error::{DataFusionError, Result}; use datafusion::prelude::SessionContext; - use datafusion_ffi::catalog_provider::ForeignCatalogProvider; - use datafusion_ffi::table_provider::ForeignTableProvider; + use datafusion_catalog::{CatalogProvider, TableProvider}; + use datafusion_execution::TaskContextAccessor; use datafusion_ffi::tests::create_record_batch; use datafusion_ffi::tests::utils::get_module; use std::sync::Arc; @@ -33,22 +33,23 @@ mod tests { async fn test_table_provider(synchronous: bool) -> Result<()> { let table_provider_module = get_module()?; + let ctx = Arc::new(SessionContext::new()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; + // By calling the code below, the table provided will be created within // the module's code. let ffi_table_provider = table_provider_module.create_table().ok_or( DataFusionError::NotImplemented( "External table provider failed to implement create_table".to_string(), ), - )?(synchronous); + )?(synchronous, task_ctx_accessor.into()); // In order to access the table provider within this executable, we need to // turn it into a `ForeignTableProvider`. - let foreign_table_provider: ForeignTableProvider = (&ffi_table_provider).into(); - - let ctx = SessionContext::new(); + let foreign_table_provider: Arc = (&ffi_table_provider).into(); // Display the data to show the full cycle works. - ctx.register_table("external_table", Arc::new(foreign_table_provider))?; + ctx.register_table("external_table", foreign_table_provider)?; let df = ctx.table("external_table").await?; let results = df.collect().await?; @@ -73,6 +74,8 @@ mod tests { #[tokio::test] async fn test_catalog() -> Result<()> { let module = get_module()?; + let ctx = Arc::new(SessionContext::default()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; let ffi_catalog = module @@ -80,11 +83,10 @@ mod tests { .ok_or(DataFusionError::NotImplemented( "External catalog provider failed to implement create_catalog" .to_string(), - ))?(); - let foreign_catalog: ForeignCatalogProvider = (&ffi_catalog).into(); + ))?(task_ctx_accessor.into()); + let foreign_catalog: Arc = (&ffi_catalog).into(); - let ctx = SessionContext::default(); - let _ = ctx.register_catalog("fruit", Arc::new(foreign_catalog)); + let _ = ctx.register_catalog("fruit", foreign_catalog); let df = ctx.table("fruit.apple.purchases").await?; diff --git a/datafusion/ffi/tests/ffi_udaf.rs b/datafusion/ffi/tests/ffi_udaf.rs index ffd99bac62ec..bb4893802cde 100644 --- a/datafusion/ffi/tests/ffi_udaf.rs +++ b/datafusion/ffi/tests/ffi_udaf.rs @@ -24,25 +24,27 @@ mod tests { use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::AggregateUDF; use datafusion::prelude::{col, SessionContext}; - + use datafusion_execution::TaskContextAccessor; + use datafusion_expr::AggregateUDFImpl; use datafusion_ffi::tests::utils::get_module; - use datafusion_ffi::udaf::ForeignAggregateUDF; + use std::sync::Arc; #[tokio::test] async fn test_ffi_udaf() -> Result<()> { let module = get_module()?; + let ctx = Arc::new(SessionContext::default()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; let ffi_sum_func = module .create_sum_udaf() .ok_or(DataFusionError::NotImplemented( "External table provider failed to implement create_udaf".to_string(), - ))?(); - let foreign_sum_func: ForeignAggregateUDF = (&ffi_sum_func).try_into()?; + ))?(task_ctx_accessor.into()); + let foreign_sum_func: Arc = (&ffi_sum_func).try_into()?; - let udaf: AggregateUDF = foreign_sum_func.into(); + let udaf = AggregateUDF::new_from_shared_impl(foreign_sum_func); - let ctx = SessionContext::default(); let record_batch = record_batch!( ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), ("b", Float64, vec![1.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0]) @@ -73,18 +75,20 @@ mod tests { #[tokio::test] async fn test_ffi_grouping_udaf() -> Result<()> { let module = get_module()?; + let ctx = Arc::new(SessionContext::default()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; let ffi_stddev_func = module .create_stddev_udaf() .ok_or(DataFusionError::NotImplemented( "External table provider failed to implement create_udaf".to_string(), - ))?(); - let foreign_stddev_func: ForeignAggregateUDF = (&ffi_stddev_func).try_into()?; + ))?(task_ctx_accessor.into()); + let foreign_stddev_func: Arc = + (&ffi_stddev_func).try_into()?; - let udaf: AggregateUDF = foreign_stddev_func.into(); + let udaf = AggregateUDF::new_from_shared_impl(foreign_stddev_func); - let ctx = SessionContext::default(); let record_batch = record_batch!( ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), ( diff --git a/datafusion/ffi/tests/ffi_udf.rs b/datafusion/ffi/tests/ffi_udf.rs index fd6a84bcf5b0..399836928799 100644 --- a/datafusion/ffi/tests/ffi_udf.rs +++ b/datafusion/ffi/tests/ffi_udf.rs @@ -19,16 +19,14 @@ /// when the feature integration-tests is built #[cfg(feature = "integration-tests")] mod tests { - use arrow::datatypes::DataType; - use datafusion::common::record_batch; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::ScalarUDF; use datafusion::prelude::{col, SessionContext}; - + use datafusion_expr::{lit, ScalarUDFImpl}; use datafusion_ffi::tests::create_record_batch; use datafusion_ffi::tests::utils::get_module; - use datafusion_ffi::udf::ForeignScalarUDF; + use std::sync::Arc; /// This test validates that we can load an external module and use a scalar /// udf defined in it via the foreign function interface. In this case we are @@ -44,28 +42,37 @@ mod tests { "External table provider failed to implement create_scalar_udf" .to_string(), ))?(); - let foreign_abs_func: ForeignScalarUDF = (&ffi_abs_func).try_into()?; + let foreign_abs_func: Arc = (&ffi_abs_func).try_into()?; - let udf: ScalarUDF = foreign_abs_func.into(); + let udf = ScalarUDF::new_from_shared_impl(foreign_abs_func); let ctx = SessionContext::default(); let df = ctx.read_batch(create_record_batch(-5, 5))?; let df = df .with_column("abs_a", udf.call(vec![col("a")]))? - .with_column("abs_b", udf.call(vec![col("b")]))?; + .with_column("abs_b", udf.call(vec![col("b")]))? + .with_column("abs_lit", udf.call(vec![lit(-1)]))?; let result = df.collect().await?; + assert!(result.len() == 1); - let expected = record_batch!( - ("a", Int32, vec![-5, -4, -3, -2, -1]), - ("b", Float64, vec![-5., -4., -3., -2., -1.]), - ("abs_a", Int32, vec![5, 4, 3, 2, 1]), - ("abs_b", Float64, vec![5., 4., 3., 2., 1.]) + let expected = arrow::array::record_batch!( + ("a", Int32, [-5, -4, -3, -2, -1]), + ("b", Float64, [-5., -4., -3., -2., -1.]), + ("abs_a", Int32, [5, 4, 3, 2, 1]), + ("abs_b", Float64, [5., 4., 3., 2., 1.]), + ("abs_lit", Int32, [1, 1, 1, 1, 1]) )?; - assert!(result.len() == 1); - assert!(result[0] == expected); + // Literal value will create a non-null schema, so project before comparison + let result = result + .into_iter() + .next() + .unwrap() + .with_schema(expected.schema())?; + + assert!(result == expected); Ok(()) } @@ -82,9 +89,9 @@ mod tests { "External table provider failed to implement create_scalar_udf" .to_string(), ))?(); - let foreign_abs_func: ForeignScalarUDF = (&ffi_abs_func).try_into()?; + let foreign_abs_func: Arc = (&ffi_abs_func).try_into()?; - let udf: ScalarUDF = foreign_abs_func.into(); + let udf = ScalarUDF::new_from_shared_impl(foreign_abs_func); let ctx = SessionContext::default(); let df = ctx.read_batch(create_record_batch(-5, 5))?; diff --git a/datafusion/ffi/tests/ffi_udtf.rs b/datafusion/ffi/tests/ffi_udtf.rs index 8c1c64a092e1..aa60fff366c5 100644 --- a/datafusion/ffi/tests/ffi_udtf.rs +++ b/datafusion/ffi/tests/ffi_udtf.rs @@ -25,9 +25,9 @@ mod tests { use arrow::array::{create_array, ArrayRef}; use datafusion::error::{DataFusionError, Result}; use datafusion::prelude::SessionContext; - + use datafusion_catalog::TableFunctionImpl; + use datafusion_execution::TaskContextAccessor; use datafusion_ffi::tests::utils::get_module; - use datafusion_ffi::udtf::ForeignTableFunction; /// This test validates that we can load an external module and use a scalar /// udf defined in it via the foreign function interface. In this case we are @@ -36,18 +36,18 @@ mod tests { async fn test_user_defined_table_function() -> Result<()> { let module = get_module()?; + let ctx = Arc::new(SessionContext::default()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; + let ffi_table_func = module .create_table_function() .ok_or(DataFusionError::NotImplemented( "External table function provider failed to implement create_table_function" .to_string(), - ))?(); - let foreign_table_func: ForeignTableFunction = ffi_table_func.into(); - - let udtf = Arc::new(foreign_table_func); + ))?(task_ctx_accessor.into()); + let foreign_table_func: Arc = ffi_table_func.into(); - let ctx = SessionContext::default(); - ctx.register_udtf("my_range", udtf); + ctx.register_udtf("my_range", foreign_table_func); let result = ctx .sql("SELECT * FROM my_range(5)") diff --git a/datafusion/ffi/tests/ffi_udwf.rs b/datafusion/ffi/tests/ffi_udwf.rs index 18ffd0c5bcb7..aeb617cf9ca7 100644 --- a/datafusion/ffi/tests/ffi_udwf.rs +++ b/datafusion/ffi/tests/ffi_udwf.rs @@ -24,31 +24,32 @@ mod tests { use datafusion::logical_expr::expr::Sort; use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF}; use datafusion::prelude::SessionContext; + use datafusion_execution::TaskContextAccessor; + use datafusion_expr::{lit, Expr, WindowUDFImpl}; + use datafusion_ffi::session::task_ctx_accessor::FFI_TaskContextAccessor; use datafusion_ffi::tests::create_record_batch; use datafusion_ffi::tests::utils::get_module; - use datafusion_ffi::udwf::ForeignWindowUDF; + use datafusion_ffi::udwf::FFI_WindowUDF; + use std::sync::Arc; - #[tokio::test] - async fn test_rank_udwf() -> Result<()> { - let module = get_module()?; + async fn test_window_function( + function: extern "C" fn(FFI_TaskContextAccessor) -> FFI_WindowUDF, + arguments: Vec, + expected: ArrayRef, + ) -> Result<()> { + let ctx = Arc::new(SessionContext::default()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; - let ffi_rank_func = - module - .create_rank_udwf() - .ok_or(DataFusionError::NotImplemented( - "External table provider failed to implement create_scalar_udf" - .to_string(), - ))?(); - let foreign_rank_func: ForeignWindowUDF = (&ffi_rank_func).try_into()?; + let ffi_rank_func = function(task_ctx_accessor.into()); + let foreign_rank_func: Arc = (&ffi_rank_func).try_into()?; - let udwf: WindowUDF = foreign_rank_func.into(); + let udwf = WindowUDF::new_from_shared_impl(foreign_rank_func); - let ctx = SessionContext::default(); let df = ctx.read_batch(create_record_batch(-5, 5))?; let df = df.select(vec![ col("a"), - udwf.call(vec![]) + udwf.call(arguments) .order_by(vec![Sort::new(col("a"), true, true)]) .build() .unwrap() @@ -58,11 +59,52 @@ mod tests { df.clone().show().await?; let result = df.collect().await?; - let expected = create_array!(UInt64, [1, 2, 3, 4, 5]) as ArrayRef; assert_eq!(result.len(), 1); assert_eq!(result[0].column(1), &expected); Ok(()) } + + #[tokio::test] + async fn test_rank_udwf() -> Result<()> { + let module = get_module()?; + let expected = create_array!(UInt64, [1, 2, 3, 4, 5]) as ArrayRef; + let function = + module + .create_rank_udwf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement window function" + .to_string(), + ))?; + test_window_function(function, vec![], expected).await + } + + #[tokio::test] + async fn test_ntile_udwf() -> Result<()> { + let module = get_module()?; + let expected = create_array!(UInt64, [1, 1, 2, 2, 3]) as ArrayRef; + let function = + module + .create_ntile_udwf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement window function" + .to_string(), + ))?; + test_window_function(function, vec![lit(3)], expected).await + } + + #[tokio::test] + async fn test_cumedist_udwf() -> Result<()> { + let module = get_module()?; + let expected = create_array!(Float64, [0.2, 0.4, 0.6, 0.8, 1.0]) as ArrayRef; + let function = + module + .create_cumedist_udwf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement window function" + .to_string(), + ))?; + test_window_function(function, vec![], expected).await + } }