From 3b0c2585c8daec44877a38fd83edaf1888afe7f9 Mon Sep 17 00:00:00 2001
From: Josh Stone <jistone@redhat.com>
Date: Wed, 26 Feb 2025 20:47:53 -0800
Subject: [PATCH] Convert `ShardedHashMap` to use `hashbrown::HashTable`

The `hash_raw_entry` feature has finished fcp-close, so the compiler
should stop using it to allow its removal. Several `Sharded` maps were
using raw entries to avoid re-hashing between shard and map lookup, and
we can do that with `hashbrown::HashTable` instead.
---
 Cargo.lock                                    |  2 +
 compiler/rustc_data_structures/Cargo.toml     |  5 +
 compiler/rustc_data_structures/src/lib.rs     |  1 -
 compiler/rustc_data_structures/src/marker.rs  |  2 +
 compiler/rustc_data_structures/src/sharded.rs | 95 +++++++++++++++----
 .../rustc_middle/src/mir/interpret/mod.rs     | 20 +---
 compiler/rustc_middle/src/ty/context.rs       |  4 +-
 .../rustc_query_system/src/dep_graph/graph.rs | 31 +++---
 compiler/rustc_query_system/src/lib.rs        |  1 -
 .../rustc_query_system/src/query/caches.rs    | 14 +--
 10 files changed, 109 insertions(+), 66 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index ecd26d6199c90..0ae5798d78a48 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1491,6 +1491,7 @@ version = "0.15.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
 dependencies = [
+ "allocator-api2",
  "foldhash",
  "serde",
 ]
