Skip to content

Commit

Permalink
refactor sort exec stream and combine batches
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiayu Liu committed Jun 7, 2021
1 parent 767eeb0 commit de4055b
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 66 deletions.
96 changes: 85 additions & 11 deletions datafusion/src/physical_plan/common.rs
Expand Up @@ -17,24 +17,22 @@

//! Defines common code used in execution plans

use std::fs;
use std::fs::metadata;
use std::sync::Arc;
use std::task::{Context, Poll};

use super::{RecordBatchStream, SendableRecordBatchStream};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::ExecutionPlan;
use arrow::compute::concat;
use arrow::datatypes::SchemaRef;
use arrow::error::ArrowError;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use futures::channel::mpsc;
use futures::{SinkExt, Stream, StreamExt, TryStreamExt};
use std::fs;
use std::fs::metadata;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::task::JoinHandle;

use crate::arrow::error::ArrowError;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::ExecutionPlan;

use super::{RecordBatchStream, SendableRecordBatchStream};

/// Stream of record batches
pub struct SizedRecordBatchStream {
schema: SchemaRef,
Expand Down Expand Up @@ -83,6 +81,32 @@ pub async fn collect(stream: SendableRecordBatchStream) -> Result<Vec<RecordBatc
.map_err(DataFusionError::from)
}

/// Combine a slice of record batches into one, or returns None if the slice itself
/// is empty; all the record batches inside the slice must be of the same schema.
pub(crate) fn combine_batches(
batches: &[RecordBatch],
schema: SchemaRef,
) -> ArrowResult<Option<RecordBatch>> {
if batches.is_empty() {
Ok(None)
} else {
let columns = schema
.fields()
.iter()
.enumerate()
.map(|(i, _)| {
concat(
&batches
.iter()
.map(|batch| batch.column(i).as_ref())
.collect::<Vec<_>>(),
)
})
.collect::<ArrowResult<Vec<_>>>()?;
Ok(Some(RecordBatch::try_new(schema.clone(), columns)?))
}
}

/// Recursively builds a list of files in a directory with a given extension
pub fn build_file_list(dir: &str, ext: &str) -> Result<Vec<String>> {
let mut filenames: Vec<String> = Vec::new();
Expand Down Expand Up @@ -144,3 +168,53 @@ pub(crate) fn spawn_execution(
}
})
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::{
array::{Float32Array, Float64Array},
datatypes::{DataType, Field, Schema},
record_batch::RecordBatch,
};

#[test]
fn test_combine_batches_empty() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("f32", DataType::Float32, false),
Field::new("f64", DataType::Float64, false),
]));
let result = combine_batches(&[], schema)?;
assert!(result.is_none());
Ok(())
}

#[test]
fn test_combine_batches() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("f32", DataType::Float32, false),
Field::new("f64", DataType::Float64, false),
]));

let batch_count = 1000;
let batch_size = 10;
let batches = (0..batch_count)
.map(|i| {
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Float32Array::from(vec![i as f32; batch_size])),
Arc::new(Float64Array::from(vec![i as f64; batch_size])),
],
)
.unwrap()
})
.collect::<Vec<_>>();

let result = combine_batches(&batches, schema)?;
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(batch_count * batch_size, result.num_rows());
Ok(())
}
}
86 changes: 32 additions & 54 deletions datafusion/src/physical_plan/sort.rs
Expand Up @@ -17,32 +17,28 @@

//! Defines the SORT plan

use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;

use async_trait::async_trait;
use futures::stream::Stream;
use futures::Future;
use hashbrown::HashMap;

use pin_project_lite::pin_project;

pub use arrow::compute::SortOptions;
use arrow::compute::{concat, lexsort_to_indices, take, SortColumn, TakeOptions};
use arrow::datatypes::SchemaRef;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, error::ArrowError};

use super::{RecordBatchStream, SendableRecordBatchStream};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::physical_plan::{
common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, SQLMetric,
};
pub use arrow::compute::SortOptions;
use arrow::compute::{lexsort_to_indices, take, SortColumn, TakeOptions};
use arrow::datatypes::SchemaRef;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, error::ArrowError};
use async_trait::async_trait;
use futures::stream::Stream;
use futures::Future;
use hashbrown::HashMap;
use pin_project_lite::pin_project;
use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;

/// Sort execution plan
#[derive(Debug)]
Expand Down Expand Up @@ -190,47 +186,25 @@ impl ExecutionPlan for SortExec {
}
}

fn sort_batches(
batches: &[RecordBatch],
schema: &SchemaRef,
fn sort_batch(
batch: RecordBatch,
schema: SchemaRef,
expr: &[PhysicalSortExpr],
) -> ArrowResult<Option<RecordBatch>> {
if batches.is_empty() {
return Ok(None);
}
// combine all record batches into one for each column
let combined_batch = RecordBatch::try_new(
schema.clone(),
schema
.fields()
.iter()
.enumerate()
.map(|(i, _)| {
concat(
&batches
.iter()
.map(|batch| batch.column(i).as_ref())
.collect::<Vec<_>>(),
)
})
.collect::<ArrowResult<Vec<ArrayRef>>>()?,
)?;

// sort combined record batch
) -> ArrowResult<RecordBatch> {
// TODO: pushup the limit expression to sort
let indices = lexsort_to_indices(
&expr
.iter()
.map(|e| e.evaluate_to_sort_column(&combined_batch))
.map(|e| e.evaluate_to_sort_column(&batch))
.collect::<Result<Vec<SortColumn>>>()
.map_err(DataFusionError::into_arrow_external_error)?,
None,
)?;

// reorder all rows based on sorted indices
let sorted_batch = RecordBatch::try_new(
schema.clone(),
combined_batch
RecordBatch::try_new(
schema,
batch
.columns()
.iter()
.map(|column| {
Expand All @@ -245,8 +219,7 @@ fn sort_batches(
)
})
.collect::<ArrowResult<Vec<ArrayRef>>>()?,
);
sorted_batch.map(Some)
)
}

pin_project! {
Expand Down Expand Up @@ -277,9 +250,14 @@ impl SortStream {
.map_err(DataFusionError::into_arrow_external_error)
.and_then(move |batches| {
let now = Instant::now();
let result = sort_batches(&batches, &schema, &expr);
// combine all record batches into one for each column
let combined = common::combine_batches(&batches, schema.clone())?;
// sort combined record batch
let result = combined
.map(|batch| sort_batch(batch, schema, &expr))
.transpose()?;
sort_time.add(now.elapsed().as_nanos() as usize);
result
Ok(result)
});

tx.send(sorted_batch)
Expand Down
22 changes: 22 additions & 0 deletions integration-tests/sqls/simple_sort.sql
@@ -0,0 +1,22 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at

-- http://www.apache.org/licenses/LICENSE-2.0

-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

SELECT
c2,
c3,
c10
FROM test
ORDER BY c2 ASC, c3 DESC, c10;
2 changes: 1 addition & 1 deletion integration-tests/test_psql_parity.py
Expand Up @@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase):
def test_parity(self):
root = Path(os.path.dirname(__file__)) / "sqls"
files = set(root.glob("*.sql"))
self.assertEqual(len(files), 5, msg="tests are missed")
self.assertEqual(len(files), 6, msg="tests are missed")
for fname in files:
with self.subTest(fname=fname):
datafusion_output = pd.read_csv(
Expand Down

0 comments on commit de4055b

Please sign in to comment.