Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,49 @@ mod tests {
assert!((30..40).contains(&values_len), "{values_len}")
}

#[test]
fn test_primitive_dictionary_merge() {
// Same value repeated 5 times.
let keys = vec![1; 5];
let values = (10..20).collect::<Vec<_>>();
let dict = DictionaryArray::new(
Int8Array::from(keys.clone()),
Arc::new(Int32Array::from(values.clone())),
);
let other = DictionaryArray::new(
Int8Array::from(keys.clone()),
Arc::new(Int32Array::from(values.clone())),
);

let result_same_dictionary = concat(&[&dict, &dict]).unwrap();
// Verify pointer equality check succeeds, and therefore the
// dictionaries are not merged. A single values buffer should be reused
// in this case.
assert!(dict.values().to_data().ptr_eq(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

&result_same_dictionary
.as_dictionary::<Int8Type>()
.values()
.to_data()
));
assert_eq!(
result_same_dictionary
.as_dictionary::<Int8Type>()
.values()
.len(),
values.len(),
);

let result_cloned_dictionary = concat(&[&dict, &other]).unwrap();
// Should have only 1 underlying value since all keys reference it.
assert_eq!(
result_cloned_dictionary
.as_dictionary::<Int8Type>()
.values()
.len(),
1
);
}

#[test]
fn test_concat_string_sizes() {
let a: LargeStringArray = ((0..150).map(|_| Some("foo"))).collect();
Expand Down
58 changes: 45 additions & 13 deletions arrow-select/src/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
use crate::interleave::interleave;
use ahash::RandomState;
use arrow_array::builder::BooleanBufferBuilder;
use arrow_array::cast::AsArray;
use arrow_array::types::{
ArrowDictionaryKeyType, BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, Utf8Type,
ArrowDictionaryKeyType, ArrowPrimitiveType, BinaryType, ByteArrayType, LargeBinaryType,
LargeUtf8Type, Utf8Type,
};
use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray};
use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer};
use arrow_array::{cast::AsArray, downcast_primitive};
use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray, PrimitiveArray};
use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer, ToByteSlice};
use arrow_schema::{ArrowError, DataType};

/// A best effort interner that maintains a fixed number of buckets
Expand Down Expand Up @@ -102,7 +103,7 @@ fn bytes_ptr_eq<T: ByteArrayType>(a: &dyn Array, b: &dyn Array) -> bool {
}

/// A type-erased function that compares two array for pointer equality
type PtrEq = dyn Fn(&dyn Array, &dyn Array) -> bool;
type PtrEq = fn(&dyn Array, &dyn Array) -> bool;

/// A weak heuristic of whether to merge dictionary values that aims to only
/// perform the expensive merge computation when it is likely to yield at least
Expand All @@ -115,12 +116,17 @@ pub fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
) -> bool {
use DataType::*;
let first_values = dictionaries[0].values().as_ref();
let ptr_eq: Box<PtrEq> = match first_values.data_type() {
Utf8 => Box::new(bytes_ptr_eq::<Utf8Type>),
LargeUtf8 => Box::new(bytes_ptr_eq::<LargeUtf8Type>),
Binary => Box::new(bytes_ptr_eq::<BinaryType>),
LargeBinary => Box::new(bytes_ptr_eq::<LargeBinaryType>),
_ => return false,
let ptr_eq: PtrEq = match first_values.data_type() {
Utf8 => bytes_ptr_eq::<Utf8Type>,
LargeUtf8 => bytes_ptr_eq::<LargeUtf8Type>,
Binary => bytes_ptr_eq::<BinaryType>,
LargeBinary => bytes_ptr_eq::<LargeBinaryType>,
dt => {
if !dt.is_primitive() {
return false;
}
|a, b| a.to_data().ptr_eq(&b.to_data())
}
};

let mut single_dictionary = true;
Expand Down Expand Up @@ -233,17 +239,43 @@ fn compute_values_mask<K: ArrowNativeType>(
builder.finish()
}

/// Process primitive array values to bytes
fn masked_primitives_to_bytes<'a, T: ArrowPrimitiveType>(
array: &'a PrimitiveArray<T>,
mask: &BooleanBuffer,
) -> Vec<(usize, Option<&'a [u8]>)>
where
T::Native: ToByteSlice,
{
let mut out = Vec::with_capacity(mask.count_set_bits());
let values = array.values();
for idx in mask.set_indices() {
out.push((
idx,
array.is_valid(idx).then_some(values[idx].to_byte_slice()),
))
}
out
}

macro_rules! masked_primitive_to_bytes_helper {
($t:ty, $array:expr, $mask:expr) => {
masked_primitives_to_bytes::<$t>($array.as_primitive(), $mask)
};
}

/// Return a Vec containing for each set index in `mask`, the index and byte value of that index
fn get_masked_values<'a>(
array: &'a dyn Array,
mask: &BooleanBuffer,
) -> Vec<(usize, Option<&'a [u8]>)> {
match array.data_type() {
downcast_primitive! {
array.data_type() => (masked_primitive_to_bytes_helper, array, mask),
DataType::Utf8 => masked_bytes(array.as_string::<i32>(), mask),
DataType::LargeUtf8 => masked_bytes(array.as_string::<i64>(), mask),
DataType::Binary => masked_bytes(array.as_binary::<i32>(), mask),
DataType::LargeBinary => masked_bytes(array.as_binary::<i64>(), mask),
_ => unimplemented!(),
_ => unimplemented!("Dictionary merging for type {} is not implemented", array.data_type()),
}
}

Expand Down
Loading