From fa3604e39da2273c4b5ffaea5e6c7aa80b4ca3b6 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 27 Aug 2024 11:20:23 +0200 Subject: [PATCH 01/11] Take Sort (SortExpr) in file options Part of effort to remove `Expr::Sort`. --- .../examples/csv_sql_streaming.rs | 2 +- .../src/datasource/file_format/options.rs | 22 +++++++++++----- .../core/src/datasource/listing/table.rs | 26 +++++++++---------- .../src/datasource/listing_table_factory.rs | 3 ++- datafusion/expr/src/expr.rs | 26 +++++++++++++++++++ datafusion/proto/src/logical_plan/mod.rs | 8 ++++-- 6 files changed, 63 insertions(+), 24 deletions(-) diff --git a/datafusion-examples/examples/csv_sql_streaming.rs b/datafusion-examples/examples/csv_sql_streaming.rs index 99264bbcb486..cea42caaa075 100644 --- a/datafusion-examples/examples/csv_sql_streaming.rs +++ b/datafusion-examples/examples/csv_sql_streaming.rs @@ -38,7 +38,7 @@ async fn main() -> Result<()> { ctx.register_csv( "ordered_table", &format!("{testdata}/window_1.csv"), - CsvReadOptions::new().file_sort_order(vec![sort_expr]), + CsvReadOptions::new().file_sort_order_expr(vec![sort_expr]), ) .await?; diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 552977baba17..52ea4d645d51 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -31,7 +31,6 @@ use crate::datasource::{ }; use crate::error::Result; use crate::execution::context::{SessionConfig, SessionState}; -use crate::logical_expr::Expr; use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::config::TableOptions; @@ -41,6 +40,8 @@ use datafusion_common::{ }; use async_trait::async_trait; +use datafusion_expr::expr::sort_vec_vec_from_expr; +use datafusion_expr::{Expr, SortExpr}; /// Options that control the reading of CSV files. /// @@ -84,7 +85,7 @@ pub struct CsvReadOptions<'a> { /// File compression type pub file_compression_type: FileCompressionType, /// Indicates how the file is sorted - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl<'a> Default for CsvReadOptions<'a> { @@ -199,10 +200,17 @@ impl<'a> CsvReadOptions<'a> { } /// Configure if file has known sort order - pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } + + /// Configure if file has known sort order + // TODO (https://github.com/apache/datafusion/issues/12193) remove when transition is complete + pub fn file_sort_order_expr(mut self, file_sort_order: Vec>) -> Self { + self.file_sort_order = sort_vec_vec_from_expr(file_sort_order); + self + } } /// Options that control the reading of Parquet files. @@ -231,7 +239,7 @@ pub struct ParquetReadOptions<'a> { /// based on data in file. pub schema: Option<&'a Schema>, /// Indicates how the file is sorted - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl<'a> Default for ParquetReadOptions<'a> { @@ -278,7 +286,7 @@ impl<'a> ParquetReadOptions<'a> { } /// Configure if file has known sort order - pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } @@ -397,7 +405,7 @@ pub struct NdJsonReadOptions<'a> { /// Flag indicating whether this file may be unbounded (as in a FIFO file). pub infinite: bool, /// Indicates how the file is sorted - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl<'a> Default for NdJsonReadOptions<'a> { @@ -452,7 +460,7 @@ impl<'a> NdJsonReadOptions<'a> { } /// Configure if file has known sort order - pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 89066d8234ac..f614f85b6bf6 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -33,8 +33,8 @@ use crate::datasource::{ use crate::execution::context::SessionState; use datafusion_catalog::TableProvider; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::TableType; use datafusion_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}; +use datafusion_expr::{SortExpr, TableType}; use datafusion_physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}; use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; @@ -51,6 +51,7 @@ use datafusion_physical_expr::{ use async_trait::async_trait; use datafusion_catalog::Session; +use datafusion_expr::expr::sort_vec_vec_to_expr; use futures::{future, stream, StreamExt, TryStreamExt}; use itertools::Itertools; use object_store::ObjectStore; @@ -222,7 +223,7 @@ pub struct ListingOptions { /// ordering (encapsulated by a `Vec`). If there aren't /// multiple equivalent orderings, the outer `Vec` will have a /// single element. - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl ListingOptions { @@ -385,7 +386,7 @@ impl ListingOptions { /// /// assert_eq!(listing_options.file_sort_order, file_sort_order); /// ``` - pub fn with_file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn with_file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } @@ -713,7 +714,10 @@ impl ListingTable { /// If file_sort_order is specified, creates the appropriate physical expressions fn try_create_output_ordering(&self) -> Result> { - create_ordering(&self.table_schema, &self.options.file_sort_order) + create_ordering( + &self.table_schema, + &sort_vec_vec_to_expr(self.options.file_sort_order.clone()), + ) } } @@ -909,8 +913,7 @@ impl TableProvider for ListingTable { keep_partition_by_columns, }; - let unsorted: Vec> = vec![]; - let order_requirements = if self.options().file_sort_order != unsorted { + let order_requirements = if !self.options().file_sort_order.is_empty() { // Multiple sort orders in outer vec are equivalent, so we pass only the first one let ordering = self .try_create_output_ordering()? @@ -1065,6 +1068,7 @@ mod tests { use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::ExecutionPlanProperties; + use datafusion_expr::expr::sort_vec_vec_from_expr; use tempfile::TempDir; #[tokio::test] @@ -1155,16 +1159,10 @@ mod tests { use crate::datasource::file_format::parquet::ParquetFormat; use datafusion_physical_plan::expressions::col as physical_col; - use std::ops::Add; // (file_sort_order, expected_result) let cases = vec![ (vec![], Ok(vec![])), - // not a sort expr - ( - vec![vec![col("string_col")]], - Err("Expected Expr::Sort in output_ordering, but got string_col"), - ), // sort expr, but non column ( vec![vec![ @@ -1209,7 +1207,9 @@ mod tests { ]; for (file_sort_order, expected_result) in cases { - let options = options.clone().with_file_sort_order(file_sort_order); + let options = options + .clone() + .with_file_sort_order(sort_vec_vec_from_expr(file_sort_order)); let config = ListingTableConfig::new(table_path.clone()) .with_listing_options(options) diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 591a19aab49b..f2f37c391aba 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -33,6 +33,7 @@ use datafusion_expr::CreateExternalTable; use async_trait::async_trait; use datafusion_catalog::Session; +use datafusion_expr::expr::sort_vec_vec_from_expr; /// A `TableProviderFactory` capable of creating new `ListingTable`s #[derive(Debug, Default)] @@ -114,7 +115,7 @@ impl TableProviderFactory for ListingTableFactory { .with_file_extension(file_extension) .with_target_partitions(state.config().target_partitions()) .with_table_partition_cols(table_partition_cols) - .with_file_sort_order(cmd.order_exprs.clone()); + .with_file_sort_order(sort_vec_vec_from_expr(cmd.order_exprs.clone())); options .validate_partitions(session_state, &table_path) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 85ba80396c8e..0dc6abee20ef 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -633,6 +633,32 @@ impl Sort { } } +// TODO (https://github.com/apache/datafusion/issues/12193) remove when transition is complete +pub fn sort_vec_to_expr(sorts: Vec) -> Vec { + sorts.into_iter().map(Expr::Sort).collect() +} + +// TODO (https://github.com/apache/datafusion/issues/12193) remove when transition is complete +pub fn sort_vec_vec_to_expr(sorts: Vec>) -> Vec> { + sorts.into_iter().map(sort_vec_to_expr).collect() +} + +// TODO (https://github.com/apache/datafusion/issues/12193) remove when transition is complete +pub fn sort_vec_from_expr(exprs: Vec) -> Vec { + exprs + .into_iter() + .map(|expr| match expr { + Expr::Sort(s) => s, + _ => panic!("Expression must be a Expr::Sort: {}", expr), + }) + .collect() +} + +// TODO (https://github.com/apache/datafusion/issues/12193) remove when transition is complete +pub fn sort_vec_vec_from_expr(exprs: Vec>) -> Vec> { + exprs.into_iter().map(sort_vec_from_expr).collect() +} + /// Aggregate function /// /// See also [`ExprFunctionExt`] to set these fields on `Expr` diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 67977b1795a6..6a4be9486256 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -67,6 +67,7 @@ use datafusion_expr::{ use datafusion_expr::{AggregateUDF, Unnest}; use self::to_proto::{serialize_expr, serialize_exprs}; +use datafusion_expr::expr::{sort_vec_to_expr, sort_vec_vec_from_expr}; use prost::bytes::BufMut; use prost::Message; @@ -414,7 +415,7 @@ impl AsLogicalPlan for LogicalPlanNode { ) .with_collect_stat(scan.collect_stat) .with_target_partitions(scan.target_partitions as usize) - .with_file_sort_order(all_sort_orders); + .with_file_sort_order(sort_vec_vec_from_expr(all_sort_orders)); let config = ListingTableConfig::new_with_multi_paths(table_paths.clone()) @@ -984,7 +985,10 @@ impl AsLogicalPlan for LogicalPlanNode { let mut exprs_vec: Vec = vec![]; for order in &options.file_sort_order { let expr_vec = LogicalExprNodeCollection { - logical_expr_nodes: serialize_exprs(order, extension_codec)?, + logical_expr_nodes: serialize_exprs( + &sort_vec_to_expr(order.clone()), + extension_codec, + )?, }; exprs_vec.push(expr_vec); } From 48578d38ff02701840ffda8a74311f0eec217f87 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 27 Aug 2024 13:46:37 +0200 Subject: [PATCH 02/11] Return Sort from Expr.Sort Part of effort to remove `Expr::Sort`. --- datafusion-examples/examples/advanced_udwf.rs | 2 +- .../examples/csv_sql_streaming.rs | 2 +- datafusion-examples/examples/expr_api.rs | 2 +- .../examples/file_stream_provider.rs | 2 +- .../examples/parse_sql_expr.rs | 2 +- datafusion-examples/examples/simple_udwf.rs | 2 +- datafusion/core/src/dataframe/mod.rs | 53 ++++++++++--------- .../src/datasource/file_format/options.rs | 10 +--- .../core/src/datasource/listing/table.rs | 6 +-- .../physical_plan/file_scan_config.rs | 16 +++--- datafusion/core/src/physical_planner.rs | 2 +- datafusion/core/tests/dataframe/mod.rs | 4 +- datafusion/core/tests/expr_api/mod.rs | 8 +-- datafusion/core/tests/fifo/mod.rs | 2 +- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 2 +- .../core/tests/fuzz_cases/limit_fuzz.rs | 6 +-- datafusion/core/tests/sql/joins.rs | 4 +- datafusion/expr/src/expr.rs | 9 +++- datafusion/expr/src/expr_rewriter/order_by.rs | 2 +- datafusion/expr/src/utils.rs | 2 +- datafusion/expr/src/window_frame.rs | 2 +- .../src/analyzer/count_wildcard_rule.rs | 2 +- .../src/eliminate_duplicated_expr.rs | 3 +- .../src/single_distinct_to_groupby.rs | 4 +- .../tests/cases/roundtrip_logical_plan.rs | 14 ++--- datafusion/sql/src/unparser/expr.rs | 8 +-- datafusion/sql/tests/sql_integration.rs | 5 +- 27 files changed, 90 insertions(+), 86 deletions(-) diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index ec0318a561b9..e39e6d2bac89 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -219,7 +219,7 @@ async fn main() -> Result<()> { let window_expr = smooth_it .call(vec![col("speed")]) // smooth_it(speed) .partition_by(vec![col("car")]) // PARTITION BY car - .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC + .order_by(vec![col("time").sort(true, true).to_expr()]) // ORDER BY time ASC .window_frame(WindowFrame::new(None)) .build()?; let df = ctx.table("cars").await?.window(vec![window_expr])?; diff --git a/datafusion-examples/examples/csv_sql_streaming.rs b/datafusion-examples/examples/csv_sql_streaming.rs index cea42caaa075..99264bbcb486 100644 --- a/datafusion-examples/examples/csv_sql_streaming.rs +++ b/datafusion-examples/examples/csv_sql_streaming.rs @@ -38,7 +38,7 @@ async fn main() -> Result<()> { ctx.register_csv( "ordered_table", &format!("{testdata}/window_1.csv"), - CsvReadOptions::new().file_sort_order_expr(vec![sort_expr]), + CsvReadOptions::new().file_sort_order(vec![sort_expr]), ) .await?; diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 0eb823302acf..a30ec13463e4 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -99,7 +99,7 @@ fn expr_fn_demo() -> Result<()> { // such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts ) let agg = first_value .call(vec![col("price")]) - .order_by(vec![col("ts").sort(false, false)]) + .order_by(vec![col("ts").sort(false, false).to_expr()]) .filter(col("quantity").gt(lit(100))) .build()?; // build the aggregate assert_eq!( diff --git a/datafusion-examples/examples/file_stream_provider.rs b/datafusion-examples/examples/file_stream_provider.rs index b8549bd6b6e6..5c4f032adaec 100644 --- a/datafusion-examples/examples/file_stream_provider.rs +++ b/datafusion-examples/examples/file_stream_provider.rs @@ -157,7 +157,7 @@ mod non_windows { ])); // Specify the ordering: - let order = vec![vec![datafusion_expr::col("a1").sort(true, false)]]; + let order = vec![vec![datafusion_expr::col("a1").sort(true, false).to_expr()]]; let provider = fifo_table(schema.clone(), fifo_path, order.clone()); ctx.register_table("fifo", provider)?; diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index e23e5accae39..3fbb0637bf3b 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -114,7 +114,7 @@ async fn query_parquet_demo() -> Result<()> { )? // Directly parsing the SQL text into a sort expression is not supported yet, so // construct it programmatically - .sort(vec![col("double_col").sort(false, false)])? + .sort(vec![col("double_col").sort(false, false).to_expr()])? .limit(0, Some(1))?; let result = df.collect().await?; diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs index 22dfbbbf0c3a..c4a6a98b5dff 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/simple_udwf.rs @@ -121,7 +121,7 @@ async fn main() -> Result<()> { let window_expr = smooth_it .call(vec![col("speed")]) // smooth_it(speed) .partition_by(vec![col("car")]) // PARTITION BY car - .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC + .order_by(vec![col("time").sort(true, true).to_expr()]) // ORDER BY time ASC .window_frame(WindowFrame::new(None)) .build()?; let df = ctx.table("cars").await?.window(vec![window_expr])?; diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index c516c7985d54..794da5c98cd7 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -791,8 +791,8 @@ impl DataFrame { /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; /// let df = df.sort(vec![ - /// col("a").sort(true, true), // a ASC, nulls first - /// col("b").sort(false, false), // b DESC, nulls last + /// col("a").sort(true, true).to_expr(), // a ASC, nulls first + /// col("b").sort(false, false).to_expr(), // b DESC, nulls last /// ])?; /// # Ok(()) /// # } @@ -1319,7 +1319,7 @@ impl DataFrame { /// let ctx = SessionContext::new(); /// // Sort the data by column "b" and write it to a new location /// ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await? - /// .sort(vec![col("b").sort(true, true)])? // sort by b asc, nulls first + /// .sort(vec![col("b").sort(true, true).to_expr()])? // sort by b asc, nulls first /// .write_csv( /// "output.csv", /// DataFrameWriteOptions::new(), @@ -1379,7 +1379,7 @@ impl DataFrame { /// let ctx = SessionContext::new(); /// // Sort the data by column "b" and write it to a new location /// ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await? - /// .sort(vec![col("b").sort(true, true)])? // sort by b asc, nulls first + /// .sort(vec![col("b").sort(true, true).to_expr()])? // sort by b asc, nulls first /// .write_json( /// "output.json", /// DataFrameWriteOptions::new(), @@ -2403,7 +2403,10 @@ mod tests { Expr::WindowFunction(w) .null_treatment(NullTreatment::IgnoreNulls) - .order_by(vec![col("c2").sort(true, true), col("c3").sort(true, true)]) + .order_by(vec![ + col("c2").sort(true, true).to_expr(), + col("c3").sort(true, true).to_expr(), + ]) .window_frame(WindowFrame::new_bounds( WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::UInt64(None)), @@ -2493,7 +2496,7 @@ mod tests { .unwrap() .distinct() .unwrap() - .sort(vec![col("c1").sort(true, true)]) + .sort(vec![col("c1").sort(true, true).to_expr()]) .unwrap(); let df_results = plan.clone().collect().await?; @@ -2524,7 +2527,7 @@ mod tests { .distinct() .unwrap() // try to sort on some value not present in input to distinct - .sort(vec![col("c2").sort(true, true)]) + .sort(vec![col("c2").sort(true, true).to_expr()]) .unwrap_err(); assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions c2 must appear in select list"); @@ -2571,10 +2574,10 @@ mod tests { .distinct_on( vec![col("c1")], vec![col("c1")], - Some(vec![col("c1").sort(true, true)]), + Some(vec![col("c1").sort(true, true).to_expr()]), ) .unwrap() - .sort(vec![col("c1").sort(true, true)]) + .sort(vec![col("c1").sort(true, true).to_expr()]) .unwrap(); let df_results = plan.clone().collect().await?; @@ -2605,11 +2608,11 @@ mod tests { .distinct_on( vec![col("c1")], vec![col("c1")], - Some(vec![col("c1").sort(true, true)]), + Some(vec![col("c1").sort(true, true).to_expr()]), ) .unwrap() // try to sort on some value not present in input to distinct - .sort(vec![col("c2").sort(true, true)]) + .sort(vec![col("c2").sort(true, true).to_expr()]) .unwrap_err(); assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions c2 must appear in select list"); @@ -3015,7 +3018,7 @@ mod tests { )? .sort(vec![ // make the test deterministic - col("t1.c1").sort(true, true), + col("t1.c1").sort(true, true).to_expr(), ])? .limit(0, Some(1))?; @@ -3092,7 +3095,7 @@ mod tests { )? .sort(vec![ // make the test deterministic - col("t1.c1").sort(true, true), + col("t1.c1").sort(true, true).to_expr(), ])? .limit(0, Some(1))?; @@ -3125,9 +3128,9 @@ mod tests { .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))? .sort(vec![ // make the test deterministic - col("c1").sort(true, true), - col("c2").sort(true, true), - col("c3").sort(true, true), + col("c1").sort(true, true).to_expr(), + col("c2").sort(true, true).to_expr(), + col("c3").sort(true, true).to_expr(), ])? .limit(0, Some(1))? .with_column("sum", col("c2") + col("c3"))?; @@ -3205,12 +3208,12 @@ mod tests { )? .sort(vec![ // make the test deterministic - col("t1.c1").sort(true, true), - col("t1.c2").sort(true, true), - col("t1.c3").sort(true, true), - col("t2.c1").sort(true, true), - col("t2.c2").sort(true, true), - col("t2.c3").sort(true, true), + col("t1.c1").sort(true, true).to_expr(), + col("t1.c2").sort(true, true).to_expr(), + col("t1.c3").sort(true, true).to_expr(), + col("t2.c1").sort(true, true).to_expr(), + col("t2.c2").sort(true, true).to_expr(), + col("t2.c3").sort(true, true).to_expr(), ])? .limit(0, Some(1))?; @@ -3283,9 +3286,9 @@ mod tests { .limit(0, Some(1))? .sort(vec![ // make the test deterministic - col("c1").sort(true, true), - col("c2").sort(true, true), - col("c3").sort(true, true), + col("c1").sort(true, true).to_expr(), + col("c2").sort(true, true).to_expr(), + col("c3").sort(true, true).to_expr(), ])? .select_columns(&["c1"])?; diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 52ea4d645d51..db90262edbf8 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -40,8 +40,7 @@ use datafusion_common::{ }; use async_trait::async_trait; -use datafusion_expr::expr::sort_vec_vec_from_expr; -use datafusion_expr::{Expr, SortExpr}; +use datafusion_expr::SortExpr; /// Options that control the reading of CSV files. /// @@ -204,13 +203,6 @@ impl<'a> CsvReadOptions<'a> { self.file_sort_order = file_sort_order; self } - - /// Configure if file has known sort order - // TODO (https://github.com/apache/datafusion/issues/12193) remove when transition is complete - pub fn file_sort_order_expr(mut self, file_sort_order: Vec>) -> Self { - self.file_sort_order = sort_vec_vec_from_expr(file_sort_order); - self - } } /// Options that control the reading of Parquet files. diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index f614f85b6bf6..cccad1318ad4 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -1207,9 +1207,9 @@ mod tests { ]; for (file_sort_order, expected_result) in cases { - let options = options - .clone() - .with_file_sort_order(sort_vec_vec_from_expr(file_sort_order)); + let options = options.clone().with_file_sort_order(sort_vec_vec_from_expr( + sort_vec_vec_to_expr(file_sort_order), + )); let config = ListingTableConfig::new(table_path.clone()) .with_listing_options(options) diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 34fb6226c1a2..ba6d53308dd7 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -997,7 +997,7 @@ mod tests { File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), ], - sort: vec![col("value").sort(true, false)], + sort: vec![col("value").sort(true, false).to_expr()], expected_result: Ok(vec![vec!["0", "1"], vec!["2"]]), }, // same input but file '2' is in the middle @@ -1014,7 +1014,7 @@ mod tests { File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), ], - sort: vec![col("value").sort(true, false)], + sort: vec![col("value").sort(true, false).to_expr()], expected_result: Ok(vec![vec!["0", "1"], vec!["2"]]), }, TestCase { @@ -1029,7 +1029,7 @@ mod tests { File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), ], - sort: vec![col("value").sort(false, true)], + sort: vec![col("value").sort(false, true).to_expr()], expected_result: Ok(vec![vec!["1", "0"], vec!["2"]]), }, // reject nullable sort columns @@ -1045,7 +1045,7 @@ mod tests { File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), ], - sort: vec![col("value").sort(true, false)], + sort: vec![col("value").sort(true, false).to_expr()], expected_result: Err("construct min/max statistics for split_groups_by_statistics\ncaused by\nbuild min rows\ncaused by\ncreate sorting columns\ncaused by\nError during planning: cannot sort by nullable column") }, TestCase { @@ -1060,7 +1060,7 @@ mod tests { File::new("1", "2023-01-01", vec![Some((0.50, 0.99))]), File::new("2", "2023-01-02", vec![Some((1.00, 1.49))]), ], - sort: vec![col("value").sort(true, false)], + sort: vec![col("value").sort(true, false).to_expr()], expected_result: Ok(vec![vec!["0", "1", "2"]]), }, TestCase { @@ -1075,7 +1075,7 @@ mod tests { File::new("1", "2023-01-01", vec![Some((0.00, 0.49))]), File::new("2", "2023-01-02", vec![Some((0.00, 0.49))]), ], - sort: vec![col("value").sort(true, false)], + sort: vec![col("value").sort(true, false).to_expr()], expected_result: Ok(vec![vec!["0"], vec!["1"], vec!["2"]]), }, TestCase { @@ -1086,7 +1086,7 @@ mod tests { false, )]), files: vec![], - sort: vec![col("value").sort(true, false)], + sort: vec![col("value").sort(true, false).to_expr()], expected_result: Ok(vec![]), }, TestCase { @@ -1101,7 +1101,7 @@ mod tests { File::new("1", "2023-01-01", vec![Some((0.00, 0.49))]), File::new("2", "2023-01-02", vec![None]), ], - sort: vec![col("value").sort(true, false)], + sort: vec![col("value").sort(true, false).to_expr()], expected_result: Err("construct min/max statistics for split_groups_by_statistics\ncaused by\ncollect min/max values\ncaused by\nget min/max for column: 'value'\ncaused by\nError during planning: statistics not found"), }, ]; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 9501d3c6bbbb..5a7a5497e4d7 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2030,7 +2030,7 @@ mod tests { .filter(col("c7").lt(lit(5_u8)))? .project(vec![col("c1"), col("c2")])? .aggregate(vec![col("c1")], vec![sum(col("c2"))])? - .sort(vec![col("c1").sort(true, true)])? + .sort(vec![col("c1").sort(true, true).to_expr()])? .limit(3, Some(10))? .build()?; diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 86cacbaa06d8..7e92eff1838f 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -75,7 +75,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> { .table("t1") .await? .aggregate(vec![col("b")], vec![count(wildcard())])? - .sort(vec![count(wildcard()).sort(true, false)])? + .sort(vec![count(wildcard()).sort(true, false).to_expr()])? .explain(false, false)? .collect() .await?; @@ -452,7 +452,7 @@ async fn sort_on_ambiguous_column() -> Result<()> { &["a"], None, )? - .sort(vec![col("b").sort(true, true)]) + .sort(vec![col("b").sort(true, true).to_expr()]) .unwrap_err(); let expected = "Schema error: Ambiguous reference to unqualified field b"; diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 051d65652633..e5d2bf16da0b 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -189,14 +189,14 @@ async fn test_aggregate_ext_order_by() { // ORDER BY id ASC let agg_asc = agg .clone() - .order_by(vec![col("id").sort(true, true)]) + .order_by(vec![col("id").sort(true, true).to_expr()]) .build() .unwrap() .alias("asc"); // ORDER BY id DESC let agg_desc = agg - .order_by(vec![col("id").sort(false, true)]) + .order_by(vec![col("id").sort(false, true).to_expr()]) .build() .unwrap() .alias("desc"); @@ -230,7 +230,7 @@ async fn test_aggregate_ext_order_by() { async fn test_aggregate_ext_filter() { let agg = first_value_udaf() .call(vec![col("i")]) - .order_by(vec![col("i").sort(true, true)]) + .order_by(vec![col("i").sort(true, true).to_expr()]) .filter(col("i").is_not_null()) .build() .unwrap() @@ -277,7 +277,7 @@ async fn test_aggregate_ext_distinct() { async fn test_aggregate_ext_null_treatment() { let agg = first_value_udaf() .call(vec![col("i")]) - .order_by(vec![col("i").sort(true, true)]); + .order_by(vec![col("i").sort(true, true).to_expr()]); let agg_respect = agg .clone() diff --git a/datafusion/core/tests/fifo/mod.rs b/datafusion/core/tests/fifo/mod.rs index 6efbb9b029de..5ba3104b2cb5 100644 --- a/datafusion/core/tests/fifo/mod.rs +++ b/datafusion/core/tests/fifo/mod.rs @@ -247,7 +247,7 @@ mod unix_test { ])); // Specify the ordering: - let order = vec![vec![datafusion_expr::col("a1").sort(true, false)]]; + let order = vec![vec![datafusion_expr::col("a1").sort(true, false).to_expr()]]; // Set unbounded sorted files read configuration let provider = fifo_table(schema.clone(), left_fifo.clone(), order.clone()); diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 62e9be63983c..7cd3018bab54 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -292,7 +292,7 @@ async fn group_by_string_test( let provider = MemTable::try_new(schema.clone(), vec![input]).unwrap(); let provider = if sorted { - let sort_expr = datafusion::prelude::col("a").sort(true, true); + let sort_expr = datafusion::prelude::col("a").sort(true, true).to_expr(); provider.with_sort_order(vec![vec![sort_expr]]) } else { provider diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs index 9889ce2ae562..aa3d8c6f6933 100644 --- a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -229,12 +229,12 @@ impl SortedData { fn sort_expr(&self) -> Vec { match self { Self::I32 { .. } | Self::F64 { .. } | Self::Str { .. } => { - vec![datafusion_expr::col("x").sort(true, true)] + vec![datafusion_expr::col("x").sort(true, true).to_expr()] } Self::I64Str { .. } => { vec![ - datafusion_expr::col("x").sort(true, true), - datafusion_expr::col("y").sort(true, true), + datafusion_expr::col("x").sort(true, true).to_expr(), + datafusion_expr::col("y").sort(true, true).to_expr(), ] } } diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index addabc8a3612..07c786bf57e4 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -38,7 +38,7 @@ async fn join_change_in_planner() -> Result<()> { .map(|e| { let ascending = true; let nulls_first = false; - e.sort(ascending, nulls_first) + e.sort(ascending, nulls_first).to_expr() }) .collect::>()]; register_unbounded_file_with_ordering( @@ -106,7 +106,7 @@ async fn join_no_order_on_filter() -> Result<()> { .map(|e| { let ascending = true; let nulls_first = false; - e.sort(ascending, nulls_first) + e.sort(ascending, nulls_first).to_expr() }) .collect::>()]; register_unbounded_file_with_ordering( diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0dc6abee20ef..70db2b95bf09 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -631,6 +631,11 @@ impl Sort { nulls_first: !self.nulls_first, } } + + // TODO (https://github.com/apache/datafusion/issues/12193) remove when transition is complete + pub fn to_expr(self) -> Expr { + Expr::Sort(self) + } } // TODO (https://github.com/apache/datafusion/issues/12193) remove when transition is complete @@ -1404,8 +1409,8 @@ impl Expr { /// # use datafusion_expr::col; /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST /// ``` - pub fn sort(self, asc: bool, nulls_first: bool) -> Expr { - Expr::Sort(Sort::new(Box::new(self), asc, nulls_first)) + pub fn sort(self, asc: bool, nulls_first: bool) -> Sort { + Sort::new(Box::new(self), asc, nulls_first) } /// Return `IsTrue(Box(self))` diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index bbb855801c3e..19ab429c4f99 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -335,6 +335,6 @@ mod test { fn sort(expr: Expr) -> Expr { let asc = true; let nulls_first = true; - expr.sort(asc, nulls_first) + expr.sort(asc, nulls_first).to_expr() } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index a01d5ef8973a..60da38a17b84 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -483,7 +483,7 @@ pub fn generate_sort_key( partition_by.iter().for_each(|e| { // By default, create sort key with ASC is true and NULLS LAST to be consistent with // PostgreSQL's rule: https://www.postgresql.org/docs/current/queries-order.html - let e = e.clone().sort(true, false); + let e = e.clone().sort(true, false).to_expr(); if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) { let order_by_key = &order_by[pos]; if !final_sort_keys.contains(order_by_key) { diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 0e1d917419f8..38642c255a27 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -259,7 +259,7 @@ impl WindowFrame { // ORDER BY clause is present but has more than one column, // it is unchanged. Note that this follows PostgreSQL behavior. if order_by.is_empty() { - order_by.push(lit(1u64).sort(true, false)); + order_by.push(lit(1u64).sort(true, false).to_expr()); } } WindowFrameUnits::Range if order_by.len() != 1 => { diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index e114efb99960..8b746a7a8381 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -130,7 +130,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("b")], vec![count(wildcard())])? .project(vec![count(wildcard())])? - .sort(vec![count(wildcard()).sort(true, false)])? + .sort(vec![count(wildcard()).sort(true, false).to_expr()])? .build()?; let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\ \n Projection: count(*) [count(*):Int64]\ diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index e9d091d52b00..50a7fec4db8d 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -131,6 +131,7 @@ impl OptimizerRule for EliminateDuplicatedExpr { mod tests { use super::*; use crate::test::*; + use datafusion_expr::expr::sort_vec_to_expr; use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; @@ -165,7 +166,7 @@ mod tests { col("b").sort(false, true), ]; let plan = LogicalPlanBuilder::from(table_scan) - .sort(sort_exprs)? + .sort(sort_vec_to_expr(sort_exprs))? .limit(5, Some(10))? .build()?; let expected = "Limit: skip=5, fetch=10\ diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 30cae17eaf9f..d82ee7d4dbbc 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -645,7 +645,7 @@ mod tests { let expr = count_udaf() .call(vec![col("a")]) .distinct() - .order_by(vec![col("a").sort(true, false)]) + .order_by(vec![col("a").sort(true, false).to_expr()]) .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? @@ -666,7 +666,7 @@ mod tests { .call(vec![col("a")]) .distinct() .filter(col("a").gt(lit(5))) - .order_by(vec![col("a").sort(true, false)]) + .order_by(vec![col("a").sort(true, false).to_expr()]) .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 4f58185798f7..f6f9702f6cc1 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -871,7 +871,7 @@ async fn roundtrip_expr_api() -> Result<()> { count(lit(1)), count_distinct(lit(1)), first_value(lit(1), None), - first_value(lit(1), Some(vec![lit(2).sort(true, true)])), + first_value(lit(1), Some(vec![lit(2).sort(true, true).to_expr()])), avg(lit(1.5)), covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), @@ -2249,7 +2249,7 @@ fn roundtrip_window() { vec![], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2").sort(true, false)]) + .order_by(vec![col("col2").sort(true, false).to_expr()]) .window_frame(WindowFrame::new(Some(false))) .build() .unwrap(); @@ -2262,7 +2262,7 @@ fn roundtrip_window() { vec![], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2").sort(false, true)]) + .order_by(vec![col("col2").sort(false, true).to_expr()]) .window_frame(WindowFrame::new(Some(false))) .build() .unwrap(); @@ -2281,7 +2281,7 @@ fn roundtrip_window() { vec![], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2").sort(false, false)]) + .order_by(vec![col("col2").sort(false, false).to_expr()]) .window_frame(range_number_frame) .build() .unwrap(); @@ -2298,7 +2298,7 @@ fn roundtrip_window() { vec![col("col1")], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2").sort(true, true)]) + .order_by(vec![col("col2").sort(true, true).to_expr()]) .window_frame(row_number_frame.clone()) .build() .unwrap(); @@ -2348,7 +2348,7 @@ fn roundtrip_window() { vec![col("col1")], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2").sort(true, true)]) + .order_by(vec![col("col2").sort(true, true).to_expr()]) .window_frame(row_number_frame.clone()) .build() .unwrap(); @@ -2425,7 +2425,7 @@ fn roundtrip_window() { vec![col("col1")], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2").sort(true, true)]) + .order_by(vec![col("col2").sort(true, true).to_expr()]) .window_frame(row_number_frame.clone()) .build() .unwrap(); diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 9ce627aecc76..a59b64723730 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1527,7 +1527,7 @@ mod tests { case, col, cube, exists, grouping_set, interval_datetime_lit, interval_year_month_lit, lit, not, not_exists, out_ref_col, placeholder, rollup, table_scan, try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, - Signature, Volatility, WindowFrame, WindowFunctionDefinition, + Signature, SortExpr, Volatility, WindowFrame, WindowFunctionDefinition, }; use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt}; use datafusion_functions_aggregate::count::count_udaf; @@ -1945,7 +1945,7 @@ mod tests { fn expr_to_unparsed_ok() -> Result<()> { let tests: Vec<(Expr, &str)> = vec![ ((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#), - (col("a").sort(true, true), r#"a ASC NULLS FIRST"#), + (col("a").sort(true, true).to_expr(), r#"a ASC NULLS FIRST"#), ]; for (expr, expected) in tests { @@ -2047,7 +2047,7 @@ mod tests { #[test] fn customer_dialect_support_nulls_first_in_ort() -> Result<()> { - let tests: Vec<(Expr, &str, bool)> = vec![ + let tests: Vec<(SortExpr, &str, bool)> = vec![ (col("a").sort(true, true), r#"a ASC NULLS FIRST"#, true), (col("a").sort(true, true), r#"a ASC"#, false), ]; @@ -2057,7 +2057,7 @@ mod tests { .with_supports_nulls_first_in_sort(supports_nulls_first_in_sort) .build(); let unparser = Unparser::new(&dialect); - let ast = unparser.expr_to_unparsed(&expr)?; + let ast = unparser.expr_to_unparsed(&expr.to_expr())?; let actual = format!("{}", ast); diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 5a203703e967..ee52d3559cb2 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -4408,7 +4408,10 @@ fn plan_create_index() { assert_eq!(using, Some("btree".to_string())); assert_eq!( columns, - vec![col("name").sort(true, false), col("age").sort(false, true),] + vec![ + col("name").sort(true, false).to_expr(), + col("age").sort(false, true).to_expr(), + ] ); assert!(unique); assert!(if_not_exists); From e1e781402398138f624cd6b65318eb8e2e316a6a Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 27 Aug 2024 13:46:37 +0200 Subject: [PATCH 03/11] Accept Sort (SortExpr) in `LogicalPlanBuilder.sort` Take `expr::Sort` in `LogicalPlanBuilder.sort`. Accept any `Expr` in new function, `LogicalPlanBuilder.sort_by` which apply default sort ordering. Part of effort to remove `Expr::Sort`. --- datafusion/core/src/dataframe/mod.rs | 5 ++- datafusion/core/src/physical_planner.rs | 2 +- datafusion/expr/src/expr_rewriter/mod.rs | 16 +++++++- datafusion/expr/src/logical_plan/builder.rs | 41 +++++++++++++------ .../src/analyzer/count_wildcard_rule.rs | 2 +- .../src/eliminate_duplicated_expr.rs | 5 +-- datafusion/optimizer/src/eliminate_limit.rs | 6 +-- datafusion/optimizer/src/push_down_limit.rs | 4 +- .../src/replace_distinct_aggregate.rs | 3 +- datafusion/proto/src/logical_plan/mod.rs | 8 +++- datafusion/sql/src/query.rs | 5 ++- .../substrait/src/logical_plan/consumer.rs | 4 +- 12 files changed, 70 insertions(+), 31 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 794da5c98cd7..9352b2e91ca0 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -62,6 +62,7 @@ use datafusion_functions_aggregate::expr_fn::{ use async_trait::async_trait; use datafusion_catalog::Session; +use datafusion_expr::expr::sort_vec_from_expr; /// Contains options that control how data is /// written out from a DataFrame @@ -798,7 +799,9 @@ impl DataFrame { /// # } /// ``` pub fn sort(self, expr: Vec) -> Result { - let plan = LogicalPlanBuilder::from(self.plan).sort(expr)?.build()?; + let plan = LogicalPlanBuilder::from(self.plan) + .sort(sort_vec_from_expr(expr))? + .build()?; Ok(DataFrame { session_state: self.session_state, plan, diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 5a7a5497e4d7..9501d3c6bbbb 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2030,7 +2030,7 @@ mod tests { .filter(col("c7").lt(lit(5_u8)))? .project(vec![col("c1"), col("c2")])? .aggregate(vec![col("c1")], vec![sum(col("c2"))])? - .sort(vec![col("c1").sort(true, true).to_expr()])? + .sort(vec![col("c1").sort(true, true)])? .limit(3, Some(10))? .build()?; diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 768c4aabc840..375ef1edf49a 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -21,7 +21,7 @@ use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; -use crate::expr::{Alias, Unnest}; +use crate::expr::{Alias, Sort, Unnest}; use crate::logical_plan::Projection; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; @@ -117,6 +117,20 @@ pub fn normalize_cols( .collect() } +pub fn normalize_sorts( + sorts: impl IntoIterator>, + plan: &LogicalPlan, +) -> Result> { + sorts + .into_iter() + .map(|e| { + let sort = e.into(); + normalize_col(*sort.expr, plan) + .map(|expr| Sort::new(Box::new(expr), sort.asc, sort.nulls_first)) + }) + .collect() +} + /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 9894fe887de0..4539cb778c39 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -23,10 +23,10 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use crate::dml::CopyTo; -use crate::expr::Alias; +use crate::expr::{sort_vec_from_expr, sort_vec_to_expr, Alias}; use crate::expr_rewriter::{ coerce_plan_expr_for_schema, normalize_col, - normalize_col_with_schemas_and_ambiguity_check, normalize_cols, + normalize_col_with_schemas_and_ambiguity_check, normalize_cols, normalize_sorts, rewrite_sort_cols_by_aggs, }; use crate::logical_plan::{ @@ -42,7 +42,7 @@ use crate::utils::{ }; use crate::{ and, binary_expr, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, - TableProviderFilterPushDown, TableSource, WriteOp, + SortExpr, TableProviderFilterPushDown, TableSource, WriteOp, }; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; @@ -541,19 +541,34 @@ impl LogicalPlanBuilder { plan_err!("For SELECT DISTINCT, ORDER BY expressions {missing_col_names} must appear in select list") } + /// Apply a sort by provided expressions with default direction + pub fn sort_by( + self, + expr: impl IntoIterator> + Clone, + ) -> Result { + self.sort( + expr.into_iter() + .map(|e| e.into().sort(true, false)) + .collect::>(), + ) + } + /// Apply a sort pub fn sort( self, - exprs: impl IntoIterator> + Clone, + exprs: impl IntoIterator> + Clone, ) -> Result { - let exprs = rewrite_sort_cols_by_aggs(exprs, &self.plan)?; + let sorts = sort_vec_from_expr(rewrite_sort_cols_by_aggs( + sort_vec_to_expr(exprs.into_iter().map(|s| s.into()).collect()), + &self.plan, + )?); let schema = self.plan.schema(); // Collect sort columns that are missing in the input plan's schema let mut missing_cols: Vec = vec![]; - exprs.iter().try_for_each::<_, Result<()>>(|expr| { - let columns = expr.column_refs(); + sorts.iter().try_for_each::<_, Result<()>>(|sort| { + let columns = sort.expr.column_refs(); columns.into_iter().for_each(|c| { if !schema.has_column(c) { @@ -566,7 +581,7 @@ impl LogicalPlanBuilder { if missing_cols.is_empty() { return Ok(Self::new(LogicalPlan::Sort(Sort { - expr: normalize_cols(exprs, &self.plan)?, + expr: sort_vec_to_expr(normalize_sorts(sorts, &self.plan)?), input: self.plan, fetch: None, }))); @@ -582,7 +597,7 @@ impl LogicalPlanBuilder { is_distinct, )?; let sort_plan = LogicalPlan::Sort(Sort { - expr: normalize_cols(exprs, &plan)?, + expr: sort_vec_to_expr(normalize_sorts(sorts, &plan)?), input: Arc::new(plan), fetch: None, }); @@ -1708,8 +1723,8 @@ mod tests { let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))? .sort(vec![ - Expr::Sort(expr::Sort::new(Box::new(col("state")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("salary")), false, false)), + expr::Sort::new(Box::new(col("state")), true, true), + expr::Sort::new(Box::new(col("salary")), false, false), ])? .build()?; @@ -2135,8 +2150,8 @@ mod tests { let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))? .sort(vec![ - Expr::Sort(expr::Sort::new(Box::new(col("state")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("salary")), false, false)), + expr::Sort::new(Box::new(col("state")), true, true), + expr::Sort::new(Box::new(col("salary")), false, false), ])? .build()?; diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 8b746a7a8381..e114efb99960 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -130,7 +130,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("b")], vec![count(wildcard())])? .project(vec![count(wildcard())])? - .sort(vec![count(wildcard()).sort(true, false).to_expr()])? + .sort(vec![count(wildcard()).sort(true, false)])? .build()?; let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\ \n Projection: count(*) [count(*):Int64]\ diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 50a7fec4db8d..65520bee987f 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -131,7 +131,6 @@ impl OptimizerRule for EliminateDuplicatedExpr { mod tests { use super::*; use crate::test::*; - use datafusion_expr::expr::sort_vec_to_expr; use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; @@ -147,7 +146,7 @@ mod tests { fn eliminate_sort_expr() -> Result<()> { let table_scan = test_table_scan().unwrap(); let plan = LogicalPlanBuilder::from(table_scan) - .sort(vec![col("a"), col("a"), col("b"), col("c")])? + .sort_by(vec![col("a"), col("a"), col("b"), col("c")])? .limit(5, Some(10))? .build()?; let expected = "Limit: skip=5, fetch=10\ @@ -166,7 +165,7 @@ mod tests { col("b").sort(false, true), ]; let plan = LogicalPlanBuilder::from(table_scan) - .sort(sort_vec_to_expr(sort_exprs))? + .sort(sort_exprs)? .limit(5, Some(10))? .build()?; let expected = "Limit: skip=5, fetch=10\ diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index e48f37a77cd3..10baa58ad7dc 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -182,7 +182,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b"))])? .limit(0, Some(2))? - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(2, Some(1))? .build()?; @@ -202,7 +202,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b"))])? .limit(0, Some(2))? - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(0, Some(1))? .build()?; @@ -220,7 +220,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b"))])? .limit(2, Some(1))? - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(3, Some(1))? .build()?; diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index dff0b61c6b22..55ce05e5bc0e 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -347,7 +347,7 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(0, Some(10))? .build()?; @@ -364,7 +364,7 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(5, Some(10))? .build()?; diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index c887192f6370..b66cbf626b24 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -21,6 +21,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, Result}; +use datafusion_expr::expr::sort_vec_from_expr; use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; use datafusion_expr::{col, ExprFunctionExt, LogicalPlanBuilder}; @@ -144,7 +145,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { // truncate the sort_expr to the length of on_expr sort_expr.truncate(expr_cnt); - lpb.sort(sort_expr)?.build()? + lpb.sort(sort_vec_from_expr(sort_expr))?.build()? } else { lpb.build()? }; diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 6a4be9486256..c46ece2c06ae 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -67,7 +67,9 @@ use datafusion_expr::{ use datafusion_expr::{AggregateUDF, Unnest}; use self::to_proto::{serialize_expr, serialize_exprs}; -use datafusion_expr::expr::{sort_vec_to_expr, sort_vec_vec_from_expr}; +use datafusion_expr::expr::{ + sort_vec_from_expr, sort_vec_to_expr, sort_vec_vec_from_expr, +}; use prost::bytes::BufMut; use prost::Message; @@ -479,7 +481,9 @@ impl AsLogicalPlan for LogicalPlanNode { into_logical_plan!(sort.input, ctx, extension_codec)?; let sort_expr: Vec = from_proto::parse_exprs(&sort.expr, ctx, extension_codec)?; - LogicalPlanBuilder::from(input).sort(sort_expr)?.build() + LogicalPlanBuilder::from(input) + .sort(sort_vec_from_expr(sort_expr))? + .build() } LogicalPlanType::Repartition(repartition) => { use datafusion::logical_expr::Partitioning; diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index ba2b41bb6ecf..1c2b189d266c 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{not_impl_err, plan_err, Constraints, Result, ScalarValue}; +use datafusion_expr::expr::sort_vec_from_expr; use datafusion_expr::{ CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, Operator, @@ -131,7 +132,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let distinct_on = distinct_on.clone().with_sort_expr(order_by)?; Ok(LogicalPlan::Distinct(Distinct::On(distinct_on))) } else { - LogicalPlanBuilder::from(plan).sort(order_by)?.build() + LogicalPlanBuilder::from(plan) + .sort(sort_vec_from_expr(order_by))? + .build() } } diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index b1b510f1792d..00383eeb025a 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -27,7 +27,7 @@ use datafusion::common::{ DFSchemaRef, }; use datafusion::execution::FunctionRegistry; -use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; +use datafusion::logical_expr::expr::{sort_vec_from_expr, Exists, InSubquery, Sort}; use datafusion::logical_expr::{ expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, Expr, @@ -486,7 +486,7 @@ pub async fn from_substrait_rel( let sorts = from_substrait_sorts(ctx, &sort.sorts, input.schema(), extensions) .await?; - input.sort(sorts)?.build() + input.sort(sort_vec_from_expr(sorts))?.build() } else { not_impl_err!("Sort without an input is not valid") } From 75d3db565fbc70ab162c724d1e3984770312ce48 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 27 Aug 2024 13:46:37 +0200 Subject: [PATCH 04/11] Operate on `Sort` in to_substrait_sort_field / from_substrait_sorts Part of effort to remove `Expr::Sort`. --- datafusion/expr/src/expr.rs | 8 ++++ .../substrait/src/logical_plan/consumer.rs | 14 +++---- .../substrait/src/logical_plan/producer.rs | 41 ++++++++----------- 3 files changed, 33 insertions(+), 30 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 70db2b95bf09..61ece20076b9 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1413,6 +1413,14 @@ impl Expr { Sort::new(Box::new(self), asc, nulls_first) } + // TODO (https://github.com/apache/datafusion/issues/12193) remove when transition is complete + pub fn unwrap_sort(&self) -> &Sort { + match self { + Expr::Sort(sort) => sort, + _ => panic!("Expression must be a Expr::Sort: {}", self), + } + } + /// Return `IsTrue(Box(self))` pub fn is_true(self) -> Expr { Expr::IsTrue(Box::new(self)) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 00383eeb025a..bb0b603f614c 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -27,7 +27,7 @@ use datafusion::common::{ DFSchemaRef, }; use datafusion::execution::FunctionRegistry; -use datafusion::logical_expr::expr::{sort_vec_from_expr, Exists, InSubquery, Sort}; +use datafusion::logical_expr::expr::{sort_vec_to_expr, Exists, InSubquery, Sort}; use datafusion::logical_expr::{ expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, Expr, @@ -486,7 +486,7 @@ pub async fn from_substrait_rel( let sorts = from_substrait_sorts(ctx, &sort.sorts, input.schema(), extensions) .await?; - input.sort(sort_vec_from_expr(sorts))?.build() + input.sort(sorts)?.build() } else { not_impl_err!("Sort without an input is not valid") } @@ -900,8 +900,8 @@ pub async fn from_substrait_sorts( substrait_sorts: &Vec, input_schema: &DFSchema, extensions: &Extensions, -) -> Result> { - let mut sorts: Vec = vec![]; +) -> Result> { + let mut sorts: Vec = vec![]; for s in substrait_sorts { let expr = from_substrait_rex(ctx, s.expr.as_ref().unwrap(), input_schema, extensions) @@ -935,11 +935,11 @@ pub async fn from_substrait_sorts( None => not_impl_err!("Sort without sort kind is invalid"), }; let (asc, nulls_first) = asc_nullfirst.unwrap(); - sorts.push(Expr::Sort(Sort { + sorts.push(Sort { expr: Box::new(expr), asc, nulls_first, - })); + }); } Ok(sorts) } @@ -1237,7 +1237,7 @@ pub async fn from_substrait_rex( extensions, ) .await?, - order_by, + order_by: sort_vec_to_expr(order_by), window_frame: datafusion::logical_expr::WindowFrame::new_bounds( bound_units, from_substrait_bound(&window.lower_bound, true)?, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 72b6760be29c..14553fbdf452 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -764,7 +764,7 @@ pub fn to_substrait_agg_measure( match expr { Expr::AggregateFunction(expr::AggregateFunction { func, args, distinct, filter, order_by, null_treatment: _, }) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr.unwrap_sort(), schema, extensions)).collect::>>()? } else { vec![] }; @@ -808,31 +808,26 @@ pub fn to_substrait_agg_measure( /// Converts sort expression to corresponding substrait `SortField` fn to_substrait_sort_field( ctx: &SessionContext, - expr: &Expr, + sort: &Sort, schema: &DFSchemaRef, extensions: &mut Extensions, ) -> Result { - match expr { - Expr::Sort(sort) => { - let sort_kind = match (sort.asc, sort.nulls_first) { - (true, true) => SortDirection::AscNullsFirst, - (true, false) => SortDirection::AscNullsLast, - (false, true) => SortDirection::DescNullsFirst, - (false, false) => SortDirection::DescNullsLast, - }; - Ok(SortField { - expr: Some(to_substrait_rex( - ctx, - sort.expr.deref(), - schema, - 0, - extensions, - )?), - sort_kind: Some(SortKind::Direction(sort_kind.into())), - }) - } - _ => exec_err!("expects to receive sort expression"), - } + let sort_kind = match (sort.asc, sort.nulls_first) { + (true, true) => SortDirection::AscNullsFirst, + (true, false) => SortDirection::AscNullsLast, + (false, true) => SortDirection::DescNullsFirst, + (false, false) => SortDirection::DescNullsLast, + }; + Ok(SortField { + expr: Some(to_substrait_rex( + ctx, + sort.expr.deref(), + schema, + 0, + extensions, + )?), + sort_kind: Some(SortKind::Direction(sort_kind.into())), + }) } /// Return Substrait scalar function with two arguments From 57573ca35b3b6ff46a015fb8006ed068873f6843 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 27 Aug 2024 11:20:23 +0200 Subject: [PATCH 05/11] Take Sort (SortExpr) in tests' TopKPlanNode Part of effort to remove `Expr::Sort`. --- .../core/tests/user_defined/user_defined_plan.rs | 12 +++++++----- datafusion/expr/src/tree_node.rs | 7 +++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 62ba113da0d3..6fc5582f2e1a 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -59,6 +59,7 @@ //! use std::fmt::Debug; +use std::ops::Deref; use std::task::{Context, Poll}; use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; @@ -97,7 +98,8 @@ use datafusion::{ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; -use datafusion_expr::Projection; +use datafusion_expr::tree_node::replace_sort_expression; +use datafusion_expr::{Projection, SortExpr}; use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; @@ -375,7 +377,7 @@ impl OptimizerRule for TopKOptimizerRule { node: Arc::new(TopKPlanNode { k: *fetch, input: input.as_ref().clone(), - expr: expr[0].clone(), + expr: expr[0].unwrap_sort().clone(), }), }))); } @@ -392,7 +394,7 @@ struct TopKPlanNode { input: LogicalPlan, /// The sort expression (this example only supports a single sort /// expr) - expr: Expr, + expr: SortExpr, } impl Debug for TopKPlanNode { @@ -418,7 +420,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { } fn expressions(&self) -> Vec { - vec![self.expr.clone()] + vec![self.expr.expr.deref().clone()] } /// For example: `TopK: k=10` @@ -436,7 +438,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { Ok(Self { k: self.k, input: inputs.swap_remove(0), - expr: exprs.swap_remove(0), + expr: replace_sort_expression(self.expr.clone(), exprs.swap_remove(0)), }) } } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 450ebb6c2275..1678d13edea6 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -386,3 +386,10 @@ fn transform_vec Result>>( ) -> Result>> { ve.into_iter().map_until_stop_and_collect(f) } + +pub fn replace_sort_expression(sort: Sort, new_expr: Expr) -> Sort { + Sort { + expr: Box::new(new_expr), + ..sort + } +} From 6b2dfbdade2dcbc865b81abcabb48c1e9adc388d Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 27 Aug 2024 15:24:22 +0200 Subject: [PATCH 06/11] Remove Sort expression (`Expr::Sort`) Remove sort as an expression, i.e. remove `Expr::Sort` from `Expr` enum. Use `expr::Sort` directly when sorting. The sort expression was used in context of ordering (sort, topk, create table, file sorting). Those places require their sort expression to be of type Sort anyway and no other expression was allowed, so this change improves static typing. Sort as an expression was illegal in other contexts. --- datafusion-examples/examples/advanced_udwf.rs | 2 +- datafusion-examples/examples/expr_api.rs | 2 +- .../examples/file_stream_provider.rs | 6 +- .../examples/parse_sql_expr.rs | 2 +- datafusion-examples/examples/simple_udwf.rs | 2 +- datafusion/core/src/dataframe/mod.rs | 73 +++---- .../core/src/datasource/listing/helpers.rs | 3 +- .../core/src/datasource/listing/table.rs | 12 +- .../src/datasource/listing_table_factory.rs | 3 +- datafusion/core/src/datasource/memory.rs | 5 +- datafusion/core/src/datasource/mod.rs | 41 ++-- .../physical_plan/file_scan_config.rs | 18 +- datafusion/core/src/datasource/stream.rs | 6 +- datafusion/core/src/physical_planner.rs | 32 ++- datafusion/core/src/test_util/mod.rs | 4 +- datafusion/core/tests/dataframe/mod.rs | 24 +-- datafusion/core/tests/expr_api/mod.rs | 25 +-- datafusion/core/tests/fifo/mod.rs | 6 +- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 2 +- .../core/tests/fuzz_cases/limit_fuzz.rs | 8 +- datafusion/core/tests/sql/joins.rs | 4 +- .../tests/user_defined/user_defined_plan.rs | 5 +- datafusion/expr/src/expr.rs | 144 ++++--------- datafusion/expr/src/expr_fn.rs | 28 +-- datafusion/expr/src/expr_rewriter/mod.rs | 18 +- datafusion/expr/src/expr_rewriter/order_by.rs | 40 ++-- datafusion/expr/src/expr_schema.rs | 15 +- datafusion/expr/src/logical_plan/builder.rs | 17 +- datafusion/expr/src/logical_plan/ddl.rs | 7 +- datafusion/expr/src/logical_plan/plan.rs | 41 ++-- datafusion/expr/src/logical_plan/tree_node.rs | 18 +- datafusion/expr/src/tree_node.rs | 53 ++++- datafusion/expr/src/utils.rs | 194 +++++++----------- datafusion/expr/src/window_frame.rs | 6 +- .../functions-aggregate/src/first_last.rs | 4 +- .../src/analyzer/count_wildcard_rule.rs | 2 +- .../optimizer/src/analyzer/type_coercion.rs | 7 +- .../optimizer/src/common_subexpr_eliminate.rs | 15 +- .../src/eliminate_duplicated_expr.rs | 51 ++--- datafusion/optimizer/src/eliminate_limit.rs | 6 +- datafusion/optimizer/src/push_down_filter.rs | 3 +- datafusion/optimizer/src/push_down_limit.rs | 4 +- .../src/replace_distinct_aggregate.rs | 3 +- .../simplify_expressions/expr_simplifier.rs | 1 - .../src/single_distinct_to_groupby.rs | 8 +- datafusion/proto/proto/datafusion.proto | 17 +- datafusion/proto/src/generated/pbjson.rs | 105 ++++++++-- datafusion/proto/src/generated/prost.rs | 26 ++- .../proto/src/logical_plan/from_proto.rs | 49 +++-- datafusion/proto/src/logical_plan/mod.rs | 52 +++-- datafusion/proto/src/logical_plan/to_proto.rs | 48 +++-- .../tests/cases/roundtrip_logical_plan.rs | 24 +-- datafusion/sql/src/expr/function.rs | 25 +-- datafusion/sql/src/expr/order_by.rs | 8 +- datafusion/sql/src/query.rs | 8 +- datafusion/sql/src/select.rs | 4 +- datafusion/sql/src/statement.rs | 13 +- datafusion/sql/src/unparser/expr.rs | 122 +++-------- datafusion/sql/src/unparser/mod.rs | 2 - datafusion/sql/src/unparser/plan.rs | 17 +- datafusion/sql/src/unparser/rewrite.rs | 37 ++-- datafusion/sql/tests/sql_integration.rs | 5 +- .../substrait/src/logical_plan/consumer.rs | 8 +- .../substrait/src/logical_plan/producer.rs | 40 ++-- .../using-the-dataframe-api.md | 4 +- 65 files changed, 708 insertions(+), 876 deletions(-) diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index e39e6d2bac89..ec0318a561b9 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -219,7 +219,7 @@ async fn main() -> Result<()> { let window_expr = smooth_it .call(vec![col("speed")]) // smooth_it(speed) .partition_by(vec![col("car")]) // PARTITION BY car - .order_by(vec![col("time").sort(true, true).to_expr()]) // ORDER BY time ASC + .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC .window_frame(WindowFrame::new(None)) .build()?; let df = ctx.table("cars").await?.window(vec![window_expr])?; diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index a30ec13463e4..0eb823302acf 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -99,7 +99,7 @@ fn expr_fn_demo() -> Result<()> { // such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts ) let agg = first_value .call(vec![col("price")]) - .order_by(vec![col("ts").sort(false, false).to_expr()]) + .order_by(vec![col("ts").sort(false, false)]) .filter(col("quantity").gt(lit(100))) .build()?; // build the aggregate assert_eq!( diff --git a/datafusion-examples/examples/file_stream_provider.rs b/datafusion-examples/examples/file_stream_provider.rs index 5c4f032adaec..03dc67f4fda8 100644 --- a/datafusion-examples/examples/file_stream_provider.rs +++ b/datafusion-examples/examples/file_stream_provider.rs @@ -39,7 +39,7 @@ mod non_windows { use datafusion::datasource::TableProvider; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{exec_err, Result}; - use datafusion_expr::Expr; + use datafusion_expr::SortExpr; // Number of lines written to FIFO const TEST_BATCH_SIZE: usize = 5; @@ -49,7 +49,7 @@ mod non_windows { fn fifo_table( schema: SchemaRef, path: impl Into, - sort: Vec>, + sort: Vec>, ) -> Arc { let source = FileStreamProvider::new_file(schema, path.into()) .with_batch_size(TEST_BATCH_SIZE) @@ -157,7 +157,7 @@ mod non_windows { ])); // Specify the ordering: - let order = vec![vec![datafusion_expr::col("a1").sort(true, false).to_expr()]]; + let order = vec![vec![datafusion_expr::col("a1").sort(true, false)]]; let provider = fifo_table(schema.clone(), fifo_path, order.clone()); ctx.register_table("fifo", provider)?; diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index 3fbb0637bf3b..e23e5accae39 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -114,7 +114,7 @@ async fn query_parquet_demo() -> Result<()> { )? // Directly parsing the SQL text into a sort expression is not supported yet, so // construct it programmatically - .sort(vec![col("double_col").sort(false, false).to_expr()])? + .sort(vec![col("double_col").sort(false, false)])? .limit(0, Some(1))?; let result = df.collect().await?; diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs index c4a6a98b5dff..22dfbbbf0c3a 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/simple_udwf.rs @@ -121,7 +121,7 @@ async fn main() -> Result<()> { let window_expr = smooth_it .call(vec![col("speed")]) // smooth_it(speed) .partition_by(vec![col("car")]) // PARTITION BY car - .order_by(vec![col("time").sort(true, true).to_expr()]) // ORDER BY time ASC + .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC .window_frame(WindowFrame::new(None)) .build()?; let df = ctx.table("cars").await?.window(vec![window_expr])?; diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 9352b2e91ca0..5dbeb535a546 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -52,7 +52,7 @@ use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions, }; -use datafusion_expr::{case, is_null, lit}; +use datafusion_expr::{case, is_null, lit, SortExpr}; use datafusion_expr::{ utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; @@ -62,7 +62,6 @@ use datafusion_functions_aggregate::expr_fn::{ use async_trait::async_trait; use datafusion_catalog::Session; -use datafusion_expr::expr::sort_vec_from_expr; /// Contains options that control how data is /// written out from a DataFrame @@ -578,7 +577,7 @@ impl DataFrame { self, on_expr: Vec, select_expr: Vec, - sort_expr: Option>, + sort_expr: Option>, ) -> Result { let plan = LogicalPlanBuilder::from(self.plan) .distinct_on(on_expr, select_expr, sort_expr)? @@ -777,6 +776,15 @@ impl DataFrame { }) } + /// Apply a sort by provided expressions with default direction + pub fn sort_by(self, expr: Vec) -> Result { + self.sort( + expr.into_iter() + .map(|e| e.sort(true, false)) + .collect::>(), + ) + } + /// Sort the DataFrame by the specified sorting expressions. /// /// Note that any expression can be turned into @@ -792,16 +800,14 @@ impl DataFrame { /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; /// let df = df.sort(vec![ - /// col("a").sort(true, true).to_expr(), // a ASC, nulls first - /// col("b").sort(false, false).to_expr(), // b DESC, nulls last + /// col("a").sort(true, true), // a ASC, nulls first + /// col("b").sort(false, false), // b DESC, nulls last /// ])?; /// # Ok(()) /// # } /// ``` - pub fn sort(self, expr: Vec) -> Result { - let plan = LogicalPlanBuilder::from(self.plan) - .sort(sort_vec_from_expr(expr))? - .build()?; + pub fn sort(self, expr: Vec) -> Result { + let plan = LogicalPlanBuilder::from(self.plan).sort(expr)?.build()?; Ok(DataFrame { session_state: self.session_state, plan, @@ -1322,7 +1328,7 @@ impl DataFrame { /// let ctx = SessionContext::new(); /// // Sort the data by column "b" and write it to a new location /// ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await? - /// .sort(vec![col("b").sort(true, true).to_expr()])? // sort by b asc, nulls first + /// .sort(vec![col("b").sort(true, true)])? // sort by b asc, nulls first /// .write_csv( /// "output.csv", /// DataFrameWriteOptions::new(), @@ -1382,7 +1388,7 @@ impl DataFrame { /// let ctx = SessionContext::new(); /// // Sort the data by column "b" and write it to a new location /// ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await? - /// .sort(vec![col("b").sort(true, true).to_expr()])? // sort by b asc, nulls first + /// .sort(vec![col("b").sort(true, true)])? // sort by b asc, nulls first /// .write_json( /// "output.json", /// DataFrameWriteOptions::new(), @@ -2406,10 +2412,7 @@ mod tests { Expr::WindowFunction(w) .null_treatment(NullTreatment::IgnoreNulls) - .order_by(vec![ - col("c2").sort(true, true).to_expr(), - col("c3").sort(true, true).to_expr(), - ]) + .order_by(vec![col("c2").sort(true, true), col("c3").sort(true, true)]) .window_frame(WindowFrame::new_bounds( WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::UInt64(None)), @@ -2499,7 +2502,7 @@ mod tests { .unwrap() .distinct() .unwrap() - .sort(vec![col("c1").sort(true, true).to_expr()]) + .sort(vec![col("c1").sort(true, true)]) .unwrap(); let df_results = plan.clone().collect().await?; @@ -2530,7 +2533,7 @@ mod tests { .distinct() .unwrap() // try to sort on some value not present in input to distinct - .sort(vec![col("c2").sort(true, true).to_expr()]) + .sort(vec![col("c2").sort(true, true)]) .unwrap_err(); assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions c2 must appear in select list"); @@ -2577,10 +2580,10 @@ mod tests { .distinct_on( vec![col("c1")], vec![col("c1")], - Some(vec![col("c1").sort(true, true).to_expr()]), + Some(vec![col("c1").sort(true, true)]), ) .unwrap() - .sort(vec![col("c1").sort(true, true).to_expr()]) + .sort(vec![col("c1").sort(true, true)]) .unwrap(); let df_results = plan.clone().collect().await?; @@ -2611,11 +2614,11 @@ mod tests { .distinct_on( vec![col("c1")], vec![col("c1")], - Some(vec![col("c1").sort(true, true).to_expr()]), + Some(vec![col("c1").sort(true, true)]), ) .unwrap() // try to sort on some value not present in input to distinct - .sort(vec![col("c2").sort(true, true).to_expr()]) + .sort(vec![col("c2").sort(true, true)]) .unwrap_err(); assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions c2 must appear in select list"); @@ -3021,7 +3024,7 @@ mod tests { )? .sort(vec![ // make the test deterministic - col("t1.c1").sort(true, true).to_expr(), + col("t1.c1").sort(true, true), ])? .limit(0, Some(1))?; @@ -3098,7 +3101,7 @@ mod tests { )? .sort(vec![ // make the test deterministic - col("t1.c1").sort(true, true).to_expr(), + col("t1.c1").sort(true, true), ])? .limit(0, Some(1))?; @@ -3131,9 +3134,9 @@ mod tests { .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))? .sort(vec![ // make the test deterministic - col("c1").sort(true, true).to_expr(), - col("c2").sort(true, true).to_expr(), - col("c3").sort(true, true).to_expr(), + col("c1").sort(true, true), + col("c2").sort(true, true), + col("c3").sort(true, true), ])? .limit(0, Some(1))? .with_column("sum", col("c2") + col("c3"))?; @@ -3211,12 +3214,12 @@ mod tests { )? .sort(vec![ // make the test deterministic - col("t1.c1").sort(true, true).to_expr(), - col("t1.c2").sort(true, true).to_expr(), - col("t1.c3").sort(true, true).to_expr(), - col("t2.c1").sort(true, true).to_expr(), - col("t2.c2").sort(true, true).to_expr(), - col("t2.c3").sort(true, true).to_expr(), + col("t1.c1").sort(true, true), + col("t1.c2").sort(true, true), + col("t1.c3").sort(true, true), + col("t2.c1").sort(true, true), + col("t2.c2").sort(true, true), + col("t2.c3").sort(true, true), ])? .limit(0, Some(1))?; @@ -3289,9 +3292,9 @@ mod tests { .limit(0, Some(1))? .sort(vec![ // make the test deterministic - col("c1").sort(true, true).to_expr(), - col("c2").sort(true, true).to_expr(), - col("c3").sort(true, true).to_expr(), + col("c1").sort(true, true), + col("c2").sort(true, true), + col("c3").sort(true, true), ])? .select_columns(&["c1"])?; diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index b5dd2dd12e10..0d9ded22a9a9 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -102,11 +102,10 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { } // TODO other expressions are not handled yet: - // - AGGREGATE, WINDOW and SORT should not end up in filter conditions, except maybe in some edge cases + // - AGGREGATE and WINDOW should not end up in filter conditions, except maybe in some edge cases // - Can `Wildcard` be considered as a `Literal`? // - ScalarVariable could be `applicable`, but that would require access to the context Expr::AggregateFunction { .. } - | Expr::Sort { .. } | Expr::WindowFunction { .. } | Expr::Wildcard { .. } | Expr::Unnest { .. } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index cccad1318ad4..3207f2f66822 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -51,7 +51,6 @@ use datafusion_physical_expr::{ use async_trait::async_trait; use datafusion_catalog::Session; -use datafusion_expr::expr::sort_vec_vec_to_expr; use futures::{future, stream, StreamExt, TryStreamExt}; use itertools::Itertools; use object_store::ObjectStore; @@ -714,10 +713,7 @@ impl ListingTable { /// If file_sort_order is specified, creates the appropriate physical expressions fn try_create_output_ordering(&self) -> Result> { - create_ordering( - &self.table_schema, - &sort_vec_vec_to_expr(self.options.file_sort_order.clone()), - ) + create_ordering(&self.table_schema, &self.options.file_sort_order) } } @@ -1068,7 +1064,6 @@ mod tests { use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::ExecutionPlanProperties; - use datafusion_expr::expr::sort_vec_vec_from_expr; use tempfile::TempDir; #[tokio::test] @@ -1159,6 +1154,7 @@ mod tests { use crate::datasource::file_format::parquet::ParquetFormat; use datafusion_physical_plan::expressions::col as physical_col; + use std::ops::Add; // (file_sort_order, expected_result) let cases = vec![ @@ -1207,9 +1203,7 @@ mod tests { ]; for (file_sort_order, expected_result) in cases { - let options = options.clone().with_file_sort_order(sort_vec_vec_from_expr( - sort_vec_vec_to_expr(file_sort_order), - )); + let options = options.clone().with_file_sort_order(file_sort_order); let config = ListingTableConfig::new(table_path.clone()) .with_listing_options(options) diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index f2f37c391aba..591a19aab49b 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -33,7 +33,6 @@ use datafusion_expr::CreateExternalTable; use async_trait::async_trait; use datafusion_catalog::Session; -use datafusion_expr::expr::sort_vec_vec_from_expr; /// A `TableProviderFactory` capable of creating new `ListingTable`s #[derive(Debug, Default)] @@ -115,7 +114,7 @@ impl TableProviderFactory for ListingTableFactory { .with_file_extension(file_extension) .with_target_partitions(state.config().target_partitions()) .with_table_partition_cols(table_partition_cols) - .with_file_sort_order(sort_vec_vec_from_expr(cmd.order_exprs.clone())); + .with_file_sort_order(cmd.order_exprs.clone()); options .validate_partitions(session_state, &table_path) diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 44e01e71648a..cef7f210e118 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -43,6 +43,7 @@ use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; use datafusion_catalog::Session; +use datafusion_expr::SortExpr; use futures::StreamExt; use log::debug; use parking_lot::Mutex; @@ -64,7 +65,7 @@ pub struct MemTable { column_defaults: HashMap, /// Optional pre-known sort order(s). Must be `SortExpr`s. /// inserting data into this table removes the order - pub sort_order: Arc>>>, + pub sort_order: Arc>>>, } impl MemTable { @@ -118,7 +119,7 @@ impl MemTable { /// /// Note that multiple sort orders are supported, if some are known to be /// equivalent, - pub fn with_sort_order(self, mut sort_order: Vec>) -> Self { + pub fn with_sort_order(self, mut sort_order: Vec>) -> Self { std::mem::swap(self.sort_order.lock().as_mut(), &mut sort_order); self } diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 1c9924735735..55e88e572be1 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -50,38 +50,39 @@ pub use statistics::get_statistics_with_limit; use arrow_schema::{Schema, SortOptions}; use datafusion_common::{plan_err, Result}; -use datafusion_expr::Expr; +use datafusion_expr::{Expr, SortExpr}; use datafusion_physical_expr::{expressions, LexOrdering, PhysicalSortExpr}; fn create_ordering( schema: &Schema, - sort_order: &[Vec], + sort_order: &[Vec], ) -> Result> { let mut all_sort_orders = vec![]; for exprs in sort_order { // Construct PhysicalSortExpr objects from Expr objects: let mut sort_exprs = vec![]; - for expr in exprs { - match expr { - Expr::Sort(sort) => match sort.expr.as_ref() { - Expr::Column(col) => match expressions::col(&col.name, schema) { - Ok(expr) => { - sort_exprs.push(PhysicalSortExpr { - expr, - options: SortOptions { - descending: !sort.asc, - nulls_first: sort.nulls_first, - }, - }); - } - // Cannot find expression in the projected_schema, stop iterating - // since rest of the orderings are violated - Err(_) => break, + for sort in exprs { + match sort.expr.as_ref() { + Expr::Column(col) => match expressions::col(&col.name, schema) { + Ok(expr) => { + sort_exprs.push(PhysicalSortExpr { + expr, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); } - expr => return plan_err!("Expected single column references in output_ordering, got {expr}"), + // Cannot find expression in the projected_schema, stop iterating + // since rest of the orderings are violated + Err(_) => break, + }, + expr => { + return plan_err!( + "Expected single column references in output_ordering, got {expr}" + ) } - expr => return plan_err!("Expected Expr::Sort in output_ordering, but got {expr}"), } } if !sort_exprs.is_empty() { diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index ba6d53308dd7..763dd59ab523 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -979,7 +979,7 @@ mod tests { name: &'static str, file_schema: Schema, files: Vec, - sort: Vec, + sort: Vec, expected_result: Result>, &'static str>, } @@ -997,7 +997,7 @@ mod tests { File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), ], - sort: vec![col("value").sort(true, false).to_expr()], + sort: vec![col("value").sort(true, false)], expected_result: Ok(vec![vec!["0", "1"], vec!["2"]]), }, // same input but file '2' is in the middle @@ -1014,7 +1014,7 @@ mod tests { File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), ], - sort: vec![col("value").sort(true, false).to_expr()], + sort: vec![col("value").sort(true, false)], expected_result: Ok(vec![vec!["0", "1"], vec!["2"]]), }, TestCase { @@ -1029,7 +1029,7 @@ mod tests { File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), ], - sort: vec![col("value").sort(false, true).to_expr()], + sort: vec![col("value").sort(false, true)], expected_result: Ok(vec![vec!["1", "0"], vec!["2"]]), }, // reject nullable sort columns @@ -1045,7 +1045,7 @@ mod tests { File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), ], - sort: vec![col("value").sort(true, false).to_expr()], + sort: vec![col("value").sort(true, false)], expected_result: Err("construct min/max statistics for split_groups_by_statistics\ncaused by\nbuild min rows\ncaused by\ncreate sorting columns\ncaused by\nError during planning: cannot sort by nullable column") }, TestCase { @@ -1060,7 +1060,7 @@ mod tests { File::new("1", "2023-01-01", vec![Some((0.50, 0.99))]), File::new("2", "2023-01-02", vec![Some((1.00, 1.49))]), ], - sort: vec![col("value").sort(true, false).to_expr()], + sort: vec![col("value").sort(true, false)], expected_result: Ok(vec![vec!["0", "1", "2"]]), }, TestCase { @@ -1075,7 +1075,7 @@ mod tests { File::new("1", "2023-01-01", vec![Some((0.00, 0.49))]), File::new("2", "2023-01-02", vec![Some((0.00, 0.49))]), ], - sort: vec![col("value").sort(true, false).to_expr()], + sort: vec![col("value").sort(true, false)], expected_result: Ok(vec![vec!["0"], vec!["1"], vec!["2"]]), }, TestCase { @@ -1086,7 +1086,7 @@ mod tests { false, )]), files: vec![], - sort: vec![col("value").sort(true, false).to_expr()], + sort: vec![col("value").sort(true, false)], expected_result: Ok(vec![]), }, TestCase { @@ -1101,7 +1101,7 @@ mod tests { File::new("1", "2023-01-01", vec![Some((0.00, 0.49))]), File::new("2", "2023-01-02", vec![None]), ], - sort: vec![col("value").sort(true, false).to_expr()], + sort: vec![col("value").sort(true, false)], expected_result: Err("construct min/max statistics for split_groups_by_statistics\ncaused by\ncollect min/max values\ncaused by\nget min/max for column: 'value'\ncaused by\nError during planning: statistics not found"), }, ]; diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index b53fe8663178..ef6d195cdaff 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -33,7 +33,7 @@ use arrow_schema::SchemaRef; use datafusion_common::{config_err, plan_err, Constraints, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder; @@ -248,7 +248,7 @@ impl StreamProvider for FileStreamProvider { #[derive(Debug)] pub struct StreamConfig { source: Arc, - order: Vec>, + order: Vec>, constraints: Constraints, } @@ -263,7 +263,7 @@ impl StreamConfig { } /// Specify a sort order for the stream - pub fn with_order(mut self, order: Vec>) -> Self { + pub fn with_order(mut self, order: Vec>) -> Self { self.order = order; self } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 9501d3c6bbbb..1b43e8d57bc3 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -73,13 +73,13 @@ use datafusion_common::{ }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ - self, physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction, + physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, StringifiedPlan, - WindowFrame, WindowFrameBound, WriteOp, + DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, SortExpr, + StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Literal; @@ -1641,31 +1641,27 @@ pub fn create_aggregate_expr_and_maybe_filter( /// Create a physical sort expression from a logical expression pub fn create_physical_sort_expr( - e: &Expr, + e: &SortExpr, input_dfschema: &DFSchema, execution_props: &ExecutionProps, ) -> Result { - if let Expr::Sort(expr::Sort { + let SortExpr { expr, asc, nulls_first, - }) = e - { - Ok(PhysicalSortExpr { - expr: create_physical_expr(expr, input_dfschema, execution_props)?, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, - }) - } else { - internal_err!("Expects a sort expression") - } + } = e; + Ok(PhysicalSortExpr { + expr: create_physical_expr(expr, input_dfschema, execution_props)?, + options: SortOptions { + descending: !asc, + nulls_first: *nulls_first, + }, + }) } /// Create vector of physical sort expression from a vector of logical expression pub fn create_physical_sort_exprs( - exprs: &[Expr], + exprs: &[SortExpr], input_dfschema: &DFSchema, execution_props: &ExecutionProps, ) -> Result { diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index faa9378535fd..dd8b697666ee 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -46,7 +46,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::TableReference; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::{expressions, EquivalenceProperties, PhysicalExpr}; @@ -360,7 +360,7 @@ pub fn register_unbounded_file_with_ordering( schema: SchemaRef, file_path: &Path, table_name: &str, - file_sort_order: Vec>, + file_sort_order: Vec>, ) -> Result<()> { let source = FileStreamProvider::new_file(schema, file_path.into()); let config = StreamConfig::new(Arc::new(source)).with_order(file_sort_order); diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 7e92eff1838f..c5b9db7588e9 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -75,7 +75,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> { .table("t1") .await? .aggregate(vec![col("b")], vec![count(wildcard())])? - .sort(vec![count(wildcard()).sort(true, false).to_expr()])? + .sort(vec![count(wildcard()).sort(true, false)])? .explain(false, false)? .collect() .await?; @@ -184,7 +184,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) - .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .order_by(vec![Sort::new(Box::new(col("a")), false, true)]) .window_frame(WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), @@ -352,7 +352,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { .unwrap() .select(vec![col("a")]) .unwrap() - .sort(vec![Expr::Sort(Sort::new(Box::new(col("b")), false, true))]) + .sort(vec![Sort::new(Box::new(col("b")), false, true)]) .unwrap(); let results = df.collect().await.unwrap(); @@ -396,7 +396,7 @@ async fn sort_on_distinct_columns() -> Result<()> { .unwrap() .distinct() .unwrap() - .sort(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .sort(vec![Sort::new(Box::new(col("a")), false, true)]) .unwrap(); let results = df.collect().await.unwrap(); @@ -435,7 +435,7 @@ async fn sort_on_distinct_unprojected_columns() -> Result<()> { .await? .select(vec![col("a")])? .distinct()? - .sort(vec![Expr::Sort(Sort::new(Box::new(col("b")), false, true))]) + .sort(vec![Sort::new(Box::new(col("b")), false, true)]) .unwrap_err(); assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions b must appear in select list"); Ok(()) @@ -452,7 +452,7 @@ async fn sort_on_ambiguous_column() -> Result<()> { &["a"], None, )? - .sort(vec![col("b").sort(true, true).to_expr()]) + .sort(vec![col("b").sort(true, true)]) .unwrap_err(); let expected = "Schema error: Ambiguous reference to unqualified field b"; @@ -599,8 +599,8 @@ async fn test_grouping_sets() -> Result<()> { .await? .aggregate(vec![grouping_set_expr], vec![count(col("a"))])? .sort(vec![ - Expr::Sort(Sort::new(Box::new(col("a")), false, true)), - Expr::Sort(Sort::new(Box::new(col("b")), false, true)), + Sort::new(Box::new(col("a")), false, true), + Sort::new(Box::new(col("b")), false, true), ])?; let results = df.collect().await?; @@ -640,8 +640,8 @@ async fn test_grouping_sets_count() -> Result<()> { .await? .aggregate(vec![grouping_set_expr], vec![count(lit(1))])? .sort(vec![ - Expr::Sort(Sort::new(Box::new(col("c1")), false, true)), - Expr::Sort(Sort::new(Box::new(col("c2")), false, true)), + Sort::new(Box::new(col("c1")), false, true), + Sort::new(Box::new(col("c2")), false, true), ])?; let results = df.collect().await?; @@ -687,8 +687,8 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { ], )? .sort(vec![ - Expr::Sort(Sort::new(Box::new(col("c1")), false, true)), - Expr::Sort(Sort::new(Box::new(col("c2")), false, true)), + Sort::new(Box::new(col("c1")), false, true), + Sort::new(Box::new(col("c2")), false, true), ])?; let results = df.collect().await?; diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index e5d2bf16da0b..cbd892672152 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -20,7 +20,7 @@ use arrow_array::builder::{ListBuilder, StringBuilder}; use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray}; use arrow_schema::{DataType, Field}; use datafusion::prelude::*; -use datafusion_common::{assert_contains, DFSchema, ScalarValue}; +use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::ExprFunctionExt; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions_aggregate::first_last::first_value_udaf; @@ -167,21 +167,6 @@ fn test_list_range() { ); } -#[tokio::test] -async fn test_aggregate_error() { - let err = first_value_udaf() - .call(vec![col("props")]) - // not a sort column - .order_by(vec![col("id")]) - .build() - .unwrap_err() - .to_string(); - assert_contains!( - err, - "Error during planning: ORDER BY expressions must be Expr::Sort" - ); -} - #[tokio::test] async fn test_aggregate_ext_order_by() { let agg = first_value_udaf().call(vec![col("props")]); @@ -189,14 +174,14 @@ async fn test_aggregate_ext_order_by() { // ORDER BY id ASC let agg_asc = agg .clone() - .order_by(vec![col("id").sort(true, true).to_expr()]) + .order_by(vec![col("id").sort(true, true)]) .build() .unwrap() .alias("asc"); // ORDER BY id DESC let agg_desc = agg - .order_by(vec![col("id").sort(false, true).to_expr()]) + .order_by(vec![col("id").sort(false, true)]) .build() .unwrap() .alias("desc"); @@ -230,7 +215,7 @@ async fn test_aggregate_ext_order_by() { async fn test_aggregate_ext_filter() { let agg = first_value_udaf() .call(vec![col("i")]) - .order_by(vec![col("i").sort(true, true).to_expr()]) + .order_by(vec![col("i").sort(true, true)]) .filter(col("i").is_not_null()) .build() .unwrap() @@ -277,7 +262,7 @@ async fn test_aggregate_ext_distinct() { async fn test_aggregate_ext_null_treatment() { let agg = first_value_udaf() .call(vec![col("i")]) - .order_by(vec![col("i").sort(true, true).to_expr()]); + .order_by(vec![col("i").sort(true, true)]); let agg_respect = agg .clone() diff --git a/datafusion/core/tests/fifo/mod.rs b/datafusion/core/tests/fifo/mod.rs index 5ba3104b2cb5..cb587e3510c2 100644 --- a/datafusion/core/tests/fifo/mod.rs +++ b/datafusion/core/tests/fifo/mod.rs @@ -38,7 +38,7 @@ mod unix_test { }; use datafusion_common::instant::Instant; use datafusion_common::{exec_err, Result}; - use datafusion_expr::Expr; + use datafusion_expr::SortExpr; use futures::StreamExt; use nix::sys::stat; @@ -51,7 +51,7 @@ mod unix_test { fn fifo_table( schema: SchemaRef, path: impl Into, - sort: Vec>, + sort: Vec>, ) -> Arc { let source = FileStreamProvider::new_file(schema, path.into()) .with_batch_size(TEST_BATCH_SIZE) @@ -247,7 +247,7 @@ mod unix_test { ])); // Specify the ordering: - let order = vec![vec![datafusion_expr::col("a1").sort(true, false).to_expr()]]; + let order = vec![vec![datafusion_expr::col("a1").sort(true, false)]]; // Set unbounded sorted files read configuration let provider = fifo_table(schema.clone(), left_fifo.clone(), order.clone()); diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 7cd3018bab54..62e9be63983c 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -292,7 +292,7 @@ async fn group_by_string_test( let provider = MemTable::try_new(schema.clone(), vec![input]).unwrap(); let provider = if sorted { - let sort_expr = datafusion::prelude::col("a").sort(true, true).to_expr(); + let sort_expr = datafusion::prelude::col("a").sort(true, true); provider.with_sort_order(vec![vec![sort_expr]]) } else { provider diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs index aa3d8c6f6933..95d97709f319 100644 --- a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -226,15 +226,15 @@ impl SortedData { } /// Return the sort expression to use for this data, depending on the type - fn sort_expr(&self) -> Vec { + fn sort_expr(&self) -> Vec { match self { Self::I32 { .. } | Self::F64 { .. } | Self::Str { .. } => { - vec![datafusion_expr::col("x").sort(true, true).to_expr()] + vec![datafusion_expr::col("x").sort(true, true)] } Self::I64Str { .. } => { vec![ - datafusion_expr::col("x").sort(true, true).to_expr(), - datafusion_expr::col("y").sort(true, true).to_expr(), + datafusion_expr::col("x").sort(true, true), + datafusion_expr::col("y").sort(true, true), ] } } diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 07c786bf57e4..addabc8a3612 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -38,7 +38,7 @@ async fn join_change_in_planner() -> Result<()> { .map(|e| { let ascending = true; let nulls_first = false; - e.sort(ascending, nulls_first).to_expr() + e.sort(ascending, nulls_first) }) .collect::>()]; register_unbounded_file_with_ordering( @@ -106,7 +106,7 @@ async fn join_no_order_on_filter() -> Result<()> { .map(|e| { let ascending = true; let nulls_first = false; - e.sort(ascending, nulls_first).to_expr() + e.sort(ascending, nulls_first) }) .collect::>()]; register_unbounded_file_with_ordering( diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 6fc5582f2e1a..da27cf8869d1 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -59,7 +59,6 @@ //! use std::fmt::Debug; -use std::ops::Deref; use std::task::{Context, Poll}; use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; @@ -377,7 +376,7 @@ impl OptimizerRule for TopKOptimizerRule { node: Arc::new(TopKPlanNode { k: *fetch, input: input.as_ref().clone(), - expr: expr[0].unwrap_sort().clone(), + expr: expr[0].clone(), }), }))); } @@ -420,7 +419,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { } fn expressions(&self) -> Vec { - vec![self.expr.expr.deref().clone()] + vec![self.expr.expr.as_ref().clone()] } /// For example: `TopK: k=10` diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 61ece20076b9..b81c02ccd0b7 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -289,10 +289,6 @@ pub enum Expr { /// Casts the expression to a given type and will return a null value if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. TryCast(TryCast), - /// A sort expression, that can be used to sort values. - /// - /// See [Expr::sort] for more details - Sort(Sort), /// Represents the call of a scalar function with a set of arguments. ScalarFunction(ScalarFunction), /// Calls an aggregate function with arguments, and optional @@ -631,37 +627,23 @@ impl Sort { nulls_first: !self.nulls_first, } } - - // TODO (https://github.com/apache/datafusion/issues/12193) remove when transition is complete - pub fn to_expr(self) -> Expr { - Expr::Sort(self) - } -} - -// TODO (https://github.com/apache/datafusion/issues/12193) remove when transition is complete -pub fn sort_vec_to_expr(sorts: Vec) -> Vec { - sorts.into_iter().map(Expr::Sort).collect() -} - -// TODO (https://github.com/apache/datafusion/issues/12193) remove when transition is complete -pub fn sort_vec_vec_to_expr(sorts: Vec>) -> Vec> { - sorts.into_iter().map(sort_vec_to_expr).collect() } -// TODO (https://github.com/apache/datafusion/issues/12193) remove when transition is complete -pub fn sort_vec_from_expr(exprs: Vec) -> Vec { - exprs - .into_iter() - .map(|expr| match expr { - Expr::Sort(s) => s, - _ => panic!("Expression must be a Expr::Sort: {}", expr), - }) - .collect() -} - -// TODO (https://github.com/apache/datafusion/issues/12193) remove when transition is complete -pub fn sort_vec_vec_from_expr(exprs: Vec>) -> Vec> { - exprs.into_iter().map(sort_vec_from_expr).collect() +impl Display for Sort { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.expr)?; + if self.asc { + write!(f, " ASC")?; + } else { + write!(f, " DESC")?; + } + if self.nulls_first { + write!(f, " NULLS FIRST")?; + } else { + write!(f, " NULLS LAST")?; + } + Ok(()) + } } /// Aggregate function @@ -680,7 +662,7 @@ pub struct AggregateFunction { /// Optional filter pub filter: Option>, /// Optional ordering - pub order_by: Option>, + pub order_by: Option>, pub null_treatment: Option, } @@ -691,7 +673,7 @@ impl AggregateFunction { args: Vec, distinct: bool, filter: Option>, - order_by: Option>, + order_by: Option>, null_treatment: Option, ) -> Self { Self { @@ -816,7 +798,7 @@ pub struct WindowFunction { /// List of partition by expressions pub partition_by: Vec, /// List of order by expressions - pub order_by: Vec, + pub order_by: Vec, /// Window frame pub window_frame: window_frame::WindowFrame, /// Specifies how NULL value is treated: ignore or respect @@ -1172,7 +1154,6 @@ impl Expr { Expr::ScalarFunction(..) => "ScalarFunction", Expr::ScalarSubquery { .. } => "ScalarSubquery", Expr::ScalarVariable(..) => "ScalarVariable", - Expr::Sort { .. } => "Sort", Expr::TryCast { .. } => "TryCast", Expr::WindowFunction { .. } => "WindowFunction", Expr::Wildcard { .. } => "Wildcard", @@ -1258,14 +1239,9 @@ impl Expr { Expr::Like(Like::new(true, Box::new(self), Box::new(other), None, true)) } - /// Return the name to use for the specific Expr, recursing into - /// `Expr::Sort` as appropriate + /// Return the name to use for the specific Expr pub fn name_for_alias(&self) -> Result { - match self { - // call Expr::display_name() on a Expr::Sort will throw an error - Expr::Sort(Sort { expr, .. }) => expr.name_for_alias(), - expr => Ok(expr.schema_name().to_string()), - } + Ok(self.schema_name().to_string()) } /// Ensure `expr` has the name as `original_name` by adding an @@ -1281,14 +1257,7 @@ impl Expr { /// Return `self AS name` alias expression pub fn alias(self, name: impl Into) -> Expr { - match self { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => Expr::Sort(Sort::new(Box::new(expr.alias(name)), asc, nulls_first)), - _ => Expr::Alias(Alias::new(self, None::<&str>, name.into())), - } + Expr::Alias(Alias::new(self, None::<&str>, name.into())) } /// Return `self AS name` alias expression with a specific qualifier @@ -1297,18 +1266,7 @@ impl Expr { relation: Option>, name: impl Into, ) -> Expr { - match self { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => Expr::Sort(Sort::new( - Box::new(expr.alias_qualified(relation, name)), - asc, - nulls_first, - )), - _ => Expr::Alias(Alias::new(self, relation, name.into())), - } + Expr::Alias(Alias::new(self, relation, name.into())) } /// Remove an alias from an expression if one exists. @@ -1403,7 +1361,7 @@ impl Expr { Expr::IsNotNull(Box::new(self)) } - /// Create a sort expression from an existing expression. + /// Create a sort configuration from an existing expression. /// /// ``` /// # use datafusion_expr::col; @@ -1413,14 +1371,6 @@ impl Expr { Sort::new(Box::new(self), asc, nulls_first) } - // TODO (https://github.com/apache/datafusion/issues/12193) remove when transition is complete - pub fn unwrap_sort(&self) -> &Sort { - match self { - Expr::Sort(sort) => sort, - _ => panic!("Expression must be a Expr::Sort: {}", self), - } - } - /// Return `IsTrue(Box(self))` pub fn is_true(self) -> Expr { Expr::IsTrue(Box::new(self)) @@ -1694,7 +1644,6 @@ impl Expr { | Expr::Wildcard { .. } | Expr::WindowFunction(..) | Expr::Literal(..) - | Expr::Sort(..) | Expr::Placeholder(..) => false, } } @@ -1791,14 +1740,6 @@ impl Expr { }) => { data_type.hash(hasher); } - Expr::Sort(Sort { - expr: _expr, - asc, - nulls_first, - }) => { - asc.hash(hasher); - nulls_first.hash(hasher); - } Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { func.hash(hasher); } @@ -1910,7 +1851,6 @@ impl<'a> Display for SchemaDisplay<'a> { Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(..) - | Expr::Sort(_) | Expr::OuterReferenceColumn(..) | Expr::Placeholder(_) | Expr::Wildcard { .. } => write!(f, "{}", self.0), @@ -1940,7 +1880,7 @@ impl<'a> Display for SchemaDisplay<'a> { }; if let Some(order_by) = order_by { - write!(f, " ORDER BY [{}]", schema_name_from_exprs(order_by)?)?; + write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; }; Ok(()) @@ -2146,7 +2086,7 @@ impl<'a> Display for SchemaDisplay<'a> { } if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", schema_name_from_exprs(order_by)?)?; + write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; }; write!(f, " {window_frame}") @@ -2183,6 +2123,24 @@ fn schema_name_from_exprs_inner(exprs: &[Expr], sep: &str) -> Result Result { + let mut s = String::new(); + for (i, e) in sorts.iter().enumerate() { + if i > 0 { + write!(&mut s, ", ")?; + } + let ordering = if e.asc { "ASC" } else { "DESC" }; + let nulls_ordering = if e.nulls_first { + "NULLS FIRST" + } else { + "NULLS LAST" + }; + write!(&mut s, "{} {} {}", e.expr, ordering, nulls_ordering)?; + } + + Ok(s) +} + /// Format expressions for display as part of a logical plan. In many cases, this will produce /// similar output to `Expr.name()` except that column names will be prefixed with '#'. impl fmt::Display for Expr { @@ -2242,22 +2200,6 @@ impl fmt::Display for Expr { }) => write!(f, "{expr} IN ({subquery:?})"), Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"), Expr::BinaryExpr(expr) => write!(f, "{expr}"), - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - if *asc { - write!(f, "{expr} ASC")?; - } else { - write!(f, "{expr} DESC")?; - } - if *nulls_first { - write!(f, " NULLS FIRST") - } else { - write!(f, " NULLS LAST") - } - } Expr::ScalarFunction(fun) => { fmt_function(f, fun.name(), false, &fun.args, true) } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1e0b601146dd..8d01712b95ad 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -26,9 +26,9 @@ use crate::function::{ StateFieldsArgs, }; use crate::{ - conditional_expressions::CaseBuilder, logical_plan::Subquery, AggregateUDF, Expr, - LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, Signature, - Volatility, + conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery, + AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, + Signature, Volatility, }; use crate::{ AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, @@ -723,9 +723,7 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { /// ``` pub trait ExprFunctionExt { /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - fn order_by(self, order_by: Vec) -> ExprFuncBuilder; + fn order_by(self, order_by: Vec) -> ExprFuncBuilder; /// Add `FILTER ` fn filter(self, filter: Expr) -> ExprFuncBuilder; /// Add `DISTINCT` @@ -753,7 +751,7 @@ pub enum ExprFuncKind { #[derive(Debug, Clone)] pub struct ExprFuncBuilder { fun: Option, - order_by: Option>, + order_by: Option>, filter: Option, distinct: bool, null_treatment: Option, @@ -798,16 +796,6 @@ impl ExprFuncBuilder { ); }; - if let Some(order_by) = &order_by { - for expr in order_by.iter() { - if !matches!(expr, Expr::Sort(_)) { - return plan_err!( - "ORDER BY expressions must be Expr::Sort, found {expr:?}" - ); - } - } - } - let fun_expr = match fun { ExprFuncKind::Aggregate(mut udaf) => { udaf.order_by = order_by; @@ -833,9 +821,7 @@ impl ExprFuncBuilder { impl ExprFunctionExt for ExprFuncBuilder { /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { + fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { self.order_by = Some(order_by); self } @@ -873,7 +859,7 @@ impl ExprFunctionExt for ExprFuncBuilder { } impl ExprFunctionExt for Expr { - fn order_by(self, order_by: Vec) -> ExprFuncBuilder { + fn order_by(self, order_by: Vec) -> ExprFuncBuilder { let mut builder = match self { Expr::AggregateFunction(udaf) => { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 375ef1edf49a..b809b015d929 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -349,7 +349,6 @@ mod test { use std::ops::Add; use super::*; - use crate::expr::Sort; use crate::{col, lit, Cast}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::ScalarValue; @@ -510,12 +509,6 @@ mod test { // change literal type from i32 to i64 test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64))); - - // SortExpr a+1 ==> b + 2 - test_rewrite( - Expr::Sort(Sort::new(Box::new(col("a").add(lit(1i32))), true, false)), - Expr::Sort(Sort::new(Box::new(col("b").add(lit(2i64))), true, false)), - ); } /// rewrites `expr_from` to `rewrite_to` using @@ -538,15 +531,8 @@ mod test { }; let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap(); - let original_name = match &expr_from { - Expr::Sort(Sort { expr, .. }) => expr.schema_name().to_string(), - expr => expr.schema_name().to_string(), - }; - - let new_name = match &expr { - Expr::Sort(Sort { expr, .. }) => expr.schema_name().to_string(), - expr => expr.schema_name().to_string(), - }; + let original_name = expr_from.schema_name().to_string(); + let new_name = expr.schema_name().to_string(); assert_eq!( original_name, new_name, diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 19ab429c4f99..af5b8c4f9177 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -17,9 +17,9 @@ //! Rewrite for order by expressions -use crate::expr::{Alias, Sort}; +use crate::expr::Alias; use crate::expr_rewriter::normalize_col; -use crate::{Cast, Expr, ExprSchemable, LogicalPlan, TryCast}; +use crate::{expr::Sort, Cast, Expr, ExprSchemable, LogicalPlan, TryCast}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Column, Result}; @@ -27,28 +27,18 @@ use datafusion_common::{Column, Result}; /// Rewrite sort on aggregate expressions to sort on the column of aggregate output /// For example, `max(x)` is written to `col("max(x)")` pub fn rewrite_sort_cols_by_aggs( - exprs: impl IntoIterator>, + sorts: impl IntoIterator>, plan: &LogicalPlan, -) -> Result> { - exprs +) -> Result> { + sorts .into_iter() .map(|e| { - let expr = e.into(); - match expr { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let sort = Expr::Sort(Sort::new( - Box::new(rewrite_sort_col_by_aggs(*expr, plan)?), - asc, - nulls_first, - )); - Ok(sort) - } - expr => Ok(expr), - } + let sort = e.into(); + Ok(Sort::new( + Box::new(rewrite_sort_col_by_aggs(*sort.expr, plan)?), + sort.asc, + sort.nulls_first, + )) }) .collect() } @@ -289,8 +279,8 @@ mod test { struct TestCase { desc: &'static str, - input: Expr, - expected: Expr, + input: Sort, + expected: Sort, } impl TestCase { @@ -332,9 +322,9 @@ mod test { .unwrap() } - fn sort(expr: Expr) -> Expr { + fn sort(expr: Expr) -> Sort { let asc = true; let nulls_first = true; - expr.sort(asc, nulls_first).to_expr() + expr.sort(asc, nulls_first) } } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 10ec10e61239..cb364f530929 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -18,7 +18,7 @@ use super::{Between, Expr, Like}; use crate::expr::{ AggregateFunction, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder, - ScalarFunction, Sort, TryCast, Unnest, WindowFunction, + ScalarFunction, TryCast, Unnest, WindowFunction, }; use crate::type_coercion::binary::get_result_type; use crate::type_coercion::functions::{ @@ -107,7 +107,7 @@ impl ExprSchemable for Expr { }, _ => expr.get_type(schema), }, - Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => expr.get_type(schema), + Expr::Negative(expr) => expr.get_type(schema), Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), @@ -280,10 +280,9 @@ impl ExprSchemable for Expr { /// column that does not exist in the schema. fn nullable(&self, input_schema: &dyn ExprSchema) -> Result { match self { - Expr::Alias(Alias { expr, .. }) - | Expr::Not(expr) - | Expr::Negative(expr) - | Expr::Sort(Sort { expr, .. }) => expr.nullable(input_schema), + Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => { + expr.nullable(input_schema) + } Expr::InList(InList { expr, list, .. }) => { // Avoid inspecting too many expressions. @@ -422,9 +421,7 @@ impl ExprSchemable for Expr { }, _ => expr.data_type_and_nullable(schema), }, - Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => { - expr.data_type_and_nullable(schema) - } + Expr::Negative(expr) => expr.data_type_and_nullable(schema), Expr::Column(c) => schema .data_type_and_nullable(c) .map(|(d, n)| (d.clone(), n)), diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 4539cb778c39..ae70c3830116 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -23,7 +23,7 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use crate::dml::CopyTo; -use crate::expr::{sort_vec_from_expr, sort_vec_to_expr, Alias}; +use crate::expr::{Alias, Sort as SortExpr}; use crate::expr_rewriter::{ coerce_plan_expr_for_schema, normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_cols, normalize_sorts, @@ -42,7 +42,7 @@ use crate::utils::{ }; use crate::{ and, binary_expr, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, - SortExpr, TableProviderFilterPushDown, TableSource, WriteOp, + TableProviderFilterPushDown, TableSource, WriteOp, }; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; @@ -556,12 +556,9 @@ impl LogicalPlanBuilder { /// Apply a sort pub fn sort( self, - exprs: impl IntoIterator> + Clone, + sorts: impl IntoIterator> + Clone, ) -> Result { - let sorts = sort_vec_from_expr(rewrite_sort_cols_by_aggs( - sort_vec_to_expr(exprs.into_iter().map(|s| s.into()).collect()), - &self.plan, - )?); + let sorts = rewrite_sort_cols_by_aggs(sorts, &self.plan)?; let schema = self.plan.schema(); @@ -581,7 +578,7 @@ impl LogicalPlanBuilder { if missing_cols.is_empty() { return Ok(Self::new(LogicalPlan::Sort(Sort { - expr: sort_vec_to_expr(normalize_sorts(sorts, &self.plan)?), + expr: normalize_sorts(sorts, &self.plan)?, input: self.plan, fetch: None, }))); @@ -597,7 +594,7 @@ impl LogicalPlanBuilder { is_distinct, )?; let sort_plan = LogicalPlan::Sort(Sort { - expr: sort_vec_to_expr(normalize_sorts(sorts, &plan)?), + expr: normalize_sorts(sorts, &plan)?, input: Arc::new(plan), fetch: None, }); @@ -633,7 +630,7 @@ impl LogicalPlanBuilder { self, on_expr: Vec, select_expr: Vec, - sort_expr: Option>, + sort_expr: Option>, ) -> Result { Ok(Self::new(LogicalPlan::Distinct(Distinct::On( DistinctOn::try_new(on_expr, select_expr, sort_expr, self.plan)?, diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index ad0fcd2d4771..3fc43200efe6 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -22,8 +22,9 @@ use std::{ hash::{Hash, Hasher}, }; -use crate::{Expr, LogicalPlan, Volatility}; +use crate::{Expr, LogicalPlan, SortExpr, Volatility}; +use crate::expr::Sort; use arrow::datatypes::DataType; use datafusion_common::{Constraints, DFSchemaRef, SchemaReference, TableReference}; use sqlparser::ast::Ident; @@ -204,7 +205,7 @@ pub struct CreateExternalTable { /// SQL used to create the table, if available pub definition: Option, /// Order expressions supplied by user - pub order_exprs: Vec>, + pub order_exprs: Vec>, /// Whether the table is an infinite streams pub unbounded: bool, /// Table(provider) specific options @@ -365,7 +366,7 @@ pub struct CreateIndex { pub name: Option, pub table: TableReference, pub using: Option, - pub columns: Vec, + pub columns: Vec, pub unique: bool, pub if_not_exists: bool, pub schema: DFSchemaRef, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 359de2d30a57..8e6ec762f549 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -26,7 +26,9 @@ use super::dml::CopyTo; use super::DdlStatement; use crate::builder::{change_redundant_column, unnest_with_options}; use crate::expr::{Placeholder, Sort as SortExpr, WindowFunction}; -use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols, NamePreserver}; +use crate::expr_rewriter::{ + create_col_from_scalar_expr, normalize_cols, normalize_sorts, NamePreserver, +}; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; @@ -51,6 +53,7 @@ use datafusion_common::{ // backwards compatibility use crate::display::PgJsonVisitor; +use crate::tree_node::replace_sort_expressions; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -884,8 +887,12 @@ impl LogicalPlan { Aggregate::try_new(Arc::new(inputs.swap_remove(0)), expr, agg_expr) .map(LogicalPlan::Aggregate) } - LogicalPlan::Sort(Sort { fetch, .. }) => Ok(LogicalPlan::Sort(Sort { - expr, + LogicalPlan::Sort(Sort { + expr: sort_expr, + fetch, + .. + }) => Ok(LogicalPlan::Sort(Sort { + expr: replace_sort_expressions(sort_expr.clone(), expr), input: Arc::new(inputs.swap_remove(0)), fetch: *fetch, })), @@ -1014,14 +1021,11 @@ impl LogicalPlan { }) => { let sort_expr = expr.split_off(on_expr.len() + select_expr.len()); let select_expr = expr.split_off(on_expr.len()); + assert!(sort_expr.is_empty(), "with_new_exprs for Distinct does not support sort expressions"); Distinct::On(DistinctOn::try_new( expr, select_expr, - if !sort_expr.is_empty() { - Some(sort_expr) - } else { - None - }, + None, // no sort expressions accepted Arc::new(inputs.swap_remove(0)), )?) } @@ -2559,7 +2563,7 @@ pub struct DistinctOn { /// The `ORDER BY` clause, whose initial expressions must match those of the `ON` clause when /// present. Note that those matching expressions actually wrap the `ON` expressions with /// additional info pertaining to the sorting procedure (i.e. ASC/DESC, and NULLS FIRST/LAST). - pub sort_expr: Option>, + pub sort_expr: Option>, /// The logical plan that is being DISTINCT'd pub input: Arc, /// The schema description of the DISTINCT ON output @@ -2571,7 +2575,7 @@ impl DistinctOn { pub fn try_new( on_expr: Vec, select_expr: Vec, - sort_expr: Option>, + sort_expr: Option>, input: Arc, ) -> Result { if on_expr.is_empty() { @@ -2606,20 +2610,15 @@ impl DistinctOn { /// Try to update `self` with a new sort expressions. /// /// Validates that the sort expressions are a super-set of the `ON` expressions. - pub fn with_sort_expr(mut self, sort_expr: Vec) -> Result { - let sort_expr = normalize_cols(sort_expr, self.input.as_ref())?; + pub fn with_sort_expr(mut self, sort_expr: Vec) -> Result { + let sort_expr = normalize_sorts(sort_expr, self.input.as_ref())?; // Check that the left-most sort expressions are the same as the `ON` expressions. let mut matched = true; for (on, sort) in self.on_expr.iter().zip(sort_expr.iter()) { - match sort { - Expr::Sort(SortExpr { expr, .. }) => { - if on != &**expr { - matched = false; - break; - } - } - _ => return plan_err!("Not a sort expression: {sort}"), + if on != &*sort.expr { + matched = false; + break; } } @@ -2833,7 +2832,7 @@ fn calc_func_dependencies_for_project( #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Sort { /// The sort expressions - pub expr: Vec, + pub expr: Vec, /// The incoming logical plan pub input: Arc, /// Optional fetch limit diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 273404c8df31..29a99a8e8886 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -46,7 +46,7 @@ use crate::{ use std::sync::Arc; use crate::expr::{Exists, InSubquery}; -use crate::tree_node::transform_option_vec; +use crate::tree_node::{transform_sort_option_vec, transform_sort_vec}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, @@ -481,7 +481,9 @@ impl LogicalPlan { .apply_until_stop(|e| f(&e))? .visit_sibling(|| filter.iter().apply_until_stop(f)) } - LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().apply_until_stop(f), + LogicalPlan::Sort(Sort { expr, .. }) => { + expr.iter().apply_until_stop(|sort| f(&sort.expr)) + } LogicalPlan::Extension(extension) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs @@ -507,7 +509,7 @@ impl LogicalPlan { })) => on_expr .iter() .chain(select_expr.iter()) - .chain(sort_expr.iter().flatten()) + .chain(sort_expr.iter().flatten().map(|sort| &*sort.expr)) .apply_until_stop(f), // plans without expressions LogicalPlan::EmptyRelation(_) @@ -658,10 +660,10 @@ impl LogicalPlan { null_equals_null, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), + LogicalPlan::Sort(Sort { expr, input, fetch }) => { + transform_sort_vec(expr, &mut f)? + .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })) + } LogicalPlan::Extension(Extension { node }) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs @@ -709,7 +711,7 @@ impl LogicalPlan { select_expr, select_expr.into_iter().map_until_stop_and_collect(&mut f), sort_expr, - transform_option_vec(sort_expr, &mut f) + transform_sort_option_vec(sort_expr, &mut f) )? .update_data(|(on_expr, select_expr, sort_expr)| { LogicalPlan::Distinct(Distinct::On(DistinctOn { diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 1678d13edea6..69748aded531 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -48,7 +48,6 @@ impl TreeNode for Expr { | Expr::Negative(expr) | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) - | Expr::Sort(Sort { expr, .. }) | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref()], Expr::GroupingSet(GroupingSet::Rollup(exprs)) | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().collect(), @@ -98,7 +97,7 @@ impl TreeNode for Expr { expr_vec.push(f.as_ref()); } if let Some(order_by) = order_by { - expr_vec.extend(order_by); + expr_vec.extend(order_by.iter().map(|sort| sort.expr.as_ref())); } expr_vec } @@ -110,7 +109,7 @@ impl TreeNode for Expr { }) => { let mut expr_vec = args.iter().collect::>(); expr_vec.extend(partition_by); - expr_vec.extend(order_by); + expr_vec.extend(order_by.iter().map(|sort| sort.expr.as_ref())); expr_vec } Expr::InList(InList { expr, list, .. }) => { @@ -265,12 +264,6 @@ impl TreeNode for Expr { .update_data(|be| Expr::Cast(Cast::new(be, data_type))), Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)? .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => transform_box(expr, &mut f)? - .update_data(|be| Expr::Sort(Sort::new(be, asc, nulls_first))), Expr::ScalarFunction(ScalarFunction { func, args }) => { transform_vec(args, &mut f)?.map_data(|new_args| { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( @@ -290,7 +283,7 @@ impl TreeNode for Expr { partition_by, transform_vec(partition_by, &mut f), order_by, - transform_vec(order_by, &mut f) + transform_sort_vec(order_by, &mut f) )? .update_data(|(new_args, new_partition_by, new_order_by)| { Expr::WindowFunction(WindowFunction::new(fun, new_args)) @@ -313,7 +306,7 @@ impl TreeNode for Expr { filter, transform_option_box(filter, &mut f), order_by, - transform_option_vec(order_by, &mut f) + transform_sort_option_vec(order_by, &mut f) )? .map_data(|(new_args, new_filter, new_order_by)| { Ok(Expr::AggregateFunction(AggregateFunction::new_udf( @@ -387,6 +380,44 @@ fn transform_vec Result>>( ve.into_iter().map_until_stop_and_collect(f) } +pub fn transform_sort_option_vec Result>>( + sorts_option: Option>, + f: &mut F, +) -> Result>>> { + sorts_option.map_or(Ok(Transformed::no(None)), |sorts| { + Ok(transform_sort_vec(sorts, f)?.update_data(Some)) + }) +} + +pub fn transform_sort_vec Result>>( + sorts: Vec, + mut f: &mut F, +) -> Result>> { + Ok(sorts + .iter() + .map(|sort| (*sort.expr).clone()) + .map_until_stop_and_collect(&mut f)? + .update_data(|transformed_exprs| { + replace_sort_expressions(sorts, transformed_exprs) + })) +} + +pub fn replace_sort_expressions(sorts: Vec, new_expr: Vec) -> Vec { + if sorts.len() != new_expr.len() { + panic!( + "Incorrect number of new_expr, expected {}, got {}", + sorts.len(), + new_expr.len() + ); + } + + let mut new_sorts = Vec::with_capacity(sorts.len()); + for (i, expr) in new_expr.into_iter().enumerate() { + new_sorts.push(replace_sort_expression(sorts[i].clone(), expr)); + } + new_sorts +} + pub fn replace_sort_expression(sort: Sort, new_expr: Expr) -> Sort { Sort { expr: Box::new(new_expr), diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 60da38a17b84..b6b1b5660a81 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -296,7 +296,6 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Case { .. } | Expr::Cast { .. } | Expr::TryCast { .. } - | Expr::Sort { .. } | Expr::ScalarFunction(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } @@ -461,29 +460,27 @@ pub fn expand_qualified_wildcard( /// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)") /// if bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column -type WindowSortKey = Vec<(Expr, bool)>; +type WindowSortKey = Vec<(Sort, bool)>; /// Generate a sort key for a given window expr's partition_by and order_by expr pub fn generate_sort_key( partition_by: &[Expr], - order_by: &[Expr], + order_by: &[Sort], ) -> Result { let normalized_order_by_keys = order_by .iter() - .map(|e| match e { - Expr::Sort(Sort { expr, .. }) => { - Ok(Expr::Sort(Sort::new(expr.clone(), true, false))) - } - _ => plan_err!("Order by only accepts sort expressions"), + .map(|e| { + let Sort { expr, .. } = e; + Sort::new(expr.clone(), true, false) }) - .collect::>>()?; + .collect::>(); let mut final_sort_keys = vec![]; let mut is_partition_flag = vec![]; partition_by.iter().for_each(|e| { // By default, create sort key with ASC is true and NULLS LAST to be consistent with // PostgreSQL's rule: https://www.postgresql.org/docs/current/queries-order.html - let e = e.clone().sort(true, false).to_expr(); + let e = e.clone().sort(true, false); if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) { let order_by_key = &order_by[pos]; if !final_sort_keys.contains(order_by_key) { @@ -512,65 +509,61 @@ pub fn generate_sort_key( /// Compare the sort expr as PostgreSQL's common_prefix_cmp(): /// pub fn compare_sort_expr( - sort_expr_a: &Expr, - sort_expr_b: &Expr, + sort_expr_a: &Sort, + sort_expr_b: &Sort, schema: &DFSchemaRef, ) -> Ordering { - match (sort_expr_a, sort_expr_b) { - ( - Expr::Sort(Sort { - expr: expr_a, - asc: asc_a, - nulls_first: nulls_first_a, - }), - Expr::Sort(Sort { - expr: expr_b, - asc: asc_b, - nulls_first: nulls_first_b, - }), - ) => { - let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema); - let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema); - for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) { - match idx_a.cmp(idx_b) { - Ordering::Less => { - return Ordering::Less; - } - Ordering::Greater => { - return Ordering::Greater; - } - Ordering::Equal => {} - } + let Sort { + expr: expr_a, + asc: asc_a, + nulls_first: nulls_first_a, + } = sort_expr_a; + + let Sort { + expr: expr_b, + asc: asc_b, + nulls_first: nulls_first_b, + } = sort_expr_b; + + let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema); + let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema); + for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) { + match idx_a.cmp(idx_b) { + Ordering::Less => { + return Ordering::Less; } - match ref_indexes_a.len().cmp(&ref_indexes_b.len()) { - Ordering::Less => return Ordering::Greater, - Ordering::Greater => { - return Ordering::Less; - } - Ordering::Equal => {} + Ordering::Greater => { + return Ordering::Greater; } - match (asc_a, asc_b) { - (true, false) => { - return Ordering::Greater; - } - (false, true) => { - return Ordering::Less; - } - _ => {} - } - match (nulls_first_a, nulls_first_b) { - (true, false) => { - return Ordering::Less; - } - (false, true) => { - return Ordering::Greater; - } - _ => {} - } - Ordering::Equal + Ordering::Equal => {} } - _ => panic!("Sort expressions must be of type Sort"), } + match ref_indexes_a.len().cmp(&ref_indexes_b.len()) { + Ordering::Less => return Ordering::Greater, + Ordering::Greater => { + return Ordering::Less; + } + Ordering::Equal => {} + } + match (asc_a, asc_b) { + (true, false) => { + return Ordering::Greater; + } + (false, true) => { + return Ordering::Less; + } + _ => {} + } + match (nulls_first_a, nulls_first_b) { + (true, false) => { + return Ordering::Less; + } + (false, true) => { + return Ordering::Greater; + } + _ => {} + } + Ordering::Equal } /// group a slice of window expression expr by their order by expressions @@ -606,14 +599,6 @@ pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { }) } -/// Collect all deeply nested `Expr::Sort`. They are returned in order of occurrence -/// (depth first), with duplicates omitted. -pub fn find_sort_exprs(exprs: &[Expr]) -> Vec { - find_exprs_in_exprs(exprs, &|nested_expr| { - matches!(nested_expr, Expr::Sort { .. }) - }) -} - /// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence /// (depth first), with duplicates omitted. pub fn find_window_exprs(exprs: &[Expr]) -> Vec { @@ -1376,8 +1361,7 @@ mod tests { use crate::{ col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, test::function_stub::max_udaf, test::function_stub::min_udaf, - test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFrame, - WindowFunctionDefinition, + test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition, }; #[test] @@ -1417,10 +1401,9 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys() -> Result<()> { - let age_asc = Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)); - let name_desc = Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)); - let created_at_desc = - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); + let age_asc = expr::Sort::new(Box::new(col("age")), true, true); + let name_desc = expr::Sort::new(Box::new(col("name")), false, true); + let created_at_desc = expr::Sort::new(Box::new(col("created_at")), false, true); let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], @@ -1471,43 +1454,6 @@ mod tests { Ok(()) } - #[test] - fn test_find_sort_exprs() -> Result<()> { - let exprs = &[ - Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateUDF(max_udaf()), - vec![col("name")], - )) - .order_by(vec![ - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - ]) - .window_frame(WindowFrame::new(Some(false))) - .build() - .unwrap(), - Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateUDF(sum_udaf()), - vec![col("age")], - )) - .order_by(vec![ - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), - ]) - .window_frame(WindowFrame::new(Some(false))) - .build() - .unwrap(), - ]; - let expected = vec![ - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), - ]; - let result = find_sort_exprs(exprs); - assert_eq!(expected, result); - Ok(()) - } - #[test] fn avoid_generate_duplicate_sort_keys() -> Result<()> { let asc_or_desc = [true, false]; @@ -1516,41 +1462,41 @@ mod tests { for asc_ in asc_or_desc { for nulls_first_ in nulls_first_or_last { let order_by = &[ - Expr::Sort(Sort { + Sort { expr: Box::new(col("age")), asc: asc_, nulls_first: nulls_first_, - }), - Expr::Sort(Sort { + }, + Sort { expr: Box::new(col("name")), asc: asc_, nulls_first: nulls_first_, - }), + }, ]; let expected = vec![ ( - Expr::Sort(Sort { + Sort { expr: Box::new(col("age")), asc: asc_, nulls_first: nulls_first_, - }), + }, true, ), ( - Expr::Sort(Sort { + Sort { expr: Box::new(col("name")), asc: asc_, nulls_first: nulls_first_, - }), + }, true, ), ( - Expr::Sort(Sort { + Sort { expr: Box::new(col("created_at")), asc: true, nulls_first: false, - }), + }, true, ), ]; diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 38642c255a27..6c935cdcd121 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -26,7 +26,7 @@ use std::fmt::{self, Formatter}; use std::hash::Hash; -use crate::{lit, Expr}; +use crate::{expr::Sort, lit}; use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; @@ -247,7 +247,7 @@ impl WindowFrame { } /// Regularizes the ORDER BY clause of the window frame. - pub fn regularize_order_bys(&self, order_by: &mut Vec) -> Result<()> { + pub fn regularize_order_bys(&self, order_by: &mut Vec) -> Result<()> { match self.units { // Normally, RANGE frames require an ORDER BY clause with exactly // one column. However, an ORDER BY clause may be absent or have @@ -259,7 +259,7 @@ impl WindowFrame { // ORDER BY clause is present but has more than one column, // it is unchanged. Note that this follows PostgreSQL behavior. if order_by.is_empty() { - order_by.push(lit(1u64).sort(true, false).to_expr()); + order_by.push(lit(1u64).sort(true, false)); } } WindowFrameUnits::Range if order_by.len() != 1 => { diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 2162442f054e..30f5d5b07561 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -32,7 +32,7 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, ExprFunctionExt, - Signature, TypeSignature, Volatility, + Signature, SortExpr, TypeSignature, Volatility, }; use datafusion_functions_aggregate_common::utils::get_sort_options; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -40,7 +40,7 @@ use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; create_func!(FirstValue, first_value_udaf); /// Returns the first value in a group of values. -pub fn first_value(expression: Expr, order_by: Option>) -> Expr { +pub fn first_value(expression: Expr, order_by: Option>) -> Expr { if let Some(order_by) = order_by { first_value_udaf() .call(vec![expression]) diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index e114efb99960..35d4f91e3b6f 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -229,7 +229,7 @@ mod tests { WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) - .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .order_by(vec![Sort::new(Box::new(col("a")), false, true)]) .window_frame(WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index a6b9bad6c5d9..61ff4b4fd5a8 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -33,7 +33,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, Alias, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, - ScalarFunction, WindowFunction, + ScalarFunction, Sort, WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; @@ -506,7 +506,6 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { | Expr::Negative(_) | Expr::Cast(_) | Expr::TryCast(_) - | Expr::Sort(_) | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) @@ -593,12 +592,12 @@ fn coerce_frame_bound( fn coerce_window_frame( window_frame: WindowFrame, schema: &DFSchema, - expressions: &[Expr], + expressions: &[Sort], ) -> Result { let mut window_frame = window_frame; let current_types = expressions .iter() - .map(|e| e.get_type(schema)) + .map(|s| s.expr.get_type(schema)) .collect::>>()?; let target_type = match window_frame.units { WindowFrameUnits::Range => { diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 3fcee4123b76..7cce43815b07 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -36,6 +36,7 @@ use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; +use datafusion_expr::tree_node::replace_sort_expressions; use datafusion_expr::{col, BinaryExpr, Case, Expr, ExprSchemable, Operator}; use indexmap::IndexMap; @@ -327,15 +328,17 @@ impl CommonSubexprEliminate { ) -> Result> { let Sort { expr, input, fetch } = sort; let input = Arc::unwrap_or_clone(input); - let new_sort = self.try_unary_plan(expr, input, config)?.update_data( - |(new_expr, new_input)| { + let sort_expressions = + expr.iter().map(|sort| sort.expr.as_ref().clone()).collect(); + let new_sort = self + .try_unary_plan(sort_expressions, input, config)? + .update_data(|(new_expr, new_input)| { LogicalPlan::Sort(Sort { - expr: new_expr, + expr: replace_sort_expressions(expr, new_expr), input: Arc::new(new_input), fetch, }) - }, - ); + }); Ok(new_sort) } @@ -882,7 +885,6 @@ enum ExprMask { /// - [`Columns`](Expr::Column) /// - [`ScalarVariable`](Expr::ScalarVariable) /// - [`Alias`](Expr::Alias) - /// - [`Sort`](Expr::Sort) /// - [`Wildcard`](Expr::Wildcard) /// - [`AggregateFunction`](Expr::AggregateFunction) Normal, @@ -899,7 +901,6 @@ impl ExprMask { | Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Alias(..) - | Expr::Sort { .. } | Expr::Wildcard { .. } ); diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 65520bee987f..5904da3436a0 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -22,9 +22,8 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::{Aggregate, Expr, Sort}; -use indexmap::IndexSet; -use std::hash::{Hash, Hasher}; +use datafusion_expr::{Aggregate, Expr, Sort, SortExpr}; +use indexmap::{IndexMap, IndexSet}; /// Optimization rule that eliminate duplicated expr. #[derive(Default)] pub struct EliminateDuplicatedExpr; @@ -35,33 +34,6 @@ impl EliminateDuplicatedExpr { Self {} } } -// use this structure to avoid initial clone -#[derive(Eq, Clone, Debug)] -struct SortExprWrapper { - expr: Expr, -} -impl PartialEq for SortExprWrapper { - fn eq(&self, other: &Self) -> bool { - match (&self.expr, &other.expr) { - (Expr::Sort(own_sort), Expr::Sort(other_sort)) => { - own_sort.expr == other_sort.expr - } - _ => self.expr == other.expr, - } - } -} -impl Hash for SortExprWrapper { - fn hash(&self, state: &mut H) { - match &self.expr { - Expr::Sort(sort) => { - sort.expr.hash(state); - } - _ => { - self.expr.hash(state); - } - } - } -} impl OptimizerRule for EliminateDuplicatedExpr { fn apply_order(&self) -> Option { Some(ApplyOrder::TopDown) @@ -79,14 +51,15 @@ impl OptimizerRule for EliminateDuplicatedExpr { match plan { LogicalPlan::Sort(sort) => { let len = sort.expr.len(); - let unique_exprs: Vec<_> = sort - .expr - .into_iter() - .map(|e| SortExprWrapper { expr: e }) - .collect::>() - .into_iter() - .map(|wrapper| wrapper.expr) - .collect(); + let mut first_sort_by_expr: IndexMap = + IndexMap::default(); + for s in &sort.expr { + first_sort_by_expr + .entry(s.expr.as_ref().clone()) + .or_insert(s.clone()); + } + let unique_exprs: Vec = + first_sort_by_expr.into_values().collect(); let transformed = if len != unique_exprs.len() { Transformed::yes @@ -150,7 +123,7 @@ mod tests { .limit(5, Some(10))? .build()?; let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a, test.b, test.c\ + \n Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST\ \n TableScan: test"; assert_optimized_plan_eq(plan, expected) } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 10baa58ad7dc..2503475bd8df 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -189,7 +189,7 @@ mod tests { // After remove global-state, we don't record the parent // So, bottom don't know parent info, so can't eliminate. let expected = "Limit: skip=2, fetch=1\ - \n Sort: test.a, fetch=3\ + \n Sort: test.a ASC NULLS LAST, fetch=3\ \n Limit: skip=0, fetch=2\ \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ \n TableScan: test"; @@ -207,7 +207,7 @@ mod tests { .build()?; let expected = "Limit: skip=0, fetch=1\ - \n Sort: test.a\ + \n Sort: test.a ASC NULLS LAST\ \n Limit: skip=0, fetch=2\ \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ \n TableScan: test"; @@ -225,7 +225,7 @@ mod tests { .build()?; let expected = "Limit: skip=3, fetch=1\ - \n Sort: test.a\ + \n Sort: test.a ASC NULLS LAST\ \n Limit: skip=2, fetch=1\ \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ \n TableScan: test"; diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 6e75f46c2d0b..1cc3cefc4726 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -284,8 +284,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::TryCast(_) | Expr::InList { .. } | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue), - Expr::Sort(_) - | Expr::AggregateFunction(_) + Expr::AggregateFunction(_) | Expr::WindowFunction(_) | Expr::Wildcard { .. } | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 55ce05e5bc0e..ab7880213692 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -353,7 +353,7 @@ mod test { // Should push down limit to sort let expected = "Limit: skip=0, fetch=10\ - \n Sort: test.a, fetch=10\ + \n Sort: test.a ASC NULLS LAST, fetch=10\ \n TableScan: test"; assert_optimized_plan_equal(plan, expected) @@ -370,7 +370,7 @@ mod test { // Should push down limit to sort let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a, fetch=15\ + \n Sort: test.a ASC NULLS LAST, fetch=15\ \n TableScan: test"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index b66cbf626b24..c887192f6370 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -21,7 +21,6 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, Result}; -use datafusion_expr::expr::sort_vec_from_expr; use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; use datafusion_expr::{col, ExprFunctionExt, LogicalPlanBuilder}; @@ -145,7 +144,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { // truncate the sort_expr to the length of on_expr sort_expr.truncate(expr_cnt); - lpb.sort(sort_vec_from_expr(sort_expr))?.build()? + lpb.sort(sort_expr)?.build()? } else { lpb.build()? }; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c45df74a564d..8997a711aace 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -591,7 +591,6 @@ impl<'a> ConstEvaluator<'a> { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::WindowFunction { .. } - | Expr::Sort { .. } | Expr::GroupingSet(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) => false, diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index d82ee7d4dbbc..dd82b056d0a6 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -624,14 +624,14 @@ mod tests { vec![col("a")], false, None, - Some(vec![col("a")]), + Some(vec![col("a").sort(true, false)]), None, )); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a]:UInt64;N, count(DISTINCT test.b):Int64]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a ASC NULLS LAST], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a ASC NULLS LAST]:UInt64;N, count(DISTINCT test.b):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -645,7 +645,7 @@ mod tests { let expr = count_udaf() .call(vec![col("a")]) .distinct() - .order_by(vec![col("a").sort(true, false).to_expr()]) + .order_by(vec![col("a").sort(true, false)]) .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? @@ -666,7 +666,7 @@ mod tests { .call(vec![col("a")]) .distinct() .filter(col("a").gt(lit(5))) - .order_by(vec![col("a").sort(true, false).to_expr()]) + .order_by(vec![col("a").sort(true, false)]) .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 826992e132ba..19759a897068 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -75,6 +75,10 @@ message LogicalExprNodeCollection { repeated LogicalExprNode logical_expr_nodes = 1; } +message SortExprNodeCollection { + repeated SortExprNode sort_expr_nodes = 1; +} + message ListingTableScanNode { reserved 1; // was string table_name TableReference table_name = 14; @@ -92,7 +96,7 @@ message ListingTableScanNode { datafusion_common.AvroFormat avro = 12; datafusion_common.NdJsonFormat json = 15; } - repeated LogicalExprNodeCollection file_sort_order = 13; + repeated SortExprNodeCollection file_sort_order = 13; } message ViewTableScanNode { @@ -129,7 +133,7 @@ message SelectionNode { message SortNode { LogicalPlanNode input = 1; - repeated LogicalExprNode expr = 2; + repeated SortExprNode expr = 2; // Maximum number of highest/lowest rows to fetch; negative means no limit int64 fetch = 3; } @@ -160,7 +164,7 @@ message CreateExternalTableNode { repeated string table_partition_cols = 5; bool if_not_exists = 6; string definition = 7; - repeated LogicalExprNodeCollection order_exprs = 10; + repeated SortExprNodeCollection order_exprs = 10; bool unbounded = 11; map options = 8; datafusion_common.Constraints constraints = 12; @@ -245,7 +249,7 @@ message DistinctNode { message DistinctOnNode { repeated LogicalExprNode on_expr = 1; repeated LogicalExprNode select_expr = 2; - repeated LogicalExprNode sort_expr = 3; + repeated SortExprNode sort_expr = 3; LogicalPlanNode input = 4; } @@ -320,7 +324,6 @@ message LogicalExprNode { BetweenNode between = 9; CaseNode case_ = 10; CastNode cast = 11; - SortExprNode sort = 12; NegativeNode negative = 13; InListNode in_list = 14; Wildcard wildcard = 15; @@ -470,7 +473,7 @@ message AggregateUDFExprNode { repeated LogicalExprNode args = 2; bool distinct = 5; LogicalExprNode filter = 3; - repeated LogicalExprNode order_by = 4; + repeated SortExprNode order_by = 4; optional bytes fun_definition = 6; } @@ -503,7 +506,7 @@ message WindowExprNode { } LogicalExprNode expr = 4; repeated LogicalExprNode partition_by = 5; - repeated LogicalExprNode order_by = 6; + repeated SortExprNode order_by = 6; // repeated LogicalExprNode filter = 7; WindowFrame window_frame = 8; optional bytes fun_definition = 10; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index b4d63798f080..cff58d3ddc4a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -9291,9 +9291,6 @@ impl serde::Serialize for LogicalExprNode { logical_expr_node::ExprType::Cast(v) => { struct_ser.serialize_field("cast", v)?; } - logical_expr_node::ExprType::Sort(v) => { - struct_ser.serialize_field("sort", v)?; - } logical_expr_node::ExprType::Negative(v) => { struct_ser.serialize_field("negative", v)?; } @@ -9384,7 +9381,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "case_", "case", "cast", - "sort", "negative", "in_list", "inList", @@ -9433,7 +9429,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { Between, Case, Cast, - Sort, Negative, InList, Wildcard, @@ -9486,7 +9481,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "between" => Ok(GeneratedField::Between), "case" | "case_" => Ok(GeneratedField::Case), "cast" => Ok(GeneratedField::Cast), - "sort" => Ok(GeneratedField::Sort), "negative" => Ok(GeneratedField::Negative), "inList" | "in_list" => Ok(GeneratedField::InList), "wildcard" => Ok(GeneratedField::Wildcard), @@ -9598,13 +9592,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { return Err(serde::de::Error::duplicate_field("cast")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Cast) -; - } - GeneratedField::Sort => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("sort")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Sort) ; } GeneratedField::Negative => { @@ -17947,6 +17934,98 @@ impl<'de> serde::Deserialize<'de> for SortExprNode { deserializer.deserialize_struct("datafusion.SortExprNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for SortExprNodeCollection { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.sort_expr_nodes.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SortExprNodeCollection", len)?; + if !self.sort_expr_nodes.is_empty() { + struct_ser.serialize_field("sortExprNodes", &self.sort_expr_nodes)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SortExprNodeCollection { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "sort_expr_nodes", + "sortExprNodes", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + SortExprNodes, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "sortExprNodes" | "sort_expr_nodes" => Ok(GeneratedField::SortExprNodes), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SortExprNodeCollection; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SortExprNodeCollection") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut sort_expr_nodes__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::SortExprNodes => { + if sort_expr_nodes__.is_some() { + return Err(serde::de::Error::duplicate_field("sortExprNodes")); + } + sort_expr_nodes__ = Some(map_.next_value()?); + } + } + } + Ok(SortExprNodeCollection { + sort_expr_nodes: sort_expr_nodes__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SortExprNodeCollection", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for SortNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 875d2af75dd7..2ce8004e3248 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -97,6 +97,12 @@ pub struct LogicalExprNodeCollection { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct SortExprNodeCollection { + #[prost(message, repeated, tag = "1")] + pub sort_expr_nodes: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct ListingTableScanNode { #[prost(message, optional, tag = "14")] pub table_name: ::core::option::Option, @@ -117,7 +123,7 @@ pub struct ListingTableScanNode { #[prost(uint32, tag = "9")] pub target_partitions: u32, #[prost(message, repeated, tag = "13")] - pub file_sort_order: ::prost::alloc::vec::Vec, + pub file_sort_order: ::prost::alloc::vec::Vec, #[prost(oneof = "listing_table_scan_node::FileFormatType", tags = "10, 11, 12, 15")] pub file_format_type: ::core::option::Option< listing_table_scan_node::FileFormatType, @@ -200,7 +206,7 @@ pub struct SortNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "2")] - pub expr: ::prost::alloc::vec::Vec, + pub expr: ::prost::alloc::vec::Vec, /// Maximum number of highest/lowest rows to fetch; negative means no limit #[prost(int64, tag = "3")] pub fetch: i64, @@ -256,7 +262,7 @@ pub struct CreateExternalTableNode { #[prost(string, tag = "7")] pub definition: ::prost::alloc::string::String, #[prost(message, repeated, tag = "10")] - pub order_exprs: ::prost::alloc::vec::Vec, + pub order_exprs: ::prost::alloc::vec::Vec, #[prost(bool, tag = "11")] pub unbounded: bool, #[prost(map = "string, string", tag = "8")] @@ -402,7 +408,7 @@ pub struct DistinctOnNode { #[prost(message, repeated, tag = "2")] pub select_expr: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "3")] - pub sort_expr: ::prost::alloc::vec::Vec, + pub sort_expr: ::prost::alloc::vec::Vec, #[prost(message, optional, boxed, tag = "4")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, } @@ -488,7 +494,7 @@ pub struct SubqueryAliasNode { pub struct LogicalExprNode { #[prost( oneof = "logical_expr_node::ExprType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" )] pub expr_type: ::core::option::Option, } @@ -521,8 +527,6 @@ pub mod logical_expr_node { Case(::prost::alloc::boxed::Box), #[prost(message, tag = "11")] Cast(::prost::alloc::boxed::Box), - #[prost(message, tag = "12")] - Sort(::prost::alloc::boxed::Box), #[prost(message, tag = "13")] Negative(::prost::alloc::boxed::Box), #[prost(message, tag = "14")] @@ -740,7 +744,7 @@ pub struct AggregateUdfExprNode { #[prost(message, optional, boxed, tag = "3")] pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "4")] - pub order_by: ::prost::alloc::vec::Vec, + pub order_by: ::prost::alloc::vec::Vec, #[prost(bytes = "vec", optional, tag = "6")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, } @@ -762,7 +766,7 @@ pub struct WindowExprNode { #[prost(message, repeated, tag = "5")] pub partition_by: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "6")] - pub order_by: ::prost::alloc::vec::Vec, + pub order_by: ::prost::alloc::vec::Vec, /// repeated LogicalExprNode filter = 7; #[prost(message, optional, tag = "8")] pub window_frame: ::core::option::Option, @@ -869,8 +873,8 @@ pub struct TryCastNode { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortExprNode { - #[prost(message, optional, boxed, tag = "1")] - pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "1")] + pub expr: ::core::option::Option, #[prost(bool, tag = "2")] pub asc: bool, #[prost(bool, tag = "3")] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index acda1298dd80..3ba1cb945e9c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -22,11 +22,11 @@ use datafusion_common::{ exec_datafusion_err, internal_err, plan_datafusion_err, Result, ScalarValue, TableReference, UnnestOptions, }; -use datafusion_expr::expr::{Alias, Placeholder}; +use datafusion_expr::expr::{Alias, Placeholder, Sort}; use datafusion_expr::expr::{Unnest, WildcardOptions}; use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ - expr::{self, InList, Sort, WindowFunction}, + expr::{self, InList, WindowFunction}, logical_plan::{PlanType, StringifiedPlan}, Between, BinaryExpr, BuiltInWindowFunction, Case, Cast, Expr, GroupingSet, GroupingSet::GroupingSets, @@ -267,7 +267,7 @@ pub fn parse_expr( .as_ref() .ok_or_else(|| Error::required("window_function"))?; let partition_by = parse_exprs(&expr.partition_by, registry, codec)?; - let mut order_by = parse_exprs(&expr.order_by, registry, codec)?; + let mut order_by = parse_sorts(&expr.order_by, registry, codec)?; let window_frame = expr .window_frame .as_ref() @@ -524,16 +524,6 @@ pub fn parse_expr( let data_type = cast.arrow_type.as_ref().required("arrow_type")?; Ok(Expr::TryCast(TryCast::new(expr, data_type))) } - ExprType::Sort(sort) => Ok(Expr::Sort(Sort::new( - Box::new(parse_required_expr( - sort.expr.as_deref(), - registry, - "expr", - codec, - )?), - sort.asc, - sort.nulls_first, - ))), ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, ))), @@ -588,7 +578,7 @@ pub fn parse_expr( parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), match pb.order_by.len() { 0 => None, - _ => Some(parse_exprs(&pb.order_by, registry, codec)?), + _ => Some(parse_sorts(&pb.order_by, registry, codec)?), }, None, ))) @@ -635,6 +625,37 @@ where Ok(res) } +pub fn parse_sorts<'a, I>( + protos: I, + registry: &dyn FunctionRegistry, + codec: &dyn LogicalExtensionCodec, +) -> Result, Error> +where + I: IntoIterator, +{ + protos + .into_iter() + .map(|sort| parse_sort(sort, registry, codec)) + .collect::, Error>>() +} + +pub fn parse_sort( + sort: &protobuf::SortExprNode, + registry: &dyn FunctionRegistry, + codec: &dyn LogicalExtensionCodec, +) -> Result { + Ok(Sort::new( + Box::new(parse_required_expr( + sort.expr.as_ref(), + registry, + "expr", + codec, + )?), + sort.asc, + sort.nulls_first, + )) +} + /// Parse an optional escape_char for Like, ILike, SimilarTo fn parse_escape_char(s: &str) -> Result> { match s.len() { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index c46ece2c06ae..bf5394ec01de 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -20,7 +20,7 @@ use std::fmt::Debug; use std::sync::Arc; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; -use crate::protobuf::{CustomTableScanNode, LogicalExprNodeCollection}; +use crate::protobuf::{CustomTableScanNode, SortExprNodeCollection}; use crate::{ convert_required, into_required, protobuf::{ @@ -62,14 +62,13 @@ use datafusion_expr::{ EmptyRelation, Extension, Join, JoinConstraint, Limit, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Values, Window, }, - DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, WindowUDF, + DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, + WindowUDF, }; use datafusion_expr::{AggregateUDF, Unnest}; use self::to_proto::{serialize_expr, serialize_exprs}; -use datafusion_expr::expr::{ - sort_vec_from_expr, sort_vec_to_expr, sort_vec_vec_from_expr, -}; +use crate::logical_plan::to_proto::serialize_sorts; use prost::bytes::BufMut; use prost::Message; @@ -350,8 +349,8 @@ impl AsLogicalPlan for LogicalPlanNode { let mut all_sort_orders = vec![]; for order in &scan.file_sort_order { - all_sort_orders.push(from_proto::parse_exprs( - &order.logical_expr_nodes, + all_sort_orders.push(from_proto::parse_sorts( + &order.sort_expr_nodes, ctx, extension_codec, )?) @@ -417,7 +416,7 @@ impl AsLogicalPlan for LogicalPlanNode { ) .with_collect_stat(scan.collect_stat) .with_target_partitions(scan.target_partitions as usize) - .with_file_sort_order(sort_vec_vec_from_expr(all_sort_orders)); + .with_file_sort_order(all_sort_orders); let config = ListingTableConfig::new_with_multi_paths(table_paths.clone()) @@ -479,11 +478,9 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::Sort(sort) => { let input: LogicalPlan = into_logical_plan!(sort.input, ctx, extension_codec)?; - let sort_expr: Vec = - from_proto::parse_exprs(&sort.expr, ctx, extension_codec)?; - LogicalPlanBuilder::from(input) - .sort(sort_vec_from_expr(sort_expr))? - .build() + let sort_expr: Vec = + from_proto::parse_sorts(&sort.expr, ctx, extension_codec)?; + LogicalPlanBuilder::from(input).sort(sort_expr)?.build() } LogicalPlanType::Repartition(repartition) => { use datafusion::logical_expr::Partitioning; @@ -541,8 +538,8 @@ impl AsLogicalPlan for LogicalPlanNode { let mut order_exprs = vec![]; for expr in &create_extern_table.order_exprs { - order_exprs.push(from_proto::parse_exprs( - &expr.logical_expr_nodes, + order_exprs.push(from_proto::parse_sorts( + &expr.sort_expr_nodes, ctx, extension_codec, )?); @@ -777,7 +774,7 @@ impl AsLogicalPlan for LogicalPlanNode { )?; let sort_expr = match distinct_on.sort_expr.len() { 0 => None, - _ => Some(from_proto::parse_exprs( + _ => Some(from_proto::parse_sorts( &distinct_on.sort_expr, ctx, extension_codec, @@ -986,13 +983,10 @@ impl AsLogicalPlan for LogicalPlanNode { let options = listing_table.options(); - let mut exprs_vec: Vec = vec![]; + let mut exprs_vec: Vec = vec![]; for order in &options.file_sort_order { - let expr_vec = LogicalExprNodeCollection { - logical_expr_nodes: serialize_exprs( - &sort_vec_to_expr(order.clone()), - extension_codec, - )?, + let expr_vec = SortExprNodeCollection { + sort_expr_nodes: serialize_sorts(order, extension_codec)?, }; exprs_vec.push(expr_vec); } @@ -1122,7 +1116,7 @@ impl AsLogicalPlan for LogicalPlanNode { )?; let sort_expr = match sort_expr { None => vec![], - Some(sort_expr) => serialize_exprs(sort_expr, extension_codec)?, + Some(sort_expr) => serialize_sorts(sort_expr, extension_codec)?, }; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::DistinctOn(Box::new( @@ -1266,13 +1260,13 @@ impl AsLogicalPlan for LogicalPlanNode { input.as_ref(), extension_codec, )?; - let selection_expr: Vec = - serialize_exprs(expr, extension_codec)?; + let sort_expr: Vec = + serialize_sorts(expr, extension_codec)?; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Sort(Box::new( protobuf::SortNode { input: Some(Box::new(input)), - expr: selection_expr, + expr: sort_expr, fetch: fetch.map(|f| f as i64).unwrap_or(-1i64), }, ))), @@ -1342,10 +1336,10 @@ impl AsLogicalPlan for LogicalPlanNode { column_defaults, }, )) => { - let mut converted_order_exprs: Vec = vec![]; + let mut converted_order_exprs: Vec = vec![]; for order in order_exprs { - let temp = LogicalExprNodeCollection { - logical_expr_nodes: serialize_exprs(order, extension_codec)?, + let temp = SortExprNodeCollection { + sort_expr_nodes: serialize_sorts(order, extension_codec)?, }; converted_order_exprs.push(temp); } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index bb7bf84a3387..b937c03f79d9 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -22,12 +22,12 @@ use datafusion_common::{TableReference, UnnestOptions}; use datafusion_expr::expr::{ self, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like, Placeholder, - ScalarFunction, Sort, Unnest, + ScalarFunction, Unnest, }; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, BuiltInWindowFunction, Expr, - JoinConstraint, JoinType, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, + JoinConstraint, JoinType, SortExpr, TryCast, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, }; use crate::protobuf::{ @@ -343,7 +343,7 @@ pub fn serialize_expr( None }; let partition_by = serialize_exprs(partition_by, codec)?; - let order_by = serialize_exprs(order_by, codec)?; + let order_by = serialize_sorts(order_by, codec)?; let window_frame: Option = Some(window_frame.try_into()?); @@ -380,7 +380,7 @@ pub fn serialize_expr( None => None, }, order_by: match order_by { - Some(e) => serialize_exprs(e, codec)?, + Some(e) => serialize_sorts(e, codec)?, None => vec![], }, fun_definition: (!buf.is_empty()).then_some(buf), @@ -537,20 +537,6 @@ pub fn serialize_expr( expr_type: Some(ExprType::TryCast(expr)), } } - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let expr = Box::new(protobuf::SortExprNode { - expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - asc: *asc, - nulls_first: *nulls_first, - }); - protobuf::LogicalExprNode { - expr_type: Some(ExprType::Sort(expr)), - } - } Expr::Negative(expr) => { let expr = Box::new(protobuf::NegativeNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), @@ -635,6 +621,30 @@ pub fn serialize_expr( Ok(expr_node) } +pub fn serialize_sorts<'a, I>( + sorts: I, + codec: &dyn LogicalExtensionCodec, +) -> Result, Error> +where + I: IntoIterator, +{ + sorts + .into_iter() + .map(|sort| { + let SortExpr { + expr, + asc, + nulls_first, + } = sort; + Ok(protobuf::SortExprNode { + expr: Some(serialize_expr(expr.as_ref(), codec)?), + asc: *asc, + nulls_first: *nulls_first, + }) + }) + .collect::, Error>>() +} + impl From for protobuf::TableReference { fn from(t: TableReference) -> Self { use protobuf::table_reference::TableReferenceEnum; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index f6f9702f6cc1..76c6d7e068ce 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -59,7 +59,7 @@ use datafusion_common::{ use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, ScalarFunction, - Sort, Unnest, WildcardOptions, + Unnest, WildcardOptions, }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ @@ -871,7 +871,7 @@ async fn roundtrip_expr_api() -> Result<()> { count(lit(1)), count_distinct(lit(1)), first_value(lit(1), None), - first_value(lit(1), Some(vec![lit(2).sort(true, true).to_expr()])), + first_value(lit(1), Some(vec![lit(2).sort(true, true)])), avg(lit(1.5)), covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), @@ -1937,14 +1937,6 @@ fn roundtrip_try_cast() { roundtrip_expr_test(test_expr, ctx); } -#[test] -fn roundtrip_sort_expr() { - let test_expr = Expr::Sort(Sort::new(Box::new(lit(1.0_f32)), true, true)); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); -} - #[test] fn roundtrip_negative() { let test_expr = Expr::Negative(Box::new(lit(1.0_f32))); @@ -2249,7 +2241,7 @@ fn roundtrip_window() { vec![], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2").sort(true, false).to_expr()]) + .order_by(vec![col("col2").sort(true, false)]) .window_frame(WindowFrame::new(Some(false))) .build() .unwrap(); @@ -2262,7 +2254,7 @@ fn roundtrip_window() { vec![], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2").sort(false, true).to_expr()]) + .order_by(vec![col("col2").sort(false, true)]) .window_frame(WindowFrame::new(Some(false))) .build() .unwrap(); @@ -2281,7 +2273,7 @@ fn roundtrip_window() { vec![], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2").sort(false, false).to_expr()]) + .order_by(vec![col("col2").sort(false, false)]) .window_frame(range_number_frame) .build() .unwrap(); @@ -2298,7 +2290,7 @@ fn roundtrip_window() { vec![col("col1")], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2").sort(true, true).to_expr()]) + .order_by(vec![col("col2").sort(true, true)]) .window_frame(row_number_frame.clone()) .build() .unwrap(); @@ -2348,7 +2340,7 @@ fn roundtrip_window() { vec![col("col1")], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2").sort(true, true).to_expr()]) + .order_by(vec![col("col2").sort(true, true)]) .window_frame(row_number_frame.clone()) .build() .unwrap(); @@ -2425,7 +2417,7 @@ fn roundtrip_window() { vec![col("col1")], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2").sort(true, true).to_expr()]) + .order_by(vec![col("col2").sort(true, true)]) .window_frame(row_number_frame.clone()) .build() .unwrap(); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 72e08e4b8fb5..cfc10fbc39d1 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -283,22 +283,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let func_deps = schema.functional_dependencies(); // Find whether ties are possible in the given ordering let is_ordering_strict = order_by.iter().find_map(|orderby_expr| { - if let Expr::Sort(sort_expr) = orderby_expr { - if let Expr::Column(col) = sort_expr.expr.as_ref() { - let idx = schema.index_of_column(col).ok()?; - return if func_deps.iter().any(|dep| { - dep.source_indices == vec![idx] - && dep.mode == Dependency::Single - }) { - Some(true) - } else { - Some(false) - }; - } - Some(false) - } else { - panic!("order_by expression must be of type Sort"); + if let Expr::Column(col) = orderby_expr.expr.as_ref() { + let idx = schema.index_of_column(col).ok()?; + return if func_deps.iter().any(|dep| { + dep.source_indices == vec![idx] && dep.mode == Dependency::Single + }) { + Some(true) + } else { + Some(false) + }; } + Some(false) }); let window_frame = window diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index 7fb32f714cfa..cdaa787cedd0 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -20,7 +20,7 @@ use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, Result, }; use datafusion_expr::expr::Sort; -use datafusion_expr::Expr; +use datafusion_expr::{Expr, SortExpr}; use sqlparser::ast::{Expr as SQLExpr, OrderByExpr, Value}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -44,7 +44,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, literal_to_column: bool, additional_schema: Option<&DFSchema>, - ) -> Result> { + ) -> Result> { if exprs.is_empty() { return Ok(vec![]); } @@ -99,13 +99,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; let asc = asc.unwrap_or(true); - expr_vec.push(Expr::Sort(Sort::new( + expr_vec.push(Sort::new( Box::new(expr), asc, // when asc is true, by default nulls last to be consistent with postgres // postgres rule: https://www.postgresql.org/docs/current/queries-order.html nulls_first.unwrap_or(!asc), - ))) + )) } Ok(expr_vec) } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 1c2b189d266c..71328cfd018c 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{not_impl_err, plan_err, Constraints, Result, ScalarValue}; -use datafusion_expr::expr::sort_vec_from_expr; +use datafusion_expr::expr::Sort; use datafusion_expr::{ CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, Operator, @@ -120,7 +120,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn order_by( &self, plan: LogicalPlan, - order_by: Vec, + order_by: Vec, ) -> Result { if order_by.is_empty() { return Ok(plan); @@ -132,9 +132,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let distinct_on = distinct_on.clone().with_sort_expr(order_by)?; Ok(LogicalPlan::Distinct(Distinct::On(distinct_on))) } else { - LogicalPlanBuilder::from(plan) - .sort(sort_vec_from_expr(order_by))? - .build() + LogicalPlanBuilder::from(plan).sort(order_by)?.build() } } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 384893bfa94c..8a26671fcb6c 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -31,7 +31,7 @@ use datafusion_common::UnnestOptions; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; use datafusion_expr::expr_rewriter::{ - normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_cols, + normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_sorts, }; use datafusion_expr::utils::{ expr_as_column_expr, expr_to_columns, find_aggregate_exprs, find_window_exprs, @@ -107,7 +107,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { true, Some(base_plan.schema().as_ref()), )?; - let order_by_rex = normalize_cols(order_by_rex, &projected_plan)?; + let order_by_rex = normalize_sorts(order_by_rex, &projected_plan)?; // this alias map is resolved and looked up in both having exprs and group by exprs let alias_map = extract_aliases(&select_exprs); diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index e75a96e78d48..3dfc379b039a 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -48,9 +48,10 @@ use datafusion_expr::{ CreateIndex as PlanCreateIndex, CreateMemoryTable, CreateView, DescribeTable, DmlStatement, DropCatalogSchema, DropFunction, DropTable, DropView, EmptyRelation, Explain, Expr, ExprSchemable, Filter, LogicalPlan, LogicalPlanBuilder, - OperateFunctionArg, PlanType, Prepare, SetVariable, Statement as PlanStatement, - ToStringifiedPlan, TransactionAccessMode, TransactionConclusion, TransactionEnd, - TransactionIsolationLevel, TransactionStart, Volatility, WriteOp, + OperateFunctionArg, PlanType, Prepare, SetVariable, SortExpr, + Statement as PlanStatement, ToStringifiedPlan, TransactionAccessMode, + TransactionConclusion, TransactionEnd, TransactionIsolationLevel, TransactionStart, + Volatility, WriteOp, }; use sqlparser::ast; use sqlparser::ast::{ @@ -952,7 +953,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { order_exprs: Vec, schema: &DFSchemaRef, planner_context: &mut PlannerContext, - ) -> Result>> { + ) -> Result>> { // Ask user to provide a schema if schema is empty. if !order_exprs.is_empty() && schema.fields().is_empty() { return plan_err!( @@ -966,8 +967,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let expr_vec = self.order_by_to_sort_expr(expr, schema, planner_context, true, None)?; // Verify that columns of all SortExprs exist in the schema: - for expr in expr_vec.iter() { - for column in expr.column_refs().iter() { + for sort in expr_vec.iter() { + for column in sort.expr.column_refs().iter() { if !schema.has_column(column) { // Return an error if any column is not in the schema: return plan_err!("Column {column} is not in schema"); diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index a59b64723730..a61a0e095d3c 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use core::fmt; - use datafusion_expr::ScalarUDF; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ @@ -24,7 +22,7 @@ use sqlparser::ast::{ ObjectName, TimezoneInfo, UnaryOperator, }; use std::sync::Arc; -use std::{fmt::Display, vec}; +use std::vec; use super::dialect::{DateFieldExtractStyle, IntervalStyle}; use super::Unparser; @@ -46,33 +44,6 @@ use datafusion_expr::{ Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, }; -/// DataFusion's Exprs can represent either an `Expr` or an `OrderByExpr` -pub enum Unparsed { - // SQL Expression - Expr(ast::Expr), - // SQL ORDER BY expression (e.g. `col ASC NULLS FIRST`) - OrderByExpr(ast::OrderByExpr), -} - -impl Unparsed { - pub fn into_order_by_expr(self) -> Result { - if let Unparsed::OrderByExpr(order_by_expr) = self { - Ok(order_by_expr) - } else { - internal_err!("Expected Sort expression to be converted an OrderByExpr") - } - } -} - -impl Display for Unparsed { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Unparsed::Expr(expr) => write!(f, "{}", expr), - Unparsed::OrderByExpr(order_by_expr) => write!(f, "{}", order_by_expr), - } - } -} - /// Convert a DataFusion [`Expr`] to [`ast::Expr`] /// /// This function is the opposite of [`SqlToRel::sql_to_expr`] and can be used @@ -106,13 +77,9 @@ pub fn expr_to_sql(expr: &Expr) -> Result { unparser.expr_to_sql(expr) } -/// Convert a DataFusion [`Expr`] to [`Unparsed`] -/// -/// This function is similar to expr_to_sql, but it supports converting more [`Expr`] types like -/// `Sort` expressions to `OrderByExpr` expressions. -pub fn expr_to_unparsed(expr: &Expr) -> Result { +pub fn sort_to_sql(sort: &Sort) -> Result { let unparser = Unparser::default(); - unparser.expr_to_unparsed(expr) + unparser.sort_to_sql(sort) } const LOWEST: &BinaryOperator = &BinaryOperator::Or; @@ -286,7 +253,7 @@ impl Unparser<'_> { }; let order_by: Vec = order_by .iter() - .map(|expr| expr_to_unparsed(expr)?.into_order_by_expr()) + .map(sort_to_sql) .collect::>>()?; let start_bound = self.convert_bound(&window_frame.start_bound)?; @@ -413,11 +380,6 @@ impl Unparser<'_> { negated: *negated, }) } - Expr::Sort(Sort { - expr: _, - asc: _, - nulls_first: _, - }) => plan_err!("Sort expression should be handled by expr_to_unparsed"), Expr::IsNull(expr) => { Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql_inner(expr)?))) } @@ -534,36 +496,26 @@ impl Unparser<'_> { } } - /// This function can convert more [`Expr`] types than `expr_to_sql`, - /// returning an [`Unparsed`] like `Sort` expressions to `OrderByExpr` - /// expressions. - pub fn expr_to_unparsed(&self, expr: &Expr) -> Result { - match expr { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let sql_parser_expr = self.expr_to_sql(expr)?; + pub fn sort_to_sql(&self, sort: &Sort) -> Result { + let Sort { + expr, + asc, + nulls_first, + } = sort; + let sql_parser_expr = self.expr_to_sql(expr)?; - let nulls_first = if self.dialect.supports_nulls_first_in_sort() { - Some(*nulls_first) - } else { - None - }; + let nulls_first = if self.dialect.supports_nulls_first_in_sort() { + Some(*nulls_first) + } else { + None + }; - Ok(Unparsed::OrderByExpr(ast::OrderByExpr { - expr: sql_parser_expr, - asc: Some(*asc), - nulls_first, - with_fill: None, - })) - } - _ => { - let sql_parser_expr = self.expr_to_sql(expr)?; - Ok(Unparsed::Expr(sql_parser_expr)) - } - } + Ok(ast::OrderByExpr { + expr: sql_parser_expr, + asc: Some(*asc), + nulls_first, + with_fill: None, + }) } fn scalar_function_to_sql_overrides( @@ -1527,7 +1479,7 @@ mod tests { case, col, cube, exists, grouping_set, interval_datetime_lit, interval_year_month_lit, lit, not, not_exists, out_ref_col, placeholder, rollup, table_scan, try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, - Signature, SortExpr, Volatility, WindowFrame, WindowFunctionDefinition, + Signature, Volatility, WindowFrame, WindowFunctionDefinition, }; use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt}; use datafusion_functions_aggregate::count::count_udaf; @@ -1809,11 +1761,7 @@ mod tests { fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), args: vec![wildcard()], partition_by: vec![], - order_by: vec![Expr::Sort(Sort::new( - Box::new(col("a")), - false, - true, - ))], + order_by: vec![Sort::new(Box::new(col("a")), false, true)], window_frame: WindowFrame::new_bounds( datafusion_expr::WindowFrameUnits::Range, datafusion_expr::WindowFrameBound::Preceding( @@ -1941,24 +1889,6 @@ mod tests { Ok(()) } - #[test] - fn expr_to_unparsed_ok() -> Result<()> { - let tests: Vec<(Expr, &str)> = vec![ - ((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#), - (col("a").sort(true, true).to_expr(), r#"a ASC NULLS FIRST"#), - ]; - - for (expr, expected) in tests { - let ast = expr_to_unparsed(&expr)?; - - let actual = format!("{}", ast); - - assert_eq!(actual, expected); - } - - Ok(()) - } - #[test] fn custom_dialect_with_identifier_quote_style() -> Result<()> { let dialect = CustomDialectBuilder::new() @@ -2047,7 +1977,7 @@ mod tests { #[test] fn customer_dialect_support_nulls_first_in_ort() -> Result<()> { - let tests: Vec<(SortExpr, &str, bool)> = vec![ + let tests: Vec<(Sort, &str, bool)> = vec![ (col("a").sort(true, true), r#"a ASC NULLS FIRST"#, true), (col("a").sort(true, true), r#"a ASC"#, false), ]; @@ -2057,7 +1987,7 @@ mod tests { .with_supports_nulls_first_in_sort(supports_nulls_first_in_sort) .build(); let unparser = Unparser::new(&dialect); - let ast = unparser.expr_to_unparsed(&expr.to_expr())?; + let ast = unparser.sort_to_sql(&expr)?; let actual = format!("{}", ast); diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index b2fd32566aa8..83ae64ba238b 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -29,8 +29,6 @@ pub use plan::plan_to_sql; use self::dialect::{DefaultDialect, Dialect}; pub mod dialect; -pub use expr::Unparsed; - /// Convert a DataFusion [`Expr`] to [`sqlparser::ast::Expr`] /// /// See [`expr_to_sql`] for background. `Unparser` allows greater control of diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 106705c322fc..509c5dd52cd4 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -15,11 +15,10 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{ - internal_err, not_impl_err, plan_err, Column, DataFusionError, Result, -}; +use datafusion_common::{internal_err, not_impl_err, Column, DataFusionError, Result}; use datafusion_expr::{ expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, Projection, + SortExpr, }; use sqlparser::ast::{self, Ident, SetExpr}; @@ -318,7 +317,7 @@ impl Unparser<'_> { return self.derive(plan, relation); } if let Some(query_ref) = query { - query_ref.order_by(self.sort_to_sql(sort.expr.clone())?); + query_ref.order_by(self.sorts_to_sql(sort.expr.clone())?); } else { return internal_err!( "Sort operator only valid in a statement context." @@ -361,7 +360,7 @@ impl Unparser<'_> { .collect::>>()?; if let Some(sort_expr) = &on.sort_expr { if let Some(query_ref) = query { - query_ref.order_by(self.sort_to_sql(sort_expr.clone())?); + query_ref.order_by(self.sorts_to_sql(sort_expr.clone())?); } else { return internal_err!( "Sort operator only valid in a statement context." @@ -525,14 +524,10 @@ impl Unparser<'_> { } } - fn sort_to_sql(&self, sort_exprs: Vec) -> Result> { + fn sorts_to_sql(&self, sort_exprs: Vec) -> Result> { sort_exprs .iter() - .map(|expr: &Expr| { - self.expr_to_unparsed(expr)? - .into_order_by_expr() - .or(plan_err!("Expecting Sort expr")) - }) + .map(|sort_expr| self.sort_to_sql(sort_expr)) .collect::>>() } diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 9e1adcf4df31..522a08af8546 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -21,10 +21,11 @@ use std::{ }; use datafusion_common::{ - tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeIterator}, + tree_node::{Transformed, TransformedResult, TreeNode}, Result, }; -use datafusion_expr::{Expr, LogicalPlan, Projection, Sort}; +use datafusion_expr::tree_node::transform_sort_vec; +use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr}; use sqlparser::ast::Ident; /// Normalize the schema of a union plan to remove qualifiers from the schema fields and sort expressions. @@ -83,20 +84,18 @@ pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> Result } /// Rewrite sort expressions that have a UNION plan as their input to remove the table reference. -fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { - let sort_exprs: Vec = exprs - .into_iter() - .map_until_stop_and_collect(|expr| { - expr.transform_up(|expr| { - if let Expr::Column(mut col) = expr { - col.relation = None; - Ok(Transformed::yes(Expr::Column(col))) - } else { - Ok(Transformed::no(expr)) - } - }) +fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { + let sort_exprs = transform_sort_vec(exprs, &mut |expr| { + expr.transform_up(|expr| { + if let Expr::Column(mut col) = expr { + col.relation = None; + Ok(Transformed::yes(Expr::Column(col))) + } else { + Ok(Transformed::no(expr)) + } }) - .data()?; + }) + .data()?; Ok(sort_exprs) } @@ -158,12 +157,8 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( .collect::>(); let mut collects = p.expr.clone(); - for expr in &sort.expr { - if let Expr::Sort(s) = expr { - collects.push(s.expr.as_ref().clone()); - } else { - panic!("sort expression must be of type Sort"); - } + for sort in &sort.expr { + collects.push(sort.expr.as_ref().clone()); } // Compare outer collects Expr::to_string with inner collected transformed values diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index ee52d3559cb2..5a203703e967 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -4408,10 +4408,7 @@ fn plan_create_index() { assert_eq!(using, Some("btree".to_string())); assert_eq!( columns, - vec![ - col("name").sort(true, false).to_expr(), - col("age").sort(false, true).to_expr(), - ] + vec![col("name").sort(true, false), col("age").sort(false, true),] ); assert!(unique); assert!(if_not_exists); diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index bb0b603f614c..05903bb56cfe 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -27,11 +27,11 @@ use datafusion::common::{ DFSchemaRef, }; use datafusion::execution::FunctionRegistry; -use datafusion::logical_expr::expr::{sort_vec_to_expr, Exists, InSubquery, Sort}; +use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; use datafusion::logical_expr::{ expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, Expr, - ExprSchemable, LogicalPlan, Operator, Projection, Values, + ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, Values, }; use substrait::proto::expression::subquery::set_predicate::PredicateOp; use url::Url; @@ -986,7 +986,7 @@ pub async fn from_substrait_agg_func( input_schema: &DFSchema, extensions: &Extensions, filter: Option>, - order_by: Option>, + order_by: Option>, distinct: bool, ) -> Result> { let args = @@ -1237,7 +1237,7 @@ pub async fn from_substrait_rex( extensions, ) .await?, - order_by: sort_vec_to_expr(order_by), + order_by, window_frame: datafusion::logical_expr::WindowFrame::new_bounds( bound_units, from_substrait_bound(&window.lower_bound, true)?, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 14553fbdf452..592390a285ba 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -764,7 +764,7 @@ pub fn to_substrait_agg_measure( match expr { Expr::AggregateFunction(expr::AggregateFunction { func, args, distinct, filter, order_by, null_treatment: _, }) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr.unwrap_sort(), schema, extensions)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? } else { vec![] }; @@ -2102,30 +2102,26 @@ fn try_to_substrait_field_reference( fn substrait_sort_field( ctx: &SessionContext, - expr: &Expr, + sort: &Sort, schema: &DFSchemaRef, extensions: &mut Extensions, ) -> Result { - match expr { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let e = to_substrait_rex(ctx, expr, schema, 0, extensions)?; - let d = match (asc, nulls_first) { - (true, true) => SortDirection::AscNullsFirst, - (true, false) => SortDirection::AscNullsLast, - (false, true) => SortDirection::DescNullsFirst, - (false, false) => SortDirection::DescNullsLast, - }; - Ok(SortField { - expr: Some(e), - sort_kind: Some(SortKind::Direction(d as i32)), - }) - } - _ => not_impl_err!("Expecting sort expression but got {expr:?}"), - } + let Sort { + expr, + asc, + nulls_first, + } = sort; + let e = to_substrait_rex(ctx, expr, schema, 0, extensions)?; + let d = match (asc, nulls_first) { + (true, true) => SortDirection::AscNullsFirst, + (true, false) => SortDirection::AscNullsLast, + (false, true) => SortDirection::DescNullsFirst, + (false, false) => SortDirection::DescNullsLast, + }; + Ok(SortField { + expr: Some(e), + sort_kind: Some(SortKind::Direction(d as i32)), + }) } fn substrait_field_ref(index: usize) -> Result { diff --git a/docs/source/library-user-guide/using-the-dataframe-api.md b/docs/source/library-user-guide/using-the-dataframe-api.md index 3bd47ef50e51..7f3e28c255c6 100644 --- a/docs/source/library-user-guide/using-the-dataframe-api.md +++ b/docs/source/library-user-guide/using-the-dataframe-api.md @@ -263,14 +263,14 @@ async fn main() -> Result<()>{ let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; // Create a new DataFrame sorted by `id`, `bank_account` let new_df = df.select(vec![col("a"), col("b")])? - .sort(vec![col("a")])?; + .sort_by(vec![col("a")])?; // Build the same plan using the LogicalPlanBuilder // Similar to `SELECT a, b FROM example.csv ORDER BY a` let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; let (_state, plan) = df.into_parts(); // get the DataFrame's LogicalPlan let plan = LogicalPlanBuilder::from(plan) .project(vec![col("a"), col("b")])? - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .build()?; // prove they are the same assert_eq!(new_df.logical_plan(), &plan); From cc6e9912ee0c189173c6491a017105f0205e3755 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 28 Aug 2024 09:13:01 +0200 Subject: [PATCH 07/11] use assert_eq just like in LogicalPlan.with_new_exprs --- datafusion/expr/src/tree_node.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 69748aded531..06b60947cf9a 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -403,14 +403,7 @@ pub fn transform_sort_vec Result>>( } pub fn replace_sort_expressions(sorts: Vec, new_expr: Vec) -> Vec { - if sorts.len() != new_expr.len() { - panic!( - "Incorrect number of new_expr, expected {}, got {}", - sorts.len(), - new_expr.len() - ); - } - + assert_eq!(sorts.len(), new_expr.len()); let mut new_sorts = Vec::with_capacity(sorts.len()); for (i, expr) in new_expr.into_iter().enumerate() { new_sorts.push(replace_sort_expression(sorts[i].clone(), expr)); From 39877e453e8c1d6172e727a20edf2f3c0bc31482 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 28 Aug 2024 09:19:21 +0200 Subject: [PATCH 08/11] avoid clone in replace_sort_expressions --- datafusion/expr/src/tree_node.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 06b60947cf9a..90d61bf63763 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -404,11 +404,11 @@ pub fn transform_sort_vec Result>>( pub fn replace_sort_expressions(sorts: Vec, new_expr: Vec) -> Vec { assert_eq!(sorts.len(), new_expr.len()); - let mut new_sorts = Vec::with_capacity(sorts.len()); - for (i, expr) in new_expr.into_iter().enumerate() { - new_sorts.push(replace_sort_expression(sorts[i].clone(), expr)); - } - new_sorts + sorts + .into_iter() + .zip(new_expr) + .map(|(sort, expr)| replace_sort_expression(sort, expr)) + .collect() } pub fn replace_sort_expression(sort: Sort, new_expr: Expr) -> Sort { From 09b20d5053c1d368505f1d2b835230df6fd40f55 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 28 Aug 2024 09:29:56 +0200 Subject: [PATCH 09/11] reduce cloning in EliminateDuplicatedExpr --- datafusion/optimizer/src/eliminate_duplicated_expr.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 5904da3436a0..943599defb98 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -51,15 +51,13 @@ impl OptimizerRule for EliminateDuplicatedExpr { match plan { LogicalPlan::Sort(sort) => { let len = sort.expr.len(); - let mut first_sort_by_expr: IndexMap = + let mut first_sort_by_expr: IndexMap<&Expr, &SortExpr> = IndexMap::default(); for s in &sort.expr { - first_sort_by_expr - .entry(s.expr.as_ref().clone()) - .or_insert(s.clone()); + first_sort_by_expr.entry(s.expr.as_ref()).or_insert(s); } let unique_exprs: Vec = - first_sort_by_expr.into_values().collect(); + first_sort_by_expr.into_values().cloned().collect(); let transformed = if len != unique_exprs.len() { Transformed::yes From 9931e1c718139a84806c4e83530d6f2a27cb78be Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 28 Aug 2024 11:39:41 +0200 Subject: [PATCH 10/11] restore SortExprWrapper this commit is longer than advised in the review comment, but after squashing the diff will be smaller --- .../src/eliminate_duplicated_expr.rs | 33 ++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 943599defb98..16106a3bfd30 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -23,7 +23,8 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::Result; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{Aggregate, Expr, Sort, SortExpr}; -use indexmap::{IndexMap, IndexSet}; +use indexmap::IndexSet; +use std::hash::{Hash, Hasher}; /// Optimization rule that eliminate duplicated expr. #[derive(Default)] pub struct EliminateDuplicatedExpr; @@ -34,6 +35,21 @@ impl EliminateDuplicatedExpr { Self {} } } +// use this structure to avoid initial clone +#[derive(Eq, Clone, Debug)] +struct SortExprWrapper { + expr: SortExpr, +} +impl PartialEq for SortExprWrapper { + fn eq(&self, other: &Self) -> bool { + self.expr.expr == other.expr.expr + } +} +impl Hash for SortExprWrapper { + fn hash(&self, state: &mut H) { + self.expr.expr.hash(state); + } +} impl OptimizerRule for EliminateDuplicatedExpr { fn apply_order(&self) -> Option { Some(ApplyOrder::TopDown) @@ -51,13 +67,14 @@ impl OptimizerRule for EliminateDuplicatedExpr { match plan { LogicalPlan::Sort(sort) => { let len = sort.expr.len(); - let mut first_sort_by_expr: IndexMap<&Expr, &SortExpr> = - IndexMap::default(); - for s in &sort.expr { - first_sort_by_expr.entry(s.expr.as_ref()).or_insert(s); - } - let unique_exprs: Vec = - first_sort_by_expr.into_values().cloned().collect(); + let unique_exprs: Vec<_> = sort + .expr + .into_iter() + .map(|e| SortExprWrapper { expr: e }) + .collect::>() + .into_iter() + .map(|wrapper| wrapper.expr) + .collect(); let transformed = if len != unique_exprs.len() { Transformed::yes From 0ab9348a0ff48b2e11b33148c7d097c490e1d034 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 28 Aug 2024 14:14:21 +0200 Subject: [PATCH 11/11] shorthand SortExprWrapper struct definition --- .../optimizer/src/eliminate_duplicated_expr.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 16106a3bfd30..c460d7a93d26 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -37,17 +37,15 @@ impl EliminateDuplicatedExpr { } // use this structure to avoid initial clone #[derive(Eq, Clone, Debug)] -struct SortExprWrapper { - expr: SortExpr, -} +struct SortExprWrapper(SortExpr); impl PartialEq for SortExprWrapper { fn eq(&self, other: &Self) -> bool { - self.expr.expr == other.expr.expr + self.0.expr == other.0.expr } } impl Hash for SortExprWrapper { fn hash(&self, state: &mut H) { - self.expr.expr.hash(state); + self.0.expr.hash(state); } } impl OptimizerRule for EliminateDuplicatedExpr { @@ -70,10 +68,10 @@ impl OptimizerRule for EliminateDuplicatedExpr { let unique_exprs: Vec<_> = sort .expr .into_iter() - .map(|e| SortExprWrapper { expr: e }) + .map(SortExprWrapper) .collect::>() .into_iter() - .map(|wrapper| wrapper.expr) + .map(|wrapper| wrapper.0) .collect(); let transformed = if len != unique_exprs.len() {