Skip to content

Commit

Permalink
Use a ParamEnvAnd<Predicate> for caching in ObligationForest
Browse files Browse the repository at this point in the history
Previously, we used a plain `Predicate` to cache results (e.g. successes
and failures) in ObligationForest. However, fulfillment depends on the
precise `ParamEnv` used, so this is unsound in general.

This commit changes the impl of `ForestObligation` for
`PendingPredicateObligation` to use `ParamEnvAnd<Predicate>` instead of
`Predicate` for the associated type. The associated type and method are
renamed from 'predicate' to 'cache_key' to reflect the fact that type is
no longer just a predicate.
  • Loading branch information
Aaron1011 committed Jan 23, 2020
1 parent d1e594f commit 90afc07
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 18 deletions.
9 changes: 6 additions & 3 deletions src/librustc/traits/fulfill.rs
Expand Up @@ -18,10 +18,13 @@ use super::{FulfillmentError, FulfillmentErrorCode};
use super::{ObligationCause, PredicateObligation};

impl<'tcx> ForestObligation for PendingPredicateObligation<'tcx> {
type Predicate = ty::Predicate<'tcx>;
/// Note that we include both the `ParamEnv` and the `Predicate`,
/// as the `ParamEnv` can influence whether fulfillment succeeds
/// or fails.
type CacheKey = ty::ParamEnvAnd<'tcx, ty::Predicate<'tcx>>;

fn as_predicate(&self) -> &Self::Predicate {
&self.obligation.predicate
fn as_cache_key(&self) -> Self::CacheKey {
self.obligation.param_env.and(self.obligation.predicate)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/librustc_data_structures/obligation_forest/graphviz.rs
Expand Up @@ -51,7 +51,7 @@ impl<'a, O: ForestObligation + 'a> dot::Labeller<'a> for &'a ObligationForest<O>

fn node_label(&self, index: &Self::Node) -> dot::LabelText<'_> {
let node = &self.nodes[*index];
let label = format!("{:?} ({:?})", node.obligation.as_predicate(), node.state.get());
let label = format!("{:?} ({:?})", node.obligation.as_cache_key(), node.state.get());

dot::LabelText::LabelStr(label.into())
}
Expand Down
29 changes: 17 additions & 12 deletions src/librustc_data_structures/obligation_forest/mod.rs
Expand Up @@ -86,9 +86,13 @@ mod graphviz;
mod tests;

pub trait ForestObligation: Clone + Debug {
type Predicate: Clone + hash::Hash + Eq + Debug;
type CacheKey: Clone + hash::Hash + Eq + Debug;

fn as_predicate(&self) -> &Self::Predicate;
/// Converts this `ForestObligation` suitable for use as a cache key.
/// If two distinct `ForestObligations`s return the same cache key,
/// then it must be sound to use the result of processing one obligation
/// (e.g. success for error) for the other obligation
fn as_cache_key(&self) -> Self::CacheKey;
}

pub trait ObligationProcessor {
Expand Down Expand Up @@ -138,12 +142,12 @@ pub struct ObligationForest<O: ForestObligation> {
nodes: Vec<Node<O>>,

/// A cache of predicates that have been successfully completed.
done_cache: FxHashSet<O::Predicate>,
done_cache: FxHashSet<O::CacheKey>,

/// A cache of the nodes in `nodes`, indexed by predicate. Unfortunately,
/// its contents are not guaranteed to match those of `nodes`. See the
/// comments in `process_obligation` for details.
active_cache: FxHashMap<O::Predicate, usize>,
active_cache: FxHashMap<O::CacheKey, usize>,

/// A vector reused in compress(), to avoid allocating new vectors.
node_rewrites: RefCell<Vec<usize>>,
Expand All @@ -157,7 +161,7 @@ pub struct ObligationForest<O: ForestObligation> {
/// See [this][details] for details.
///
/// [details]: https://github.com/rust-lang/rust/pull/53255#issuecomment-421184780
error_cache: FxHashMap<ObligationTreeId, FxHashSet<O::Predicate>>,
error_cache: FxHashMap<ObligationTreeId, FxHashSet<O::CacheKey>>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -305,11 +309,12 @@ impl<O: ForestObligation> ObligationForest<O> {

// Returns Err(()) if we already know this obligation failed.
fn register_obligation_at(&mut self, obligation: O, parent: Option<usize>) -> Result<(), ()> {
if self.done_cache.contains(obligation.as_predicate()) {
if self.done_cache.contains(&obligation.as_cache_key()) {
debug!("register_obligation_at: ignoring already done obligation: {:?}", obligation);
return Ok(());
}

match self.active_cache.entry(obligation.as_predicate().clone()) {
match self.active_cache.entry(obligation.as_cache_key().clone()) {
Entry::Occupied(o) => {
let node = &mut self.nodes[*o.get()];
if let Some(parent_index) = parent {
Expand All @@ -333,7 +338,7 @@ impl<O: ForestObligation> ObligationForest<O> {
&& self
.error_cache
.get(&obligation_tree_id)
.map(|errors| errors.contains(obligation.as_predicate()))
.map(|errors| errors.contains(&obligation.as_cache_key()))
.unwrap_or(false);

if already_failed {
Expand Down Expand Up @@ -380,7 +385,7 @@ impl<O: ForestObligation> ObligationForest<O> {
self.error_cache
.entry(node.obligation_tree_id)
.or_default()
.insert(node.obligation.as_predicate().clone());
.insert(node.obligation.as_cache_key().clone());
}

/// Performs a pass through the obligation list. This must
Expand Down Expand Up @@ -618,11 +623,11 @@ impl<O: ForestObligation> ObligationForest<O> {
// `self.nodes`. See the comment in `process_obligation`
// for more details.
if let Some((predicate, _)) =
self.active_cache.remove_entry(node.obligation.as_predicate())
self.active_cache.remove_entry(&node.obligation.as_cache_key())
{
self.done_cache.insert(predicate);
} else {
self.done_cache.insert(node.obligation.as_predicate().clone());
self.done_cache.insert(node.obligation.as_cache_key().clone());
}
if do_completed == DoCompleted::Yes {
// Extract the success stories.
Expand All @@ -635,7 +640,7 @@ impl<O: ForestObligation> ObligationForest<O> {
// We *intentionally* remove the node from the cache at this point. Otherwise
// tests must come up with a different type on every type error they
// check against.
self.active_cache.remove(node.obligation.as_predicate());
self.active_cache.remove(&node.obligation.as_cache_key());
self.insert_into_error_cache(index);
node_rewrites[index] = orig_nodes_len;
dead_nodes += 1;
Expand Down
4 changes: 2 additions & 2 deletions src/librustc_data_structures/obligation_forest/tests.rs
Expand Up @@ -4,9 +4,9 @@ use std::fmt;
use std::marker::PhantomData;

impl<'a> super::ForestObligation for &'a str {
type Predicate = &'a str;
type CacheKey = &'a str;

fn as_predicate(&self) -> &Self::Predicate {
fn as_cache_key(&self) -> Self::CacheKey {
self
}
}
Expand Down

0 comments on commit 90afc07

Please sign in to comment.