Skip to content

Commit

Permalink
[IR Runtime] support mean() in runtime (#1852)
Browse files Browse the repository at this point in the history
  • Loading branch information
BingqingLyu committed Jul 20, 2022
1 parent 0a8e9c1 commit 6616ee0
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 4 deletions.
6 changes: 6 additions & 0 deletions research/dyn_type/src/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,12 @@ impl From<String> for Object {
}
}

impl From<Primitives> for Object {
fn from(v: Primitives) -> Self {
Object::Primitive(v)
}
}

impl<T: Into<Object>> From<Vec<T>> for Object {
fn from(vec: Vec<T>) -> Self {
Object::Vector(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

use std::collections::HashSet;
use std::convert::{TryFrom, TryInto};
use std::ops::Div;

use dyn_type::Primitives;
use ir_common::error::ParsePbError;
use ir_common::generated::algebra as algebra_pb;
use ir_common::generated::algebra::group_by::agg_func::Aggregate;
Expand All @@ -40,6 +42,7 @@ pub enum EntryAccumulator {
ToSet(ToSet<Entry>),
ToDistinctCount(DistinctCount<Entry>),
ToSum(Sum<Entry>),
ToAvg(Sum<Entry>, Count<()>),
}

/// Accumulator for Record, including multiple accumulators for entries(columns) in Record.
Expand Down Expand Up @@ -81,6 +84,10 @@ impl Accumulator<Entry, Entry> for EntryAccumulator {
EntryAccumulator::ToSet(set) => set.accum(next),
EntryAccumulator::ToDistinctCount(distinct_count) => distinct_count.accum(next),
EntryAccumulator::ToSum(sum) => sum.accum(next),
EntryAccumulator::ToAvg(sum, count) => {
sum.accum(next)?;
count.accum(())
}
}
} else {
Ok(())
Expand Down Expand Up @@ -132,6 +139,38 @@ impl Accumulator<Entry, Entry> for EntryAccumulator {
EntryAccumulator::ToSum(sum) => sum
.finalize()?
.ok_or(FnExecError::accum_error("sum_entry is none")),
EntryAccumulator::ToAvg(sum, count) => {
let sum_entry = sum
.finalize()?
.ok_or(FnExecError::accum_error("sum_entry is none"))?;
let cnt = count.finalize()?;
// TODO: confirm if it should be CommonObject::None, or throw error;
let result = CommonObject::None.into();
if cnt == 0 {
warn!("cnt value is 0 in accum avg");
Ok(result)
} else if let Some(sum_val) = sum_entry.as_common_object() {
match sum_val {
CommonObject::None => {
warn!("sum value is none in accum avg");
Ok(result)
}
CommonObject::Prop(prop_val) => {
let primitive_cnt = Primitives::Float(cnt as f64);
let result = prop_val
.as_primitive()
.map(|val| val.div(primitive_cnt))
.map_err(|e| FnExecError::accum_error(&format!("{}", e)))?;
Ok(CommonObject::Prop(result.into()).into())
}
CommonObject::Count(cnt_val) => {
Ok(CommonObject::Prop(object!((*cnt_val) as f64 / cnt as f64)).into())
}
}
} else {
Err(FnExecError::accum_error("unreachable"))
}
}
}
}
}
Expand Down Expand Up @@ -171,10 +210,9 @@ impl AccumFactoryGen for algebra_pb::GroupBy {
EntryAccumulator::ToDistinctCount(DistinctCount { inner: HashSet::new() })
}
Aggregate::Sum => EntryAccumulator::ToSum(Sum { seed: None }),
_ => Err(FnGenError::unsupported_error(&format!(
"Unsupported aggregate kind {:?}",
agg_kind
)))?,
Aggregate::Avg => {
EntryAccumulator::ToAvg(Sum { seed: None }, Count { value: 0, _ph: Default::default() })
}
};
accum_ops.push((entry_accumulator, tag_key, alias));
}
Expand Down Expand Up @@ -214,6 +252,11 @@ impl Encode for EntryAccumulator {
writer.write_u8(6)?;
sum.write_to(writer)?;
}
EntryAccumulator::ToAvg(sum, count) => {
writer.write_u8(7)?;
sum.write_to(writer)?;
count.write_to(writer)?;
}
}
Ok(())
}
Expand Down Expand Up @@ -251,6 +294,11 @@ impl Decode for EntryAccumulator {
let sum = <Sum<Entry>>::read_from(reader)?;
Ok(EntryAccumulator::ToSum(sum))
}
7 => {
let sum = <Sum<Entry>>::read_from(reader)?;
let count = <Count<()>>::read_from(reader)?;
Ok(EntryAccumulator::ToAvg(sum, count))
}
_ => Err(std::io::Error::new(std::io::ErrorKind::Other, "unreachable")),
}
}
Expand Down Expand Up @@ -574,4 +622,31 @@ mod tests {
}
assert_eq!(res, object!(60));
}

// g.V().values('age').mean()
#[test]
fn avg_test() {
let r1 = Record::new(CommonObject::Prop(object!(10)), None);
let r2 = Record::new(CommonObject::Prop(object!(20)), None);
let r3 = Record::new(CommonObject::Prop(object!(30)), None);
let function = pb::group_by::AggFunc {
vars: vec![common_pb::Variable::from("@".to_string())],
aggregate: 7, // avg
alias: None,
};
let fold_opr_pb = pb::GroupBy { mappings: vec![], functions: vec![function] };
let mut result = fold_test(vec![r1, r2, r3], fold_opr_pb);
let mut res = "".into();
if let Some(Ok(record)) = result.next() {
if let Some(entry) = record.get(None) {
res = match entry.as_ref() {
Entry::Element(RecordElement::OffGraph(CommonObject::Prop(obj))) => obj.clone(),
_ => {
unreachable!()
}
};
}
}
assert_eq!(res, object!(20));
}
}

0 comments on commit 6616ee0

Please sign in to comment.