Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 1 addition & 55 deletions datafusion/functions/benches/atan2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

extern crate criterion;

use arrow::datatypes::{DataType, Field, Float32Type, Float64Type};
use arrow::datatypes::{DataType, Field, Float64Type};
use arrow::util::bench_util::create_primitive_array;
use criterion::{Criterion, criterion_group, criterion_main};
use datafusion_common::ScalarValue;
Expand All @@ -32,34 +32,6 @@ fn criterion_benchmark(c: &mut Criterion) {
let config_options = Arc::new(ConfigOptions::default());

for size in [1024, 4096, 8192] {
let y_f32 = Arc::new(create_primitive_array::<Float32Type>(size, 0.2));
let x_f32 = Arc::new(create_primitive_array::<Float32Type>(size, 0.2));
let f32_args = vec![ColumnarValue::Array(y_f32), ColumnarValue::Array(x_f32)];
let f32_arg_fields = f32_args
.iter()
.enumerate()
.map(|(idx, arg)| {
Field::new(format!("arg_{idx}"), arg.data_type(), true).into()
})
.collect::<Vec<_>>();
let return_field_f32 = Field::new("f", DataType::Float32, true).into();

c.bench_function(&format!("atan2 f32 array: {size}"), |b| {
b.iter(|| {
black_box(
atan2_fn
.invoke_with_args(ScalarFunctionArgs {
args: f32_args.clone(),
arg_fields: f32_arg_fields.clone(),
number_rows: size,
return_field: Arc::clone(&return_field_f32),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});

let y_f64 = Arc::new(create_primitive_array::<Float64Type>(size, 0.2));
let x_f64 = Arc::new(create_primitive_array::<Float64Type>(size, 0.2));
let f64_args = vec![ColumnarValue::Array(y_f64), ColumnarValue::Array(x_f64)];
Expand Down Expand Up @@ -89,32 +61,6 @@ fn criterion_benchmark(c: &mut Criterion) {
});
}

let scalar_f32_args = vec![
ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0))),
ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))),
];
let scalar_f32_arg_fields = vec![
Field::new("a", DataType::Float32, false).into(),
Field::new("b", DataType::Float32, false).into(),
];
let return_field_f32 = Field::new("f", DataType::Float32, false).into();

