Skip to content

Commit

Permalink
Add collision detection
Browse files Browse the repository at this point in the history
  • Loading branch information
Dandandan committed Jan 25, 2021
1 parent ed6b68a commit 5e1be03
Showing 1 changed file with 71 additions and 43 deletions.
114 changes: 71 additions & 43 deletions rust/datafusion/src/physical_plan/hash_aggregate.rs
Expand Up @@ -296,7 +296,8 @@ fn group_aggregate_batch(

// 1.1 construct the key from the group values
// 1.2 construct the mapping key if it does not exist
// 1.3 add the row' index to `indices`
// 1.3 iterate through candidates potentially mapping to
// 1.4

// Make sure we can create the accumulators or otherwise return an error
create_accumulators(aggr_expr).map_err(DataFusionError::into_arrow_external_error)?;
Expand All @@ -311,11 +312,29 @@ fn group_aggregate_batch(
.raw_entry_mut()
.from_key_hashed_nocheck(*hash, hash)
// 1.3
.and_modify(|_, (_, _, v)| {
if v.is_empty() {
batch_keys.push(hash)
.and_modify(|_, candidates| {
// Iterate through candidates
let _ = create_group_by_values(&group_values, row, &mut group_by_values);
let mut no_match = true;
for (candidate_values, _, indices) in candidates.iter_mut() {
if group_by_values == *candidate_values {
indices.push(row as u32);
no_match = false;
}
}

if candidates.iter_mut().all(|(_, _, i)| i.is_empty()) {
batch_keys.push(hash);
};
v.push(row as u32)
if no_match {
// No match found, insert new
let accumulator_set = create_accumulators(aggr_expr).unwrap();
candidates.push((
group_by_values.clone(),
accumulator_set,
vec![row as u32],
));
}
})
// 1.2
.or_insert_with(|| {
Expand All @@ -325,7 +344,7 @@ fn group_aggregate_batch(
let _ = create_group_by_values(&group_values, row, &mut group_by_values);
(
*hash,
(group_by_values.clone(), accumulator_set, vec![row as u32]),
vec![(group_by_values.clone(), accumulator_set, vec![row as u32])],
)
});
}
Expand All @@ -335,10 +354,12 @@ fn group_aggregate_batch(
let mut offsets = vec![0];
let mut offset_so_far = 0;
for key in batch_keys.iter() {
let (_, _, indices) = accumulators.get_mut(key).unwrap();
batch_indices.append_slice(&indices)?;
offset_so_far += indices.len();
offsets.push(offset_so_far);
let accs = accumulators.get(key).unwrap();
for (_, _, indices) in accs {
batch_indices.append_slice(&indices)?;
offset_so_far += indices.len();
offsets.push(offset_so_far);
}
}
let batch_indices = batch_indices.finish();

Expand All @@ -361,42 +382,48 @@ fn group_aggregate_batch(
.collect();

// 2.1 for each key in this batch
// 2.2 for each aggregation
// 2.3 `slice` from each of its arrays the keys' values
// 2.4 update / merge the accumulator with the values
// 2.5 clear indices
// 2.2 for each candidate
// 2.3 for each aggregation
// 2.4 `slice` from each of its arrays the keys' values
// 2.5 update / merge the accumulator with the values
// 2.6 clear indices
batch_keys
.iter_mut()
.zip(offsets.windows(2))
.try_for_each(|(key, offsets)| {
let (_, accumulator_set, indices) = accumulators.get_mut(key).unwrap();
let accs = accumulators.get_mut(key).unwrap();
// 2.2
accumulator_set
.iter_mut()
.zip(values.iter())
.map(|(accumulator, aggr_array)| {
(
accumulator,
aggr_array
.iter()
.map(|array| {
// 2.3
array.slice(offsets[0], offsets[1] - offsets[0])
})
.collect(),
)
})
.try_for_each(|(accumulator, values)| match mode {
AggregateMode::Partial => accumulator.update_batch(&values),
AggregateMode::Final => {
// note: the aggregation here is over states, not values, thus the merge
accumulator.merge_batch(&values)
}
})
// 2.5
.and({
indices.clear();
Ok(())
accs.iter_mut()
.try_for_each(|(_, accumulator_set, indices)| {
// 2.3
accumulator_set
.iter_mut()
.zip(values.iter())
.map(|(accumulator, aggr_array)| {
(
accumulator,
aggr_array
.iter()
.map(|array| {
// 2.4
array.slice(offsets[0], offsets[1] - offsets[0])
})
.collect(),
)
})
.try_for_each(|(accumulator, values)| match mode {
// 2.5
AggregateMode::Partial => accumulator.update_batch(&values),
AggregateMode::Final => {
// note: the aggregation here is over states, not values, thus the merge
accumulator.merge_batch(&values)
}
})
// 2.6
.and({
indices.clear();
Ok(())
})
})
})?;
Ok(accumulators)
Expand Down Expand Up @@ -476,7 +503,7 @@ impl GroupedHashAggregateStream {

type AccumulatorSet = Vec<Box<dyn Accumulator>>;
type Accumulators =
HashMap<u64, (Box<[GroupByScalar]>, AccumulatorSet, Vec<u32>), IdHashBuilder>;
HashMap<u64, Vec<(Box<[GroupByScalar]>, AccumulatorSet, Vec<u32>)>, IdHashBuilder>;

impl Stream for GroupedHashAggregateStream {
type Item = ArrowResult<RecordBatch>;
Expand Down Expand Up @@ -737,7 +764,8 @@ fn create_batch_from_map(
// 5. concatenate the arrays over the second index [j] into a single vec<ArrayRef>.
let arrays = accumulators
.iter()
.map(|(_, (group_by_values, accumulator_set, _))| {
.flat_map(|(_, accs)| accs)
.map(|(group_by_values, accumulator_set, _)| {
// 2.
let mut groups = (0..num_group_expr)
.map(|i| match &group_by_values[i] {
Expand Down

0 comments on commit 5e1be03

Please sign in to comment.