From 1ab7f073c8c6168cb2d5e3bc1058986dc553e86c Mon Sep 17 00:00:00 2001 From: Martins Purins Date: Tue, 21 Feb 2023 23:22:01 +0100 Subject: [PATCH 1/2] allow setting config extensions for TaskContext --- datafusion/common/src/config.rs | 7 +++ datafusion/core/src/execution/context.rs | 56 +++++++++++++++++++++--- 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 4a71bce617c4..3a0552a3498f 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -397,6 +397,13 @@ impl ConfigOptions { Self::default() } + /// Creates a new [`ConfigOptions`] with extensions set to provided value + pub fn with_extensions(extensions: Extensions) -> Self { + let mut config = Self::new(); + config.extensions = extensions; + config + } + /// Set a configuration option pub fn set(&mut self, key: &str, value: &str) -> Result<()> { let (prefix, key) = key.split_once('.').ok_or_else(|| { diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 5b39e54dbe86..c12f35813582 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -83,7 +83,7 @@ use crate::physical_plan::PhysicalPlanner; use crate::variable::{VarProvider, VarType}; use async_trait::async_trait; use chrono::{DateTime, Utc}; -use datafusion_common::{OwnedTableReference, ScalarValue}; +use datafusion_common::{config::Extensions, OwnedTableReference, ScalarValue}; use datafusion_sql::{ parser::DFParser, planner::{ContextProvider, SqlToRel}, @@ -2097,27 +2097,28 @@ pub struct TaskContext { impl TaskContext { /// Create a new task context instance - pub fn new( + pub fn try_new( task_id: String, session_id: String, task_props: HashMap, scalar_functions: HashMap>, aggregate_functions: HashMap>, runtime: Arc, - ) -> Self { - let mut config = ConfigOptions::new(); + extensions: Extensions, + ) -> Result { + let mut config = ConfigOptions::with_extensions(extensions); for (k, v) in task_props { - let _ = config.set(&k, &v); + config.set(&k, &v)?; } - Self { + Ok(Self { task_id: Some(task_id), session_id, session_config: config.into(), scalar_functions, aggregate_functions, runtime, - } + }) } /// Return the SessionConfig associated with the Task @@ -2212,6 +2213,8 @@ mod tests { use arrow::array::ArrayRef; use arrow::record_batch::RecordBatch; use async_trait::async_trait; + use datafusion_common::config::ConfigExtension; + use datafusion_common::extensions_options; use datafusion_expr::{create_udaf, create_udf, Expr, Volatility}; use datafusion_physical_expr::functions::make_scalar_function; use std::fs::File; @@ -2879,4 +2882,43 @@ mod tests { .unwrap() } } + + extensions_options! { + struct TestExtension { + value: usize, default = 42 + } + } + + impl ConfigExtension for TestExtension { + const PREFIX: &'static str = "test"; + } + + #[test] + fn task_context_extensions() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let task_props = HashMap::from([("test.value".to_string(), "24".to_string())]); + let mut extensions = Extensions::default(); + extensions.insert(TestExtension::default()); + + let task_context = TaskContext::try_new( + "task_id".to_string(), + "session_id".to_string(), + task_props, + HashMap::default(), + HashMap::default(), + runtime, + extensions, + )?; + + let test = task_context + .session_config() + .config_options() + .extensions + .get::(); + assert!(test.is_some()); + + assert_eq!(test.unwrap().value, 24); + + Ok(()) + } } From c9d8f7098686a5e0b9dd562353903fb4c36005b0 Mon Sep 17 00:00:00 2001 From: Martins Purins Date: Tue, 7 Mar 2023 15:18:51 +0100 Subject: [PATCH 2/2] builder like api for ConfigOptions::with_extensions --- datafusion/common/src/config.rs | 9 ++++----- datafusion/core/src/execution/context.rs | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 3a0552a3498f..5418e69a05ea 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -397,11 +397,10 @@ impl ConfigOptions { Self::default() } - /// Creates a new [`ConfigOptions`] with extensions set to provided value - pub fn with_extensions(extensions: Extensions) -> Self { - let mut config = Self::new(); - config.extensions = extensions; - config + /// Set extensions to provided value + pub fn with_extensions(mut self, extensions: Extensions) -> Self { + self.extensions = extensions; + self } /// Set a configuration option diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index c12f35813582..6eca7fb843d0 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -2106,7 +2106,7 @@ impl TaskContext { runtime: Arc, extensions: Extensions, ) -> Result { - let mut config = ConfigOptions::with_extensions(extensions); + let mut config = ConfigOptions::new().with_extensions(extensions); for (k, v) in task_props { config.set(&k, &v)?; }