diff --git a/datafusion/core/src/logical_plan/mod.rs b/datafusion/core/src/logical_plan/mod.rs index e4e26ad54432..9b391983748c 100644 --- a/datafusion/core/src/logical_plan/mod.rs +++ b/datafusion/core/src/logical_plan/mod.rs @@ -28,8 +28,8 @@ pub use datafusion_common::{ }; pub use datafusion_expr::{ abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, - avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, coalesce, col, - combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos, count, + atan2, avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, coalesce, + col, combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos, count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exists, exp, expr_rewriter, expr_rewriter::{ diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 93347ee41b43..c9c5d955a35c 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -505,6 +505,9 @@ async fn test_mathematical_expressions_with_null() -> Result<()> { test_expression!("power(NULL, 2)", "NULL"); test_expression!("power(NULL, NULL)", "NULL"); test_expression!("power(2, NULL)", "NULL"); + test_expression!("atan2(NULL, NULL)", "NULL"); + test_expression!("atan2(1, NULL)", "NULL"); + test_expression!("atan2(NULL, 1)", "NULL"); Ok(()) } diff --git a/datafusion/core/tests/sql/math.rs b/datafusion/core/tests/sql/math.rs new file mode 100644 index 000000000000..cff7120a20a1 --- /dev/null +++ b/datafusion/core/tests/sql/math.rs @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; +use arrow::array::Float64Array; + +#[tokio::test] +async fn test_atan2() -> Result<()> { + let ctx = SessionContext::new(); + + let t1_schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Float64, true), + Field::new("y", DataType::Float64, true), + ])); + + let t1_data = RecordBatch::try_new( + t1_schema.clone(), + vec![ + Arc::new(Float64Array::from(vec![1.0, 1.0, -1.0, -1.0])), + Arc::new(Float64Array::from(vec![2.0, -2.0, 2.0, -2.0])), + ], + )?; + let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; + ctx.register_table("t1", Arc::new(t1_table))?; + + let sql = "SELECT atan2(y, x) FROM t1"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------------------+", + "| atan2(t1.y,t1.x) |", + "+---------------------+", + "| 1.1071487177940904 |", + "| -1.1071487177940904 |", + "| 2.0344439357957027 |", + "| -2.0344439357957027 |", + "+---------------------+", + ]; + + assert_batches_eq!(expected, &actual); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index a7f4cabe9d7b..9533337a5300 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -91,6 +91,7 @@ pub mod intersection; pub mod joins; pub mod json; pub mod limit; +pub mod math; pub mod order; pub mod parquet; pub mod predicates; diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 663888e2ecd8..ffac07ca53f9 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -34,6 +34,8 @@ pub enum BuiltinScalarFunction { Asin, /// atan Atan, + /// atan2 + Atan2, /// ceil Ceil, /// coalesce @@ -181,6 +183,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Acos => Volatility::Immutable, BuiltinScalarFunction::Asin => Volatility::Immutable, BuiltinScalarFunction::Atan => Volatility::Immutable, + BuiltinScalarFunction::Atan2 => Volatility::Immutable, BuiltinScalarFunction::Ceil => Volatility::Immutable, BuiltinScalarFunction::Coalesce => Volatility::Immutable, BuiltinScalarFunction::Cos => Volatility::Immutable, @@ -268,6 +271,7 @@ impl FromStr for BuiltinScalarFunction { "acos" => BuiltinScalarFunction::Acos, "asin" => BuiltinScalarFunction::Asin, "atan" => BuiltinScalarFunction::Atan, + "atan2" => BuiltinScalarFunction::Atan2, "ceil" => BuiltinScalarFunction::Ceil, "cos" => BuiltinScalarFunction::Cos, "exp" => BuiltinScalarFunction::Exp, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index abfd37a7c1c6..97bbd419e4ec 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -304,6 +304,7 @@ unary_scalar_expr!(Log10, log10); unary_scalar_expr!(Ln, ln); unary_scalar_expr!(NullIf, nullif); scalar_expr!(Power, power, base, exponent); +scalar_expr!(Atan2, atan2, y, x); // string functions scalar_expr!(Ascii, ascii, string); @@ -546,6 +547,7 @@ mod test { test_unary_scalar_expr!(Log2, log2); test_unary_scalar_expr!(Log10, log10); test_unary_scalar_expr!(Ln, ln); + test_scalar_expr!(Atan2, atan2, y, x); test_scalar_expr!(Ascii, ascii, input); test_scalar_expr!(BitLength, bit_length, string); diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 331756f8de83..29158e234b79 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -229,6 +229,11 @@ pub fn return_type( BuiltinScalarFunction::Struct => Ok(DataType::Struct(vec![])), + BuiltinScalarFunction::Atan2 => match &input_expr_types[0] { + DataType::Float32 => Ok(DataType::Float32), + _ => Ok(DataType::Float64), + }, + BuiltinScalarFunction::Abs | BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin @@ -540,6 +545,13 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { ], fun.volatility(), ), + BuiltinScalarFunction::Atan2 => Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Float32, DataType::Float32]), + TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]), + ], + fun.volatility(), + ), // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 5f0e711f80e4..a84b00bf1e45 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -308,6 +308,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::Power => { Arc::new(|args| make_scalar_function(math_expressions::power)(args)) } + BuiltinScalarFunction::Atan2 => { + Arc::new(|args| make_scalar_function(math_expressions::atan2)(args)) + } // string functions BuiltinScalarFunction::Array => Arc::new(array_expressions::array), diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 7f41268154a9..16dda93dd134 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -176,11 +176,38 @@ pub fn power(args: &[ArrayRef]) -> Result { } } +pub fn atan2(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Float64 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "y", + "x", + Float64Array, + { f64::atan2 } + )) as ArrayRef), + + DataType::Float32 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "y", + "x", + Float32Array, + { f32::atan2 } + )) as ArrayRef), + + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function atan2", + other + ))), + } +} + #[cfg(test)] mod tests { use super::*; - use arrow::array::{Float64Array, NullArray}; + use arrow::array::{Array, Float64Array, NullArray}; #[test] fn test_random_expression() { @@ -191,4 +218,44 @@ mod tests { assert_eq!(floats.len(), 1); assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0); } + + #[test] + fn test_atan2_f64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y + Arc::new(Float64Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x + ]; + + let result = atan2(&args).expect("fail"); + let floats = result + .as_any() + .downcast_ref::() + .expect("fail"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), (2.0_f64).atan2(1.0)); + assert_eq!(floats.value(1), (-3.0_f64).atan2(2.0)); + assert_eq!(floats.value(2), (4.0_f64).atan2(-3.0)); + assert_eq!(floats.value(3), (-5.0_f64).atan2(-4.0)); + } + + #[test] + fn test_atan2_f32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y + Arc::new(Float32Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x + ]; + + let result = atan2(&args).expect("fail"); + let floats = result + .as_any() + .downcast_ref::() + .expect("fail"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), (2.0_f32).atan2(1.0)); + assert_eq!(floats.value(1), (-3.0_f32).atan2(2.0)); + assert_eq!(floats.value(2), (4.0_f32).atan2(-3.0)); + assert_eq!(floats.value(3), (-5.0_f32).atan2(-4.0)); + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 39c254ea70cd..ec816a419432 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -439,6 +439,7 @@ enum ScalarFunction { Power=64; StructFun=65; FromUnixtime=66; + Atan2=67; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index cb7b111895d6..40ea1bd02500 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -32,9 +32,9 @@ use datafusion_common::{ use datafusion_expr::expr::GroupingSet; use datafusion_expr::expr::GroupingSet::GroupingSets; use datafusion_expr::{ - abs, acos, array, ascii, asin, atan, bit_length, btrim, ceil, character_length, chr, - coalesce, concat_expr, concat_ws_expr, cos, date_part, date_trunc, digest, exp, - floor, from_unixtime, left, ln, log10, log2, + abs, acos, array, ascii, asin, atan, atan2, bit_length, btrim, ceil, + character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, date_part, + date_trunc, digest, exp, floor, from_unixtime, left, ln, log10, log2, logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, now_expr, nullif, octet_length, power, random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, @@ -474,6 +474,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Power => Self::Power, ScalarFunction::StructFun => Self::Struct, ScalarFunction::FromUnixtime => Self::FromUnixtime, + ScalarFunction::Atan2 => Self::Atan2, } } } @@ -1132,6 +1133,10 @@ pub fn parse_expr( ScalarFunction::FromUnixtime => { Ok(from_unixtime(parse_expr(&args[0], registry)?)) } + ScalarFunction::Atan2 => Ok(atan2( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), _ => Err(proto_error( "Protobuf deserialization error: Unsupported scalar function", )), diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index fd5276ca8bc9..42706f602d51 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -1121,6 +1121,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Power => Self::Power, BuiltinScalarFunction::Struct => Self::StructFun, BuiltinScalarFunction::FromUnixtime => Self::FromUnixtime, + BuiltinScalarFunction::Atan2 => Self::Atan2, }; Ok(scalar_function)