From ec449104d7ca6fe9b843fdc03e8338ecb768af3c Mon Sep 17 00:00:00 2001 From: Jay Miller <3744812+jaylmiller@users.noreply.github.com> Date: Fri, 10 Feb 2023 11:56:26 -0500 Subject: [PATCH] modify sort_batch to use arrow row format for multi-column sorts --- .../core/src/physical_plan/sorts/sort.rs | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index ad08504c3b86..286bfab02d12 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -38,13 +38,14 @@ use crate::physical_plan::{ RecordBatchStream, SendableRecordBatchStream, Statistics, }; use crate::prelude::SessionConfig; -use arrow::array::{make_array, Array, ArrayRef, MutableArrayData}; +use arrow::array::{make_array, Array, ArrayRef, MutableArrayData, UInt32Array}; pub use arrow::compute::SortOptions; use arrow::compute::{concat, lexsort_to_indices, take, SortColumn, TakeOptions}; use arrow::datatypes::SchemaRef; use arrow::error::ArrowError; use arrow::ipc::reader::FileReader; use arrow::record_batch::RecordBatch; +use arrow::row::{Row, RowConverter, SortField}; use datafusion_physical_expr::EquivalenceProperties; use futures::{Stream, StreamExt, TryStreamExt}; use log::{debug, error}; @@ -820,7 +821,29 @@ fn sort_batch( .map(|e| e.evaluate_to_sort_column(&batch)) .collect::>>()?; - let indices = lexsort_to_indices(&sort_columns, fetch)?; + let indices = if sort_columns.len() == 1 { + lexsort_to_indices(&sort_columns, fetch)? + } else { + let sort_fields = sort_columns + .iter() + .map(|c| { + let datatype = c.values.data_type().to_owned(); + SortField::new_with_options(datatype, c.options.unwrap_or_default()) + }) + .collect::>(); + let arrays: Vec = + sort_columns.iter().map(|c| c.values.clone()).collect(); + let mut row_converter = RowConverter::new(sort_fields)?; + let rows = row_converter.convert_columns(&arrays)?; + + let mut to_sort: Vec<(usize, Row)> = rows.into_iter().enumerate().collect(); + to_sort.sort_unstable_by(|(_, row_a), (_, row_b)| row_a.cmp(row_b)); + let limit = match fetch { + Some(lim) => lim.min(to_sort.len()), + None => to_sort.len(), + }; + UInt32Array::from_iter(to_sort.into_iter().take(limit).map(|(idx, _)| idx as u32)) + }; // reorder all rows based on sorted indices let sorted_batch = RecordBatch::try_new(