c.bench_function("atan2 f32 scalar", |b| {
b.iter(|| {
black_box(
atan2_fn
.invoke_with_args(ScalarFunctionArgs {
args: scalar_f32_args.clone(),
arg_fields: scalar_f32_arg_fields.clone(),
number_rows: 1,
return_field: Arc::clone(&return_field_f32),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});

let scalar_f64_args = vec![
ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0))),
ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))),
Expand Down
79 changes: 23 additions & 56 deletions datafusion/functions/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,8 @@ macro_rules! make_math_unary_udf {

/// Macro to create a binary math UDF.
///
/// A binary math function takes two arguments of types Float32 or Float64,
/// applies a binary floating function to the argument, and returns a value of the same type.
/// A binary math function takes two numeric arguments, coerces them to Float64,
/// applies a binary floating function to the arguments, and returns Float64.
///
/// $UDF: the name of the UDF struct that implements `ScalarUDFImpl`
/// $NAME: the name of the function
Expand All @@ -362,10 +362,9 @@ macro_rules! make_math_binary_udf {
use std::sync::Arc;

use arrow::array::{ArrayRef, AsArray};
use arrow::datatypes::{DataType, Float32Type, Float64Type};
use arrow::datatypes::{DataType, Float64Type};
use datafusion_common::utils::take_function_args;
use datafusion_common::{Result, ScalarValue, internal_err};
use datafusion_expr::TypeSignature;
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl,
Expand All @@ -381,11 +380,8 @@ macro_rules! make_math_binary_udf {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![
TypeSignature::Exact(vec![Float32, Float32]),
TypeSignature::Exact(vec![Float64, Float64]),
],
signature: Signature::exact(
vec![Float64, Float64],
Volatility::Immutable,
),
}
Expand All @@ -401,14 +397,8 @@ macro_rules! make_math_binary_udf {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let arg_type = &arg_types[0];

match arg_type {
DataType::Float32 => Ok(DataType::Float32),
// For other types (possible values float64/null/int), use Float64
_ => Ok(DataType::Float64),
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn output_ordering(
Expand All @@ -422,10 +412,7 @@ macro_rules! make_math_binary_udf {
&self,
args: ScalarFunctionArgs,
) -> Result<ColumnarValue> {
let ScalarFunctionArgs {
args, return_field, ..
} = args;
let return_type = return_field.data_type();
let ScalarFunctionArgs { args, .. } = args;
let [y, x] = take_function_args(self.name(), args)?;

match (y, x) {
Expand All @@ -434,21 +421,14 @@ macro_rules! make_math_binary_udf {
ColumnarValue::Scalar(x_scalar),
) => match (&y_scalar, &x_scalar) {
(y, x) if y.is_null() || x.is_null() => {
ColumnarValue::Scalar(ScalarValue::Null)
.cast_to(return_type, None)
Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)))
}
(
ScalarValue::Float64(Some(yv)),
ScalarValue::Float64(Some(xv)),
) => Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(
f64::$BINARY_FUNC(*yv, *xv),
)))),
(
ScalarValue::Float32(Some(yv)),
ScalarValue::Float32(Some(xv)),
) => Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(
f32::$BINARY_FUNC(*yv, *xv),
)))),
_ => internal_err!(
"Unexpected scalar types for function {}: {:?}, {:?}",
self.name(),
Expand All @@ -458,38 +438,25 @@ macro_rules! make_math_binary_udf {
},
(y, x) => {
let args = ColumnarValue::values_to_arrays(&[y, x])?;
let arr: ArrayRef = match args[0].data_type() {
DataType::Float64 => {
match (args[0].data_type(), args[1].data_type()) {
(DataType::Float64, DataType::Float64) => {
let y = args[0].as_primitive::<Float64Type>();
let x = args[1].as_primitive::<Float64Type>();
let result =
arrow::compute::binary::<_, _, _, Float64Type>(
y,
x,
|y, x| f64::$BINARY_FUNC(y, x),
)?;
Arc::new(result) as _
}
DataType::Float32 => {
let y = args[0].as_primitive::<Float32Type>();
let x = args[1].as_primitive::<Float32Type>();
let result =
arrow::compute::binary::<_, _, _, Float32Type>(
y,
x,
|y, x| f32::$BINARY_FUNC(y, x),
)?;
Arc::new(result) as _
let result = arrow::compute::binary::<_, _, _, Float64Type>(
y,
x,
|y, x| f64::$BINARY_FUNC(y, x),
)?;

Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
}
other => {
return internal_err!(
"Unsupported data type {other:?} for function {}",
(left, right) => {
internal_err!(
"Unexpected array types for function {}: {left:?}, {right:?}",
self.name()
);
)
}
};

Ok(ColumnarValue::Array(arr))
}
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions datafusion/functions/src/math/monotonicity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,11 @@ Can be a constant, column, or function, and any combination of arithmetic operat
)
.with_sql_example(r#"```sql
> SELECT atan2(1, 1);
+------------+
| atan2(1,1) |
+------------+
| 0.7853982 |
+------------+
+--------------------+
| atan2(1,1) |
+--------------------+
| 0.7853981633974483 |
+--------------------+
```"#)
.build()
});
Expand Down
18 changes: 17 additions & 1 deletion datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,23 @@ select round(atanh(a), 5), round(atanh(b), 5), round(atanh(c), 5) from small_flo
query RRR rowsort
select atan2(0, 1), atan2(1, 2), atan2(2, 2);
----
0 0.4636476 0.7853982
0 0.463647609001 0.785398163397

# atan2 always returns Float64, including integer, Float32, and NULL inputs
query TTTT
select
arrow_typeof(atan2(1, 1)),
arrow_typeof(atan2(arrow_cast(1.0, 'Float32'), arrow_cast(1.0, 'Float32'))),
arrow_typeof(atan2(null, null)),
arrow_typeof(atan2(null, 64));
----
Float64 Float64 Float64 Float64

# atan2 with integer inputs is computed in double precision
query B
select atan2(1, 1000000) = atan2(1.0, 1000000.0);
----
true

# atan2 scalar nulls
query R rowsort
Expand Down
10 changes: 5 additions & 5 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,11 @@ atan2(expression_y, expression_x)

```sql
> SELECT atan2(1, 1);
+------------+
| atan2(1,1) |
+------------+
| 0.7853982 |
+------------+
+--------------------+
| atan2(1,1) |
+--------------------+
| 0.7853981633974483 |
+--------------------+
```

### `atanh`
Expand Down
Loading