Skip to content

Commit

Permalink
Address PR comments (factory interface)
Browse files Browse the repository at this point in the history
  • Loading branch information
milenkovicm committed Feb 29, 2024
1 parent a650e16 commit acdb4b5
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 40 deletions.
71 changes: 50 additions & 21 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -798,28 +798,48 @@ impl SessionContext {
}

async fn create_function(&self, stmt: CreateFunction) -> Result<DataFrame> {
let function_factory = self.state.read().function_factory.clone();
let function = {
let state = self.state.read().clone();
let function_factory = &state.function_factory;

match function_factory {
Some(f) => f.create(state.config(), stmt).await?,
_ => Err(DataFusionError::Configuration(
"Function factory has not been configured".into(),
))?,
}
};

match function_factory {
Some(f) => f.create(self.state.clone(), stmt).await?,
None => Err(DataFusionError::Configuration(
"Function factory has not been configured".into(),
))?,
match function {
RegisterFunction::Scalar(f) => {
self.state.write().register_udf(f)?;
}
RegisterFunction::Aggregate(f) => {
self.state.write().register_udaf(f)?;
}
RegisterFunction::Window(f) => {
self.state.write().register_udwf(f)?;
}
RegisterFunction::Table(name, f) => self.register_udtf(&name, f),
};

self.return_empty_dataframe()
}

async fn drop_function(&self, stmt: DropFunction) -> Result<DataFrame> {
let function_factory = self.state.read().function_factory.clone();

match function_factory {
Some(f) => f.remove(self.state.clone(), stmt).await?,
None => Err(DataFusionError::Configuration(
"Function factory has not been configured".into(),
))?,
let _function = {
let state = self.state.read().clone();
let function_factory = &state.function_factory;

match function_factory {
Some(f) => f.remove(state.config(), stmt).await?,
None => Err(DataFusionError::Configuration(
"Function factory has not been configured".into(),
))?,
}
};

// TODO: Once we have unregister UDF we need to implement it here
self.return_empty_dataframe()
}

Expand Down Expand Up @@ -1289,27 +1309,36 @@ impl QueryPlanner for DefaultQueryPlanner {
/// ```
#[async_trait]
pub trait FunctionFactory: Sync + Send {
// TODO: I don't like having RwLock Leaking here, who ever implements it
// has to depend ot `parking_lot`. I'f we expose &mut SessionState it
// may keep lock of too long.
//
// Not sure if there is better approach.
// This api holds a read lock for state
//

/// Handles creation of user defined function specified in [CreateFunction] statement
async fn create(
&self,
state: Arc<RwLock<SessionState>>,
state: &SessionConfig,
statement: CreateFunction,
) -> Result<()>;
) -> Result<RegisterFunction>;

/// Drops user defined function from [SessionState]
// Naming it `drop`` would make more sense but its already occupied in rust
async fn remove(
&self,
state: Arc<RwLock<SessionState>>,
state: &SessionConfig,
statement: DropFunction,
) -> Result<()>;
) -> Result<RegisterFunction>;
}

/// Type of function to create
pub enum RegisterFunction {
/// Scalar user defined function
Scalar(Arc<ScalarUDF>),
/// Aggregate user defined function
Aggregate(Arc<AggregateUDF>),
/// Window user defined function
Window(Arc<WindowUDF>),
/// Table user defined function
Table(String, Arc<dyn TableFunctionImpl>),
}
/// Execution context for registering data sources and executing queries.
/// See [`SessionContext`] for a higher level API.
Expand Down
40 changes: 21 additions & 19 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use arrow_array::{
};
use arrow_schema::DataType::Float64;
use arrow_schema::{DataType, Field, Schema};
use datafusion::execution::context::{FunctionFactory, SessionState};
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
use datafusion_common::cast::as_float64_array;
Expand All @@ -34,7 +34,7 @@ use datafusion_expr::{
create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, DropFunction,
ExprSchemable, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use parking_lot::{Mutex, RwLock};
use parking_lot::Mutex;
use rand::{thread_rng, Rng};
use std::any::Any;
use std::iter;
Expand Down Expand Up @@ -636,9 +636,9 @@ impl FunctionFactory for MockFunctionFactory {
#[allow(clippy::type_complexity, clippy::type_repetition_in_bounds)]
async fn create(
&self,
state: Arc<RwLock<SessionState>>,
_config: &SessionConfig,
statement: CreateFunction,
) -> datafusion::error::Result<()> {
) -> datafusion::error::Result<RegisterFunction> {
// this function is a mock for testing
// `CreateFunction` should be used to derive this function

Expand Down Expand Up @@ -675,22 +675,25 @@ impl FunctionFactory for MockFunctionFactory {
// it has been parsed
*self.captured_expr.lock() = statement.params.return_;

// we may need other infrastructure provided by state, for example:
// state.config().get_extension()

// register mock udf for testing
state.write().register_udf(mock_udf.into())?;
Ok(())
Ok(RegisterFunction::Scalar(Arc::new(mock_udf)))
}

async fn remove(
&self,
_state: Arc<RwLock<SessionState>>,
_config: &SessionConfig,
_statement: DropFunction,
) -> datafusion::error::Result<()> {
// at the moment state does not support unregister
// ignoring for now
Ok(())
) -> datafusion::error::Result<RegisterFunction> {
// TODO: I don't like that remove returns RegisterFunction
// we have to keep two states in FunctionFactory iml and
// SessionState
//
// It would be better to return (function_name, function type) tuple

// at the moment state does not support unregister user defined functions

Err(DataFusionError::NotImplemented(
"remove function has not been implemented".into(),
))
}
}

Expand Down Expand Up @@ -722,15 +725,14 @@ async fn create_scalar_function_from_sql_statement() {
.await
.unwrap();

// sql expression should be convert to datafusion expression
// in this case
// check if we sql expr has been converted to datafusion expr
let captured_expression = function_factory.captured_expr.lock().clone().unwrap();

// is there some better way to test this
assert_eq!("$1 + $2", captured_expression.to_string());
println!("{:?}", captured_expression);

ctx.sql("drop function better_add").await.unwrap();
// no support at the moment
// ctx.sql("drop function better_add").await.unwrap();
}

fn create_udf_context() -> SessionContext {
Expand Down

0 comments on commit acdb4b5

Please sign in to comment.