diff --git a/datafusion/src/physical_plan/common.rs b/datafusion/src/physical_plan/common.rs index e60963bbb5b7..2482bfc0872c 100644 --- a/datafusion/src/physical_plan/common.rs +++ b/datafusion/src/physical_plan/common.rs @@ -17,24 +17,22 @@ //! Defines common code used in execution plans -use std::fs; -use std::fs::metadata; -use std::sync::Arc; -use std::task::{Context, Poll}; - +use super::{RecordBatchStream, SendableRecordBatchStream}; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::ExecutionPlan; +use arrow::compute::concat; use arrow::datatypes::SchemaRef; +use arrow::error::ArrowError; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use futures::channel::mpsc; use futures::{SinkExt, Stream, StreamExt, TryStreamExt}; +use std::fs; +use std::fs::metadata; +use std::sync::Arc; +use std::task::{Context, Poll}; use tokio::task::JoinHandle; -use crate::arrow::error::ArrowError; -use crate::error::{DataFusionError, Result}; -use crate::physical_plan::ExecutionPlan; - -use super::{RecordBatchStream, SendableRecordBatchStream}; - /// Stream of record batches pub struct SizedRecordBatchStream { schema: SchemaRef, @@ -83,6 +81,32 @@ pub async fn collect(stream: SendableRecordBatchStream) -> Result ArrowResult> { + if batches.is_empty() { + Ok(None) + } else { + let columns = schema + .fields() + .iter() + .enumerate() + .map(|(i, _)| { + concat( + &batches + .iter() + .map(|batch| batch.column(i).as_ref()) + .collect::>(), + ) + }) + .collect::>>()?; + Ok(Some(RecordBatch::try_new(schema.clone(), columns)?)) + } +} + /// Recursively builds a list of files in a directory with a given extension pub fn build_file_list(dir: &str, ext: &str) -> Result> { let mut filenames: Vec = Vec::new(); @@ -144,3 +168,53 @@ pub(crate) fn spawn_execution( } }) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::{ + array::{Float32Array, Float64Array}, + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, + }; + + #[test] + fn test_combine_batches_empty() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("f32", DataType::Float32, false), + Field::new("f64", DataType::Float64, false), + ])); + let result = combine_batches(&[], schema)?; + assert!(result.is_none()); + Ok(()) + } + + #[test] + fn test_combine_batches() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("f32", DataType::Float32, false), + Field::new("f64", DataType::Float64, false), + ])); + + let batch_count = 1000; + let batch_size = 10; + let batches = (0..batch_count) + .map(|i| { + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Float32Array::from(vec![i as f32; batch_size])), + Arc::new(Float64Array::from(vec![i as f64; batch_size])), + ], + ) + .unwrap() + }) + .collect::>(); + + let result = combine_batches(&batches, schema)?; + assert!(result.is_some()); + let result = result.unwrap(); + assert_eq!(batch_count * batch_size, result.num_rows()); + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs index c5b838c6e84b..7747030d8a93 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sort.rs @@ -17,32 +17,28 @@ //! Defines the SORT plan -use std::any::Any; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::time::Instant; - -use async_trait::async_trait; -use futures::stream::Stream; -use futures::Future; -use hashbrown::HashMap; - -use pin_project_lite::pin_project; - -pub use arrow::compute::SortOptions; -use arrow::compute::{concat, lexsort_to_indices, take, SortColumn, TakeOptions}; -use arrow::datatypes::SchemaRef; -use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; -use arrow::{array::ArrayRef, error::ArrowError}; - use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::error::{DataFusionError, Result}; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::{ common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, SQLMetric, }; +pub use arrow::compute::SortOptions; +use arrow::compute::{lexsort_to_indices, take, SortColumn, TakeOptions}; +use arrow::datatypes::SchemaRef; +use arrow::error::Result as ArrowResult; +use arrow::record_batch::RecordBatch; +use arrow::{array::ArrayRef, error::ArrowError}; +use async_trait::async_trait; +use futures::stream::Stream; +use futures::Future; +use hashbrown::HashMap; +use pin_project_lite::pin_project; +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Instant; /// Sort execution plan #[derive(Debug)] @@ -190,47 +186,25 @@ impl ExecutionPlan for SortExec { } } -fn sort_batches( - batches: &[RecordBatch], - schema: &SchemaRef, +fn sort_batch( + batch: RecordBatch, + schema: SchemaRef, expr: &[PhysicalSortExpr], -) -> ArrowResult> { - if batches.is_empty() { - return Ok(None); - } - // combine all record batches into one for each column - let combined_batch = RecordBatch::try_new( - schema.clone(), - schema - .fields() - .iter() - .enumerate() - .map(|(i, _)| { - concat( - &batches - .iter() - .map(|batch| batch.column(i).as_ref()) - .collect::>(), - ) - }) - .collect::>>()?, - )?; - - // sort combined record batch +) -> ArrowResult { // TODO: pushup the limit expression to sort let indices = lexsort_to_indices( &expr .iter() - .map(|e| e.evaluate_to_sort_column(&combined_batch)) + .map(|e| e.evaluate_to_sort_column(&batch)) .collect::>>() .map_err(DataFusionError::into_arrow_external_error)?, None, )?; // reorder all rows based on sorted indices - let sorted_batch = RecordBatch::try_new( - schema.clone(), - combined_batch + RecordBatch::try_new( + schema, + batch .columns() .iter() .map(|column| { @@ -245,8 +219,7 @@ fn sort_batches( ) }) .collect::>>()?, - ); - sorted_batch.map(Some) + ) } pin_project! { @@ -277,9 +250,14 @@ impl SortStream { .map_err(DataFusionError::into_arrow_external_error) .and_then(move |batches| { let now = Instant::now(); - let result = sort_batches(&batches, &schema, &expr); + // combine all record batches into one for each column + let combined = common::combine_batches(&batches, schema.clone())?; + // sort combined record batch + let result = combined + .map(|batch| sort_batch(batch, schema, &expr)) + .transpose()?; sort_time.add(now.elapsed().as_nanos() as usize); - result + Ok(result) }); tx.send(sorted_batch) diff --git a/integration-tests/sqls/simple_sort.sql b/integration-tests/sqls/simple_sort.sql new file mode 100644 index 000000000000..50fb12dfdc70 --- /dev/null +++ b/integration-tests/sqls/simple_sort.sql @@ -0,0 +1,22 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +SELECT + c2, + c3, + c10 +FROM test +ORDER BY c2 ASC, c3 DESC, c10; diff --git a/integration-tests/test_psql_parity.py b/integration-tests/test_psql_parity.py index 5bd308180e59..51861c583f8a 100644 --- a/integration-tests/test_psql_parity.py +++ b/integration-tests/test_psql_parity.py @@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase): def test_parity(self): root = Path(os.path.dirname(__file__)) / "sqls" files = set(root.glob("*.sql")) - self.assertEqual(len(files), 5, msg="tests are missed") + self.assertEqual(len(files), 6, msg="tests are missed") for fname in files: with self.subTest(fname=fname): datafusion_output = pd.read_csv(