Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions datafusion-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
// under the License.

use clap::Parser;
use datafusion::datasource::datasource::TableProviderFactory;
use datafusion::datasource::file_format::file_type::FileType;
use datafusion::datasource::listing_table_factory::ListingTableFactory;
use datafusion::datasource::object_store::ObjectStoreRegistry;
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::SessionConfig;
Expand All @@ -26,6 +29,7 @@ use datafusion_cli::{
exec, print_format::PrintFormat, print_options::PrintOptions, DATAFUSION_CLI_VERSION,
};
use mimalloc::MiMalloc;
use std::collections::HashMap;
use std::env;
use std::path::Path;
use std::sync::Arc;
Expand Down Expand Up @@ -93,7 +97,7 @@ pub async fn main() -> Result<()> {

if let Some(ref path) = args.data_path {
let p = Path::new(path);
env::set_current_dir(&p).unwrap();
env::set_current_dir(p).unwrap();
};

let mut session_config = SessionConfig::from_env().with_information_schema(true);
Expand All @@ -105,6 +109,7 @@ pub async fn main() -> Result<()> {
let runtime_env = create_runtime_env()?;
let mut ctx =
SessionContext::with_config_rt(session_config.clone(), Arc::new(runtime_env));
ctx.refresh_catalogs().await?;

let mut print_options = PrintOptions {
format: args.format,
Expand Down Expand Up @@ -142,11 +147,31 @@ pub async fn main() -> Result<()> {
}

fn create_runtime_env() -> Result<RuntimeEnv> {
let mut table_factories: HashMap<String, Arc<dyn TableProviderFactory>> =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add TableFactorys for default formats.

HashMap::new();
table_factories.insert(
"csv".to_string(),
Arc::new(ListingTableFactory::new(FileType::CSV)),
);
table_factories.insert(
"parquet".to_string(),
Arc::new(ListingTableFactory::new(FileType::PARQUET)),
);
table_factories.insert(
"avro".to_string(),
Arc::new(ListingTableFactory::new(FileType::AVRO)),
);
table_factories.insert(
"json".to_string(),
Arc::new(ListingTableFactory::new(FileType::JSON)),
);

let object_store_provider = DatafusionCliObjectStoreProvider {};
let object_store_registry =
ObjectStoreRegistry::new_with_provider(Some(Arc::new(object_store_provider)));
let rn_config =
RuntimeConfig::new().with_object_store_registry(Arc::new(object_store_registry));
let rn_config = RuntimeConfig::new()
.with_object_store_registry(Arc::new(object_store_registry))
.with_table_factories(table_factories);
RuntimeEnv::new(rn_config)
}

Expand Down
8 changes: 7 additions & 1 deletion datafusion-cli/src/object_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,13 @@ mod tests {
assert!(err.to_string().contains("Generic S3 error: Missing region"));

env::set_var("AWS_REGION", "us-east-1");
assert!(provider.get_by_url(&Url::from_str(s3).unwrap()).is_ok());
let url = Url::from_str(s3).expect("Unable to parse s3 url");
let res = provider.get_by_url(&url);
let msg = match res {
Err(e) => format!("{}", e),
Ok(_) => "".to_string()
};
assert_eq!("".to_string(), msg); // Fail with error message
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored this test to have a better failure message because it was failing for me locally. The fact that CI passed makes me think it's not running there because datafusion-cli is excluded from the workspace - I don't know why this is, but would propose we include it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please file a ticket to do so? Thank you!

env::remove_var("AWS_REGION");
}
}
12 changes: 6 additions & 6 deletions datafusion-cli/src/print_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ mod tests {
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Int32Array::from_slice(&[1, 2, 3])),
Arc::new(Int32Array::from_slice(&[4, 5, 6])),
Arc::new(Int32Array::from_slice(&[7, 8, 9])),
Arc::new(Int32Array::from_slice([1, 2, 3])),
Arc::new(Int32Array::from_slice([4, 5, 6])),
Arc::new(Int32Array::from_slice([7, 8, 9])),
],
)
.unwrap();
Expand Down Expand Up @@ -137,9 +137,9 @@ mod tests {
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Int32Array::from_slice(&[1, 2, 3])),
Arc::new(Int32Array::from_slice(&[4, 5, 6])),
Arc::new(Int32Array::from_slice(&[7, 8, 9])),
Arc::new(Int32Array::from_slice([1, 2, 3])),
Arc::new(Int32Array::from_slice([4, 5, 6])),
Arc::new(Int32Array::from_slice([7, 8, 9])),
],
)
.unwrap();
Expand Down
23 changes: 16 additions & 7 deletions datafusion/core/src/catalog/listing_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
use crate::catalog::schema::SchemaProvider;
use crate::datasource::datasource::TableProviderFactory;
use crate::datasource::TableProvider;
use datafusion_common::DataFusionError;
use crate::execution::context::SessionState;
use datafusion_common::{context, DataFusionError};
use futures::TryStreamExt;
use itertools::Itertools;
use object_store::ObjectStore;
use std::any::Any;
use std::collections::{HashMap, HashSet};
Expand Down Expand Up @@ -72,7 +74,7 @@ impl ListingSchemaProvider {
}

