Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 85 additions & 11 deletions datafusion/src/physical_plan/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,22 @@

//! Defines common code used in execution plans

use std::fs;
use std::fs::metadata;
use std::sync::Arc;
use std::task::{Context, Poll};

use super::{RecordBatchStream, SendableRecordBatchStream};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::ExecutionPlan;
use arrow::compute::concat;
use arrow::datatypes::SchemaRef;
use arrow::error::ArrowError;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use futures::channel::mpsc;
use futures::{SinkExt, Stream, StreamExt, TryStreamExt};
use std::fs;
use std::fs::metadata;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::task::JoinHandle;

use crate::arrow::error::ArrowError;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::ExecutionPlan;

use super::{RecordBatchStream, SendableRecordBatchStream};

/// Stream of record batches
pub struct SizedRecordBatchStream {
schema: SchemaRef,
Expand Down Expand Up @@ -83,6 +81,32 @@ pub async fn collect(stream: SendableRecordBatchStream) -> Result<Vec<RecordBatc
.map_err(DataFusionError::from)
}

/// Combine a slice of record batches into one, or returns None if the slice itself
/// is empty; all the record batches inside the slice must be of the same schema.
pub(crate) fn combine_batches(
batches: &[RecordBatch],
schema: SchemaRef,
) -> ArrowResult<Option<RecordBatch>> {
if batches.is_empty() {
Ok(None)
} else {
let columns = schema
.fields()
.iter()
.enumerate()
.map(|(i, _)| {
concat(
&batches
.iter()
.map(|batch| batch.column(i).as_ref())
.collect::<Vec<_>>(),
)
})
.collect::<ArrowResult<Vec<_>>>()?;
Ok(Some(RecordBatch::try_new(schema.clone(), columns)?))
}
}

/// Recursively builds a list of files in a directory with a given extension
pub fn build_file_list(dir: &str, ext: &str) -> Result<Vec<String>> {
let mut filenames: Vec<String> = Vec::new();
Expand Down Expand Up @@ -144,3 +168,53 @@ pub(crate) fn spawn_execution(
}
})
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::{
array::{Float32Array, Float64Array},
datatypes::{DataType, Field, Schema},
record_batch::RecordBatch,
};

#[test]
fn test_combine_batches_empty() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("f32", DataType::Float32, false),
Field::new("f64", DataType::Float64, false),
]));
let result = combine_batches(&[], schema)?;
assert!(result.is_none());
Ok(())
}

#[test]
fn test_combine_batches() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("f32", DataType::Float32, false),
Field::new("f64", DataType::Float64, false),
]));

let batch_count = 1000;
let batch_size = 10;
let batches = (0..batch_count)
.map(|i| {
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Float32Array::from(vec![i as f32; batch_size])),
Arc::new(Float64Array::from(vec![i as f64; batch_size])),
],
)
.unwrap()
})
.collect::<Vec<_>>();

let result = combine_batches(&batches, schema)?;
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(batch_count * batch_size, result.num_rows());
Ok(())
}
}
86 changes: 32 additions & 54 deletions datafusion/src/physical_plan/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,28 @@

//! Defines the SORT plan

use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;

use async_trait::async_trait;
use futures::stream::Stream;
use futures::Future;
use hashbrown::HashMap;

use pin_project_lite::pin_project;

pub use arrow::compute::SortOptions;
use arrow::compute::{concat, lexsort_to_indices, take, SortColumn, TakeOptions};
use arrow::datatypes::SchemaRef;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, error::ArrowError};

use super::{RecordBatchStream, SendableRecordBatchStream};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::physical_plan::{
common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, SQLMetric,
};
pub use arrow::compute::SortOptions;
use arrow::compute::{lexsort_to_indices, take, SortColumn, TakeOptions};
use arrow::datatypes::SchemaRef;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, error::ArrowError};
use async_trait::async_trait;
use futures::stream::Stream;
use futures::Future;
use hashbrown::HashMap;
use pin_project_lite::pin_project;
use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;

/// Sort execution plan
#[derive(Debug)]
Expand Down Expand Up @@ -190,47 +186,25 @@ impl ExecutionPlan for SortExec {
}
}

fn sort_batches(
batches: &[RecordBatch],
schema: &SchemaRef,
fn sort_batch(
batch: RecordBatch,
schema: SchemaRef,
expr: &[PhysicalSortExpr],
) -> ArrowResult<Option<RecordBatch>> {
if batches.is_empty() {
return Ok(None);
}
// combine all record batches into one for each column
let combined_batch = RecordBatch::try_new(
schema.clone(),
schema
.fields()
.iter()
.enumerate()
.map(|(i, _)| {
concat(
&batches
.iter()
.map(|batch| batch.column(i).as_ref())
.collect::<Vec<_>>(),
)
})
.collect::<ArrowResult<Vec<ArrayRef>>>()?,
)?;

// sort combined record batch
) -> ArrowResult<RecordBatch> {
// TODO: pushup the limit expression to sort
let indices = lexsort_to_indices(
&expr
.iter()
.map(|e| e.evaluate_to_sort_column(&combined_batch))
.map(|e| e.evaluate_to_sort_column(&batch))
.collect::<Result<Vec<SortColumn>>>()
.map_err(DataFusionError::into_arrow_external_error)?,
None,
)?;

// reorder all rows based on sorted indices
let sorted_batch = RecordBatch::try_new(
schema.clone(),
combined_batch
RecordBatch::try_new(
schema,
batch
.columns()
.iter()
.map(|column| {
Expand All @@ -245,8 +219,7 @@ fn sort_batches(
)
})
.collect::<ArrowResult<Vec<ArrayRef>>>()?,
);
sorted_batch.map(Some)
)
}

pin_project! {
Expand Down Expand Up @@ -277,9 +250,14 @@ impl SortStream {
.map_err(DataFusionError::into_arrow_external_error)
.and_then(move |batches| {
let now = Instant::now();
let result = sort_batches(&batches, &schema, &expr);
// combine all record batches into one for each column
let combined = common::combine_batches(&batches, schema.clone())?;
// sort combined record batch
let result = combined
.map(|batch| sort_batch(batch, schema, &expr))
.transpose()?;
sort_time.add(now.elapsed().as_nanos() as usize);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note here the semantic is changed because on err it would no longer log a metric which shall be the correct behavior

result
Ok(result)
});

tx.send(sorted_batch)
Expand Down
22 changes: 22 additions & 0 deletions integration-tests/sqls/simple_sort.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at

-- http://www.apache.org/licenses/LICENSE-2.0

-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

SELECT
c2,
c3,
c10
FROM test
ORDER BY c2 ASC, c3 DESC, c10;
2 changes: 1 addition & 1 deletion integration-tests/test_psql_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase):
def test_parity(self):
root = Path(os.path.dirname(__file__)) / "sqls"
files = set(root.glob("*.sql"))
self.assertEqual(len(files), 5, msg="tests are missed")
self.assertEqual(len(files), 6, msg="tests are missed")
for fname in files:
with self.subTest(fname=fname):
datafusion_output = pd.read_csv(
Expand Down