Skip to content

Commit

Permalink
implement bit_length function
Browse files Browse the repository at this point in the history
  • Loading branch information
houqp committed Jan 18, 2022
1 parent 8a6fb2c commit 60e869e
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion datafusion/src/physical_plan/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ use arrow::{
compute::length::length,
datatypes::TimeUnit,
datatypes::{DataType, Field, Schema},
error::{ArrowError, Result as ArrowResult},
record_batch::RecordBatch,
types::NativeType,
};
use fmt::{Debug, Formatter};
use std::convert::From;
Expand Down Expand Up @@ -720,6 +722,46 @@ macro_rules! invoke_if_unicode_expressions_feature_flag {
};
}

fn unary_offsets_string<O, F>(array: &Utf8Array<O>, op: F) -> PrimitiveArray<O>
where
O: Offset + NativeType,
F: Fn(O) -> O,
{
let values = array
.offsets()
.windows(2)
.map(|offset| op(offset[1] - offset[0]));

let values = arrow::buffer::Buffer::from_trusted_len_iter(values);

let data_type = if O::is_large() {
DataType::Int64
} else {
DataType::Int32
};

PrimitiveArray::<O>::from_data(data_type, values, array.validity().cloned())
}

/// Returns an array of integers with the number of bits on each string of the array.
/// TODO: contribute this back upstream?
fn bit_length(array: &dyn Array) -> ArrowResult<Box<dyn Array>> {
match array.data_type() {
DataType::Utf8 => {
let array = array.as_any().downcast_ref::<Utf8Array<i32>>().unwrap();
Ok(Box::new(unary_offsets_string::<i32, _>(array, |x| x * 8)))
}
DataType::LargeUtf8 => {
let array = array.as_any().downcast_ref::<Utf8Array<i64>>().unwrap();
Ok(Box::new(unary_offsets_string::<i64, _>(array, |x| x * 8)))
}
_ => Err(ArrowError::InvalidArgumentError(format!(
"length not supported for {:?}",
array.data_type()
))),
}
}

/// Create a physical scalar function.
pub fn create_physical_fun(
fun: &BuiltinScalarFunction,
Expand Down Expand Up @@ -761,7 +803,9 @@ pub fn create_physical_fun(
))),
}),
BuiltinScalarFunction::BitLength => Arc::new(|args| match &args[0] {
ColumnarValue::Array(_v) => todo!(),
ColumnarValue::Array(v) => {
Ok(ColumnarValue::Array(bit_length(v.as_ref())?.into()))
}
ColumnarValue::Scalar(v) => match v {
ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32(
v.as_ref().map(|x| (x.len() * 8) as i32),
Expand Down

0 comments on commit 60e869e

Please sign in to comment.