Skip to content

Commit

Permalink
Clean up code and add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
xcharleslin committed Mar 18, 2023
1 parent 5d0ec89 commit e35d185
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 48 deletions.
9 changes: 1 addition & 8 deletions src/array/ops/downcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use arrow2::array;

use crate::{
array::{BaseArray, DataArray},
datatypes::{BooleanArray, DaftNumericType, NullArray, Utf8Array},
datatypes::{BooleanArray, DaftNumericType, Utf8Array},
};

impl<T> DataArray<T>
Expand All @@ -28,10 +28,3 @@ impl BooleanArray {
self.data().as_any().downcast_ref().unwrap()
}
}

impl NullArray {
// downcasts a DataArray<T> to an Arrow NullArray.
pub fn downcast(&self) -> &array::NullArray {
self.data().as_any().downcast_ref().unwrap()
}
}
98 changes: 58 additions & 40 deletions src/table/ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::{

impl Table {
pub fn agg(&self, to_agg: &[(Expr, &str)], group_by: &[Expr]) -> DaftResult<Table> {
// Dispatch depending on whether we're doing groupby or just a global agg.
match group_by.len() {
0 => self.agg_global(to_agg),
_ => self.agg_groupby(to_agg, group_by),
Expand All @@ -20,7 +21,9 @@ impl Table {
pub fn agg_global(&self, to_agg: &[(Expr, &str)]) -> DaftResult<Table> {
// Convert the input (child, name) exprs to the enum form.
// e.g. (expr, "sum") to Expr::Agg(AggExpr::Sum(expr))
// (We could do this at the pyo3 layer but I'm not sure why they're strings in the first place)
//
// NOTE: We may want to do this elsewhere later.
// See https://github.com/Eventual-Inc/Daft/pull/702#discussion_r1136597811
let agg_expr_list = to_agg
.iter()
.map(|(e, s)| AggExpr::from_name_and_child_expr(s, e))
Expand All @@ -29,55 +32,48 @@ impl Table {
.iter()
.map(|ae| Expr::Agg(ae.clone()))
.collect::<Vec<Expr>>();

self.eval_expression_list(&expr_list)
}

pub fn agg_groupby(&self, to_agg: &[(Expr, &str)], group_by: &[Expr]) -> DaftResult<Table> {
// Table with just the groupby columns.
let groupby_table = self.eval_expression_list(group_by)?;

// Get the unique group keys (by indices)
// and the grouped values (also by indices, one array of indices per group).
let (groupkey_indices, groupvals_indices) = groupby_table.sort_grouper()?;

// Table with the aggregated (deduplicated) group keys.
let groupkeys_table = {
let indices_as_arrow = arrow2::array::PrimitiveArray::from_vec(groupkey_indices);
let indices_as_series =
UInt64Array::from(("__TEMP_DAFT_GROUP_INDICES", Box::new(indices_as_arrow)))
.into_series();
UInt64Array::from(("", Box::new(indices_as_arrow))).into_series();
groupby_table.take(&indices_as_series)?
};

println!("{}", groupkeys_table);

// Table with the aggregated values, one row for each group.
let agged_values_table = {
let mut subresults: Vec<Self> = vec![];

// Agg each group into its own table.
let mut agged_groups: Vec<Self> = vec![];
for group_indices_array in groupvals_indices.iter() {
let subtable = {
let group = {
let indices_as_arrow = group_indices_array.downcast();
let indices_as_series = UInt64Array::from((
"__TEMP_DAFT_GROUP_INDICES",
Box::new(indices_as_arrow.clone()),
))
.into_series();
let indices_as_series =
UInt64Array::from(("", Box::new(indices_as_arrow.clone()))).into_series();
self.take(&indices_as_series)?
};
println!("{}", subtable);
let subresult = subtable.agg_global(to_agg)?;
println!("{}", subresult);
subresults.push(subresult.to_owned());
let agged_group = group.agg_global(to_agg)?;
agged_groups.push(agged_group.to_owned());
}

match subresults.len() {
match agged_groups.len() {
0 => self.head(0)?.agg_global(to_agg)?.head(0)?,
_ => Self::concat(subresults.iter().collect::<Vec<&Self>>().as_slice())?,
_ => Self::concat(agged_groups.iter().collect::<Vec<&Self>>().as_slice())?,
}
};

println!("{}", agged_values_table);

// Final result - concat the groupkey columns and the agg result columns together.
// Combine the groupkey columns and the aggregation result columns.
Self::from_columns(
[
&groupkeys_table.columns[..],
Expand All @@ -88,43 +84,65 @@ impl Table {
}

fn sort_grouper(&self) -> DaftResult<(Vec<u64>, Vec<UInt64Array>)> {
// Argsort the table, but also group identical values together.
//
// Given a table, returns a tuple:
// 1. An argsort of the entire table, deduplicated.
// 2. An argsort of the entire table, with identical values grouped.
//
// e.g. given a table [B, B, A, B, C, C]
// returns: (
// [2, 0, 4] <-- indices of A, B, and C
// [[2], [0, 1, 3], [4, 5]] <--- indices of all A, all B, all C
// )

// Begin by doing the argsort.
let argsort_series =
Series::argsort_multikey(self.columns.as_slice(), &vec![false; self.columns.len()])?;

let argsort_array = argsort_series.downcast::<UInt64Type>()?;

let mut groupvals_indices: Vec<UInt64Array> = vec![];
let mut groupkey_indices: Vec<u64> = vec![];
// The result indices.
let mut key_indices: Vec<u64> = vec![];
let mut values_indices: Vec<UInt64Array> = vec![];

let comparator =
build_multi_array_compare(self.columns.as_slice(), &vec![false; self.columns.len()])?;

// (argsort index, data index).
// To group the argsort values together, we will traverse the table in argsort order,
// collecting the indices traversed whenever the table value changes.

// The current table value we're looking at, but represented only by the index in the table.
// For convenience, also keep the index's index in the argarray.

let mut group_begin_indices: Option<(usize, usize)> = None;

for (argsort_index, data_index) in argsort_array.downcast().iter().enumerate() {
let data_index = *data_index.unwrap() as usize;
for (argarray_index, table_index) in argsort_array.downcast().iter().enumerate() {
let table_index = *table_index.unwrap() as usize;

// Start a new group result if the groupkey has changed (or if there was no previous groupkey).
match group_begin_indices {
None => group_begin_indices = Some((argsort_index, data_index)),
Some((begin_argsort_index, begin_data_index)) => {
let comp_result = comparator(begin_data_index, data_index);
None => group_begin_indices = Some((table_index, argarray_index)),
Some((begin_table_index, begin_argarray_index)) => {
let comp_result = comparator(begin_table_index, table_index);
if comp_result != Ordering::Equal {
groupkey_indices.push(begin_data_index as u64);
groupvals_indices
.push(argsort_array.slice(begin_argsort_index, argsort_index)?);
group_begin_indices = Some((argsort_index, data_index));
// The value has changed.
// Record results for the previous group.
key_indices.push(begin_table_index as u64);
values_indices
.push(argsort_array.slice(begin_argarray_index, argarray_index)?);

// Update the current value.
group_begin_indices = Some((table_index, argarray_index));
}
}
}
}

if let Some((begin_argsort_index, begin_data_index)) = group_begin_indices {
groupkey_indices.push(begin_data_index as u64);
groupvals_indices.push(argsort_array.slice(begin_argsort_index, argsort_array.len())?);
// Record results for the last group (since the for loop doesn't detect the last group closing).
if let Some((begin_table_index, begin_argsort_index)) = group_begin_indices {
key_indices.push(begin_table_index as u64);
values_indices.push(argsort_array.slice(begin_argsort_index, argsort_array.len())?);
}

Ok((groupkey_indices, groupvals_indices))
Ok((key_indices, values_indices))
}
}

0 comments on commit e35d185

Please sign in to comment.