diff --git a/src/frontend/src/optimizer/plan_node/batch_simple_agg.rs b/src/frontend/src/optimizer/plan_node/batch_simple_agg.rs index e19cbae85aaf..458f7b1c8ea9 100644 --- a/src/frontend/src/optimizer/plan_node/batch_simple_agg.rs +++ b/src/frontend/src/optimizer/plan_node/batch_simple_agg.rs @@ -18,10 +18,8 @@ use risingwave_common::error::Result; use risingwave_pb::batch_plan::plan_node::NodeBody; use risingwave_pb::batch_plan::SortAggNode; -use super::generic::PlanAggCall; -use super::{ - ExprRewritable, LogicalAgg, PlanBase, PlanRef, PlanTreeNodeUnary, ToBatchPb, ToDistributedBatch, -}; +use super::generic::{self, PlanAggCall}; +use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, ToBatchPb, ToDistributedBatch}; use crate::expr::ExprRewriter; use crate::optimizer::plan_node::{BatchExchange, ToLocalBatch}; use crate::optimizer::property::{Distribution, Order, RequiredDist}; @@ -29,25 +27,30 @@ use crate::optimizer::property::{Distribution, Order, RequiredDist}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct BatchSimpleAgg { pub base: PlanBase, - logical: LogicalAgg, + logical: generic::Agg, } impl BatchSimpleAgg { - pub fn new(logical: LogicalAgg) -> Self { - let ctx = logical.base.ctx.clone(); - let input = logical.input(); + pub fn new(logical: generic::Agg) -> Self { + let base = PlanBase::new_logical_with_core(&logical); + let ctx = base.ctx; + let input = logical.input.clone(); let input_dist = input.distribution(); - let base = PlanBase::new_batch( - ctx, - logical.schema().clone(), - input_dist.clone(), - Order::any(), - ); + let base = PlanBase::new_batch(ctx, base.schema, input_dist.clone(), Order::any()); BatchSimpleAgg { base, logical } } pub fn agg_calls(&self) -> &[PlanAggCall] { - self.logical.agg_calls() + &self.logical.agg_calls + } + + fn two_phase_agg_enabled(&self) -> bool { + let session_ctx = self.base.ctx.session_ctx(); + session_ctx.config().get_enable_two_phase_agg() + } + + pub(crate) fn can_two_phase_agg(&self) -> bool { + self.logical.can_two_phase_agg() && self.two_phase_agg_enabled() } } @@ -59,11 +62,14 @@ impl fmt::Display for BatchSimpleAgg { impl PlanTreeNodeUnary for BatchSimpleAgg { fn input(&self) -> PlanRef { - self.logical.input() + self.logical.input.clone() } fn clone_with_input(&self, input: PlanRef) -> Self { - Self::new(self.logical.clone_with_input(input)) + Self::new(generic::Agg { + input, + ..self.logical.clone() + }) } } impl_plan_tree_node_for_unary! { BatchSimpleAgg } @@ -75,8 +81,7 @@ impl ToDistributedBatch for BatchSimpleAgg { let dist_input = self.input().to_distributed()?; // TODO: distinct agg cannot use 2-phase agg yet. - if dist_input.distribution().satisfies(&RequiredDist::AnyShard) - && self.logical.can_two_phase_agg() + if dist_input.distribution().satisfies(&RequiredDist::AnyShard) && self.can_two_phase_agg() { // partial agg let partial_agg = self.clone_with_input(dist_input).into(); @@ -88,7 +93,7 @@ impl ToDistributedBatch for BatchSimpleAgg { // insert total agg let total_agg_types = self .logical - .agg_calls() + .agg_calls .iter() .enumerate() .map(|(partial_output_idx, agg_call)| { @@ -96,7 +101,7 @@ impl ToDistributedBatch for BatchSimpleAgg { }) .collect(); let total_agg_logical = - LogicalAgg::new(total_agg_types, self.logical.group_key().to_vec(), exchange); + generic::Agg::new(total_agg_types, self.logical.group_key.to_vec(), exchange); Ok(BatchSimpleAgg::new(total_agg_logical).into()) } else { let new_input = self @@ -138,13 +143,8 @@ impl ExprRewritable for BatchSimpleAgg { } fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef { - Self::new( - self.logical - .rewrite_exprs(r) - .as_logical_agg() - .unwrap() - .clone(), - ) - .into() + let mut logical = self.logical.clone(); + logical.rewrite_exprs(r); + Self::new(logical).into() } } diff --git a/src/frontend/src/optimizer/plan_node/generic/agg.rs b/src/frontend/src/optimizer/plan_node/generic/agg.rs index ba23311a7862..599fe2310c79 100644 --- a/src/frontend/src/optimizer/plan_node/generic/agg.rs +++ b/src/frontend/src/optimizer/plan_node/generic/agg.rs @@ -72,6 +72,35 @@ impl Agg { pub fn i2o_col_mapping(&self) -> ColIndexMapping { self.o2i_col_mapping().inverse() } + + pub(crate) fn can_two_phase_agg(&self) -> bool { + self.call_support_two_phase() && !self.is_agg_result_affected_by_order() + } + + fn call_support_two_phase(&self) -> bool { + !self.agg_calls.is_empty() + && self.agg_calls.iter().all(|call| { + matches!( + call.agg_kind, + AggKind::Min | AggKind::Max | AggKind::Sum | AggKind::Count + ) && !call.distinct + }) + } + + /// Check if the aggregation result will be affected by order by clause, if any. + pub(crate) fn is_agg_result_affected_by_order(&self) -> bool { + self.agg_calls + .iter() + .any(|call| matches!(call.agg_kind, AggKind::StringAgg | AggKind::ArrayAgg)) + } + + pub fn new(agg_calls: Vec, group_key: Vec, input: PlanRef) -> Self { + Self { + agg_calls, + group_key, + input, + } + } } impl GenericPlanNode for Agg { @@ -176,6 +205,22 @@ pub struct MaterializedInputState { } impl Agg { + pub fn infer_tables( + &self, + me: &impl stream::StreamPlanRef, + vnode_col_idx: Option, + ) -> ( + TableCatalog, + Vec, + HashMap, + ) { + ( + self.infer_result_table(me, vnode_col_idx), + self.infer_stream_agg_state(me, vnode_col_idx), + self.infer_distinct_dedup_tables(me, vnode_col_idx), + ) + } + /// Infer `AggCallState`s for streaming agg. pub fn infer_stream_agg_state( &self, diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index 794f11564d26..e81b0e6f400f 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -79,7 +79,7 @@ impl LogicalAgg { let local_agg = StreamLocalSimpleAgg::new(self.clone_with_input(stream_input)); let exchange = RequiredDist::single().enforce_if_not_satisfies(local_agg.into(), &Order::any())?; - let global_agg = new_stream_global_simple_agg(LogicalAgg::new( + let global_agg = new_stream_global_simple_agg(generic::Agg::new( self.agg_calls() .iter() .enumerate() @@ -129,7 +129,7 @@ impl LogicalAgg { local_group_key.push(vnode_col_idx); let n_local_group_key = local_group_key.len(); let local_agg = new_stream_hash_agg( - LogicalAgg::new(self.agg_calls().to_vec(), local_group_key, project.into()), + generic::Agg::new(self.agg_calls().to_vec(), local_group_key, project.into()), Some(vnode_col_idx), ); // Global group key excludes vnode. @@ -144,7 +144,7 @@ impl LogicalAgg { if self.group_key().is_empty() { let exchange = RequiredDist::single().enforce_if_not_satisfies(local_agg.into(), &Order::any())?; - let global_agg = new_stream_global_simple_agg(LogicalAgg::new( + let global_agg = new_stream_global_simple_agg(generic::Agg::new( self.agg_calls() .iter() .enumerate() @@ -162,7 +162,7 @@ impl LogicalAgg { // Local phase should have reordered the group keys into their required order. // we can just follow it. let global_agg = new_stream_hash_agg( - LogicalAgg::new( + generic::Agg::new( self.agg_calls() .iter() .enumerate() @@ -181,21 +181,18 @@ impl LogicalAgg { } fn gen_single_plan(&self, stream_input: PlanRef) -> Result { - Ok(new_stream_global_simple_agg(self.clone_with_input( - RequiredDist::single().enforce_if_not_satisfies(stream_input, &Order::any())?, - )) - .into()) + let mut logical = self.core.clone(); + let input = RequiredDist::single().enforce_if_not_satisfies(stream_input, &Order::any())?; + logical.input = input; + Ok(new_stream_global_simple_agg(logical).into()) } fn gen_shuffle_plan(&self, stream_input: PlanRef) -> Result { - Ok(new_stream_hash_agg( - self.clone_with_input( - RequiredDist::shard_by_key(stream_input.schema().len(), self.group_key()) - .enforce_if_not_satisfies(stream_input, &Order::any())?, - ), - None, - ) - .into()) + let input = RequiredDist::shard_by_key(stream_input.schema().len(), self.group_key()) + .enforce_if_not_satisfies(stream_input, &Order::any())?; + let mut logical = self.core.clone(); + logical.input = input; + Ok(new_stream_hash_agg(logical, None).into()) } /// See if all stream aggregation calls have a stateless local agg counterpart. @@ -284,13 +281,6 @@ impl LogicalAgg { } } - /// Check if the aggregation result will be affected by order by clause, if any. - pub(crate) fn is_agg_result_affected_by_order(&self) -> bool { - self.agg_calls() - .iter() - .any(|call| matches!(call.agg_kind, AggKind::StringAgg | AggKind::ArrayAgg)) - } - pub(crate) fn two_phase_agg_forced(&self) -> bool { self.base .ctx() @@ -313,15 +303,7 @@ impl LogicalAgg { } pub(crate) fn can_two_phase_agg(&self) -> bool { - !self.agg_calls().is_empty() - && self.agg_calls().iter().all(|call| { - matches!( - call.agg_kind, - AggKind::Min | AggKind::Max | AggKind::Sum | AggKind::Count - ) && !call.distinct - }) - && !self.is_agg_result_affected_by_order() - && self.two_phase_agg_enabled() + self.core.can_two_phase_agg() && self.two_phase_agg_enabled() } // Check if the output of the aggregation needs to be sorted and return ordering req by group @@ -870,21 +852,16 @@ impl ExprRewriter for LogicalAggBuilder { } } -impl LogicalAgg { - pub fn new(agg_calls: Vec, group_key: Vec, input: PlanRef) -> Self { - let core = generic::Agg { - agg_calls, - group_key, - input, - }; +impl From> for LogicalAgg { + fn from(core: generic::Agg) -> Self { let base = PlanBase::new_logical_with_core(&core); Self { base, core } } +} - /// get the Mapping of columnIndex from input column index to output column index,if a input - /// column corresponds more than one out columns, mapping to any one - pub fn o2i_col_mapping(&self) -> ColIndexMapping { - self.core.o2i_col_mapping() +impl LogicalAgg { + pub fn new(agg_calls: Vec, group_key: Vec, input: PlanRef) -> Self { + Self::from(generic::Agg::new(agg_calls, group_key, input)) } /// get the Mapping of columnIndex from input column index to out column index @@ -978,8 +955,11 @@ impl LogicalAgg { } fn to_batch_simple_agg(&self) -> Result { - let new_input = self.input().to_batch()?; - let new_logical = self.clone_with_input(new_input); + let input = self.input().to_batch()?; + let new_logical = generic::Agg { + input, + ..self.core.clone() + }; Ok(BatchSimpleAgg::new(new_logical).into()) } } @@ -1159,6 +1139,7 @@ impl ToBatch for LogicalAgg { if output_requires_order { // Push down sort before aggregation input_order = self + .core .o2i_col_mapping() .rewrite_provided_order(&group_key_order); } @@ -1180,32 +1161,33 @@ impl ToBatch for LogicalAgg { } } -fn find_or_append_row_count(mut logical: LogicalAgg) -> (LogicalAgg, usize) { +fn find_or_append_row_count(mut logical: generic::Agg) -> (generic::Agg, usize) { // `HashAgg`/`GlobalSimpleAgg` executors require a `count(*)` to correctly build changes, so // append a `count(*)` if not exists. let count_star = PlanAggCall::count_star(); let row_count_idx = if let Some((idx, _)) = logical - .agg_calls() + .agg_calls .iter() .find_position(|&c| c == &count_star) { idx } else { - let (mut agg_calls, group_key, input) = logical.decompose(); - let idx = agg_calls.len(); - agg_calls.push(count_star); - logical = LogicalAgg::new(agg_calls, group_key, input); + let idx = logical.agg_calls.len(); + logical.agg_calls.push(count_star); idx }; (logical, row_count_idx) } -fn new_stream_global_simple_agg(logical: LogicalAgg) -> StreamGlobalSimpleAgg { +fn new_stream_global_simple_agg(logical: generic::Agg) -> StreamGlobalSimpleAgg { let (logical, row_count_idx) = find_or_append_row_count(logical); StreamGlobalSimpleAgg::new(logical, row_count_idx) } -fn new_stream_hash_agg(logical: LogicalAgg, vnode_col_idx: Option) -> StreamHashAgg { +fn new_stream_hash_agg( + logical: generic::Agg, + vnode_col_idx: Option, +) -> StreamHashAgg { let (logical, row_count_idx) = find_or_append_row_count(logical); StreamHashAgg::new(logical, vnode_col_idx, row_count_idx) } diff --git a/src/frontend/src/optimizer/plan_node/stream_global_simple_agg.rs b/src/frontend/src/optimizer/plan_node/stream_global_simple_agg.rs index 9d4a2bddbd9d..e740899203ef 100644 --- a/src/frontend/src/optimizer/plan_node/stream_global_simple_agg.rs +++ b/src/frontend/src/optimizer/plan_node/stream_global_simple_agg.rs @@ -18,8 +18,8 @@ use fixedbitset::FixedBitSet; use itertools::Itertools; use risingwave_pb::stream_plan::stream_node::PbNodeBody; -use super::generic::PlanAggCall; -use super::{ExprRewritable, LogicalAgg, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; +use super::generic::{self, PlanAggCall}; +use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::expr::ExprRewriter; use crate::optimizer::property::Distribution; use crate::stream_fragmenter::BuildFragmentGraphState; @@ -27,23 +27,21 @@ use crate::stream_fragmenter::BuildFragmentGraphState; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct StreamGlobalSimpleAgg { pub base: PlanBase, - logical: LogicalAgg, + logical: generic::Agg, /// The index of `count(*)` in `agg_calls`. row_count_idx: usize, } impl StreamGlobalSimpleAgg { - pub fn new(logical: LogicalAgg, row_count_idx: usize) -> Self { - assert_eq!( - logical.agg_calls()[row_count_idx], - PlanAggCall::count_star() - ); - - let ctx = logical.base.ctx.clone(); - let pk_indices = logical.base.logical_pk.to_vec(); - let schema = logical.schema().clone(); - let input = logical.input(); + pub fn new(logical: generic::Agg, row_count_idx: usize) -> Self { + assert_eq!(logical.agg_calls[row_count_idx], PlanAggCall::count_star()); + + let base = PlanBase::new_logical_with_core(&logical); + let ctx = base.ctx; + let pk_indices = base.logical_pk; + let schema = base.schema; + let input = logical.input.clone(); let input_dist = input.distribution(); let dist = match input_dist { Distribution::Single => Distribution::Single, @@ -59,7 +57,7 @@ impl StreamGlobalSimpleAgg { ctx, schema, pk_indices, - logical.functional_dependency().clone(), + base.functional_dependency, dist, false, watermark_columns, @@ -72,7 +70,7 @@ impl StreamGlobalSimpleAgg { } pub fn agg_calls(&self) -> &[PlanAggCall] { - self.logical.agg_calls() + &self.logical.agg_calls } } @@ -91,11 +89,15 @@ impl fmt::Display for StreamGlobalSimpleAgg { impl PlanTreeNodeUnary for StreamGlobalSimpleAgg { fn input(&self) -> PlanRef { - self.logical.input() + self.logical.input.clone() } fn clone_with_input(&self, input: PlanRef) -> Self { - Self::new(self.logical.clone_with_input(input), self.row_count_idx) + let logical = generic::Agg { + input, + ..self.logical.clone() + }; + Self::new(logical, self.row_count_idx) } } impl_plan_tree_node_for_unary! { StreamGlobalSimpleAgg } @@ -103,9 +105,8 @@ impl_plan_tree_node_for_unary! { StreamGlobalSimpleAgg } impl StreamNode for StreamGlobalSimpleAgg { fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> PbNodeBody { use risingwave_pb::stream_plan::*; - let result_table = self.logical.infer_result_table(None); - let agg_states = self.logical.infer_stream_agg_state(None); - let distinct_dedup_tables = self.logical.infer_distinct_dedup_tables(None); + let (result_table, agg_states, distinct_dedup_tables) = + self.logical.infer_tables(&self.base, None); PbNodeBody::GlobalSimpleAgg(SimpleAggNode { agg_calls: self @@ -153,14 +154,8 @@ impl ExprRewritable for StreamGlobalSimpleAgg { } fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef { - Self::new( - self.logical - .rewrite_exprs(r) - .as_logical_agg() - .unwrap() - .clone(), - self.row_count_idx, - ) - .into() + let mut logical = self.logical.clone(); + logical.rewrite_exprs(r); + Self::new(logical, self.row_count_idx).into() } } diff --git a/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs b/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs index e73644886314..e35ddf09732f 100644 --- a/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs +++ b/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs @@ -19,8 +19,8 @@ use itertools::Itertools; use risingwave_common::catalog::FieldDisplay; use risingwave_pb::stream_plan::stream_node::PbNodeBody; -use super::generic::PlanAggCall; -use super::{ExprRewritable, LogicalAgg, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; +use super::generic::{self, PlanAggCall}; +use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::expr::ExprRewriter; use crate::optimizer::property::Distribution; use crate::stream_fragmenter::BuildFragmentGraphState; @@ -29,7 +29,7 @@ use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct StreamHashAgg { pub base: PlanBase, - logical: LogicalAgg, + logical: generic::Agg, /// An optional column index which is the vnode of each row computed by the input's consistent /// hash distribution. @@ -40,16 +40,18 @@ pub struct StreamHashAgg { } impl StreamHashAgg { - pub fn new(logical: LogicalAgg, vnode_col_idx: Option, row_count_idx: usize) -> Self { - assert_eq!( - logical.agg_calls()[row_count_idx], - PlanAggCall::count_star() - ); - - let ctx = logical.base.ctx.clone(); - let pk_indices = logical.base.logical_pk.to_vec(); - let schema = logical.schema().clone(); - let input = logical.input(); + pub fn new( + logical: generic::Agg, + vnode_col_idx: Option, + row_count_idx: usize, + ) -> Self { + assert_eq!(logical.agg_calls[row_count_idx], PlanAggCall::count_star()); + + let base = PlanBase::new_logical_with_core(&logical); + let ctx = base.ctx; + let pk_indices = base.logical_pk; + let schema = base.schema; + let input = logical.input.clone(); let input_dist = input.distribution(); let dist = match input_dist { Distribution::HashShard(_) | Distribution::UpstreamHashShard(_, _) => logical @@ -60,7 +62,7 @@ impl StreamHashAgg { let mut watermark_columns = FixedBitSet::with_capacity(schema.len()); // Watermark column(s) must be in group key. - for (idx, input_idx) in logical.group_key().iter().enumerate() { + for (idx, input_idx) in logical.group_key.iter().enumerate() { if input.watermark_columns().contains(*input_idx) { watermark_columns.insert(idx); } @@ -71,7 +73,7 @@ impl StreamHashAgg { ctx, schema, pk_indices, - logical.functional_dependency().clone(), + base.functional_dependency, dist, false, watermark_columns, @@ -85,11 +87,11 @@ impl StreamHashAgg { } pub fn agg_calls(&self) -> &[PlanAggCall] { - self.logical.agg_calls() + &self.logical.agg_calls } pub fn group_key(&self) -> &[usize] { - self.logical.group_key() + &self.logical.group_key } pub(crate) fn i2o_col_mapping(&self) -> ColIndexMapping { @@ -124,15 +126,15 @@ impl fmt::Display for StreamHashAgg { impl PlanTreeNodeUnary for StreamHashAgg { fn input(&self) -> PlanRef { - self.logical.input() + self.logical.input.clone() } fn clone_with_input(&self, input: PlanRef) -> Self { - Self::new( - self.logical.clone_with_input(input), - self.vnode_col_idx, - self.row_count_idx, - ) + let logical = generic::Agg { + input, + ..self.logical.clone() + }; + Self::new(logical, self.vnode_col_idx, self.row_count_idx) } } impl_plan_tree_node_for_unary! { StreamHashAgg } @@ -140,9 +142,8 @@ impl_plan_tree_node_for_unary! { StreamHashAgg } impl StreamNode for StreamHashAgg { fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> PbNodeBody { use risingwave_pb::stream_plan::*; - let result_table = self.logical.infer_result_table(self.vnode_col_idx); - let agg_states = self.logical.infer_stream_agg_state(self.vnode_col_idx); - let distinct_dedup_tables = self.logical.infer_distinct_dedup_tables(self.vnode_col_idx); + let (result_table, agg_states, distinct_dedup_tables) = + self.logical.infer_tables(&self.base, self.vnode_col_idx); PbNodeBody::HashAgg(HashAggNode { group_key: self.group_key().iter().map(|idx| *idx as u32).collect(), @@ -185,15 +186,8 @@ impl ExprRewritable for StreamHashAgg { } fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef { - Self::new( - self.logical - .rewrite_exprs(r) - .as_logical_agg() - .unwrap() - .clone(), - self.vnode_col_idx, - self.row_count_idx, - ) - .into() + let mut logical = self.logical.clone(); + logical.rewrite_exprs(r); + Self::new(logical, self.vnode_col_idx, self.row_count_idx).into() } }