Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 276 additions & 14 deletions datafusion/functions/src/math/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ use crate::utils::make_scalar_function;

use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
use arrow::compute::{cast_with_options, CastOptions};
use arrow::datatypes::DataType::{Float32, Float64, Int32};
use arrow::datatypes::{DataType, Float32Type, Float64Type, Int32Type};
use arrow::datatypes::DataType::{
Decimal128, Decimal256, Float32, Float64, Int32, Int64,
};
use arrow::datatypes::{
DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int32Type,
};
use arrow_buffer::i256;
use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
Expand Down Expand Up @@ -56,17 +60,8 @@ impl Default for RoundFunc {

impl RoundFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![
Exact(vec![Float64, Int64]),
Exact(vec![Float32, Int64]),
Exact(vec![Float64]),
Exact(vec![Float32]),
],
Volatility::Immutable,
),
signature: Signature::user_defined(Volatility::Immutable),
}
}
}
Expand All @@ -84,9 +79,41 @@ impl ScalarUDFImpl for RoundFunc {
&self.signature
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need a custom coerce types method? Why not just add two more entries to the existing table of one_of?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need a custom coerce types method? Why not just add two more entries to the existing table of one_of?

The arg of one_of TypeSignature must have a specific typmod.
but input decimal will have about 78*78 kind of types;

if arg_types.len() != 1 && arg_types.len() != 2 {
return exec_err!(
"round function requires one or two arguments, got {}",
arg_types.len()
);
}

if arg_types.len() == 1 {
match arg_types[0].clone() {
Decimal128(p, s) => Ok(vec![Decimal128(p, s)]),
Decimal256(p, s) => Ok(vec![Decimal256(p, s)]),
Float32 => Ok(vec![Float32]),
_ => Ok(vec![Float64]),
}
} else if arg_types.len() == 2 {
match arg_types[0].clone() {
Decimal128(p, s) => Ok(vec![Decimal128(p, s), Int64]),
Decimal256(p, s) => Ok(vec![Decimal256(p, s), Int64]),
Float32 => Ok(vec![Float32, Int64]),
_ => Ok(vec![Float64, Int64]),
}
} else {
exec_err!(
"round function requires one or two arguments, got {}",
arg_types.len()
)
}
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match arg_types[0] {
Float32 => Ok(Float32),
Decimal128(p, s) => Ok(Decimal128(p, s)),
Decimal256(p, s) => Ok(Decimal256(p, s)),
_ => Ok(Float64),
}
}
Expand Down Expand Up @@ -215,17 +242,141 @@ pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
}
},

Decimal128(precision, scale) => match decimal_places {
ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => {
let decimal_places: i32 = decimal_places.try_into().map_err(|e| {
exec_datafusion_err!(
"Invalid value for decimal places: {decimal_places}: {e}"
)
})?;

let values = args[0].as_primitive::<Decimal128Type>();
let result = values.unary::<_, Decimal128Type>(|value| {
round_decimal128(value, *scale, decimal_places)
});

Ok(Arc::new(result.with_precision_and_scale(*precision, *scale)?) as _)
}
ColumnarValue::Array(decimal_places) => {
let options = CastOptions {
safe: false, // raise error if the cast is not possible
..Default::default()
};
let decimal_places = cast_with_options(&decimal_places, &Int32, &options)
.map_err(|e| {
exec_datafusion_err!("Invalid values for decimal places: {e}")
})?;

let values = args[0].as_primitive::<Decimal128Type>();
let decimal_places = decimal_places.as_primitive::<Int32Type>();
let result = arrow::compute::binary::<_, _, _, Decimal128Type>(
values,
decimal_places,
|value, decimal_places| {
round_decimal128(value, *scale, decimal_places)
},
)?;

Ok(Arc::new(result.with_precision_and_scale(*precision, *scale)?) as _)
}
_ => {
exec_err!("round function requires a scalar or array for decimal_places")
}
},

Decimal256(precision, scale) => match decimal_places {
ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => {
let decimal_places: i32 = decimal_places.try_into().map_err(|e| {
exec_datafusion_err!(
"Invalid value for decimal places: {decimal_places}: {e}"
)
})?;

let values = args[0].as_primitive::<Decimal256Type>();
let result = values.unary::<_, Decimal256Type>(|value| {
round_decimal256(value, *scale, decimal_places)
});

Ok(Arc::new(result.with_precision_and_scale(*precision, *scale)?) as _)
}
ColumnarValue::Array(decimal_places) => {
let options = CastOptions {
safe: false,
..Default::default()
};
let decimal_places = cast_with_options(&decimal_places, &Int32, &options)
.map_err(|e| {
exec_datafusion_err!("Invalid values for decimal places: {e}")
})?;

let values = args[0].as_primitive::<Decimal256Type>();
let decimal_places = decimal_places.as_primitive::<Int32Type>();
let result = arrow::compute::binary::<_, _, _, Decimal256Type>(
values,
decimal_places,
|value, decimal_places| {
round_decimal256(value, *scale, decimal_places)
},
)?;

Ok(Arc::new(result.with_precision_and_scale(*precision, *scale)?) as _)
}
_ => {
exec_err!("round function requires a scalar or array for decimal_places")
}
},

