diff --git a/compiler/rustc_hir/src/hir.rs b/compiler/rustc_hir/src/hir.rs index 23999a8dca00f..ad681115812cb 100644 --- a/compiler/rustc_hir/src/hir.rs +++ b/compiler/rustc_hir/src/hir.rs @@ -760,9 +760,9 @@ pub struct Pat<'hir> { pub default_binding_modes: bool, } -impl Pat<'_> { +impl<'hir> Pat<'hir> { // FIXME(#19596) this is a workaround, but there should be a better way - fn walk_short_(&self, it: &mut impl FnMut(&Pat<'_>) -> bool) -> bool { + fn walk_short_(&self, it: &mut impl FnMut(&Pat<'hir>) -> bool) -> bool { if !it(self) { return false; } @@ -785,12 +785,12 @@ impl Pat<'_> { /// Note that when visiting e.g. `Tuple(ps)`, /// if visiting `ps[0]` returns `false`, /// then `ps[1]` will not be visited. - pub fn walk_short(&self, mut it: impl FnMut(&Pat<'_>) -> bool) -> bool { + pub fn walk_short(&self, mut it: impl FnMut(&Pat<'hir>) -> bool) -> bool { self.walk_short_(&mut it) } // FIXME(#19596) this is a workaround, but there should be a better way - fn walk_(&self, it: &mut impl FnMut(&Pat<'_>) -> bool) { + fn walk_(&self, it: &mut impl FnMut(&Pat<'hir>) -> bool) { if !it(self) { return; } @@ -810,7 +810,7 @@ impl Pat<'_> { /// Walk the pattern in left-to-right order. /// /// If `it(pat)` returns `false`, the children are not visited. - pub fn walk(&self, mut it: impl FnMut(&Pat<'_>) -> bool) { + pub fn walk(&self, mut it: impl FnMut(&Pat<'hir>) -> bool) { self.walk_(&mut it) } diff --git a/compiler/rustc_typeck/src/collect/type_of.rs b/compiler/rustc_typeck/src/collect/type_of.rs index 88ba5788b05d1..3c97b55005c44 100644 --- a/compiler/rustc_typeck/src/collect/type_of.rs +++ b/compiler/rustc_typeck/src/collect/type_of.rs @@ -6,7 +6,7 @@ use rustc_hir::def::{DefKind, Res}; use rustc_hir::def_id::{DefId, LocalDefId}; use rustc_hir::intravisit; use rustc_hir::intravisit::Visitor; -use rustc_hir::Node; +use rustc_hir::{HirId, Node}; use rustc_middle::hir::map::Map; use rustc_middle::ty::subst::{GenericArgKind, InternalSubsts}; use rustc_middle::ty::util::IntTypeExt; @@ -22,7 +22,6 @@ use super::{bad_placeholder_type, is_suggestable_infer_ty}; /// This should be called using the query `tcx.opt_const_param_of`. pub(super) fn opt_const_param_of(tcx: TyCtxt<'_>, def_id: LocalDefId) -> Option { use hir::*; - let hir_id = tcx.hir().local_def_id_to_hir_id(def_id); if let Node::AnonConst(_) = tcx.hir().get(hir_id) { @@ -62,9 +61,9 @@ pub(super) fn opt_const_param_of(tcx: TyCtxt<'_>, def_id: LocalDefId) -> Option< } Node::Ty(&Ty { kind: TyKind::Path(_), .. }) - | Node::Expr(&Expr { kind: ExprKind::Struct(..), .. }) - | Node::Expr(&Expr { kind: ExprKind::Path(_), .. }) - | Node::TraitRef(..) => { + | Node::Expr(&Expr { kind: ExprKind::Path(_) | ExprKind::Struct(..), .. }) + | Node::TraitRef(..) + | Node::Pat(_) => { let path = match parent_node { Node::Ty(&Ty { kind: TyKind::Path(QPath::Resolved(_, path)), .. }) | Node::TraitRef(&TraitRef { path, .. }) => &*path, @@ -79,6 +78,20 @@ pub(super) fn opt_const_param_of(tcx: TyCtxt<'_>, def_id: LocalDefId) -> Option< let _tables = tcx.typeck(body_owner); &*path } + Node::Pat(pat) => { + if let Some(path) = get_path_containing_arg_in_pat(pat, hir_id) { + path + } else { + tcx.sess.delay_span_bug( + tcx.def_span(def_id), + &format!( + "unable to find const parent for {} in pat {:?}", + hir_id, pat + ), + ); + return None; + } + } _ => { tcx.sess.delay_span_bug( tcx.def_span(def_id), @@ -91,7 +104,6 @@ pub(super) fn opt_const_param_of(tcx: TyCtxt<'_>, def_id: LocalDefId) -> Option< // We've encountered an `AnonConst` in some path, so we need to // figure out which generic parameter it corresponds to and return // the relevant type. - let (arg_index, segment) = path .segments .iter() @@ -144,6 +156,34 @@ pub(super) fn opt_const_param_of(tcx: TyCtxt<'_>, def_id: LocalDefId) -> Option< } } +fn get_path_containing_arg_in_pat<'hir>( + pat: &'hir hir::Pat<'hir>, + arg_id: HirId, +) -> Option<&'hir hir::Path<'hir>> { + use hir::*; + + let is_arg_in_path = |p: &hir::Path<'_>| { + p.segments + .iter() + .filter_map(|seg| seg.args) + .flat_map(|args| args.args) + .any(|arg| arg.id() == arg_id) + }; + let mut arg_path = None; + pat.walk(|pat| match pat.kind { + PatKind::Struct(QPath::Resolved(_, path), _, _) + | PatKind::TupleStruct(QPath::Resolved(_, path), _, _) + | PatKind::Path(QPath::Resolved(_, path)) + if is_arg_in_path(path) => + { + arg_path = Some(path); + false + } + _ => true, + }); + arg_path +} + pub(super) fn type_of(tcx: TyCtxt<'_>, def_id: DefId) -> Ty<'_> { let def_id = def_id.expect_local(); use rustc_hir::*; diff --git a/src/test/ui/const-generics/arg-in-pat-1.rs b/src/test/ui/const-generics/arg-in-pat-1.rs new file mode 100644 index 0000000000000..82555084e418f --- /dev/null +++ b/src/test/ui/const-generics/arg-in-pat-1.rs @@ -0,0 +1,23 @@ +// check-pass +enum ConstGenericEnum { + Foo([i32; N]), + Bar, +} + +fn foo(val: &ConstGenericEnum) { + if let ConstGenericEnum::::Foo(field, ..) = val {} +} + +fn bar(val: &ConstGenericEnum) { + match val { + ConstGenericEnum::::Foo(field, ..) => (), + ConstGenericEnum::::Bar => (), + } +} + +fn main() { + match ConstGenericEnum::Bar { + ConstGenericEnum::<3>::Foo(field, ..) => (), + ConstGenericEnum::<3>::Bar => (), + } +} diff --git a/src/test/ui/const-generics/arg-in-pat-2.rs b/src/test/ui/const-generics/arg-in-pat-2.rs new file mode 100644 index 0000000000000..dc9e722eda84c --- /dev/null +++ b/src/test/ui/const-generics/arg-in-pat-2.rs @@ -0,0 +1,10 @@ +// check-pass +enum Generic { + Variant, +} + +fn main() { + match todo!() { + Generic::<0usize>::Variant => todo!() + } +} diff --git a/src/test/ui/const-generics/arg-in-pat-3.rs b/src/test/ui/const-generics/arg-in-pat-3.rs new file mode 100644 index 0000000000000..24626a3b68ae5 --- /dev/null +++ b/src/test/ui/const-generics/arg-in-pat-3.rs @@ -0,0 +1,43 @@ +// check-pass +struct Foo; + +fn bindingp() { + match Foo { + mut x @ Foo::<3> => { + let ref mut _x @ Foo::<3> = x; + } + } +} + +struct Bar { + field: Foo, +} + +fn structp() { + match todo!() { + Bar::<3> { + field: Foo::<3>, + } => (), + } +} + +struct Baz(Foo); + +fn tuplestructp() { + match Baz(Foo) { + Baz::<3>(Foo::<3>) => (), + } +} + +impl Baz { + const ASSOC: usize = 3; +} + +fn pathp() { + match 3 { + Baz::<3>::ASSOC => (), + _ => (), + } +} + +fn main() {}