Skip to content

Commit

Permalink
Add plugable handler for CREATE FUNCTION (#9333)
Browse files Browse the repository at this point in the history
* Add plugable function factory

* cover `DROP FUNCTION` as well ...

... partially, as `SessionState` does not expose
unregister_udf at the moment.

* update documentation

* fix doc test

* Address PR comments (code organization)

* Address PR comments (factory interface)

* fix test after rebase

* `remove`'s gone from the trait ...

... `DROP FUNCTION` will look for function name
in all available registries (udf, udaf, udwf).

`remove` may be necessary if UDaF and UDwF do not
get `simplify` method from #9304.

* Rename FunctionDefinition and export it ...

FunctionDefinition already exists, DefinitionStatement makes more sense.

* Update datafusion/expr/src/logical_plan/ddl.rs

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

* Update datafusion/core/src/execution/context/mod.rs

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

* Update datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

* Update datafusion/expr/src/logical_plan/ddl.rs

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

* resolve part of follow up comments

* Qualified functions are not supported anymore

* update docs and todos

* fix clippy

* address additional comments

* Add sqllogicteset for CREATE/DROP function

* Add coverage for DROP FUNCTION IF EXISTS

* fix multiline error

* revert dialect back to generic in test ...

... as `create function` gets support in latest
sqlparser.

* fmt

---------

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
milenkovicm and alamb committed Mar 5, 2024
1 parent 3aba67e commit ea01e56
Show file tree
Hide file tree
Showing 8 changed files with 498 additions and 20 deletions.
97 changes: 94 additions & 3 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ use crate::datasource::{
};
use crate::error::{DataFusionError, Result};
use crate::logical_expr::{
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateMemoryTable,
CreateView, DropCatalogSchema, DropTable, DropView, Explain, LogicalPlan,
LogicalPlanBuilder, SetVariable, TableSource, TableType, UNNAMED_TABLE,
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction,
CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable, DropView,
Explain, LogicalPlan, LogicalPlanBuilder, SetVariable, TableSource, TableType,
UNNAMED_TABLE,
};
use crate::optimizer::OptimizerRule;
use datafusion_sql::{
Expand Down Expand Up @@ -489,6 +490,8 @@ impl SessionContext {
DdlStatement::DropTable(cmd) => self.drop_table(cmd).await,
DdlStatement::DropView(cmd) => self.drop_view(cmd).await,
DdlStatement::DropCatalogSchema(cmd) => self.drop_schema(cmd).await,
DdlStatement::CreateFunction(cmd) => self.create_function(cmd).await,
DdlStatement::DropFunction(cmd) => self.drop_function(cmd).await,
},
// TODO what about the other statements (like TransactionStart and TransactionEnd)
LogicalPlan::Statement(Statement::SetVariable(stmt)) => {
Expand Down Expand Up @@ -794,6 +797,55 @@ impl SessionContext {
Ok(false)
}

async fn create_function(&self, stmt: CreateFunction) -> Result<DataFrame> {
let function = {
let state = self.state.read().clone();
let function_factory = &state.function_factory;

match function_factory {
Some(f) => f.create(state.config(), stmt).await?,
_ => Err(DataFusionError::Configuration(
"Function factory has not been configured".into(),
))?,
}
};

match function {
RegisterFunction::Scalar(f) => {
self.state.write().register_udf(f)?;
}
RegisterFunction::Aggregate(f) => {
self.state.write().register_udaf(f)?;
}
RegisterFunction::Window(f) => {
self.state.write().register_udwf(f)?;
}
RegisterFunction::Table(name, f) => self.register_udtf(&name, f),
};

self.return_empty_dataframe()
}

async fn drop_function(&self, stmt: DropFunction) -> Result<DataFrame> {
// we don't know function type at this point
// decision has been made to drop all functions
let mut dropped = false;
dropped |= self.state.write().deregister_udf(&stmt.name)?.is_some();
dropped |= self.state.write().deregister_udaf(&stmt.name)?.is_some();
dropped |= self.state.write().deregister_udwf(&stmt.name)?.is_some();

// DROP FUNCTION IF EXISTS drops the specified function only if that
// function exists and in this way, it avoids error. While the DROP FUNCTION
// statement also performs the same function, it throws an
// error if the function does not exist.

if !stmt.if_exists && !dropped {
exec_err!("Function does not exist")
} else {
self.return_empty_dataframe()
}
}

/// Registers a variable provider within this context.
pub fn register_variable(
&self,
Expand Down Expand Up @@ -1261,7 +1313,30 @@ impl QueryPlanner for DefaultQueryPlanner {
.await
}
}
/// A pluggable interface to handle `CREATE FUNCTION` statements
/// and interact with [SessionState] to registers new udf, udaf or udwf.

#[async_trait]
pub trait FunctionFactory: Sync + Send {
/// Handles creation of user defined function specified in [CreateFunction] statement
async fn create(
&self,
state: &SessionConfig,
statement: CreateFunction,
) -> Result<RegisterFunction>;
}

/// Type of function to create
pub enum RegisterFunction {
/// Scalar user defined function
Scalar(Arc<ScalarUDF>),
/// Aggregate user defined function
Aggregate(Arc<AggregateUDF>),
/// Window user defined function
Window(Arc<WindowUDF>),
/// Table user defined function
Table(String, Arc<dyn TableFunctionImpl>),
}
/// Execution context for registering data sources and executing queries.
/// See [`SessionContext`] for a higher level API.
///
Expand Down Expand Up @@ -1306,6 +1381,12 @@ pub struct SessionState {
table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
/// Runtime environment
runtime_env: Arc<RuntimeEnv>,

/// [FunctionFactory] to support pluggable user defined function handler.
///
/// It will be invoked on `CREATE FUNCTION` statements.
/// thus, changing dialect o PostgreSql is required
function_factory: Option<Arc<dyn FunctionFactory>>,
}

impl Debug for SessionState {
Expand Down Expand Up @@ -1392,6 +1473,7 @@ impl SessionState {
execution_props: ExecutionProps::new(),
runtime_env: runtime,
table_factories,
function_factory: None,
};

// register built in functions
Expand Down Expand Up @@ -1568,6 +1650,15 @@ impl SessionState {
self
}

/// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements
pub fn with_function_factory(
mut self,
function_factory: Arc<dyn FunctionFactory>,
) -> Self {
self.function_factory = Some(function_factory);
self
}

/// Replace the extension [`SerializerRegistry`]
pub fn with_serializer_registry(
mut self,
Expand Down
130 changes: 128 additions & 2 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

use arrow::compute::kernels::numeric::add;
use arrow_array::{
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array,
Array, ArrayRef, ArrowNativeTypeOp, Float32Array, Float64Array, Int32Array,
RecordBatch, UInt8Array,
};
use arrow_schema::DataType::Float64;
use arrow_schema::{DataType, Field, Schema};
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
use datafusion_common::cast::as_float64_array;
Expand All @@ -31,10 +33,12 @@ use datafusion_common::{
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::{
create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable,
create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, ExprSchemable,
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use parking_lot::Mutex;

use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use rand::{thread_rng, Rng};
use std::any::Any;
use std::iter;
Expand Down Expand Up @@ -735,6 +739,128 @@ async fn verify_udf_return_type() -> Result<()> {
Ok(())
}

#[derive(Debug, Default)]
struct MockFunctionFactory {
pub captured_expr: Mutex<Option<Expr>>,
}

#[async_trait::async_trait]
impl FunctionFactory for MockFunctionFactory {
#[doc = r" Crates and registers a function from [CreateFunction] statement"]
#[must_use]
#[allow(clippy::type_complexity, clippy::type_repetition_in_bounds)]
async fn create(
&self,
_config: &SessionConfig,
statement: CreateFunction,
) -> datafusion::error::Result<RegisterFunction> {
// In this example, we always create a function that adds its arguments
// with the name specified in `CREATE FUNCTION`. In a real implementation
// the body of the created UDF would also likely be a function of the contents
// of the `CreateFunction`
let mock_add = Arc::new(|args: &[datafusion_expr::ColumnarValue]| {
let args = datafusion_expr::ColumnarValue::values_to_arrays(args)?;
let base =
datafusion_common::cast::as_float64_array(&args[0]).expect("cast failed");
let exponent =
datafusion_common::cast::as_float64_array(&args[1]).expect("cast failed");

let array = base
.iter()
.zip(exponent.iter())
.map(|(base, exponent)| match (base, exponent) {
(Some(base), Some(exponent)) => Some(base.add_wrapping(exponent)),
_ => None,
})
.collect::<arrow_array::Float64Array>();
Ok(datafusion_expr::ColumnarValue::from(
Arc::new(array) as arrow_array::ArrayRef
))
});

let args = statement.args.unwrap();
let mock_udf = create_udf(
&statement.name,
vec![args[0].data_type.clone(), args[1].data_type.clone()],
Arc::new(statement.return_type.unwrap()),
datafusion_expr::Volatility::Immutable,
mock_add,
);

// capture expression so we can verify
// it has been parsed
*self.captured_expr.lock() = statement.params.return_;

Ok(RegisterFunction::Scalar(Arc::new(mock_udf)))
}
}

#[tokio::test]
async fn create_scalar_function_from_sql_statement() -> Result<()> {
let function_factory = Arc::new(MockFunctionFactory::default());
let runtime_config = RuntimeConfig::new();
let runtime_environment = RuntimeEnv::new(runtime_config)?;

let session_config = SessionConfig::new();
let state =
SessionState::new_with_config_rt(session_config, Arc::new(runtime_environment))
.with_function_factory(function_factory.clone());

let ctx = SessionContext::new_with_state(state);
let options = SQLOptions::new().with_allow_ddl(false);

let sql = r#"
CREATE FUNCTION better_add(DOUBLE, DOUBLE)
RETURNS DOUBLE
RETURN $1 + $2
"#;

// try to `create function` when sql options have allow ddl disabled
assert!(ctx.sql_with_options(sql, options).await.is_err());

// Create the `better_add` function dynamically via CREATE FUNCTION statement
assert!(ctx.sql(sql).await.is_ok());
// try to `drop function` when sql options have allow ddl disabled
assert!(ctx
.sql_with_options("drop function better_add", options)
.await
.is_err());

ctx.sql("select better_add(2.0, 2.0)").await?.show().await?;

// check if we sql expr has been converted to datafusion expr
let captured_expression = function_factory.captured_expr.lock().clone().unwrap();
assert_eq!("$1 + $2", captured_expression.to_string());

// statement drops function
assert!(ctx.sql("drop function better_add").await.is_ok());
// no function, it panics
assert!(ctx.sql("drop function better_add").await.is_err());
// no function, it dies not care
assert!(ctx.sql("drop function if exists better_add").await.is_ok());
// query should fail as there is no function
assert!(ctx.sql("select better_add(2.0, 2.0)").await.is_err());

// tests expression parsing
// if expression is not correct
let bad_expression_sql = r#"
CREATE FUNCTION bad_expression_fun(DOUBLE, DOUBLE)
RETURNS DOUBLE
RETURN $1 $3
"#;
assert!(ctx.sql(bad_expression_sql).await.is_err());

// tests bad function definition
let bad_definition_sql = r#"
CREATE FUNCTION bad_definition_fun(DOUBLE, DOUBLE)
RET BAD_TYPE
RETURN $1 + $3
"#;
assert!(ctx.sql(bad_definition_sql).await.is_err());

Ok(())
}

fn create_udf_context() -> SessionContext {
let ctx = SessionContext::new();
// register a custom UDF
Expand Down
Loading

0 comments on commit ea01e56

Please sign in to comment.