Skip to content
Permalink
Browse files
Support dictionary arrays in length and bit_length (#1674)
* Support dictionary arrays in length and bit_length

* Fix typo
  • Loading branch information
viirya committed May 9, 2022
1 parent daed6ab commit 42c9e025c4af3958a4d45ae54d2d9d267cce2fa6
Showing 1 changed file with 160 additions and 5 deletions.
@@ -19,10 +19,12 @@

use crate::{array::*, buffer::Buffer, datatypes::ArrowPrimitiveType};
use crate::{
datatypes::{DataType, Int32Type, Int64Type},
datatypes::*,
error::{ArrowError, Result},
};

use std::sync::Arc;

macro_rules! unary_offsets {
($array: expr, $data_type: expr, $op: expr) => {{
let slice = $array.value_offsets();
@@ -56,6 +58,27 @@ macro_rules! unary_offsets {
}};
}

macro_rules! kernel_dict {
($array: ident, $kernel: expr, $kt: ident, $($t: ident: $gt: ident), *) => {
match $kt.as_ref() {
$(&DataType::$t => {
let dict = $array
.as_any()
.downcast_ref::<DictionaryArray<$gt>>()
.unwrap_or_else(|| {
panic!("Expect 'DictionaryArray<{}>' but got array of data type {:?}",
stringify!($gt), $array.data_type())
});
let values = $kernel(dict.values())?;
let result = DictionaryArray::try_new(dict.keys(), &values)?;
Ok(Arc::new(result))
},
)*
t => panic!("Unsupported dictionary key type: {}", t)
}
}
}

fn length_list<O, T>(array: &dyn Array) -> ArrayRef
where
O: OffsetSizeTrait,
@@ -127,10 +150,26 @@ where
/// For list array, length is the number of elements in each list.
/// For string array and binary array, length is the number of bytes of each value.
///
/// * this only accepts ListArray/LargeListArray, StringArray/LargeStringArray and BinaryArray/LargeBinaryArray
/// * this only accepts ListArray/LargeListArray, StringArray/LargeStringArray and BinaryArray/LargeBinaryArray,
/// or DictionaryArray with above Arrays as values
/// * length of null is null.
pub fn length(array: &dyn Array) -> Result<ArrayRef> {
match array.data_type() {
DataType::Dictionary(kt, _) => {
kernel_dict!(
array,
|a| { length(a) },
kt,
Int8: Int8Type,
Int16: Int16Type,
Int32: Int32Type,
Int64: Int64Type,
UInt8: UInt8Type,
UInt16: UInt16Type,
UInt32: UInt32Type,
UInt64: UInt64Type
)
}
DataType::List(_) => Ok(length_list::<i32, Int32Type>(array)),
DataType::LargeList(_) => Ok(length_list::<i64, Int64Type>(array)),
DataType::Utf8 => Ok(length_string::<i32, Int32Type>(array)),
@@ -146,11 +185,27 @@ pub fn length(array: &dyn Array) -> Result<ArrayRef> {

/// Returns an array of Int32/Int64 denoting the number of bits in each value in the array.
///
/// * this only accepts StringArray/Utf8, LargeString/LargeUtf8, BinaryArray and LargeBinaryArray
/// * this only accepts StringArray/Utf8, LargeString/LargeUtf8, BinaryArray and LargeBinaryArray,
/// or DictionaryArray with above Arrays as values
/// * bit_length of null is null.
/// * bit_length is in number of bits
pub fn bit_length(array: &dyn Array) -> Result<ArrayRef> {
match array.data_type() {
DataType::Dictionary(kt, _) => {
kernel_dict!(
array,
|a| { bit_length(a) },
kt,
Int8: Int8Type,
Int16: Int16Type,
Int32: Int32Type,
Int64: Int64Type,
UInt8: UInt8Type,
UInt16: UInt16Type,
UInt32: UInt32Type,
UInt64: UInt64Type
)
}
DataType::Utf8 => Ok(bit_length_string::<i32, Int32Type>(array)),
DataType::LargeUtf8 => Ok(bit_length_string::<i64, Int64Type>(array)),
DataType::Binary => Ok(bit_length_binary::<i32, Int32Type>(array)),
@@ -164,8 +219,6 @@ pub fn bit_length(array: &dyn Array) -> Result<ArrayRef> {

#[cfg(test)]
mod tests {
use crate::datatypes::{Float32Type, Int8Type};

use super::*;

fn double_vec<T: Clone>(v: Vec<T>) -> Vec<T> {
@@ -570,4 +623,106 @@ mod tests {

Ok(())
}

#[test]
fn length_dictionary() -> Result<()> {
_length_dictionary::<Int8Type>()?;
_length_dictionary::<Int16Type>()?;
_length_dictionary::<Int32Type>()?;
_length_dictionary::<Int64Type>()?;
_length_dictionary::<UInt8Type>()?;
_length_dictionary::<UInt16Type>()?;
_length_dictionary::<UInt32Type>()?;
_length_dictionary::<UInt64Type>()?;
Ok(())
}

fn _length_dictionary<K: ArrowDictionaryKeyType>() -> Result<()> {
const TOTAL: i32 = 100;

let v = ["aaaa", "bb", "ccccc", "ddd", "eeeeee"];
let data: Vec<Option<&str>> = (0..TOTAL)
.map(|n| {
let i = n % 5;
if i == 3 {
None
} else {
Some(v[i as usize])
}
})
.collect();

let dict_array: DictionaryArray<K> = data.clone().into_iter().collect();

let expected: Vec<Option<i32>> =
data.iter().map(|opt| opt.map(|s| s.len() as i32)).collect();

let res = length(&dict_array)?;
let actual = res.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
let actual: Vec<Option<i32>> = actual
.values()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.take_iter(dict_array.keys_iter())
.collect();

for i in 0..TOTAL as usize {
assert_eq!(expected[i], actual[i],);
}

Ok(())
}

#[test]
fn bit_length_dictionary() -> Result<()> {
_bit_length_dictionary::<Int8Type>()?;
_bit_length_dictionary::<Int16Type>()?;
_bit_length_dictionary::<Int32Type>()?;
_bit_length_dictionary::<Int64Type>()?;
_bit_length_dictionary::<UInt8Type>()?;
_bit_length_dictionary::<UInt16Type>()?;
_bit_length_dictionary::<UInt32Type>()?;
_bit_length_dictionary::<UInt64Type>()?;
Ok(())
}

fn _bit_length_dictionary<K: ArrowDictionaryKeyType>() -> Result<()> {
const TOTAL: i32 = 100;

let v = ["aaaa", "bb", "ccccc", "ddd", "eeeeee"];
let data: Vec<Option<&str>> = (0..TOTAL)
.map(|n| {
let i = n % 5;
if i == 3 {
None
} else {
Some(v[i as usize])
}
})
.collect();

let dict_array: DictionaryArray<K> = data.clone().into_iter().collect();

let expected: Vec<Option<i32>> = data
.iter()
.map(|opt| opt.map(|s| (s.chars().count() * 8) as i32))
.collect();

let res = bit_length(&dict_array)?;
let actual = res.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
let actual: Vec<Option<i32>> = actual
.values()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.take_iter(dict_array.keys_iter())
.collect();

for i in 0..TOTAL as usize {
assert_eq!(expected[i], actual[i],);
}

Ok(())
}
}

0 comments on commit 42c9e02

Please sign in to comment.