Skip to content

Commit

Permalink
refactor(plan_node): simplify agg-related nodes (risingwavelabs#8930)
Browse files Browse the repository at this point in the history
  • Loading branch information
ice1000 committed Apr 2, 2023
1 parent 039d4a0 commit e217689
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 147 deletions.
58 changes: 29 additions & 29 deletions src/frontend/src/optimizer/plan_node/batch_simple_agg.rs
Expand Up @@ -18,36 +18,39 @@ use risingwave_common::error::Result;
use risingwave_pb::batch_plan::plan_node::NodeBody;
use risingwave_pb::batch_plan::SortAggNode;

use super::generic::PlanAggCall;
use super::{
ExprRewritable, LogicalAgg, PlanBase, PlanRef, PlanTreeNodeUnary, ToBatchPb, ToDistributedBatch,
};
use super::generic::{self, PlanAggCall};
use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, ToBatchPb, ToDistributedBatch};
use crate::expr::ExprRewriter;
use crate::optimizer::plan_node::{BatchExchange, ToLocalBatch};
use crate::optimizer::property::{Distribution, Order, RequiredDist};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BatchSimpleAgg {
pub base: PlanBase,
logical: LogicalAgg,
logical: generic::Agg<PlanRef>,
}

impl BatchSimpleAgg {
pub fn new(logical: LogicalAgg) -> Self {
let ctx = logical.base.ctx.clone();
let input = logical.input();
pub fn new(logical: generic::Agg<PlanRef>) -> Self {
let base = PlanBase::new_logical_with_core(&logical);
let ctx = base.ctx;
let input = logical.input.clone();
let input_dist = input.distribution();
let base = PlanBase::new_batch(
ctx,
logical.schema().clone(),
input_dist.clone(),
Order::any(),
);
let base = PlanBase::new_batch(ctx, base.schema, input_dist.clone(), Order::any());
BatchSimpleAgg { base, logical }
}

pub fn agg_calls(&self) -> &[PlanAggCall] {
self.logical.agg_calls()
&self.logical.agg_calls
}

fn two_phase_agg_enabled(&self) -> bool {
let session_ctx = self.base.ctx.session_ctx();
session_ctx.config().get_enable_two_phase_agg()
}

pub(crate) fn can_two_phase_agg(&self) -> bool {
self.logical.can_two_phase_agg() && self.two_phase_agg_enabled()
}
}

Expand All @@ -59,11 +62,14 @@ impl fmt::Display for BatchSimpleAgg {

impl PlanTreeNodeUnary for BatchSimpleAgg {
fn input(&self) -> PlanRef {
self.logical.input()
self.logical.input.clone()
}

fn clone_with_input(&self, input: PlanRef) -> Self {
Self::new(self.logical.clone_with_input(input))
Self::new(generic::Agg {
input,
..self.logical.clone()
})
}
}
impl_plan_tree_node_for_unary! { BatchSimpleAgg }
Expand All @@ -75,8 +81,7 @@ impl ToDistributedBatch for BatchSimpleAgg {
let dist_input = self.input().to_distributed()?;

// TODO: distinct agg cannot use 2-phase agg yet.
if dist_input.distribution().satisfies(&RequiredDist::AnyShard)
&& self.logical.can_two_phase_agg()
if dist_input.distribution().satisfies(&RequiredDist::AnyShard) && self.can_two_phase_agg()
{
// partial agg
let partial_agg = self.clone_with_input(dist_input).into();
Expand All @@ -88,15 +93,15 @@ impl ToDistributedBatch for BatchSimpleAgg {
// insert total agg
let total_agg_types = self
.logical
.agg_calls()
.agg_calls
.iter()
.enumerate()
.map(|(partial_output_idx, agg_call)| {
agg_call.partial_to_total_agg_call(partial_output_idx)
})
.collect();
let total_agg_logical =
LogicalAgg::new(total_agg_types, self.logical.group_key().to_vec(), exchange);
generic::Agg::new(total_agg_types, self.logical.group_key.to_vec(), exchange);
Ok(BatchSimpleAgg::new(total_agg_logical).into())
} else {
let new_input = self
Expand Down Expand Up @@ -138,13 +143,8 @@ impl ExprRewritable for BatchSimpleAgg {
}

fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
Self::new(
self.logical
.rewrite_exprs(r)
.as_logical_agg()
.unwrap()
.clone(),
)
.into()
let mut logical = self.logical.clone();
logical.rewrite_exprs(r);
Self::new(logical).into()
}
}
45 changes: 45 additions & 0 deletions src/frontend/src/optimizer/plan_node/generic/agg.rs
Expand Up @@ -72,6 +72,35 @@ impl<PlanRef: GenericPlanRef> Agg<PlanRef> {
pub fn i2o_col_mapping(&self) -> ColIndexMapping {
self.o2i_col_mapping().inverse()
}

pub(crate) fn can_two_phase_agg(&self) -> bool {
self.call_support_two_phase() && !self.is_agg_result_affected_by_order()
}

fn call_support_two_phase(&self) -> bool {
!self.agg_calls.is_empty()
&& self.agg_calls.iter().all(|call| {
matches!(
call.agg_kind,
AggKind::Min | AggKind::Max | AggKind::Sum | AggKind::Count
) && !call.distinct
})
}

/// Check if the aggregation result will be affected by order by clause, if any.
pub(crate) fn is_agg_result_affected_by_order(&self) -> bool {
self.agg_calls
.iter()
.any(|call| matches!(call.agg_kind, AggKind::StringAgg | AggKind::ArrayAgg))
}

pub fn new(agg_calls: Vec<PlanAggCall>, group_key: Vec<usize>, input: PlanRef) -> Self {
Self {
agg_calls,
group_key,
input,
}
}
}

impl<PlanRef: GenericPlanRef> GenericPlanNode for Agg<PlanRef> {
Expand Down Expand Up @@ -176,6 +205,22 @@ pub struct MaterializedInputState {
}

impl<PlanRef: stream::StreamPlanRef> Agg<PlanRef> {
pub fn infer_tables(
&self,
me: &impl stream::StreamPlanRef,
vnode_col_idx: Option<usize>,
) -> (
TableCatalog,
Vec<AggCallState>,
HashMap<usize, TableCatalog>,
) {
(
self.infer_result_table(me, vnode_col_idx),
self.infer_stream_agg_state(me, vnode_col_idx),
self.infer_distinct_dedup_tables(me, vnode_col_idx),
)
}

/// Infer `AggCallState`s for streaming agg.
pub fn infer_stream_agg_state(
&self,
Expand Down
88 changes: 35 additions & 53 deletions src/frontend/src/optimizer/plan_node/logical_agg.rs
Expand Up @@ -79,7 +79,7 @@ impl LogicalAgg {
let local_agg = StreamLocalSimpleAgg::new(self.clone_with_input(stream_input));
let exchange =
RequiredDist::single().enforce_if_not_satisfies(local_agg.into(), &Order::any())?;
let global_agg = new_stream_global_simple_agg(LogicalAgg::new(
let global_agg = new_stream_global_simple_agg(generic::Agg::new(
self.agg_calls()
.iter()
.enumerate()
Expand Down Expand Up @@ -129,7 +129,7 @@ impl LogicalAgg {
local_group_key.push(vnode_col_idx);
let n_local_group_key = local_group_key.len();
let local_agg = new_stream_hash_agg(
LogicalAgg::new(self.agg_calls().to_vec(), local_group_key, project.into()),
generic::Agg::new(self.agg_calls().to_vec(), local_group_key, project.into()),
Some(vnode_col_idx),
);
// Global group key excludes vnode.
Expand All @@ -144,7 +144,7 @@ impl LogicalAgg {
if self.group_key().is_empty() {
let exchange =
RequiredDist::single().enforce_if_not_satisfies(local_agg.into(), &Order::any())?;
let global_agg = new_stream_global_simple_agg(LogicalAgg::new(
let global_agg = new_stream_global_simple_agg(generic::Agg::new(
self.agg_calls()
.iter()
.enumerate()
Expand All @@ -162,7 +162,7 @@ impl LogicalAgg {
// Local phase should have reordered the group keys into their required order.
// we can just follow it.
let global_agg = new_stream_hash_agg(
LogicalAgg::new(
generic::Agg::new(
self.agg_calls()
.iter()
.enumerate()
Expand All @@ -181,21 +181,18 @@ impl LogicalAgg {
}

fn gen_single_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
Ok(new_stream_global_simple_agg(self.clone_with_input(
RequiredDist::single().enforce_if_not_satisfies(stream_input, &Order::any())?,
))
.into())
let mut logical = self.core.clone();
let input = RequiredDist::single().enforce_if_not_satisfies(stream_input, &Order::any())?;
logical.input = input;
Ok(new_stream_global_simple_agg(logical).into())
}

fn gen_shuffle_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
Ok(new_stream_hash_agg(
self.clone_with_input(
RequiredDist::shard_by_key(stream_input.schema().len(), self.group_key())
.enforce_if_not_satisfies(stream_input, &Order::any())?,
),
None,
)
.into())
let input = RequiredDist::shard_by_key(stream_input.schema().len(), self.group_key())
.enforce_if_not_satisfies(stream_input, &Order::any())?;
let mut logical = self.core.clone();
logical.input = input;
Ok(new_stream_hash_agg(logical, None).into())
}

/// See if all stream aggregation calls have a stateless local agg counterpart.
Expand Down Expand Up @@ -284,13 +281,6 @@ impl LogicalAgg {
}
}

/// Check if the aggregation result will be affected by order by clause, if any.
pub(crate) fn is_agg_result_affected_by_order(&self) -> bool {
self.agg_calls()
.iter()
.any(|call| matches!(call.agg_kind, AggKind::StringAgg | AggKind::ArrayAgg))
}

pub(crate) fn two_phase_agg_forced(&self) -> bool {
self.base
.ctx()
Expand All @@ -313,15 +303,7 @@ impl LogicalAgg {
}

pub(crate) fn can_two_phase_agg(&self) -> bool {
!self.agg_calls().is_empty()
&& self.agg_calls().iter().all(|call| {
matches!(
call.agg_kind,
AggKind::Min | AggKind::Max | AggKind::Sum | AggKind::Count
) && !call.distinct
})
&& !self.is_agg_result_affected_by_order()
&& self.two_phase_agg_enabled()
self.core.can_two_phase_agg() && self.two_phase_agg_enabled()
}

// Check if the output of the aggregation needs to be sorted and return ordering req by group
Expand Down Expand Up @@ -870,21 +852,16 @@ impl ExprRewriter for LogicalAggBuilder {
}
}

impl LogicalAgg {
pub fn new(agg_calls: Vec<PlanAggCall>, group_key: Vec<usize>, input: PlanRef) -> Self {
let core = generic::Agg {
agg_calls,
group_key,
input,
};
impl From<generic::Agg<PlanRef>> for LogicalAgg {
fn from(core: generic::Agg<PlanRef>) -> Self {
let base = PlanBase::new_logical_with_core(&core);
Self { base, core }
}
}

/// get the Mapping of columnIndex from input column index to output column index,if a input
/// column corresponds more than one out columns, mapping to any one
pub fn o2i_col_mapping(&self) -> ColIndexMapping {
self.core.o2i_col_mapping()
impl LogicalAgg {
pub fn new(agg_calls: Vec<PlanAggCall>, group_key: Vec<usize>, input: PlanRef) -> Self {
Self::from(generic::Agg::new(agg_calls, group_key, input))
}

/// get the Mapping of columnIndex from input column index to out column index
Expand Down Expand Up @@ -978,8 +955,11 @@ impl LogicalAgg {
}

fn to_batch_simple_agg(&self) -> Result<PlanRef> {
let new_input = self.input().to_batch()?;
let new_logical = self.clone_with_input(new_input);
let input = self.input().to_batch()?;
let new_logical = generic::Agg {
input,
..self.core.clone()
};
Ok(BatchSimpleAgg::new(new_logical).into())
}
}
Expand Down Expand Up @@ -1159,6 +1139,7 @@ impl ToBatch for LogicalAgg {
if output_requires_order {
// Push down sort before aggregation
input_order = self
.core
.o2i_col_mapping()
.rewrite_provided_order(&group_key_order);
}
Expand All @@ -1180,32 +1161,33 @@ impl ToBatch for LogicalAgg {
}
}

fn find_or_append_row_count(mut logical: LogicalAgg) -> (LogicalAgg, usize) {
fn find_or_append_row_count(mut logical: generic::Agg<PlanRef>) -> (generic::Agg<PlanRef>, usize) {
// `HashAgg`/`GlobalSimpleAgg` executors require a `count(*)` to correctly build changes, so
// append a `count(*)` if not exists.
let count_star = PlanAggCall::count_star();
let row_count_idx = if let Some((idx, _)) = logical
.agg_calls()
.agg_calls
.iter()
.find_position(|&c| c == &count_star)
{
idx
} else {
let (mut agg_calls, group_key, input) = logical.decompose();
let idx = agg_calls.len();
agg_calls.push(count_star);
logical = LogicalAgg::new(agg_calls, group_key, input);
let idx = logical.agg_calls.len();
logical.agg_calls.push(count_star);
idx
};
(logical, row_count_idx)
}

fn new_stream_global_simple_agg(logical: LogicalAgg) -> StreamGlobalSimpleAgg {
fn new_stream_global_simple_agg(logical: generic::Agg<PlanRef>) -> StreamGlobalSimpleAgg {
let (logical, row_count_idx) = find_or_append_row_count(logical);
StreamGlobalSimpleAgg::new(logical, row_count_idx)
}

fn new_stream_hash_agg(logical: LogicalAgg, vnode_col_idx: Option<usize>) -> StreamHashAgg {
fn new_stream_hash_agg(
logical: generic::Agg<PlanRef>,
vnode_col_idx: Option<usize>,
) -> StreamHashAgg {
let (logical, row_count_idx) = find_or_append_row_count(logical);
StreamHashAgg::new(logical, vnode_col_idx, row_count_idx)
}
Expand Down

0 comments on commit e217689

Please sign in to comment.