Skip to content

Commit

Permalink
Rollup merge of rust-lang#57043 - ssomers:master, r=alexcrichton
Browse files Browse the repository at this point in the history
Fix poor worst case performance of set intersection

Specifically, intersection of asymmetrically sized sets when the large set is on the left. See also the [latest answer on stackoverflow](https://stackoverflow.com/questions/35439376/python-set-intersection-is-faster-then-rust-hashset-intersection).

Also applied to the union member, where the effect is much less but still measurable.

Formatted the changed code only, does not increase the error count reported by tidy check, and tried to adhere to the spirit of the unit tests.
  • Loading branch information
Centril committed Jan 14, 2019
2 parents d106808 + cef2e2f commit 5bc95de
Showing 1 changed file with 60 additions and 7 deletions.
67 changes: 60 additions & 7 deletions src/libstd/collections/hash/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,16 @@ impl<T, S> HashSet<T, S>
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn intersection<'a>(&'a self, other: &'a HashSet<T, S>) -> Intersection<'a, T, S> {
Intersection {
iter: self.iter(),
other,
if self.len() <= other.len() {
Intersection {
iter: self.iter(),
other,
}
} else {
Intersection {
iter: other.iter(),
other: self,
}
}
}

Expand All @@ -436,7 +443,15 @@ impl<T, S> HashSet<T, S>
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn union<'a>(&'a self, other: &'a HashSet<T, S>) -> Union<'a, T, S> {
Union { iter: self.iter().chain(other.difference(self)) }
if self.len() <= other.len() {
Union {
iter: self.iter().chain(other.difference(self)),
}
} else {
Union {
iter: other.iter().chain(self.difference(other)),
}
}
}

/// Returns the number of elements in the set.
Expand Down Expand Up @@ -584,7 +599,11 @@ impl<T, S> HashSet<T, S>
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn is_disjoint(&self, other: &HashSet<T, S>) -> bool {
self.iter().all(|v| !other.contains(v))
if self.len() <= other.len() {
self.iter().all(|v| !other.contains(v))
} else {
other.iter().all(|v| !self.contains(v))
}
}

/// Returns `true` if the set is a subset of another,
Expand Down Expand Up @@ -1494,6 +1513,7 @@ mod test_set {
fn test_intersection() {
let mut a = HashSet::new();
let mut b = HashSet::new();
assert!(a.intersection(&b).next().is_none());

assert!(a.insert(11));
assert!(a.insert(1));
Expand All @@ -1518,6 +1538,22 @@ mod test_set {
i += 1
}
assert_eq!(i, expected.len());

assert!(a.insert(9)); // make a bigger than b

i = 0;
for x in a.intersection(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());

i = 0;
for x in b.intersection(&a) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
}

#[test]
Expand Down Expand Up @@ -1573,11 +1609,11 @@ mod test_set {
fn test_union() {
let mut a = HashSet::new();
let mut b = HashSet::new();
assert!(a.union(&b).next().is_none());
assert!(b.union(&a).next().is_none());

assert!(a.insert(1));
assert!(a.insert(3));
assert!(a.insert(5));
assert!(a.insert(9));
assert!(a.insert(11));
assert!(a.insert(16));
assert!(a.insert(19));
Expand All @@ -1597,6 +1633,23 @@ mod test_set {
i += 1
}
assert_eq!(i, expected.len());

assert!(a.insert(9)); // make a bigger than b
assert!(a.insert(5));

i = 0;
for x in a.union(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());

i = 0;
for x in b.union(&a) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
}

#[test]
Expand Down

0 comments on commit 5bc95de

Please sign in to comment.