diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index c4574bf04251..f0ee7327b90e 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -32,7 +32,9 @@ use datafusion_common::cast::as_list_array; use datafusion_common::utils::{ compare_rows, get_row_at_idx, take_function_args, SingleRowListArrayBuilder, }; -use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_common::{ + assert_eq_or_internal_err, exec_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -319,9 +321,7 @@ impl Accumulator for ArrayAggAccumulator { return Ok(()); } - if values.len() != 1 { - return internal_err!("expects single batch"); - } + assert_eq_or_internal_err!(values.len(), 1, "expects single batch"); let val = &values[0]; let nulls = if self.ignore_nulls { @@ -349,9 +349,7 @@ impl Accumulator for ArrayAggAccumulator { return Ok(()); } - if states.len() != 1 { - return internal_err!("expects single state"); - } + assert_eq_or_internal_err!(states.len(), 1, "expects single state"); let list_arr = as_list_array(&states[0])?; @@ -472,9 +470,7 @@ impl Accumulator for DistinctArrayAggAccumulator { return Ok(()); } - if states.len() != 1 { - return internal_err!("expects single state"); - } + assert_eq_or_internal_err!(states.len(), 1, "expects single state"); states[0] .as_list::() diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 4ec2a124efee..ef76d1e6ea2d 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -40,7 +40,8 @@ use arrow::datatypes::{ }; use datafusion_common::{ - internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, + assert_eq_or_internal_err, internal_datafusion_err, DataFusionError, Result, + ScalarValue, }; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ @@ -189,12 +190,12 @@ impl AggregateUDFImpl for Median { args: AccumulatorArgs, ) -> Result> { let num_args = args.exprs.len(); - if num_args != 1 { - return internal_err!( - "median should only have 1 arg, but found num args:{}", - args.exprs.len() - ); - } + assert_eq_or_internal_err!( + num_args, + 1, + "median should only have 1 arg, but found num args:{}", + num_args + ); let dt = args.expr_fields[0].data_type().clone(); diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 2f4f9371be58..adf3e47b7d5a 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -27,7 +27,9 @@ use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; -use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::{ + assert_or_internal_err, exec_err, not_impl_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -206,10 +208,11 @@ impl TrivialNthValueAccumulator { /// Create a new order-insensitive NTH_VALUE accumulator based on the given /// item data type. pub fn try_new(n: i64, datatype: &DataType) -> Result { - if n == 0 { - // n cannot be 0 - return internal_err!("Nth value indices are 1 based. 0 is invalid index"); - } + // n cannot be 0 + assert_or_internal_err!( + n != 0, + "Nth value indices are 1 based. 0 is invalid index" + ); Ok(Self { n, values: VecDeque::new(), @@ -339,10 +342,11 @@ impl NthValueAccumulator { ordering_dtypes: &[DataType], ordering_req: LexOrdering, ) -> Result { - if n == 0 { - // n cannot be 0 - return internal_err!("Nth value indices are 1 based. 0 is invalid index"); - } + // n cannot be 0 + assert_or_internal_err!( + n != 0, + "Nth value indices are 1 based. 0 is invalid index" + ); let mut datatypes = vec![datatype.clone()]; datatypes.extend(ordering_dtypes.iter().cloned()); Ok(Self { diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index ce5881732f0e..2807e4bbe8b8 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -34,7 +34,8 @@ use arrow::{ use arrow::array::ArrowNativeTypeOp; use datafusion_common::{ - internal_datafusion_err, internal_err, plan_err, DataFusionError, Result, ScalarValue, + assert_eq_or_internal_err, internal_datafusion_err, plan_err, DataFusionError, + Result, ScalarValue, }; use datafusion_expr::expr::{AggregateFunction, Sort}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -303,12 +304,12 @@ impl AggregateUDFImpl for PercentileCont { args: AccumulatorArgs, ) -> Result> { let num_args = args.exprs.len(); - if num_args != 2 { - return internal_err!( - "percentile_cont should have 2 args, but found num args:{}", - args.exprs.len() - ); - } + assert_eq_or_internal_err!( + num_args, + 2, + "percentile_cont should have 2 args, but found num args:{}", + num_args + ); let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?;