Skip to content

Commit

Permalink
Merge pull request #9117 from benesch/stack-guards
Browse files Browse the repository at this point in the history
sql: avoid stack overflow on deeply nested ASTs
  • Loading branch information
benesch committed Nov 17, 2021
2 parents ca09c85 + c8c5069 commit 63c0639
Show file tree
Hide file tree
Showing 9 changed files with 288 additions and 69 deletions.
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion src/ore/Cargo.toml
Expand Up @@ -7,10 +7,11 @@ edition = "2021"
publish = false

[features]
default = ["network", "chrono", "cli", "metrics", "test"]
default = ["network", "chrono", "cli", "metrics", "stack", "test"]
cli = ["structopt"]
metrics = ["prometheus"]
network = ["async-trait", "bytes", "futures", "openssl", "smallvec", "tokio-openssl", "tokio"]
stack = ["stacker"]
test = ["anyhow", "ctor", "tracing-subscriber"]

# NB: ore is meant to be an extension of the Rust stdlib. To keep it
Expand All @@ -34,6 +35,7 @@ openssl = { version = "0.10.38", features = ["vendored"], optional = true }
pin-project = "1"
prometheus = { git = "https://github.com/MaterializeInc/rust-prometheus.git", default-features = false, optional = true }
smallvec = { version = "1.7.0", optional = true }
stacker = { version = "0.1.14", optional = true }
structopt = { version = "0.3.25", optional = true }
tokio = { version = "1.13.0", features = ["io-util", "net", "rt-multi-thread", "time"], optional = true }
tokio-openssl = { version = "0.6.3", optional = true }
Expand Down
11 changes: 11 additions & 0 deletions src/ore/src/lib.rs
Expand Up @@ -20,37 +20,48 @@
//! small to warrant their own crate.

#![warn(missing_docs, missing_debug_implementations)]
#![cfg_attr(nightly_doc_features, feature(doc_cfg))]

#[cfg_attr(nightly_doc_features, doc(cfg(feature = "test")))]
#[cfg(feature = "test")]
pub mod assert;
pub mod cast;
pub mod cgroup;
#[cfg_attr(nightly_doc_features, doc(cfg(feature = "cli")))]
#[cfg(feature = "cli")]
pub mod cli;
pub mod codegen;
pub mod collections;
pub mod display;
pub mod env;
pub mod fmt;
#[cfg_attr(nightly_doc_features, doc(cfg(feature = "network")))]
#[cfg(feature = "network")]
pub mod future;
pub mod hash;
pub mod hint;
pub mod id_gen;
pub mod iter;
pub mod lex;
#[cfg_attr(nightly_doc_features, doc(cfg(feature = "metrics")))]
#[cfg(feature = "metrics")]
pub mod metrics;
#[cfg_attr(nightly_doc_features, doc(cfg(feature = "network")))]
#[cfg(feature = "network")]
pub mod netio;
pub mod now;
pub mod option;
pub mod panic;
pub mod result;
#[cfg_attr(nightly_doc_features, doc(cfg(feature = "network")))]
#[cfg(feature = "network")]
pub mod retry;
#[cfg_attr(nightly_doc_features, doc(cfg(feature = "stack")))]
#[cfg(feature = "stack")]
pub mod stack;
pub mod stats;
pub mod str;
#[cfg_attr(nightly_doc_features, doc(cfg(feature = "test")))]
#[cfg(feature = "test")]
pub mod test;
pub mod thread;
Expand Down
204 changes: 204 additions & 0 deletions src/ore/src/stack.rs
@@ -0,0 +1,204 @@
// Copyright Materialize, Inc. and contributors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License in the LICENSE file at the
// root of this repository, or online at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Stack management utilities.

use std::cell::RefCell;
use std::error::Error;
use std::fmt;

/// The red zone is the amount of stack space that must be available on the
/// current stack in order for [`maybe_grow`] to call the supplied closure
/// without allocating a new stack.
pub const STACK_RED_ZONE: usize = 32 << 10; // 32KiB

/// The size of any freshly allocated stacks. It was chosen to match the
/// default stack size for threads in Rust.
pub const STACK_SIZE: usize = 2 << 20; // 2MiB

/// Grows the stack if necessary before invoking `f`.
///
/// This function is intended to be called at manually instrumented points in a
/// program where arbitrarily deep recursion is known to happen. This function
/// will check to see if it is within `STACK_RED_ZONE` bytes of the end of the
/// stack, and if so it will allocate a new stack of at least `STACK_SIZE`
/// bytes.
///
/// The closure `f` is guaranteed to run on a stack with at least
/// `STACK_RED_ZONE` bytes, and it will be run on the current stack if there's
/// space available.
///
/// It is generally better to use [`CheckedRecursion`] to enforce a limit on the
/// stack growth. Not all recursive code paths support returning errors,
/// however, in which case unconditionally growing the stack with this function
/// is still preferable to panicking.
pub fn maybe_grow<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
stacker::maybe_grow(STACK_RED_ZONE, STACK_SIZE, || f())
}

