Skip to content

Commit

Permalink
Don't go through TraitRef to relate projections
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewjasper committed Feb 13, 2021
1 parent 9bbd3e0 commit 0bf1d73
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 18 deletions.
25 changes: 24 additions & 1 deletion compiler/rustc_infer/src/infer/at.rs
Expand Up @@ -55,6 +55,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {

pub trait ToTrace<'tcx>: Relate<'tcx> + Copy {
fn to_trace(
tcx: TyCtxt<'tcx>,
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
Expand Down Expand Up @@ -178,7 +179,7 @@ impl<'a, 'tcx> At<'a, 'tcx> {
where
T: ToTrace<'tcx>,
{
let trace = ToTrace::to_trace(self.cause, a_is_expected, a, b);
let trace = ToTrace::to_trace(self.infcx.tcx, self.cause, a_is_expected, a, b);
Trace { at: self, trace, a_is_expected }
}
}
Expand Down Expand Up @@ -251,6 +252,7 @@ impl<'a, 'tcx> Trace<'a, 'tcx> {

impl<'tcx> ToTrace<'tcx> for Ty<'tcx> {
fn to_trace(
_: TyCtxt<'tcx>,
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
Expand All @@ -262,6 +264,7 @@ impl<'tcx> ToTrace<'tcx> for Ty<'tcx> {

impl<'tcx> ToTrace<'tcx> for ty::Region<'tcx> {
fn to_trace(
_: TyCtxt<'tcx>,
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
Expand All @@ -273,6 +276,7 @@ impl<'tcx> ToTrace<'tcx> for ty::Region<'tcx> {

impl<'tcx> ToTrace<'tcx> for &'tcx Const<'tcx> {
fn to_trace(
_: TyCtxt<'tcx>,
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
Expand All @@ -284,6 +288,7 @@ impl<'tcx> ToTrace<'tcx> for &'tcx Const<'tcx> {

impl<'tcx> ToTrace<'tcx> for ty::TraitRef<'tcx> {
fn to_trace(
_: TyCtxt<'tcx>,
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
Expand All @@ -298,6 +303,7 @@ impl<'tcx> ToTrace<'tcx> for ty::TraitRef<'tcx> {

impl<'tcx> ToTrace<'tcx> for ty::PolyTraitRef<'tcx> {
fn to_trace(
_: TyCtxt<'tcx>,
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
Expand All @@ -309,3 +315,20 @@ impl<'tcx> ToTrace<'tcx> for ty::PolyTraitRef<'tcx> {
}
}
}

impl<'tcx> ToTrace<'tcx> for ty::ProjectionTy<'tcx> {
fn to_trace(
tcx: TyCtxt<'tcx>,
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
b: Self,
) -> TypeTrace<'tcx> {
let a_ty = tcx.mk_projection(a.item_def_id, a.substs);
let b_ty = tcx.mk_projection(b.item_def_id, b.substs);
TypeTrace {
cause: cause.clone(),
values: Types(ExpectedFound::new(a_is_expected, a_ty, b_ty)),
}
}
}
15 changes: 7 additions & 8 deletions compiler/rustc_trait_selection/src/traits/project.rs
Expand Up @@ -921,8 +921,7 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>(
&& infcx.probe(|_| {
selcx.match_projection_projections(
obligation,
obligation_trait_ref,
&data,
data,
potentially_unnormalized_candidates,
)
});
Expand Down Expand Up @@ -1344,25 +1343,25 @@ fn confirm_param_env_candidate<'cx, 'tcx>(
poly_cache_entry,
);

let cache_trait_ref = cache_entry.projection_ty.trait_ref(infcx.tcx);
let obligation_trait_ref = obligation.predicate.trait_ref(infcx.tcx);
let cache_projection = cache_entry.projection_ty;
let obligation_projection = obligation.predicate;
let mut nested_obligations = Vec::new();
let cache_trait_ref = if potentially_unnormalized_candidate {
let cache_projection = if potentially_unnormalized_candidate {
ensure_sufficient_stack(|| {
normalize_with_depth_to(
selcx,
obligation.param_env,
obligation.cause.clone(),
obligation.recursion_depth + 1,
cache_trait_ref,
cache_projection,
&mut nested_obligations,
)
})
} else {
cache_trait_ref
cache_projection
};

match infcx.at(cause, param_env).eq(cache_trait_ref, obligation_trait_ref) {
match infcx.at(cause, param_env).eq(cache_projection, obligation_projection) {
Ok(InferOk { value: _, obligations }) => {
nested_obligations.extend(obligations);
assoc_ty_own_obligations(selcx, obligation, &mut nested_obligations);
Expand Down
20 changes: 11 additions & 9 deletions compiler/rustc_trait_selection/src/traits/select/mod.rs
Expand Up @@ -32,6 +32,7 @@ use rustc_errors::ErrorReported;
use rustc_hir as hir;
use rustc_hir::def_id::DefId;
use rustc_hir::Constness;
use rustc_infer::infer::LateBoundRegionConversionTime;
use rustc_middle::dep_graph::{DepKind, DepNodeIndex};
use rustc_middle::mir::interpret::ErrorHandled;
use rustc_middle::ty::fast_reject;
Expand Down Expand Up @@ -1254,32 +1255,33 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
pub(super) fn match_projection_projections(
&mut self,
obligation: &ProjectionTyObligation<'tcx>,
obligation_trait_ref: &ty::TraitRef<'tcx>,
data: &PolyProjectionPredicate<'tcx>,
env_predicate: PolyProjectionPredicate<'tcx>,
potentially_unnormalized_candidates: bool,
) -> bool {
let mut nested_obligations = Vec::new();
let projection_ty = if potentially_unnormalized_candidates {
let (infer_predicate, _) = self.infcx.replace_bound_vars_with_fresh_vars(
obligation.cause.span,
LateBoundRegionConversionTime::HigherRankedType,
env_predicate,
);
let infer_projection = if potentially_unnormalized_candidates {
ensure_sufficient_stack(|| {
project::normalize_with_depth_to(
self,
obligation.param_env,
obligation.cause.clone(),
obligation.recursion_depth + 1,
data.map_bound(|data| data.projection_ty),
infer_predicate.projection_ty,
&mut nested_obligations,
)
})
} else {
data.map_bound(|data| data.projection_ty)
infer_predicate.projection_ty
};

// FIXME(generic_associated_types): Compare the whole projections
let data_poly_trait_ref = projection_ty.map_bound(|proj| proj.trait_ref(self.tcx()));
let obligation_poly_trait_ref = ty::Binder::dummy(*obligation_trait_ref);
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(obligation_poly_trait_ref, data_poly_trait_ref)
.sup(obligation.predicate, infer_projection)
.map_or(false, |InferOk { obligations, value: () }| {
self.evaluate_predicates_recursively(
TraitObligationStackList::empty(&ProvisionalEvaluationCache::default()),
Expand Down

0 comments on commit 0bf1d73

Please sign in to comment.