Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 22 additions & 43 deletions datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,40 +65,6 @@ impl GeometricMean {
pub fn new() -> Self {
GeometricMean { n: 0, prod: 1.0 }
}

// this function receives one entry per argument of this accumulator.
// DataFusion calls this function on every row, and expects this function to update the accumulator's state.
fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
// this is a one-argument UDAF, and thus we use `0`.
let value = &values[0];
match value {
// here we map `ScalarValue` to our internal state. `Float64` indicates that this function
// only accepts Float64 as its argument (DataFusion does try to coerce arguments to this type)
//
// Note that `.map` here ensures that we ignore Nulls.
ScalarValue::Float64(e) => e.map(|value| {
self.prod *= value;
self.n += 1;
}),
_ => unreachable!(""),
};
Ok(())
}

// this function receives states from other accumulators (Vec<ScalarValue>)
// and updates the accumulator.
fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
let prod = &states[0];
let n = &states[1];
match (prod, n) {
(ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) => {
self.prod *= prod;
self.n += n;
}
_ => unreachable!(""),
};
Ok(())
}
}

// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions
Expand Down Expand Up @@ -128,28 +94,41 @@ impl Accumulator for GeometricMean {
if values.is_empty() {
return Ok(());
}
(0..values[0].len()).try_for_each(|index| {
let v = values
.iter()
.map(|array| ScalarValue::try_from_array(array, index))
.collect::<Result<Vec<_>>>()?;
self.update(&v)
let arr = &values[0];
(0..arr.len()).try_for_each(|index| {
let v = ScalarValue::try_from_array(arr, index)?;

if let ScalarValue::Float64(Some(value)) = v {
self.prod *= value;
self.n += 1;
} else {
unreachable!("")
}
Ok(())
})
}

// Optimization hint: this trait also supports `update_batch` and `merge_batch`,
// that can be used to perform these operations on arrays instead of single values.
// By default, these methods call `update` and `merge` row by row
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
(0..states[0].len()).try_for_each(|index| {
let arr = &states[0];
(0..arr.len()).try_for_each(|index| {
let v = states
.iter()
.map(|array| ScalarValue::try_from_array(array, index))
.collect::<Result<Vec<_>>>()?;
self.merge(&v)
if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) =
(&v[0], &v[1])
{
self.prod *= prod;
self.n += n;
} else {
unreachable!("")
}
Ok(())
})
}

Expand Down