@@ -3492,6 +3493,7 @@ dependencies = [
  "either",
  "elsa",
  "ena",
+ "hashbrown 0.15.2",
  "indexmap",
  "jobserver",
  "libc",
diff --git a/compiler/rustc_data_structures/Cargo.toml b/compiler/rustc_data_structures/Cargo.toml
index bdf5494f2107b..b9a5fc3a1fe9c 100644
--- a/compiler/rustc_data_structures/Cargo.toml
+++ b/compiler/rustc_data_structures/Cargo.toml
@@ -29,6 +29,11 @@ thin-vec = "0.2.12"
 tracing = "0.1"
 # tidy-alphabetical-end
 
+[dependencies.hashbrown]
+version = "0.15.2"
+default-features = false
+features = ["nightly"] # for may_dangle
+
 [dependencies.parking_lot]
 version = "0.12"
 
diff --git a/compiler/rustc_data_structures/src/lib.rs b/compiler/rustc_data_structures/src/lib.rs
index a3b62b469196a..865424fd6bbdc 100644
--- a/compiler/rustc_data_structures/src/lib.rs
+++ b/compiler/rustc_data_structures/src/lib.rs
@@ -24,7 +24,6 @@
 #![feature(dropck_eyepatch)]
 #![feature(extend_one)]
 #![feature(file_buffered)]
-#![feature(hash_raw_entry)]
 #![feature(macro_metavar_expr)]
 #![feature(map_try_insert)]
 #![feature(min_specialization)]
diff --git a/compiler/rustc_data_structures/src/marker.rs b/compiler/rustc_data_structures/src/marker.rs
index 4cacc269709a9..64c64bfa3c296 100644
--- a/compiler/rustc_data_structures/src/marker.rs
+++ b/compiler/rustc_data_structures/src/marker.rs
@@ -76,6 +76,7 @@ impl_dyn_send!(
     [crate::sync::RwLock<T> where T: DynSend]
     [crate::tagged_ptr::TaggedRef<'a, P, T> where 'a, P: Sync, T: Send + crate::tagged_ptr::Tag]
     [rustc_arena::TypedArena<T> where T: DynSend]
+    [hashbrown::HashTable<T> where T: DynSend]
     [indexmap::IndexSet<V, S> where V: DynSend, S: DynSend]
     [indexmap::IndexMap<K, V, S> where K: DynSend, V: DynSend, S: DynSend]
     [thin_vec::ThinVec<T> where T: DynSend]
@@ -153,6 +154,7 @@ impl_dyn_sync!(
     [crate::tagged_ptr::TaggedRef<'a, P, T> where 'a, P: Sync, T: Sync + crate::tagged_ptr::Tag]
     [parking_lot::lock_api::Mutex<R, T> where R: DynSync, T: ?Sized + DynSend]
     [parking_lot::lock_api::RwLock<R, T> where R: DynSync, T: ?Sized + DynSend + DynSync]
+    [hashbrown::HashTable<T> where T: DynSync]
     [indexmap::IndexSet<V, S> where V: DynSync, S: DynSync]
     [indexmap::IndexMap<K, V, S> where K: DynSync, V: DynSync, S: DynSync]
     [smallvec::SmallVec<A> where A: smallvec::Array + DynSync]
diff --git a/compiler/rustc_data_structures/src/sharded.rs b/compiler/rustc_data_structures/src/sharded.rs
index 3016348f22469..49cafcb17a03e 100644
--- a/compiler/rustc_data_structures/src/sharded.rs
+++ b/compiler/rustc_data_structures/src/sharded.rs
@@ -1,11 +1,11 @@
 use std::borrow::Borrow;
-use std::collections::hash_map::RawEntryMut;
 use std::hash::{Hash, Hasher};
-use std::iter;
+use std::{iter, mem};
 
 use either::Either;
+use hashbrown::hash_table::{Entry, HashTable};
 
-use crate::fx::{FxHashMap, FxHasher};
+use crate::fx::FxHasher;
 use crate::sync::{CacheAligned, Lock, LockGuard, Mode, is_dyn_thread_safe};
 
 // 32 shards is sufficient to reduce contention on an 8-core Ryzen 7 1700,
@@ -140,17 +140,67 @@ pub fn shards() -> usize {
     1
 }
 
-pub type ShardedHashMap<K, V> = Sharded<FxHashMap<K, V>>;
+pub type ShardedHashMap<K, V> = Sharded<HashTable<(K, V)>>;
 
 impl<K: Eq, V> ShardedHashMap<K, V> {
     pub fn with_capacity(cap: usize) -> Self {
-        Self::new(|| FxHashMap::with_capacity_and_hasher(cap, rustc_hash::FxBuildHasher::default()))
+        Self::new(|| HashTable::with_capacity(cap))
     }
     pub fn len(&self) -> usize {
         self.lock_shards().map(|shard| shard.len()).sum()
     }
 }
 
+impl<K: Eq + Hash, V> ShardedHashMap<K, V> {
+    #[inline]
+    pub fn get<Q>(&self, key: &Q) -> Option<V>
+    where
+        K: Borrow<Q>,
+        Q: Hash + Eq,
+        V: Clone,
+    {
+        let hash = make_hash(key);
+        let shard = self.lock_shard_by_hash(hash);
+        let (_, value) = shard.find(hash, |(k, _)| k.borrow() == key)?;
+        Some(value.clone())
+    }
+
+    #[inline]
+    pub fn get_or_insert_with(&self, key: K, default: impl FnOnce() -> V) -> V
+    where
+        V: Copy,
+    {
+        let hash = make_hash(&key);
+        let mut shard = self.lock_shard_by_hash(hash);
+
+        match table_entry(&mut shard, hash, &key) {
+            Entry::Occupied(e) => e.get().1,
+            Entry::Vacant(e) => {
+                let value = default();
+                e.insert((key, value));
+                value
+            }
+        }
+    }
+
+    #[inline]
+    pub fn insert(&self, key: K, value: V) -> Option<V> {
+        let hash = make_hash(&key);
+        let mut shard = self.lock_shard_by_hash(hash);
+
+        match table_entry(&mut shard, hash, &key) {
+            Entry::Occupied(e) => {
+                let previous = mem::replace(&mut e.into_mut().1, value);
+                Some(previous)
+            }
+            Entry::Vacant(e) => {
+                e.insert((key, value));
+                None
+            }
+        }
+    }
+}
+
 impl<K: Eq + Hash + Copy> ShardedHashMap<K, ()> {
     #[inline]
     pub fn intern_ref<Q: ?Sized>(&self, value: &Q, make: impl FnOnce() -> K) -> K
@@ -160,13 +210,12 @@ impl<K: Eq + Hash + Copy> ShardedHashMap<K, ()> {
     {
         let hash = make_hash(value);
         let mut shard = self.lock_shard_by_hash(hash);
-        let entry = shard.raw_entry_mut().from_key_hashed_nocheck(hash, value);
 
-        match entry {
-            RawEntryMut::Occupied(e) => *e.key(),
-            RawEntryMut::Vacant(e) => {
+        match table_entry(&mut shard, hash, value) {
+            Entry::Occupied(e) => e.get().0,
+            Entry::Vacant(e) => {
                 let v = make();
-                e.insert_hashed_nocheck(hash, v, ());
+                e.insert((v, ()));
                 v
             }
         }
@@ -180,13 +229,12 @@ impl<K: Eq + Hash + Copy> ShardedHashMap<K, ()> {
     {
         let hash = make_hash(&value);
         let mut shard = self.lock_shard_by_hash(hash);
-        let entry = shard.raw_entry_mut().from_key_hashed_nocheck(hash, &value);
 
-        match entry {
-            RawEntryMut::Occupied(e) => *e.key(),
-            RawEntryMut::Vacant(e) => {
+        match table_entry(&mut shard, hash, &value) {
+            Entry::Occupied(e) => e.get().0,
+            Entry::Vacant(e) => {
                 let v = make(value);
-                e.insert_hashed_nocheck(hash, v, ());
+                e.insert((v, ()));
                 v
             }
         }
@@ -203,17 +251,30 @@ impl<K: Eq + Hash + Copy + IntoPointer> ShardedHashMap<K, ()> {
         let hash = make_hash(&value);
         let shard = self.lock_shard_by_hash(hash);
         let value = value.into_pointer();
-        shard.raw_entry().from_hash(hash, |entry| entry.into_pointer() == value).is_some()
+        shard.find(hash, |(k, ())| k.into_pointer() == value).is_some()
     }
 }
 
 #[inline]
-pub fn make_hash<K: Hash + ?Sized>(val: &K) -> u64 {
+fn make_hash<K: Hash + ?Sized>(val: &K) -> u64 {
     let mut state = FxHasher::default();
     val.hash(&mut state);
     state.finish()
 }
 
+#[inline]
+fn table_entry<'a, K, V, Q>(
+    table: &'a mut HashTable<(K, V)>,
+    hash: u64,
+    key: &Q,
+) -> Entry<'a, (K, V)>
+where
+    K: Hash + Borrow<Q>,
+    Q: ?Sized + Eq,
+{
+    table.entry(hash, move |(k, _)| k.borrow() == key, |(k, _)| make_hash(k))
+}
+
 /// Get a shard with a pre-computed hash value. If `get_shard_by_value` is
 /// ever used in combination with `get_shard_by_hash` on a single `Sharded`
 /// instance, then `hash` must be computed with `FxHasher`. Otherwise,
diff --git a/compiler/rustc_middle/src/mir/interpret/mod.rs b/compiler/rustc_middle/src/mir/interpret/mod.rs
index 2675b7e0fc512..effedf854e473 100644
--- a/compiler/rustc_middle/src/mir/interpret/mod.rs
+++ b/compiler/rustc_middle/src/mir/interpret/mod.rs
@@ -452,12 +452,7 @@ impl<'tcx> TyCtxt<'tcx> {
         }
         let id = self.alloc_map.reserve();
         debug!("creating alloc {:?} with id {id:?}", alloc_salt.0);
-        let had_previous = self
-            .alloc_map
-            .to_alloc
-            .lock_shard_by_value(&id)
-            .insert(id, alloc_salt.0.clone())
-            .is_some();
+        let had_previous = self.alloc_map.to_alloc.insert(id, alloc_salt.0.clone()).is_some();
         // We just reserved, so should always be unique.
         assert!(!had_previous);
         dedup.insert(alloc_salt, id);
@@ -510,7 +505,7 @@ impl<'tcx> TyCtxt<'tcx> {
     /// local dangling pointers and allocations in constants/statics.
     #[inline]
     pub fn try_get_global_alloc(self, id: AllocId) -> Option<GlobalAlloc<'tcx>> {
-        self.alloc_map.to_alloc.lock_shard_by_value(&id).get(&id).cloned()
+        self.alloc_map.to_alloc.get(&id)
     }
 
     #[inline]
@@ -529,9 +524,7 @@ impl<'tcx> TyCtxt<'tcx> {
     /// Freezes an `AllocId` created with `reserve` by pointing it at an `Allocation`. Trying to
     /// call this function twice, even with the same `Allocation` will ICE the compiler.
     pub fn set_alloc_id_memory(self, id: AllocId, mem: ConstAllocation<'tcx>) {
-        if let Some(old) =
-            self.alloc_map.to_alloc.lock_shard_by_value(&id).insert(id, GlobalAlloc::Memory(mem))
-        {
+        if let Some(old) = self.alloc_map.to_alloc.insert(id, GlobalAlloc::Memory(mem)) {
             bug!("tried to set allocation ID {id:?}, but it was already existing as {old:#?}");
         }
     }
@@ -539,11 +532,8 @@ impl<'tcx> TyCtxt<'tcx> {
     /// Freezes an `AllocId` created with `reserve` by pointing it at a static item. Trying to
     /// call this function twice, even with the same `DefId` will ICE the compiler.
     pub fn set_nested_alloc_id_static(self, id: AllocId, def_id: LocalDefId) {
-        if let Some(old) = self
-            .alloc_map
-            .to_alloc
-            .lock_shard_by_value(&id)
-            .insert(id, GlobalAlloc::Static(def_id.to_def_id()))
+        if let Some(old) =
+            self.alloc_map.to_alloc.insert(id, GlobalAlloc::Static(def_id.to_def_id()))
         {
             bug!("tried to set allocation ID {id:?}, but it was already existing as {old:#?}");
         }
diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs
index edba2a2530f68..70bb902d1ae7e 100644
--- a/compiler/rustc_middle/src/ty/context.rs
+++ b/compiler/rustc_middle/src/ty/context.rs
@@ -2336,8 +2336,8 @@ macro_rules! sty_debug_print {
                 $(let mut $variant = total;)*
 
                 for shard in tcx.interners.type_.lock_shards() {
-                    let types = shard.keys();
-                    for &InternedInSet(t) in types {
+                    let types = shard.iter();
+                    for &(InternedInSet(t), ()) in types {
                         let variant = match t.internee {
                             ty::Bool | ty::Char | ty::Int(..) | ty::Uint(..) |
                                 ty::Float(..) | ty::Str | ty::Never => continue,
diff --git a/compiler/rustc_query_system/src/dep_graph/graph.rs b/compiler/rustc_query_system/src/dep_graph/graph.rs
index 6ece01c211bac..3109d53cd2caa 100644
--- a/compiler/rustc_query_system/src/dep_graph/graph.rs
+++ b/compiler/rustc_query_system/src/dep_graph/graph.rs
@@ -1,5 +1,4 @@
 use std::assert_matches::assert_matches;
-use std::collections::hash_map::Entry;
 use std::fmt::Debug;
 use std::hash::Hash;
 use std::marker::PhantomData;
@@ -9,7 +8,7 @@ use std::sync::atomic::{AtomicU32, Ordering};
 use rustc_data_structures::fingerprint::Fingerprint;
 use rustc_data_structures::fx::{FxHashMap, FxHashSet};
 use rustc_data_structures::profiling::{QueryInvocationId, SelfProfilerRef};
-use rustc_data_structures::sharded::{self, Sharded};
+use rustc_data_structures::sharded::{self, ShardedHashMap};
 use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
 use rustc_data_structures::sync::{AtomicU64, Lock};
 use rustc_data_structures::unord::UnordMap;
@@ -619,7 +618,7 @@ impl<D: Deps> DepGraphData<D> {
         if let Some(prev_index) = self.previous.node_to_index_opt(dep_node) {
             self.current.prev_index_to_index.lock()[prev_index]
         } else {
-            self.current.new_node_to_index.lock_shard_by_value(dep_node).get(dep_node).copied()
+            self.current.new_node_to_index.get(dep_node)
         }
     }
 
@@ -1048,7 +1047,7 @@ rustc_index::newtype_index! {
 /// first, and `data` second.
 pub(super) struct CurrentDepGraph<D: Deps> {
     encoder: GraphEncoder<D>,
-    new_node_to_index: Sharded<FxHashMap<DepNode, DepNodeIndex>>,
+    new_node_to_index: ShardedHashMap<DepNode, DepNodeIndex>,
     prev_index_to_index: Lock<IndexVec<SerializedDepNodeIndex, Option<DepNodeIndex>>>,
 
     /// This is used to verify that fingerprints do not change between the creation of a node
@@ -1117,12 +1116,9 @@ impl<D: Deps> CurrentDepGraph<D> {
                 profiler,
                 previous,
             ),
-            new_node_to_index: Sharded::new(|| {
-                FxHashMap::with_capacity_and_hasher(
-                    new_node_count_estimate / sharded::shards(),
-                    Default::default(),
-                )
-            }),
+            new_node_to_index: ShardedHashMap::with_capacity(
+                new_node_count_estimate / sharded::shards(),
+            ),
             prev_index_to_index: Lock::new(IndexVec::from_elem_n(None, prev_graph_node_count)),
             anon_id_seed,
             #[cfg(debug_assertions)]
@@ -1152,14 +1148,9 @@ impl<D: Deps> CurrentDepGraph<D> {
         edges: EdgesVec,
         current_fingerprint: Fingerprint,
     ) -> DepNodeIndex {
-        let dep_node_index = match self.new_node_to_index.lock_shard_by_value(&key).entry(key) {
-            Entry::Occupied(entry) => *entry.get(),
-            Entry::Vacant(entry) => {
-                let dep_node_index = self.encoder.send(key, current_fingerprint, edges);
-                entry.insert(dep_node_index);
-                dep_node_index
-            }
-        };
+        let dep_node_index = self
+            .new_node_to_index
+            .get_or_insert_with(key, || self.encoder.send(key, current_fingerprint, edges));
 
         #[cfg(debug_assertions)]
         self.record_edge(dep_node_index, key, current_fingerprint);
@@ -1257,7 +1248,7 @@ impl<D: Deps> CurrentDepGraph<D> {
     ) {
         let node = &prev_graph.index_to_node(prev_index);
         debug_assert!(
-            !self.new_node_to_index.lock_shard_by_value(node).contains_key(node),
+            !self.new_node_to_index.get(node).is_some(),
             "node from previous graph present in new node collection"
         );
     }
@@ -1382,7 +1373,7 @@ fn panic_on_forbidden_read<D: Deps>(data: &DepGraphData<D>, dep_node_index: DepN
     if dep_node.is_none() {
         // Try to find it among the new nodes
         for shard in data.current.new_node_to_index.lock_shards() {
-            if let Some((node, _)) = shard.iter().find(|(_, index)| **index == dep_node_index) {
+            if let Some((node, _)) = shard.iter().find(|(_, index)| *index == dep_node_index) {
                 dep_node = Some(*node);
                 break;
             }
diff --git a/compiler/rustc_query_system/src/lib.rs b/compiler/rustc_query_system/src/lib.rs
index ee984095ad84b..41ebf2d5ac761 100644
--- a/compiler/rustc_query_system/src/lib.rs
+++ b/compiler/rustc_query_system/src/lib.rs
@@ -3,7 +3,6 @@
 #![feature(assert_matches)]
 #![feature(core_intrinsics)]
 #![feature(dropck_eyepatch)]
-#![feature(hash_raw_entry)]
 #![feature(let_chains)]
 #![feature(min_specialization)]
 #![warn(unreachable_pub)]
diff --git a/compiler/rustc_query_system/src/query/caches.rs b/compiler/rustc_query_system/src/query/caches.rs
index 3b47e7eba0ff8..30b5d7e595491 100644
--- a/compiler/rustc_query_system/src/query/caches.rs
+++ b/compiler/rustc_query_system/src/query/caches.rs
@@ -2,8 +2,7 @@ use std::fmt::Debug;
 use std::hash::Hash;
 use std::sync::OnceLock;
 
-use rustc_data_structures::fx::FxHashMap;
-use rustc_data_structures::sharded::{self, Sharded};
+use rustc_data_structures::sharded::ShardedHashMap;
 pub use rustc_data_structures::vec_cache::VecCache;
 use rustc_hir::def_id::LOCAL_CRATE;
 use rustc_index::Idx;
@@ -36,7 +35,7 @@ pub trait QueryCache: Sized {
 /// In-memory cache for queries whose keys aren't suitable for any of the
 /// more specialized kinds of cache. Backed by a sharded hashmap.
 pub struct DefaultCache<K, V> {
-    cache: Sharded<FxHashMap<K, (V, DepNodeIndex)>>,
+    cache: ShardedHashMap<K, (V, DepNodeIndex)>,
 }
 
 impl<K, V> Default for DefaultCache<K, V> {
@@ -55,19 +54,14 @@ where
 
     #[inline(always)]
     fn lookup(&self, key: &K) -> Option<(V, DepNodeIndex)> {
-        let key_hash = sharded::make_hash(key);
-        let lock = self.cache.lock_shard_by_hash(key_hash);
-        let result = lock.raw_entry().from_key_hashed_nocheck(key_hash, key);
-
-        if let Some((_, value)) = result { Some(*value) } else { None }
+        self.cache.get(key)
     }
 
     #[inline]
     fn complete(&self, key: K, value: V, index: DepNodeIndex) {
-        let mut lock = self.cache.lock_shard_by_value(&key);
         // We may be overwriting another value. This is all right, since the dep-graph
         // will check that the fingerprint matches.
-        lock.insert(key, (value, index));
+        self.cache.insert(key, (value, index));
     }
 
     fn iter(&self, f: &mut dyn FnMut(&Self::Key, &Self::Value, DepNodeIndex)) {