-
Notifications
You must be signed in to change notification settings - Fork 657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support sorting dictionary encoded primitive integer arrays #2680
Changes from 2 commits
e66d133
3697abc
286d9a9
6a778cf
93708c0
2ed13a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -21,6 +21,7 @@ use crate::array::*; | |||||
use crate::buffer::MutableBuffer; | ||||||
use crate::compute::take; | ||||||
use crate::datatypes::*; | ||||||
use crate::downcast_dictionary_array; | ||||||
use crate::error::{ArrowError, Result}; | ||||||
use std::cmp::Ordering; | ||||||
use TimeUnit::*; | ||||||
|
@@ -311,41 +312,58 @@ pub fn sort_to_indices( | |||||
))); | ||||||
} | ||||||
}, | ||||||
DataType::Dictionary(key_type, value_type) | ||||||
if *value_type.as_ref() == DataType::Utf8 => | ||||||
{ | ||||||
match key_type.as_ref() { | ||||||
DataType::Int8 => { | ||||||
sort_string_dictionary::<Int8Type>(values, v, n, &options, limit) | ||||||
} | ||||||
DataType::Int16 => { | ||||||
sort_string_dictionary::<Int16Type>(values, v, n, &options, limit) | ||||||
} | ||||||
DataType::Int32 => { | ||||||
sort_string_dictionary::<Int32Type>(values, v, n, &options, limit) | ||||||
} | ||||||
DataType::Int64 => { | ||||||
sort_string_dictionary::<Int64Type>(values, v, n, &options, limit) | ||||||
} | ||||||
DataType::UInt8 => { | ||||||
sort_string_dictionary::<UInt8Type>(values, v, n, &options, limit) | ||||||
} | ||||||
DataType::UInt16 => { | ||||||
sort_string_dictionary::<UInt16Type>(values, v, n, &options, limit) | ||||||
} | ||||||
DataType::UInt32 => { | ||||||
sort_string_dictionary::<UInt32Type>(values, v, n, &options, limit) | ||||||
} | ||||||
DataType::UInt64 => { | ||||||
sort_string_dictionary::<UInt64Type>(values, v, n, &options, limit) | ||||||
} | ||||||
t => { | ||||||
return Err(ArrowError::ComputeError(format!( | ||||||
"Sort not supported for dictionary key type {:?}", | ||||||
t | ||||||
))); | ||||||
} | ||||||
} | ||||||
DataType::Dictionary(_, _) => { | ||||||
downcast_dictionary_array!( | ||||||
values => match values.values().data_type() { | ||||||
DataType::Int8 => { | ||||||
let dict_values = values.values(); | ||||||
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), None)?; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to have fixed sort options (default) when sorting the values of dictionary. As we will compare keys based on the sorted indices. The sorted indices represent the order of values which is fixed. For example, a dictionary (keys = [1, 2, 0], values = [0, 1, 2]). No matter we sort in ascending order or in descending order, when we compare key = 1 and key = 2, their relation is always (key = 1) < (key = 2) according to the values behind the keys. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current test can catch an error if we use the same option to sort the values of the dictionary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aah, you raise a good point. I think you may still need to propagate the null ordering though? Edit: I think you need to pass the full options here, and then drop the descending for the second sort. But I could be wrong, multi-level nullability is just 🤯 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, right, we need to keep the given null ordering when sorting the values of the dictionary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Propagating null ordering now and added one test. |
||||||
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices, v, n, &options, limit, cmp) | ||||||
}, | ||||||
DataType::Int16 => { | ||||||
let dict_values = values.values(); | ||||||
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), None)?; | ||||||
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices, v, n, &options, limit, cmp) | ||||||
}, | ||||||
DataType::Int32 => { | ||||||
let dict_values = values.values(); | ||||||
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), None)?; | ||||||
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices, v, n, &options, limit, cmp) | ||||||
}, | ||||||
DataType::Int64 => { | ||||||
let dict_values = values.values(); | ||||||
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), None)?; | ||||||
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices,v, n, &options, limit, cmp) | ||||||
}, | ||||||
DataType::UInt8 => { | ||||||
let dict_values = values.values(); | ||||||
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), None)?; | ||||||
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices,v, n, &options, limit, cmp) | ||||||
}, | ||||||
DataType::UInt16 => { | ||||||
let dict_values = values.values(); | ||||||
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), None)?; | ||||||
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices,v, n, &options, limit, cmp) | ||||||
}, | ||||||
DataType::UInt32 => { | ||||||
let dict_values = values.values(); | ||||||
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), None)?; | ||||||
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices,v, n, &options, limit, cmp) | ||||||
}, | ||||||
DataType::UInt64 => { | ||||||
let dict_values = values.values(); | ||||||
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), None)?; | ||||||
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices, v, n, &options, limit, cmp) | ||||||
}, | ||||||
DataType::Utf8 => sort_string_dictionary::<_>(values, v, n, &options, limit), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||||||
t => return Err(ArrowError::ComputeError(format!( | ||||||
"Unsupported dictionary value type {}", t | ||||||
))), | ||||||
}, | ||||||
t => return Err(ArrowError::ComputeError(format!( | ||||||
"Unsupported datatype {}", t | ||||||
))), | ||||||
) | ||||||
} | ||||||
DataType::Binary | DataType::FixedSizeBinary(_) => { | ||||||
sort_binary::<i32>(values, v, n, &options, limit) | ||||||
|
@@ -489,7 +507,14 @@ where | |||||
.into_iter() | ||||||
.map(|index| (index, decimal_array.value(index as usize).as_i128())) | ||||||
.collect::<Vec<(u32, i128)>>(); | ||||||
sort_primitive_inner(decimal_values, null_indices, cmp, options, limit, valids) | ||||||
sort_primitive_inner( | ||||||
decimal_values.len(), | ||||||
null_indices, | ||||||
cmp, | ||||||
options, | ||||||
limit, | ||||||
valids, | ||||||
) | ||||||
} | ||||||
|
||||||
/// Sort primitive values | ||||||
|
@@ -514,12 +539,40 @@ where | |||||
.map(|index| (index, values.value(index as usize))) | ||||||
.collect::<Vec<(u32, T::Native)>>() | ||||||
}; | ||||||
sort_primitive_inner(values, null_indices, cmp, options, limit, valids) | ||||||
sort_primitive_inner(values.len(), null_indices, cmp, options, limit, valids) | ||||||
} | ||||||
|
||||||
/// Sort dictionary encoded primitive values | ||||||
fn sort_primitive_dictionary<K, F>( | ||||||
values: &DictionaryArray<K>, | ||||||
sorted_value_indices: &UInt32Array, | ||||||
value_indices: Vec<u32>, | ||||||
null_indices: Vec<u32>, | ||||||
options: &SortOptions, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW |
||||||
limit: Option<usize>, | ||||||
cmp: F, | ||||||
) -> UInt32Array | ||||||
where | ||||||
K: ArrowDictionaryKeyType, | ||||||
F: Fn(u32, u32) -> std::cmp::Ordering, | ||||||
{ | ||||||
let keys: &PrimitiveArray<K> = values.keys(); | ||||||
|
||||||
// create tuples that are used for sorting | ||||||
let valids = value_indices | ||||||
.into_iter() | ||||||
.map(|index| { | ||||||
let key: K::Native = keys.value(index as usize); | ||||||
(index, sorted_value_indices.value(key.to_usize().unwrap())) | ||||||
}) | ||||||
.collect::<Vec<(u32, u32)>>(); | ||||||
|
||||||
sort_primitive_inner::<_, _>(keys.len(), null_indices, cmp, options, limit, valids) | ||||||
} | ||||||
|
||||||
// sort is instantiated a lot so we only compile this inner version for each native type | ||||||
fn sort_primitive_inner<T, F>( | ||||||
values: &ArrayRef, | ||||||
value_len: usize, | ||||||
null_indices: Vec<u32>, | ||||||
cmp: F, | ||||||
options: &SortOptions, | ||||||
|
@@ -535,7 +588,7 @@ where | |||||
|
||||||
let valids_len = valids.len(); | ||||||
let nulls_len = nulls.len(); | ||||||
let mut len = values.len(); | ||||||
let mut len = value_len; | ||||||
|
||||||
if let Some(limit) = limit { | ||||||
len = limit.min(len); | ||||||
|
@@ -620,14 +673,12 @@ fn sort_string<Offset: OffsetSizeTrait>( | |||||
|
||||||
/// Sort dictionary encoded strings | ||||||
fn sort_string_dictionary<T: ArrowDictionaryKeyType>( | ||||||
values: &ArrayRef, | ||||||
values: &DictionaryArray<T>, | ||||||
value_indices: Vec<u32>, | ||||||
null_indices: Vec<u32>, | ||||||
options: &SortOptions, | ||||||
limit: Option<usize>, | ||||||
) -> UInt32Array { | ||||||
let values: &DictionaryArray<T> = as_dictionary_array::<T>(values); | ||||||
|
||||||
let keys: &PrimitiveArray<T> = values.keys(); | ||||||
|
||||||
let dict = values.values(); | ||||||
|
@@ -1239,6 +1290,58 @@ mod tests { | |||||
assert_eq!(sorted_strings, expected) | ||||||
} | ||||||
|
||||||
fn test_sort_primitive_dict_arrays<K: ArrowDictionaryKeyType, T: ArrowPrimitiveType>( | ||||||
keys: PrimitiveArray<K>, | ||||||
values: PrimitiveArray<T>, | ||||||
options: Option<SortOptions>, | ||||||
limit: Option<usize>, | ||||||
expected_data: Vec<Option<T::Native>>, | ||||||
) where | ||||||
PrimitiveArray<T>: From<Vec<Option<T::Native>>>, | ||||||
{ | ||||||
let array = DictionaryArray::<K>::try_new(&keys, &values).unwrap(); | ||||||
let array_values = array.values().clone(); | ||||||
let dict = array_values | ||||||
.as_any() | ||||||
.downcast_ref::<PrimitiveArray<T>>() | ||||||
.expect("Unable to get dictionary values"); | ||||||
|
||||||
let sorted = match limit { | ||||||
Some(_) => { | ||||||
sort_limit(&(Arc::new(array) as ArrayRef), options, limit).unwrap() | ||||||
} | ||||||
_ => sort(&(Arc::new(array) as ArrayRef), options).unwrap(), | ||||||
}; | ||||||
let sorted = sorted | ||||||
.as_any() | ||||||
.downcast_ref::<DictionaryArray<K>>() | ||||||
.unwrap(); | ||||||
let sorted_values = sorted.values(); | ||||||
let sorted_dict = sorted_values | ||||||
.as_any() | ||||||
.downcast_ref::<PrimitiveArray<T>>() | ||||||
.expect("Unable to get dictionary values"); | ||||||
let sorted_keys = sorted.keys(); | ||||||
|
||||||
assert_eq!(sorted_dict, dict); | ||||||
|
||||||
let sorted_values: PrimitiveArray<T> = From::<Vec<Option<T::Native>>>::from( | ||||||
(0..sorted.len()) | ||||||
.map(|i| { | ||||||
if sorted.is_valid(i) { | ||||||
Some(sorted_dict.value(sorted_keys.value(i).to_usize().unwrap())) | ||||||
} else { | ||||||
None | ||||||
} | ||||||
}) | ||||||
.collect::<Vec<Option<T::Native>>>(), | ||||||
); | ||||||
let expected: PrimitiveArray<T> = | ||||||
From::<Vec<Option<T::Native>>>::from(expected_data); | ||||||
|
||||||
assert_eq!(sorted_values, expected) | ||||||
} | ||||||
|
||||||
fn test_sort_list_arrays<T>( | ||||||
data: Vec<Option<Vec<Option<T::Native>>>>, | ||||||
options: Option<SortOptions>, | ||||||
|
@@ -3222,4 +3325,60 @@ mod tests { | |||||
partial_sort(&mut before, last, |a, b| a.cmp(b)); | ||||||
assert_eq!(&d[0..last], &before[0..last]); | ||||||
} | ||||||
|
||||||
#[test] | ||||||
fn test_sort_int8_dicts() { | ||||||
let keys = | ||||||
Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]); | ||||||
let values = Int8Array::from(vec![1, 3, 5]); | ||||||
test_sort_primitive_dict_arrays::<Int8Type, Int8Type>( | ||||||
keys, | ||||||
values, | ||||||
None, | ||||||
None, | ||||||
vec![None, None, Some(1), Some(3), Some(5), Some(5)], | ||||||
); | ||||||
|
||||||
let keys = | ||||||
Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]); | ||||||
let values = Int8Array::from(vec![1, 3, 5]); | ||||||
test_sort_primitive_dict_arrays::<Int8Type, Int8Type>( | ||||||
keys, | ||||||
values, | ||||||
Some(SortOptions { | ||||||
descending: true, | ||||||
nulls_first: false, | ||||||
}), | ||||||
None, | ||||||
vec![Some(5), Some(5), Some(3), Some(1), None, None], | ||||||
); | ||||||
|
||||||
let keys = | ||||||
Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]); | ||||||
let values = Int8Array::from(vec![1, 3, 5]); | ||||||
test_sort_primitive_dict_arrays::<Int8Type, Int8Type>( | ||||||
keys, | ||||||
values, | ||||||
Some(SortOptions { | ||||||
descending: false, | ||||||
nulls_first: false, | ||||||
}), | ||||||
None, | ||||||
vec![Some(1), Some(3), Some(5), Some(5), None, None], | ||||||
); | ||||||
|
||||||
let keys = | ||||||
Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]); | ||||||
let values = Int8Array::from(vec![1, 3, 5]); | ||||||
test_sort_primitive_dict_arrays::<Int8Type, Int8Type>( | ||||||
keys, | ||||||
values, | ||||||
Some(SortOptions { | ||||||
descending: false, | ||||||
nulls_first: true, | ||||||
}), | ||||||
Some(3), | ||||||
vec![None, None, Some(1)], | ||||||
); | ||||||
} | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NGL I worry a bit that this has the combinatorial fanout that absolutely tanks compile times... Perhaps we could compute the sort order of the dictionary values and then use this to compare the keys? This might even be faster
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I agree that sorting the dictionary values and using it to compare the keys might be faster. But does it help on combinatorial fanout from dictionary? For example, in order to sort on the dictionary values, we still need get the value array from dictionary array so a
downcast_dictionary_array!
is also necessary.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, you could call sort_to_indices on the values array, which is only typed on the values type, and then to compare dictionary values you compare the indices you just computed, which is only typed on the dictionary key type. You therefore avoid the fanout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So
downcast_dictionary_array
is not the one you concern butdowncast_dictionary_array
+sort_primitive_dictionary
which is typed on both key and value types?Then it makes sense to me. I can refactor this and split sorting dictionary to sorting dictionary values and sorting keys based on the computed indices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it is the combinatorial fanout that is especially painful, and given it is avoidable I think it makes sense to do so. It's a case of every little helps 😅