/// Reload table information from ObjectStore
pub async fn refresh(&self) -> datafusion_common::Result<()> {
pub async fn refresh(&self, state: &SessionState) -> datafusion_common::Result<()> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to add SessionState to allow ListingTables to load their schema.

let entries: Vec<_> = self
.store
.list(Some(&self.path))
Expand Down Expand Up @@ -100,13 +102,20 @@ impl ListingSchemaProvider {
.ok_or_else(|| {
DataFusionError::Internal("Cannot parse file name!".to_string())
})?;
let table_name = table.to_str().ok_or_else(|| {
let table_name = file_name.split('.').collect_vec()[0];
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deltatables are always folders, but .csvs should have their extension removed from their table name (and unfortunately the Path method to do this is marked unstable).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which method? Maybe we can contribute something back upstream to object_store 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://doc.rust-lang.org/std/path/struct.PathBuf.html#method.file_prefix

For some reason the github UI didn't let me respond in thread until now :/

let table_path = table.to_str().ok_or_else(|| {
DataFusionError::Internal("Cannot parse file name!".to_string())
})?;
if !self.table_exist(file_name) {
let table_name = format!("{}/{}", self.authority, table_name);
let provider = self.factory.create(table_name.as_str()).await?;
let _ = self.register_table(file_name.to_string(), provider.clone())?;
if !self.table_exist(table_name) {
let table_url = format!("{}/{}", self.authority, table_path);
let provider = self
.factory
.create(state, table_url.as_str())
.await
.map_err(|e| {
context!(format!("Could not create table for {}", table_url), e)
})?;
let _ = self.register_table(table_name.to_string(), provider.clone())?;
}
}
Ok(())
Expand Down
6 changes: 5 additions & 1 deletion datafusion/core/src/datasource/datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,9 @@ pub trait TableProvider: Sync + Send {
#[async_trait]
pub trait TableProviderFactory: Sync + Send {
/// Create a TableProvider with the given url
async fn create(&self, url: &str) -> Result<Arc<dyn TableProvider>>;
async fn create(
&self,
ctx: &SessionState,
url: &str,
) -> Result<Arc<dyn TableProvider>>;
}
79 changes: 79 additions & 0 deletions datafusion/core/src/datasource/listing_table_factory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Factory for creating ListingTables with default options

use crate::datasource::datasource::TableProviderFactory;
use crate::datasource::file_format::avro::AvroFormat;
use crate::datasource::file_format::csv::CsvFormat;
use crate::datasource::file_format::file_type::{FileType, GetExt};
use crate::datasource::file_format::json::JsonFormat;
use crate::datasource::file_format::parquet::ParquetFormat;
use crate::datasource::file_format::FileFormat;
use crate::datasource::listing::{
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
};
use crate::datasource::TableProvider;
use crate::execution::context::SessionState;
use async_trait::async_trait;
use std::sync::Arc;

/// A `TableProviderFactory` capable of creating new `ListingTable`s
pub struct ListingTableFactory {
file_type: FileType,
}

impl ListingTableFactory {
/// Creates a new `ListingTableFactory`
pub fn new(file_type: FileType) -> Self {
Self { file_type }
}
}

#[async_trait]
impl TableProviderFactory for ListingTableFactory {
async fn create(
&self,
state: &SessionState,
url: &str,
) -> datafusion_common::Result<Arc<dyn TableProvider>> {
let file_extension = self.file_type.get_ext();

let file_format: Arc<dyn FileFormat> = match self.file_type {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 it would be really neat to somehow combine the logic in ListingTable and ListingTableFactory (or maybe datafusion-cli could just use the factory -- not sure)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combine the logic

I think this is the duplicated fragment. I would love to combine those two and only have register_table instead of create_listing_table() and create_custom_table().

FileType::CSV => Arc::new(CsvFormat::default()),
FileType::PARQUET => Arc::new(ParquetFormat::default()),
FileType::AVRO => Arc::new(AvroFormat::default()),
FileType::JSON => Arc::new(JsonFormat::default()),
};

let options = ListingOptions {
format: file_format,
collect_stat: true,
file_extension: file_extension.to_owned(),
target_partitions: 1,
table_partition_cols: vec![],
};

let table_path = ListingTableUrl::parse(url)?;
let resolved_schema = options.infer_schema(state, &table_path).await?;
let config = ListingTableConfig::new(table_path)
.with_listing_options(options)
.with_schema(resolved_schema);
let table = ListingTable::try_new(config)?;
Ok(Arc::new(table))
}
}
1 change: 1 addition & 0 deletions datafusion/core/src/datasource/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub mod default_table_source;
pub mod empty;
pub mod file_format;
pub mod listing;
pub mod listing_table_factory;
pub mod memory;
pub mod object_store;
pub mod view;
Expand Down
60 changes: 42 additions & 18 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,26 @@ impl SessionContext {
Self::with_config(SessionConfig::new())
}

/// Finds any ListSchemaProviders and instructs them to reload tables from "disk"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Finds any ListSchemaProviders and instructs them to reload tables from "disk"
/// Invokes `ListingSchemaProvider::reload()` for all registered providers

pub async fn refresh_catalogs(&self) -> Result<()> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving this code to a common function on the context which we can use from datafusion-cli, tests, Ballista, etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense to me

let cat_names = self.catalog_names().clone();
for cat_name in cat_names.iter() {
let cat = self.catalog(cat_name.as_str()).ok_or_else(|| {
DataFusionError::Internal("Catalog not found!".to_string())
})?;
for schema_name in cat.schema_names() {
let schema = cat.schema(schema_name.as_str()).ok_or_else(|| {
DataFusionError::Internal("Schema not found!".to_string())
})?;
let lister = schema.as_any().downcast_ref::<ListingSchemaProvider>();
if let Some(lister) = lister {
lister.refresh(&self.state()).await?;
}
}
}
Ok(())
}

/// Creates a new session context using the provided session configuration.
pub fn with_config(config: SessionConfig) -> Self {
let runtime = Arc::new(RuntimeEnv::default());
Expand Down Expand Up @@ -486,7 +506,7 @@ impl SessionContext {
cmd.file_type
))
})?;
let table = (*factory).create(cmd.location.as_str()).await?;
let table = (*factory).create(&state, cmd.location.as_str()).await?;
self.register_table(cmd.name.as_str(), table)?;
let plan = LogicalPlanBuilder::empty(false).build()?;
Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
Expand Down Expand Up @@ -1764,7 +1784,7 @@ impl ContextProvider for SessionState {
Ok(schema) => {
let provider = schema.table(resolved_ref.table).ok_or_else(|| {
DataFusionError::Plan(format!(
"'{}.{}.{}' not found",
"table '{}.{}.{}' not found",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not found is a hard error message to ctrl-f for. Adding the word table will hopefully make this statistically more likely to be found.

resolved_ref.catalog, resolved_ref.schema, resolved_ref.table
))
})?;
Expand Down Expand Up @@ -2005,11 +2025,12 @@ mod tests {
use super::*;
use crate::assert_batches_eq;
use crate::datasource::datasource::TableProviderFactory;
use crate::datasource::listing_table_factory::ListingTableFactory;
use crate::execution::context::QueryPlanner;
use crate::execution::runtime_env::RuntimeConfig;
use crate::physical_plan::expressions::AvgAccumulator;
use crate::test;
use crate::test_util::{parquet_test_data, TestTableFactory};
use crate::test_util::parquet_test_data;
use crate::variable::VarType;
use arrow::array::ArrayRef;
use arrow::datatypes::*;
Expand Down Expand Up @@ -2267,30 +2288,33 @@ mod tests {

let mut table_factories: HashMap<String, Arc<dyn TableProviderFactory>> =
HashMap::new();
table_factories.insert("test".to_string(), Arc::new(TestTableFactory {}));
let factory = Arc::new(ListingTableFactory::new(FileType::CSV));
table_factories.insert("test".to_string(), factory);
let rt_cfg = RuntimeConfig::new().with_table_factories(table_factories);
let runtime = Arc::new(RuntimeEnv::new(rt_cfg).unwrap());
let cfg = SessionConfig::new()
.set_str("datafusion.catalog.location", url.as_str())
.set_str("datafusion.catalog.type", "test");
let session_state = SessionState::with_config_rt(cfg, runtime);
let ctx = SessionContext::with_state(session_state);
ctx.refresh_catalogs().await?;

let mut table_count = 0;
for cat_name in ctx.catalog_names().iter() {
let cat = ctx.catalog(cat_name).unwrap();
for s_name in cat.schema_names().iter() {
let schema = cat.schema(s_name).unwrap();
if let Some(listing) =
schema.as_any().downcast_ref::<ListingSchemaProvider>()
{
listing.refresh().await.unwrap();
table_count = schema.table_names().len();
}
}
}
let result =
plan_and_collect(&ctx, "select c_name from default.customer limit 3;")
.await?;

let actual = arrow::util::pretty::pretty_format_batches(&result)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might also consider using assert_batches_eq here in this test

.unwrap()
.to_string();
let expected = r#"+--------------------+
| c_name |
+--------------------+
| Customer#000000002 |
| Customer#000000003 |
| Customer#000000004 |
+--------------------+"#;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proof is in the pudding. Can't select from a table without registering it first, so this must be auto-registered.

assert_eq!(actual, expected);

assert_eq!(table_count, 8);
Ok(())
}

Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ pub struct TestTableFactory {}
impl TableProviderFactory for TestTableFactory {
async fn create(
&self,
_state: &SessionState,
url: &str,
) -> datafusion_common::Result<Arc<dyn TableProvider>> {
Ok(Arc::new(TestTableProvider {
Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/tests/sql/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async fn query_cte_incorrect() -> Result<()> {
assert!(plan.is_err());
assert_eq!(
format!("{}", plan.unwrap_err()),
"Error during planning: 'datafusion.public.t' not found"
"Error during planning: table 'datafusion.public.t' not found"
);

// forward referencing
Expand All @@ -89,7 +89,7 @@ async fn query_cte_incorrect() -> Result<()> {
assert!(plan.is_err());
assert_eq!(
format!("{}", plan.unwrap_err()),
"Error during planning: 'datafusion.public.u' not found"
"Error during planning: table 'datafusion.public.u' not found"
);

// wrapping should hide u
Expand All @@ -98,7 +98,7 @@ async fn query_cte_incorrect() -> Result<()> {
assert!(plan.is_err());
assert_eq!(
format!("{}", plan.unwrap_err()),
"Error during planning: 'datafusion.public.u' not found"
"Error during planning: table 'datafusion.public.u' not found"
);

Ok(())
Expand Down
Loading