Skip to content

Commit

Permalink
Cast decimal256 to signed integer (#3040)
Browse files Browse the repository at this point in the history
* Cast decimal256 to signed integer

* Use ToPrimitive

* Add CastOptions
  • Loading branch information
viirya committed Nov 8, 2022
1 parent 879b461 commit a950b52
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 42 deletions.
87 changes: 81 additions & 6 deletions arrow-buffer/src/bigint.rs
Expand Up @@ -16,7 +16,7 @@
// under the License.

use num::cast::AsPrimitive;
use num::{BigInt, FromPrimitive};
use num::{BigInt, FromPrimitive, ToPrimitive};
use std::cmp::Ordering;

/// A signed 256-bit integer
Expand Down Expand Up @@ -388,13 +388,15 @@ impl i256 {

/// Temporary workaround due to lack of stable const array slicing
/// See <https://github.com/rust-lang/rust/issues/90091>
const fn split_array(vals: [u8; 32]) -> ([u8; 16], [u8; 16]) {
let mut a = [0; 16];
let mut b = [0; 16];
const fn split_array<const N: usize, const M: usize>(
vals: [u8; N],
) -> ([u8; M], [u8; M]) {
let mut a = [0; M];
let mut b = [0; M];
let mut i = 0;
while i != 16 {
while i != M {
a[i] = vals[i];
b[i] = vals[i + 16];
b[i] = vals[i + M];
i += 1;
}
(a, b)
Expand Down Expand Up @@ -478,6 +480,44 @@ define_as_primitive!(i16);
define_as_primitive!(i32);
define_as_primitive!(i64);

impl ToPrimitive for i256 {
fn to_i64(&self) -> Option<i64> {
let as_i128 = self.low as i128;

let high_negative = self.high < 0;
let low_negative = as_i128 < 0;
let high_valid = self.high == -1 || self.high == 0;

if high_negative == low_negative && high_valid {
let (low_bytes, high_bytes) = split_array(u128::to_le_bytes(self.low));
let high = i64::from_le_bytes(high_bytes);
let low = i64::from_le_bytes(low_bytes);

let high_negative = high < 0;
let low_negative = low < 0;
let high_valid = self.high == -1 || self.high == 0;

(high_negative == low_negative && high_valid).then_some(low)
} else {
None
}
}

fn to_u64(&self) -> Option<u64> {
let as_i128 = self.low as i128;

let high_negative = self.high < 0;
let low_negative = as_i128 < 0;
let high_valid = self.high == -1 || self.high == 0;

if high_negative == low_negative && high_valid {
self.low.to_u64()
} else {
None
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -676,4 +716,39 @@ mod tests {
test_ops(i256::from_le_bytes(l), i256::from_le_bytes(r))
}
}

#[test]
fn test_i256_to_primitive() {
let a = i256::MAX;
assert!(a.to_i64().is_none());
assert!(a.to_u64().is_none());

let a = i256::from_i128(i128::MAX);
assert!(a.to_i64().is_none());
assert!(a.to_u64().is_none());

let a = i256::from_i128(i64::MAX as i128);
assert_eq!(a.to_i64().unwrap(), i64::MAX);
assert_eq!(a.to_u64().unwrap(), i64::MAX as u64);

let a = i256::from_i128(i64::MAX as i128 + 1);
assert!(a.to_i64().is_none());
assert_eq!(a.to_u64().unwrap(), i64::MAX as u64 + 1);

let a = i256::MIN;
assert!(a.to_i64().is_none());
assert!(a.to_u64().is_none());

let a = i256::from_i128(i128::MIN);
assert!(a.to_i64().is_none());
assert!(a.to_u64().is_none());

let a = i256::from_i128(i64::MIN as i128);
assert_eq!(a.to_i64().unwrap(), i64::MIN);
assert!(a.to_u64().is_none());

let a = i256::from_i128(i64::MIN as i128 - 1);
assert!(a.to_i64().is_none());
assert!(a.to_u64().is_none());
}
}
216 changes: 180 additions & 36 deletions arrow-cast/src/cast.rs
Expand Up @@ -81,7 +81,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal128(_, _)) |
(Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal256(_, _)) |
// decimal to signed numeric
(Decimal128(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
(Decimal128(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) |
(Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64 )
| (
Null,
Boolean
Expand Down Expand Up @@ -433,34 +434,65 @@ fn cast_reinterpret_arrays<
))
}

// cast the decimal array to integer array
macro_rules! cast_decimal_to_integer {
($ARRAY:expr, $SCALE : ident, $VALUE_BUILDER: ident, $NATIVE_TYPE : ident, $DATA_TYPE : expr) => {{
let array = $ARRAY.as_any().downcast_ref::<Decimal128Array>().unwrap();
let mut value_builder = $VALUE_BUILDER::with_capacity(array.len());
let div: i128 = 10_i128.pow(*$SCALE as u32);
let min_bound = ($NATIVE_TYPE::MIN) as i128;
let max_bound = ($NATIVE_TYPE::MAX) as i128;
fn cast_decimal_to_integer<D, T>(
array: &ArrayRef,
base: D::Native,
scale: u8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
T: ArrowPrimitiveType,
<T as ArrowPrimitiveType>::Native: NumCast,
D: DecimalType + ArrowPrimitiveType,
<D as ArrowPrimitiveType>::Native: ArrowNativeTypeOp + ToPrimitive,
{
let array = array.as_any().downcast_ref::<PrimitiveArray<D>>().unwrap();

let div: D::Native = base.pow_checked(scale as u32).map_err(|_| {
ArrowError::CastError(format!(
"Cannot cast to {:?}. The scale {} causes overflow.",
D::PREFIX,
scale,
))
})?;

let mut value_builder = PrimitiveBuilder::<T>::with_capacity(array.len());

if cast_options.safe {
for i in 0..array.len() {
if array.is_null(i) {
value_builder.append_null();
} else {
let v = array.value(i) / div;
// check the overflow
// For example: Decimal(128,10,0) as i8
// 128 is out of range i8
if v <= max_bound && v >= min_bound {
value_builder.append_value(v as $NATIVE_TYPE);
} else {
return Err(ArrowError::CastError(format!(
"value of {} is out of range {}",
v, $DATA_TYPE
)));
}
let v = array
.value(i)
.div_checked(div)
.ok()
.and_then(<T::Native as NumCast>::from::<D::Native>);

value_builder.append_option(v);
}
}
Ok(Arc::new(value_builder.finish()))
}};
} else {
for i in 0..array.len() {
if array.is_null(i) {
value_builder.append_null();
} else {
let v = array.value(i).div_checked(div)?;

let value =
<T::Native as NumCast>::from::<D::Native>(v).ok_or_else(|| {
ArrowError::CastError(format!(
"value of {:?} is out of range {}",
v,
T::DATA_TYPE
))
})?;

value_builder.append_value(value);
}
}
}
Ok(Arc::new(value_builder.finish()))
}

// cast the decimal array to floating-point array
Expand Down Expand Up @@ -576,18 +608,30 @@ pub fn cast_with_options(
(Decimal128(_, scale), _) => {
// cast decimal to other type
match to_type {
Int8 => {
cast_decimal_to_integer!(array, scale, Int8Builder, i8, Int8)
}
Int16 => {
cast_decimal_to_integer!(array, scale, Int16Builder, i16, Int16)
}
Int32 => {
cast_decimal_to_integer!(array, scale, Int32Builder, i32, Int32)
}
Int64 => {
cast_decimal_to_integer!(array, scale, Int64Builder, i64, Int64)
}
Int8 => cast_decimal_to_integer::<Decimal128Type, Int8Type>(
array,
10_i128,
*scale,
cast_options,
),
Int16 => cast_decimal_to_integer::<Decimal128Type, Int16Type>(
array,
10_i128,
*scale,
cast_options,
),
Int32 => cast_decimal_to_integer::<Decimal128Type, Int32Type>(
array,
10_i128,
*scale,
cast_options,
),
Int64 => cast_decimal_to_integer::<Decimal128Type, Int64Type>(
array,
10_i128,
*scale,
cast_options,
),
Float32 => {
cast_decimal_to_float!(array, scale, Float32Builder, f32)
}
Expand All @@ -601,6 +645,40 @@ pub fn cast_with_options(
))),
}
}
(Decimal256(_, scale), _) => {
// cast decimal to other type
match to_type {
Int8 => cast_decimal_to_integer::<Decimal256Type, Int8Type>(
array,
i256::from_i128(10_i128),
*scale,
cast_options,
),
Int16 => cast_decimal_to_integer::<Decimal256Type, Int16Type>(
array,
i256::from_i128(10_i128),
*scale,
cast_options,
),
Int32 => cast_decimal_to_integer::<Decimal256Type, Int32Type>(
array,
i256::from_i128(10_i128),
*scale,
cast_options,
),
Int64 => cast_decimal_to_integer::<Decimal256Type, Int64Type>(
array,
i256::from_i128(10_i128),
*scale,
cast_options,
),
Null => Ok(new_null_array(to_type, array.len())),
_ => Err(ArrowError::CastError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type
))),
}
}
(_, Decimal128(precision, scale)) => {
// cast data to decimal
match from_type {
Expand Down Expand Up @@ -3154,12 +3232,18 @@ mod tests {
let value_array: Vec<Option<i128>> = vec![Some(24400)];
let decimal_array = create_decimal_array(value_array, 38, 2).unwrap();
let array = Arc::new(decimal_array) as ArrayRef;
let casted_array = cast(&array, &DataType::Int8);
let casted_array =
cast_with_options(&array, &DataType::Int8, &CastOptions { safe: false });
assert_eq!(
"Cast error: value of 244 is out of range Int8".to_string(),
casted_array.unwrap_err().to_string()
);

let casted_array =
cast_with_options(&array, &DataType::Int8, &CastOptions { safe: true });
assert!(casted_array.is_ok());
assert!(casted_array.unwrap().is_null(0));

// loss the precision: convert decimal to f32、f64
// f32
// 112345678_f32 and 112345679_f32 are same, so the 112345679_f32 will lose precision.
Expand Down Expand Up @@ -3218,6 +3302,66 @@ mod tests {
);
}

#[test]
fn test_cast_decimal256_to_numeric() {
let decimal_type = DataType::Decimal256(38, 2);
// negative test
assert!(!can_cast_types(&decimal_type, &DataType::UInt8));
let value_array: Vec<Option<i256>> = vec![
Some(i256::from_i128(125)),
Some(i256::from_i128(225)),
Some(i256::from_i128(325)),
None,
Some(i256::from_i128(525)),
];
let decimal_array = create_decimal256_array(value_array, 38, 2).unwrap();
let array = Arc::new(decimal_array) as ArrayRef;
// i8
generate_cast_test_case!(
&array,
Int8Array,
&DataType::Int8,
vec![Some(1_i8), Some(2_i8), Some(3_i8), None, Some(5_i8)]
);
// i16
generate_cast_test_case!(
&array,
Int16Array,
&DataType::Int16,
vec![Some(1_i16), Some(2_i16), Some(3_i16), None, Some(5_i16)]
);
// i32
generate_cast_test_case!(
&array,
Int32Array,
&DataType::Int32,
vec![Some(1_i32), Some(2_i32), Some(3_i32), None, Some(5_i32)]
);
// i64
generate_cast_test_case!(
&array,
Int64Array,
&DataType::Int64,
vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)]
);

// overflow test: out of range of max i8
let value_array: Vec<Option<i256>> = vec![Some(i256::from_i128(24400))];
let decimal_array = create_decimal256_array(value_array, 38, 2).unwrap();
let array = Arc::new(decimal_array) as ArrayRef;
let casted_array =
cast_with_options(&array, &DataType::Int8, &CastOptions { safe: false });
assert_eq!(
"Cast error: value of 244 is out of range Int8".to_string(),
casted_array.unwrap_err().to_string()
);

let casted_array =
cast_with_options(&array, &DataType::Int8, &CastOptions { safe: true });
assert!(casted_array.is_ok());
assert!(casted_array.unwrap().is_null(0));
}

#[test]
#[cfg(not(feature = "force_validate"))]
fn test_cast_numeric_to_decimal128() {
Expand Down

0 comments on commit a950b52

Please sign in to comment.