From 95bb9388bfb252dddcaf33cd1c91434cdb461abe Mon Sep 17 00:00:00 2001 From: Igor Izvekov Date: Thu, 9 Mar 2023 22:56:06 +0300 Subject: [PATCH 1/2] Minor: add the concise way for matching numerics --- .../expr/src/type_coercion/aggregates.rs | 172 ++++-------------- 1 file changed, 34 insertions(+), 138 deletions(-) diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index fca851ce63c9..df2e4aa97608 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -314,77 +314,45 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { /// function return type of variance pub fn variance_return_type(arg_type: &DataType) -> Result { - match arg_type { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => Ok(DataType::Float64), - other => Err(DataFusionError::Plan(format!( - "VAR does not support {other:?}" - ))), + if NUMERICS.contains(&arg_type) { + Ok(DataType::Float64) + } else { + Err(DataFusionError::Plan(format!( + "VAR does not support {arg_type:?}" + ))) } } /// function return type of covariance pub fn covariance_return_type(arg_type: &DataType) -> Result { - match arg_type { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => Ok(DataType::Float64), - other => Err(DataFusionError::Plan(format!( - "COVAR does not support {other:?}" - ))), + if NUMERICS.contains(&arg_type) { + Ok(DataType::Float64) + } else { + Err(DataFusionError::Plan(format!( + "COVAR does not support {arg_type:?}" + ))) } } /// function return type of correlation pub fn correlation_return_type(arg_type: &DataType) -> Result { - match arg_type { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => Ok(DataType::Float64), - other => Err(DataFusionError::Plan(format!( - "CORR does not support {other:?}" - ))), + if NUMERICS.contains(&arg_type) { + Ok(DataType::Float64) + } else { + Err(DataFusionError::Plan(format!( + "CORR does not support {arg_type:?}" + ))) } } /// function return type of standard deviation pub fn stddev_return_type(arg_type: &DataType) -> Result { - match arg_type { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => Ok(DataType::Float64), - other => Err(DataFusionError::Plan(format!( - "STDDEV does not support {other:?}" - ))), + if NUMERICS.contains(&arg_type) { + Ok(DataType::Float64) + } else { + Err(DataFusionError::Plan(format!( + "STDDEV does not support {arg_type:?}" + ))) } } @@ -398,16 +366,7 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); Ok(DataType::Decimal128(new_precision, new_scale)) } - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => Ok(DataType::Float64), + arg_type if NUMERICS.contains(&arg_type) => Ok(DataType::Float64), other => Err(DataFusionError::Plan(format!( "AVG does not support {other:?}" ))), @@ -417,98 +376,44 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) + arg_type if NUMERICS.contains(&arg_type) + || matches!(arg_type, DataType::Decimal128(_, _)) ) } pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) + arg_type if NUMERICS.contains(&arg_type) + || matches!(arg_type, DataType::Decimal128(_, _)) ) } pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 + arg_type if NUMERICS.contains(&arg_type) ) } pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 + arg_type if NUMERICS.contains(&arg_type) ) } pub fn is_stddev_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 + arg_type if NUMERICS.contains(&arg_type) ) } pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 + arg_type if NUMERICS.contains(&arg_type) ) } @@ -531,16 +436,7 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool { pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 + arg_type if NUMERICS.contains(&arg_type) ) } From 904c74416c7465ce9680618cd32633a14d51df77 Mon Sep 17 00:00:00 2001 From: Igor Izvekov Date: Thu, 9 Mar 2023 23:33:12 +0300 Subject: [PATCH 2/2] fix: use values of argument type except links --- .../expr/src/type_coercion/aggregates.rs | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index df2e4aa97608..3ad197afb64a 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -314,7 +314,7 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { /// function return type of variance pub fn variance_return_type(arg_type: &DataType) -> Result { - if NUMERICS.contains(&arg_type) { + if NUMERICS.contains(arg_type) { Ok(DataType::Float64) } else { Err(DataFusionError::Plan(format!( @@ -325,7 +325,7 @@ pub fn variance_return_type(arg_type: &DataType) -> Result { /// function return type of covariance pub fn covariance_return_type(arg_type: &DataType) -> Result { - if NUMERICS.contains(&arg_type) { + if NUMERICS.contains(arg_type) { Ok(DataType::Float64) } else { Err(DataFusionError::Plan(format!( @@ -336,7 +336,7 @@ pub fn covariance_return_type(arg_type: &DataType) -> Result { /// function return type of correlation pub fn correlation_return_type(arg_type: &DataType) -> Result { - if NUMERICS.contains(&arg_type) { + if NUMERICS.contains(arg_type) { Ok(DataType::Float64) } else { Err(DataFusionError::Plan(format!( @@ -347,7 +347,7 @@ pub fn correlation_return_type(arg_type: &DataType) -> Result { /// function return type of standard deviation pub fn stddev_return_type(arg_type: &DataType) -> Result { - if NUMERICS.contains(&arg_type) { + if NUMERICS.contains(arg_type) { Ok(DataType::Float64) } else { Err(DataFusionError::Plan(format!( @@ -366,7 +366,7 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); Ok(DataType::Decimal128(new_precision, new_scale)) } - arg_type if NUMERICS.contains(&arg_type) => Ok(DataType::Float64), + arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), other => Err(DataFusionError::Plan(format!( "AVG does not support {other:?}" ))), @@ -376,7 +376,7 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - arg_type if NUMERICS.contains(&arg_type) + arg_type if NUMERICS.contains(arg_type) || matches!(arg_type, DataType::Decimal128(_, _)) ) } @@ -384,7 +384,7 @@ pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - arg_type if NUMERICS.contains(&arg_type) + arg_type if NUMERICS.contains(arg_type) || matches!(arg_type, DataType::Decimal128(_, _)) ) } @@ -392,28 +392,28 @@ pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool { pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - arg_type if NUMERICS.contains(&arg_type) + arg_type if NUMERICS.contains(arg_type) ) } pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - arg_type if NUMERICS.contains(&arg_type) + arg_type if NUMERICS.contains(arg_type) ) } pub fn is_stddev_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - arg_type if NUMERICS.contains(&arg_type) + arg_type if NUMERICS.contains(arg_type) ) } pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - arg_type if NUMERICS.contains(&arg_type) + arg_type if NUMERICS.contains(arg_type) ) } @@ -436,7 +436,7 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool { pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, - arg_type if NUMERICS.contains(&arg_type) + arg_type if NUMERICS.contains(arg_type) ) }