other => exec_err!("Unsupported data type {other:?} for function round"),
}
}

#[inline]
fn round_decimal128(value: i128, current_scale: i8, decimal_places: i32) -> i128 {
let scale_adjustment = current_scale as i32 - decimal_places;

if scale_adjustment > 0 {
let remove_factor = 10_i128.pow(scale_adjustment as u32);
let half = remove_factor / 2;

if value >= 0 {
((value + half) / remove_factor) * remove_factor
} else {
((value - half) / remove_factor) * remove_factor
}
} else {
value
}
}

#[inline]
fn round_decimal256(value: i256, current_scale: i8, decimal_places: i32) -> i256 {
let scale_adjustment = current_scale as i32 - decimal_places;

if scale_adjustment > 0 {
let remove_factor = i256::from_i128(10_i128.pow(scale_adjustment as u32));
let half = remove_factor / i256::from_i128(2);

if value >= i256::from_i128(0) {
((value + half) / remove_factor) * remove_factor
} else {
((value - half) / remove_factor) * remove_factor
}
} else {
value
}
}

#[cfg(test)]
mod test {
use std::sync::Arc;

use crate::math::round::round;

use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
use arrow::array::{
ArrayRef, Decimal128Array, Decimal256Array, Float32Array, Float64Array,
Int64Array,
};
use arrow_buffer::i256;
use datafusion_common::cast::{as_float32_array, as_float64_array};
use datafusion_common::DataFusionError;

Expand Down Expand Up @@ -307,4 +458,115 @@ mod test {
assert!(result.is_err());
assert!(matches!(result, Err(DataFusionError::Execution { .. })));
}

#[test]
fn test_round_decimal128() {
let args: Vec<ArrayRef> = vec![
Arc::new(
Decimal128Array::from(vec![1252345_i128; 10])
.with_precision_and_scale(10, 4)
.unwrap(),
),
Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])),
];

let result = round(&args).expect("failed to initialize function round");
let decimals = result.as_any().downcast_ref::<Decimal128Array>().unwrap();

let expected = Decimal128Array::from(vec![
1250000_i128,
1252000_i128,
1252300_i128,
1252350_i128,
1252345_i128,
1252345_i128,
1300000_i128,
1000000_i128,
0_i128,
0_i128,
])
.with_precision_and_scale(10, 4)
.unwrap();

assert_eq!(decimals, &expected);
}

#[test]
fn test_round_decimal128_one_input() {
let args: Vec<ArrayRef> = vec![Arc::new(
Decimal128Array::from(vec![1252345_i128, 123450_i128, 12340_i128, 1234_i128])
.with_precision_and_scale(10, 4)
.unwrap(),
)];

let result = round(&args).expect("failed to initialize function round");
let decimals = result.as_any().downcast_ref::<Decimal128Array>().unwrap();

let expected =
Decimal128Array::from(vec![1250000_i128, 120000_i128, 10000_i128, 0_i128])
.with_precision_and_scale(10, 4)
.unwrap();

assert_eq!(decimals, &expected);
}

#[test]
fn test_round_decimal256() {
let args: Vec<ArrayRef> = vec![
Arc::new(
Decimal256Array::from(vec![i256::from_i128(1252345_i128); 10])
.with_precision_and_scale(20, 4)
.unwrap(),
),
Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])),
];

let result = round(&args).expect("failed to initialize function round");
let decimals = result.as_any().downcast_ref::<Decimal256Array>().unwrap();

let expected = Decimal256Array::from(vec![
i256::from_i128(1250000_i128),
i256::from_i128(1252000_i128),
i256::from_i128(1252300_i128),
i256::from_i128(1252350_i128),
i256::from_i128(1252345_i128),
i256::from_i128(1252345_i128),
i256::from_i128(1300000_i128),
i256::from_i128(1000000_i128),
i256::from_i128(0_i128),
i256::from_i128(0_i128),
])
.with_precision_and_scale(20, 4)
.unwrap();

assert_eq!(decimals, &expected);
}

#[test]
fn test_round_decimal256_one_input() {
let args: Vec<ArrayRef> = vec![Arc::new(
Decimal256Array::from(vec![
i256::from_i128(1252345_i128),
i256::from_i128(123450_i128),
i256::from_i128(12340_i128),
i256::from_i128(1234_i128),
])
.with_precision_and_scale(20, 4)
.unwrap(),
)];

let result = round(&args).expect("failed to initialize function round");
let decimals = result.as_any().downcast_ref::<Decimal256Array>().unwrap();

let expected = Decimal256Array::from(vec![
i256::from_i128(1250000_i128),
i256::from_i128(120000_i128),
i256::from_i128(10000_i128),
i256::from_i128(0_i128),
])
.with_precision_and_scale(20, 4)
.unwrap();

assert_eq!(decimals, &expected);
}
}