Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 41 additions & 22 deletions datafusion/functions/src/math/nanvl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ use std::sync::Arc;
use arrow::array::{ArrayRef, AsArray, Float16Array, Float32Array, Float64Array};
use arrow::datatypes::DataType::{Float16, Float32, Float64};
use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type};
use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args};
use datafusion_expr::TypeSignature::Exact;
use datafusion_common::{
Result, ScalarValue, exec_err, plan_err, utils::take_function_args,
};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
Expand Down Expand Up @@ -64,14 +65,8 @@ impl Default for NanvlFunc {
impl NanvlFunc {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
Exact(vec![Float16, Float16]),
Exact(vec![Float32, Float32]),
Exact(vec![Float64, Float64]),
],
Volatility::Immutable,
),
// Argument coercion is handled by `coerce_types`.
signature: Signature::user_defined(Volatility::Immutable),
}
}
}
Expand All @@ -86,27 +81,42 @@ impl ScalarUDFImpl for NanvlFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match &arg_types[0] {
Float16 => Ok(Float16),
Float32 => Ok(Float32),
match (&arg_types[0], &arg_types[1]) {
(Float16, Float16) => Ok(Float16),
(Float32, Float32) => Ok(Float32),
_ => Ok(Float64),
}
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let [x, y] = take_function_args(self.name(), arg_types)?;

// Integers, decimals, and NULL become Float64; choosing Float64 ensures
// we can represent as many inputs as possible before rounding. The two
// inputs are then unified to the widest float type. For example,
// (Float16, Float32) -> Float32, not Float64.
let to_float = |t: &DataType| match t {
Float16 => Ok(Float16),
Float32 => Ok(Float32),
t if t.is_numeric() || t.is_null() => Ok(Float64),
t => plan_err!("Function 'nanvl' expects numeric arguments, got {t}"),
};
let common = match (to_float(x)?, to_float(y)?) {
(Float64, _) | (_, Float64) => Float64,
(Float32, _) | (_, Float32) => Float32,
_ => Float16,
};
Ok(vec![common.clone(), common])
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let [x, y] = take_function_args(self.name(), args.args)?;

match (x, y) {
(ColumnarValue::Scalar(ScalarValue::Float16(Some(v))), y) if v.is_nan() => {
Ok(y)
}
(ColumnarValue::Scalar(ScalarValue::Float32(Some(v))), y) if v.is_nan() => {
Ok(y)
}
(ColumnarValue::Scalar(ScalarValue::Float64(Some(v))), y) if v.is_nan() => {
Ok(y)
}
// Scalar x: return y if x is NaN, otherwise x (which may be NULL).
(ColumnarValue::Scalar(ref x), y) if scalar_is_nan(x) => Ok(y),
(x @ ColumnarValue::Scalar(_), _) => Ok(x),
// At least one argument is an array: evaluate element-wise.
(x, y) => {
let args = ColumnarValue::values_to_arrays(&[x, y])?;
Ok(ColumnarValue::Array(nanvl(&args)?))
Expand All @@ -119,6 +129,15 @@ impl ScalarUDFImpl for NanvlFunc {
}
}

fn scalar_is_nan(scalar: &ScalarValue) -> bool {
match scalar {
ScalarValue::Float16(Some(v)) => v.is_nan(),
ScalarValue::Float32(Some(v)) => v.is_nan(),
ScalarValue::Float64(Some(v)) => v.is_nan(),
_ => false,
}
}

/// Nanvl SQL function
///
/// - x is NaN -> output is y (which may itself be NULL)
Expand Down
25 changes: 25 additions & 0 deletions datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,31 @@ select nanvl(null, null);
----
NULL

# nanvl evaluates in the common (widest) float type of its arguments. Mixing
# narrower floats widens losslessly (Float16 + Float32 -> Float32), while
# integers, decimals, and NULL are coerced to Float64.
query TTTTTTTT
select
arrow_typeof(nanvl(arrow_cast(1.0, 'Float16'), arrow_cast(2.0, 'Float16'))),
arrow_typeof(nanvl(arrow_cast(1.0, 'Float32'), arrow_cast(2.0, 'Float32'))),
arrow_typeof(nanvl(arrow_cast(1.0, 'Float16'), arrow_cast(2.0, 'Float32'))),
arrow_typeof(nanvl(arrow_cast(1.0, 'Float16'), arrow_cast(2.0, 'Float64'))),
arrow_typeof(nanvl(arrow_cast(1.0, 'Float32'), arrow_cast(2.0, 'Float64'))),
arrow_typeof(nanvl(1, 2)),
arrow_typeof(nanvl(1, arrow_cast(2.0, 'Float32'))),
arrow_typeof(nanvl(null, null));
----
Float16 Float32 Float32 Float64 Float64 Float64 Float64 Float64

# nanvl with an integer argument is computed in double precision, even when the
# other argument is Float32.
query BB
select
nanvl(16777217, 1) = nanvl(arrow_cast(16777217, 'Float64'), 1.0),
nanvl(16777217, arrow_cast(1.0, 'Float32')) = nanvl(arrow_cast(16777217, 'Float64'), 1.0);
----
true true

# nanvl with columns (round is needed to normalize the outputs of different operating systems)
query RRR rowsort
select round(nanvl(asin(f + a), 2), 5), round(nanvl(asin(b + c), 3), 5), round(nanvl(asin(d + e), 4), 5) from small_floats;
Expand Down
Loading