diff --git a/Cargo.lock b/Cargo.lock index e0f811f16f1b..4f84d596a4f4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2297,6 +2297,8 @@ dependencies = [ "async-trait", "datafusion", "datafusion-common", + "datafusion-execution", + "datafusion-expr", "datafusion-functions-aggregate-common", "datafusion-proto", "datafusion-proto-common", @@ -3042,7 +3044,6 @@ version = "0.1.0" dependencies = [ "abi_stable", "datafusion", - "datafusion-ffi", "ffi_module_interface", "tokio", ] diff --git a/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs b/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs index a83f15926f05..53825352aef0 100644 --- a/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs +++ b/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs @@ -21,6 +21,7 @@ use abi_stable::{export_root_module, prefix_type::PrefixTypeTrait}; use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{common::record_batch, datasource::MemTable}; +use datafusion_ffi::execution::FFI_TaskContextProvider; use datafusion_ffi::table_provider::FFI_TableProvider; use ffi_module_interface::{TableProviderModule, TableProviderModuleRef}; @@ -34,7 +35,9 @@ fn create_record_batch(start_value: i32, num_values: usize) -> RecordBatch { /// Here we only wish to create a simple table provider as an example. /// We create an in-memory table and convert it to it's FFI counterpart. -extern "C" fn construct_simple_table_provider() -> FFI_TableProvider { +extern "C" fn construct_simple_table_provider( + task_ctx_provider: FFI_TaskContextProvider, +) -> FFI_TableProvider { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Float64, true), @@ -50,7 +53,7 @@ extern "C" fn construct_simple_table_provider() -> FFI_TableProvider { let table_provider = MemTable::try_new(schema, vec![batches]).unwrap(); - FFI_TableProvider::new(Arc::new(table_provider), true, None) + FFI_TableProvider::new(Arc::new(table_provider), true, None, task_ctx_provider) } #[export_root_module] diff --git a/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs b/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs index 88690e929713..1eeb924902a4 100644 --- a/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs +++ b/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs @@ -22,6 +22,7 @@ use abi_stable::{ sabi_types::VersionStrings, StableAbi, }; +use datafusion_ffi::execution::FFI_TaskContextProvider; use datafusion_ffi::table_provider::FFI_TableProvider; #[repr(C)] @@ -34,7 +35,7 @@ use datafusion_ffi::table_provider::FFI_TableProvider; /// how a user may wish to separate these concerns. pub struct TableProviderModule { /// Constructs the table provider - pub create_table: extern "C" fn() -> FFI_TableProvider, + pub create_table: extern "C" fn(FFI_TaskContextProvider) -> FFI_TableProvider, } impl RootModule for TableProviderModuleRef { diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml index 028a366aab1c..1f68bb1bb1be 100644 --- a/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml +++ b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml @@ -24,6 +24,5 @@ publish = false [dependencies] abi_stable = "0.11.3" datafusion = { workspace = true } -datafusion-ffi = { workspace = true } ffi_module_interface = { path = "../ffi_module_interface" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs index 6e376ca866e8..d84890aacb72 100644 --- a/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs +++ b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs @@ -23,7 +23,8 @@ use datafusion::{ }; use abi_stable::library::{development_utils::compute_library_path, RootModule}; -use datafusion_ffi::table_provider::ForeignTableProvider; +use datafusion::datasource::TableProvider; +use datafusion::execution::TaskContextProvider; use ffi_module_interface::TableProviderModuleRef; #[tokio::main] @@ -39,6 +40,9 @@ async fn main() -> Result<()> { TableProviderModuleRef::load_from_directory(&library_path) .map_err(|e| DataFusionError::External(Box::new(e)))?; + let ctx = Arc::new(SessionContext::new()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; + // By calling the code below, the table provided will be created within // the module's code. let ffi_table_provider = @@ -46,16 +50,14 @@ async fn main() -> Result<()> { .create_table() .ok_or(DataFusionError::NotImplemented( "External table provider failed to implement create_table".to_string(), - ))?(); + ))?(task_ctx_provider.into()); // In order to access the table provider within this executable, we need to // turn it into a `ForeignTableProvider`. - let foreign_table_provider: ForeignTableProvider = (&ffi_table_provider).into(); - - let ctx = SessionContext::new(); + let foreign_table_provider: Arc = (&ffi_table_provider).into(); // Display the data to show the full cycle works. - ctx.register_table("external_table", Arc::new(foreign_table_provider))?; + ctx.register_table("external_table", foreign_table_provider)?; let df = ctx.table("external_table").await?; df.show().await?; diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 9c7339e6748e..5027ab70da69 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1801,6 +1801,12 @@ impl SessionContext { } } +impl datafusion_execution::TaskContextProvider for SessionContext { + fn task_ctx(&self) -> Arc { + SessionContext::task_ctx(self) + } +} + impl FunctionRegistry for SessionContext { fn udfs(&self) -> HashSet { self.state.read().udfs() diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index d7a66db28ac4..2f527c775814 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -286,6 +286,12 @@ impl Session for SessionState { } } +impl datafusion_execution::TaskContextProvider for SessionState { + fn task_ctx(&self) -> Arc { + SessionState::task_ctx(self) + } +} + impl SessionState { pub(crate) fn resolve_table_ref( &self, diff --git a/datafusion/execution/src/lib.rs b/datafusion/execution/src/lib.rs index 9439aefc008d..082f77448e50 100644 --- a/datafusion/execution/src/lib.rs +++ b/datafusion/execution/src/lib.rs @@ -49,4 +49,4 @@ pub mod registry { pub use disk_manager::DiskManager; pub use registry::FunctionRegistry; pub use stream::{RecordBatchStream, SendableRecordBatchStream}; -pub use task::TaskContext; +pub use task::{TaskContext, TaskContextProvider}; diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index c2a6cfe2c833..33e0fea53674 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -32,7 +32,7 @@ use std::{collections::HashMap, sync::Arc}; /// information. /// /// [`SessionContext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct TaskContext { /// Session Id session_id: String, @@ -211,6 +211,11 @@ impl FunctionRegistry for TaskContext { } } +/// Produce the [`TaskContext`]. +pub trait TaskContextProvider { + fn task_ctx(&self) -> Arc; +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index 3ac08180fb68..cecc2f5ad26d 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -48,6 +48,8 @@ async-ffi = { version = "0.5.0", features = ["abi_stable"] } async-trait = { workspace = true } datafusion = { workspace = true, default-features = false } datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-execution = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-proto = { workspace = true } datafusion-proto-common = { workspace = true } @@ -58,6 +60,7 @@ semver = "1.0.27" tokio = { workspace = true } [dev-dependencies] +datafusion = { workspace = true, default-features = false, features = ["sql"] } doc-comment = { workspace = true } [features] diff --git a/datafusion/ffi/README.md b/datafusion/ffi/README.md index 72070984f931..7e5090c39cc8 100644 --- a/datafusion/ffi/README.md +++ b/datafusion/ffi/README.md @@ -101,6 +101,101 @@ In this crate we have a variety of structs which closely mimic the behavior of their internal counterparts. To see detailed notes about how to use them, see the example in `FFI_TableProvider`. +## Task Context Provider + +Many of the FFI structs in this crate contain a `FFI_TaskContextProvider`. The +purpose of this struct is to _weakly_ hold a reference to a method to +access the current `TaskContext`. The reason we need this accessor is because +we use the `datafusion-proto` crate to serialize and deserialize data across +the FFI boundary. In particular, we need to serialize and deserialize +functions using a `TaskContext`. + +This becomes difficult because we may need to register multiple user defined +functions, table or catalog providers, etc with a `Session`, and each of these +will need the `TaskContext` to perform the processing. For this reason we +cannot simply include the `TaskContext` at the time of registration because +it would not have knowledge of anything registered afterward. + +The `FFI_TaskContextProvider` is built from a trait that provides a method +to get the current `TaskContext`. `FFI_TaskContextProvider` only holds a +`Weak` reference to the `TaskContextProvider`, because otherwise we could +create a circular dependency at runtime. It is imperative that if you use +these methods that your provider remains valid for the lifetime of the +calls. The `FFI_TaskContextProvider` is implemented on `SessionContext` +and it is easy to implement on an struct that implements `Session`. + +## Library Marker ID + +When reviewing the code, many of the structs in this crate contain a call to +a `library_maker_id`. The purpose of this call is to determine if a library is +accessing _local_ code through the FFI structs. Consider this example: you have +a `primary` program that exposes functions to create a schema provider. You +have a `secondary` library that exposes a function to create a catalog provider +and the `secondary` library uses the schema provider of the `primary` program. +From the point of view of the `secondary` library, the schema provider is +foreign code. + +Now when we register the `secondary` library with the `primary` program as a +catalog provider and we make calls to get a schema, the `secondary` library +will return a FFI wrapped schema provider back to the `primary` program. In +this case that schema provider is actually local code to the `primary` program +except that it is wrapped in the FFI code! + +We work around this by the `library_marker_id` calls. What this does is it +creates a global variable within each library and returns a `u64` address +of that library. This is guaranteed to be unique for every library that contains +FFI code. By comparing these `u64` addresses we can determine if a FFI struct +is local or foreign. + +In our example of the schema provider, if you were to make a call in your +primary program to get the schema provider, it would reach out to the foreign +catalog provider and send back a `FFI_SchemaProvider` object. By then +comparing the `library_marker_id` of this object to the `primary` program, we +determine it is local code. This means it is safe to access the underlying +private data. + +## Testing Coverage + +Since this library contains a large amount of `unsafe` code, it is important +to ensure proper test coverage. To generate a coverage report, you can use +[tarpaulin] as follows. It is necessary to use the `integration-tests` feature +in order to properly generate coverage. + +```shell +cargo tarpaulin --package datafusion-ffi --tests --features integration-tests --out Html +``` + +While it is not normally required to check Rust code for memory leaks, this +crate does manual memory management due to the FFI boundary. You can test for +leaks using the generated unit tests. How you run these checks differs depending +on your OS. + +### Linux + +On Linux, you can install `cargo-valgrind` + +```shell +cargo valgrind test --features integration-tests -p datafusion-ffi +``` + +### MacOS + +You can find the generated binaries for your unit tests by running `cargo test`. + +```shell +cargo test --features integration-tests -p datafusion-ffi --no-run +``` + +This should generate output that shows the path to the test binaries. Then +you can run commands such as the following. The specific paths of the tests +will vary. + +```shell +leaks --atExit -- target/debug/deps/datafusion_ffi-e77a2604a85a8afe +leaks --atExit -- target/debug/deps/ffi_integration-e91b7127a59b71a7 +# ... +``` + [apache datafusion]: https://datafusion.apache.org/ [api docs]: http://docs.rs/datafusion-ffi/latest [rust abi]: https://doc.rust-lang.org/reference/abi.html @@ -110,3 +205,4 @@ the example in `FFI_TableProvider`. [bindgen]: https://crates.io/crates/bindgen [`datafusion-python`]: https://datafusion.apache.org/python/ [datafusion-contrib]: https://github.com/datafusion-contrib +[tarpaulin]: https://crates.io/crates/cargo-tarpaulin diff --git a/datafusion/ffi/src/catalog_provider.rs b/datafusion/ffi/src/catalog_provider.rs index d279951783b4..e9415b437c5c 100644 --- a/datafusion/ffi/src/catalog_provider.rs +++ b/datafusion/ffi/src/catalog_provider.rs @@ -29,6 +29,7 @@ use crate::{ schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider}, }; +use crate::execution::FFI_TaskContextProvider; use datafusion::error::Result; /// A stable struct for sharing [`CatalogProvider`] across FFI boundaries. @@ -57,6 +58,10 @@ pub struct FFI_CatalogProvider { cascade: bool, ) -> RResult, RString>, + /// Provider for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_provider: FFI_TaskContextProvider, + /// Used to create a clone on the provider of the execution plan. This should /// only need to be called by the receiver of the plan. pub clone: unsafe extern "C" fn(plan: &Self) -> Self, @@ -70,6 +75,10 @@ pub struct FFI_CatalogProvider { /// Internal data. This is only to be accessed by the provider of the plan. /// A [`ForeignCatalogProvider`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_CatalogProvider {} @@ -105,7 +114,13 @@ unsafe extern "C" fn schema_fn_wrapper( ) -> ROption { let maybe_schema = provider.inner().schema(name.as_str()); maybe_schema - .map(|schema| FFI_SchemaProvider::new(schema, provider.runtime())) + .map(|schema| { + FFI_SchemaProvider::new( + schema, + provider.runtime(), + provider.task_ctx_provider.clone(), + ) + }) .into() } @@ -115,12 +130,18 @@ unsafe extern "C" fn register_schema_fn_wrapper( schema: &FFI_SchemaProvider, ) -> RResult, RString> { let runtime = provider.runtime(); - let provider = provider.inner(); - let schema = Arc::new(ForeignSchemaProvider::from(schema)); + let inner_provider = provider.inner(); + let schema: Arc = schema.into(); let returned_schema = - rresult_return!(provider.register_schema(name.as_str(), schema)) - .map(|schema| FFI_SchemaProvider::new(schema, runtime)) + rresult_return!(inner_provider.register_schema(name.as_str(), schema)) + .map(|schema| { + FFI_SchemaProvider::new( + schema, + runtime, + provider.task_ctx_provider.clone(), + ) + }) .into(); RResult::ROk(returned_schema) @@ -132,14 +153,20 @@ unsafe extern "C" fn deregister_schema_fn_wrapper( cascade: bool, ) -> RResult, RString> { let runtime = provider.runtime(); - let provider = provider.inner(); + let inner_provider = provider.inner(); let maybe_schema = - rresult_return!(provider.deregister_schema(name.as_str(), cascade)); + rresult_return!(inner_provider.deregister_schema(name.as_str(), cascade)); RResult::ROk( maybe_schema - .map(|schema| FFI_SchemaProvider::new(schema, runtime)) + .map(|schema| { + FFI_SchemaProvider::new( + schema, + runtime, + provider.task_ctx_provider.clone(), + ) + }) .into(), ) } @@ -165,10 +192,12 @@ unsafe extern "C" fn clone_fn_wrapper( schema: schema_fn_wrapper, register_schema: register_schema_fn_wrapper, deregister_schema: deregister_schema_fn_wrapper, + task_ctx_provider: provider.task_ctx_provider.clone(), clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, private_data, + library_marker_id: crate::get_library_marker_id, } } @@ -183,7 +212,9 @@ impl FFI_CatalogProvider { pub fn new( provider: Arc, runtime: Option, + task_ctx_provider: impl Into, ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); let private_data = Box::new(ProviderPrivateData { provider, runtime }); Self { @@ -191,10 +222,12 @@ impl FFI_CatalogProvider { schema: schema_fn_wrapper, register_schema: register_schema_fn_wrapper, deregister_schema: deregister_schema_fn_wrapper, + task_ctx_provider, clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -209,9 +242,14 @@ pub struct ForeignCatalogProvider(pub(crate) FFI_CatalogProvider); unsafe impl Send for ForeignCatalogProvider {} unsafe impl Sync for ForeignCatalogProvider {} -impl From<&FFI_CatalogProvider> for ForeignCatalogProvider { +impl From<&FFI_CatalogProvider> for Arc { fn from(provider: &FFI_CatalogProvider) -> Self { - Self(provider.clone()) + if (provider.library_marker_id)() == crate::get_library_marker_id() { + return Arc::clone(unsafe { provider.inner() }); + } + + Arc::new(ForeignCatalogProvider(provider.clone())) + as Arc } } @@ -254,7 +292,11 @@ impl CatalogProvider for ForeignCatalogProvider { unsafe { let schema = match schema.as_any().downcast_ref::() { Some(s) => &s.0, - None => &FFI_SchemaProvider::new(schema, None), + None => &FFI_SchemaProvider::new( + schema, + None, + self.0.task_ctx_provider.clone(), + ), }; let returned_schema: Option = df_result!((self.0.register_schema)(&self.0, name.into(), schema))? @@ -283,9 +325,10 @@ impl CatalogProvider for ForeignCatalogProvider { #[cfg(test)] mod tests { - use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider}; - use super::*; + use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider}; + use datafusion::prelude::SessionContext; + use datafusion_execution::TaskContextProvider; #[test] fn test_round_trip_ffi_catalog_provider() { @@ -297,10 +340,12 @@ mod tests { .register_schema("prior_schema", prior_schema) .unwrap() .is_none()); + let ctx = Arc::new(SessionContext::new()) as Arc; - let ffi_catalog = FFI_CatalogProvider::new(catalog, None); + let mut ffi_catalog = FFI_CatalogProvider::new(catalog, None, ctx); + ffi_catalog.library_marker_id = crate::mock_foreign_marker_id; - let foreign_catalog: ForeignCatalogProvider = (&ffi_catalog).into(); + let foreign_catalog: Arc = (&ffi_catalog).into(); let prior_schema_names = foreign_catalog.schema_names(); assert_eq!(prior_schema_names.len(), 1); @@ -335,4 +380,27 @@ mod tests { let returned_schema = foreign_catalog.schema("second_schema"); assert!(returned_schema.is_some()); } + + #[test] + fn test_ffi_catalog_provider_local_bypass() { + let catalog = Arc::new(MemoryCatalogProvider::new()); + + let ctx = Arc::new(SessionContext::new()) as Arc; + let mut ffi_catalog = FFI_CatalogProvider::new(catalog, None, ctx); + + // Verify local libraries can be downcast to their original + let foreign_catalog: Arc = (&ffi_catalog).into(); + assert!(foreign_catalog + .as_any() + .downcast_ref::() + .is_some()); + + // Verify different library markers generate foreign providers + ffi_catalog.library_marker_id = crate::mock_foreign_marker_id; + let foreign_catalog: Arc = (&ffi_catalog).into(); + assert!(foreign_catalog + .as_any() + .downcast_ref::() + .is_some()); + } } diff --git a/datafusion/ffi/src/catalog_provider_list.rs b/datafusion/ffi/src/catalog_provider_list.rs index b09f06d318c1..b975fb7a8ba4 100644 --- a/datafusion/ffi/src/catalog_provider_list.rs +++ b/datafusion/ffi/src/catalog_provider_list.rs @@ -25,6 +25,7 @@ use datafusion::catalog::{CatalogProvider, CatalogProviderList}; use tokio::runtime::Handle; use crate::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider}; +use crate::execution::FFI_TaskContextProvider; /// A stable struct for sharing [`CatalogProviderList`] across FFI boundaries. #[repr(C)] @@ -45,8 +46,12 @@ pub struct FFI_CatalogProviderList { pub catalog: unsafe extern "C" fn(&Self, name: RString) -> ROption, - /// Used to create a clone on the provider. This should only need to be called - /// by the receiver of the plan. + /// Provider for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_provider: FFI_TaskContextProvider, + + /// Used to create a clone on the provider of the execution plan. This should + /// only need to be called by the receiver of the plan. pub clone: unsafe extern "C" fn(plan: &Self) -> Self, /// Release the memory of the private data when it is no longer being used. @@ -58,6 +63,10 @@ pub struct FFI_CatalogProviderList { /// Internal data. This is only to be accessed by the provider of the plan. /// A [`ForeignCatalogProviderList`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_CatalogProviderList {} @@ -93,12 +102,14 @@ unsafe extern "C" fn register_catalog_fn_wrapper( catalog: &FFI_CatalogProvider, ) -> ROption { let runtime = provider.runtime(); - let provider = provider.inner(); - let catalog = Arc::new(ForeignCatalogProvider::from(catalog)); + let inner_provider = provider.inner(); + let catalog: Arc = catalog.into(); - provider + inner_provider .register_catalog(name.into(), catalog) - .map(|catalog| FFI_CatalogProvider::new(catalog, runtime)) + .map(|catalog| { + FFI_CatalogProvider::new(catalog, runtime, provider.task_ctx_provider.clone()) + }) .into() } @@ -107,10 +118,12 @@ unsafe extern "C" fn catalog_fn_wrapper( name: RString, ) -> ROption { let runtime = provider.runtime(); - let provider = provider.inner(); - provider + let inner_provider = provider.inner(); + inner_provider .catalog(name.as_str()) - .map(|catalog| FFI_CatalogProvider::new(catalog, runtime)) + .map(|catalog| { + FFI_CatalogProvider::new(catalog, runtime, provider.task_ctx_provider.clone()) + }) .into() } @@ -134,10 +147,12 @@ unsafe extern "C" fn clone_fn_wrapper( register_catalog: register_catalog_fn_wrapper, catalog_names: catalog_names_fn_wrapper, catalog: catalog_fn_wrapper, + task_ctx_provider: provider.task_ctx_provider.clone(), clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, private_data, + library_marker_id: crate::get_library_marker_id, } } @@ -152,17 +167,21 @@ impl FFI_CatalogProviderList { pub fn new( provider: Arc, runtime: Option, + task_ctx_provider: impl Into, ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); let private_data = Box::new(ProviderPrivateData { provider, runtime }); Self { register_catalog: register_catalog_fn_wrapper, catalog_names: catalog_names_fn_wrapper, catalog: catalog_fn_wrapper, + task_ctx_provider, clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -170,16 +189,21 @@ impl FFI_CatalogProviderList { /// This wrapper struct exists on the receiver side of the FFI interface, so it has /// no guarantees about being able to access the data in `private_data`. Any functions /// defined on this struct must only use the stable functions provided in -/// FFI_CatalogProviderList to interact with the foreign catalog provider list. +/// FFI_CatalogProviderList to interact with the foreign table provider. #[derive(Debug)] pub struct ForeignCatalogProviderList(FFI_CatalogProviderList); unsafe impl Send for ForeignCatalogProviderList {} unsafe impl Sync for ForeignCatalogProviderList {} -impl From<&FFI_CatalogProviderList> for ForeignCatalogProviderList { +impl From<&FFI_CatalogProviderList> for Arc { fn from(provider: &FFI_CatalogProviderList) -> Self { - Self(provider.clone()) + if (provider.library_marker_id)() == crate::get_library_marker_id() { + return Arc::clone(unsafe { provider.inner() }); + } + + Arc::new(ForeignCatalogProviderList(provider.clone())) + as Arc } } @@ -203,7 +227,11 @@ impl CatalogProviderList for ForeignCatalogProviderList { let catalog = match catalog.as_any().downcast_ref::() { Some(s) => &s.0, - None => &FFI_CatalogProvider::new(catalog, None), + None => &FFI_CatalogProvider::new( + catalog, + None, + self.0.task_ctx_provider.clone(), + ), }; (self.0.register_catalog)(&self.0, name.into(), catalog) @@ -234,9 +262,10 @@ impl CatalogProviderList for ForeignCatalogProviderList { #[cfg(test)] mod tests { - use datafusion::catalog::{MemoryCatalogProvider, MemoryCatalogProviderList}; - use super::*; + use datafusion::catalog::{MemoryCatalogProvider, MemoryCatalogProviderList}; + use datafusion::prelude::SessionContext; + use datafusion_execution::TaskContextProvider; #[test] fn test_round_trip_ffi_catalog_provider_list() { @@ -248,9 +277,12 @@ mod tests { .register_catalog("prior_catalog".to_owned(), prior_catalog) .is_none()); - let ffi_catalog_list = FFI_CatalogProviderList::new(catalog_list, None); + let ctx = Arc::new(SessionContext::new()) as Arc; + let mut ffi_catalog_list = FFI_CatalogProviderList::new(catalog_list, None, ctx); + ffi_catalog_list.library_marker_id = crate::mock_foreign_marker_id; - let foreign_catalog_list: ForeignCatalogProviderList = (&ffi_catalog_list).into(); + let foreign_catalog_list: Arc = + (&ffi_catalog_list).into(); let prior_catalog_names = foreign_catalog_list.catalog_names(); assert_eq!(prior_catalog_names.len(), 1); @@ -280,4 +312,29 @@ mod tests { let returned_catalog = foreign_catalog_list.catalog("second_catalog"); assert!(returned_catalog.is_some()); } + + #[test] + fn test_ffi_catalog_provider_list_local_bypass() { + let catalog_list = Arc::new(MemoryCatalogProviderList::new()); + + let ctx = Arc::new(SessionContext::new()) as Arc; + let mut ffi_catalog_list = FFI_CatalogProviderList::new(catalog_list, None, ctx); + + // Verify local libraries can be downcast to their original + let foreign_catalog_list: Arc = + (&ffi_catalog_list).into(); + assert!(foreign_catalog_list + .as_any() + .downcast_ref::() + .is_some()); + + // Verify different library markers generate foreign providers + ffi_catalog_list.library_marker_id = crate::mock_foreign_marker_id; + let foreign_catalog_list: Arc = + (&ffi_catalog_list).into(); + assert!(foreign_catalog_list + .as_any() + .downcast_ref::() + .is_some()); + } } diff --git a/datafusion/ffi/src/execution/mod.rs b/datafusion/ffi/src/execution/mod.rs new file mode 100644 index 000000000000..0b54cb1b579d --- /dev/null +++ b/datafusion/ffi/src/execution/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod task_ctx; +mod task_ctx_provider; + +pub use task_ctx::FFI_TaskContext; +pub use task_ctx_provider::FFI_TaskContextProvider; diff --git a/datafusion/ffi/src/execution/task_ctx.rs b/datafusion/ffi/src/execution/task_ctx.rs new file mode 100644 index 000000000000..73122b0f70ed --- /dev/null +++ b/datafusion/ffi/src/execution/task_ctx.rs @@ -0,0 +1,250 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::execution::task_ctx_provider::FFI_TaskContextProvider; +use crate::session_config::FFI_SessionConfig; +use crate::udaf::FFI_AggregateUDF; +use crate::udf::FFI_ScalarUDF; +use crate::udwf::FFI_WindowUDF; +use abi_stable::pmr::ROption; +use abi_stable::std_types::RHashMap; +use abi_stable::{std_types::RString, StableAbi}; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_execution::TaskContext; +use datafusion_expr::{ + AggregateUDF, AggregateUDFImpl, ScalarUDF, ScalarUDFImpl, WindowUDF, WindowUDFImpl, +}; +use std::{ffi::c_void, sync::Arc}; + +/// A stable struct for sharing [`TaskContext`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_TaskContext { + pub session_id: unsafe extern "C" fn(&Self) -> RString, + + pub task_id: unsafe extern "C" fn(&Self) -> ROption, + + pub session_config: unsafe extern "C" fn(&Self) -> FFI_SessionConfig, + + pub scalar_functions: unsafe extern "C" fn(&Self) -> RHashMap, + + pub aggregate_functions: + unsafe extern "C" fn(&Self) -> RHashMap, + + pub window_functions: unsafe extern "C" fn(&Self) -> RHashMap, + + /// Provider for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_provider: FFI_TaskContextProvider, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the plan. + /// The foreign library should never attempt to access this data. + pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, +} + +struct TaskContextPrivateData { + ctx: Arc, +} + +impl FFI_TaskContext { + unsafe fn inner(&self) -> &TaskContext { + let private_data = self.private_data as *const TaskContextPrivateData; + &(*private_data).ctx + } +} + +unsafe extern "C" fn session_id_fn_wrapper(ctx: &FFI_TaskContext) -> RString { + let ctx = ctx.inner(); + ctx.session_id().into() +} + +unsafe extern "C" fn task_id_fn_wrapper(ctx: &FFI_TaskContext) -> ROption { + let ctx = ctx.inner(); + ctx.task_id().map(|s| s.as_str().into()).into() +} + +unsafe extern "C" fn session_config_fn_wrapper( + ctx: &FFI_TaskContext, +) -> FFI_SessionConfig { + let ctx = ctx.inner(); + ctx.session_config().into() +} + +unsafe extern "C" fn scalar_functions_fn_wrapper( + ctx: &FFI_TaskContext, +) -> RHashMap { + let ctx = ctx.inner(); + ctx.scalar_functions() + .iter() + .map(|(name, udf)| (name.to_owned().into(), Arc::clone(udf).into())) + .collect() +} + +unsafe extern "C" fn aggregate_functions_fn_wrapper( + ctx: &FFI_TaskContext, +) -> RHashMap { + let task_ctx_provider = &ctx.task_ctx_provider; + let ctx = ctx.inner(); + ctx.aggregate_functions() + .iter() + .map(|(name, udaf)| { + ( + name.to_owned().into(), + FFI_AggregateUDF::new(Arc::clone(udaf), task_ctx_provider.clone()), + ) + }) + .collect() +} + +unsafe extern "C" fn window_functions_fn_wrapper( + ctx: &FFI_TaskContext, +) -> RHashMap { + let task_ctx_provider = &ctx.task_ctx_provider; + let ctx = ctx.inner(); + ctx.window_functions() + .iter() + .map(|(name, udf)| { + ( + name.to_owned().into(), + FFI_WindowUDF::new(Arc::clone(udf), task_ctx_provider.clone()), + ) + }) + .collect() +} + +unsafe extern "C" fn release_fn_wrapper(ctx: &mut FFI_TaskContext) { + let private_data = Box::from_raw(ctx.private_data as *mut TaskContextPrivateData); + drop(private_data); +} + +impl Drop for FFI_TaskContext { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +impl FFI_TaskContext { + pub fn new( + ctx: Arc, + task_ctx_provider: impl Into, + ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); + let private_data = Box::new(TaskContextPrivateData { ctx }); + + FFI_TaskContext { + session_id: session_id_fn_wrapper, + task_id: task_id_fn_wrapper, + session_config: session_config_fn_wrapper, + scalar_functions: scalar_functions_fn_wrapper, + aggregate_functions: aggregate_functions_fn_wrapper, + window_functions: window_functions_fn_wrapper, + task_ctx_provider, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, + } + } +} + +impl From for TaskContext { + fn from(ffi_ctx: FFI_TaskContext) -> Self { + unsafe { + if (ffi_ctx.library_marker_id)() == crate::get_library_marker_id() { + return ffi_ctx.inner().clone(); + } + + let task_id = (ffi_ctx.task_id)(&ffi_ctx).map(|s| s.to_string()).into(); + let session_id = (ffi_ctx.session_id)(&ffi_ctx).into(); + let session_config = (ffi_ctx.session_config)(&ffi_ctx); + let session_config = + SessionConfig::try_from(&session_config).unwrap_or_default(); + + let scalar_functions = (ffi_ctx.scalar_functions)(&ffi_ctx) + .into_iter() + .filter_map(|kv_pair| { + let udf = >::try_from(&kv_pair.1); + + if let Err(err) = &udf { + log::error!("Unable to create WindowUDF in FFI: {err}") + } + + udf.ok().map(|udf| { + ( + kv_pair.0.into_string(), + Arc::new(ScalarUDF::new_from_shared_impl(udf)), + ) + }) + }) + .collect(); + let aggregate_functions = (ffi_ctx.aggregate_functions)(&ffi_ctx) + .into_iter() + .filter_map(|kv_pair| { + let udaf = >::try_from(&kv_pair.1); + + if let Err(err) = &udaf { + log::error!("Unable to create AggregateUDF in FFI: {err}") + } + + udaf.ok().map(|udaf| { + ( + kv_pair.0.into_string(), + Arc::new(AggregateUDF::new_from_shared_impl(udaf)), + ) + }) + }) + .collect(); + let window_functions = (ffi_ctx.window_functions)(&ffi_ctx) + .into_iter() + .filter_map(|kv_pair| { + let udwf = >::try_from(&kv_pair.1); + + if let Err(err) = &udwf { + log::error!("Unable to create WindowUDF in FFI: {err}") + } + + udwf.ok().map(|udwf| { + ( + kv_pair.0.into_string(), + Arc::new(WindowUDF::new_from_shared_impl(udwf)), + ) + }) + }) + .collect(); + + let runtime = Arc::new(RuntimeEnv::default()); + + TaskContext::new( + task_id, + session_id, + session_config, + scalar_functions, + aggregate_functions, + window_functions, + runtime, + ) + } + } +} diff --git a/datafusion/ffi/src/execution/task_ctx_provider.rs b/datafusion/ffi/src/execution/task_ctx_provider.rs new file mode 100644 index 000000000000..31aade9d83c3 --- /dev/null +++ b/datafusion/ffi/src/execution/task_ctx_provider.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::execution::task_ctx::FFI_TaskContext; +use crate::{df_result, rresult}; +use abi_stable::std_types::RResult; +use abi_stable::{std_types::RString, StableAbi}; +use datafusion_common::{exec_datafusion_err, DataFusionError}; +use datafusion_execution::{TaskContext, TaskContextProvider}; +use std::sync::Weak; +use std::{ffi::c_void, sync::Arc}; + +/// Struct for accessing the [`TaskContext`]. This method contains a weak +/// reference, so there are no guarantees that the [`TaskContext`] remains +/// valid. This is used primarily for protobuf encoding and decoding of +/// data passed across the FFI boundary. See the crate README for +/// additional information. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_TaskContextProvider { + pub task_ctx: unsafe extern "C" fn(&Self) -> RResult, + + /// Used to create a clone on the task context accessor. This should + /// only need to be called by the receiver of the plan. + pub clone: unsafe extern "C" fn(plan: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the plan. + /// The foreign library should never attempt to access this data. + pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, +} + +unsafe impl Send for FFI_TaskContextProvider {} +unsafe impl Sync for FFI_TaskContextProvider {} + +struct TaskContextProviderPrivateData { + ctx: Weak, +} + +impl FFI_TaskContextProvider { + unsafe fn inner(&self) -> Option> { + let private_data = self.private_data as *const TaskContextProviderPrivateData; + (*private_data).ctx.upgrade().map(|ctx| ctx.task_ctx()) + } +} + +unsafe extern "C" fn task_ctx_fn_wrapper( + ctx_accessor: &FFI_TaskContextProvider, +) -> RResult { + rresult!(ctx_accessor + .inner() + .map(|ctx| FFI_TaskContext::new(ctx, ctx_accessor.clone())) + .ok_or_else(|| { + exec_datafusion_err!( + "TaskContextProvider went out of scope over FFI boundary." + ) + })) +} + +unsafe extern "C" fn clone_fn_wrapper( + accessor: &FFI_TaskContextProvider, +) -> FFI_TaskContextProvider { + let private_data = accessor.private_data as *const TaskContextProviderPrivateData; + let ctx = Weak::clone(&(*private_data).ctx); + + let private_data = Box::new(TaskContextProviderPrivateData { ctx }); + + FFI_TaskContextProvider { + task_ctx: task_ctx_fn_wrapper, + release: release_fn_wrapper, + clone: clone_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, + } +} +unsafe extern "C" fn release_fn_wrapper(ctx: &mut FFI_TaskContextProvider) { + let private_data = + Box::from_raw(ctx.private_data as *mut TaskContextProviderPrivateData); + drop(private_data); +} +impl Drop for FFI_TaskContextProvider { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +impl Clone for FFI_TaskContextProvider { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From> for FFI_TaskContextProvider { + fn from(ctx: Arc) -> Self { + (&ctx).into() + } +} + +impl From<&Arc> for FFI_TaskContextProvider { + fn from(ctx: &Arc) -> Self { + let ctx = Arc::downgrade(ctx); + let private_data = Box::new(TaskContextProviderPrivateData { ctx }); + + FFI_TaskContextProvider { + task_ctx: task_ctx_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, + } + } +} + +impl TryFrom<&FFI_TaskContextProvider> for Arc { + type Error = DataFusionError; + fn try_from(ffi_ctx: &FFI_TaskContextProvider) -> Result { + unsafe { + if (ffi_ctx.library_marker_id)() == crate::get_library_marker_id() { + return ffi_ctx.inner().ok_or_else(|| { + exec_datafusion_err!( + "TaskContextProvider went out of scope over FFI boundary." + ) + }); + } + + df_result!((ffi_ctx.task_ctx)(ffi_ctx)) + .map(Into::into) + .map(Arc::new) + } + } +} diff --git a/datafusion/ffi/src/execution_plan.rs b/datafusion/ffi/src/execution_plan.rs index 70c957d8c373..1fc7d4d28d48 100644 --- a/datafusion/ffi/src/execution_plan.rs +++ b/datafusion/ffi/src/execution_plan.rs @@ -29,9 +29,10 @@ use datafusion::{ use datafusion::{error::Result, physical_plan::DisplayFormatType}; use tokio::runtime::Handle; +use crate::execution::FFI_TaskContextProvider; use crate::{ df_result, plan_properties::FFI_PlanProperties, - record_batch_stream::FFI_RecordBatchStream, rresult, + record_batch_stream::FFI_RecordBatchStream, rresult, rresult_return, }; /// A stable struct for sharing a [`ExecutionPlan`] across FFI boundaries. @@ -55,6 +56,10 @@ pub struct FFI_ExecutionPlan { partition: usize, ) -> RResult, + /// Provider for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_provider: FFI_TaskContextProvider, + /// Used to create a clone on the provider of the execution plan. This should /// only need to be called by the receiver of the plan. pub clone: unsafe extern "C" fn(plan: &Self) -> Self, @@ -65,6 +70,10 @@ pub struct FFI_ExecutionPlan { /// Internal data. This is only to be accessed by the provider of the plan. /// A [`ForeignExecutionPlan`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_ExecutionPlan {} @@ -72,32 +81,39 @@ unsafe impl Sync for FFI_ExecutionPlan {} pub struct ExecutionPlanPrivateData { pub plan: Arc, - pub context: Arc, pub runtime: Option, } +impl FFI_ExecutionPlan { + fn inner(&self) -> &Arc { + let private_data = self.private_data as *const ExecutionPlanPrivateData; + unsafe { &(*private_data).plan } + } +} + unsafe extern "C" fn properties_fn_wrapper( plan: &FFI_ExecutionPlan, ) -> FFI_PlanProperties { - let private_data = plan.private_data as *const ExecutionPlanPrivateData; - let plan = &(*private_data).plan; - - plan.properties().into() + FFI_PlanProperties::new(plan.inner().properties(), plan.task_ctx_provider.clone()) } unsafe extern "C" fn children_fn_wrapper( plan: &FFI_ExecutionPlan, ) -> RVec { + let task_ctx_provider = &plan.task_ctx_provider; let private_data = plan.private_data as *const ExecutionPlanPrivateData; let plan = &(*private_data).plan; - let ctx = &(*private_data).context; let runtime = &(*private_data).runtime; let children: Vec<_> = plan .children() .into_iter() .map(|child| { - FFI_ExecutionPlan::new(Arc::clone(child), Arc::clone(ctx), runtime.clone()) + FFI_ExecutionPlan::new( + Arc::clone(child), + task_ctx_provider.clone(), + runtime.clone(), + ) }) .collect(); @@ -108,21 +124,18 @@ unsafe extern "C" fn execute_fn_wrapper( plan: &FFI_ExecutionPlan, partition: usize, ) -> RResult { + let ctx = rresult_return!(>::try_from(&plan.task_ctx_provider)); let private_data = plan.private_data as *const ExecutionPlanPrivateData; let plan = &(*private_data).plan; - let ctx = &(*private_data).context; let runtime = (*private_data).runtime.clone(); rresult!(plan - .execute(partition, Arc::clone(ctx)) + .execute(partition, ctx) .map(|rbs| FFI_RecordBatchStream::new(rbs, runtime))) } unsafe extern "C" fn name_fn_wrapper(plan: &FFI_ExecutionPlan) -> RString { - let private_data = plan.private_data as *const ExecutionPlanPrivateData; - let plan = &(*private_data).plan; - - plan.name().into() + plan.inner().name().into() } unsafe extern "C" fn release_fn_wrapper(plan: &mut FFI_ExecutionPlan) { @@ -136,7 +149,7 @@ unsafe extern "C" fn clone_fn_wrapper(plan: &FFI_ExecutionPlan) -> FFI_Execution FFI_ExecutionPlan::new( Arc::clone(&plan_data.plan), - Arc::clone(&plan_data.context), + plan.task_ctx_provider.clone(), plan_data.runtime.clone(), ) } @@ -151,23 +164,22 @@ impl FFI_ExecutionPlan { /// This function is called on the provider's side. pub fn new( plan: Arc, - context: Arc, + task_ctx_provider: impl Into, runtime: Option, ) -> Self { - let private_data = Box::new(ExecutionPlanPrivateData { - plan, - context, - runtime, - }); + let task_ctx_provider = task_ctx_provider.into(); + let private_data = Box::new(ExecutionPlanPrivateData { plan, runtime }); Self { properties: properties_fn_wrapper, children: children_fn_wrapper, name: name_fn_wrapper, execute: execute_fn_wrapper, + task_ctx_provider, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -218,10 +230,14 @@ impl DisplayAs for ForeignExecutionPlan { } } -impl TryFrom<&FFI_ExecutionPlan> for ForeignExecutionPlan { +impl TryFrom<&FFI_ExecutionPlan> for Arc { type Error = DataFusionError; fn try_from(plan: &FFI_ExecutionPlan) -> Result { + if (plan.library_marker_id)() == crate::get_library_marker_id() { + return Ok(Arc::clone(plan.inner())); + } + unsafe { let name = (plan.name)(plan).into(); @@ -230,16 +246,17 @@ impl TryFrom<&FFI_ExecutionPlan> for ForeignExecutionPlan { let children_rvec = (plan.children)(plan); let children = children_rvec .iter() - .map(ForeignExecutionPlan::try_from) - .map(|child| child.map(|c| Arc::new(c) as Arc)) + .map(>::try_from) .collect::>>()?; - Ok(Self { + let plan = ForeignExecutionPlan { name, plan: plan.clone(), properties, children, - }) + }; + + Ok(Arc::new(plan)) } } } @@ -258,10 +275,7 @@ impl ExecutionPlan for ForeignExecutionPlan { } fn children(&self) -> Vec<&Arc> { - self.children - .iter() - .map(|p| p as &Arc) - .collect() + self.children.iter().collect() } fn with_new_children( @@ -290,6 +304,7 @@ impl ExecutionPlan for ForeignExecutionPlan { #[cfg(test)] mod tests { + use super::*; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{ physical_plan::{ @@ -298,8 +313,7 @@ mod tests { }, prelude::SessionContext, }; - - use super::*; + use datafusion_execution::TaskContextProvider; #[derive(Debug)] pub struct EmptyExec { @@ -375,19 +389,20 @@ mod tests { fn test_round_trip_ffi_execution_plan() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); - let ctx = SessionContext::new(); + let ctx = Arc::new(SessionContext::new()) as Arc; let original_plan = Arc::new(EmptyExec::new(schema)); let original_name = original_plan.name().to_string(); - let local_plan = FFI_ExecutionPlan::new(original_plan, ctx.task_ctx(), None); + let mut local_plan = FFI_ExecutionPlan::new(original_plan, ctx, None); + local_plan.library_marker_id = crate::mock_foreign_marker_id; - let foreign_plan: ForeignExecutionPlan = (&local_plan).try_into()?; + let foreign_plan: Arc = (&local_plan).try_into()?; assert!(original_name == foreign_plan.name()); let display = datafusion::physical_plan::display::DisplayableExecutionPlan::new( - &foreign_plan, + foreign_plan.as_ref(), ); let buf = display.one_line().to_string(); @@ -403,16 +418,19 @@ mod tests { fn test_ffi_execution_plan_children() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); - let ctx = SessionContext::new(); + let ctx = Arc::new(SessionContext::new()) as Arc; // Version 1: Adding child to the foreign plan let child_plan = Arc::new(EmptyExec::new(Arc::clone(&schema))); - let child_local = FFI_ExecutionPlan::new(child_plan, ctx.task_ctx(), None); - let child_foreign = Arc::new(ForeignExecutionPlan::try_from(&child_local)?); + let mut child_local = FFI_ExecutionPlan::new(child_plan, Arc::clone(&ctx), None); + child_local.library_marker_id = crate::mock_foreign_marker_id; + let child_foreign = >::try_from(&child_local)?; let parent_plan = Arc::new(EmptyExec::new(Arc::clone(&schema))); - let parent_local = FFI_ExecutionPlan::new(parent_plan, ctx.task_ctx(), None); - let parent_foreign = Arc::new(ForeignExecutionPlan::try_from(&parent_local)?); + let mut parent_local = + FFI_ExecutionPlan::new(parent_plan, Arc::clone(&ctx), None); + parent_local.library_marker_id = crate::mock_foreign_marker_id; + let parent_foreign = >::try_from(&parent_local)?; assert_eq!(parent_foreign.children().len(), 0); assert_eq!(child_foreign.children().len(), 0); @@ -422,16 +440,43 @@ mod tests { // Version 2: Adding child to the local plan let child_plan = Arc::new(EmptyExec::new(Arc::clone(&schema))); - let child_local = FFI_ExecutionPlan::new(child_plan, ctx.task_ctx(), None); - let child_foreign = Arc::new(ForeignExecutionPlan::try_from(&child_local)?); + let mut child_local = FFI_ExecutionPlan::new(child_plan, Arc::clone(&ctx), None); + child_local.library_marker_id = crate::mock_foreign_marker_id; + let child_foreign = >::try_from(&child_local)?; let parent_plan = Arc::new(EmptyExec::new(Arc::clone(&schema))); let parent_plan = parent_plan.with_new_children(vec![child_foreign])?; - let parent_local = FFI_ExecutionPlan::new(parent_plan, ctx.task_ctx(), None); - let parent_foreign = Arc::new(ForeignExecutionPlan::try_from(&parent_local)?); + let mut parent_local = + FFI_ExecutionPlan::new(parent_plan, Arc::clone(&ctx), None); + parent_local.library_marker_id = crate::mock_foreign_marker_id; + let parent_foreign = >::try_from(&parent_local)?; assert_eq!(parent_foreign.children().len(), 1); Ok(()) } + + #[test] + fn test_ffi_execution_plan_local_bypass() { + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + let ctx = Arc::new(SessionContext::new()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; + + let plan = Arc::new(EmptyExec::new(schema)); + + let mut ffi_plan = FFI_ExecutionPlan::new(plan, task_ctx_provider, None); + + // Verify local libraries can be downcast to their original + let foreign_plan: Arc = (&ffi_plan).try_into().unwrap(); + assert!(foreign_plan.as_any().downcast_ref::().is_some()); + + // Verify different library markers generate foreign providers + ffi_plan.library_marker_id = crate::mock_foreign_marker_id; + let foreign_plan: Arc = (&ffi_plan).try_into().unwrap(); + assert!(foreign_plan + .as_any() + .downcast_ref::() + .is_some()); + } } diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index 39eb7babd90d..e9ef27d7f886 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -30,6 +30,7 @@ pub mod arrow_wrappers; pub mod catalog_provider; pub mod catalog_provider_list; +pub mod execution; pub mod execution_plan; pub mod insert_op; pub mod plan_properties; @@ -58,5 +59,31 @@ pub extern "C" fn version() -> u64 { version.major } +static LIBRARY_MARKER: u8 = 0; + +/// This utility is used to determine if two FFI structs are within +/// the same library. It is possible that the interplay between +/// foreign and local functions calls create one FFI struct that +/// references another. It is helpful to determine if a foreign +/// struct is truly foreign or in the same library. If we are in the +/// same library, then we can access the underlying types directly. +/// +/// This function works by checking the address of the library +/// marker. Each library that implements the FFI code will have +/// a different address for the marker. By checking the marker +/// address we can determine if a struct is truly Foreign or is +/// actually within the same originating library. +pub extern "C" fn get_library_marker_id() -> u64 { + &LIBRARY_MARKER as *const u8 as u64 +} + +/// For unit testing in this crate we need to trick the providers +/// into thinking we have a foreign call. We do this by overwriting +/// their `library_marker_id` function to return a different value. +#[cfg(test)] +pub(crate) extern "C" fn mock_foreign_marker_id() -> u64 { + get_library_marker_id() + 1 +} + #[cfg(doctest)] doc_comment::doctest!("../README.md", readme_example_test); diff --git a/datafusion/ffi/src/plan_properties.rs b/datafusion/ffi/src/plan_properties.rs index 48c2698a58c7..ac4ad5451832 100644 --- a/datafusion/ffi/src/plan_properties.rs +++ b/datafusion/ffi/src/plan_properties.rs @@ -17,6 +17,8 @@ use std::{ffi::c_void, sync::Arc}; +use crate::execution::FFI_TaskContextProvider; +use crate::{arrow_wrappers::WrappedSchema, df_result, rresult_return}; use abi_stable::{ std_types::{ RResult::{self, ROk}, @@ -32,8 +34,8 @@ use datafusion::{ execution_plan::{Boundedness, EmissionType}, PlanProperties, }, - prelude::SessionContext, }; +use datafusion_execution::TaskContext; use datafusion_proto::{ physical_plan::{ from_proto::{parse_physical_sort_exprs, parse_protobuf_partitioning}, @@ -44,8 +46,6 @@ use datafusion_proto::{ }; use prost::Message; -use crate::{arrow_wrappers::WrappedSchema, df_result, rresult_return}; - /// A stable struct for sharing [`PlanProperties`] across FFI boundaries. #[repr(C)] #[derive(Debug, StableAbi)] @@ -69,27 +69,41 @@ pub struct FFI_PlanProperties { /// Return the schema of the plan. pub schema: unsafe extern "C" fn(plan: &Self) -> WrappedSchema, + /// Provider for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_provider: FFI_TaskContextProvider, + /// Release the memory of the private data when it is no longer being used. pub release: unsafe extern "C" fn(arg: &mut Self), /// Internal data. This is only to be accessed by the provider of the plan. /// The foreign library should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } struct PlanPropertiesPrivateData { props: PlanProperties, } +impl FFI_PlanProperties { + fn inner(&self) -> &PlanProperties { + let private_data = self.private_data as *const PlanPropertiesPrivateData; + unsafe { &(*private_data).props } + } +} + unsafe extern "C" fn output_partitioning_fn_wrapper( properties: &FFI_PlanProperties, ) -> RResult, RString> { - let private_data = properties.private_data as *const PlanPropertiesPrivateData; - let props = &(*private_data).props; - let codec = DefaultPhysicalExtensionCodec {}; - let partitioning_data = - rresult_return!(serialize_partitioning(props.output_partitioning(), &codec)); + let partitioning_data = rresult_return!(serialize_partitioning( + properties.inner().output_partitioning(), + &codec + )); let output_partitioning = partitioning_data.encode_to_vec(); ROk(output_partitioning.into()) @@ -98,27 +112,20 @@ unsafe extern "C" fn output_partitioning_fn_wrapper( unsafe extern "C" fn emission_type_fn_wrapper( properties: &FFI_PlanProperties, ) -> FFI_EmissionType { - let private_data = properties.private_data as *const PlanPropertiesPrivateData; - let props = &(*private_data).props; - props.emission_type.into() + properties.inner().emission_type.into() } unsafe extern "C" fn boundedness_fn_wrapper( properties: &FFI_PlanProperties, ) -> FFI_Boundedness { - let private_data = properties.private_data as *const PlanPropertiesPrivateData; - let props = &(*private_data).props; - props.boundedness.into() + properties.inner().boundedness.into() } unsafe extern "C" fn output_ordering_fn_wrapper( properties: &FFI_PlanProperties, ) -> RResult, RString> { - let private_data = properties.private_data as *const PlanPropertiesPrivateData; - let props = &(*private_data).props; - let codec = DefaultPhysicalExtensionCodec {}; - let output_ordering = match props.output_ordering() { + let output_ordering = match properties.inner().output_ordering() { Some(ordering) => { let physical_sort_expr_nodes = rresult_return!( serialize_physical_sort_exprs(ordering.to_owned(), &codec) @@ -135,10 +142,7 @@ unsafe extern "C" fn output_ordering_fn_wrapper( } unsafe extern "C" fn schema_fn_wrapper(properties: &FFI_PlanProperties) -> WrappedSchema { - let private_data = properties.private_data as *const PlanPropertiesPrivateData; - let props = &(*private_data).props; - - let schema: SchemaRef = Arc::clone(props.eq_properties.schema()); + let schema: SchemaRef = Arc::clone(properties.inner().eq_properties.schema()); schema.into() } @@ -154,8 +158,12 @@ impl Drop for FFI_PlanProperties { } } -impl From<&PlanProperties> for FFI_PlanProperties { - fn from(props: &PlanProperties) -> Self { +impl FFI_PlanProperties { + pub fn new( + props: &PlanProperties, + task_ctx_provider: impl Into, + ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); let private_data = Box::new(PlanPropertiesPrivateData { props: props.clone(), }); @@ -166,8 +174,10 @@ impl From<&PlanProperties> for FFI_PlanProperties { boundedness: boundedness_fn_wrapper, output_ordering: output_ordering_fn_wrapper, schema: schema_fn_wrapper, + task_ctx_provider, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -176,12 +186,15 @@ impl TryFrom for PlanProperties { type Error = DataFusionError; fn try_from(ffi_props: FFI_PlanProperties) -> Result { + if (ffi_props.library_marker_id)() == crate::get_library_marker_id() { + return Ok(ffi_props.inner().clone()); + } + let ffi_schema = unsafe { (ffi_props.schema)(&ffi_props) }; let schema = (&ffi_schema.0).try_into()?; - // TODO Extend FFI to get the registry and codex - let default_ctx = SessionContext::new(); - let task_context = default_ctx.task_ctx(); + // TODO Extend FFI to get the codec + let task_ctx: Arc = (&ffi_props.task_ctx_provider).try_into()?; let codex = DefaultPhysicalExtensionCodec {}; let ffi_orderings = unsafe { (ffi_props.output_ordering)(&ffi_props) }; @@ -191,7 +204,7 @@ impl TryFrom for PlanProperties { .map_err(|e| DataFusionError::External(Box::new(e)))?; let sort_exprs = parse_physical_sort_exprs( &proto_output_ordering.physical_sort_expr_nodes, - &task_context, + task_ctx.as_ref(), &schema, &codex, )?; @@ -203,7 +216,7 @@ impl TryFrom for PlanProperties { .map_err(|e| DataFusionError::External(Box::new(e)))?; let partitioning = parse_protobuf_partitioning( Some(&proto_output_partitioning), - &task_context, + task_ctx.as_ref(), &schema, &codex, )? @@ -300,12 +313,12 @@ impl From for EmissionType { #[cfg(test)] mod tests { - use datafusion::{physical_expr::PhysicalSortExpr, physical_plan::Partitioning}; - use super::*; + use datafusion::prelude::SessionContext; + use datafusion::{physical_expr::PhysicalSortExpr, physical_plan::Partitioning}; + use datafusion_execution::TaskContextProvider; - #[test] - fn test_round_trip_ffi_plan_properties() -> Result<()> { + fn create_test_props() -> Result { use arrow::datatypes::{DataType, Field, Schema}; let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); @@ -314,14 +327,23 @@ mod tests { let _ = eqp.reorder([PhysicalSortExpr::new_default( datafusion::physical_plan::expressions::col("a", &schema)?, )]); - let original_props = PlanProperties::new( + Ok(PlanProperties::new( eqp, Partitioning::RoundRobinBatch(3), EmissionType::Incremental, Boundedness::Bounded, - ); + )) + } + + #[test] + fn test_round_trip_ffi_plan_properties() -> Result<()> { + let original_props = create_test_props()?; + let ctx = Arc::new(SessionContext::default()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; - let local_props_ptr = FFI_PlanProperties::from(&original_props); + let mut local_props_ptr = + FFI_PlanProperties::new(&original_props, task_ctx_provider); + local_props_ptr.library_marker_id = crate::mock_foreign_marker_id; let foreign_props: PlanProperties = local_props_ptr.try_into()?; @@ -329,4 +351,25 @@ mod tests { Ok(()) } + + #[test] + fn test_ffi_plan_properties_local_bypass() -> Result<()> { + let props = create_test_props()?; + let ctx = Arc::new(SessionContext::default()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; + + let ffi_plan = FFI_PlanProperties::new(&props, Arc::clone(&task_ctx_provider)); + + // Verify local libraries + let foreign_plan: PlanProperties = ffi_plan.try_into()?; + assert_eq!(format!("{foreign_plan:?}"), format!("{:?}", foreign_plan)); + + // Verify different library markers still can produce identical properties + let mut ffi_plan = FFI_PlanProperties::new(&props, task_ctx_provider); + ffi_plan.library_marker_id = crate::mock_foreign_marker_id; + let foreign_plan: PlanProperties = ffi_plan.try_into()?; + assert_eq!(format!("{foreign_plan:?}"), format!("{:?}", foreign_plan)); + + Ok(()) + } } diff --git a/datafusion/ffi/src/schema_provider.rs b/datafusion/ffi/src/schema_provider.rs index b5970d5881d6..4a1a6bb66331 100644 --- a/datafusion/ffi/src/schema_provider.rs +++ b/datafusion/ffi/src/schema_provider.rs @@ -34,6 +34,7 @@ use crate::{ table_provider::{FFI_TableProvider, ForeignTableProvider}, }; +use crate::execution::FFI_TaskContextProvider; use datafusion::error::Result; /// A stable struct for sharing [`SchemaProvider`] across FFI boundaries. @@ -67,6 +68,10 @@ pub struct FFI_SchemaProvider { pub table_exist: unsafe extern "C" fn(provider: &Self, name: RString) -> bool, + /// Provider for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_provider: FFI_TaskContextProvider, + /// Used to create a clone on the provider of the execution plan. This should /// only need to be called by the receiver of the plan. pub clone: unsafe extern "C" fn(plan: &Self) -> Self, @@ -80,6 +85,10 @@ pub struct FFI_SchemaProvider { /// Internal data. This is only to be accessed by the provider of the plan. /// A [`ForeignSchemaProvider`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_SchemaProvider {} @@ -116,11 +125,12 @@ unsafe extern "C" fn table_fn_wrapper( name: RString, ) -> FfiFuture, RString>> { let runtime = provider.runtime(); + let task_ctx_provider = provider.task_ctx_provider.clone(); let provider = Arc::clone(provider.inner()); async move { let table = rresult_return!(provider.table(name.as_str()).await) - .map(|t| FFI_TableProvider::new(t, true, runtime)) + .map(|t| FFI_TableProvider::new(t, true, runtime, task_ctx_provider)) .into(); RResult::ROk(table) @@ -134,12 +144,13 @@ unsafe extern "C" fn register_table_fn_wrapper( table: FFI_TableProvider, ) -> RResult, RString> { let runtime = provider.runtime(); + let task_ctx_provider = provider.task_ctx_provider.clone(); let provider = provider.inner(); let table = Arc::new(ForeignTableProvider(table)); let returned_table = rresult_return!(provider.register_table(name.into(), table)) - .map(|t| FFI_TableProvider::new(t, true, runtime)); + .map(|t| FFI_TableProvider::new(t, true, runtime, task_ctx_provider)); RResult::ROk(returned_table.into()) } @@ -149,10 +160,11 @@ unsafe extern "C" fn deregister_table_fn_wrapper( name: RString, ) -> RResult, RString> { let runtime = provider.runtime(); + let task_ctx_provider = provider.task_ctx_provider.clone(); let provider = provider.inner(); let returned_table = rresult_return!(provider.deregister_table(name.as_str())) - .map(|t| FFI_TableProvider::new(t, true, runtime)); + .map(|t| FFI_TableProvider::new(t, true, runtime, task_ctx_provider)); RResult::ROk(returned_table.into()) } @@ -183,14 +195,16 @@ unsafe extern "C" fn clone_fn_wrapper( FFI_SchemaProvider { owner_name: provider.owner_name.clone(), table_names: table_names_fn_wrapper, - clone: clone_fn_wrapper, - release: release_fn_wrapper, - version: super::version, - private_data, table: table_fn_wrapper, register_table: register_table_fn_wrapper, deregister_table: deregister_table_fn_wrapper, table_exist: table_exist_fn_wrapper, + task_ctx_provider: provider.task_ctx_provider.clone(), + clone: clone_fn_wrapper, + release: release_fn_wrapper, + version: super::version, + private_data, + library_marker_id: crate::get_library_marker_id, } } @@ -205,21 +219,25 @@ impl FFI_SchemaProvider { pub fn new( provider: Arc, runtime: Option, + task_ctx_provider: impl Into, ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); let owner_name = provider.owner_name().map(|s| s.into()).into(); let private_data = Box::new(ProviderPrivateData { provider, runtime }); Self { owner_name, table_names: table_names_fn_wrapper, - clone: clone_fn_wrapper, - release: release_fn_wrapper, - version: super::version, - private_data: Box::into_raw(private_data) as *mut c_void, table: table_fn_wrapper, register_table: register_table_fn_wrapper, deregister_table: deregister_table_fn_wrapper, table_exist: table_exist_fn_wrapper, + task_ctx_provider, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + version: super::version, + private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -234,9 +252,14 @@ pub struct ForeignSchemaProvider(pub FFI_SchemaProvider); unsafe impl Send for ForeignSchemaProvider {} unsafe impl Sync for ForeignSchemaProvider {} -impl From<&FFI_SchemaProvider> for ForeignSchemaProvider { +impl From<&FFI_SchemaProvider> for Arc { fn from(provider: &FFI_SchemaProvider) -> Self { - Self(provider.clone()) + if (provider.library_marker_id)() == crate::get_library_marker_id() { + return Arc::clone(unsafe { provider.inner() }); + } + + Arc::new(ForeignSchemaProvider(provider.clone())) + as Arc } } @@ -274,9 +297,7 @@ impl SchemaProvider for ForeignSchemaProvider { let table: Option = df_result!((self.0.table)(&self.0, name.into()).await)?.into(); - let table = table.as_ref().map(|t| { - Arc::new(ForeignTableProvider::from(t)) as Arc - }); + let table = table.as_ref().map(>::from); Ok(table) } @@ -290,7 +311,12 @@ impl SchemaProvider for ForeignSchemaProvider { unsafe { let ffi_table = match table.as_any().downcast_ref::() { Some(t) => t.0.clone(), - None => FFI_TableProvider::new(table, true, None), + None => FFI_TableProvider::new( + table, + true, + None, + self.0.task_ctx_provider.clone(), + ), }; let returned_provider: Option = @@ -319,10 +345,11 @@ impl SchemaProvider for ForeignSchemaProvider { #[cfg(test)] mod tests { + use super::*; use arrow::datatypes::Schema; + use datafusion::prelude::SessionContext; use datafusion::{catalog::MemorySchemaProvider, datasource::empty::EmptyTable}; - - use super::*; + use datafusion_execution::TaskContextProvider; fn empty_table() -> Arc { Arc::new(EmptyTable::new(Arc::new(Schema::empty()))) @@ -337,9 +364,11 @@ mod tests { .unwrap() .is_none()); - let ffi_schema_provider = FFI_SchemaProvider::new(schema_provider, None); + let ctx = Arc::new(SessionContext::new()) as Arc; + let mut ffi_schema_provider = FFI_SchemaProvider::new(schema_provider, None, ctx); + ffi_schema_provider.library_marker_id = crate::mock_foreign_marker_id; - let foreign_schema_provider: ForeignSchemaProvider = + let foreign_schema_provider: Arc = (&ffi_schema_provider).into(); let prior_table_names = foreign_schema_provider.table_names(); @@ -382,4 +411,27 @@ mod tests { assert!(returned_schema.is_some()); assert!(foreign_schema_provider.table_exist("second_table")); } + + #[test] + fn test_ffi_schema_provider_local_bypass() { + let schema_provider = Arc::new(MemorySchemaProvider::new()); + let ctx = Arc::new(SessionContext::new()) as Arc; + + let mut ffi_schema = FFI_SchemaProvider::new(schema_provider, None, ctx); + + // Verify local libraries can be downcast to their original + let foreign_schema: Arc = (&ffi_schema).into(); + assert!(foreign_schema + .as_any() + .downcast_ref::() + .is_some()); + + // Verify different library markers generate foreign providers + ffi_schema.library_marker_id = crate::mock_foreign_marker_id; + let foreign_schema: Arc = (&ffi_schema).into(); + assert!(foreign_schema + .as_any() + .downcast_ref::() + .is_some()); + } } diff --git a/datafusion/ffi/src/session_config.rs b/datafusion/ffi/src/session_config.rs index a07b66c60196..de83d19ca44f 100644 --- a/datafusion/ffi/src/session_config.rs +++ b/datafusion/ffi/src/session_config.rs @@ -19,13 +19,9 @@ use abi_stable::{ std_types::{RHashMap, RString}, StableAbi, }; -use datafusion::{config::ConfigOptions, error::Result}; -use datafusion::{error::DataFusionError, prelude::SessionConfig}; -use std::sync::Arc; -use std::{ - collections::HashMap, - ffi::{c_char, c_void, CString}, -}; +use datafusion_common::error::{DataFusionError, Result}; +use datafusion_execution::config::SessionConfig; +use std::{collections::HashMap, ffi::c_void}; /// A stable struct for sharing [`SessionConfig`] across FFI boundaries. /// Instead of attempting to expose the entire SessionConfig interface, we @@ -54,18 +50,27 @@ pub struct FFI_SessionConfig { pub release: unsafe extern "C" fn(arg: &mut Self), /// Internal data. This is only to be accessed by the provider of the plan. - /// A [`ForeignSessionConfig`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_SessionConfig {} unsafe impl Sync for FFI_SessionConfig {} +impl FFI_SessionConfig { + fn inner(&self) -> &SessionConfig { + let private_data = self.private_data as *mut SessionConfigPrivateData; + unsafe { &(*private_data).config } + } +} + unsafe extern "C" fn config_options_fn_wrapper( config: &FFI_SessionConfig, ) -> RHashMap { - let private_data = config.private_data as *mut SessionConfigPrivateData; - let config_options = &(*private_data).config; + let config_options = config.inner().options(); let mut options = RHashMap::default(); for config_entry in config_options.entries() { @@ -85,7 +90,7 @@ unsafe extern "C" fn release_fn_wrapper(config: &mut FFI_SessionConfig) { unsafe extern "C" fn clone_fn_wrapper(config: &FFI_SessionConfig) -> FFI_SessionConfig { let old_private_data = config.private_data as *mut SessionConfigPrivateData; - let old_config = Arc::clone(&(*old_private_data).config); + let old_config = (*old_private_data).config.clone(); let private_data = Box::new(SessionConfigPrivateData { config: old_config }); @@ -94,31 +99,18 @@ unsafe extern "C" fn clone_fn_wrapper(config: &FFI_SessionConfig) -> FFI_Session private_data: Box::into_raw(private_data) as *mut c_void, clone: clone_fn_wrapper, release: release_fn_wrapper, + library_marker_id: crate::get_library_marker_id, } } struct SessionConfigPrivateData { - pub config: Arc, + pub config: SessionConfig, } impl From<&SessionConfig> for FFI_SessionConfig { fn from(session: &SessionConfig) -> Self { - let mut config_keys = Vec::new(); - let mut config_values = Vec::new(); - for config_entry in session.options().entries() { - if let Some(value) = config_entry.value { - let key_cstr = CString::new(config_entry.key).unwrap_or_default(); - let key_ptr = key_cstr.into_raw() as *const c_char; - config_keys.push(key_ptr); - - config_values - .push(CString::new(value).unwrap_or_default().into_raw() - as *const c_char); - } - } - let private_data = Box::new(SessionConfigPrivateData { - config: Arc::clone(session.options()), + config: session.clone(), }); Self { @@ -126,6 +118,7 @@ impl From<&SessionConfig> for FFI_SessionConfig { private_data: Box::into_raw(private_data) as *mut c_void, clone: clone_fn_wrapper, release: release_fn_wrapper, + library_marker_id: crate::get_library_marker_id, } } } @@ -142,16 +135,14 @@ impl Drop for FFI_SessionConfig { } } -/// A wrapper struct for accessing [`SessionConfig`] across a FFI boundary. -/// The [`SessionConfig`] will be generated from a hash map of the config -/// options in the provider and will be reconstructed on this side of the -/// interface.s -pub struct ForeignSessionConfig(pub SessionConfig); - -impl TryFrom<&FFI_SessionConfig> for ForeignSessionConfig { +impl TryFrom<&FFI_SessionConfig> for SessionConfig { type Error = DataFusionError; fn try_from(config: &FFI_SessionConfig) -> Result { + if (config.library_marker_id)() == crate::get_library_marker_id() { + return Ok(config.inner().clone()); + } + let config_options = unsafe { (config.config_options)(config) }; let mut options_map = HashMap::new(); @@ -159,7 +150,7 @@ impl TryFrom<&FFI_SessionConfig> for ForeignSessionConfig { options_map.insert(kv_pair.0.to_string(), kv_pair.1.to_string()); }); - Ok(Self(SessionConfig::from_string_hash_map(&options_map)?)) + SessionConfig::from_string_hash_map(&options_map) } } @@ -172,13 +163,15 @@ mod tests { let session_config = SessionConfig::new(); let original_options = session_config.options().entries(); - let ffi_config: FFI_SessionConfig = (&session_config).into(); + let mut ffi_config: FFI_SessionConfig = (&session_config).into(); + let _ = ffi_config.clone(); + ffi_config.library_marker_id = crate::mock_foreign_marker_id; - let foreign_config: ForeignSessionConfig = (&ffi_config).try_into()?; + let foreign_config: SessionConfig = (&ffi_config).try_into()?; - let returned_options = foreign_config.0.options().entries(); + let returned_options = foreign_config.options().entries(); - assert!(original_options.len() == returned_options.len()); + assert_eq!(original_options.len(), returned_options.len()); Ok(()) } diff --git a/datafusion/ffi/src/table_provider.rs b/datafusion/ffi/src/table_provider.rs index 890511997a70..2e7cec2c1ec5 100644 --- a/datafusion/ffi/src/table_provider.rs +++ b/datafusion/ffi/src/table_provider.rs @@ -45,16 +45,16 @@ use tokio::runtime::Handle; use crate::{ arrow_wrappers::WrappedSchema, df_result, rresult_return, - session_config::ForeignSessionConfig, table_source::{FFI_TableProviderFilterPushDown, FFI_TableType}, }; use super::{ - execution_plan::{FFI_ExecutionPlan, ForeignExecutionPlan}, - insert_op::FFI_InsertOp, + execution_plan::FFI_ExecutionPlan, insert_op::FFI_InsertOp, session_config::FFI_SessionConfig, }; +use crate::execution::FFI_TaskContextProvider; use datafusion::error::Result; +use datafusion_execution::config::SessionConfig; /// A stable struct for sharing [`TableProvider`] across FFI boundaries. /// @@ -143,6 +143,10 @@ pub struct FFI_TableProvider { insert_op: FFI_InsertOp, ) -> FfiFuture>, + /// Provider for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_provider: FFI_TaskContextProvider, + /// Used to create a clone on the provider of the execution plan. This should /// only need to be called by the receiver of the plan. pub clone: unsafe extern "C" fn(plan: &Self) -> Self, @@ -154,8 +158,12 @@ pub struct FFI_TableProvider { pub version: unsafe extern "C" fn() -> u64, /// Internal data. This is only to be accessed by the provider of the plan. - /// A [`ForeignExecutionPlan`] should never attempt to access this data. + /// A [`ForeignTableProvider`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_TableProvider {} @@ -166,27 +174,33 @@ struct ProviderPrivateData { runtime: Option, } -unsafe extern "C" fn schema_fn_wrapper(provider: &FFI_TableProvider) -> WrappedSchema { - let private_data = provider.private_data as *const ProviderPrivateData; - let provider = &(*private_data).provider; +impl FFI_TableProvider { + fn inner(&self) -> &Arc { + let private_data = self.private_data as *const ProviderPrivateData; + unsafe { &(*private_data).provider } + } + + fn runtime(&self) -> &Option { + let private_data = self.private_data as *const ProviderPrivateData; + unsafe { &(*private_data).runtime } + } +} - provider.schema().into() +unsafe extern "C" fn schema_fn_wrapper(provider: &FFI_TableProvider) -> WrappedSchema { + provider.inner().schema().into() } unsafe extern "C" fn table_type_fn_wrapper( provider: &FFI_TableProvider, ) -> FFI_TableType { - let private_data = provider.private_data as *const ProviderPrivateData; - let provider = &(*private_data).provider; - - provider.table_type().into() + provider.inner().table_type().into() } fn supports_filters_pushdown_internal( provider: &Arc, filters_serialized: &[u8], + task_ctx: &Arc, ) -> Result> { - let default_ctx = SessionContext::new(); let codec = DefaultLogicalExtensionCodec {}; let filters = match filters_serialized.is_empty() { @@ -195,7 +209,7 @@ fn supports_filters_pushdown_internal( let proto_filters = LogicalExprList::decode(filters_serialized) .map_err(|e| DataFusionError::Plan(e.to_string()))?; - parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec)? + parse_exprs(proto_filters.expr.iter(), task_ctx.as_ref(), &codec)? } }; let filters_borrowed: Vec<&Expr> = filters.iter().collect(); @@ -213,10 +227,9 @@ unsafe extern "C" fn supports_filters_pushdown_fn_wrapper( provider: &FFI_TableProvider, filters_serialized: RVec, ) -> RResult, RString> { - let private_data = provider.private_data as *const ProviderPrivateData; - let provider = &(*private_data).provider; - - supports_filters_pushdown_internal(provider, &filters_serialized) + let task_ctx = + rresult_return!(>::try_from(&provider.task_ctx_provider)); + supports_filters_pushdown_internal(provider.inner(), &filters_serialized, &task_ctx) .map_err(|e| e.to_string().into()) .into() } @@ -228,23 +241,25 @@ unsafe extern "C" fn scan_fn_wrapper( filters_serialized: RVec, limit: ROption, ) -> FfiFuture> { - let private_data = provider.private_data as *mut ProviderPrivateData; - let internal_provider = &(*private_data).provider; + let task_ctx_provider = provider.task_ctx_provider.clone(); + let task_ctx: Result, DataFusionError> = + (&provider.task_ctx_provider).try_into(); + let runtime = provider.runtime().clone(); + let internal_provider = Arc::clone(provider.inner()); let session_config = session_config.clone(); - let runtime = &(*private_data).runtime; async move { - let config = rresult_return!(ForeignSessionConfig::try_from(&session_config)); + let task_ctx = rresult_return!(task_ctx); + let config = rresult_return!(SessionConfig::try_from(&session_config)); let session = SessionStateBuilder::new() .with_default_features() - .with_config(config.0) + .with_config(config) .build(); let ctx = SessionContext::new_with_state(session); let filters = match filters_serialized.is_empty() { true => vec![], false => { - let default_ctx = SessionContext::new(); let codec = DefaultLogicalExtensionCodec {}; let proto_filters = @@ -252,7 +267,7 @@ unsafe extern "C" fn scan_fn_wrapper( rresult_return!(parse_exprs( proto_filters.expr.iter(), - &default_ctx, + task_ctx.as_ref(), &codec )) } @@ -268,7 +283,7 @@ unsafe extern "C" fn scan_fn_wrapper( RResult::ROk(FFI_ExecutionPlan::new( plan, - ctx.task_ctx(), + task_ctx_provider, runtime.clone(), )) } @@ -281,33 +296,32 @@ unsafe extern "C" fn insert_into_fn_wrapper( input: &FFI_ExecutionPlan, insert_op: FFI_InsertOp, ) -> FfiFuture> { - let private_data = provider.private_data as *mut ProviderPrivateData; - let internal_provider = &(*private_data).provider; + let task_ctx_provider = provider.task_ctx_provider.clone(); + let runtime = provider.runtime().clone(); + let internal_provider = Arc::clone(provider.inner()); let session_config = session_config.clone(); let input = input.clone(); - let runtime = &(*private_data).runtime; async move { - let config = rresult_return!(ForeignSessionConfig::try_from(&session_config)); + let config = rresult_return!(SessionConfig::try_from(&session_config)); let session = SessionStateBuilder::new() .with_default_features() - .with_config(config.0) + .with_config(config) .build(); - let ctx = SessionContext::new_with_state(session); - let input = rresult_return!(ForeignExecutionPlan::try_from(&input).map(Arc::new)); + let input = rresult_return!(>::try_from(&input)); let insert_op = InsertOp::from(insert_op); let plan = rresult_return!( internal_provider - .insert_into(&ctx.state(), input, insert_op) + .insert_into(&session, input, insert_op) .await ); RResult::ROk(FFI_ExecutionPlan::new( plan, - ctx.task_ctx(), + task_ctx_provider, runtime.clone(), )) } @@ -320,11 +334,11 @@ unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_TableProvider) { } unsafe extern "C" fn clone_fn_wrapper(provider: &FFI_TableProvider) -> FFI_TableProvider { - let old_private_data = provider.private_data as *const ProviderPrivateData; - let runtime = (*old_private_data).runtime.clone(); + let runtime = provider.runtime().clone(); + let old_provider = Arc::clone(provider.inner()); let private_data = Box::into_raw(Box::new(ProviderPrivateData { - provider: Arc::clone(&(*old_private_data).provider), + provider: old_provider, runtime, })) as *mut c_void; @@ -334,10 +348,12 @@ unsafe extern "C" fn clone_fn_wrapper(provider: &FFI_TableProvider) -> FFI_Table table_type: table_type_fn_wrapper, supports_filters_pushdown: provider.supports_filters_pushdown, insert_into: provider.insert_into, + task_ctx_provider: provider.task_ctx_provider.clone(), clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, private_data, + library_marker_id: crate::get_library_marker_id, } } @@ -353,7 +369,9 @@ impl FFI_TableProvider { provider: Arc, can_support_pushdown_filters: bool, runtime: Option, + task_ctx_provider: impl Into, ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); let private_data = Box::new(ProviderPrivateData { provider, runtime }); Self { @@ -365,10 +383,12 @@ impl FFI_TableProvider { false => None, }, insert_into: insert_into_fn_wrapper, + task_ctx_provider, clone: clone_fn_wrapper, release: release_fn_wrapper, version: super::version, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -383,9 +403,13 @@ pub struct ForeignTableProvider(pub FFI_TableProvider); unsafe impl Send for ForeignTableProvider {} unsafe impl Sync for ForeignTableProvider {} -impl From<&FFI_TableProvider> for ForeignTableProvider { +impl From<&FFI_TableProvider> for Arc { fn from(provider: &FFI_TableProvider) -> Self { - Self(provider.clone()) + if (provider.library_marker_id)() == crate::get_library_marker_id() { + Arc::clone(provider.inner()) as Arc + } else { + Arc::new(ForeignTableProvider(provider.clone())) + } } } @@ -438,10 +462,10 @@ impl TableProvider for ForeignTableProvider { ) .await; - ForeignExecutionPlan::try_from(&df_result!(maybe_plan)?)? + >::try_from(&df_result!(maybe_plan)?)? }; - Ok(Arc::new(plan)) + Ok(plan) } /// Tests whether the table provider can make use of a filter expression @@ -483,30 +507,28 @@ impl TableProvider for ForeignTableProvider { let session_config: FFI_SessionConfig = session.config().into(); let rc = Handle::try_current().ok(); - let input = - FFI_ExecutionPlan::new(input, Arc::new(TaskContext::from(session)), rc); + let input = FFI_ExecutionPlan::new(input, self.0.task_ctx_provider.clone(), rc); let insert_op: FFI_InsertOp = insert_op.into(); let plan = unsafe { let maybe_plan = (self.0.insert_into)(&self.0, &session_config, &input, insert_op).await; - ForeignExecutionPlan::try_from(&df_result!(maybe_plan)?)? + >::try_from(&df_result!(maybe_plan)?)? }; - Ok(Arc::new(plan)) + Ok(plan) } } #[cfg(test)] mod tests { + use super::*; use arrow::datatypes::Schema; use datafusion::prelude::{col, lit}; + use datafusion_execution::TaskContextProvider; - use super::*; - - #[tokio::test] - async fn test_round_trip_ffi_table_provider_scan() -> Result<()> { + fn create_test_table_provider() -> Result> { use arrow::datatypes::Field; use datafusion::arrow::{ array::Float32Array, datatypes::DataType, record_batch::RecordBatch, @@ -526,16 +548,25 @@ mod tests { vec![Arc::new(Float32Array::from(vec![64.0]))], )?; - let ctx = SessionContext::new(); + Ok(Arc::new(MemTable::try_new( + schema, + vec![vec![batch1], vec![batch2]], + )?)) + } - let provider = - Arc::new(MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?); + #[tokio::test] + async fn test_round_trip_ffi_table_provider_scan() -> Result<()> { + let provider = create_test_table_provider()?; + let ctx = Arc::new(SessionContext::new()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; - let ffi_provider = FFI_TableProvider::new(provider, true, None); + let mut ffi_provider = + FFI_TableProvider::new(provider, true, None, task_ctx_provider); + ffi_provider.library_marker_id = crate::mock_foreign_marker_id; - let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into(); + let foreign_table_provider: Arc = (&ffi_provider).into(); - ctx.register_table("t", Arc::new(foreign_table_provider))?; + ctx.register_table("t", foreign_table_provider)?; let df = ctx.table("t").await?; @@ -549,35 +580,17 @@ mod tests { #[tokio::test] async fn test_round_trip_ffi_table_provider_insert_into() -> Result<()> { - use arrow::datatypes::Field; - use datafusion::arrow::{ - array::Float32Array, datatypes::DataType, record_batch::RecordBatch, - }; - use datafusion::datasource::MemTable; - - let schema = - Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); - - // define data in two partitions - let batch1 = RecordBatch::try_new( - Arc::clone(&schema), - vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))], - )?; - let batch2 = RecordBatch::try_new( - Arc::clone(&schema), - vec![Arc::new(Float32Array::from(vec![64.0]))], - )?; + let provider = create_test_table_provider()?; + let ctx = Arc::new(SessionContext::new()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; - let ctx = SessionContext::new(); + let mut ffi_provider = + FFI_TableProvider::new(provider, true, None, task_ctx_provider); + ffi_provider.library_marker_id = crate::mock_foreign_marker_id; - let provider = - Arc::new(MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?); + let foreign_table_provider: Arc = (&ffi_provider).into(); - let ffi_provider = FFI_TableProvider::new(provider, true, None); - - let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into(); - - ctx.register_table("t", Arc::new(foreign_table_provider))?; + ctx.register_table("t", foreign_table_provider)?; let result = ctx .sql("INSERT INTO t VALUES (128.0);") @@ -615,15 +628,17 @@ mod tests { vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))], )?; - let ctx = SessionContext::new(); + let ctx = Arc::new(SessionContext::new()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; let provider = Arc::new(MemTable::try_new(schema, vec![vec![batch1]])?); - let ffi_provider = FFI_TableProvider::new(provider, true, None); + let ffi_provider = + FFI_TableProvider::new(provider, true, None, task_ctx_provider); - let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into(); + let foreign_table_provider: Arc = (&ffi_provider).into(); - ctx.register_table("t", Arc::new(foreign_table_provider))?; + ctx.register_table("t", foreign_table_provider)?; let result = ctx .sql("SELECT COUNT(*) as cnt FROM t") @@ -641,4 +656,29 @@ mod tests { assert_batches_eq!(expected, &result); Ok(()) } + + #[test] + fn test_ffi_table_provider_local_bypass() -> Result<()> { + let table_provider = create_test_table_provider()?; + + let ctx = Arc::new(SessionContext::new()) as Arc; + let mut ffi_table = FFI_TableProvider::new(table_provider, false, None, ctx); + + // Verify local libraries can be downcast to their original + let foreign_table: Arc = (&ffi_table).into(); + assert!(foreign_table + .as_any() + .downcast_ref::() + .is_some()); + + // Verify different library markers generate foreign providers + ffi_table.library_marker_id = crate::mock_foreign_marker_id; + let foreign_table: Arc = (&ffi_table).into(); + assert!(foreign_table + .as_any() + .downcast_ref::() + .is_some()); + + Ok(()) + } } diff --git a/datafusion/ffi/src/tests/async_provider.rs b/datafusion/ffi/src/tests/async_provider.rs index 67421f58805a..9a808cf2bab9 100644 --- a/datafusion/ffi/src/tests/async_provider.rs +++ b/datafusion/ffi/src/tests/async_provider.rs @@ -27,6 +27,8 @@ use std::{any::Any, fmt::Debug, sync::Arc}; +use super::create_record_batch; +use crate::execution::FFI_TaskContextProvider; use crate::table_provider::FFI_TableProvider; use arrow::array::RecordBatch; use arrow::datatypes::Schema; @@ -46,8 +48,6 @@ use tokio::{ sync::{broadcast, mpsc}, }; -use super::create_record_batch; - #[derive(Debug)] pub struct AsyncTableProvider { batch_request: mpsc::Sender, @@ -277,7 +277,14 @@ impl Stream for AsyncTestRecordBatchStream { } } -pub(crate) fn create_async_table_provider() -> FFI_TableProvider { +pub(crate) fn create_async_table_provider( + task_ctx_provider: FFI_TaskContextProvider, +) -> FFI_TableProvider { let (table_provider, tokio_rt) = start_async_provider(); - FFI_TableProvider::new(Arc::new(table_provider), true, Some(tokio_rt)) + FFI_TableProvider::new( + Arc::new(table_provider), + true, + Some(tokio_rt), + task_ctx_provider, + ) } diff --git a/datafusion/ffi/src/tests/catalog.rs b/datafusion/ffi/src/tests/catalog.rs index b6efbdf726e0..86802ac0dde9 100644 --- a/datafusion/ffi/src/tests/catalog.rs +++ b/datafusion/ffi/src/tests/catalog.rs @@ -29,6 +29,7 @@ use std::{any::Any, fmt::Debug, sync::Arc}; use crate::catalog_provider::FFI_CatalogProvider; use crate::catalog_provider_list::FFI_CatalogProviderList; +use crate::execution::FFI_TaskContextProvider; use arrow::datatypes::Schema; use async_trait::async_trait; use datafusion::{ @@ -178,9 +179,11 @@ impl CatalogProvider for FixedCatalogProvider { } } -pub(crate) extern "C" fn create_catalog_provider() -> FFI_CatalogProvider { +pub(crate) extern "C" fn create_catalog_provider( + task_ctx_provider: FFI_TaskContextProvider, +) -> FFI_CatalogProvider { let catalog_provider = Arc::new(FixedCatalogProvider::default()); - FFI_CatalogProvider::new(catalog_provider, None) + FFI_CatalogProvider::new(catalog_provider, None, task_ctx_provider) } /// This catalog provider list is intended only for unit tests. It prepopulates with one @@ -230,7 +233,9 @@ impl CatalogProviderList for FixedCatalogProviderList { } } -pub(crate) extern "C" fn create_catalog_provider_list() -> FFI_CatalogProviderList { +pub(crate) extern "C" fn create_catalog_provider_list( + task_ctx_provider: FFI_TaskContextProvider, +) -> FFI_CatalogProviderList { let catalog_provider_list = Arc::new(FixedCatalogProviderList::default()); - FFI_CatalogProviderList::new(catalog_provider_list, None) + FFI_CatalogProviderList::new(catalog_provider_list, None, task_ctx_provider) } diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index d9b4a61579e9..c2e969452f89 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -35,6 +35,7 @@ use crate::udwf::FFI_WindowUDF; use super::{table_provider::FFI_TableProvider, udf::FFI_ScalarUDF}; use crate::catalog_provider_list::FFI_CatalogProviderList; +use crate::execution::FFI_TaskContextProvider; use crate::tests::catalog::create_catalog_provider_list; use arrow::array::RecordBatch; use async_provider::create_async_table_provider; @@ -62,28 +63,33 @@ pub mod utils; /// module. pub struct ForeignLibraryModule { /// Construct an opinionated catalog provider - pub create_catalog: extern "C" fn() -> FFI_CatalogProvider, + pub create_catalog: extern "C" fn(FFI_TaskContextProvider) -> FFI_CatalogProvider, /// Construct an opinionated catalog provider list - pub create_catalog_list: extern "C" fn() -> FFI_CatalogProviderList, + pub create_catalog_list: + extern "C" fn(FFI_TaskContextProvider) -> FFI_CatalogProviderList, /// Constructs the table provider - pub create_table: extern "C" fn(synchronous: bool) -> FFI_TableProvider, + pub create_table: extern "C" fn( + synchronous: bool, + task_ctx_provider: FFI_TaskContextProvider, + ) -> FFI_TableProvider, /// Create a scalar UDF pub create_scalar_udf: extern "C" fn() -> FFI_ScalarUDF, pub create_nullary_udf: extern "C" fn() -> FFI_ScalarUDF, - pub create_table_function: extern "C" fn() -> FFI_TableFunction, + pub create_table_function: + extern "C" fn(FFI_TaskContextProvider) -> FFI_TableFunction, /// Create an aggregate UDAF using sum - pub create_sum_udaf: extern "C" fn() -> FFI_AggregateUDF, + pub create_sum_udaf: extern "C" fn(FFI_TaskContextProvider) -> FFI_AggregateUDF, /// Create grouping UDAF using stddev - pub create_stddev_udaf: extern "C" fn() -> FFI_AggregateUDF, + pub create_stddev_udaf: extern "C" fn(FFI_TaskContextProvider) -> FFI_AggregateUDF, - pub create_rank_udwf: extern "C" fn() -> FFI_WindowUDF, + pub create_rank_udwf: extern "C" fn(FFI_TaskContextProvider) -> FFI_WindowUDF, pub version: extern "C" fn() -> u64, } @@ -116,10 +122,13 @@ pub fn create_record_batch(start_value: i32, num_values: usize) -> RecordBatch { /// Here we only wish to create a simple table provider as an example. /// We create an in-memory table and convert it to it's FFI counterpart. -extern "C" fn construct_table_provider(synchronous: bool) -> FFI_TableProvider { +extern "C" fn construct_table_provider( + synchronous: bool, + task_ctx_provider: FFI_TaskContextProvider, +) -> FFI_TableProvider { match synchronous { - true => create_sync_table_provider(), - false => create_async_table_provider(), + true => create_sync_table_provider(task_ctx_provider), + false => create_async_table_provider(task_ctx_provider), } } diff --git a/datafusion/ffi/src/tests/sync_provider.rs b/datafusion/ffi/src/tests/sync_provider.rs index ff85e0b15b39..9be3c869dd1b 100644 --- a/datafusion/ffi/src/tests/sync_provider.rs +++ b/datafusion/ffi/src/tests/sync_provider.rs @@ -17,12 +17,14 @@ use std::sync::Arc; +use super::{create_record_batch, create_test_schema}; +use crate::execution::FFI_TaskContextProvider; use crate::table_provider::FFI_TableProvider; use datafusion::datasource::MemTable; -use super::{create_record_batch, create_test_schema}; - -pub(crate) fn create_sync_table_provider() -> FFI_TableProvider { +pub(crate) fn create_sync_table_provider( + task_ctx_provider: FFI_TaskContextProvider, +) -> FFI_TableProvider { let schema = create_test_schema(); // It is useful to create these as multiple record batches @@ -35,5 +37,5 @@ pub(crate) fn create_sync_table_provider() -> FFI_TableProvider { let table_provider = MemTable::try_new(schema, vec![batches]).unwrap(); - FFI_TableProvider::new(Arc::new(table_provider), true, None) + FFI_TableProvider::new(Arc::new(table_provider), true, None, task_ctx_provider) } diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs index 55e31ef3ab77..87a6d9603b80 100644 --- a/datafusion/ffi/src/tests/udf_udaf_udwf.rs +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -28,6 +28,7 @@ use datafusion::{ logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}, }; +use crate::execution::FFI_TaskContextProvider; use std::sync::Arc; pub(crate) extern "C" fn create_ffi_abs_func() -> FFI_ScalarUDF { @@ -42,25 +43,33 @@ pub(crate) extern "C" fn create_ffi_random_func() -> FFI_ScalarUDF { udf.into() } -pub(crate) extern "C" fn create_ffi_table_func() -> FFI_TableFunction { +pub(crate) extern "C" fn create_ffi_table_func( + task_ctx_provider: FFI_TaskContextProvider, +) -> FFI_TableFunction { let udtf: Arc = Arc::new(RangeFunc {}); - FFI_TableFunction::new(udtf, None) + FFI_TableFunction::new(udtf, None, task_ctx_provider) } -pub(crate) extern "C" fn create_ffi_sum_func() -> FFI_AggregateUDF { +pub(crate) extern "C" fn create_ffi_sum_func( + task_ctx_provider: FFI_TaskContextProvider, +) -> FFI_AggregateUDF { let udaf: Arc = Arc::new(Sum::new().into()); - udaf.into() + FFI_AggregateUDF::new(udaf, task_ctx_provider) } -pub(crate) extern "C" fn create_ffi_stddev_func() -> FFI_AggregateUDF { +pub(crate) extern "C" fn create_ffi_stddev_func( + task_ctx_provider: FFI_TaskContextProvider, +) -> FFI_AggregateUDF { let udaf: Arc = Arc::new(Stddev::new().into()); - udaf.into() + FFI_AggregateUDF::new(udaf, task_ctx_provider) } -pub(crate) extern "C" fn create_ffi_rank_func() -> FFI_WindowUDF { +pub(crate) extern "C" fn create_ffi_rank_func( + task_ctx_provider: FFI_TaskContextProvider, +) -> FFI_WindowUDF { let udwf: Arc = Arc::new( Rank::new( "rank_demo".to_string(), @@ -69,5 +78,5 @@ pub(crate) extern "C" fn create_ffi_rank_func() -> FFI_WindowUDF { .into(), ); - udwf.into() + FFI_WindowUDF::new(udwf, task_ctx_provider) } diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index 80b872159f48..227b04902aa2 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, ops::Deref}; - use abi_stable::{ std_types::{RResult, RString, RVec}, StableAbi, @@ -28,6 +26,8 @@ use datafusion::{ scalar::ScalarValue, }; use prost::Message; +use std::ptr::null_mut; +use std::{ffi::c_void, ops::Deref}; use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; @@ -70,6 +70,10 @@ pub struct FFI_Accumulator { /// Internal data. This is only to be accessed by the provider of the accumulator. /// A [`ForeignAccumulator`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_Accumulator {} @@ -173,9 +177,11 @@ unsafe extern "C" fn retract_batch_fn_wrapper( } unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { - let private_data = - Box::from_raw(accumulator.private_data as *mut AccumulatorPrivateData); - drop(private_data); + if !accumulator.private_data.is_null() { + let private_data = + Box::from_raw(accumulator.private_data as *mut AccumulatorPrivateData); + drop(private_data); + } } impl From> for FFI_Accumulator { @@ -193,6 +199,7 @@ impl From> for FFI_Accumulator { supports_retract_batch, release: release_fn_wrapper, private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -217,9 +224,20 @@ pub struct ForeignAccumulator { unsafe impl Send for ForeignAccumulator {} unsafe impl Sync for ForeignAccumulator {} -impl From for ForeignAccumulator { - fn from(accumulator: FFI_Accumulator) -> Self { - Self { accumulator } +impl From for Box { + fn from(mut accumulator: FFI_Accumulator) -> Self { + if (accumulator.library_marker_id)() == crate::get_library_marker_id() { + unsafe { + let private_data = Box::from_raw( + accumulator.private_data as *mut AccumulatorPrivateData, + ); + // We must set this to null to avoid a double free + accumulator.private_data = null_mut(); + private_data.accumulator + } + } else { + Box::new(ForeignAccumulator { accumulator }) + } } } @@ -306,6 +324,7 @@ impl Accumulator for ForeignAccumulator { #[cfg(test)] mod tests { + use super::{FFI_Accumulator, ForeignAccumulator}; use arrow::array::{make_array, Array}; use datafusion::{ common::create_array, error::Result, @@ -313,8 +332,6 @@ mod tests { scalar::ScalarValue, }; - use super::{FFI_Accumulator, ForeignAccumulator}; - #[test] fn test_foreign_avg_accumulator() -> Result<()> { let original_accum = AvgAccumulator::default(); @@ -322,8 +339,9 @@ mod tests { let original_supports_retract = original_accum.supports_retract_batch(); let boxed_accum: Box = Box::new(original_accum); - let ffi_accum: FFI_Accumulator = boxed_accum.into(); - let mut foreign_accum: ForeignAccumulator = ffi_accum.into(); + let mut ffi_accum: FFI_Accumulator = boxed_accum.into(); + ffi_accum.library_marker_id = crate::mock_foreign_marker_id; + let mut foreign_accum: Box = ffi_accum.into(); // Send in an array to average. There are 5 values and it should average to 30.0 let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); @@ -363,4 +381,35 @@ mod tests { Ok(()) } + + #[test] + fn test_ffi_accumulator_local_bypass() -> Result<()> { + let original_accum = AvgAccumulator::default(); + let boxed_accum: Box = Box::new(original_accum); + let original_size = boxed_accum.size(); + + let ffi_accum: FFI_Accumulator = boxed_accum.into(); + + // Verify local libraries can be downcast to their original + let foreign_accum: Box = ffi_accum.into(); + unsafe { + let concrete = &*(foreign_accum.as_ref() as *const dyn Accumulator + as *const AvgAccumulator); + assert_eq!(original_size, concrete.size()); + } + + // Verify different library markers generate foreign accumulator + let original_accum = AvgAccumulator::default(); + let boxed_accum: Box = Box::new(original_accum); + let mut ffi_accum: FFI_Accumulator = boxed_accum.into(); + ffi_accum.library_marker_id = crate::mock_foreign_marker_id; + let foreign_accum: Box = ffi_accum.into(); + unsafe { + let concrete = &*(foreign_accum.as_ref() as *const dyn Accumulator + as *const ForeignAccumulator); + assert_eq!(original_size, concrete.size()); + } + + Ok(()) + } } diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs index 6ac0a0b21d2d..c34b5fe19b41 100644 --- a/datafusion/ffi/src/udaf/accumulator_args.rs +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use crate::arrow_wrappers::WrappedSchema; +use crate::execution::FFI_TaskContextProvider; use abi_stable::{ std_types::{RString, RVec}, StableAbi, @@ -28,9 +29,9 @@ use datafusion::{ error::DataFusionError, logical_expr::function::AccumulatorArgs, physical_expr::{PhysicalExpr, PhysicalSortExpr}, - prelude::SessionContext, }; use datafusion_common::exec_datafusion_err; +use datafusion_execution::TaskContext; use datafusion_proto::{ physical_plan::{ from_proto::{parse_physical_exprs, parse_physical_sort_exprs}, @@ -53,12 +54,18 @@ pub struct FFI_AccumulatorArgs { is_reversed: bool, name: RString, physical_expr_def: RVec, -} -impl TryFrom> for FFI_AccumulatorArgs { - type Error = DataFusionError; + /// Provider for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_provider: FFI_TaskContextProvider, +} - fn try_from(args: AccumulatorArgs) -> Result { +impl FFI_AccumulatorArgs { + pub fn try_new( + args: AccumulatorArgs, + task_ctx_provider: impl Into, + ) -> Result { + let task_ctx_provider = task_ctx_provider.into(); let return_field = WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?); let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?); @@ -86,6 +93,7 @@ impl TryFrom> for FFI_AccumulatorArgs { is_reversed: args.is_reversed, name: args.name.into(), physical_expr_def, + task_ctx_provider, }) } } @@ -120,8 +128,7 @@ impl TryFrom for ForeignAccumulatorArgs { let return_field = Arc::new((&value.return_field.0).try_into()?); let schema = Schema::try_from(&value.schema.0)?; - let default_ctx = SessionContext::new(); - let task_ctx = default_ctx.task_ctx(); + let task_ctx: Arc = (&value.task_ctx_provider).try_into()?; let codex = DefaultPhysicalExtensionCodec {}; let order_bys = parse_physical_sort_exprs( @@ -172,10 +179,13 @@ impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> { mod tests { use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::prelude::SessionContext; use datafusion::{ error::Result, logical_expr::function::AccumulatorArgs, physical_expr::PhysicalSortExpr, physical_plan::expressions::col, }; + use datafusion_execution::TaskContextProvider; + use std::sync::Arc; #[test] fn test_round_trip_accumulator_args() -> Result<()> { @@ -192,8 +202,10 @@ mod tests { exprs: &[col("a", &schema)?], }; let orig_str = format!("{orig_args:?}"); + let ctx = Arc::new(SessionContext::new()); + let task_ctx_accessor = Arc::clone(&ctx) as Arc; - let ffi_args: FFI_AccumulatorArgs = orig_args.try_into()?; + let ffi_args = FFI_AccumulatorArgs::try_new(orig_args, task_ctx_accessor)?; let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?; let round_trip_args: AccumulatorArgs = (&foreign_args).into(); diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index 58a18c69db7c..82095383efb2 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, ops::Deref, sync::Arc}; - use crate::{ arrow_wrappers::{WrappedArray, WrappedSchema}, df_result, rresult, rresult_return, @@ -34,6 +32,8 @@ use datafusion::{ error::{DataFusionError, Result}, logical_expr::{EmitTo, GroupsAccumulator}, }; +use std::ptr::null_mut; +use std::{ffi::c_void, ops::Deref, sync::Arc}; /// A stable struct for sharing [`GroupsAccumulator`] across FFI boundaries. /// For an explanation of each field, see the corresponding function @@ -86,6 +86,10 @@ pub struct FFI_GroupsAccumulator { /// Internal data. This is only to be accessed by the provider of the accumulator. /// A [`ForeignGroupsAccumulator`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_GroupsAccumulator {} @@ -215,9 +219,11 @@ unsafe extern "C" fn convert_to_state_fn_wrapper( } unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) { - let private_data = - Box::from_raw(accumulator.private_data as *mut GroupsAccumulatorPrivateData); - drop(private_data); + if !accumulator.private_data.is_null() { + let private_data = + Box::from_raw(accumulator.private_data as *mut GroupsAccumulatorPrivateData); + drop(private_data); + } } impl From> for FFI_GroupsAccumulator { @@ -236,6 +242,7 @@ impl From> for FFI_GroupsAccumulator { release: release_fn_wrapper, private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -260,9 +267,19 @@ pub struct ForeignGroupsAccumulator { unsafe impl Send for ForeignGroupsAccumulator {} unsafe impl Sync for ForeignGroupsAccumulator {} -impl From for ForeignGroupsAccumulator { - fn from(accumulator: FFI_GroupsAccumulator) -> Self { - Self { accumulator } +impl From for Box { + fn from(mut accumulator: FFI_GroupsAccumulator) -> Self { + if (accumulator.library_marker_id)() == crate::get_library_marker_id() { + unsafe { + let private_data = Box::from_raw( + accumulator.private_data as *mut GroupsAccumulatorPrivateData, + ); + accumulator.private_data = null_mut(); + private_data.accumulator + } + } else { + Box::new(ForeignGroupsAccumulator { accumulator }) + } } } @@ -428,22 +445,24 @@ impl From for EmitTo { #[cfg(test)] mod tests { + use super::{FFI_EmitTo, FFI_GroupsAccumulator, ForeignGroupsAccumulator}; use arrow::array::{make_array, Array, BooleanArray}; + use datafusion::functions_aggregate::stddev::StddevGroupsAccumulator; use datafusion::{ common::create_array, error::Result, logical_expr::{EmitTo, GroupsAccumulator}, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; - - use super::{FFI_EmitTo, FFI_GroupsAccumulator, ForeignGroupsAccumulator}; + use datafusion_functions_aggregate_common::stats::StatsType; #[test] fn test_foreign_avg_accumulator() -> Result<()> { let boxed_accum: Box = Box::new(BooleanGroupsAccumulator::new(|a, b| a && b, true)); - let ffi_accum: FFI_GroupsAccumulator = boxed_accum.into(); - let mut foreign_accum: ForeignGroupsAccumulator = ffi_accum.into(); + let mut ffi_accum: FFI_GroupsAccumulator = boxed_accum.into(); + ffi_accum.library_marker_id = crate::mock_foreign_marker_id; + let mut foreign_accum: Box = ffi_accum.into(); // Send in an array to evaluate. We want a mean of 30 and standard deviation of 4. let values = create_array!(Boolean, vec![true, true, true, false, true, true]); @@ -510,4 +529,35 @@ mod tests { Ok(()) } + + #[test] + fn test_ffi_groups_accumulator_local_bypass_inner() -> Result<()> { + let original_accum = StddevGroupsAccumulator::new(StatsType::Population); + let boxed_accum: Box = Box::new(original_accum); + let original_size = boxed_accum.size(); + + let ffi_accum: FFI_GroupsAccumulator = boxed_accum.into(); + + // Verify local libraries can be downcast to their original + let foreign_accum: Box = ffi_accum.into(); + unsafe { + let concrete = &*(foreign_accum.as_ref() as *const dyn GroupsAccumulator + as *const StddevGroupsAccumulator); + assert_eq!(original_size, concrete.size()); + } + + // Verify different library markers generate foreign accumulator + let original_accum = StddevGroupsAccumulator::new(StatsType::Population); + let boxed_accum: Box = Box::new(original_accum); + let mut ffi_accum: FFI_GroupsAccumulator = boxed_accum.into(); + ffi_accum.library_marker_id = crate::mock_foreign_marker_id; + let foreign_accum: Box = ffi_accum.into(); + unsafe { + let concrete = &*(foreign_accum.as_ref() as *const dyn GroupsAccumulator + as *const ForeignGroupsAccumulator); + assert_eq!(original_size, concrete.size()); + } + + Ok(()) + } } diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index ce5611590b67..4d10787ee877 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -19,7 +19,7 @@ use abi_stable::{ std_types::{ROption, RResult, RStr, RString, RVec}, StableAbi, }; -use accumulator::{FFI_Accumulator, ForeignAccumulator}; +use accumulator::FFI_Accumulator; use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; use arrow::datatypes::{DataType, Field}; use arrow::ffi::FFI_ArrowSchema; @@ -39,10 +39,11 @@ use datafusion::{ }; use datafusion_common::exec_datafusion_err; use datafusion_proto_common::from_proto::parse_proto_fields_to_fields; -use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; +use groups_accumulator::FFI_GroupsAccumulator; use std::hash::{Hash, Hasher}; use std::{ffi::c_void, sync::Arc}; +use crate::execution::FFI_TaskContextProvider; use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; use crate::{ arrow_wrappers::WrappedSchema, @@ -135,6 +136,10 @@ pub struct FFI_AggregateUDF { arg_types: RVec, ) -> RResult, RString>, + /// Provider for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_provider: FFI_TaskContextProvider, + /// Used to create a clone on the provider of the udaf. This should /// only need to be called by the receiver of the udaf. pub clone: unsafe extern "C" fn(udaf: &Self) -> Self, @@ -145,6 +150,10 @@ pub struct FFI_AggregateUDF { /// Internal data. This is only to be accessed by the provider of the udaf. /// A [`ForeignAggregateUDF`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_AggregateUDF {} @@ -236,6 +245,7 @@ unsafe extern "C" fn with_beneficial_ordering_fn_wrapper( udaf: &FFI_AggregateUDF, beneficial_ordering: bool, ) -> RResult, RString> { + let task_ctx_provider = udaf.task_ctx_provider.clone(); let udaf = udaf.inner().as_ref().clone(); let result = rresult_return!(udaf.with_beneficial_ordering(beneficial_ordering)); @@ -243,7 +253,7 @@ unsafe extern "C" fn with_beneficial_ordering_fn_wrapper( .map(|func| func.with_beneficial_ordering(beneficial_ordering)) .transpose()) .flatten() - .map(|func| FFI_AggregateUDF::from(Arc::new(func))); + .map(|func| FFI_AggregateUDF::new(Arc::new(func), task_ctx_provider)); RResult::ROk(result.into()) } @@ -326,7 +336,7 @@ unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) { } unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF { - Arc::clone(udaf.inner()).into() + FFI_AggregateUDF::new(Arc::clone(udaf.inner()), udaf.task_ctx_provider.clone()) } impl Clone for FFI_AggregateUDF { @@ -335,8 +345,12 @@ impl Clone for FFI_AggregateUDF { } } -impl From> for FFI_AggregateUDF { - fn from(udaf: Arc) -> Self { +impl FFI_AggregateUDF { + pub fn new( + udaf: Arc, + task_ctx_provider: impl Into, + ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); let name = udaf.name().into(); let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect(); let is_nullable = udaf.is_nullable(); @@ -358,9 +372,11 @@ impl From> for FFI_AggregateUDF { state_fields: state_fields_fn_wrapper, order_sensitivity: order_sensitivity_fn_wrapper, coerce_types: coerce_types_fn_wrapper, + task_ctx_provider, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -400,18 +416,22 @@ impl Hash for ForeignAggregateUDF { } } -impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF { +impl TryFrom<&FFI_AggregateUDF> for Arc { type Error = DataFusionError; fn try_from(udaf: &FFI_AggregateUDF) -> Result { + if (udaf.library_marker_id)() == crate::get_library_marker_id() { + return Ok(Arc::clone(unsafe { udaf.inner().inner() })); + } + let signature = Signature::user_defined((&udaf.volatility).into()); let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect(); - Ok(Self { + Ok(Arc::new(ForeignAggregateUDF { udaf: udaf.clone(), signature, aliases, - }) + })) } } @@ -451,11 +471,11 @@ impl AggregateUDFImpl for ForeignAggregateUDF { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let args = acc_args.try_into()?; + let args = + FFI_AccumulatorArgs::try_new(acc_args, self.udaf.task_ctx_provider.clone())?; unsafe { - df_result!((self.udaf.accumulator)(&self.udaf, args)).map(|accum| { - Box::new(ForeignAccumulator::from(accum)) as Box - }) + df_result!((self.udaf.accumulator)(&self.udaf, args)) + .map(>::from) } } @@ -499,13 +519,15 @@ impl AggregateUDFImpl for ForeignAggregateUDF { } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { - let args = match FFI_AccumulatorArgs::try_from(args) { - Ok(v) => v, - Err(e) => { - log::warn!("Attempting to convert accumulator arguments: {e}"); - return false; - } - }; + let args = + match FFI_AccumulatorArgs::try_new(args, self.udaf.task_ctx_provider.clone()) + { + Ok(v) => v, + Err(e) => { + log::warn!("Attempting to convert accumulator arguments: {e}"); + return false; + } + }; unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) } } @@ -514,15 +536,12 @@ impl AggregateUDFImpl for ForeignAggregateUDF { &self, args: AccumulatorArgs, ) -> Result> { - let args = FFI_AccumulatorArgs::try_from(args)?; + let args = + FFI_AccumulatorArgs::try_new(args, self.udaf.task_ctx_provider.clone())?; unsafe { - df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args)).map( - |accum| { - Box::new(ForeignGroupsAccumulator::from(accum)) - as Box - }, - ) + df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args)) + .map(>::from) } } @@ -534,11 +553,11 @@ impl AggregateUDFImpl for ForeignAggregateUDF { &self, args: AccumulatorArgs, ) -> Result> { - let args = args.try_into()?; + let args = + FFI_AccumulatorArgs::try_new(args, self.udaf.task_ctx_provider.clone())?; unsafe { - df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args)).map( - |accum| Box::new(ForeignAccumulator::from(accum)) as Box, - ) + df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args)) + .map(>::from) } } @@ -554,10 +573,10 @@ impl AggregateUDFImpl for ForeignAggregateUDF { .into_option(); let result = result - .map(|func| ForeignAggregateUDF::try_from(&func)) + .map(|func| >::try_from(&func)) .transpose()?; - Ok(result.map(|func| Arc::new(func) as Arc)) + Ok(result) } } @@ -613,17 +632,18 @@ impl From for FFI_AggregateOrderSensitivity { #[cfg(test)] mod tests { + use super::*; use arrow::datatypes::Schema; + use datafusion::prelude::SessionContext; use datafusion::{ common::create_array, functions_aggregate::sum::Sum, physical_expr::PhysicalSortExpr, physical_plan::expressions::col, scalar::ScalarValue, }; + use datafusion_execution::TaskContextProvider; use std::any::Any; use std::collections::HashMap; - use super::*; - #[derive(Default, Debug, Hash, Eq, PartialEq)] struct SumWithCopiedMetadata { inner: Sum, @@ -658,13 +678,15 @@ mod tests { fn create_test_foreign_udaf( original_udaf: impl AggregateUDFImpl + 'static, + ctx: &Arc, ) -> Result { let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); - let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + let mut local_udaf = FFI_AggregateUDF::new(Arc::clone(&original_udaf), ctx); + local_udaf.library_marker_id = crate::mock_foreign_marker_id; - let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; - Ok(foreign_udaf.into()) + let foreign_udaf: Arc = (&local_udaf).try_into()?; + Ok(AggregateUDF::new_from_shared_impl(foreign_udaf)) } #[test] @@ -672,13 +694,15 @@ mod tests { let original_udaf = Sum::new(); let original_name = original_udaf.name().to_owned(); let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + let ctx = Arc::new(SessionContext::new()) as Arc; // Convert to FFI format - let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + let mut local_udaf = FFI_AggregateUDF::new(Arc::clone(&original_udaf), ctx); + local_udaf.library_marker_id = crate::mock_foreign_marker_id; // Convert back to native format - let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; - let foreign_udaf: AggregateUDF = foreign_udaf.into(); + let foreign_udaf: Arc = (&local_udaf).try_into()?; + let foreign_udaf = AggregateUDF::new_from_shared_impl(foreign_udaf); assert_eq!(original_name, foreign_udaf.name()); Ok(()) @@ -686,8 +710,9 @@ mod tests { #[test] fn test_foreign_udaf_aliases() -> Result<()> { + let ctx = Arc::new(SessionContext::new()) as Arc; let foreign_udaf = - create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]); + create_test_foreign_udaf(Sum::new(), &ctx)?.with_aliases(["my_function"]); let return_field = foreign_udaf @@ -699,7 +724,8 @@ mod tests { #[test] fn test_foreign_udaf_accumulator() -> Result<()> { - let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + let ctx = Arc::new(SessionContext::new()) as Arc; + let foreign_udaf = create_test_foreign_udaf(Sum::new(), &ctx)?; let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); let acc_args = AccumulatorArgs { @@ -726,13 +752,14 @@ mod tests { fn test_round_trip_udaf_metadata() -> Result<()> { let original_udaf = SumWithCopiedMetadata::default(); let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + let ctx = Arc::new(SessionContext::new()) as Arc; // Convert to FFI format - let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + let local_udaf = FFI_AggregateUDF::new(Arc::clone(&original_udaf), ctx); // Convert back to native format - let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; - let foreign_udaf: AggregateUDF = foreign_udaf.into(); + let foreign_udaf: Arc = (&local_udaf).try_into()?; + let foreign_udaf = AggregateUDF::new_from_shared_impl(foreign_udaf); let metadata: HashMap = [("a_key".to_string(), "a_value".to_string())] @@ -749,8 +776,10 @@ mod tests { #[test] fn test_beneficial_ordering() -> Result<()> { + let ctx = Arc::new(SessionContext::new()) as Arc; let foreign_udaf = create_test_foreign_udaf( datafusion::functions_aggregate::first_last::FirstValue::new(), + &ctx, )?; let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap(); @@ -776,7 +805,8 @@ mod tests { #[test] fn test_sliding_accumulator() -> Result<()> { - let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + let ctx = Arc::new(SessionContext::new()) as Arc; + let foreign_udaf = create_test_foreign_udaf(Sum::new(), &ctx)?; let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); // Note: sum distinct is only support Int64 until now @@ -815,4 +845,27 @@ mod tests { test_round_trip_order_sensitivity(AggregateOrderSensitivity::SoftRequirement); test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial); } + + #[test] + fn test_ffi_udaf_local_bypass() -> Result<()> { + let original_udaf = Sum::new(); + let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + let ctx = Arc::new(SessionContext::default()) as Arc; + + let mut ffi_udaf = FFI_AggregateUDF::new(original_udaf, ctx); + + // Verify local libraries can be downcast to their original + let foreign_udaf: Arc = (&ffi_udaf).try_into()?; + assert!(foreign_udaf.as_any().downcast_ref::().is_some()); + + // Verify different library markers generate foreign providers + ffi_udaf.library_marker_id = crate::mock_foreign_marker_id; + let foreign_udaf: Arc = (&ffi_udaf).try_into()?; + assert!(foreign_udaf + .as_any() + .downcast_ref::() + .is_some()); + + Ok(()) + } } diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 5e59cfc5ecb0..4900ca821be4 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -44,6 +44,7 @@ use datafusion::{ ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, }, }; +use datafusion_common::internal_err; use return_type_args::{ FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned, }; @@ -66,13 +67,6 @@ pub struct FFI_ScalarUDF { /// FFI equivalent to the `volatility` of a [`ScalarUDF`] pub volatility: FFI_Volatility, - /// Determines the return type of the underlying [`ScalarUDF`] based on the - /// argument types. - pub return_type: unsafe extern "C" fn( - udf: &Self, - arg_types: RVec, - ) -> RResult, - /// Determines the return info of the underlying [`ScalarUDF`]. Either this /// or return_type may be implemented on a UDF. pub return_field_from_args: unsafe extern "C" fn( @@ -114,6 +108,10 @@ pub struct FFI_ScalarUDF { /// Internal data. This is only to be accessed by the provider of the udf. /// A [`ForeignScalarUDF`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_ScalarUDF {} @@ -123,34 +121,22 @@ pub struct ScalarUDFPrivateData { pub udf: Arc, } -unsafe extern "C" fn return_type_fn_wrapper( - udf: &FFI_ScalarUDF, - arg_types: RVec, -) -> RResult { - let private_data = udf.private_data as *const ScalarUDFPrivateData; - let udf = &(*private_data).udf; - - let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); - - let return_type = udf - .return_type(&arg_types) - .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from)) - .map(WrappedSchema); - - rresult!(return_type) +impl FFI_ScalarUDF { + fn inner(&self) -> &Arc { + let private_data = self.private_data as *const ScalarUDFPrivateData; + unsafe { &(*private_data).udf } + } } unsafe extern "C" fn return_field_from_args_fn_wrapper( udf: &FFI_ScalarUDF, args: FFI_ReturnFieldArgs, ) -> RResult { - let private_data = udf.private_data as *const ScalarUDFPrivateData; - let udf = &(*private_data).udf; - let args: ForeignReturnFieldArgsOwned = rresult_return!((&args).try_into()); let args_ref: ForeignReturnFieldArgs = (&args).into(); let return_type = udf + .inner() .return_field_from_args((&args_ref).into()) .and_then(|f| FFI_ArrowSchema::try_from(&f).map_err(DataFusionError::from)) .map(WrappedSchema); @@ -162,12 +148,10 @@ unsafe extern "C" fn coerce_types_fn_wrapper( udf: &FFI_ScalarUDF, arg_types: RVec, ) -> RResult, RString> { - let private_data = udf.private_data as *const ScalarUDFPrivateData; - let udf = &(*private_data).udf; - let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); - let return_types = rresult_return!(data_types_with_scalar_udf(&arg_types, udf)); + let return_types = + rresult_return!(data_types_with_scalar_udf(&arg_types, udf.inner())); rresult!(vec_datatype_to_rvec_wrapped(&return_types)) } @@ -179,9 +163,6 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( number_rows: usize, return_field: WrappedSchema, ) -> RResult { - let private_data = udf.private_data as *const ScalarUDFPrivateData; - let udf = &(*private_data).udf; - let args = args .into_iter() .map(|arr| { @@ -213,6 +194,7 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( }; let result = rresult_return!(udf + .inner() .invoke_with_args(args) .and_then(|r| r.to_array(number_rows))); @@ -257,12 +239,12 @@ impl From> for FFI_ScalarUDF { volatility, short_circuits, invoke_with_args: invoke_with_args_fn_wrapper, - return_type: return_type_fn_wrapper, return_field_from_args: return_field_from_args_fn_wrapper, coerce_types: coerce_types_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -321,21 +303,25 @@ impl Hash for ForeignScalarUDF { } } -impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF { +impl TryFrom<&FFI_ScalarUDF> for Arc { type Error = DataFusionError; fn try_from(udf: &FFI_ScalarUDF) -> Result { - let name = udf.name.to_owned().into(); - let signature = Signature::user_defined((&udf.volatility).into()); - - let aliases = udf.aliases.iter().map(|s| s.to_string()).collect(); - - Ok(Self { - name, - udf: udf.clone(), - aliases, - signature, - }) + if (udf.library_marker_id)() == crate::get_library_marker_id() { + Ok(Arc::clone(udf.inner().inner())) + } else { + let name = udf.name.to_owned().into(); + let signature = Signature::user_defined((&udf.volatility).into()); + + let aliases = udf.aliases.iter().map(|s| s.to_string()).collect(); + + Ok(Arc::new(ForeignScalarUDF { + name, + udf: udf.clone(), + aliases, + signature, + })) + } } } @@ -352,14 +338,8 @@ impl ScalarUDFImpl for ForeignScalarUDF { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; - - let result = unsafe { (self.udf.return_type)(&self.udf, arg_types) }; - - let result = df_result!(result); - - result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("ForeignScalarUDF implements return_field_from_args instead.") } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { @@ -455,12 +435,36 @@ mod tests { let original_udf = datafusion::functions::math::abs::AbsFunc::new(); let original_udf = Arc::new(ScalarUDF::from(original_udf)); - let local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into(); + let mut local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into(); + local_udf.library_marker_id = crate::mock_foreign_marker_id; - let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?; + let foreign_udf: Arc = (&local_udf).try_into()?; assert_eq!(original_udf.name(), foreign_udf.name()); Ok(()) } + + #[test] + fn test_ffi_udf_local_bypass() -> Result<()> { + use datafusion::functions::math::abs::AbsFunc; + let original_udf = AbsFunc::new(); + let original_udf = Arc::new(ScalarUDF::from(original_udf)); + + let mut ffi_udf = FFI_ScalarUDF::from(original_udf); + + // Verify local libraries can be downcast to their original + let foreign_udf: Arc = (&ffi_udf).try_into()?; + assert!(foreign_udf.as_any().downcast_ref::().is_some()); + + // Verify different library markers generate foreign providers + ffi_udf.library_marker_id = crate::mock_foreign_marker_id; + let foreign_udf: Arc = (&ffi_udf).try_into()?; + assert!(foreign_udf + .as_any() + .downcast_ref::() + .is_some()); + + Ok(()) + } } diff --git a/datafusion/ffi/src/udtf.rs b/datafusion/ffi/src/udtf.rs index edd5273c70a8..edef827cd617 100644 --- a/datafusion/ffi/src/udtf.rs +++ b/datafusion/ffi/src/udtf.rs @@ -36,10 +36,8 @@ use datafusion_proto::{ use prost::Message; use tokio::runtime::Handle; -use crate::{ - df_result, rresult_return, - table_provider::{FFI_TableProvider, ForeignTableProvider}, -}; +use crate::execution::FFI_TaskContextProvider; +use crate::{df_result, rresult_return, table_provider::FFI_TableProvider}; /// A stable struct for sharing a [`TableFunctionImpl`] across FFI boundaries. #[repr(C)] @@ -53,6 +51,10 @@ pub struct FFI_TableFunction { args: RVec, ) -> RResult, + /// Provider for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_provider: FFI_TaskContextProvider, + /// Used to create a clone on the provider of the udtf. This should /// only need to be called by the receiver of the udtf. pub clone: unsafe extern "C" fn(udtf: &Self) -> Self, @@ -63,6 +65,10 @@ pub struct FFI_TableFunction { /// Internal data. This is only to be accessed by the provider of the udtf. /// A [`ForeignTableFunction`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_TableFunction {} @@ -90,7 +96,7 @@ unsafe extern "C" fn call_fn_wrapper( args: RVec, ) -> RResult { let runtime = udtf.runtime(); - let udtf = udtf.inner(); + let udtf_inner = udtf.inner(); let default_ctx = SessionContext::new(); let codec = DefaultLogicalExtensionCodec {}; @@ -100,8 +106,13 @@ unsafe extern "C" fn call_fn_wrapper( let args = rresult_return!(parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec)); - let table_provider = rresult_return!(udtf.call(&args)); - RResult::ROk(FFI_TableProvider::new(table_provider, false, runtime)) + let table_provider = rresult_return!(udtf_inner.call(&args)); + RResult::ROk(FFI_TableProvider::new( + table_provider, + false, + runtime, + udtf.task_ctx_provider.clone(), + )) } unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) { @@ -111,9 +122,13 @@ unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) { unsafe extern "C" fn clone_fn_wrapper(udtf: &FFI_TableFunction) -> FFI_TableFunction { let runtime = udtf.runtime(); - let udtf = udtf.inner(); + let udtf_inner = udtf.inner(); - FFI_TableFunction::new(Arc::clone(udtf), runtime) + FFI_TableFunction::new( + Arc::clone(udtf_inner), + runtime, + udtf.task_ctx_provider.clone(), + ) } impl Clone for FFI_TableFunction { @@ -123,30 +138,21 @@ impl Clone for FFI_TableFunction { } impl FFI_TableFunction { - pub fn new(udtf: Arc, runtime: Option) -> Self { + pub fn new( + udtf: Arc, + runtime: Option, + task_ctx_provider: impl Into, + ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); let private_data = Box::new(TableFunctionPrivateData { udtf, runtime }); Self { call: call_fn_wrapper, + task_ctx_provider, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, - } - } -} - -impl From> for FFI_TableFunction { - fn from(udtf: Arc) -> Self { - let private_data = Box::new(TableFunctionPrivateData { - udtf, - runtime: None, - }); - - Self { - call: call_fn_wrapper, - clone: clone_fn_wrapper, - release: release_fn_wrapper, - private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -169,9 +175,13 @@ pub struct ForeignTableFunction(FFI_TableFunction); unsafe impl Send for ForeignTableFunction {} unsafe impl Sync for ForeignTableFunction {} -impl From for ForeignTableFunction { +impl From for Arc { fn from(value: FFI_TableFunction) -> Self { - Self(value) + if (value.library_marker_id)() == crate::get_library_marker_id() { + Arc::clone(value.inner()) + } else { + Arc::new(ForeignTableFunction(value)) + } } } @@ -186,25 +196,26 @@ impl TableFunctionImpl for ForeignTableFunction { let table_provider = unsafe { (self.0.call)(&self.0, filters_serialized) }; let table_provider = df_result!(table_provider)?; - let table_provider: ForeignTableProvider = (&table_provider).into(); + let table_provider: Arc = (&table_provider).into(); - Ok(Arc::new(table_provider)) + Ok(table_provider) } } #[cfg(test)] mod tests { + use super::*; use arrow::{ array::{ record_batch, ArrayRef, Float64Array, RecordBatch, StringArray, UInt64Array, }, datatypes::{DataType, Field, Schema}, }; + use datafusion::logical_expr::ptr_eq::arc_ptr_eq; use datafusion::{ catalog::MemTable, common::exec_err, prelude::lit, scalar::ScalarValue, }; - - use super::*; + use datafusion_execution::TaskContextProvider; #[derive(Debug)] struct TestUDTF {} @@ -287,15 +298,17 @@ mod tests { #[tokio::test] async fn test_round_trip_udtf() -> Result<()> { let original_udtf = Arc::new(TestUDTF {}) as Arc; + let ctx = Arc::new(SessionContext::default()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; - let local_udtf: FFI_TableFunction = - FFI_TableFunction::new(Arc::clone(&original_udtf), None); + let mut local_udtf: FFI_TableFunction = + FFI_TableFunction::new(Arc::clone(&original_udtf), None, task_ctx_provider); + local_udtf.library_marker_id = crate::mock_foreign_marker_id; - let foreign_udf: ForeignTableFunction = local_udtf.into(); + let foreign_udf: Arc = local_udtf.into(); let table = foreign_udf.call(&[lit(6_u64), lit("one"), lit(2.0), lit(3_u64)])?; - let ctx = SessionContext::default(); let _ = ctx.register_table("test-table", table)?; let returned_batches = ctx.table("test-table").await?.collect().await?; @@ -317,4 +330,23 @@ mod tests { Ok(()) } + + #[test] + fn test_ffi_udtf_local_bypass() -> Result<()> { + let original_udtf = Arc::new(TestUDTF {}) as Arc; + + let ctx = Arc::new(SessionContext::default()) as Arc; + let mut ffi_udtf = FFI_TableFunction::new(Arc::clone(&original_udtf), None, ctx); + + // Verify local libraries can be downcast to their original + let foreign_udtf: Arc = ffi_udtf.clone().into(); + assert!(arc_ptr_eq(&original_udtf, &foreign_udtf)); + + // Verify different library markers generate foreign providers + ffi_udtf.library_marker_id = crate::mock_foreign_marker_id; + let foreign_udtf: Arc = ffi_udtf.into(); + assert!(!arc_ptr_eq(&original_udtf, &foreign_udtf)); + + Ok(()) + } } diff --git a/datafusion/ffi/src/udwf/mod.rs b/datafusion/ffi/src/udwf/mod.rs index 9f56e2d4788b..1dc51ed729db 100644 --- a/datafusion/ffi/src/udwf/mod.rs +++ b/datafusion/ffi/src/udwf/mod.rs @@ -39,7 +39,7 @@ use datafusion::{ logical_expr::{Signature, WindowUDF, WindowUDFImpl}, }; use datafusion_common::exec_err; -use partition_evaluator::{FFI_PartitionEvaluator, ForeignPartitionEvaluator}; +use partition_evaluator::FFI_PartitionEvaluator; use partition_evaluator_args::{ FFI_PartitionEvaluatorArgs, ForeignPartitionEvaluatorArgs, }; @@ -50,6 +50,7 @@ mod partition_evaluator; mod partition_evaluator_args; mod range; +use crate::execution::FFI_TaskContextProvider; use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; use crate::{ arrow_wrappers::WrappedSchema, @@ -95,6 +96,10 @@ pub struct FFI_WindowUDF { pub sort_options: ROption, + /// Provider for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_provider: FFI_TaskContextProvider, + /// Used to create a clone on the provider of the udf. This should /// only need to be called by the receiver of the udf. pub clone: unsafe extern "C" fn(udf: &Self) -> Self, @@ -105,6 +110,10 @@ pub struct FFI_WindowUDF { /// Internal data. This is only to be accessed by the provider of the udf. /// A [`ForeignWindowUDF`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_WindowUDF {} @@ -198,9 +207,11 @@ unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF { sort_options: udwf.sort_options.clone(), coerce_types: coerce_types_fn_wrapper, field: field_fn_wrapper, + task_ctx_provider: udwf.task_ctx_provider.clone(), clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } @@ -210,8 +221,12 @@ impl Clone for FFI_WindowUDF { } } -impl From> for FFI_WindowUDF { - fn from(udf: Arc) -> Self { +impl FFI_WindowUDF { + pub fn new( + udf: Arc, + task_ctx_provider: impl Into, + ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); let name = udf.name().into(); let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect(); let volatility = udf.signature().volatility.into(); @@ -227,9 +242,11 @@ impl From> for FFI_WindowUDF { sort_options, coerce_types: coerce_types_fn_wrapper, field: field_fn_wrapper, + task_ctx_provider, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -270,21 +287,25 @@ impl Hash for ForeignWindowUDF { } } -impl TryFrom<&FFI_WindowUDF> for ForeignWindowUDF { +impl TryFrom<&FFI_WindowUDF> for Arc { type Error = DataFusionError; fn try_from(udf: &FFI_WindowUDF) -> Result { - let name = udf.name.to_owned().into(); - let signature = Signature::user_defined((&udf.volatility).into()); - - let aliases = udf.aliases.iter().map(|s| s.to_string()).collect(); - - Ok(Self { - name, - udf: udf.clone(), - aliases, - signature, - }) + if (udf.library_marker_id)() == crate::get_library_marker_id() { + Ok(Arc::clone(unsafe { udf.inner().inner() })) + } else { + let name = udf.name.to_owned().into(); + let signature = Signature::user_defined((&udf.volatility).into()); + + let aliases = udf.aliases.iter().map(|s| s.to_string()).collect(); + + Ok(Arc::new(ForeignWindowUDF { + name, + udf: udf.clone(), + aliases, + signature, + })) + } } } @@ -318,14 +339,14 @@ impl WindowUDFImpl for ForeignWindowUDF { args: datafusion::logical_expr::function::PartitionEvaluatorArgs, ) -> Result> { let evaluator = unsafe { - let args = FFI_PartitionEvaluatorArgs::try_from(args)?; + let args = FFI_PartitionEvaluatorArgs::try_new( + args, + self.udf.task_ctx_provider.clone(), + )?; (self.udf.partition_evaluator)(&self.udf, args) }; - df_result!(evaluator).map(|evaluator| { - Box::new(ForeignPartitionEvaluator::from(evaluator)) - as Box - }) + df_result!(evaluator).map(>::from) } fn field(&self, field_args: WindowUDFFieldArgs) -> Result { @@ -393,30 +414,35 @@ mod tests { use datafusion::logical_expr::expr::Sort; use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF, WindowUDFImpl}; use datafusion::prelude::SessionContext; + use datafusion_execution::TaskContextProvider; use std::sync::Arc; fn create_test_foreign_udwf( original_udwf: impl WindowUDFImpl + 'static, + ctx: &Arc, ) -> datafusion::common::Result { let original_udwf = Arc::new(WindowUDF::from(original_udwf)); - let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into(); + let mut local_udwf = FFI_WindowUDF::new(Arc::clone(&original_udwf), ctx); + local_udwf.library_marker_id = crate::mock_foreign_marker_id; - let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?; - Ok(foreign_udwf.into()) + let foreign_udwf: Arc = (&local_udwf).try_into()?; + Ok(WindowUDF::new_from_shared_impl(foreign_udwf)) } #[test] fn test_round_trip_udwf() -> datafusion::common::Result<()> { let original_udwf = lag_udwf(); let original_name = original_udwf.name().to_owned(); + let ctx = Arc::new(SessionContext::new()) as Arc; // Convert to FFI format - let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into(); + let mut local_udwf = FFI_WindowUDF::new(Arc::clone(&original_udwf), ctx); + local_udwf.library_marker_id = crate::mock_foreign_marker_id; // Convert back to native format - let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?; - let foreign_udwf: WindowUDF = foreign_udwf.into(); + let foreign_udwf: Arc = (&local_udwf).try_into()?; + let foreign_udwf = WindowUDF::new_from_shared_impl(foreign_udwf); assert_eq!(original_name, foreign_udwf.name()); Ok(()) @@ -424,7 +450,8 @@ mod tests { #[tokio::test] async fn test_lag_udwf() -> datafusion::common::Result<()> { - let udwf = create_test_foreign_udwf(WindowShift::lag())?; + let ctx = Arc::new(SessionContext::new()) as Arc; + let udwf = create_test_foreign_udwf(WindowShift::lag(), &ctx)?; let ctx = SessionContext::default(); let df = ctx.read_batch(create_record_batch(-5, 5))?; @@ -450,4 +477,29 @@ mod tests { Ok(()) } + + #[test] + fn test_ffi_udwf_local_bypass() -> datafusion_common::Result<()> { + let original_udwf = Arc::new(WindowUDF::from(WindowShift::lag())); + let ctx = Arc::new(SessionContext::new()) as Arc; + + let mut ffi_udwf = FFI_WindowUDF::new(original_udwf, ctx); + + // Verify local libraries can be downcast to their original + let foreign_udwf: Arc = (&ffi_udwf).try_into()?; + assert!(foreign_udwf + .as_any() + .downcast_ref::() + .is_some()); + + // Verify different library markers generate foreign providers + ffi_udwf.library_marker_id = crate::mock_foreign_marker_id; + let foreign_udwf: Arc = (&ffi_udwf).try_into()?; + assert!(foreign_udwf + .as_any() + .downcast_ref::() + .is_some()); + + Ok(()) + } } diff --git a/datafusion/ffi/src/udwf/partition_evaluator.rs b/datafusion/ffi/src/udwf/partition_evaluator.rs index 14cf23b919aa..2f7784753009 100644 --- a/datafusion/ffi/src/udwf/partition_evaluator.rs +++ b/datafusion/ffi/src/udwf/partition_evaluator.rs @@ -76,6 +76,10 @@ pub struct FFI_PartitionEvaluator { /// Internal data. This is only to be accessed by the provider of the evaluator. /// A [`ForeignPartitionEvaluator`] should never attempt to access this data. pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> u64, } unsafe impl Send for FFI_PartitionEvaluator {} @@ -170,9 +174,11 @@ unsafe extern "C" fn get_range_fn_wrapper( } unsafe extern "C" fn release_fn_wrapper(evaluator: &mut FFI_PartitionEvaluator) { - let private_data = - Box::from_raw(evaluator.private_data as *mut PartitionEvaluatorPrivateData); - drop(private_data); + if !evaluator.private_data.is_null() { + let private_data = + Box::from_raw(evaluator.private_data as *mut PartitionEvaluatorPrivateData); + drop(private_data); + } } impl From> for FFI_PartitionEvaluator { @@ -195,6 +201,7 @@ impl From> for FFI_PartitionEvaluator { uses_window_frame, release: release_fn_wrapper, private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + library_marker_id: crate::get_library_marker_id, } } } @@ -219,9 +226,19 @@ pub struct ForeignPartitionEvaluator { unsafe impl Send for ForeignPartitionEvaluator {} unsafe impl Sync for ForeignPartitionEvaluator {} -impl From for ForeignPartitionEvaluator { - fn from(evaluator: FFI_PartitionEvaluator) -> Self { - Self { evaluator } +impl From for Box { + fn from(mut evaluator: FFI_PartitionEvaluator) -> Self { + if (evaluator.library_marker_id)() == crate::get_library_marker_id() { + unsafe { + let private_data = Box::from_raw( + evaluator.private_data as *mut PartitionEvaluatorPrivateData, + ); + evaluator.private_data = std::ptr::null_mut(); + private_data.evaluator + } + } else { + Box::new(ForeignPartitionEvaluator { evaluator }) + } } } @@ -317,4 +334,54 @@ impl PartitionEvaluator for ForeignPartitionEvaluator { } #[cfg(test)] -mod tests {} +mod tests { + use crate::udwf::partition_evaluator::{ + FFI_PartitionEvaluator, ForeignPartitionEvaluator, + }; + use arrow::array::ArrayRef; + use datafusion::logical_expr::PartitionEvaluator; + + #[derive(Debug)] + struct TestPartitionEvaluator {} + + impl PartitionEvaluator for TestPartitionEvaluator { + fn evaluate_all( + &mut self, + values: &[ArrayRef], + _num_rows: usize, + ) -> datafusion_common::Result { + Ok(values[0].to_owned()) + } + } + + #[test] + fn test_ffi_partition_evaluator_local_bypass_inner() -> datafusion_common::Result<()> + { + let original_accum = TestPartitionEvaluator {}; + let boxed_accum: Box = Box::new(original_accum); + + let ffi_accum: FFI_PartitionEvaluator = boxed_accum.into(); + + // Verify local libraries can be downcast to their original + let foreign_accum: Box = ffi_accum.into(); + unsafe { + let concrete = &*(foreign_accum.as_ref() as *const dyn PartitionEvaluator + as *const TestPartitionEvaluator); + assert!(!concrete.uses_window_frame()); + } + + // Verify different library markers generate foreign accumulator + let original_accum = TestPartitionEvaluator {}; + let boxed_accum: Box = Box::new(original_accum); + let mut ffi_accum: FFI_PartitionEvaluator = boxed_accum.into(); + ffi_accum.library_marker_id = crate::mock_foreign_marker_id; + let foreign_accum: Box = ffi_accum.into(); + unsafe { + let concrete = &*(foreign_accum.as_ref() as *const dyn PartitionEvaluator + as *const ForeignPartitionEvaluator); + assert!(!concrete.uses_window_frame()); + } + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udwf/partition_evaluator_args.rs b/datafusion/ffi/src/udwf/partition_evaluator_args.rs index cd2641256437..93497d34f368 100644 --- a/datafusion/ffi/src/udwf/partition_evaluator_args.rs +++ b/datafusion/ffi/src/udwf/partition_evaluator_args.rs @@ -18,6 +18,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::arrow_wrappers::WrappedSchema; +use crate::execution::FFI_TaskContextProvider; use abi_stable::{std_types::RVec, StableAbi}; use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, @@ -29,9 +30,9 @@ use datafusion::{ error::{DataFusionError, Result}, logical_expr::function::PartitionEvaluatorArgs, physical_plan::{expressions::Column, PhysicalExpr}, - prelude::SessionContext, }; use datafusion_common::exec_datafusion_err; +use datafusion_execution::TaskContext; use datafusion_proto::{ physical_plan::{ from_proto::parse_physical_expr, to_proto::serialize_physical_exprs, @@ -53,11 +54,17 @@ pub struct FFI_PartitionEvaluatorArgs { is_reversed: bool, ignore_nulls: bool, schema: WrappedSchema, + + /// Provider for TaskContext to be used during protobuf serialization + /// and deserialization. + pub task_ctx_provider: FFI_TaskContextProvider, } -impl TryFrom> for FFI_PartitionEvaluatorArgs { - type Error = DataFusionError; - fn try_from(args: PartitionEvaluatorArgs) -> Result { +impl FFI_PartitionEvaluatorArgs { + pub fn try_new( + args: PartitionEvaluatorArgs, + task_ctx_provider: impl Into, + ) -> Result { // This is a bit of a hack. Since PartitionEvaluatorArgs does not carry a schema // around, and instead passes the data types directly we are unable to decode the // protobuf PhysicalExpr correctly. In evaluating the code the only place these @@ -117,6 +124,7 @@ impl TryFrom> for FFI_PartitionEvaluatorArgs { schema, is_reversed: args.is_reversed(), ignore_nulls: args.ignore_nulls(), + task_ctx_provider: task_ctx_provider.into(), }) } } @@ -136,10 +144,10 @@ impl TryFrom for ForeignPartitionEvaluatorArgs { type Error = DataFusionError; fn try_from(value: FFI_PartitionEvaluatorArgs) -> Result { - let default_ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; let schema: SchemaRef = value.schema.into(); + let task_ctx: Arc = (&value.task_ctx_provider).try_into()?; let input_exprs = value .input_exprs @@ -149,7 +157,7 @@ impl TryFrom for ForeignPartitionEvaluatorArgs { .map_err(|e| exec_datafusion_err!("Failed to decode PhysicalExprNode: {e}"))? .iter() .map(|expr_node| { - parse_physical_expr(expr_node, &default_ctx.task_ctx(), &schema, &codec) + parse_physical_expr(expr_node, task_ctx.as_ref(), &schema, &codec) }) .collect::>>()?; diff --git a/datafusion/ffi/tests/ffi_catalog.rs b/datafusion/ffi/tests/ffi_catalog.rs index b63d8cbd631b..41bf867efb9d 100644 --- a/datafusion/ffi/tests/ffi_catalog.rs +++ b/datafusion/ffi/tests/ffi_catalog.rs @@ -19,16 +19,18 @@ /// when the feature integration-tests is built #[cfg(feature = "integration-tests")] mod tests { + use datafusion::catalog::{CatalogProvider, CatalogProviderList}; use datafusion::prelude::SessionContext; use datafusion_common::DataFusionError; - use datafusion_ffi::catalog_provider::ForeignCatalogProvider; - use datafusion_ffi::catalog_provider_list::ForeignCatalogProviderList; + use datafusion_execution::TaskContextProvider; use datafusion_ffi::tests::utils::get_module; use std::sync::Arc; #[tokio::test] async fn test_catalog() -> datafusion_common::Result<()> { let module = get_module()?; + let ctx = Arc::new(SessionContext::default()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; let ffi_catalog = module @@ -36,11 +38,10 @@ mod tests { .ok_or(DataFusionError::NotImplemented( "External catalog provider failed to implement create_catalog" .to_string(), - ))?(); - let foreign_catalog: ForeignCatalogProvider = (&ffi_catalog).into(); + ))?(task_ctx_provider.into()); + let foreign_catalog: Arc = (&ffi_catalog).into(); - let ctx = SessionContext::default(); - let _ = ctx.register_catalog("fruit", Arc::new(foreign_catalog)); + let _ = ctx.register_catalog("fruit", foreign_catalog); let df = ctx.table("fruit.apple.purchases").await?; @@ -57,17 +58,20 @@ mod tests { async fn test_catalog_list() -> datafusion_common::Result<()> { let module = get_module()?; + let ctx = Arc::new(SessionContext::default()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; + let ffi_catalog_list = module .create_catalog_list() .ok_or(DataFusionError::NotImplemented( "External catalog provider failed to implement create_catalog_list" .to_string(), - ))?(); - let foreign_catalog_list: ForeignCatalogProviderList = (&ffi_catalog_list).into(); + ))?(task_ctx_provider.into()); + let foreign_catalog_list: Arc = + (&ffi_catalog_list).into(); - let ctx = SessionContext::default(); - ctx.register_catalog_list(Arc::new(foreign_catalog_list)); + ctx.register_catalog_list(foreign_catalog_list); let df = ctx.table("blue.apple.purchases").await?; diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index 7b4d1b1e350a..a0d42da94688 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -19,9 +19,10 @@ /// when the feature integration-tests is built #[cfg(feature = "integration-tests")] mod tests { + use datafusion::catalog::TableProvider; use datafusion::error::{DataFusionError, Result}; use datafusion::prelude::SessionContext; - use datafusion_ffi::table_provider::ForeignTableProvider; + use datafusion_execution::TaskContextProvider; use datafusion_ffi::tests::create_record_batch; use datafusion_ffi::tests::utils::get_module; use std::sync::Arc; @@ -32,22 +33,23 @@ mod tests { async fn test_table_provider(synchronous: bool) -> Result<()> { let table_provider_module = get_module()?; + let ctx = Arc::new(SessionContext::new()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; + // By calling the code below, the table provided will be created within // the module's code. let ffi_table_provider = table_provider_module.create_table().ok_or( DataFusionError::NotImplemented( "External table provider failed to implement create_table".to_string(), ), - )?(synchronous); + )?(synchronous, task_ctx_provider.into()); // In order to access the table provider within this executable, we need to // turn it into a `ForeignTableProvider`. - let foreign_table_provider: ForeignTableProvider = (&ffi_table_provider).into(); - - let ctx = SessionContext::new(); + let foreign_table_provider: Arc = (&ffi_table_provider).into(); // Display the data to show the full cycle works. - ctx.register_table("external_table", Arc::new(foreign_table_provider))?; + ctx.register_table("external_table", foreign_table_provider)?; let df = ctx.table("external_table").await?; let results = df.collect().await?; diff --git a/datafusion/ffi/tests/ffi_udaf.rs b/datafusion/ffi/tests/ffi_udaf.rs index ffd99bac62ec..e6f581ff0b63 100644 --- a/datafusion/ffi/tests/ffi_udaf.rs +++ b/datafusion/ffi/tests/ffi_udaf.rs @@ -22,27 +22,28 @@ mod tests { use arrow::array::Float64Array; use datafusion::common::record_batch; use datafusion::error::{DataFusionError, Result}; - use datafusion::logical_expr::AggregateUDF; + use datafusion::logical_expr::{AggregateUDF, AggregateUDFImpl}; use datafusion::prelude::{col, SessionContext}; - + use datafusion_execution::TaskContextProvider; use datafusion_ffi::tests::utils::get_module; - use datafusion_ffi::udaf::ForeignAggregateUDF; + use std::sync::Arc; #[tokio::test] async fn test_ffi_udaf() -> Result<()> { let module = get_module()?; + let ctx = Arc::new(SessionContext::default()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; let ffi_sum_func = module .create_sum_udaf() .ok_or(DataFusionError::NotImplemented( "External table provider failed to implement create_udaf".to_string(), - ))?(); - let foreign_sum_func: ForeignAggregateUDF = (&ffi_sum_func).try_into()?; + ))?(task_ctx_provider.into()); + let foreign_sum_func: Arc = (&ffi_sum_func).try_into()?; - let udaf: AggregateUDF = foreign_sum_func.into(); + let udaf = AggregateUDF::new_from_shared_impl(foreign_sum_func); - let ctx = SessionContext::default(); let record_batch = record_batch!( ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), ("b", Float64, vec![1.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0]) @@ -73,18 +74,20 @@ mod tests { #[tokio::test] async fn test_ffi_grouping_udaf() -> Result<()> { let module = get_module()?; + let ctx = Arc::new(SessionContext::default()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; let ffi_stddev_func = module .create_stddev_udaf() .ok_or(DataFusionError::NotImplemented( "External table provider failed to implement create_udaf".to_string(), - ))?(); - let foreign_stddev_func: ForeignAggregateUDF = (&ffi_stddev_func).try_into()?; + ))?(task_ctx_provider.into()); + let foreign_stddev_func: Arc = + (&ffi_stddev_func).try_into()?; - let udaf: AggregateUDF = foreign_stddev_func.into(); + let udaf = AggregateUDF::new_from_shared_impl(foreign_stddev_func); - let ctx = SessionContext::default(); let record_batch = record_batch!( ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), ( diff --git a/datafusion/ffi/tests/ffi_udf.rs b/datafusion/ffi/tests/ffi_udf.rs index fd6a84bcf5b0..d50739be9975 100644 --- a/datafusion/ffi/tests/ffi_udf.rs +++ b/datafusion/ffi/tests/ffi_udf.rs @@ -19,16 +19,15 @@ /// when the feature integration-tests is built #[cfg(feature = "integration-tests")] mod tests { - use arrow::datatypes::DataType; use datafusion::common::record_batch; use datafusion::error::{DataFusionError, Result}; - use datafusion::logical_expr::ScalarUDF; + use datafusion::logical_expr::{ScalarUDF, ScalarUDFImpl}; use datafusion::prelude::{col, SessionContext}; + use std::sync::Arc; use datafusion_ffi::tests::create_record_batch; use datafusion_ffi::tests::utils::get_module; - use datafusion_ffi::udf::ForeignScalarUDF; /// This test validates that we can load an external module and use a scalar /// udf defined in it via the foreign function interface. In this case we are @@ -44,9 +43,9 @@ mod tests { "External table provider failed to implement create_scalar_udf" .to_string(), ))?(); - let foreign_abs_func: ForeignScalarUDF = (&ffi_abs_func).try_into()?; + let foreign_abs_func: Arc = (&ffi_abs_func).try_into()?; - let udf: ScalarUDF = foreign_abs_func.into(); + let udf = ScalarUDF::new_from_shared_impl(foreign_abs_func); let ctx = SessionContext::default(); let df = ctx.read_batch(create_record_batch(-5, 5))?; @@ -82,9 +81,9 @@ mod tests { "External table provider failed to implement create_scalar_udf" .to_string(), ))?(); - let foreign_abs_func: ForeignScalarUDF = (&ffi_abs_func).try_into()?; + let foreign_abs_func: Arc = (&ffi_abs_func).try_into()?; - let udf: ScalarUDF = foreign_abs_func.into(); + let udf = ScalarUDF::new_from_shared_impl(foreign_abs_func); let ctx = SessionContext::default(); let df = ctx.read_batch(create_record_batch(-5, 5))?; diff --git a/datafusion/ffi/tests/ffi_udtf.rs b/datafusion/ffi/tests/ffi_udtf.rs index 8c1c64a092e1..7b292a9e449a 100644 --- a/datafusion/ffi/tests/ffi_udtf.rs +++ b/datafusion/ffi/tests/ffi_udtf.rs @@ -23,11 +23,11 @@ mod tests { use std::sync::Arc; use arrow::array::{create_array, ArrayRef}; + use datafusion::catalog::TableFunctionImpl; use datafusion::error::{DataFusionError, Result}; use datafusion::prelude::SessionContext; - + use datafusion_execution::TaskContextProvider; use datafusion_ffi::tests::utils::get_module; - use datafusion_ffi::udtf::ForeignTableFunction; /// This test validates that we can load an external module and use a scalar /// udf defined in it via the foreign function interface. In this case we are @@ -36,18 +36,18 @@ mod tests { async fn test_user_defined_table_function() -> Result<()> { let module = get_module()?; + let ctx = Arc::new(SessionContext::default()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; + let ffi_table_func = module .create_table_function() .ok_or(DataFusionError::NotImplemented( "External table function provider failed to implement create_table_function" .to_string(), - ))?(); - let foreign_table_func: ForeignTableFunction = ffi_table_func.into(); - - let udtf = Arc::new(foreign_table_func); + ))?(task_ctx_provider.into()); + let foreign_table_func: Arc = ffi_table_func.into(); - let ctx = SessionContext::default(); - ctx.register_udtf("my_range", udtf); + ctx.register_udtf("my_range", foreign_table_func); let result = ctx .sql("SELECT * FROM my_range(5)") diff --git a/datafusion/ffi/tests/ffi_udwf.rs b/datafusion/ffi/tests/ffi_udwf.rs index 18ffd0c5bcb7..4a260addaf7b 100644 --- a/datafusion/ffi/tests/ffi_udwf.rs +++ b/datafusion/ffi/tests/ffi_udwf.rs @@ -22,15 +22,18 @@ mod tests { use arrow::array::{create_array, ArrayRef}; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::expr::Sort; - use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF}; + use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF, WindowUDFImpl}; use datafusion::prelude::SessionContext; + use datafusion_execution::TaskContextProvider; use datafusion_ffi::tests::create_record_batch; use datafusion_ffi::tests::utils::get_module; - use datafusion_ffi::udwf::ForeignWindowUDF; + use std::sync::Arc; #[tokio::test] async fn test_rank_udwf() -> Result<()> { let module = get_module()?; + let ctx = Arc::new(SessionContext::default()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; let ffi_rank_func = module @@ -38,12 +41,11 @@ mod tests { .ok_or(DataFusionError::NotImplemented( "External table provider failed to implement create_scalar_udf" .to_string(), - ))?(); - let foreign_rank_func: ForeignWindowUDF = (&ffi_rank_func).try_into()?; + ))?(task_ctx_provider.into()); + let foreign_rank_func: Arc = (&ffi_rank_func).try_into()?; - let udwf: WindowUDF = foreign_rank_func.into(); + let udwf = WindowUDF::new_from_shared_impl(foreign_rank_func); - let ctx = SessionContext::default(); let df = ctx.read_batch(create_record_batch(-5, 5))?; let df = df.select(vec![ diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index 828892300860..bbb504c7bdc4 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -67,6 +67,47 @@ SELECT median(c1) IGNORE NULLS FROM table Instead of silently succeeding. +### FFI object conversion + +Many of the structs in the `datafusion-ffi` crate have been updated to allow easier +conversion to the underlying trait types they represent. This simplifies some code +paths, but also provides an additional improvement in cases where library code goes +through a round trip via the foreign function interface. + +To update your code, suppose you have a `FFI_SchemaProvider` called `ffi_provider` +and you wish to use this as a `SchemaProvider`. In the old approach you would do +something like: + +```rust,ignore + let foreign_provider: ForeignSchemaProvider = provider.into(); + let foreign_provider = Arc::new(foreign_provider) as Arc; +``` + +This code should now be written as: + +```rust,ignore + let foreign_provider: Arc = provider.into(); + let foreign_provider = foreign_provider as Arc; +``` + +For the case of user defined functions, the updates are similar but you +may need to change the way you call the creation of the `ScalarUDF`. +Aggregate and window functions follow the same pattern. + +Previously you may write: + +```rust,ignore + let foreign_udf: ForeignScalarUDF = ffi_udf.try_into()?; + let foreign_udf: ScalarUDF = foreign_udf.into(); +``` + +Instead this should now be: + +```rust,ignore + let foreign_udf: Arc = ffi_udf.try_into()?; + let foreign_udf = ScalarUDF::new_from_shared_impl(foreign_udf); +``` + ## DataFusion `51.0.0` ### `arrow` / `parquet` updated to 57.0.0