Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions datafusion/core/src/datasource/datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ pub trait TableProvider: Sync + Send {
/// A factory which creates [`TableProvider`]s at runtime given a URL.
///
/// For example, this can be used to create a table "on the fly"
/// from a directory of files only when that name is referenced.
/// from a directory of files only when that name is referenced.
#[async_trait]
pub trait TableProviderFactory: Sync + Send {
/// Create a TableProvider given name and url
fn create(&self, name: &str, url: &str) -> Arc<dyn TableProvider>;
async fn create(&self, name: &str, url: &str) -> Result<Arc<dyn TableProvider>>;
}
35 changes: 14 additions & 21 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ use crate::config::{
ConfigOptions, OPT_BATCH_SIZE, OPT_COALESCE_BATCHES, OPT_COALESCE_TARGET_BATCH_SIZE,
OPT_FILTER_NULL_JOIN_KEYS, OPT_OPTIMIZER_SKIP_FAILED_RULES,
};
use crate::datasource::datasource::TableProviderFactory;
use crate::datasource::file_format::file_type::{FileCompressionType, FileType};
use crate::execution::{runtime_env::RuntimeEnv, FunctionRegistry};
use crate::physical_plan::file_format::{plan_to_csv, plan_to_json, plan_to_parquet};
Expand Down Expand Up @@ -159,8 +158,6 @@ pub struct SessionContext {
pub session_start_time: DateTime<Utc>,
/// Shared session state for the session
pub state: Arc<RwLock<SessionState>>,
/// Dynamic table providers
pub table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
}

impl Default for SessionContext {
Expand Down Expand Up @@ -188,7 +185,6 @@ impl SessionContext {
session_id: state.session_id.clone(),
session_start_time: chrono::Utc::now(),
state: Arc::new(RwLock::new(state)),
table_factories: HashMap::default(),
}
}

Expand All @@ -198,19 +194,9 @@ impl SessionContext {
session_id: state.session_id.clone(),
session_start_time: chrono::Utc::now(),
state: Arc::new(RwLock::new(state)),
table_factories: HashMap::default(),
}
}

/// Register a `TableProviderFactory` for a given `file_type` identifier
pub fn register_table_factory(
&mut self,
file_type: &str,
factory: Arc<dyn TableProviderFactory>,
) {
self.table_factories.insert(file_type.to_string(), factory);
}

