Skip to content

Commit

Permalink
Auto merge of #64383 - pcpthm:btreeset-size-hint, r=dtolnay
Browse files Browse the repository at this point in the history
Improve BTreeSet::Intersection::size_hint

A comment on `IntersectionInner` mentions `small_iter` should be smaller than `other_iter` but this condition is broken while iterating because those two iterators can be consumed at a different rate. I added a test to demonstrate this situation.
<del>I made `small_iter.len() < other_iter.len()` always true by swapping two iterators when that condition became false. This change affects the return value of `size_hint`. The previous result was also correct but this new version always returns smaller upper bound than the previous version.</del>
I changed `size_hint` to taking minimum of both lengths of iterators and renamed fields to `a` and `b` to match `Union` iterator.
  • Loading branch information
bors committed Sep 16, 2019
2 parents 3e3e06d + 4333b86 commit b6269f2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 22 deletions.
44 changes: 22 additions & 22 deletions src/liballoc/collections/btree/set.rs
Expand Up @@ -3,7 +3,7 @@

use core::borrow::Borrow;
use core::cmp::Ordering::{self, Less, Greater, Equal};
use core::cmp::max;
use core::cmp::{max, min};
use core::fmt::{self, Debug};
use core::iter::{Peekable, FromIterator, FusedIterator};
use core::ops::{BitOr, BitAnd, BitXor, Sub, RangeBounds};
Expand Down Expand Up @@ -187,8 +187,8 @@ pub struct Intersection<'a, T: 'a> {
}
enum IntersectionInner<'a, T: 'a> {
Stitch {
small_iter: Iter<'a, T>, // for size_hint, should be the smaller of the sets
other_iter: Iter<'a, T>,
a: Iter<'a, T>,
b: Iter<'a, T>,
},
Search {
small_iter: Iter<'a, T>,
Expand All @@ -201,12 +201,12 @@ impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.inner {
IntersectionInner::Stitch {
small_iter,
other_iter,
a,
b,
} => f
.debug_tuple("Intersection")
.field(&small_iter)
.field(&other_iter)
.field(&a)
.field(&b)
.finish(),
IntersectionInner::Search {
small_iter,
Expand Down Expand Up @@ -397,8 +397,8 @@ impl<T: Ord> BTreeSet<T> {
// Iterate both sets jointly, spotting matches along the way.
Intersection {
inner: IntersectionInner::Stitch {
small_iter: small.iter(),
other_iter: other.iter(),
a: small.iter(),
b: other.iter(),
},
}
} else {
Expand Down Expand Up @@ -1221,11 +1221,11 @@ impl<T> Clone for Intersection<'_, T> {
Intersection {
inner: match &self.inner {
IntersectionInner::Stitch {
small_iter,
other_iter,
a,
b,
} => IntersectionInner::Stitch {
small_iter: small_iter.clone(),
other_iter: other_iter.clone(),
a: a.clone(),
b: b.clone(),
},
IntersectionInner::Search {
small_iter,
Expand All @@ -1245,16 +1245,16 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> {
fn next(&mut self) -> Option<&'a T> {
match &mut self.inner {
IntersectionInner::Stitch {
small_iter,
other_iter,
a,
b,
} => {
let mut small_next = small_iter.next()?;
let mut other_next = other_iter.next()?;
let mut a_next = a.next()?;
let mut b_next = b.next()?;
loop {
match Ord::cmp(small_next, other_next) {
Less => small_next = small_iter.next()?,
Greater => other_next = other_iter.next()?,
Equal => return Some(small_next),
match Ord::cmp(a_next, b_next) {
Less => a_next = a.next()?,
Greater => b_next = b.next()?,
Equal => return Some(a_next),
}
}
}
Expand All @@ -1272,7 +1272,7 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> {

fn size_hint(&self) -> (usize, Option<usize>) {
let min_len = match &self.inner {
IntersectionInner::Stitch { small_iter, .. } => small_iter.len(),
IntersectionInner::Stitch { a, b } => min(a.len(), b.len()),
IntersectionInner::Search { small_iter, .. } => small_iter.len(),
};
(0, Some(min_len))
Expand Down
11 changes: 11 additions & 0 deletions src/liballoc/tests/btree/set.rs
Expand Up @@ -90,6 +90,17 @@ fn test_intersection() {
&[1, 3, 11, 77, 103]);
}

#[test]
fn test_intersection_size_hint() {
let x: BTreeSet<i32> = [3, 4].iter().copied().collect();
let y: BTreeSet<i32> = [1, 2, 3].iter().copied().collect();
let mut iter = x.intersection(&y);
assert_eq!(iter.size_hint(), (0, Some(2)));
assert_eq!(iter.next(), Some(&3));
assert_eq!(iter.size_hint(), (0, Some(0)));
assert_eq!(iter.next(), None);
}

#[test]
fn test_difference() {
fn check_difference(a: &[i32], b: &[i32], expected: &[i32]) {
Expand Down

0 comments on commit b6269f2

Please sign in to comment.