diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 4a71bce617c4..5418e69a05ea 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -397,6 +397,12 @@ impl ConfigOptions { Self::default() } + /// Set extensions to provided value + pub fn with_extensions(mut self, extensions: Extensions) -> Self { + self.extensions = extensions; + self + } + /// Set a configuration option pub fn set(&mut self, key: &str, value: &str) -> Result<()> { let (prefix, key) = key.split_once('.').ok_or_else(|| { diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 5b39e54dbe86..6eca7fb843d0 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -83,7 +83,7 @@ use crate::physical_plan::PhysicalPlanner; use crate::variable::{VarProvider, VarType}; use async_trait::async_trait; use chrono::{DateTime, Utc}; -use datafusion_common::{OwnedTableReference, ScalarValue}; +use datafusion_common::{config::Extensions, OwnedTableReference, ScalarValue}; use datafusion_sql::{ parser::DFParser, planner::{ContextProvider, SqlToRel}, @@ -2097,27 +2097,28 @@ pub struct TaskContext { impl TaskContext { /// Create a new task context instance - pub fn new( + pub fn try_new( task_id: String, session_id: String, task_props: HashMap, scalar_functions: HashMap>, aggregate_functions: HashMap>, runtime: Arc, - ) -> Self { - let mut config = ConfigOptions::new(); + extensions: Extensions, + ) -> Result { + let mut config = ConfigOptions::new().with_extensions(extensions); for (k, v) in task_props { - let _ = config.set(&k, &v); + config.set(&k, &v)?; } - Self { + Ok(Self { task_id: Some(task_id), session_id, session_config: config.into(), scalar_functions, aggregate_functions, runtime, - } + }) } /// Return the SessionConfig associated with the Task @@ -2212,6 +2213,8 @@ mod tests { use arrow::array::ArrayRef; use arrow::record_batch::RecordBatch; use async_trait::async_trait; + use datafusion_common::config::ConfigExtension; + use datafusion_common::extensions_options; use datafusion_expr::{create_udaf, create_udf, Expr, Volatility}; use datafusion_physical_expr::functions::make_scalar_function; use std::fs::File; @@ -2879,4 +2882,43 @@ mod tests { .unwrap() } } + + extensions_options! { + struct TestExtension { + value: usize, default = 42 + } + } + + impl ConfigExtension for TestExtension { + const PREFIX: &'static str = "test"; + } + + #[test] + fn task_context_extensions() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let task_props = HashMap::from([("test.value".to_string(), "24".to_string())]); + let mut extensions = Extensions::default(); + extensions.insert(TestExtension::default()); + + let task_context = TaskContext::try_new( + "task_id".to_string(), + "session_id".to_string(), + task_props, + HashMap::default(), + HashMap::default(), + runtime, + extensions, + )?; + + let test = task_context + .session_config() + .config_options() + .extensions + .get::(); + assert!(test.is_some()); + + assert_eq!(test.unwrap().value, 24); + + Ok(()) + } }