From a0c460bb6b6d0cb3fde6d7aa471053ad480ce565 Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Thu, 4 Jun 2026 17:51:18 +0200 Subject: [PATCH] Revert "refactor: Change all `Visitor`s to be iterative, child-based (#36852)" This reverts commit dfce26ef97a338c55dcfb77bb4ceefca3ebc5fad. --- src/adapter/src/coord/sequencer/inner.rs | 2 +- src/adapter/src/coord/sequencer/inner/peek.rs | 2 +- src/adapter/src/frontend_peek.rs | 4 +- src/adapter/src/optimize/dataflows.rs | 2 +- src/expr/src/linear.rs | 27 +- src/expr/src/relation.rs | 24 +- src/expr/src/relation/canonicalize.rs | 6 +- src/expr/src/relation/join_input_mapper.rs | 3 +- src/expr/src/scalar.rs | 85 +- src/expr/src/scalar/optimizable.rs | 19 +- src/expr/src/scalar/reduce.rs | 3 +- src/expr/src/visit.rs | 1044 ++++++++++------- src/sql/src/plan.rs | 4 +- src/sql/src/plan/explain.rs | 28 +- src/sql/src/plan/hir.rs | 681 +---------- src/sql/src/plan/lowering.rs | 2 +- src/sql/src/plan/query.rs | 2 +- src/sql/src/plan/statement/dml.rs | 2 +- .../canonicalization/flat_map_elimination.rs | 2 +- .../canonicalization/projection_extraction.rs | 2 +- .../src/canonicalization/topk_elision.rs | 2 +- src/transform/src/coalesce_case.rs | 23 +- src/transform/src/column_knowledge.rs | 3 +- src/transform/src/compound/union.rs | 2 +- src/transform/src/fusion.rs | 2 +- src/transform/src/fusion/union.rs | 2 +- src/transform/src/join_implementation.rs | 14 +- src/transform/src/literal_constraints.rs | 4 +- src/transform/src/literal_lifting.rs | 16 +- src/transform/src/predicate_pushdown.rs | 18 +- src/transform/src/reduction_pushdown.rs | 13 +- src/transform/src/redundant_join.rs | 2 +- 32 files changed, 781 insertions(+), 1264 deletions(-) diff --git a/src/adapter/src/coord/sequencer/inner.rs b/src/adapter/src/coord/sequencer/inner.rs index 277cd1fc8384c..fb9d5438446b2 100644 --- a/src/adapter/src/coord/sequencer/inner.rs +++ b/src/adapter/src/coord/sequencer/inner.rs @@ -2696,7 +2696,7 @@ impl Coordinator { }; // Disallow mz_now in any position because read time and write time differ. - let contains_temporal = selection.contains_temporal() + let contains_temporal = return_if_err!(selection.contains_temporal(), ctx) || assignments.values().any(|e| e.contains_temporal()) || returning.iter().any(|e| e.contains_temporal()); if contains_temporal { diff --git a/src/adapter/src/coord/sequencer/inner/peek.rs b/src/adapter/src/coord/sequencer/inner/peek.rs index 429602b12a933..8c6a7174c3e15 100644 --- a/src/adapter/src/coord/sequencer/inner/peek.rs +++ b/src/adapter/src/coord/sequencer/inner/peek.rs @@ -331,7 +331,7 @@ impl Coordinator { .catalog() .validate_timeline_context(source_ids.iter().copied())?; if matches!(timeline_context, TimelineContext::TimestampIndependent) - && plan.source.contains_temporal() + && plan.source.contains_temporal()? { // If the source IDs are timestamp independent but the query contains temporal functions, // then the timeline context needs to be upgraded to timestamp dependent. This is diff --git a/src/adapter/src/frontend_peek.rs b/src/adapter/src/frontend_peek.rs index 2ee151ba1c756..522a4f12b67b7 100644 --- a/src/adapter/src/frontend_peek.rs +++ b/src/adapter/src/frontend_peek.rs @@ -456,7 +456,7 @@ impl PeekClient { let contains_temporal = match query_plan { QueryPlan::Select(s) => s.source.contains_temporal(), QueryPlan::CopyTo(s, _) => s.source.contains_temporal(), - QueryPlan::Subscribe(s) => s.from.contains_temporal(), + QueryPlan::Subscribe(s) => Ok(s.from.contains_temporal()), }; // # From sequence_plan @@ -548,7 +548,7 @@ impl PeekClient { // simple benchmarks), because it traverses transitive dependencies even of indexed views and // materialized views (also traversing their MIR plans). let mut timeline_context = catalog.validate_timeline_context(source_ids.iter().copied())?; - if matches!(timeline_context, TimelineContext::TimestampIndependent) && contains_temporal { + if matches!(timeline_context, TimelineContext::TimestampIndependent) && contains_temporal? { // If the source IDs are timestamp independent but the query contains temporal functions, // then the timeline context needs to be upgraded to timestamp dependent. This is // required because `source_ids` doesn't contain functions. diff --git a/src/adapter/src/optimize/dataflows.rs b/src/adapter/src/optimize/dataflows.rs index abebeff1ab1c8..fe4517d2f13c5 100644 --- a/src/adapter/src/optimize/dataflows.rs +++ b/src/adapter/src/optimize/dataflows.rs @@ -186,7 +186,7 @@ impl ExprPrep for ExprPrepMaintained { if let MirScalarExpr::CallUnmaterializable(f) = e { last_observed_unmaterializable_func = Some(f.clone()); } - }); + })?; if let Some(f) = last_observed_unmaterializable_func { Err(OptimizerError::UnmaterializableFunction(f)) diff --git a/src/expr/src/linear.rs b/src/expr/src/linear.rs index 0d5a8c23e7e34..67ffdd92a3e04 100644 --- a/src/expr/src/linear.rs +++ b/src/expr/src/linear.rs @@ -1263,14 +1263,16 @@ impl MapFilterProject { if let Some(i) = e.as_column() { reference_count[i] += 1; } - }); + }) + .expect("visit_pre hit recursion limit"); } for (_, pred) in self.predicates.iter() { pred.visit_pre(&mut |e| { if let Some(i) = e.as_column() { reference_count[i] += 1; } - }); + }) + .expect("visit_pre hit recursion limit"); } for proj in self.projection.iter() { reference_count[*proj] += 1; @@ -1321,13 +1323,15 @@ impl MapFilterProject { pub fn perform_inlining(&mut self, should_inline: Vec) { for index in 0..self.expressions.len() { let (prior, expr) = self.expressions.split_at_mut(index); - expr[0].visit_mut_post(&mut |e| { - if let Some(i) = e.as_column() { - if should_inline[i] { - *e = prior[i - self.input_arity].clone(); + expr[0] + .visit_mut_post(&mut |e| { + if let Some(i) = e.as_column() { + if should_inline[i] { + *e = prior[i - self.input_arity].clone(); + } } - } - }); + }) + .expect("inlining hit recursion limit"); } for (_index, pred) in self.predicates.iter_mut() { let expressions = &self.expressions; @@ -1337,7 +1341,8 @@ impl MapFilterProject { *e = expressions[i - self.input_arity].clone(); } } - }); + }) + .expect("inlining hit recursion limit"); } } @@ -1445,6 +1450,7 @@ pub fn memoize_expr( memoized_parts: &mut Vec, input_arity: usize, ) { + #[allow(deprecated)] expr.visit_mut_pre_post(&mut |e| e.eager_children(), &mut |e| { if E::is_literal(e) { // Literals do not need to be memoized. @@ -1475,7 +1481,8 @@ pub fn memoize_expr( E::column(input_arity + memoized_parts.len()), )); } - }); + }) + .expect("memoize_expr hit recursion limit"); } pub mod util { diff --git a/src/expr/src/relation.rs b/src/expr/src/relation.rs index 25a9a8b985d39..a516152e64df5 100644 --- a/src/expr/src/relation.rs +++ b/src/expr/src/relation.rs @@ -347,7 +347,8 @@ impl MirRelationExpr { /// visited in `type_stack`. pub fn typ(&self) -> ReprRelationType { let mut type_stack = Vec::new(); - self.visit_pre_post( + #[allow(deprecated)] + self.visit_pre_post_nolimit( &mut |e: &MirRelationExpr| -> Option> { match &e { MirRelationExpr::Let { body, .. } => Some(vec![&*body]), @@ -959,7 +960,8 @@ impl MirRelationExpr { /// visited in `arity_stack`. pub fn arity(&self) -> usize { let mut arity_stack = Vec::new(); - self.visit_pre_post( + #[allow(deprecated)] + self.visit_pre_post_nolimit( &mut |e: &MirRelationExpr| -> Option> { match &e { MirRelationExpr::Let { body, .. } => { @@ -1791,6 +1793,7 @@ impl MirRelationExpr { pub fn try_visit_scalars_mut(&mut self, f: &mut F) -> Result<(), E> where F: FnMut(&mut MirScalarExpr) -> Result<(), E>, + E: From, { self.try_visit_mut_post(&mut |expr| expr.try_visit_scalars_mut1(f)) } @@ -1919,6 +1922,7 @@ impl MirRelationExpr { pub fn try_visit_scalars(&self, f: &mut F) -> Result<(), E> where F: FnMut(&MirScalarExpr) -> Result<(), E>, + E: From, { self.try_visit_post(&mut |expr| expr.try_visit_scalars_1(f)) } @@ -2382,6 +2386,7 @@ impl VisitChildren for MirRelationExpr { fn try_visit_children(&self, mut f: F) -> Result<(), E> where F: FnMut(&Self) -> Result<(), E>, + E: From, { for child in self.children() { f(child)? @@ -2392,26 +2397,13 @@ impl VisitChildren for MirRelationExpr { fn try_visit_mut_children(&mut self, mut f: F) -> Result<(), E> where F: FnMut(&mut Self) -> Result<(), E>, + E: From, { for child in self.children_mut() { f(child)? } Ok(()) } - - fn children<'a>(&'a self) -> impl DoubleEndedIterator - where - Self: 'a, - { - self.children() - } - - fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator - where - Self: 'a, - { - self.children_mut() - } } /// Specification for an ordering by a column. diff --git a/src/expr/src/relation/canonicalize.rs b/src/expr/src/relation/canonicalize.rs index a0df9c46ee8d5..73cb890605840 100644 --- a/src/expr/src/relation/canonicalize.rs +++ b/src/expr/src/relation/canonicalize.rs @@ -69,7 +69,8 @@ pub fn canonicalize_equivalences<'a, I>( // which will then replace `to_reduce[i]`. let mut new_equivalence = Vec::with_capacity(to_reduce[i].len()); while let Some((_, mut popped_expr)) = to_reduce[i].pop() { - popped_expr.visit_mut_post(&mut |e: &mut MirScalarExpr| { + #[allow(deprecated)] + popped_expr.visit_mut_post_nolimit(&mut |e: &mut MirScalarExpr| { // If a simpler expression can be found that is equivalent // to e, if let Some(simpler_e) = to_reduce.iter().find_map(|cls| { @@ -394,7 +395,8 @@ fn replace_subexpr_and_reduce( repr_column_types: &[ReprColumnType], ) -> bool { let mut changed = false; - predicate.visit_mut_pre_post( + #[allow(deprecated)] + predicate.visit_mut_pre_post_nolimit( &mut |e| { // The `cond` of an if statement is not visited to prevent `then` // or `els` from being evaluated before `cond`, resulting in a diff --git a/src/expr/src/relation/join_input_mapper.rs b/src/expr/src/relation/join_input_mapper.rs index 338b819f56aaa..0207172beca51 100644 --- a/src/expr/src/relation/join_input_mapper.rs +++ b/src/expr/src/relation/join_input_mapper.rs @@ -328,7 +328,8 @@ impl JoinInputMapper { // `e` anyway, so we end up visiting nodes in `e` multiple times // here. Alternatively, consider having the future `PredicateKnowledge` // take over the responsibilities of this code? - expr.visit_mut_pre_post( + #[allow(deprecated)] + expr.visit_mut_pre_post_nolimit( &mut |e| { let mut inputs = self.lookup_inputs(e); if let Some(first_input) = inputs.next() { diff --git a/src/expr/src/scalar.rs b/src/expr/src/scalar.rs index 39d2de76d8ff9..690cb5332f4a8 100644 --- a/src/expr/src/scalar.rs +++ b/src/expr/src/scalar.rs @@ -1327,6 +1327,7 @@ impl VisitChildren for MirScalarExpr { fn try_visit_children(&self, mut f: F) -> Result<(), E> where F: FnMut(&Self) -> Result<(), E>, + E: From, { use MirScalarExpr::*; match self { @@ -1355,6 +1356,7 @@ impl VisitChildren for MirScalarExpr { fn try_visit_mut_children(&mut self, mut f: F) -> Result<(), E> where F: FnMut(&mut Self) -> Result<(), E>, + E: From, { use MirScalarExpr::*; match self { @@ -1379,20 +1381,6 @@ impl VisitChildren for MirScalarExpr { } Ok(()) } - - fn children<'a>(&'a self) -> impl DoubleEndedIterator - where - Self: 'a, - { - self.children() - } - - fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator - where - Self: 'a, - { - self.children_mut() - } } impl MirScalarExpr { @@ -1621,7 +1609,7 @@ impl FilterCharacteristics { } } }, - ); + )?; if literal_inequality_in_current_filter { literal_inequality += 1; } @@ -2552,71 +2540,4 @@ mod tests { ); } } - - /// Exercises the `unsafe` pointer stack in [`Visit::visit_mut_post`] with a - /// closure that *replaces subtrees* (`*expr = ...`). Miri's aliasing model - /// should shout if the "stack mirrors the call stack" becomes untrue. - #[mz_ore::test] - fn test_visit_mut_post_replace_subtrees() { - let col = MirScalarExpr::column; - let mut expr = col(0).if_then_else(col(1).if_then_else(col(2), col(3)), col(4)); - - expr.visit_mut_post(&mut |expr: &mut MirScalarExpr| match expr { - MirScalarExpr::Column(n, _) => *n += 1, - MirScalarExpr::If { then, .. } => { - let then = then.take(); - *expr = then; - } - _ => {} - }); - - // collapses to then-most branch - assert_eq!(expr, col(3)); - } - - /// Exercises the `unsafe` pointer stack in [`Visit::visit_mut_pre_post`] with - /// a `pre` that both *replaces the visited node wholesale* (`*expr = ...`) - /// and *returns an explicit child set* borrowed from the freshly written - /// value. Miri's aliasing model should shout if the "stack mirrors the call - /// stack" becomes untrue. - #[mz_ore::test] - fn test_visit_mut_pre_post_explicit_children() { - let col = MirScalarExpr::column; - let mut expr = col(5) - .if_then_else(col(6), col(7)) - .if_then_else(col(1).if_then_else(col(2), col(3)), col(4)); - - // turns conditions into column 0 in pre - // doesn't traverse conditions of ifs - // adds 10 to all column refs in post (but not in conditions!) - expr.visit_mut_pre_post( - &mut |expr: &mut MirScalarExpr| -> Option> { - if let MirScalarExpr::If { .. } = expr { - let MirScalarExpr::If { then, els, .. } = expr else { - unreachable!() - }; - let then = then.take(); - let els = els.take(); - *expr = MirScalarExpr::column(0).if_then_else(then, els); - - let MirScalarExpr::If { then, els, .. } = expr else { - unreachable!() - }; - Some(vec![then.as_mut(), els.as_mut()]) - } else { - // Leaves recurse with their default (empty) child set. - None - } - }, - &mut |expr: &mut MirScalarExpr| { - if let MirScalarExpr::Column(n, _) = expr { - *n += 10; - } - }, - ); - - // conditions become 0; everyone else += 10 - let expected = col(0).if_then_else(col(0).if_then_else(col(12), col(13)), col(14)); - assert_eq!(expr, expected); - } } diff --git a/src/expr/src/scalar/optimizable.rs b/src/expr/src/scalar/optimizable.rs index cf805b07518fd..cd0da0f1be69b 100644 --- a/src/expr/src/scalar/optimizable.rs +++ b/src/expr/src/scalar/optimizable.rs @@ -15,11 +15,12 @@ use std::fmt::Debug; use std::hash::Hash; +use mz_ore::stack::RecursionLimitError; use serde::Serialize; use crate::scalar::columns::Columns; use crate::scalar::func::{BinaryFunc, UnaryFunc, VariadicFunc}; -use crate::visit::VisitChildren; +use crate::visit::{Visit, VisitChildren}; use crate::{MirScalarExpr, func}; /// A scalar expression type that can be optimized inside a `MapFilterProject`. @@ -57,6 +58,14 @@ pub trait OptimizableExpr: /// /// Returns `(lower_bounds, upper_bounds)` for use in `MfpPlan`. fn extract_temporal_bounds(temporal: Vec) -> Result<(Vec, Vec), String>; + + /// Visit in a pre-traversal. Defaults to the `Visit` implementation, but overridable. + fn visit_pre(&self, f: &mut F) -> Result<(), RecursionLimitError> + where + F: FnMut(&Self), + { + Visit::visit_pre(self, f) + } } impl OptimizableExpr for MirScalarExpr { @@ -160,4 +169,12 @@ impl OptimizableExpr for MirScalarExpr { Ok((lower_bounds, upper_bounds)) } + + fn visit_pre(&self, f: &mut F) -> Result<(), RecursionLimitError> + where + F: FnMut(&Self), + { + MirScalarExpr::visit_pre(self, f); + Ok(()) + } } diff --git a/src/expr/src/scalar/reduce.rs b/src/expr/src/scalar/reduce.rs index f8bd6d52c2292..6118b48f06e9a 100644 --- a/src/expr/src/scalar/reduce.rs +++ b/src/expr/src/scalar/reduce.rs @@ -37,7 +37,8 @@ pub fn reduce(expr: &mut MirScalarExpr, column_types: &[ReprColumnType]) { let mut old = MirScalarExpr::column(0); while old != *expr { old = expr.clone(); - expr.visit_mut_pre_post( + #[allow(deprecated)] + expr.visit_mut_pre_post_nolimit( &mut |e| { reduce_pre(e, column_types); None diff --git a/src/expr/src/visit.rs b/src/expr/src/visit.rs index 4f70af29cd1d3..596695225dbb6 100644 --- a/src/expr/src/visit.rs +++ b/src/expr/src/visit.rs @@ -11,7 +11,7 @@ //! //! Recursive types can implement the [`VisitChildren`] trait, to //! specify how their recursive entries can be accessed. The extension -//! trait [`Visit`] then adds support for iteratively traversing +//! trait [`Visit`] then adds support for recursively traversing //! instances of those types. //! //! # Naming @@ -34,6 +34,12 @@ //! * no suffix: recursively visit children in pre- and post-order //! using a ~Visitor~` that encapsulates the shared context. +use std::marker::PhantomData; + +use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError, maybe_grow}; + +use crate::RECURSION_LIMIT; + /// A trait for types that can visit their direct children of type `T`. /// /// Implementing [`VisitChildren`] automatically also implements @@ -51,55 +57,34 @@ pub trait VisitChildren { /// Apply an infallible immutable function `f` to each direct child. fn visit_children(&self, f: F) where - F: FnMut(&T), - { - self.children().for_each(f); - } + F: FnMut(&T); /// Apply an infallible mutable function `f` to each direct child. fn visit_mut_children(&mut self, f: F) where - F: FnMut(&mut T), - { - self.children_mut().for_each(f); - } + F: FnMut(&mut T); /// Apply a fallible immutable function `f` to each direct child. - fn try_visit_children(&self, mut f: F) -> Result<(), E> + /// + /// For mutually recursive implementations (say consisting of two + /// types `A` and `B`), recursing through `B` in order to find all + /// `A`-children of a node of type `A` might cause lead to a + /// [`RecursionLimitError`], hence the bound on `E`. + fn try_visit_children(&self, f: F) -> Result<(), E> where F: FnMut(&T) -> Result<(), E>, - { - for child in self.children() { - f(child)?; - } - - Ok(()) - } + E: From; /// Apply a fallible mutable function `f` to each direct child. - fn try_visit_mut_children(&mut self, mut f: F) -> Result<(), E> - where - F: FnMut(&mut T) -> Result<(), E>, - { - for child in self.children_mut() { - f(child)?; - } - - Ok(()) - } - - /// The `T`-typed children of this element. - fn children<'a>(&'a self) -> impl DoubleEndedIterator - where - T: 'a; - - /// The `&mut T`-typed children of this element. /// - /// It is critical for the safety of mutable post-order traversals that this - /// function be written using safe code. - fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator + /// For mutually recursive implementations (say consisting of two + /// types `A` and `B`), recursing through `B` in order to find all + /// `A`-children of a node of type `A` might cause lead to a + /// [`RecursionLimitError`], hence the bound on `E`. + fn try_visit_mut_children(&mut self, f: F) -> Result<(), E> where - T: 'a; + F: FnMut(&mut T) -> Result<(), E>, + E: From; } /// A trait for types that can recursively visit their children of the @@ -108,34 +93,53 @@ pub trait VisitChildren { /// This trait is automatically implemented for all implementors of /// [`VisitChildren`]. /// -/// All methods provided by this trait are iterative. +/// All methods provided by this trait ensure that the stack is grown +/// as needed, to avoid stack overflows when traversing deeply +/// recursive objects. They also enforce a recursion limit of +/// [`RECURSION_LIMIT`] by returning an error when that limit +/// is exceeded. /// -/// NB that any visitor with mutable post-traversal uses unsafe code. It is critical -/// that `VisitChildren::children_mut` be written using safe code, i.e., no aliasing -/// of children or access to parents. +/// There are also `*_nolimit` methods that don't enforce a recursion +/// limit. Those methods are deprecated and should not be used. pub trait Visit { /// Post-order immutable infallible visitor for `self`. - fn visit_post(&self, f: &mut F) + fn visit_post(&self, f: &mut F) -> Result<(), RecursionLimitError> where F: FnMut(&Self); + /// Post-order immutable infallible visitor for `self`. + /// Does not enforce a recursion limit. + #[deprecated = "Use `visit_post` instead."] + fn visit_post_nolimit(&self, f: &mut F) + where + F: FnMut(&Self); + + /// Post-order mutable infallible visitor for `self`. + fn visit_mut_post(&mut self, f: &mut F) -> Result<(), RecursionLimitError> + where + F: FnMut(&mut Self); + /// Post-order mutable infallible visitor for `self`. - fn visit_mut_post(&mut self, f: &mut F) + /// Does not enforce a recursion limit. + #[deprecated = "Use `visit_mut_post` instead."] + fn visit_mut_post_nolimit(&mut self, f: &mut F) where F: FnMut(&mut Self); /// Post-order immutable fallible visitor for `self`. fn try_visit_post(&self, f: &mut F) -> Result<(), E> where - F: FnMut(&Self) -> Result<(), E>; + F: FnMut(&Self) -> Result<(), E>, + E: From; /// Post-order mutable fallible visitor for `self`. fn try_visit_mut_post(&mut self, f: &mut F) -> Result<(), E> where - F: FnMut(&mut Self) -> Result<(), E>; + F: FnMut(&mut Self) -> Result<(), E>, + E: From; /// Pre-order immutable infallible visitor for `self`. - fn visit_pre(&self, f: &mut F) + fn visit_pre(&self, f: &mut F) -> Result<(), RecursionLimitError> where F: FnMut(&Self); @@ -155,25 +159,42 @@ pub trait Visit { init: Context, acc_fun: &mut AccFun, visitor: &mut Visitor, - ) where + ) -> Result<(), RecursionLimitError> + where Context: Clone, AccFun: FnMut(Context, &Self) -> Context, Visitor: FnMut(&Context, &Self); + /// Pre-order immutable infallible visitor for `self`. + /// Does not enforce a recursion limit. + #[deprecated = "Use `visit_pre` instead."] + fn visit_pre_nolimit(&self, f: &mut F) + where + F: FnMut(&Self); + + /// Pre-order mutable infallible visitor for `self`. + fn visit_mut_pre(&mut self, f: &mut F) -> Result<(), RecursionLimitError> + where + F: FnMut(&mut Self); + /// Pre-order mutable infallible visitor for `self`. - fn visit_mut_pre(&mut self, f: &mut F) + /// Does not enforce a recursion limit. + #[deprecated = "Use `visit_mut_pre` instead."] + fn visit_mut_pre_nolimit(&mut self, f: &mut F) where F: FnMut(&mut Self); /// Pre-order immutable fallible visitor for `self`. fn try_visit_pre(&self, f: &mut F) -> Result<(), E> where - F: FnMut(&Self) -> Result<(), E>; + F: FnMut(&Self) -> Result<(), E>, + E: From; /// Pre-order mutable fallible visitor for `self`. fn try_visit_mut_pre(&mut self, f: &mut F) -> Result<(), E> where - F: FnMut(&mut Self) -> Result<(), E>; + F: FnMut(&mut Self) -> Result<(), E>, + E: From; /// A generalization of [`Visit::visit_pre`] and [`Visit::visit_post`]. /// @@ -182,171 +203,153 @@ pub trait Visit { /// /// Optionally, `pre` can return which children, if any, should be visited /// (default is to visit all children). - fn visit_pre_post(&self, pre: &mut F1, post: &mut F2) + fn visit_pre_post( + &self, + pre: &mut F1, + post: &mut F2, + ) -> Result<(), RecursionLimitError> where F1: FnMut(&Self) -> Option>, F2: FnMut(&Self); - /// A generalization of [`Visit::visit_mut_pre`] and [`Visit::visit_mut_post`]. + /// A generalization of [`Visit::visit_pre`] and [`Visit::visit_post`]. + /// Does not enforce a recursion limit. /// /// The function `pre` runs on `self` before it runs on any of the children. /// The function `post` runs on children first before the parent. /// /// Optionally, `pre` can return which children, if any, should be visited /// (default is to visit all children). + #[deprecated = "Use `visit` instead."] + fn visit_pre_post_nolimit(&self, pre: &mut F1, post: &mut F2) + where + F1: FnMut(&Self) -> Option>, + F2: FnMut(&Self); + + /// A generalization of [`Visit::visit_mut_pre`] and [`Visit::visit_mut_post`]. + /// + /// The function `pre` runs on `self` before it runs on any of the children. + /// The function `post` runs on children first before the parent. /// - /// It is improtant for safety that `pre` is (a) safe code and (b) returns children only. - fn visit_mut_pre_post(&mut self, pre: &mut F1, post: &mut F2) + /// Optionally, `pre` can return which children, if any, should be visited + /// (default is to visit all children). + #[deprecated = "Use `visit_mut` instead."] + fn visit_mut_pre_post( + &mut self, + pre: &mut F1, + post: &mut F2, + ) -> Result<(), RecursionLimitError> where F1: FnMut(&mut Self) -> Option>, F2: FnMut(&mut Self); -} -/// Frames for immutable post-traversals, will be kept in a stack. -enum VisitAction<'a, T> { - /// Put on the stack when entering a node. + /// A generalization of [`Visit::visit_mut_pre`] and [`Visit::visit_mut_post`]. + /// Does not enforce a recursion limit. /// - /// Causes us to push children. - Enter(&'a T), - /// Put on the stack when leaving a node, all children visited. + /// The function `pre` runs on `self` before it runs on any of the children. + /// The function `post` runs on children first before the parent. /// - /// Causes us to do the post-traversal visit of the parent. - Leave(&'a T), + /// Optionally, `pre` can return which children, if any, should be visited + /// (default is to visit all children). + #[deprecated = "Use `visit_mut_pre_post` instead."] + fn visit_mut_pre_post_nolimit(&mut self, pre: &mut F1, post: &mut F2) + where + F1: FnMut(&mut Self) -> Option>, + F2: FnMut(&mut Self); + + fn visit(&self, visitor: &mut V) -> Result<(), RecursionLimitError> + where + Self: Sized, + V: Visitor; + + fn visit_mut(&mut self, visitor: &mut V) -> Result<(), RecursionLimitError> + where + Self: Sized, + V: VisitorMut; + + fn try_visit(&self, visitor: &mut V) -> Result<(), E> + where + Self: Sized, + V: TryVisitor, + E: From; + + fn try_visit_mut(&mut self, visitor: &mut V) -> Result<(), E> + where + Self: Sized, + V: TryVisitorMut, + E: From; } -/// Frames for mutable post-traversals, will be kept in a stack. -/// -/// Notice that we use mutable pointers, because mutable post-traversal is unsafe in rust. -/// -/// The core argument for correctness mirrors the tree-borrow correctness argument Rust's -/// borrow checker uses for the ordinary function call stack. Loosely: -/// -/// - We split a mutable parent node into its children, and put them on the stack. -/// - We keep a mutable pointer to the parent node, but won't touch it until we're done with all children. -/// + In the function call stack, this mutable pointer is the stack frame, safely inaccessible. -/// + In our `unsafe` action stack, this mutable pointer is below all of the `Enter` actions, -/// and we promise not to touch it until they complete. -/// - When we have processed all children, we can reassemble access to the parent from its parts. -/// + In the function call stack, we do this on return from recursive calls. -/// - In our `unsafe` action stack, we do this after popping all children. -enum VisitMutAction { - /// Put on the stack when entering a node. - /// - /// Causes us to push children. - Enter(*mut T), - /// Put on the stack when leaving a node, all children visited. - /// - /// Causes us to do the post-traversal visit of the parent. - Leave(*mut T), +pub trait Visitor { + fn pre_visit(&mut self, expr: &T); + fn post_visit(&mut self, expr: &T); +} + +pub trait VisitorMut { + fn pre_visit(&mut self, expr: &mut T); + fn post_visit(&mut self, expr: &mut T); +} + +pub trait TryVisitor> { + fn pre_visit(&mut self, expr: &T) -> Result<(), E>; + fn post_visit(&mut self, expr: &T) -> Result<(), E>; +} + +pub trait TryVisitorMut> { + fn pre_visit(&mut self, expr: &mut T) -> Result<(), E>; + fn post_visit(&mut self, expr: &mut T) -> Result<(), E>; } impl> Visit for T { - fn visit_post(&self, f: &mut F) + fn visit_post(&self, f: &mut F) -> Result<(), RecursionLimitError> where F: FnMut(&Self), { - use VisitAction::*; - let mut stack = vec![Enter(self)]; - while let Some(action) = stack.pop() { - match action { - Enter(elt) => { - stack.push(Leave(elt)); - // Push children in reverse so they pop (and are visited) left-to-right. - stack.extend(elt.children().rev().map(Enter)); - } - Leave(elt) => f(elt), - } - } + StackSafeVisit::new().visit_post(self, f) } - #[allow(clippy::as_conversions)] - fn visit_mut_post(&mut self, f: &mut F) + fn visit_post_nolimit(&self, f: &mut F) + where + F: FnMut(&Self), + { + StackSafeVisit::new().visit_post_nolimit(self, f) + } + + fn visit_mut_post(&mut self, f: &mut F) -> Result<(), RecursionLimitError> where F: FnMut(&mut Self), { - // This code uses `unsafe`. The core safety argument is that: - // - // - `children_mut()` produces disjoint children - // - no aliasing means each `Enter` is processed separately, and we `Leave` each node exactly once - // - // Put another way, our `stack` mirrors the function call stack, which allows multiple `&mut` refs at once, - // since only one stack frame can be active at a time. - - use VisitMutAction::*; - let mut stack = vec![Enter(self as *mut T)]; - while let Some(action) = stack.pop() { - match action { - Enter(ptr) => { - stack.push(Leave(ptr)); - let elt = unsafe { &mut *ptr }; - // Push children in reverse so they pop (and are visited) left-to-right. - stack.extend(elt.children_mut().rev().map(|child| Enter(child as *mut T))); - } - Leave(elt) => f(unsafe { &mut *elt }), - } - } + StackSafeVisit::new().visit_mut_post(self, f) + } + + fn visit_mut_post_nolimit(&mut self, f: &mut F) + where + F: FnMut(&mut Self), + { + StackSafeVisit::new().visit_mut_post_nolimit(self, f) } fn try_visit_post(&self, f: &mut F) -> Result<(), E> where F: FnMut(&Self) -> Result<(), E>, + E: From, { - use VisitAction::*; - let mut stack = vec![Enter(self)]; - while let Some(action) = stack.pop() { - match action { - Enter(elt) => { - stack.push(Leave(elt)); - // Push children in reverse so they pop (and are visited) left-to-right. - stack.extend(elt.children().rev().map(Enter)); - } - Leave(elt) => f(elt)?, - } - } - - Ok(()) + StackSafeVisit::new().try_visit_post(self, f) } - #[allow(clippy::as_conversions)] fn try_visit_mut_post(&mut self, f: &mut F) -> Result<(), E> where F: FnMut(&mut Self) -> Result<(), E>, + E: From, { - // This code uses `unsafe`. The core safety argument is that: - // - // - `children_mut()` produces disjoint children - // - no aliasing means each `Enter` is processed separately, and we `Leave` each node exactly once - // - // Put another way, our `stack` mirrors the function call stack, which allows multiple `&mut` refs at once, - // since only one stack frame can be active at a time. - - use VisitMutAction::*; - let mut stack = vec![Enter(self as *mut T)]; - while let Some(action) = stack.pop() { - match action { - Enter(ptr) => { - stack.push(Leave(ptr)); - let elt = unsafe { &mut *ptr }; - // Push children in reverse so they pop (and are visited) left-to-right. - stack.extend(elt.children_mut().rev().map(|child| Enter(child as *mut T))); - } - Leave(ptr) => f(unsafe { &mut *ptr })?, - } - } - - Ok(()) + StackSafeVisit::new().try_visit_mut_post(self, f) } - fn visit_pre(&self, f: &mut F) + fn visit_pre(&self, f: &mut F) -> Result<(), RecursionLimitError> where F: FnMut(&Self), { - let mut stack = vec![self]; - while let Some(elt) = stack.pop() { - f(elt); - // Push children in reverse so they pop (and are visited) left-to-right. - stack.extend(elt.children().rev()); - } + StackSafeVisit::new().visit_pre(self, f) } fn visit_pre_with_context( @@ -354,136 +357,429 @@ impl> Visit for T { init: Context, acc_fun: &mut AccFun, visitor: &mut Visitor, - ) where + ) -> Result<(), RecursionLimitError> + where Context: Clone, AccFun: FnMut(Context, &Self) -> Context, Visitor: FnMut(&Context, &Self), { - let mut stack = vec![(self, init)]; - while let Some((elt, ctx)) = stack.pop() { - visitor(&ctx, elt); - let ctx = acc_fun(ctx, elt); - // Push children in reverse so they pop (and are visited) left-to-right. - stack.extend(elt.children().rev().map(|child| (child, ctx.clone()))); - } + StackSafeVisit::new().visit_pre_with_context(self, init, acc_fun, visitor) } - fn visit_mut_pre(&mut self, f: &mut F) + fn visit_pre_nolimit(&self, f: &mut F) + where + F: FnMut(&Self), + { + StackSafeVisit::new().visit_pre_nolimit(self, f) + } + + fn visit_mut_pre(&mut self, f: &mut F) -> Result<(), RecursionLimitError> where F: FnMut(&mut Self), { - let mut stack = vec![self]; - while let Some(elt) = stack.pop() { - f(elt); - // Push children in reverse so they pop (and are visited) left-to-right. - stack.extend(elt.children_mut().rev()) - } + StackSafeVisit::new().visit_mut_pre(self, f) + } + + fn visit_mut_pre_nolimit(&mut self, f: &mut F) + where + F: FnMut(&mut Self), + { + StackSafeVisit::new().visit_mut_pre_nolimit(self, f) } fn try_visit_pre(&self, f: &mut F) -> Result<(), E> where F: FnMut(&Self) -> Result<(), E>, + E: From, { - let mut stack = vec![self]; - while let Some(elt) = stack.pop() { - f(elt)?; - // Push children in reverse so they pop (and are visited) left-to-right. - stack.extend(elt.children().rev()); - } - - Ok(()) + StackSafeVisit::new().try_visit_pre(self, f) } fn try_visit_mut_pre(&mut self, f: &mut F) -> Result<(), E> where F: FnMut(&mut Self) -> Result<(), E>, + E: From, { - let mut stack = vec![self]; - while let Some(elt) = stack.pop() { - f(elt)?; - // Push children in reverse so they pop (and are visited) left-to-right. - stack.extend(elt.children_mut().rev()); - } + StackSafeVisit::new().try_visit_mut_pre(self, f) + } - Ok(()) + fn visit_pre_post(&self, pre: &mut F1, post: &mut F2) -> Result<(), RecursionLimitError> + where + F1: FnMut(&Self) -> Option>, + F2: FnMut(&Self), + { + StackSafeVisit::new().visit_pre_post(self, pre, post) } - fn visit_pre_post(&self, pre: &mut F1, post: &mut F2) + + fn visit_pre_post_nolimit(&self, pre: &mut F1, post: &mut F2) where F1: FnMut(&Self) -> Option>, F2: FnMut(&Self), { - use VisitAction::*; - let mut stack = vec![Enter(self)]; - while let Some(action) = stack.pop() { - match action { - Enter(elt) => { - stack.push(Leave(elt)); - if let Some(children) = pre(elt) { - for child in children.into_iter().rev() { - stack.push(Enter(child)); - } - } else { - for child in elt.children().rev() { - stack.push(Enter(child)); - } - } + StackSafeVisit::new().visit_pre_post_nolimit(self, pre, post) + } + + fn visit_mut_pre_post( + &mut self, + pre: &mut F1, + post: &mut F2, + ) -> Result<(), RecursionLimitError> + where + F1: FnMut(&mut Self) -> Option>, + F2: FnMut(&mut Self), + { + StackSafeVisit::new().visit_mut_pre_post(self, pre, post) + } + + fn visit_mut_pre_post_nolimit(&mut self, pre: &mut F1, post: &mut F2) + where + F1: FnMut(&mut Self) -> Option>, + F2: FnMut(&mut Self), + { + StackSafeVisit::new().visit_mut_pre_post_nolimit(self, pre, post) + } + + fn visit(&self, visitor: &mut V) -> Result<(), RecursionLimitError> + where + Self: Sized, + V: Visitor, + { + StackSafeVisit::new().visit(self, visitor) + } + + fn visit_mut(&mut self, visitor: &mut V) -> Result<(), RecursionLimitError> + where + Self: Sized, + V: VisitorMut, + { + StackSafeVisit::new().visit_mut(self, visitor) + } + + fn try_visit(&self, visitor: &mut V) -> Result<(), E> + where + Self: Sized, + V: TryVisitor, + E: From, + { + StackSafeVisit::new().try_visit(self, visitor) + } + + fn try_visit_mut(&mut self, visitor: &mut V) -> Result<(), E> + where + Self: Sized, + V: TryVisitorMut, + E: From, + { + StackSafeVisit::new().try_visit_mut(self, visitor) + } +} + +struct StackSafeVisit { + recursion_guard: RecursionGuard, + _type: PhantomData, +} + +impl CheckedRecursion for StackSafeVisit { + fn recursion_guard(&self) -> &RecursionGuard { + &self.recursion_guard + } +} + +impl> StackSafeVisit { + fn new() -> Self { + Self { + recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT), + _type: PhantomData, + } + } + + fn visit_post(&self, value: &T, f: &mut F) -> Result<(), RecursionLimitError> + where + F: FnMut(&T), + { + self.checked_recur(move |_| { + value.try_visit_children(|child| self.visit_post(child, f))?; + f(value); + Ok(()) + }) + } + + fn visit_post_nolimit(&self, value: &T, f: &mut F) + where + F: FnMut(&T), + { + maybe_grow(|| { + value.visit_children(|child| self.visit_post_nolimit(child, f)); + f(value) + }) + } + + fn visit_mut_post(&self, value: &mut T, f: &mut F) -> Result<(), RecursionLimitError> + where + F: FnMut(&mut T), + { + self.checked_recur(move |_| { + value.try_visit_mut_children(|child| self.visit_mut_post(child, f))?; + f(value); + Ok(()) + }) + } + + fn visit_mut_post_nolimit(&self, value: &mut T, f: &mut F) + where + F: FnMut(&mut T), + { + maybe_grow(|| { + value.visit_mut_children(|child| self.visit_mut_post_nolimit(child, f)); + f(value) + }) + } + + fn try_visit_post(&self, value: &T, f: &mut F) -> Result<(), E> + where + F: FnMut(&T) -> Result<(), E>, + E: From, + { + self.checked_recur(move |_| { + value.try_visit_children(|child| self.try_visit_post(child, f))?; + f(value) + }) + } + + fn try_visit_mut_post(&self, value: &mut T, f: &mut F) -> Result<(), E> + where + F: FnMut(&mut T) -> Result<(), E>, + E: From, + { + self.checked_recur(move |_| { + value.try_visit_mut_children(|child| self.try_visit_mut_post(child, f))?; + f(value) + }) + } + + fn visit_pre(&self, value: &T, f: &mut F) -> Result<(), RecursionLimitError> + where + F: FnMut(&T), + { + self.checked_recur(move |_| { + f(value); + value.try_visit_children(|child| self.visit_pre(child, f)) + }) + } + + fn visit_pre_with_context( + &self, + node: &T, + init: Context, + acc_fun: &mut AccFun, + visitor: &mut Visitor, + ) -> Result<(), RecursionLimitError> + where + Context: Clone, + AccFun: FnMut(Context, &T) -> Context, + Visitor: FnMut(&Context, &T), + { + self.checked_recur(move |_| { + visitor(&init, node); + let context = acc_fun(init, node); + node.try_visit_children(|child| { + self.visit_pre_with_context(child, context.clone(), acc_fun, visitor) + }) + }) + } + + fn visit_pre_nolimit(&self, value: &T, f: &mut F) + where + F: FnMut(&T), + { + maybe_grow(|| { + f(value); + value.visit_children(|child| self.visit_pre_nolimit(child, f)) + }) + } + + fn visit_mut_pre(&self, value: &mut T, f: &mut F) -> Result<(), RecursionLimitError> + where + F: FnMut(&mut T), + { + self.checked_recur(move |_| { + f(value); + value.try_visit_mut_children(|child| self.visit_mut_pre(child, f)) + }) + } + + fn visit_mut_pre_nolimit(&self, value: &mut T, f: &mut F) + where + F: FnMut(&mut T), + { + maybe_grow(|| { + f(value); + value.visit_mut_children(|child| self.visit_mut_pre_nolimit(child, f)) + }) + } + + fn try_visit_pre(&self, value: &T, f: &mut F) -> Result<(), E> + where + F: FnMut(&T) -> Result<(), E>, + E: From, + { + self.checked_recur(move |_| { + f(value)?; + value.try_visit_children(|child| self.try_visit_pre(child, f)) + }) + } + + fn try_visit_mut_pre(&self, value: &mut T, f: &mut F) -> Result<(), E> + where + F: FnMut(&mut T) -> Result<(), E>, + E: From, + { + self.checked_recur(move |_| { + f(value)?; + value.try_visit_mut_children(|child| self.try_visit_mut_pre(child, f)) + }) + } + + fn visit_pre_post( + &self, + value: &T, + pre: &mut F1, + post: &mut F2, + ) -> Result<(), RecursionLimitError> + where + F1: FnMut(&T) -> Option>, + F2: FnMut(&T), + { + self.checked_recur(move |_| { + if let Some(to_visit) = pre(value) { + for child in to_visit { + self.visit_pre_post(child, pre, post)?; } - Leave(elt) => { - post(elt); + } else { + value.try_visit_children(|child| self.visit_pre_post(child, pre, post))?; + } + post(value); + Ok(()) + }) + } + + fn visit_pre_post_nolimit(&self, value: &T, pre: &mut F1, post: &mut F2) + where + F1: FnMut(&T) -> Option>, + F2: FnMut(&T), + { + maybe_grow(|| { + if let Some(to_visit) = pre(value) { + for child in to_visit { + self.visit_pre_post_nolimit(child, pre, post); } + } else { + value.visit_children(|child| self.visit_pre_post_nolimit(child, pre, post)); } - } + post(value); + }) } - #[allow(clippy::as_conversions)] - fn visit_mut_pre_post(&mut self, pre: &mut F1, post: &mut F2) + fn visit_mut_pre_post( + &self, + value: &mut T, + pre: &mut F1, + post: &mut F2, + ) -> Result<(), RecursionLimitError> where - F1: FnMut(&mut Self) -> Option>, - F2: FnMut(&mut Self), + F1: FnMut(&mut T) -> Option>, + F2: FnMut(&mut T), { - // This code uses `unsafe`. The core safety argument is that: - // - // - `children_mut()` produces disjoint children - // - no aliasing means each `Enter` is processed separately, and we `Leave` each node exactly once - // - even if `pre` modifies the pointer, we retake it before computing children - // - // Put another way, our `stack` mirrors the function call stack, which allows multiple `&mut` refs at once, - // since only one stack frame can be active at a time. - - use VisitMutAction::*; - let mut stack = vec![Enter(self as *mut T)]; - while let Some(action) = stack.pop() { - match action { - Enter(ptr) => { - let elt = unsafe { &mut *ptr }; - stack.push(Leave(ptr)); - - if let Some(children) = pre(elt) { - for child in children.into_iter().rev() { - stack.push(Enter(child)); - } - } else { - let elt = unsafe { &mut *ptr }; - for child in elt.children_mut().rev() { - stack.push(Enter(child)); - } - } + self.checked_recur(move |_| { + if let Some(to_visit) = pre(value) { + for child in to_visit { + self.visit_mut_pre_post(child, pre, post)?; } - Leave(ptr) => { - post(unsafe { &mut *ptr }); + } else { + value.try_visit_mut_children(|child| self.visit_mut_pre_post(child, pre, post))?; + } + post(value); + Ok(()) + }) + } + + fn visit_mut_pre_post_nolimit(&self, value: &mut T, pre: &mut F1, post: &mut F2) + where + F1: FnMut(&mut T) -> Option>, + F2: FnMut(&mut T), + { + maybe_grow(|| { + if let Some(to_visit) = pre(value) { + for child in to_visit { + self.visit_mut_pre_post_nolimit(child, pre, post); } + } else { + value.visit_mut_children(|child| self.visit_mut_pre_post_nolimit(child, pre, post)); } - } + post(value); + }) + } + + fn visit(&self, value: &T, visitor: &mut V) -> Result<(), RecursionLimitError> + where + Self: Sized, + V: Visitor, + { + self.checked_recur(move |this| { + visitor.pre_visit(value); + value.try_visit_children(|child| this.visit(child, visitor))?; + visitor.post_visit(value); + Ok(()) + }) + } + + fn visit_mut(&self, value: &mut T, visitor: &mut V) -> Result<(), RecursionLimitError> + where + Self: Sized, + V: VisitorMut, + { + self.checked_recur(move |this| { + visitor.pre_visit(value); + value.try_visit_mut_children(|child| this.visit_mut(child, visitor))?; + visitor.post_visit(value); + Ok(()) + }) + } + + fn try_visit(&self, value: &T, visitor: &mut V) -> Result<(), E> + where + Self: Sized, + V: TryVisitor, + E: From, + { + self.checked_recur(move |_| { + visitor.pre_visit(value)?; + value.try_visit_children(|child| self.try_visit(child, visitor))?; + visitor.post_visit(value)?; + Ok(()) + }) + } + + fn try_visit_mut(&self, value: &mut T, visitor: &mut V) -> Result<(), E> + where + Self: Sized, + V: TryVisitorMut, + E: From, + { + self.checked_recur(move |_| { + visitor.pre_visit(value)?; + value.try_visit_mut_children(|child| self.try_visit_mut(child, visitor))?; + visitor.post_visit(value)?; + Ok(()) + }) } } #[cfg(test)] mod tests { - use super::*; + use mz_ore::assert_ok; - // This test demonstrates how to build visitors for mutually recursive definitions. - // The key move here is the `direct_sub_*` methods, which are worklist-based traversals - // that find children of appropriate type. + use super::*; #[derive(Debug, Eq, PartialEq)] enum A { @@ -499,89 +795,14 @@ mod tests { FrA(Box), } - impl A { - fn direct_sub_b(&self) -> Vec<&B> { - let mut subs: Vec<&B> = vec![]; - - let mut worklist = vec![self]; - while let Some(a) = worklist.pop() { - match a { - A::Add(lhs, rhs) => { - worklist.push(&*lhs); - worklist.push(&*rhs); - } - A::Lit(_) => (), - A::FrB(b) => subs.push(&*b), - } - } - - subs - } - - fn direct_sub_b_mut(&mut self) -> Vec<&mut B> { - let mut subs: Vec<&mut B> = vec![]; - - let mut worklist = vec![self]; - while let Some(a) = worklist.pop() { - match a { - A::Add(lhs, rhs) => { - worklist.push(&mut **lhs); - worklist.push(&mut **rhs); - } - A::Lit(_) => (), - A::FrB(b) => subs.push(&mut **b), - } - } - - subs - } - } - - impl B { - fn direct_sub_a(&self) -> Vec<&A> { - let mut subs: Vec<&A> = vec![]; - - let mut worklist = vec![self]; - while let Some(b) = worklist.pop() { - match b { - B::Mul(lhs, rhs) => { - worklist.push(&*lhs); - worklist.push(&*rhs); - } - B::Lit(_) => (), - B::FrA(a) => subs.push(&*a), - } - } - - subs - } - - fn direct_sub_a_mut(&mut self) -> Vec<&mut A> { - let mut subs: Vec<&mut A> = vec![]; - - let mut worklist = vec![self]; - while let Some(b) = worklist.pop() { - match b { - B::Mul(lhs, rhs) => { - worklist.push(&mut **lhs); - worklist.push(&mut **rhs); - } - B::Lit(_) => (), - B::FrA(a) => subs.push(&mut **a), - } - } - - subs - } - } - impl VisitChildren for A { fn visit_children(&self, mut f: F) where F: FnMut(&A), { VisitChildren::visit_children(self, |expr: &B| { - Visit::visit_post(expr, &mut |expr| match expr { + #[allow(deprecated)] + Visit::visit_post_nolimit(expr, &mut |expr| match expr { B::FrA(expr) => f(expr.as_ref()), _ => (), }); @@ -602,7 +823,8 @@ mod tests { F: FnMut(&mut A), { VisitChildren::visit_mut_children(self, |expr: &mut B| { - Visit::visit_mut_post(expr, &mut |expr| match expr { + #[allow(deprecated)] + Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr { B::FrA(expr) => f(expr.as_mut()), _ => (), }); @@ -621,6 +843,7 @@ mod tests { fn try_visit_children(&self, mut f: F) -> Result<(), E> where F: FnMut(&A) -> Result<(), E>, + E: From, { VisitChildren::try_visit_children(self, |expr: &B| { Visit::try_visit_post(expr, &mut |expr| match expr { @@ -643,6 +866,7 @@ mod tests { fn try_visit_mut_children(&mut self, mut f: F) -> Result<(), E> where F: FnMut(&mut A) -> Result<(), E>, + E: From, { VisitChildren::try_visit_mut_children(self, |expr: &mut B| { Visit::try_visit_mut_post(expr, &mut |expr| match expr { @@ -661,45 +885,6 @@ mod tests { } Ok(()) } - - fn children<'a>(&'a self) -> impl DoubleEndedIterator - where - A: 'a, - { - let mut v: Vec<&A> = vec![]; - match self { - A::Add(lhs, rhs) => { - v.push(&*lhs); - v.push(&*rhs) - } - A::Lit(_) => (), - A::FrB(b) => { - v.append(&mut b.direct_sub_a()); - } - } - - v.into_iter() - } - - fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator - where - A: 'a, - { - let mut v: Vec<&mut A> = vec![]; - - match self { - A::Add(lhs, rhs) => { - v.push(&mut **lhs); - v.push(&mut **rhs) - } - A::Lit(_) => (), - A::FrB(b) => { - v.append(&mut b.direct_sub_a_mut()); - } - } - - v.into_iter() - } } impl VisitChildren for A { @@ -728,6 +913,7 @@ mod tests { fn try_visit_children(&self, mut f: F) -> Result<(), E> where F: FnMut(&B) -> Result<(), E>, + E: From, { match self { A::Add(_, _) => Ok(()), @@ -739,6 +925,7 @@ mod tests { fn try_visit_mut_children(&mut self, mut f: F) -> Result<(), E> where F: FnMut(&mut B) -> Result<(), E>, + E: From, { match self { A::Add(_, _) => Ok(()), @@ -746,30 +933,6 @@ mod tests { A::FrB(expr) => f(expr), } } - - fn children<'a>(&'a self) -> impl DoubleEndedIterator - where - B: 'a, - { - let mut child: Option<&B> = None; - match self { - A::Add(_, _) | A::Lit(_) => (), - A::FrB(b) => child = Some(&*b), - } - child.into_iter() - } - - fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator - where - B: 'a, - { - let mut child: Option<&mut B> = None; - match self { - A::Add(_, _) | A::Lit(_) => (), - A::FrB(b) => child = Some(&mut **b), - } - child.into_iter() - } } impl VisitChildren for B { @@ -779,7 +942,7 @@ mod tests { { // VisitChildren::visit_children(self, |expr: &A| { // #[allow(deprecated)] - // Visit::visit_post(expr, &mut |expr| match expr { + // Visit::visit_post_nolimit(expr, &mut |expr| match expr { // A::FrB(expr) => f(expr.as_ref()), // _ => (), // }); @@ -820,6 +983,7 @@ mod tests { fn try_visit_children(&self, mut f: F) -> Result<(), E> where F: FnMut(&B) -> Result<(), E>, + E: From, { // VisitChildren::try_visit_children(self, |expr: &A| { // Visit::try_visit_post(expr, &mut |expr| match expr { @@ -842,6 +1006,7 @@ mod tests { fn try_visit_mut_children(&mut self, mut f: F) -> Result<(), E> where F: FnMut(&mut B) -> Result<(), E>, + E: From, { // VisitChildren::try_visit_mut_children(self, |expr: &mut A| { // Visit::try_visit_mut_post(expr, &mut |expr| match expr { @@ -860,38 +1025,6 @@ mod tests { } Ok(()) } - - fn children<'a>(&'a self) -> impl DoubleEndedIterator - where - B: 'a, - { - let mut v: Vec<&B> = vec![]; - match self { - B::Mul(lhs, rhs) => { - v.push(&*lhs); - v.push(&*rhs); - } - B::Lit(_) => (), - B::FrA(a) => v.append(&mut a.direct_sub_b()), - } - v.into_iter() - } - - fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator - where - B: 'a, - { - let mut v: Vec<&mut B> = vec![]; - match self { - B::Mul(lhs, rhs) => { - v.push(&mut **lhs); - v.push(&mut **rhs); - } - B::Lit(_) => (), - B::FrA(a) => v.append(&mut a.direct_sub_b_mut()), - } - v.into_iter() - } } impl VisitChildren for B { @@ -920,6 +1053,7 @@ mod tests { fn try_visit_children(&self, mut f: F) -> Result<(), E> where F: FnMut(&A) -> Result<(), E>, + E: From, { match self { B::Mul(_, _) => Ok(()), @@ -931,6 +1065,7 @@ mod tests { fn try_visit_mut_children(&mut self, mut f: F) -> Result<(), E> where F: FnMut(&mut A) -> Result<(), E>, + E: From, { match self { B::Mul(_, _) => Ok(()), @@ -938,30 +1073,6 @@ mod tests { B::FrA(expr) => f(expr), } } - - fn children<'a>(&'a self) -> impl DoubleEndedIterator - where - A: 'a, - { - let mut child: Option<&A> = None; - match self { - B::Mul(_, _) | B::Lit(_) => (), - B::FrA(a) => child = Some(&*a), - } - child.into_iter() - } - - fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator - where - A: 'a, - { - let mut child: Option<&mut A> = None; - match self { - B::Mul(_, _) | B::Lit(_) => (), - B::FrA(a) => child = Some(&mut **a), - } - child.into_iter() - } } /// x + (y + z) @@ -1034,24 +1145,45 @@ mod tests { let mut act = test_term_rec_a(0); let exp = test_term_rec_a(20); - act.visit_mut_pre(&mut |expr| match expr { + let res = act.visit_mut_pre(&mut |expr| match expr { A::Lit(x) => *x = *x + 20, _ => (), }); + assert_ok!(res); assert_eq!(act, exp); } + /// This test currently fails with the following error: + /// + /// reached the recursion limit while instantiating + /// ` as CheckedRec...ore::stack::RecursionLimitError>` + /// + /// The problem (I think) is in the fact that the lambdas passed in the + /// VisitChildren for A and the VisitChildren for B definitions end + /// up in an infinite loop. + /// + /// More specifically, we run into the following cycle: + /// + /// - `>::visit_children` + /// - >::visit_children` + /// - ::visit_post_nolimit` + /// - >::visit_children` + /// - >::visit_children` + /// - ::visit_post_nolimit` + /// - >::visit_children` #[mz_ore::test] + #[ignore = "making the VisitChildren definitions symmetric breaks the compiler"] fn test_recursive_types_b() { let mut act = test_term_rec_b(0); let exp = test_term_rec_b(30); - act.visit_mut_pre(&mut |expr| match expr { + let res = act.visit_mut_pre(&mut |expr| match expr { B::Lit(x) => *x = *x + 30, _ => (), }); + assert_ok!(res); assert_eq!(act, exp); } } diff --git a/src/sql/src/plan.rs b/src/sql/src/plan.rs index 47f45ab8c9ad3..46383a9bb60a3 100644 --- a/src/sql/src/plan.rs +++ b/src/sql/src/plan.rs @@ -936,7 +936,9 @@ impl SubscribeFrom { pub fn contains_temporal(&self) -> bool { match self { SubscribeFrom::Id(_) => false, - SubscribeFrom::Query { expr, .. } => expr.contains_temporal(), + SubscribeFrom::Query { expr, .. } => expr + .contains_temporal() + .expect("Unexpected error in `visit_scalars` call"), } } } diff --git a/src/sql/src/plan/explain.rs b/src/sql/src/plan/explain.rs index dd07bf551ee36..f029b02a0ffc3 100644 --- a/src/sql/src/plan/explain.rs +++ b/src/sql/src/plan/explain.rs @@ -14,6 +14,7 @@ use std::panic::AssertUnwindSafe; use mz_expr::explain::{ExplainContext, ExplainSinglePlan}; use mz_expr::visit::{Visit, VisitChildren}; use mz_expr::{Id, LocalId}; +use mz_ore::stack::RecursionLimitError; use mz_repr::SqlRelationType; use mz_repr::explain::{AnnotatedPlan, Explain, ExplainError, ScalarOps, UnsupportedFormat}; @@ -49,8 +50,7 @@ impl<'a> HirRelationExpr { // `normalize_subqueries` if !context.config.raw_plans { mz_ore::panic::catch_unwind_str(AssertUnwindSafe(|| { - normalize_subqueries(self); - Ok(()) + normalize_subqueries(self).map_err(|e| e.into()) })) .unwrap_or_else(|panic| { // A panic during optimization is always a bug; log an error so we learn about it. @@ -80,7 +80,7 @@ impl<'a> HirRelationExpr { /// [`HirScalarExpr::Exists`] or [`HirScalarExpr::Select`] where the /// subquery appears, and the corresponding variant references the /// new binding with a [`HirRelationExpr::Get`]. -pub fn normalize_subqueries<'a>(expr: &'a mut HirRelationExpr) { +pub fn normalize_subqueries<'a>(expr: &'a mut HirRelationExpr) -> Result<(), RecursionLimitError> { // A helper struct to represent accumulated `$local_id = $subquery` // bindings that need to be installed in `let ... in $expr` nodes // that wrap their parent $expr. @@ -93,14 +93,14 @@ pub fn normalize_subqueries<'a>(expr: &'a mut HirRelationExpr) { // - a stack of bindings let mut bindings = Vec::::new(); // - a generator of fresh local ids - let mut id_gen = id_gen(expr).peekable(); + let mut id_gen = id_gen(expr)?.peekable(); // Grow the `bindings` stack by collecting subqueries appearing in // one of the HirScalarExpr children at the given HirRelationExpr. // As part of this, the subquery is replaced by a `Get(id)` for a // fresh local id. let mut collect_subqueries = |expr: &mut HirRelationExpr, bindings: &mut Vec| { - expr.visit_mut_children(|expr: &mut HirScalarExpr| { + expr.try_visit_mut_children(|expr: &mut HirScalarExpr| { use HirRelationExpr::Get; use HirScalarExpr::{Exists, Select}; expr.visit_mut_post(&mut |expr: &mut HirScalarExpr| match expr { @@ -121,8 +121,8 @@ pub fn normalize_subqueries<'a>(expr: &'a mut HirRelationExpr) { } }, _ => (), - }); - }); + }) + }) }; // Drain the `bindings` stack by wrapping the given `HirRelationExpr` with @@ -142,17 +142,21 @@ pub fn normalize_subqueries<'a>(expr: &'a mut HirRelationExpr) { } }; - expr.visit_mut_post(&mut |expr: &mut HirRelationExpr| { + expr.try_visit_mut_post(&mut |expr: &mut HirRelationExpr| { // first grow bindings stack - collect_subqueries(expr, &mut bindings); + collect_subqueries(expr, &mut bindings)?; // then drain bindings stack insert_let_bindings(expr, &mut bindings); + // done! + Ok(()) }) } // Create an [`Iterator`] for [`LocalId`] values that are guaranteed to be // fresh within the scope of the given [`HirRelationExpr`]. -fn id_gen(expr: &HirRelationExpr) -> impl Iterator + use<> { +fn id_gen( + expr: &HirRelationExpr, +) -> Result + use<>, RecursionLimitError> { let mut max_id = 0_u64; expr.visit_pre(&mut |expr| { @@ -160,9 +164,9 @@ fn id_gen(expr: &HirRelationExpr) -> impl Iterator + use<> { HirRelationExpr::Let { id, .. } => max_id = std::cmp::max(max_id, id.into()), _ => (), }; - }); + })?; - (max_id + 1..).map(LocalId::new) + Ok((max_id + 1..).map(LocalId::new)) } impl ScalarOps for HirScalarExpr { diff --git a/src/sql/src/plan/hir.rs b/src/sql/src/plan/hir.rs index 0dd240bd9d7a2..596a806c794f5 100644 --- a/src/sql/src/plan/hir.rs +++ b/src/sql/src/plan/hir.rs @@ -28,6 +28,7 @@ pub use mz_expr::{ }; use mz_ore::collections::CollectionExt; use mz_ore::error::ErrorExt; +use mz_ore::stack::RecursionLimitError; use mz_ore::str::separated; use mz_ore::treat_as_equal::TreatAsEqual; use mz_ore::{soft_assert_or_log, stack}; @@ -356,6 +357,7 @@ impl VisitChildren for WindowExpr { fn try_visit_children(&self, mut f: F) -> Result<(), E> where F: FnMut(&HirScalarExpr) -> Result<(), E>, + E: From, { self.func.try_visit_children(&mut f)?; for expr in self.partition_by.iter() { @@ -370,6 +372,7 @@ impl VisitChildren for WindowExpr { fn try_visit_mut_children(&mut self, mut f: F) -> Result<(), E> where F: FnMut(&mut HirScalarExpr) -> Result<(), E>, + E: From, { self.func.try_visit_mut_children(&mut f)?; for expr in self.partition_by.iter_mut() { @@ -380,26 +383,6 @@ impl VisitChildren for WindowExpr { } Ok(()) } - - fn children<'a>(&'a self) -> impl DoubleEndedIterator - where - HirScalarExpr: 'a, - { - self.func - .children() - .chain(self.partition_by.iter()) - .chain(self.order_by.iter()) - } - - fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator - where - HirScalarExpr: 'a, - { - self.func - .children_mut() - .chain(self.partition_by.iter_mut()) - .chain(self.order_by.iter_mut()) - } } #[derive( @@ -504,6 +487,7 @@ impl VisitChildren for WindowExprType { fn try_visit_children(&self, f: F) -> Result<(), E> where F: FnMut(&HirScalarExpr) -> Result<(), E>, + E: From, { match self { Self::Scalar(_) => Ok(()), @@ -515,6 +499,7 @@ impl VisitChildren for WindowExprType { fn try_visit_mut_children(&mut self, f: F) -> Result<(), E> where F: FnMut(&mut HirScalarExpr) -> Result<(), E>, + E: From, { match self { Self::Scalar(_) => Ok(()), @@ -522,30 +507,6 @@ impl VisitChildren for WindowExprType { Self::Aggregate(expr) => expr.try_visit_mut_children(f), } } - - fn children<'a>(&'a self) -> impl DoubleEndedIterator - where - HirScalarExpr: 'a, - { - match self { - Self::Scalar(_) => vec![], - Self::Value(expr) => expr.children().collect(), - Self::Aggregate(expr) => expr.children().collect(), - } - .into_iter() - } - - fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator - where - HirScalarExpr: 'a, - { - match self { - Self::Scalar(_) => vec![], - Self::Value(expr) => expr.children_mut().collect(), - Self::Aggregate(expr) => expr.children_mut().collect(), - } - .into_iter() - } } #[derive( @@ -747,6 +708,7 @@ impl VisitChildren for ValueWindowExpr { fn try_visit_children(&self, mut f: F) -> Result<(), E> where F: FnMut(&HirScalarExpr) -> Result<(), E>, + E: From, { f(&self.args) } @@ -754,23 +716,10 @@ impl VisitChildren for ValueWindowExpr { fn try_visit_mut_children(&mut self, mut f: F) -> Result<(), E> where F: FnMut(&mut HirScalarExpr) -> Result<(), E>, + E: From, { f(&mut self.args) } - - fn children<'a>(&'a self) -> impl DoubleEndedIterator - where - HirScalarExpr: 'a, - { - VisitChildren::::children(&*self.args) - } - - fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator - where - HirScalarExpr: 'a, - { - VisitChildren::::children_mut(&mut *self.args) - } } #[derive( @@ -947,6 +896,7 @@ impl VisitChildren for AggregateWindowExpr { fn try_visit_children(&self, mut f: F) -> Result<(), E> where F: FnMut(&HirScalarExpr) -> Result<(), E>, + E: From, { f(&self.aggregate_expr.expr) } @@ -954,23 +904,10 @@ impl VisitChildren for AggregateWindowExpr { fn try_visit_mut_children(&mut self, mut f: F) -> Result<(), E> where F: FnMut(&mut HirScalarExpr) -> Result<(), E>, + E: From, { f(&mut self.aggregate_expr.expr) } - - fn children<'a>(&'a self) -> impl DoubleEndedIterator - where - HirScalarExpr: 'a, - { - VisitChildren::::children(&*self.aggregate_expr.expr) - } - - fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator - where - HirScalarExpr: 'a, - { - VisitChildren::::children_mut(&mut *self.aggregate_expr.expr) - } } /// A `CoercibleScalarExpr` is a [`HirScalarExpr`] whose type is not fully @@ -2499,12 +2436,12 @@ impl HirRelationExpr { /// should be kept in sync w.r.t. HIR ⇒ MIR lowering! pub fn could_run_expensive_function(&self) -> bool { let mut result = false; - self.visit_pre(&mut |e: &HirRelationExpr| { + if let Err(_) = self.visit_pre(&mut |e: &HirRelationExpr| { use HirRelationExpr::*; use HirScalarExpr::*; e.visit_children(|scalar: &HirScalarExpr| { - scalar.visit_pre(&mut |scalar: &HirScalarExpr| { + if let Err(_) = scalar.visit_pre(&mut |scalar: &HirScalarExpr| { result |= match scalar { Column(..) | Literal(..) @@ -2519,49 +2456,55 @@ impl HirRelationExpr { | CallVariadic { .. } | Windowing(..) => true, }; - }) + }) { + // Conservatively set `true` on RecursionLimitError. + result = true; + } }); // CallTable has a table function; Reduce has an aggregate function. // Other constructs use MirScalarExpr to run a function result |= matches!(e, CallTable { .. } | Reduce { .. }); - }); + }) { + // Conservatively set `true` on RecursionLimitError. + result = true; + } result } /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call. - pub fn contains_temporal(&self) -> bool { + pub fn contains_temporal(&self) -> Result { let mut contains = false; self.visit_post(&mut |expr| { expr.visit_children(|expr: &HirScalarExpr| { contains = contains || expr.contains_temporal() }) - }); - contains + })?; + Ok(contains) } /// Whether the expression contains any [`UnmaterializableFunc`] call. - pub fn contains_unmaterializable(&self) -> bool { + pub fn contains_unmaterializable(&self) -> Result { let mut contains = false; self.visit_post(&mut |expr| { expr.visit_children(|expr: &HirScalarExpr| { contains = contains || expr.contains_unmaterializable() }) - }); - contains + })?; + Ok(contains) } /// Whether the expression contains any [`UnmaterializableFunc`] call other than /// [`UnmaterializableFunc::MzNow`]. - pub fn contains_unmaterializable_except_temporal(&self) -> bool { + pub fn contains_unmaterializable_except_temporal(&self) -> Result { let mut contains = false; self.visit_post(&mut |expr| { expr.visit_children(|expr: &HirScalarExpr| { contains = contains || expr.contains_unmaterializable_except_temporal() }) - }); - contains + })?; + Ok(contains) } } @@ -2758,6 +2701,7 @@ impl VisitChildren for HirRelationExpr { fn try_visit_children(&self, mut f: F) -> Result<(), E> where F: FnMut(&Self) -> Result<(), E>, + E: From, { // subqueries of type HirRelationExpr might be wrapped in // Exists or Select variants within HirScalarExpr trees @@ -2842,6 +2786,7 @@ impl VisitChildren for HirRelationExpr { fn try_visit_mut_children(&mut self, mut f: F) -> Result<(), E> where F: FnMut(&mut Self) -> Result<(), E>, + E: From, { // subqueries of type HirRelationExpr might be wrapped in // Exists or Select variants within HirScalarExpr trees @@ -2922,186 +2867,6 @@ impl VisitChildren for HirRelationExpr { } Ok(()) } - - fn children<'a>(&'a self) -> impl DoubleEndedIterator - where - Self: 'a, - { - // we visit subqueries _first_, then the input - let mut v: Vec<&HirRelationExpr> = vec![]; - use HirRelationExpr::*; - match self { - Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (), - Let { - name: _, - id: _, - value, - body, - } => { - v.push(&*value); - v.push(&*body); - } - LetRec { - limit: _, - bindings, - body, - } => { - v.extend(bindings.iter().map(|(_, _, value, _)| value)); - v.push(&*body); - } - Map { input, scalars } - | Filter { - input, - predicates: scalars, - } => { - for scalar in scalars { - v.append(&mut scalar.direct_subqueries()); - } - v.push(&*input); - } - Reduce { - input, - group_key: _, - aggregates, - expected_group_size: _, - } => { - for agg in aggregates { - v.append(&mut agg.expr.direct_subqueries()); - } - v.push(&*input); - } - TopK { - input, - group_key: _, - order_key: _, - limit, - offset, - expected_group_size: _, - } => { - if let Some(limit) = limit { - v.append(&mut limit.direct_subqueries()); - } - v.append(&mut offset.direct_subqueries()); - v.push(&*input); - } - Project { input, outputs: _ } - | Distinct { input } - | Negate { input } - | Threshold { input } => v.push(&*input), - CallTable { func: _, exprs } => v.extend( - exprs - .iter() - .map(|scalar| scalar.direct_subqueries()) - .flatten(), - ), - Join { - left, - right, - on, - kind: _, - } => { - v.append(&mut on.direct_subqueries()); - v.push(&*left); - v.push(&*right); - } - Union { base, inputs } => { - v.push(&*base); - v.extend(inputs.iter()); - } - } - - v.into_iter() - } - - fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator - where - Self: 'a, - { - // we visit subqueries _first_, then the input - let mut v = vec![]; - use HirRelationExpr::*; - match self { - Constant { rows: _, typ: _ } | Get { id: _, typ: _ } => (), - Let { - name: _, - id: _, - value, - body, - } => { - v.push(&mut **value); - v.push(&mut **body); - } - LetRec { - limit: _, - bindings, - body, - } => { - v.extend(bindings.iter_mut().map(|(_, _, value, _)| value)); - v.push(&mut **body); - } - Map { input, scalars } - | Filter { - input, - predicates: scalars, - } => { - for scalar in scalars { - v.append(&mut scalar.direct_subqueries_mut()); - } - v.push(&mut **input); - } - Reduce { - input, - group_key: _, - aggregates, - expected_group_size: _, - } => { - for agg in aggregates { - v.append(&mut agg.expr.direct_subqueries_mut()); - } - v.push(&mut **input); - } - TopK { - input, - group_key: _, - order_key: _, - limit, - offset, - expected_group_size: _, - } => { - if let Some(limit) = limit { - v.append(&mut limit.direct_subqueries_mut()); - } - v.append(&mut offset.direct_subqueries_mut()); - v.push(&mut **input); - } - Project { input, outputs: _ } - | Distinct { input } - | Negate { input } - | Threshold { input } => v.push(&mut **input), - CallTable { func: _, exprs } => v.extend( - exprs - .iter_mut() - .map(|scalar| scalar.direct_subqueries_mut()) - .flatten(), - ), - Join { - left, - right, - on, - kind: _, - } => { - v.append(&mut on.direct_subqueries_mut()); - v.push(&mut **left); - v.push(&mut **right); - } - Union { base, inputs } => { - v.push(&mut **base); - v.extend(inputs.iter_mut()); - } - } - - v.into_iter() - } } /// Yields the scalars directly attached to relation nodes (e.g. `Map.scalars`, @@ -3265,6 +3030,7 @@ impl VisitChildren for HirRelationExpr { fn try_visit_children(&self, mut f: F) -> Result<(), E> where F: FnMut(&HirScalarExpr) -> Result<(), E>, + E: From, { use HirRelationExpr::*; match self { @@ -3343,6 +3109,7 @@ impl VisitChildren for HirRelationExpr { fn try_visit_mut_children(&mut self, mut f: F) -> Result<(), E> where F: FnMut(&mut HirScalarExpr) -> Result<(), E>, + E: From, { use HirRelationExpr::*; match self { @@ -3417,126 +3184,6 @@ impl VisitChildren for HirRelationExpr { } Ok(()) } - - fn children<'a>(&'a self) -> impl DoubleEndedIterator - where - HirScalarExpr: 'a, - { - use HirRelationExpr::*; - match self { - Constant { rows: _, typ: _ } - | Get { id: _, typ: _ } - | Let { - name: _, - id: _, - value: _, - body: _, - } - | LetRec { - limit: _, - bindings: _, - body: _, - } - | Project { - input: _, - outputs: _, - } - | Distinct { input: _ } - | Negate { input: _ } - | Threshold { input: _ } - | Union { base: _, inputs: _ } => vec![], - Map { input: _, scalars } - | CallTable { - func: _, - exprs: scalars, - } - | Filter { - input: _, - predicates: scalars, - } => scalars.iter().collect(), - Join { - left: _, - right: _, - on, - kind: _, - } => vec![on], - Reduce { - input: _, - group_key: _, - aggregates, - expected_group_size: _, - } => aggregates.iter().map(|agg| &*agg.expr).collect(), - TopK { - input: _, - group_key: _, - order_key: _, - limit, - offset, - expected_group_size: _, - } => limit.iter().chain(std::iter::once(offset)).collect(), - } - .into_iter() - } - - fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator - where - HirScalarExpr: 'a, - { - use HirRelationExpr::*; - match self { - Constant { rows: _, typ: _ } - | Get { id: _, typ: _ } - | Let { - name: _, - id: _, - value: _, - body: _, - } - | LetRec { - limit: _, - bindings: _, - body: _, - } - | Project { - input: _, - outputs: _, - } - | Distinct { input: _ } - | Negate { input: _ } - | Threshold { input: _ } - | Union { base: _, inputs: _ } => vec![], - Map { input: _, scalars } - | CallTable { - func: _, - exprs: scalars, - } - | Filter { - input: _, - predicates: scalars, - } => scalars.iter_mut().collect(), - Join { - left: _, - right: _, - on, - kind: _, - } => vec![on], - Reduce { - input: _, - group_key: _, - aggregates, - expected_group_size: _, - } => aggregates.iter_mut().map(|agg| &mut *agg.expr).collect(), - TopK { - input: _, - group_key: _, - order_key: _, - limit, - offset, - expected_group_size: _, - } => limit.iter_mut().chain(std::iter::once(offset)).collect(), - } - .into_iter() - } } impl HirScalarExpr { @@ -3562,7 +3209,13 @@ impl HirScalarExpr { where F: FnMut(&HirRelationExpr), { - self.visit_post(&mut |e| { + // The infallible variants of the post-walk are deprecated in favor + // of the limit-aware `try_visit_post`, but we have no error channel + // here (the surrounding signature is infallible, mirroring the + // infallible `VisitChildren::{visit_children, visit_mut_children}` + // impls that this helper is designed to be called from). + #[allow(deprecated)] + self.visit_post_nolimit(&mut |e| { VisitChildren::::visit_children(e, &mut f); }); } @@ -3572,7 +3225,10 @@ impl HirScalarExpr { where F: FnMut(&mut HirRelationExpr), { - self.visit_mut_post(&mut |e| { + // See the comment on `visit_direct_subqueries` for why we use the + // deprecated `_nolimit` walk here. + #[allow(deprecated)] + self.visit_mut_post_nolimit(&mut |e| { VisitChildren::::visit_mut_children(e, &mut f); }); } @@ -3581,6 +3237,7 @@ impl HirScalarExpr { pub fn try_visit_direct_subqueries(&self, mut f: F) -> Result<(), E> where F: FnMut(&HirRelationExpr) -> Result<(), E>, + E: From, { self.try_visit_post(&mut |e| { VisitChildren::::try_visit_children(e, &mut f) @@ -3591,6 +3248,7 @@ impl HirScalarExpr { pub fn try_visit_direct_subqueries_mut(&mut self, mut f: F) -> Result<(), E> where F: FnMut(&mut HirRelationExpr) -> Result<(), E>, + E: From, { self.try_visit_mut_post(&mut |e| { VisitChildren::::try_visit_mut_children(e, &mut f) @@ -3681,7 +3339,8 @@ impl HirScalarExpr { /// Whether the expression contains an [`UnmaterializableFunc::MzNow`] call. pub fn contains_temporal(&self) -> bool { let mut contains = false; - self.visit_post(&mut |e| { + #[allow(deprecated)] + self.visit_post_nolimit(&mut |e| { if let Self::CallUnmaterializable(UnmaterializableFunc::MzNow, _name) = e { contains = true; } @@ -3692,7 +3351,8 @@ impl HirScalarExpr { /// Whether the expression contains any [`UnmaterializableFunc`] call. pub fn contains_unmaterializable(&self) -> bool { let mut contains = false; - self.visit_post(&mut |e| { + #[allow(deprecated)] + self.visit_post_nolimit(&mut |e| { if let Self::CallUnmaterializable(_, _) = e { contains = true; } @@ -3704,7 +3364,8 @@ impl HirScalarExpr { /// [`UnmaterializableFunc::MzNow`]. pub fn contains_unmaterializable_except_temporal(&self) -> bool { let mut contains = false; - self.visit_post(&mut |e| { + #[allow(deprecated)] + self.visit_post_nolimit(&mut |e| { if let Self::CallUnmaterializable(f, _) = e { if *f != UnmaterializableFunc::MzNow { contains = true; @@ -4268,146 +3929,6 @@ impl HirScalarExpr { }); contains_parameters } - - fn direct_subqueries(&self) -> Vec<&HirRelationExpr> { - let mut subqueries: Vec<&HirRelationExpr> = vec![]; - - let mut worklist = vec![self]; - while let Some(elt) = worklist.pop() { - match elt { - HirScalarExpr::Column(_, _) - | HirScalarExpr::Parameter(_, _) - | HirScalarExpr::Literal(_, _, _) - | HirScalarExpr::CallUnmaterializable(_, _) => (), - HirScalarExpr::CallUnary { - func: _, - expr, - name: _, - } => worklist.push(&*expr), - HirScalarExpr::CallBinary { - func: _, - expr1, - expr2, - name: _name, - } => { - // Push in reverse so children pop (and are visited) left-to-right. - worklist.push(&*expr2); - worklist.push(&*expr1); - } - HirScalarExpr::CallVariadic { - func: _, - exprs, - name: _name, - } => { - worklist.extend(exprs.iter().rev()); - } - HirScalarExpr::If { - cond, - then, - els, - name: _, - } => { - worklist.push(&*els); - worklist.push(&*then); - worklist.push(&*cond); - } - HirScalarExpr::Exists(hir, _) | HirScalarExpr::Select(hir, _) => { - subqueries.push(&*hir); - } - HirScalarExpr::Windowing( - WindowExpr { - func, - partition_by, - order_by, - }, - _, - ) => { - // Push in reverse so children pop (and are visited) left-to-right: - // func args, then partition_by, then order_by. - worklist.extend(order_by.iter().rev()); - worklist.extend(partition_by.iter().rev()); - match func { - WindowExprType::Scalar(_) => (), - WindowExprType::Value(val) => worklist.push(&*val.args), - WindowExprType::Aggregate(agg) => worklist.push(&*agg.aggregate_expr.expr), - } - } - } - } - - subqueries - } - - fn direct_subqueries_mut(&mut self) -> Vec<&mut HirRelationExpr> { - let mut subqueries: Vec<&mut HirRelationExpr> = vec![]; - - let mut worklist = vec![self]; - while let Some(elt) = worklist.pop() { - match elt { - HirScalarExpr::Column(_, _) - | HirScalarExpr::Parameter(_, _) - | HirScalarExpr::Literal(_, _, _) - | HirScalarExpr::CallUnmaterializable(_, _) => (), - HirScalarExpr::CallUnary { - func: _, - expr, - name: _, - } => worklist.push(&mut **expr), - HirScalarExpr::CallBinary { - func: _, - expr1, - expr2, - name: _name, - } => { - // Push in reverse so children pop (and are visited) left-to-right. - worklist.push(&mut **expr2); - worklist.push(&mut **expr1); - } - HirScalarExpr::CallVariadic { - func: _, - exprs, - name: _name, - } => { - worklist.extend(exprs.iter_mut().rev()); - } - HirScalarExpr::If { - cond, - then, - els, - name: _, - } => { - worklist.push(&mut **els); - worklist.push(&mut **then); - worklist.push(&mut **cond); - } - HirScalarExpr::Exists(hir, _) | HirScalarExpr::Select(hir, _) => { - subqueries.push(&mut **hir); - } - HirScalarExpr::Windowing( - WindowExpr { - func, - partition_by, - order_by, - }, - _, - ) => { - // Push in reverse so children pop (and are visited) left-to-right: - // func args, then partition_by, then order_by. - worklist.extend(order_by.iter_mut().rev()); - worklist.extend(partition_by.iter_mut().rev()); - match func { - WindowExprType::Scalar(_) => (), - WindowExprType::Value(val) => worklist.push(&mut val.args), - WindowExprType::Aggregate(agg) => { - worklist.push(&mut agg.aggregate_expr.expr) - } - } - } - } - } - - subqueries - } } /// Yields the direct scalar children of `self`. Stops at `Exists` / `Select`: @@ -4483,6 +4004,7 @@ impl VisitChildren for HirScalarExpr { fn try_visit_children(&self, mut f: F) -> Result<(), E> where F: FnMut(&Self) -> Result<(), E>, + E: From, { use HirScalarExpr::*; match self { @@ -4516,6 +4038,7 @@ impl VisitChildren for HirScalarExpr { fn try_visit_mut_children(&mut self, mut f: F) -> Result<(), E> where F: FnMut(&mut Self) -> Result<(), E>, + E: From, { use HirScalarExpr::*; match self { @@ -4545,58 +4068,6 @@ impl VisitChildren for HirScalarExpr { } Ok(()) } - - fn children<'a>(&'a self) -> impl DoubleEndedIterator - where - Self: 'a, - { - use HirScalarExpr::*; - let v: Vec<&Self> = match self { - Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => vec![], - CallUnary { expr, .. } => vec![&*expr], - CallBinary { expr1, expr2, .. } => { - vec![&*expr1, &*expr2] - } - CallVariadic { exprs, .. } => exprs.iter().collect(), - If { - cond, - then, - els, - name: _, - } => { - vec![&*cond, &*then, &*els] - } - Exists(..) | Select(..) => vec![], - Windowing(expr, _name) => expr.children().collect(), - }; - v.into_iter() - } - - fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator - where - Self: 'a, - { - use HirScalarExpr::*; - let v: Vec<&mut Self> = match self { - Column(..) | Parameter(..) | Literal(..) | CallUnmaterializable(..) => vec![], - CallUnary { expr, .. } => vec![&mut **expr], - CallBinary { expr1, expr2, .. } => { - vec![&mut **expr1, &mut **expr2] - } - CallVariadic { exprs, .. } => exprs.iter_mut().collect(), - If { - cond, - then, - els, - name: _, - } => { - vec![&mut **cond, &mut **then, &mut **els] - } - Exists(..) | Select(..) => vec![], - Windowing(expr, _name) => expr.children_mut().collect(), - }; - v.into_iter() - } } /// Yields the immediate `HirRelationExpr` children of `self` (the bodies of @@ -4643,6 +4114,7 @@ impl VisitChildren for HirScalarExpr { fn try_visit_children(&self, mut f: F) -> Result<(), E> where F: FnMut(&HirRelationExpr) -> Result<(), E>, + E: From, { use HirScalarExpr::*; match self { @@ -4663,6 +4135,7 @@ impl VisitChildren for HirScalarExpr { fn try_visit_mut_children(&mut self, mut f: F) -> Result<(), E> where F: FnMut(&mut HirRelationExpr) -> Result<(), E>, + E: From, { use HirScalarExpr::*; match self { @@ -4679,50 +4152,6 @@ impl VisitChildren for HirScalarExpr { } Ok(()) } - - fn children<'a>(&'a self) -> impl DoubleEndedIterator - where - HirRelationExpr: 'a, - { - let mut child: Option<&HirRelationExpr> = None; - use HirScalarExpr::*; - match self { - Column(..) - | Parameter(..) - | Literal(..) - | CallUnmaterializable(..) - | CallUnary { .. } - | CallBinary { .. } - | CallVariadic { .. } - | If { .. } - | Windowing(..) => (), - Exists(expr, _name) | Select(expr, _name) => child = Some(&*expr), - } - - child.into_iter() - } - - fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator - where - HirRelationExpr: 'a, - { - let mut child: Option<&mut HirRelationExpr> = None; - use HirScalarExpr::*; - match self { - Column(..) - | Parameter(..) - | Literal(..) - | CallUnmaterializable(..) - | CallUnary { .. } - | CallBinary { .. } - | CallVariadic { .. } - | If { .. } - | Windowing(..) => (), - Exists(expr, _name) | Select(expr, _name) => child = Some(&mut **expr), - } - - child.into_iter() - } } impl AbstractExpr for HirScalarExpr { diff --git a/src/sql/src/plan/lowering.rs b/src/sql/src/plan/lowering.rs index a13ab2f4c3e08..9014f66401992 100644 --- a/src/sql/src/plan/lowering.rs +++ b/src/sql/src/plan/lowering.rs @@ -1620,7 +1620,7 @@ impl HirScalarExpr { } _ => {} }, - ); + )?; } if subqueries.is_empty() { diff --git a/src/sql/src/plan/query.rs b/src/sql/src/plan/query.rs index c55a7def6750f..8f12e8ddea6e7 100644 --- a/src/sql/src/plan/query.rs +++ b/src/sql/src/plan/query.rs @@ -917,7 +917,7 @@ fn handle_mutation_using_clause( c.column -= using_rel_arity; }; } - }); + })?; // Filter `USING` tables like ` WHERE `. Note that // this filters the `USING` tables, _not_ the joined `USING..., FROM` diff --git a/src/sql/src/plan/statement/dml.rs b/src/sql/src/plan/statement/dml.rs index fdf9b87cfa2ca..9ed2a4069c0b3 100644 --- a/src/sql/src/plan/statement/dml.rs +++ b/src/sql/src/plan/statement/dml.rs @@ -269,7 +269,7 @@ fn plan_select_inner( // we just disallow AS OF when there is an unmaterializable function in a query (except mz_now). if scx.is_feature_flag_enabled(&DISALLOW_UNMATERIALIZABLE_FUNCTIONS_AS_OF) && select.as_of.is_some() - && expr.contains_unmaterializable_except_temporal() + && expr.contains_unmaterializable_except_temporal()? { bail_unsupported!("unmaterializable function (except `mz_now`) in an AS OF query"); } diff --git a/src/transform/src/canonicalization/flat_map_elimination.rs b/src/transform/src/canonicalization/flat_map_elimination.rs index b314906a38f25..483303d4c26a8 100644 --- a/src/transform/src/canonicalization/flat_map_elimination.rs +++ b/src/transform/src/canonicalization/flat_map_elimination.rs @@ -41,7 +41,7 @@ impl crate::Transform for FlatMapElimination { relation: &mut MirRelationExpr, _: &mut TransformCtx, ) -> Result<(), crate::TransformError> { - relation.visit_mut_post(&mut Self::action); + relation.visit_mut_post(&mut Self::action)?; mz_repr::explain::trace_plan(&*relation); Ok(()) } diff --git a/src/transform/src/canonicalization/projection_extraction.rs b/src/transform/src/canonicalization/projection_extraction.rs index 65415c5a98d18..d59cba403c8a4 100644 --- a/src/transform/src/canonicalization/projection_extraction.rs +++ b/src/transform/src/canonicalization/projection_extraction.rs @@ -34,7 +34,7 @@ impl crate::Transform for ProjectionExtraction { relation: &mut MirRelationExpr, _: &mut TransformCtx, ) -> Result<(), crate::TransformError> { - relation.visit_mut_post(&mut Self::action); + relation.visit_mut_post(&mut Self::action)?; mz_repr::explain::trace_plan(&*relation); Ok(()) } diff --git a/src/transform/src/canonicalization/topk_elision.rs b/src/transform/src/canonicalization/topk_elision.rs index 47e6755e575fe..cd238a29ea06d 100644 --- a/src/transform/src/canonicalization/topk_elision.rs +++ b/src/transform/src/canonicalization/topk_elision.rs @@ -33,7 +33,7 @@ impl crate::Transform for TopKElision { relation: &mut MirRelationExpr, _: &mut TransformCtx, ) -> Result<(), crate::TransformError> { - relation.visit_mut_post(&mut Self::action); + relation.visit_mut_post(&mut Self::action)?; mz_repr::explain::trace_plan(&*relation); Ok(()) } diff --git a/src/transform/src/coalesce_case.rs b/src/transform/src/coalesce_case.rs index 105dfa6103549..938206ada3990 100644 --- a/src/transform/src/coalesce_case.rs +++ b/src/transform/src/coalesce_case.rs @@ -41,14 +41,14 @@ impl crate::Transform for CoalesceCase { relation: &mut MirRelationExpr, _: &mut TransformCtx, ) -> Result<(), TransformError> { - relation.visit_mut_post(&mut |e| self.action(e)); + relation.try_visit_mut_post(&mut |e| self.action(e))?; mz_repr::explain::trace_plan(&*relation); Ok(()) } } impl CoalesceCase { - fn action(&self, relation: &mut MirRelationExpr) { + fn action(&self, relation: &mut MirRelationExpr) -> Result<(), TransformError> { match relation { MirRelationExpr::Constant { .. } | MirRelationExpr::Get { .. } | @@ -66,44 +66,47 @@ impl CoalesceCase { MirRelationExpr::FlatMap { exprs, .. } => { // NB TableFunc doesn't ever hold an MSE for expr in exprs.iter_mut() { - self.rewrite_scalar(expr); + self.rewrite_scalar(expr)?; } } MirRelationExpr::Join { equivalences, implementation, .. } => { if implementation.is_implemented() { soft_panic_or_log!("unexpected implemented Join when optimizing coalesce/case, skipping: {implementation:?}"); - return; + return Ok(()); } for equivalence in equivalences.iter_mut() { for expr in equivalence { - self.rewrite_scalar(expr); + self.rewrite_scalar(expr)?; } } } MirRelationExpr::Reduce { group_key, aggregates, .. } => { for expr in group_key.iter_mut() { - self.rewrite_scalar(expr); + self.rewrite_scalar(expr)?; } for agg in aggregates.iter_mut() { - self.rewrite_aggreagte(agg); + self.rewrite_aggreagte(agg)?; } } MirRelationExpr::TopK { limit, .. } => { if let Some(expr) = limit { - self.rewrite_scalar(expr); + self.rewrite_scalar(expr)?; } } } + + Ok(()) } - fn rewrite_scalar(&self, expr: &mut MirScalarExpr) { + fn rewrite_scalar(&self, expr: &mut MirScalarExpr) -> Result<(), TransformError> { // Visiting in pre-order means that when we push a `COALESCE` down, we'll keep pushing if the `CASE` chain continues. expr.visit_mut_pre(&mut |e| self.try_combine_coalesce_case(e)) + .map_err(TransformError::from) } - fn rewrite_aggreagte(&self, agg: &mut AggregateExpr) { + fn rewrite_aggreagte(&self, agg: &mut AggregateExpr) -> Result<(), TransformError> { // NB AggregateFunc doesn't contain any MSEs self.rewrite_scalar(&mut agg.expr) } diff --git a/src/transform/src/column_knowledge.rs b/src/transform/src/column_knowledge.rs index 3f539d61e1e6a..64525a09d2782 100644 --- a/src/transform/src/column_knowledge.rs +++ b/src/transform/src/column_knowledge.rs @@ -765,6 +765,7 @@ fn optimize( // `DatumKnowledge` in the stack are the `DatumKnowledge` corresponding to // the children. assert!(knowledge_stack.is_empty()); + #[allow(deprecated)] expr.visit_mut_pre_post( &mut |e| { if let MirScalarExpr::If { then, els, .. } = e { @@ -843,7 +844,7 @@ fn optimize( }; knowledge_stack.push(result); }, - ); + )?; let knowledge_datum = knowledge_stack.pop(); assert!(knowledge_stack.is_empty()); knowledge_datum.ok_or_else(|| { diff --git a/src/transform/src/compound/union.rs b/src/transform/src/compound/union.rs index eb7a323e736cf..3fd0480da718e 100644 --- a/src/transform/src/compound/union.rs +++ b/src/transform/src/compound/union.rs @@ -39,7 +39,7 @@ impl crate::Transform for UnionNegateFusion { relation: &mut MirRelationExpr, _: &mut TransformCtx, ) -> Result<(), crate::TransformError> { - relation.visit_mut_post(&mut Self::action); + relation.visit_mut_post(&mut Self::action)?; mz_repr::explain::trace_plan(&*relation); Ok(()) } diff --git a/src/transform/src/fusion.rs b/src/transform/src/fusion.rs index f8f32203a0304..345ccb0d6e10d 100644 --- a/src/transform/src/fusion.rs +++ b/src/transform/src/fusion.rs @@ -42,7 +42,7 @@ impl crate::Transform for Fusion { _: &mut TransformCtx, ) -> Result<(), crate::TransformError> { use mz_expr::visit::Visit; - relation.visit_mut_post(&mut Self::action); + relation.visit_mut_post(&mut Self::action)?; mz_repr::explain::trace_plan(&*relation); Ok(()) } diff --git a/src/transform/src/fusion/union.rs b/src/transform/src/fusion/union.rs index c7342b5ea2ce1..ce0f01a94c794 100644 --- a/src/transform/src/fusion/union.rs +++ b/src/transform/src/fusion/union.rs @@ -31,7 +31,7 @@ impl crate::Transform for Union { relation: &mut MirRelationExpr, _: &mut crate::TransformCtx, ) -> Result<(), crate::TransformError> { - relation.visit_mut_post(&mut Self::action); + relation.visit_mut_post(&mut Self::action)?; mz_repr::explain::trace_plan(&*relation); Ok(()) } diff --git a/src/transform/src/join_implementation.rs b/src/transform/src/join_implementation.rs index 480e0f3181f69..1a552fa6e97c6 100644 --- a/src/transform/src/join_implementation.rs +++ b/src/transform/src/join_implementation.rs @@ -656,7 +656,7 @@ mod delta_queries { *implementation = JoinImplementation::DeltaQuery(orders); - super::install_lifted_mfp(&mut new_join, lifted_mfp); + super::install_lifted_mfp(&mut new_join, lifted_mfp)?; // Hooray done! Ok((new_join, new_arrangements)) @@ -816,7 +816,7 @@ mod differential { order, ); - super::install_lifted_mfp(&mut new_join, lifted_mfp); + super::install_lifted_mfp(&mut new_join, lifted_mfp)?; // Hooray done! Ok((new_join, new_arrangements)) @@ -940,7 +940,10 @@ fn implement_arrangements<'a>( /// column that was permuted or created by the given MFP. /// - Canonicalizes scalar expressions in maps and filters with respect to the join equivalences. /// See inline comment for more details. -fn install_lifted_mfp(new_join: &mut MirRelationExpr, mfp: MapFilterProject) { +fn install_lifted_mfp( + new_join: &mut MirRelationExpr, + mfp: MapFilterProject, +) -> Result<(), TransformError> { if !mfp.is_identity() { let (mut map, mut filter, project) = mfp.as_map_filter_project(); if let MirRelationExpr::Join { equivalences, .. } = new_join { @@ -963,7 +966,7 @@ fn install_lifted_mfp(new_join: &mut MirRelationExpr, mfp: MapFilterProject) { break; } } - }); + })?; } } // Canonicalize scalar expressions in maps and filters with respect to the join @@ -982,11 +985,12 @@ fn install_lifted_mfp(new_join: &mut MirRelationExpr, mfp: MapFilterProject) { if let Some(canonical_expr) = canonicalizer_map.get(e) { *e = canonical_expr.clone(); } - }) + })? } } *new_join = new_join.clone().map(map).filter(filter).project(project); } + Ok(()) } /// Permute the keys in `order` to compensate for projections being lifted from inputs. diff --git a/src/transform/src/literal_constraints.rs b/src/transform/src/literal_constraints.rs index 003bc005ea772..81288407dd5b3 100644 --- a/src/transform/src/literal_constraints.rs +++ b/src/transform/src/literal_constraints.rs @@ -701,10 +701,10 @@ impl LiteralConstraints { *e = or; // The modified OR will be the new top-level expr. } } - }); + })?; p.visit_mut_post(&mut |e: &mut MirScalarExpr| { e.flatten_associative(); - }); + })?; } Ok(()) }) diff --git a/src/transform/src/literal_lifting.rs b/src/transform/src/literal_lifting.rs index 88e5bafcffe39..a7caf07492b6c 100644 --- a/src/transform/src/literal_lifting.rs +++ b/src/transform/src/literal_lifting.rs @@ -404,7 +404,7 @@ impl LiteralLifting { *old_id = new_id; } } - }); + })?; projection.push(input_arity + new_scalars.len()); new_scalars.push(cloned_scalar); } @@ -427,7 +427,7 @@ impl LiteralLifting { *e = literals[*c - input_arity].clone(); } } - }); + })?; } // Permute the literals around the columns added by FlatMap let mut projection = (0..input_arity).collect::>(); @@ -453,7 +453,7 @@ impl LiteralLifting { *e = literals[*c - input_arity].clone(); } } - }); + })?; } } Ok(literals) @@ -517,7 +517,7 @@ impl LiteralLifting { .map_column_to_global(col, input); } } - }); + })?; } } @@ -578,7 +578,7 @@ impl LiteralLifting { *e = literals[*c - input_arity].clone(); } } - }); + })?; } // Inline literals into aggregate value selector expressions. for aggr in aggregates.iter_mut() { @@ -588,7 +588,7 @@ impl LiteralLifting { *e = literals[*c - input_arity].clone(); } } - }); + })?; } } @@ -690,7 +690,7 @@ impl LiteralLifting { *e = literals[*c - input_arity].clone(); } } - }); + })?; } } Ok(literals) @@ -758,7 +758,7 @@ impl LiteralLifting { *e = literals[*c - input_arity].clone(); } } - }); + })?; } } } diff --git a/src/transform/src/predicate_pushdown.rs b/src/transform/src/predicate_pushdown.rs index bfd498608b0d8..cba9205213bf7 100644 --- a/src/transform/src/predicate_pushdown.rs +++ b/src/transform/src/predicate_pushdown.rs @@ -98,7 +98,7 @@ use mz_expr::{ RECURSION_LIMIT, VariadicFunc, func, }; use mz_ore::soft_assert_eq_no_log; -use mz_ore::stack::{CheckedRecursion, RecursionGuard}; +use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError}; use mz_repr::{Datum, ReprColumnType, ReprScalarType}; use crate::{TransformCtx, TransformError}; @@ -326,7 +326,7 @@ impl PredicatePushdown { if let MirScalarExpr::Column(i, _) = e { *e = group_key[*i].clone(); } - }); + })?; push_down.push(new_predicate); } else if let MirScalarExpr::Column(col, _) = &predicate { if *col == group_key.len() @@ -934,7 +934,7 @@ impl PredicatePushdown { if !predicate.is_literal_err() || all_errors { // Consider inlining Map expressions. if let Some(cleaned) = - Self::inline_if_not_too_big(&predicate, input_arity, map_exprs) + Self::inline_if_not_too_big(&predicate, input_arity, map_exprs)? { pushdown.push(cleaned); } else { @@ -965,7 +965,7 @@ impl PredicatePushdown { expr: &MirScalarExpr, input_arity: usize, map_exprs: &Vec, - ) -> Option { + ) -> Result, RecursionLimitError> { let size_limit = 1000; // Transitively determine the support of `expr` produced by `map_exprs` @@ -1022,7 +1022,7 @@ impl PredicatePushdown { new_size += m_size - 1; // Adjust for the +1 above. } } - }); + })?; if new_size <= size_limit { inlined.insert(*c, (new_expr, new_size)); @@ -1033,7 +1033,7 @@ impl PredicatePushdown { // Try to resolve expr against the memo table. if inlined.len() < cols_to_inline.len() { - None // We couldn't memoize all map expressions within the given limit. + Ok(None) // We couldn't memoize all map expressions within the given limit. } else { let mut new_expr = expr.clone(); let mut new_size = 0; @@ -1047,13 +1047,13 @@ impl PredicatePushdown { new_size += m_size - 1; // Adjust for the +1 above. } } - }); + })?; soft_assert_eq_no_log!(new_size, new_expr.size()); if new_size <= size_limit { - Some(new_expr) // We managed to stay within the limit. + Ok(Some(new_expr)) // We managed to stay within the limit. } else { - None // Limit exceeded. + Ok(None) // Limit exceeded. } } } diff --git a/src/transform/src/reduction_pushdown.rs b/src/transform/src/reduction_pushdown.rs index f9e9984b186fe..8cbe2334e1048 100644 --- a/src/transform/src/reduction_pushdown.rs +++ b/src/transform/src/reduction_pushdown.rs @@ -77,9 +77,9 @@ impl crate::Transform for ReductionPushdown { ) -> Result<(), crate::TransformError> { // `try_visit_mut_pre` is used here because after pushing down a reduction, // we want to see if we can push the same reduction further down. - relation.visit_mut_pre(&mut |e| self.action(e)); + let result = relation.try_visit_mut_pre(&mut |e| self.action(e)); mz_repr::explain::trace_plan(&*relation); - Ok(()) + result } } @@ -90,7 +90,7 @@ impl ReductionPushdown { /// edges are join constraints. After removing constraints containing a /// GroupBy, the reduce will be pushed down to all connected components. If /// there is only one connected component, this method is a no-op. - pub fn action(&self, relation: &mut MirRelationExpr) { + pub fn action(&self, relation: &mut MirRelationExpr) -> Result<(), crate::TransformError> { if let MirRelationExpr::Reduce { input, group_key, @@ -117,7 +117,7 @@ impl ReductionPushdown { *e = lower[*c - arity].clone(); } } - }); + })?; } for key in group_key.iter_mut() { key.visit_mut_post(&mut |e| { @@ -126,7 +126,7 @@ impl ReductionPushdown { *e = scalars[*c - arity].clone(); } } - }); + })?; } for agg in aggregates.iter_mut() { agg.expr.visit_mut_post(&mut |e| { @@ -135,7 +135,7 @@ impl ReductionPushdown { *e = scalars[*c - arity].clone(); } } - }); + })?; } **input = inner.take_dangerous() @@ -158,6 +158,7 @@ impl ReductionPushdown { } } } + Ok(()) } } diff --git a/src/transform/src/redundant_join.rs b/src/transform/src/redundant_join.rs index ed699b6576930..a576eed524b5c 100644 --- a/src/transform/src/redundant_join.rs +++ b/src/transform/src/redundant_join.rs @@ -246,7 +246,7 @@ impl RedundantJoin { *c -= old_input_mapper.input_arity(remove_input_idx); } } - }); + })?; } }