/// Registers the [`RecordBatch`] as the specified table name
pub fn register_batch(
&self,
Expand Down Expand Up @@ -431,13 +417,20 @@ impl SessionContext {
&self,
cmd: &CreateExternalTable,
) -> Result<Arc<DataFrame>> {
let factory = &self.table_factories.get(&cmd.file_type).ok_or_else(|| {
DataFusionError::Execution(format!(
"Unable to find factory for {}",
cmd.file_type
))
})?;
let table = (*factory).create(cmd.name.as_str(), cmd.location.as_str());
let state = self.state.read().clone();
let factory = &state
.runtime_env
.table_factories
.get(&cmd.file_type)
.ok_or_else(|| {
DataFusionError::Execution(format!(
"Unable to find factory for {}",
cmd.file_type
))
})?;
let table = (*factory)
.create(cmd.name.as_str(), cmd.location.as_str())
.await?;
self.register_table(cmd.name.as_str(), table)?;
let plan = LogicalPlanBuilder::empty(false).build()?;
Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
Expand Down
27 changes: 26 additions & 1 deletion datafusion/core/src/execution/runtime_env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ use crate::{
memory_manager::{MemoryConsumerId, MemoryManager, MemoryManagerConfig},
},
};
use std::collections::HashMap;

use crate::datasource::datasource::TableProviderFactory;
use crate::datasource::object_store::ObjectStoreRegistry;
use datafusion_common::DataFusionError;
use object_store::ObjectStore;
Expand All @@ -43,6 +45,8 @@ pub struct RuntimeEnv {
pub disk_manager: Arc<DiskManager>,
/// Object Store Registry
pub object_store_registry: Arc<ObjectStoreRegistry>,
/// TableProviderFactories
pub table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
}

impl Debug for RuntimeEnv {
Expand All @@ -58,12 +62,14 @@ impl RuntimeEnv {
memory_manager,
disk_manager,
object_store_registry,
table_factories,
} = config;

Ok(Self {
memory_manager: MemoryManager::new(memory_manager),
disk_manager: DiskManager::try_new(disk_manager)?,
object_store_registry,
table_factories,
})
}

Expand All @@ -87,7 +93,7 @@ impl RuntimeEnv {
self.memory_manager.shrink_tracker_usage(delta)
}

/// Registers a object store with scheme using a custom `ObjectStore` so that
/// Registers an object store with scheme using a custom `ObjectStore` so that
/// an external file system or object storage system could be used against this context.
///
/// Returns the `ObjectStore` previously registered for this scheme, if any
Expand All @@ -101,6 +107,14 @@ impl RuntimeEnv {
.register_store(scheme, host, object_store)
}

/// Registers TableFactories
pub fn register_table_factories(
&mut self,
table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
) {
self.table_factories.extend(table_factories)
}

/// Retrieves a `ObjectStore` instance for a url
pub fn object_store(&self, url: impl AsRef<Url>) -> Result<Arc<dyn ObjectStore>> {
self.object_store_registry
Expand All @@ -124,6 +138,8 @@ pub struct RuntimeConfig {
pub memory_manager: MemoryManagerConfig,
/// ObjectStoreRegistry to get object store based on url
pub object_store_registry: Arc<ObjectStoreRegistry>,
/// Custom table factories for things like deltalake that are not part of core datafusion
pub table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
}

impl RuntimeConfig {
Expand Down Expand Up @@ -153,6 +169,15 @@ impl RuntimeConfig {
self
}

/// Customize object store registry
pub fn with_table_factories(
mut self,
table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
) -> Self {
self.table_factories = table_factories;
self
}

/// Specify the total memory to use while running the DataFusion
/// plan to `max_memory * memory_fraction` in bytes.
///
Expand Down
16 changes: 12 additions & 4 deletions datafusion/core/tests/sql/create_drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

use async_trait::async_trait;
use std::any::Any;
use std::collections::HashMap;
use std::io::Write;

use datafusion::datasource::datasource::TableProviderFactory;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_expr::TableType;
use tempfile::TempDir;

Expand Down Expand Up @@ -398,16 +400,22 @@ impl TableProvider for TestTableProvider {

struct TestTableFactory {}

#[async_trait]
impl TableProviderFactory for TestTableFactory {
fn create(&self, _name: &str, _path: &str) -> Arc<dyn TableProvider> {
Arc::new(TestTableProvider {})
async fn create(&self, _name: &str, _url: &str) -> Result<Arc<dyn TableProvider>> {
Ok(Arc::new(TestTableProvider {}))
}
}

#[tokio::test]
async fn create_custom_table() -> Result<()> {
let mut ctx = SessionContext::new();
ctx.register_table_factory("DELTATABLE", Arc::new(TestTableFactory {}));
let mut table_factories: HashMap<String, Arc<dyn TableProviderFactory>> =
HashMap::new();
table_factories.insert("DELTATABLE".to_string(), Arc::new(TestTableFactory {}));
let cfg = RuntimeConfig::new().with_table_factories(table_factories);
let env = RuntimeEnv::new(cfg).unwrap();
let ses = SessionConfig::new();
let ctx = SessionContext::with_config_rt(ses, Arc::new(env));

let sql = "CREATE EXTERNAL TABLE dt STORED AS DELTATABLE LOCATION 's3://bucket/schema/table';";
ctx.sql(sql).await.unwrap();
Expand Down
3 changes: 2 additions & 1 deletion datafusion/proto/src/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,8 @@ impl AsLogicalPlan for LogicalPlanNode {
match create_extern_table.file_type.as_str() {
"CSV" | "JSON" | "PARQUET" | "AVRO" => {}
it => {
if !ctx.table_factories.contains_key(it) {
let env = &ctx.state.as_ref().read().runtime_env;
if !env.table_factories.contains_key(it) {
Err(DataFusionError::Internal(format!(
"No TableProvider for file type: {}",
it
Expand Down