Skip to content

Commit

Permalink
Check overflow while casting floating point value to decimal128 (#3021)
Browse files Browse the repository at this point in the history
* Check overflow while casting floating point value to decimal128

* Don't validate with precision

* Return error when saturating

* Use to_i128

* Apply suggestions from code review

Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com>

* Fix format

Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com>
  • Loading branch information
viirya and tustvold committed Nov 6, 2022
1 parent 4f525fe commit 108e7d2
Showing 1 changed file with 60 additions and 4 deletions.
64 changes: 60 additions & 4 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,16 +344,43 @@ fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
{
let mul = 10_f64.powi(scale as i32);

array
.unary::<_, Decimal128Type>(|v| (v.as_() * mul).round() as i128)
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
if cast_options.safe {
let iter = array
.iter()
.map(|v| v.and_then(|v| (mul * v.as_()).round().to_i128()));
let casted_array =
unsafe { PrimitiveArray::<Decimal128Type>::from_trusted_len_iter(iter) };
casted_array
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
array
.try_unary::<_, Decimal128Type, _>(|v| {
mul.mul_checked(v.as_()).and_then(|value| {
let mul_v = value.round();
let integer: i128 = mul_v.to_i128().ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {}({}, {}). Overflowing on {:?}",
Decimal128Type::PREFIX,
precision,
scale,
v
))
})?;

Ok(integer)
})
})
.and_then(|a| a.with_precision_and_scale(precision, scale))
.map(|a| Arc::new(a) as ArrayRef)
}
}

fn cast_floating_point_to_decimal256<T: ArrowPrimitiveType>(
Expand Down Expand Up @@ -588,11 +615,13 @@ pub fn cast_with_options(
as_primitive_array::<Float32Type>(array),
*precision,
*scale,
cast_options,
),
Float64 => cast_floating_point_to_decimal128(
as_primitive_array::<Float64Type>(array),
*precision,
*scale,
cast_options,
),
Null => Ok(new_null_array(to_type, array.len())),
_ => Err(ArrowError::CastError(format!(
Expand Down Expand Up @@ -6110,4 +6139,31 @@ mod tests {
);
assert!(casted_array.is_err());
}

#[test]
fn test_cast_floating_point_to_decimal128_overflow() {
let array = Float64Array::from(vec![f64::MAX]);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast_with_options(
&array,
&DataType::Decimal128(38, 30),
&CastOptions { safe: true },
);
assert!(casted_array.is_ok());
assert!(casted_array.unwrap().is_null(0));

let casted_array = cast_with_options(
&array,
&DataType::Decimal128(38, 30),
&CastOptions { safe: false },
);
let err = casted_array.unwrap_err().to_string();
let expected_error = "Cast error: Cannot cast to Decimal128(38, 30)";
assert!(
err.contains(expected_error),
"did not find expected error '{}' in actual error '{}'",
expected_error,
err
);
}
}

0 comments on commit 108e7d2

Please sign in to comment.