Skip to content

Commit

Permalink
Preserve empty list array elements in take kernel (#3473)
Browse files Browse the repository at this point in the history
* Update test_take_list to fail with empty list

* Fix null_bit_buffer to match intended array

comment says: [[0,null,0], [-1,-2,3], null, [5,null]]
which implies null buffer of: 0b11111011

* Fix null_bit_buffer to match intended array

comment says: [[0,null,0], [-1,-2,3], [null], [5,null]]
which has not null values at the list level, or a null
buffer of 0b11111011

* Compute null buffer in take_value_indices_from_list

this way we can distinguish empty list elements from null elements.

* clippy
  • Loading branch information
jonmmease committed Jan 7, 2023
1 parent ca7ea59 commit c28d69a
Showing 1 changed file with 40 additions and 33 deletions.
73 changes: 40 additions & 33 deletions arrow-select/src/take.rs
Expand Up @@ -691,27 +691,11 @@ where
{
// TODO: Some optimizations can be done here such as if it is
// taking the whole list or a contiguous sublist
let (list_indices, offsets) =
let (list_indices, offsets, null_buf) =
take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;

let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices, None)?;
// determine null count and null buffer, which are a function of `values` and `indices`
let mut null_count = 0;
let num_bytes = bit_util::ceil(indices.len(), 8);
let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
{
let null_slice = null_buf.as_slice_mut();
offsets[..].windows(2).enumerate().for_each(
|(i, window): (usize, &[OffsetType::Native])| {
if window[0] == window[1] {
// offsets are equal, slot is null
bit_util::unset_bit(null_slice, i);
null_count += 1;
}
},
);
}
let value_offsets = Buffer::from_slice_ref(&offsets);
let value_offsets = Buffer::from_slice_ref(offsets);
// create a new list with taken data and computed null information
let list_data = ArrayDataBuilder::new(values.data_type().clone())
.len(indices.len())
Expand Down Expand Up @@ -831,10 +815,18 @@ where
/// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns
/// an array of the indices `[5..10, 0..2]` and offsets `[0,5,7]` (5 elements and 2
/// elements)
#[allow(clippy::type_complexity)]
fn take_value_indices_from_list<IndexType, OffsetType>(
list: &GenericListArray<OffsetType::Native>,
indices: &PrimitiveArray<IndexType>,
) -> Result<(PrimitiveArray<OffsetType>, Vec<OffsetType::Native>), ArrowError>
) -> Result<
(
PrimitiveArray<OffsetType>,
Vec<OffsetType::Native>,
MutableBuffer,
),
ArrowError,
>
where
IndexType: ArrowPrimitiveType,
IndexType::Native: ToPrimitive,
Expand All @@ -850,6 +842,12 @@ where
let mut current_offset = OffsetType::Native::zero();
// add first offset
new_offsets.push(OffsetType::Native::zero());

// Initialize null buffer
let num_bytes = bit_util::ceil(indices.len(), 8);
let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
let null_slice = null_buf.as_slice_mut();

// compute the value indices, and set offsets accordingly
for i in 0..indices.len() {
if indices.is_valid(i) {
Expand All @@ -868,12 +866,20 @@ where
values.push(Some(curr));
curr += num::One::one();
}
if !list.is_valid(ix) {
bit_util::unset_bit(null_slice, i);
}
} else {
bit_util::unset_bit(null_slice, i);
new_offsets.push(current_offset);
}
}

Ok((PrimitiveArray::<OffsetType>::from(values), new_offsets))
Ok((
PrimitiveArray::<OffsetType>::from(values),
new_offsets,
null_buf,
))
}

/// Takes/filters a fixed size list array's inner data using the offsets of the list array.
Expand Down Expand Up @@ -1519,12 +1525,12 @@ mod tests {

macro_rules! test_take_list {
($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
// Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]]
// Construct a value array, [[0,0,0], [-1,-2,-1], [], [2,3]]
let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3])
.data()
.clone();
// Construct offsets
let value_offsets: [$offset_type; 4] = [0, 3, 6, 8];
let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
let value_offsets = Buffer::from_slice_ref(&value_offsets);
// Construct a list array from the above two
let list_data_type = DataType::$list_data_type(Box::new(Field::new(
Expand All @@ -1533,38 +1539,36 @@ mod tests {
false,
)));
let list_data = ArrayData::builder(list_data_type.clone())
.len(3)
.len(4)
.add_buffer(value_offsets)
.add_child_data(value_data)
.build()
.unwrap();
let list_array = $list_array_type::from(list_data);

// index returns: [[2,3], null, [-1,-2,-1], [2,3], [0,0,0]]
let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(2), Some(0)]);
// index returns: [[2,3], null, [-1,-2,-1], [], [0,0,0]]
let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);

let a = take(&list_array, &index, None).unwrap();
let a: &$list_array_type =
a.as_any().downcast_ref::<$list_array_type>().unwrap();

// construct a value array with expected results:
// [[2,3], null, [-1,-2,-1], [2,3], [0,0,0]]
// [[2,3], null, [-1,-2,-1], [], [0,0,0]]
let expected_data = Int32Array::from(vec![
Some(2),
Some(3),
Some(-1),
Some(-2),
Some(-1),
Some(2),
Some(3),
Some(0),
Some(0),
Some(0),
])
.data()
.clone();
// construct offsets
let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 7, 10];
let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
// construct list array from the two
let expected_list_data = ArrayData::builder(list_data_type)
Expand Down Expand Up @@ -1609,7 +1613,7 @@ mod tests {
let list_data = ArrayData::builder(list_data_type.clone())
.len(4)
.add_buffer(value_offsets)
.null_bit_buffer(Some(Buffer::from([0b10111101, 0b00000000])))
.null_bit_buffer(Some(Buffer::from([0b11111111])))
.add_child_data(value_data)
.build()
.unwrap();
Expand Down Expand Up @@ -1682,7 +1686,7 @@ mod tests {
let list_data = ArrayData::builder(list_data_type.clone())
.len(4)
.add_buffer(value_offsets)
.null_bit_buffer(Some(Buffer::from([0b01111101])))
.null_bit_buffer(Some(Buffer::from([0b11111011])))
.add_child_data(value_data)
.build()
.unwrap();
Expand Down Expand Up @@ -2057,10 +2061,12 @@ mod tests {
]);
let indices = UInt32Array::from(vec![2, 0]);

let (indexed, offsets) = take_value_indices_from_list(&list, &indices).unwrap();
let (indexed, offsets, null_buf) =
take_value_indices_from_list(&list, &indices).unwrap();

assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
assert_eq!(offsets, vec![0, 5, 7]);
assert_eq!(null_buf.as_slice(), &[0b11111111]);
}

#[test]
Expand All @@ -2072,11 +2078,12 @@ mod tests {
]);
let indices = UInt32Array::from(vec![2, 0]);

let (indexed, offsets) =
let (indexed, offsets, null_buf) =
take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();

assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
assert_eq!(offsets, vec![0, 5, 7]);
assert_eq!(null_buf.as_slice(), &[0b11111111]);
}

#[test]
Expand Down

0 comments on commit c28d69a

Please sign in to comment.