/// A trait for types which support bounded recursion to prevent stack overflow.
///
/// The rather odd design of this trait allows checked recursion to be added to
/// existing mutually recursive functions without threading an explicit `depth:
/// &mut usize` parameter through each function. As long as there is an
/// existing context structure, or if the mutually recursive functions are
/// methods on a context structure, the [`RecursionGuard`] can be embedded
/// inside this existing structure.
///
/// # Examples
///
/// Consider a simple expression evaluator:
///
/// ```
/// # use std::collections::HashMap;
///
/// enum Expr {
/// Var { name: String },
/// Add { left: Box<Expr>, right: Box<Expr> },
/// }
///
/// struct Evaluator {
/// vars: HashMap<String, i64>,
/// }
///
/// impl Evaluator {
/// fn eval(&mut self, expr: &Expr) -> i64 {
/// match expr {
/// Expr::Var { name } => self.vars[name],
/// Expr::Add { left, right } => self.eval(left) + self.eval(right),
/// }
/// }
/// }
/// ```
///
/// Calling `eval` could overflow the stack and crash with a sufficiently large
/// `expr`. This is the situation `CheckedRecursion` is designed to solve, like
/// so:
///
/// ```
/// # use std::collections::HashMap;
/// # enum Expr {
/// # Var { name: String },
/// # Add { left: Box<Expr>, right: Box<Expr> },
/// # }
/// use ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
///
/// struct Evaluator {
/// vars: HashMap<String, i64>,
/// recursion_guard: RecursionGuard,
/// }
///
/// impl Evaluator {
/// fn eval(&mut self, expr: &Expr) -> Result<i64, RecursionLimitError> {
/// // ADDED: call to `self.checked_recur`.
/// self.checked_recur_mut(|e| match expr {
/// Expr::Var { name } => Ok(e.vars[name]),
/// Expr::Add { left, right } => Ok(e.eval(left)? + e.eval(right)?),
/// })
/// }
/// }
///
/// impl CheckedRecursion for Evaluator {
/// fn recursion_guard(&self) -> &RecursionGuard {
/// &self.recursion_guard
/// }
/// }
/// ```
pub trait CheckedRecursion {
/// Extracts a reference to the recursion guard embedded within the type.
fn recursion_guard(&self) -> &RecursionGuard;

/// Checks whether it is safe to recur and calls `f` if so.
///
/// If the recursion limit for the recursion guard returned by
/// [`CheckedRecursion::recursion_guard`] has been reached, returns a
/// `RecursionLimitError`. Otherwise, it will call `f`, possibly growing the
/// stack if necessary.
///
/// Calls to this function must be manually inserted at any point that
/// mutual recursion occurs.
fn checked_recur<F, T, E>(&self, f: F) -> Result<T, E>
where
F: FnOnce(&Self) -> Result<T, E>,
E: From<RecursionLimitError>,
{
self.recursion_guard().descend()?;
let out = maybe_grow(|| f(self));
self.recursion_guard().ascend();
out
}

/// Like [`CheckedRecursion::checked_recur`], but operates on a mutable
/// reference to `Self`.
fn checked_recur_mut<F, T, E>(&mut self, f: F) -> Result<T, E>
where
F: FnOnce(&mut Self) -> Result<T, E>,
E: From<RecursionLimitError>,
{
self.recursion_guard().descend()?;
let out = maybe_grow(|| f(self));
self.recursion_guard().ascend();
out
}
}

/// Tracks recursion depth.
///
/// See the [`CheckedRecursion`] trait for usage instructions.
#[derive(Default, Debug, Clone)]
pub struct RecursionGuard {
depth: RefCell<usize>,
limit: usize,
}

impl RecursionGuard {
/// Constructs a new recursion guard with the specified recursion
/// limit.
pub fn with_limit(limit: usize) -> RecursionGuard {
RecursionGuard {
depth: RefCell::new(0),
limit,
}
}

fn descend(&self) -> Result<(), RecursionLimitError> {
let mut depth = self.depth.borrow_mut();
if *depth < self.limit {
*depth += 1;
Ok(())
} else {
Err(RecursionLimitError)
}
}

fn ascend(&self) {
*self.depth.borrow_mut() -= 1;
}
}

/// A [`RecursionGuard`]'s recursion limit was reached.
#[derive(Clone, Debug)]
pub struct RecursionLimitError;

impl fmt::Display for RecursionLimitError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("recursion limit exceeded")
}
}

impl Error for RecursionLimitError {}
1 change: 0 additions & 1 deletion src/sql-parser/Cargo.toml
Expand Up @@ -14,7 +14,6 @@ lazy_static = "1.4.0"
log = "0.4.13"
ore = { path = "../ore", default-features = false }
phf = { version = "0.10.0", features = ["uncased"] }
stacker = "0.1.14"
uncased = "0.9.6"

[dev-dependencies]
Expand Down

0 comments on commit 63c0639

Please sign in to comment.