From 187bf619dfafccdb21cea6b2cecabd29daffc1e4 Mon Sep 17 00:00:00 2001 From: Kun Liu Date: Fri, 25 Nov 2022 15:06:12 +0800 Subject: [PATCH] fix: cast decimal to decimal should be round the result (#3139) Co-authored-by: Raphael Taylor-Davies --- arrow-cast/src/cast.rs | 192 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 179 insertions(+), 13 deletions(-) diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 3f17758255c..07c7d6a3ac5 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -1967,12 +1967,26 @@ fn cast_decimal_to_decimal_safe().unwrap(); if BYTE_WIDTH2 == 16 { - let iter = array - .iter() - .map(|v| v.and_then(|v| v.div_checked(div).ok())); + // rounding the result + let iter = array.iter().map(|v| { + v.map(|v| { + // the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation + let d = v.wrapping_div(div); + let r = v.wrapping_rem(div); + if v >= 0 && r >= half { + d.wrapping_add(1) + } else if v < 0 && r <= neg_half { + d.wrapping_sub(1) + } else { + d + } + }) + }); let casted_array = unsafe { PrimitiveArray::::from_trusted_len_iter(iter) }; @@ -1981,7 +1995,17 @@ fn cast_decimal_to_decimal_safe= 0 && r >= half { + d.wrapping_add(1) + } else if v < 0 && r <= neg_half { + d.wrapping_sub(1) + } else { + d + }) + }) }); let casted_array = unsafe { PrimitiveArray::::from_trusted_len_iter(iter) @@ -1993,9 +2017,22 @@ fn cast_decimal_to_decimal_safe().unwrap(); let div = i256::from_i128(div); + let half = div / i256::from_i128(2); + let neg_half = half.wrapping_neg(); if BYTE_WIDTH2 == 16 { let iter = array.iter().map(|v| { - v.and_then(|v| v.div_checked(div).ok().and_then(|v| v.to_i128())) + v.and_then(|v| { + let d = v.wrapping_div(div); + let r = v.wrapping_rem(div); + if v >= i256::ZERO && r >= half { + d.wrapping_add(i256::ONE) + } else if v < i256::ZERO && r <= neg_half { + d.wrapping_sub(i256::ONE) + } else { + d + } + .to_i128() + }) }); let casted_array = unsafe { PrimitiveArray::::from_trusted_len_iter(iter) @@ -2004,9 +2041,19 @@ fn cast_decimal_to_decimal_safe= i256::ZERO && r >= half { + d.wrapping_add(i256::ONE) + } else if v < i256::ZERO && r <= neg_half { + d.wrapping_sub(i256::ONE) + } else { + d + } + }) + }); let casted_array = unsafe { PrimitiveArray::::from_trusted_len_iter(iter) }; @@ -3566,6 +3613,125 @@ mod tests { .with_precision_and_scale(precision, scale) } + #[test] + #[cfg(not(feature = "force_validate"))] + fn test_cast_decimal_to_decimal_round() { + let array = vec![ + Some(1123454), + Some(2123456), + Some(-3123453), + Some(-3123456), + None, + ]; + let input_decimal_array = create_decimal_array(array, 20, 4).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + // decimal128 to decimal128 + let input_type = DataType::Decimal128(20, 4); + let output_type = DataType::Decimal128(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(112345_i128), + Some(212346_i128), + Some(-312345_i128), + Some(-312346_i128), + None + ] + ); + + // decimal128 to decimal256 + let input_type = DataType::Decimal128(20, 4); + let output_type = DataType::Decimal256(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(112345_i128)), + Some(i256::from_i128(212346_i128)), + Some(i256::from_i128(-312345_i128)), + Some(i256::from_i128(-312346_i128)), + None + ] + ); + + // decimal256 + let array = vec![ + Some(i256::from_i128(1123454)), + Some(i256::from_i128(2123456)), + Some(i256::from_i128(-3123453)), + Some(i256::from_i128(-3123456)), + None, + ]; + let input_decimal_array = create_decimal256_array(array, 20, 4).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + + // decimal256 to decimal256 + let input_type = DataType::Decimal256(20, 4); + let output_type = DataType::Decimal256(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(112345_i128)), + Some(i256::from_i128(212346_i128)), + Some(i256::from_i128(-312345_i128)), + Some(i256::from_i128(-312346_i128)), + None + ] + ); + // decimal256 to decimal128 + let input_type = DataType::Decimal256(20, 4); + let output_type = DataType::Decimal128(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(112345_i128), + Some(212346_i128), + Some(-312345_i128), + Some(-312346_i128), + None + ] + ); + + // decimal256 to decimal128 overflow + let array = vec![ + Some(i256::from_i128(1123454)), + Some(i256::from_i128(2123456)), + Some(i256::from_i128(-3123453)), + Some(i256::from_i128(-3123456)), + None, + Some(i256::MAX), + Some(i256::MIN), + ]; + let input_decimal_array = create_decimal256_array(array, 76, 4).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(112345_i128), + Some(212346_i128), + Some(-312345_i128), + Some(-312346_i128), + None, + None, + None + ] + ); + } + #[test] #[cfg(not(feature = "force_validate"))] fn test_cast_decimal128_to_decimal128() { @@ -7219,7 +7385,7 @@ mod tests { let input_type = DataType::Decimal128(20, 0); let output_type = DataType::Decimal128(20, -1); assert!(can_cast_types(&input_type, &output_type)); - let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = vec![Some(1123450), Some(2123455), Some(3123456), None]; let input_decimal_array = create_decimal_array(array, 20, 0).unwrap(); let array = Arc::new(input_decimal_array) as ArrayRef; generate_cast_test_case!( @@ -7228,8 +7394,8 @@ mod tests { &output_type, vec![ Some(112345_i128), - Some(212345_i128), - Some(312345_i128), + Some(212346_i128), + Some(312346_i128), None ] ); @@ -7238,8 +7404,8 @@ mod tests { let decimal_arr = as_primitive_array::(&casted_array); assert_eq!("1123450", decimal_arr.value_as_string(0)); - assert_eq!("2123450", decimal_arr.value_as_string(1)); - assert_eq!("3123450", decimal_arr.value_as_string(2)); + assert_eq!("2123460", decimal_arr.value_as_string(1)); + assert_eq!("3123460", decimal_arr.value_as_string(2)); } #[test]