diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 418d60de3e7a..15ee50733eca 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -126,9 +126,14 @@ impl TryInto for &protobuf::LogicalPlanNode { projection = Some(column_indices); } - LogicalPlanBuilder::scan_csv(&scan.path, options, projection)? - .build() - .map_err(|e| e.into()) + LogicalPlanBuilder::scan_csv_with_name( + &scan.path, + options, + projection, + &scan.table_name, + )? + .build() + .map_err(|e| e.into()) } LogicalPlanType::ParquetScan(scan) => { let projection = match scan.projection.as_ref() { @@ -151,9 +156,14 @@ impl TryInto for &protobuf::LogicalPlanNode { Some(r?) } }; - LogicalPlanBuilder::scan_parquet(&scan.path, projection, 24)? //TODO concurrency - .build() - .map_err(|e| e.into()) + LogicalPlanBuilder::scan_parquet_with_name( + &scan.path, + projection, + 24, + &scan.table_name, + )? //TODO concurrency + .build() + .map_err(|e| e.into()) } LogicalPlanType::Sort(sort) => { let input: LogicalPlan = convert_box_required!(sort.input)?; diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 6a763420c782..19a67a504e77 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -39,3 +39,6 @@ futures = "0.3" env_logger = "^0.8" mimalloc = { version = "0.1", optional = true, default-features = false } snmalloc-rs = {version = "0.2", optional = true, features= ["cache-friendly"] } + +[dev-dependencies] +ballista-core = { path = "../ballista/rust/core" } diff --git a/benchmarks/run.sh b/benchmarks/run.sh index 21633d39c23a..8e36424da89f 100755 --- a/benchmarks/run.sh +++ b/benchmarks/run.sh @@ -20,7 +20,7 @@ set -e # This bash script is meant to be run inside the docker-compose environment. Check the README for instructions cd / -for query in 1 3 5 6 7 8 9 10 12 +for query in 1 3 5 6 10 12 do /tpch benchmark ballista --host ballista-scheduler --port 50050 --query $query --path /data --format tbl --iterations 1 --debug done diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 286fe4594510..77c69f0ce524 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -573,7 +573,6 @@ mod tests { use datafusion::arrow::array::*; use datafusion::arrow::util::display::array_value_to_string; - use datafusion::logical_plan::Expr; use datafusion::logical_plan::Expr::Cast; @@ -1042,4 +1041,88 @@ mod tests { Ok(()) } + + mod ballista_round_trip { + use super::*; + use ballista_core::serde::protobuf; + use datafusion::physical_plan::ExecutionPlan; + use std::convert::TryInto; + + fn round_trip_query(n: usize) -> Result<()> { + let config = ExecutionConfig::new() + .with_concurrency(1) + .with_batch_size(10); + let mut ctx = ExecutionContext::with_config(config); + + // set tpch_data_path to dummy value and skip physical plan serde test when TPCH_DATA + // is not set. + let tpch_data_path = + env::var("TPCH_DATA").unwrap_or_else(|_| "./".to_string()); + + for &table in TABLES { + let schema = get_schema(table); + let options = CsvReadOptions::new() + .schema(&schema) + .delimiter(b'|') + .has_header(false) + .file_extension(".tbl"); + let provider = CsvFile::try_new( + &format!("{}/{}.tbl", tpch_data_path, table), + options, + )?; + ctx.register_table(table, Arc::new(provider))?; + } + + // test logical plan round trip + let plan = create_logical_plan(&mut ctx, n)?; + let proto: protobuf::LogicalPlanNode = (&plan).try_into().unwrap(); + let round_trip: LogicalPlan = (&proto).try_into().unwrap(); + assert_eq!( + format!("{:?}", plan), + format!("{:?}", round_trip), + "logical plan round trip failed" + ); + + // test optimized logical plan round trip + let plan = ctx.optimize(&plan)?; + let proto: protobuf::LogicalPlanNode = (&plan).try_into().unwrap(); + let round_trip: LogicalPlan = (&proto).try_into().unwrap(); + assert_eq!( + format!("{:?}", plan), + format!("{:?}", round_trip), + "opitmized logical plan round trip failed" + ); + + // test physical plan roundtrip + if env::var("TPCH_DATA").is_ok() { + let physical_plan = ctx.create_physical_plan(&plan)?; + let proto: protobuf::PhysicalPlanNode = + (physical_plan.clone()).try_into().unwrap(); + let round_trip: Arc = (&proto).try_into().unwrap(); + assert_eq!( + format!("{:?}", physical_plan), + format!("{:?}", round_trip), + "physical plan round trip failed" + ); + } + + Ok(()) + } + + macro_rules! test_round_trip { + ($tn:ident, $query:expr) => { + #[test] + fn $tn() -> Result<()> { + round_trip_query($query) + } + }; + } + + test_round_trip!(q1, 1); + test_round_trip!(q3, 3); + test_round_trip!(q5, 5); + test_round_trip!(q6, 6); + test_round_trip!(q10, 10); + test_round_trip!(q12, 12); + } } diff --git a/datafusion/src/datasource/mod.rs b/datafusion/src/datasource/mod.rs index b46b9cc4e899..9699a997caa1 100644 --- a/datafusion/src/datasource/mod.rs +++ b/datafusion/src/datasource/mod.rs @@ -28,6 +28,7 @@ pub use self::csv::{CsvFile, CsvReadOptions}; pub use self::datasource::{TableProvider, TableType}; pub use self::memory::MemTable; +/// Source for table input data pub(crate) enum Source> { /// Path to a single file or a directory containing one of more files Path(String), diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 147f8322df5d..ced77ba6c6f6 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -118,9 +118,19 @@ impl LogicalPlanBuilder { path: &str, options: CsvReadOptions, projection: Option>, + ) -> Result { + Self::scan_csv_with_name(path, options, projection, path) + } + + /// Scan a CSV data source and register it with a given table name + pub fn scan_csv_with_name( + path: &str, + options: CsvReadOptions, + projection: Option>, + table_name: &str, ) -> Result { let provider = Arc::new(CsvFile::try_new(path, options)?); - Self::scan(path, provider, projection) + Self::scan(table_name, provider, projection) } /// Scan a Parquet data source @@ -128,9 +138,19 @@ impl LogicalPlanBuilder { path: &str, projection: Option>, max_concurrency: usize, + ) -> Result { + Self::scan_parquet_with_name(path, projection, max_concurrency, path) + } + + /// Scan a Parquet data source and register it with a given table name + pub fn scan_parquet_with_name( + path: &str, + projection: Option>, + max_concurrency: usize, + table_name: &str, ) -> Result { let provider = Arc::new(ParquetTable::try_new(path, max_concurrency)?); - Self::scan(path, provider, projection) + Self::scan(table_name, provider, projection) } /// Scan an empty data source, mainly used in tests