From b2a3a3295999de52509b5157b78dceb4f9c6b7ef Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Fri, 3 Mar 2023 17:04:47 +0100 Subject: [PATCH] refactor: make GeometricMean not to have update and merge --- datafusion-examples/examples/simple_udaf.rs | 65 +++++++-------------- 1 file changed, 22 insertions(+), 43 deletions(-) diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index d171f6579bfe..b858ce7ebf1f 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -65,40 +65,6 @@ impl GeometricMean { pub fn new() -> Self { GeometricMean { n: 0, prod: 1.0 } } - - // this function receives one entry per argument of this accumulator. - // DataFusion calls this function on every row, and expects this function to update the accumulator's state. - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - // this is a one-argument UDAF, and thus we use `0`. - let value = &values[0]; - match value { - // here we map `ScalarValue` to our internal state. `Float64` indicates that this function - // only accepts Float64 as its argument (DataFusion does try to coerce arguments to this type) - // - // Note that `.map` here ensures that we ignore Nulls. - ScalarValue::Float64(e) => e.map(|value| { - self.prod *= value; - self.n += 1; - }), - _ => unreachable!(""), - }; - Ok(()) - } - - // this function receives states from other accumulators (Vec) - // and updates the accumulator. - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - let prod = &states[0]; - let n = &states[1]; - match (prod, n) { - (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) => { - self.prod *= prod; - self.n += n; - } - _ => unreachable!(""), - }; - Ok(()) - } } // UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions @@ -128,28 +94,41 @@ impl Accumulator for GeometricMean { if values.is_empty() { return Ok(()); } - (0..values[0].len()).try_for_each(|index| { - let v = values - .iter() - .map(|array| ScalarValue::try_from_array(array, index)) - .collect::>>()?; - self.update(&v) + let arr = &values[0]; + (0..arr.len()).try_for_each(|index| { + let v = ScalarValue::try_from_array(arr, index)?; + + if let ScalarValue::Float64(Some(value)) = v { + self.prod *= value; + self.n += 1; + } else { + unreachable!("") + } + Ok(()) }) } // Optimization hint: this trait also supports `update_batch` and `merge_batch`, // that can be used to perform these operations on arrays instead of single values. - // By default, these methods call `update` and `merge` row by row fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); } - (0..states[0].len()).try_for_each(|index| { + let arr = &states[0]; + (0..arr.len()).try_for_each(|index| { let v = states .iter() .map(|array| ScalarValue::try_from_array(array, index)) .collect::>>()?; - self.merge(&v) + if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) = + (&v[0], &v[1]) + { + self.prod *= prod; + self.n += n; + } else { + unreachable!("") + } + Ok(()) }) }