Skip to content

Commit

Permalink
Implement grouped aggregations. (#704)
Browse files Browse the repository at this point in the history
- Fills out implementation of `agg_groupby` for `Table` in Rust. (The
glue layers to the Python API are already done.)
- Adds a `sort_grouper` function for Table, which produces an argsort of
the table, with identical values grouped together.
- Uses this to implement a grouped aggregation, which will go through
each group, take on the indices, and then eval the aggregation
expressions.
- Tests included. Tested multicolumn, multitype, empty grouped
aggregations.
  • Loading branch information
xcharleslin committed Mar 21, 2023
1 parent e7df931 commit 6cea757
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ mod sum;
mod take;
mod utf8;

pub use sort::build_multi_array_compare;

pub trait DaftCompare<Rhs> {
type Output;

Expand Down
5 changes: 4 additions & 1 deletion src/array/ops/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ use arrow2::{

use super::arrow2::sort::primitive::common::multi_column_idx_sort;

fn build_multi_array_compare(arrays: &[Series], descending: &[bool]) -> DaftResult<DynComparator> {
pub fn build_multi_array_compare(
arrays: &[Series],
descending: &[bool],
) -> DaftResult<DynComparator> {
let mut cmp_list = Vec::with_capacity(arrays.len());
for (s, desc) in arrays.iter().zip(descending.iter()) {
cmp_list.push(build_compare_with_nulls(
Expand Down
123 changes: 119 additions & 4 deletions src/table/ops/agg.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
use std::cmp::Ordering;

use crate::{
array::{ops::build_multi_array_compare, BaseArray},
datatypes::{UInt64Array, UInt64Type},
dsl::{AggExpr, Expr},
error::DaftResult,
series::Series,
table::Table,
};

impl Table {
pub fn agg(&self, to_agg: &[(Expr, &str)], group_by: &[Expr]) -> DaftResult<Table> {
// Dispatch depending on whether we're doing groupby or just a global agg.
match group_by.len() {
0 => self.agg_global(to_agg),
_ => self.agg_groupby(to_agg, group_by),
Expand All @@ -14,8 +20,10 @@ impl Table {

pub fn agg_global(&self, to_agg: &[(Expr, &str)]) -> DaftResult<Table> {
// Convert the input (child, name) exprs to the enum form.
// e.g. (expr, "sum") to Expr::Agg(AggEXpr::Sum(expr))
// (We could do this at the pyo3 layer but I'm not sure why they're strings in the first place)
// e.g. (expr, "sum") to Expr::Agg(AggExpr::Sum(expr))
//
// NOTE: We may want to do this elsewhere later.
// See https://github.com/Eventual-Inc/Daft/pull/702#discussion_r1136597811
let agg_expr_list = to_agg
.iter()
.map(|(e, s)| AggExpr::from_name_and_child_expr(s, e))
Expand All @@ -24,10 +32,117 @@ impl Table {
.iter()
.map(|ae| Expr::Agg(ae.clone()))
.collect::<Vec<Expr>>();

self.eval_expression_list(&expr_list)
}

pub fn agg_groupby(&self, _to_agg: &[(Expr, &str)], _group_by: &[Expr]) -> DaftResult<Table> {
todo!()
pub fn agg_groupby(&self, to_agg: &[(Expr, &str)], group_by: &[Expr]) -> DaftResult<Table> {
// Table with just the groupby columns.
let groupby_table = self.eval_expression_list(group_by)?;

// Get the unique group keys (by indices)
// and the grouped values (also by indices, one array of indices per group).
let (groupkey_indices, groupvals_indices) = groupby_table.sort_grouper()?;

// Table with the aggregated (deduplicated) group keys.
let groupkeys_table = {
let indices_as_arrow = arrow2::array::PrimitiveArray::from_vec(groupkey_indices);
let indices_as_series =
UInt64Array::from(("", Box::new(indices_as_arrow))).into_series();
groupby_table.take(&indices_as_series)?
};

// Table with the aggregated values, one row for each group.
let agged_values_table = {
// Agg each group into its own table.
let mut agged_groups: Vec<Self> = vec![];
for group_indices_array in groupvals_indices.iter() {
let group = {
let indices_as_arrow = group_indices_array.downcast();
let indices_as_series =
UInt64Array::from(("", Box::new(indices_as_arrow.clone()))).into_series();
self.take(&indices_as_series)?
};
let agged_group = group.agg_global(to_agg)?;
agged_groups.push(agged_group.to_owned());
}

match agged_groups.len() {
0 => self.head(0)?.agg_global(to_agg)?.head(0)?,
_ => Self::concat(agged_groups.iter().collect::<Vec<&Self>>().as_slice())?,
}
};

// Combine the groupkey columns and the aggregation result columns.
Self::from_columns(
[
&groupkeys_table.columns[..],
&agged_values_table.columns[..],
]
.concat(),
)
}

fn sort_grouper(&self) -> DaftResult<(Vec<u64>, Vec<UInt64Array>)> {
// Argsort the table, but also group identical values together.
//
// Given a table, returns a tuple:
// 1. An argsort of the entire table, deduplicated.
// 2. An argsort of the entire table, with identical values grouped.
//
// e.g. given a table [B, B, A, B, C, C]
// returns: (
// [2, 0, 4] <-- indices of A, B, and C
// [[2], [0, 1, 3], [4, 5]] <--- indices of all A, all B, all C
// )

// Begin by doing the argsort.
let argsort_series =
Series::argsort_multikey(self.columns.as_slice(), &vec![false; self.columns.len()])?;
let argsort_array = argsort_series.downcast::<UInt64Type>()?;

// The result indices.
let mut key_indices: Vec<u64> = vec![];
let mut values_indices: Vec<UInt64Array> = vec![];

let comparator =
build_multi_array_compare(self.columns.as_slice(), &vec![false; self.columns.len()])?;

// To group the argsort values together, we will traverse the table in argsort order,
// collecting the indices traversed whenever the table value changes.

// The current table value we're looking at, but represented only by the index in the table.
// For convenience, also keep the index's index in the argarray.

let mut group_begin_indices: Option<(usize, usize)> = None;

for (argarray_index, table_index) in argsort_array.downcast().iter().enumerate() {
let table_index = *table_index.unwrap() as usize;

match group_begin_indices {
None => group_begin_indices = Some((table_index, argarray_index)),
Some((begin_table_index, begin_argarray_index)) => {
let comp_result = comparator(begin_table_index, table_index);
if comp_result != Ordering::Equal {
// The value has changed.
// Record results for the previous group.
key_indices.push(begin_table_index as u64);
values_indices
.push(argsort_array.slice(begin_argarray_index, argarray_index)?);

// Update the current value.
group_begin_indices = Some((table_index, argarray_index));
}
}
}
}

// Record results for the last group (since the for loop doesn't detect the last group closing).
if let Some((begin_table_index, begin_argsort_index)) = group_begin_indices {
key_indices.push(begin_table_index as u64);
values_indices.push(argsort_array.slice(begin_argsort_index, argsort_array.len())?);
}

Ok((key_indices, values_indices))
}
}
76 changes: 76 additions & 0 deletions tests/table/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,82 @@ def test_table_agg_global(case) -> None:
assert result[key] == value


@pytest.mark.parametrize(
"groups_and_aggs",
[
(["col_A"], ["col_B"]),
(["col_A", "col_B"], []),
],
)
def test_table_agg_groupby_empty(groups_and_aggs) -> None:
groups, aggs = groups_and_aggs
daft_table = Table.from_pydict({"col_A": [], "col_B": []})
daft_table = daft_table.agg(
[(col(a), "count") for a in aggs],
[col(g).cast(DataType.int32()) for g in groups],
)
res = daft_table.to_pydict()

assert res == {"col_A": [], "col_B": []}


test_table_agg_groupby_cases = [
{
# Group by strings.
"groups": ["name"],
"aggs": [("cookies", "sum"), ("name", "count")],
"expected": {"name": ["Alice", "Bob", None], "sum": [None, 10, 7], "count": [4, 4, 0]},
},
{
# Group by numbers.
"groups": ["cookies"],
"aggs": [("name", "count")],
"expected": {"cookies": [2, 5, None], "count": [0, 2, 6]},
},
{
# Group by multicol.
"groups": ["name", "cookies"],
"aggs": [("name", "count")],
"expected": {
"name": ["Alice", "Bob", "Bob", None, None, None],
"cookies": [None, 5, None, 2, 5, None],
"count": [4, 2, 2, 0, 0, 0],
},
},
]


@pytest.mark.parametrize(
"case", test_table_agg_groupby_cases, ids=[f"{case['groups']}" for case in test_table_agg_groupby_cases]
)
def test_table_agg_groupby(case) -> None:
values = [
("Bob", None),
("Bob", None),
("Bob", 5),
("Bob", 5),
(None, None),
(None, 5),
(None, None),
(None, 2),
("Alice", None),
("Alice", None),
("Alice", None),
("Alice", None),
]
daft_table = Table.from_pydict(
{
"name": [_[0] for _ in values],
"cookies": [_[1] for _ in values],
}
)
daft_table = daft_table.agg(
[(col(aggcol).alias(aggfn), aggfn) for aggcol, aggfn in case["aggs"]],
[col(group) for group in case["groups"]],
)
assert daft_table.to_pydict() == case["expected"]


import operator as ops

OPS = [ops.add, ops.sub, ops.mul, ops.truediv, ops.mod, ops.lt, ops.le, ops.eq, ops.ne, ops.ge, ops.gt]
Expand Down

0 comments on commit 6cea757

Please sign in to comment.