Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement grouped aggregations. #704

Merged
merged 11 commits into from Mar 21, 2023
2 changes: 2 additions & 0 deletions src/array/ops/mod.rs
Expand Up @@ -23,6 +23,8 @@ mod sum;
mod take;
mod utf8;

pub use sort::build_multi_array_compare;

pub trait DaftCompare<Rhs> {
type Output;

Expand Down
5 changes: 4 additions & 1 deletion src/array/ops/sort.rs
Expand Up @@ -17,7 +17,10 @@ use arrow2::{

use super::arrow2::sort::primitive::common::multi_column_idx_sort;

fn build_multi_array_compare(arrays: &[Series], descending: &[bool]) -> DaftResult<DynComparator> {
pub fn build_multi_array_compare(
arrays: &[Series],
descending: &[bool],
) -> DaftResult<DynComparator> {
let mut cmp_list = Vec::with_capacity(arrays.len());
for (s, desc) in arrays.iter().zip(descending.iter()) {
cmp_list.push(build_compare_with_nulls(
Expand Down
123 changes: 119 additions & 4 deletions src/table/ops/agg.rs
@@ -1,11 +1,17 @@
use std::cmp::Ordering;

use crate::{
array::{ops::build_multi_array_compare, BaseArray},
datatypes::{UInt64Array, UInt64Type},
dsl::{AggExpr, Expr},
error::DaftResult,
series::Series,
table::Table,
};

impl Table {
pub fn agg(&self, to_agg: &[(Expr, &str)], group_by: &[Expr]) -> DaftResult<Table> {
// Dispatch depending on whether we're doing groupby or just a global agg.
match group_by.len() {
0 => self.agg_global(to_agg),
_ => self.agg_groupby(to_agg, group_by),
Expand All @@ -14,8 +20,10 @@ impl Table {

pub fn agg_global(&self, to_agg: &[(Expr, &str)]) -> DaftResult<Table> {
// Convert the input (child, name) exprs to the enum form.
// e.g. (expr, "sum") to Expr::Agg(AggEXpr::Sum(expr))
// (We could do this at the pyo3 layer but I'm not sure why they're strings in the first place)
// e.g. (expr, "sum") to Expr::Agg(AggExpr::Sum(expr))
//
// NOTE: We may want to do this elsewhere later.
// See https://github.com/Eventual-Inc/Daft/pull/702#discussion_r1136597811
let agg_expr_list = to_agg
.iter()
.map(|(e, s)| AggExpr::from_name_and_child_expr(s, e))
Expand All @@ -24,10 +32,117 @@ impl Table {
.iter()
.map(|ae| Expr::Agg(ae.clone()))
.collect::<Vec<Expr>>();

self.eval_expression_list(&expr_list)
}

pub fn agg_groupby(&self, _to_agg: &[(Expr, &str)], _group_by: &[Expr]) -> DaftResult<Table> {
todo!()
pub fn agg_groupby(&self, to_agg: &[(Expr, &str)], group_by: &[Expr]) -> DaftResult<Table> {
// Table with just the groupby columns.
let groupby_table = self.eval_expression_list(group_by)?;

// Get the unique group keys (by indices)
// and the grouped values (also by indices, one array of indices per group).
let (groupkey_indices, groupvals_indices) = groupby_table.sort_grouper()?;

// Table with the aggregated (deduplicated) group keys.
let groupkeys_table = {
let indices_as_arrow = arrow2::array::PrimitiveArray::from_vec(groupkey_indices);
let indices_as_series =
UInt64Array::from(("", Box::new(indices_as_arrow))).into_series();
groupby_table.take(&indices_as_series)?
};

// Table with the aggregated values, one row for each group.
let agged_values_table = {
// Agg each group into its own table.
let mut agged_groups: Vec<Self> = vec![];
for group_indices_array in groupvals_indices.iter() {
let group = {
let indices_as_arrow = group_indices_array.downcast();
let indices_as_series =
UInt64Array::from(("", Box::new(indices_as_arrow.clone()))).into_series();
self.take(&indices_as_series)?
};
let agged_group = group.agg_global(to_agg)?;
agged_groups.push(agged_group.to_owned());
}

match agged_groups.len() {
0 => self.head(0)?.agg_global(to_agg)?.head(0)?,
_ => Self::concat(agged_groups.iter().collect::<Vec<&Self>>().as_slice())?,
}
};

// Combine the groupkey columns and the aggregation result columns.
Self::from_columns(
[
&groupkeys_table.columns[..],
&agged_values_table.columns[..],
]
.concat(),
)
}

fn sort_grouper(&self) -> DaftResult<(Vec<u64>, Vec<UInt64Array>)> {
// Argsort the table, but also group identical values together.
//
// Given a table, returns a tuple:
// 1. An argsort of the entire table, deduplicated.
// 2. An argsort of the entire table, with identical values grouped.
//
// e.g. given a table [B, B, A, B, C, C]
// returns: (
// [2, 0, 4] <-- indices of A, B, and C
// [[2], [0, 1, 3], [4, 5]] <--- indices of all A, all B, all C
// )

// Begin by doing the argsort.
let argsort_series =
Series::argsort_multikey(self.columns.as_slice(), &vec![false; self.columns.len()])?;
let argsort_array = argsort_series.downcast::<UInt64Type>()?;

// The result indices.
let mut key_indices: Vec<u64> = vec![];
let mut values_indices: Vec<UInt64Array> = vec![];

let comparator =
build_multi_array_compare(self.columns.as_slice(), &vec![false; self.columns.len()])?;

// To group the argsort values together, we will traverse the table in argsort order,
// collecting the indices traversed whenever the table value changes.

// The current table value we're looking at, but represented only by the index in the table.
// For convenience, also keep the index's index in the argarray.

let mut group_begin_indices: Option<(usize, usize)> = None;

for (argarray_index, table_index) in argsort_array.downcast().iter().enumerate() {
let table_index = *table_index.unwrap() as usize;

match group_begin_indices {
None => group_begin_indices = Some((table_index, argarray_index)),
Some((begin_table_index, begin_argarray_index)) => {
let comp_result = comparator(begin_table_index, table_index);
if comp_result != Ordering::Equal {
// The value has changed.
// Record results for the previous group.
key_indices.push(begin_table_index as u64);
values_indices
.push(argsort_array.slice(begin_argarray_index, argarray_index)?);

// Update the current value.
group_begin_indices = Some((table_index, argarray_index));
}
}
}
}

// Record results for the last group (since the for loop doesn't detect the last group closing).
if let Some((begin_table_index, begin_argsort_index)) = group_begin_indices {
key_indices.push(begin_table_index as u64);
values_indices.push(argsort_array.slice(begin_argsort_index, argsort_array.len())?);
}

Ok((key_indices, values_indices))
}
}
76 changes: 76 additions & 0 deletions tests/table/test_table.py
Expand Up @@ -532,6 +532,82 @@ def test_table_agg_global(case) -> None:
assert result[key] == value


@pytest.mark.parametrize(
"groups_and_aggs",
[
(["col_A"], ["col_B"]),
(["col_A", "col_B"], []),
],
)
def test_table_agg_groupby_empty(groups_and_aggs) -> None:
groups, aggs = groups_and_aggs
daft_table = Table.from_pydict({"col_A": [], "col_B": []})
daft_table = daft_table.agg(
[(col(a), "count") for a in aggs],
[col(g).cast(DataType.int32()) for g in groups],
)
res = daft_table.to_pydict()

assert res == {"col_A": [], "col_B": []}


test_table_agg_groupby_cases = [
{
# Group by strings.
"groups": ["name"],
"aggs": [("cookies", "sum"), ("name", "count")],
"expected": {"name": ["Alice", "Bob", None], "sum": [None, 10, 7], "count": [4, 4, 0]},
},
{
# Group by numbers.
"groups": ["cookies"],
"aggs": [("name", "count")],
"expected": {"cookies": [2, 5, None], "count": [0, 2, 6]},
},
{
# Group by multicol.
"groups": ["name", "cookies"],
"aggs": [("name", "count")],
"expected": {
"name": ["Alice", "Bob", "Bob", None, None, None],
"cookies": [None, 5, None, 2, 5, None],
"count": [4, 2, 2, 0, 0, 0],
},
},
]


@pytest.mark.parametrize(
"case", test_table_agg_groupby_cases, ids=[f"{case['groups']}" for case in test_table_agg_groupby_cases]
)
def test_table_agg_groupby(case) -> None:
values = [
("Bob", None),
("Bob", None),
("Bob", 5),
("Bob", 5),
(None, None),
(None, 5),
(None, None),
(None, 2),
("Alice", None),
("Alice", None),
("Alice", None),
("Alice", None),
]
daft_table = Table.from_pydict(
{
"name": [_[0] for _ in values],
"cookies": [_[1] for _ in values],
}
)
daft_table = daft_table.agg(
[(col(aggcol).alias(aggfn), aggfn) for aggcol, aggfn in case["aggs"]],
[col(group) for group in case["groups"]],
)
assert daft_table.to_pydict() == case["expected"]


import operator as ops

OPS = [ops.add, ops.sub, ops.mul, ops.truediv, ops.mod, ops.lt, ops.le, ops.eq, ops.ne, ops.ge, ops.gt]
Expand Down