From f0528f03a76519fbdaa4877d5d3720a9ded494aa Mon Sep 17 00:00:00 2001 From: Wei-Ting Kuo Date: Tue, 19 Jul 2022 04:37:58 +0800 Subject: [PATCH 01/11] add atan -> f64 --- datafusion/core/src/logical_plan/mod.rs | 2 +- datafusion/expr/src/built_in_function.rs | 4 ++++ datafusion/expr/src/expr_fn.rs | 2 ++ datafusion/expr/src/function.rs | 8 ++++++++ datafusion/physical-expr/src/functions.rs | 3 +++ datafusion/physical-expr/src/math_expressions.rs | 12 ++++++++++++ datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/from_proto.rs | 9 +++++++-- datafusion/proto/src/to_proto.rs | 1 + 9 files changed, 39 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/logical_plan/mod.rs b/datafusion/core/src/logical_plan/mod.rs index e4e26ad54432..c154651fc888 100644 --- a/datafusion/core/src/logical_plan/mod.rs +++ b/datafusion/core/src/logical_plan/mod.rs @@ -27,7 +27,7 @@ pub use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema, }; pub use datafusion_expr::{ - abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, + abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, atan2, avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, coalesce, col, combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos, count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exists, exp, diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 663888e2ecd8..ffac07ca53f9 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -34,6 +34,8 @@ pub enum BuiltinScalarFunction { Asin, /// atan Atan, + /// atan2 + Atan2, /// ceil Ceil, /// coalesce @@ -181,6 +183,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Acos => Volatility::Immutable, BuiltinScalarFunction::Asin => Volatility::Immutable, BuiltinScalarFunction::Atan => Volatility::Immutable, + BuiltinScalarFunction::Atan2 => Volatility::Immutable, BuiltinScalarFunction::Ceil => Volatility::Immutable, BuiltinScalarFunction::Coalesce => Volatility::Immutable, BuiltinScalarFunction::Cos => Volatility::Immutable, @@ -268,6 +271,7 @@ impl FromStr for BuiltinScalarFunction { "acos" => BuiltinScalarFunction::Acos, "asin" => BuiltinScalarFunction::Asin, "atan" => BuiltinScalarFunction::Atan, + "atan2" => BuiltinScalarFunction::Atan2, "ceil" => BuiltinScalarFunction::Ceil, "cos" => BuiltinScalarFunction::Cos, "exp" => BuiltinScalarFunction::Exp, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index abfd37a7c1c6..97bbd419e4ec 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -304,6 +304,7 @@ unary_scalar_expr!(Log10, log10); unary_scalar_expr!(Ln, ln); unary_scalar_expr!(NullIf, nullif); scalar_expr!(Power, power, base, exponent); +scalar_expr!(Atan2, atan2, y, x); // string functions scalar_expr!(Ascii, ascii, string); @@ -546,6 +547,7 @@ mod test { test_unary_scalar_expr!(Log2, log2); test_unary_scalar_expr!(Log10, log10); test_unary_scalar_expr!(Ln, ln); + test_scalar_expr!(Atan2, atan2, y, x); test_scalar_expr!(Ascii, ascii, input); test_scalar_expr!(BitLength, bit_length, string); diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 331756f8de83..1e256f8c9372 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -229,6 +229,8 @@ pub fn return_type( BuiltinScalarFunction::Struct => Ok(DataType::Struct(vec![])), + BuiltinScalarFunction::Atan2 => Ok(DataType::Float64), + BuiltinScalarFunction::Abs | BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin @@ -540,6 +542,12 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { ], fun.volatility(), ), + BuiltinScalarFunction::Atan2 => Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]), + ], + fun.volatility(), + ), // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 5f0e711f80e4..b72a55287734 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -307,6 +307,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::Trunc => Arc::new(math_expressions::trunc), BuiltinScalarFunction::Power => { Arc::new(|args| make_scalar_function(math_expressions::power)(args)) + }, + BuiltinScalarFunction::Atan2 => { + Arc::new(|args| make_scalar_function(math_expressions::atan2)(args)) } // string functions diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 7f41268154a9..b333b3459f0b 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -176,6 +176,18 @@ pub fn power(args: &[ArrayRef]) -> Result { } } +pub fn atan2(args: &[ArrayRef]) -> Result { + // FIXME other data_type? + Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "y", + "x", + Float64Array, + { f64::atan2 } + )) as ArrayRef) +} + #[cfg(test)] mod tests { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 39c254ea70cd..ec816a419432 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -439,6 +439,7 @@ enum ScalarFunction { Power=64; StructFun=65; FromUnixtime=66; + Atan2=67; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index cb7b111895d6..382323272805 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -32,7 +32,7 @@ use datafusion_common::{ use datafusion_expr::expr::GroupingSet; use datafusion_expr::expr::GroupingSet::GroupingSets; use datafusion_expr::{ - abs, acos, array, ascii, asin, atan, bit_length, btrim, ceil, character_length, chr, + abs, acos, array, ascii, asin, atan, atan2, bit_length, btrim, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, date_part, date_trunc, digest, exp, floor, from_unixtime, left, ln, log10, log2, logical_plan::{PlanType, StringifiedPlan}, @@ -474,6 +474,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Power => Self::Power, ScalarFunction::StructFun => Self::Struct, ScalarFunction::FromUnixtime => Self::FromUnixtime, + ScalarFunction::Atan2 => Self::Atan2, } } } @@ -1131,7 +1132,11 @@ pub fn parse_expr( )), ScalarFunction::FromUnixtime => { Ok(from_unixtime(parse_expr(&args[0], registry)?)) - } + }, + ScalarFunction::Atan2 => Ok(atan2( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), _ => Err(proto_error( "Protobuf deserialization error: Unsupported scalar function", )), diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index fd5276ca8bc9..ab93afb6890c 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -1121,6 +1121,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Power => Self::Power, BuiltinScalarFunction::Struct => Self::StructFun, BuiltinScalarFunction::FromUnixtime => Self::FromUnixtime, + BuiltInWindowFunction::Atan2 => self::Atan2, }; Ok(scalar_function) From 99f150b1a9acf9bf642c191431e68684ade3929c Mon Sep 17 00:00:00 2001 From: Wei-Ting Kuo Date: Tue, 19 Jul 2022 05:30:11 +0800 Subject: [PATCH 02/11] make atan2 support f32 --- datafusion/expr/src/function.rs | 6 +++- .../physical-expr/src/math_expressions.rs | 32 ++++++++++++++----- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 1e256f8c9372..88cdcc16835b 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -229,7 +229,10 @@ pub fn return_type( BuiltinScalarFunction::Struct => Ok(DataType::Struct(vec![])), - BuiltinScalarFunction::Atan2 => Ok(DataType::Float64), + BuiltinScalarFunction::Atan2 => match &input_expr_types[0] { + DataType::Float32 => Ok(DataType::Float32), + _ => Ok(DataType::Float64), + } BuiltinScalarFunction::Abs | BuiltinScalarFunction::Acos @@ -544,6 +547,7 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { ), BuiltinScalarFunction::Atan2 => Signature::one_of( vec![ + TypeSignature::Exact(vec![DataType::Float32, DataType::Float32]), TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]), ], fun.volatility(), diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index b333b3459f0b..c491337d5bef 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -178,14 +178,30 @@ pub fn power(args: &[ArrayRef]) -> Result { pub fn atan2(args: &[ArrayRef]) -> Result { // FIXME other data_type? - Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float64Array, - { f64::atan2 } - )) as ArrayRef) + match args[0].data_type() { + DataType::Float64 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "y", + "x", + Float64Array, + { f64::atan2 } + )) as ArrayRef), + + DataType::Float32 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "y", + "x", + Float32Array, + { f32::atan2 } + )) as ArrayRef), + + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function atan2", + other + ))), + } } #[cfg(test)] From afd1a0229faf84c56b19507433edbf91ffc8a93d Mon Sep 17 00:00:00 2001 From: Wei-Ting Kuo Date: Tue, 19 Jul 2022 07:03:32 +0800 Subject: [PATCH 03/11] add test case for null input --- datafusion/core/tests/sql/expr.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 93347ee41b43..c9c5d955a35c 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -505,6 +505,9 @@ async fn test_mathematical_expressions_with_null() -> Result<()> { test_expression!("power(NULL, 2)", "NULL"); test_expression!("power(NULL, NULL)", "NULL"); test_expression!("power(2, NULL)", "NULL"); + test_expression!("atan2(NULL, NULL)", "NULL"); + test_expression!("atan2(1, NULL)", "NULL"); + test_expression!("atan2(NULL, 1)", "NULL"); Ok(()) } From 025d423d0e424be84759f7c04e4fe02fc37230cb Mon Sep 17 00:00:00 2001 From: Wei-Ting Kuo Date: Tue, 26 Jul 2022 14:13:06 +0800 Subject: [PATCH 04/11] add math in mod.rs --- datafusion/core/tests/sql/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index a7f4cabe9d7b..9533337a5300 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -91,6 +91,7 @@ pub mod intersection; pub mod joins; pub mod json; pub mod limit; +pub mod math; pub mod order; pub mod parquet; pub mod predicates; From 6d33df5e1b03a739b4e1cd85ee2127be09ff25ef Mon Sep 17 00:00:00 2001 From: Wei-Ting Kuo Date: Tue, 26 Jul 2022 18:33:54 +0800 Subject: [PATCH 05/11] fix proto --- datafusion/proto/src/to_proto.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index ab93afb6890c..42706f602d51 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -1121,7 +1121,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Power => Self::Power, BuiltinScalarFunction::Struct => Self::StructFun, BuiltinScalarFunction::FromUnixtime => Self::FromUnixtime, - BuiltInWindowFunction::Atan2 => self::Atan2, + BuiltinScalarFunction::Atan2 => Self::Atan2, }; Ok(scalar_function) From 63915a3a95a1a83c0a85532246b71f267bf53ae7 Mon Sep 17 00:00:00 2001 From: Wei-Ting Kuo Date: Tue, 26 Jul 2022 20:14:19 +0800 Subject: [PATCH 06/11] add sql test for atan2 --- datafusion/core/tests/sql/math.rs | 54 +++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 datafusion/core/tests/sql/math.rs diff --git a/datafusion/core/tests/sql/math.rs b/datafusion/core/tests/sql/math.rs new file mode 100644 index 000000000000..8a6c2908db15 --- /dev/null +++ b/datafusion/core/tests/sql/math.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; +use arrow::array::Float64Array; + +#[tokio::test] +async fn test_atan2() -> Result<()> { + let ctx = SessionContext::new(); + + let t1_schema = Arc::new(Schema::new( + vec![Field::new("x", DataType::Float64, true), + Field::new("y", DataType::Float64, true)])); + + let t1_data = RecordBatch::try_new( + t1_schema.clone(), + vec![Arc::new(Float64Array::from(vec![1.0, 1.0, -1.0, -1.0])), + Arc::new(Float64Array::from(vec![2.0, -2.0, 2.0, -2.0]))], + )?; + let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; + ctx.register_table("t1", Arc::new(t1_table))?; + + let sql = "SELECT atan2(y, x) FROM t1"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------------------+", + "| atan2(t1.y,t1.x) |", + "+---------------------+", + "| 1.1071487177940904 |", + "| -1.1071487177940904 |", + "| 2.0344439357957027 |", + "| -2.0344439357957027 |", + "+---------------------+", + ]; + + assert_batches_eq!(expected, &actual); + + Ok(()) +} \ No newline at end of file From a7462adf316153c2c22f26d993fb90eaf9335810 Mon Sep 17 00:00:00 2001 From: Wei-Ting Kuo Date: Tue, 26 Jul 2022 22:12:35 +0800 Subject: [PATCH 07/11] add text case in math_expressions --- .../physical-expr/src/math_expressions.rs | 39 ++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index c491337d5bef..a8667505afa4 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -208,7 +208,8 @@ pub fn atan2(args: &[ArrayRef]) -> Result { mod tests { use super::*; - use arrow::array::{Float64Array, NullArray}; + use arrow::array::{Float64Array, NullArray, Array}; + #[test] fn test_random_expression() { @@ -219,4 +220,40 @@ mod tests { assert_eq!(floats.len(), 1); assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0); } + + #[test] + fn test_atan2_f64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y + Arc::new(Float64Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x + ]; + + let result = atan2(&args).expect("fail"); + let floats = result.as_any().downcast_ref::().expect("fail"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), 2.0_f64.atan2(1.0)); + assert_eq!(floats.value(1), -3.0_f64.atan2(2.0)); + assert_eq!(floats.value(2), 4.0_f64.atan2(-3.0)); + assert_eq!(floats.value(3), -5.0_f64.atan2(-4.0)); + } + + #[test] + fn test_atan2_f32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y + Arc::new(Float32Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x + ]; + + let result = atan2(&args).expect("fail"); + let floats = result.as_any().downcast_ref::().expect("fail"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), 2.0_f32.atan2(1.0)); + assert_eq!(floats.value(1), -3.0_f32.atan2(2.0)); + assert_eq!(floats.value(2), 4.0_f32.atan2(-3.0)); + assert_eq!(floats.value(3), -5.0_f32.atan2(-4.0)); + } + } + From c5ca6630d79807b1ebf872de0746f647bb45e74a Mon Sep 17 00:00:00 2001 From: Wei-Ting Kuo Date: Tue, 26 Jul 2022 22:14:09 +0800 Subject: [PATCH 08/11] cargo fmt --- datafusion/core/src/logical_plan/mod.rs | 6 ++--- datafusion/core/tests/sql/math.rs | 17 +++++++------ datafusion/expr/src/function.rs | 2 +- datafusion/physical-expr/src/functions.rs | 2 +- .../physical-expr/src/math_expressions.rs | 25 +++++++++++-------- datafusion/proto/src/from_proto.rs | 8 +++--- 6 files changed, 33 insertions(+), 27 deletions(-) diff --git a/datafusion/core/src/logical_plan/mod.rs b/datafusion/core/src/logical_plan/mod.rs index c154651fc888..9b391983748c 100644 --- a/datafusion/core/src/logical_plan/mod.rs +++ b/datafusion/core/src/logical_plan/mod.rs @@ -27,9 +27,9 @@ pub use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema, }; pub use datafusion_expr::{ - abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, atan2, - avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, coalesce, col, - combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos, count, + abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, + atan2, avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, coalesce, + col, combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos, count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exists, exp, expr_rewriter, expr_rewriter::{ diff --git a/datafusion/core/tests/sql/math.rs b/datafusion/core/tests/sql/math.rs index 8a6c2908db15..cff7120a20a1 100644 --- a/datafusion/core/tests/sql/math.rs +++ b/datafusion/core/tests/sql/math.rs @@ -22,14 +22,17 @@ use arrow::array::Float64Array; async fn test_atan2() -> Result<()> { let ctx = SessionContext::new(); - let t1_schema = Arc::new(Schema::new( - vec![Field::new("x", DataType::Float64, true), - Field::new("y", DataType::Float64, true)])); - + let t1_schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Float64, true), + Field::new("y", DataType::Float64, true), + ])); + let t1_data = RecordBatch::try_new( t1_schema.clone(), - vec![Arc::new(Float64Array::from(vec![1.0, 1.0, -1.0, -1.0])), - Arc::new(Float64Array::from(vec![2.0, -2.0, 2.0, -2.0]))], + vec![ + Arc::new(Float64Array::from(vec![1.0, 1.0, -1.0, -1.0])), + Arc::new(Float64Array::from(vec![2.0, -2.0, 2.0, -2.0])), + ], )?; let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; ctx.register_table("t1", Arc::new(t1_table))?; @@ -51,4 +54,4 @@ async fn test_atan2() -> Result<()> { assert_batches_eq!(expected, &actual); Ok(()) -} \ No newline at end of file +} diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 88cdcc16835b..29158e234b79 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -232,7 +232,7 @@ pub fn return_type( BuiltinScalarFunction::Atan2 => match &input_expr_types[0] { DataType::Float32 => Ok(DataType::Float32), _ => Ok(DataType::Float64), - } + }, BuiltinScalarFunction::Abs | BuiltinScalarFunction::Acos diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index b72a55287734..a84b00bf1e45 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -307,7 +307,7 @@ pub fn create_physical_fun( BuiltinScalarFunction::Trunc => Arc::new(math_expressions::trunc), BuiltinScalarFunction::Power => { Arc::new(|args| make_scalar_function(math_expressions::power)(args)) - }, + } BuiltinScalarFunction::Atan2 => { Arc::new(|args| make_scalar_function(math_expressions::atan2)(args)) } diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index a8667505afa4..9335d65b9a96 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -195,11 +195,11 @@ pub fn atan2(args: &[ArrayRef]) -> Result { "x", Float32Array, { f32::atan2 } - )) as ArrayRef), + )) as ArrayRef), other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function atan2", - other + "Unsupported data type {:?} for function atan2", + other ))), } } @@ -208,8 +208,7 @@ pub fn atan2(args: &[ArrayRef]) -> Result { mod tests { use super::*; - use arrow::array::{Float64Array, NullArray, Array}; - + use arrow::array::{Array, Float64Array, NullArray}; #[test] fn test_random_expression() { @@ -229,13 +228,16 @@ mod tests { ]; let result = atan2(&args).expect("fail"); - let floats = result.as_any().downcast_ref::().expect("fail"); + let floats = result + .as_any() + .downcast_ref::() + .expect("fail"); assert_eq!(floats.len(), 4); assert_eq!(floats.value(0), 2.0_f64.atan2(1.0)); assert_eq!(floats.value(1), -3.0_f64.atan2(2.0)); assert_eq!(floats.value(2), 4.0_f64.atan2(-3.0)); - assert_eq!(floats.value(3), -5.0_f64.atan2(-4.0)); + assert_eq!(floats.value(3), -5.0_f64.atan2(-4.0)); } #[test] @@ -246,14 +248,15 @@ mod tests { ]; let result = atan2(&args).expect("fail"); - let floats = result.as_any().downcast_ref::().expect("fail"); + let floats = result + .as_any() + .downcast_ref::() + .expect("fail"); assert_eq!(floats.len(), 4); assert_eq!(floats.value(0), 2.0_f32.atan2(1.0)); assert_eq!(floats.value(1), -3.0_f32.atan2(2.0)); assert_eq!(floats.value(2), 4.0_f32.atan2(-3.0)); - assert_eq!(floats.value(3), -5.0_f32.atan2(-4.0)); + assert_eq!(floats.value(3), -5.0_f32.atan2(-4.0)); } - } - diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 382323272805..40ea1bd02500 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -32,9 +32,9 @@ use datafusion_common::{ use datafusion_expr::expr::GroupingSet; use datafusion_expr::expr::GroupingSet::GroupingSets; use datafusion_expr::{ - abs, acos, array, ascii, asin, atan, atan2, bit_length, btrim, ceil, character_length, chr, - coalesce, concat_expr, concat_ws_expr, cos, date_part, date_trunc, digest, exp, - floor, from_unixtime, left, ln, log10, log2, + abs, acos, array, ascii, asin, atan, atan2, bit_length, btrim, ceil, + character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, date_part, + date_trunc, digest, exp, floor, from_unixtime, left, ln, log10, log2, logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, now_expr, nullif, octet_length, power, random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, @@ -1132,7 +1132,7 @@ pub fn parse_expr( )), ScalarFunction::FromUnixtime => { Ok(from_unixtime(parse_expr(&args[0], registry)?)) - }, + } ScalarFunction::Atan2 => Ok(atan2( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, From 86fe7c97d3ea37484bf84ddba55803ae6ead6337 Mon Sep 17 00:00:00 2001 From: Wei-Ting Kuo Date: Tue, 26 Jul 2022 22:38:46 +0800 Subject: [PATCH 09/11] fix error from clippy --- datafusion/physical-expr/src/math_expressions.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 9335d65b9a96..bd305d3e3314 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -234,10 +234,10 @@ mod tests { .expect("fail"); assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), 2.0_f64.atan2(1.0)); - assert_eq!(floats.value(1), -3.0_f64.atan2(2.0)); - assert_eq!(floats.value(2), 4.0_f64.atan2(-3.0)); - assert_eq!(floats.value(3), -5.0_f64.atan2(-4.0)); + assert_eq!(floats.value(0), (2.0_f64).atan2(1.0)); + assert_eq!(floats.value(1), (-3.0_f64).atan2(2.0)); + assert_eq!(floats.value(2), (4.0_f64).atan2(-3.0)); + assert_eq!(floats.value(3), (-5.0_f64).atan2(-4.0)); } #[test] @@ -254,9 +254,9 @@ mod tests { .expect("fail"); assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), 2.0_f32.atan2(1.0)); - assert_eq!(floats.value(1), -3.0_f32.atan2(2.0)); - assert_eq!(floats.value(2), 4.0_f32.atan2(-3.0)); - assert_eq!(floats.value(3), -5.0_f32.atan2(-4.0)); + assert_eq!(floats.value(0), (2.0_f32).atan2(1.0)); + assert_eq!(floats.value(1), (-3.0_f32).atan2(2.0)); + assert_eq!(floats.value(2), (4.0_f32).atan2(-3.0)); + assert_eq!(floats.value(3), (-5.0_f32).atan2(-4.0)); } } From 227b94a15e8db717fc355f2b71c71cfcecadc916 Mon Sep 17 00:00:00 2001 From: Wei-Ting Kuo Date: Thu, 28 Jul 2022 02:13:07 +0800 Subject: [PATCH 10/11] remove useless comment --- datafusion/physical-expr/src/math_expressions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index bd305d3e3314..6127c39eac77 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -177,7 +177,7 @@ pub fn power(args: &[ArrayRef]) -> Result { } pub fn atan2(args: &[ArrayRef]) -> Result { - // FIXME other data_type? + match args[0].data_type() { DataType::Float64 => Ok(Arc::new(make_function_inputs2!( &args[0], From bde505ec640af7f6a4c53f44bef830dcb9d41330 Mon Sep 17 00:00:00 2001 From: Wei-Ting Kuo Date: Thu, 28 Jul 2022 02:13:46 +0800 Subject: [PATCH 11/11] apply cargo fmt --- datafusion/physical-expr/src/math_expressions.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 6127c39eac77..16dda93dd134 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -177,7 +177,6 @@ pub fn power(args: &[ArrayRef]) -> Result { } pub fn atan2(args: &[ArrayRef]) -> Result { - match args[0].data_type() { DataType::Float64 => Ok(Arc::new(make_function_inputs2!( &args[0],