Skip to content

Commit

Permalink
implement advance_(back_)_by on more iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
the8472 committed Sep 30, 2021
1 parent 6dc08b9 commit 2c6e671
Show file tree
Hide file tree
Showing 15 changed files with 376 additions and 3 deletions.
1 change: 1 addition & 0 deletions library/alloc/src/lib.rs
Expand Up @@ -111,6 +111,7 @@
// that the feature-gate isn't enabled. Ideally, it wouldn't check for the feature gate for docs
// from other crates, but since this can only appear for lang items, it doesn't seem worth fixing.
#![feature(intra_doc_pointers)]
#![feature(iter_advance_by)]
#![feature(iter_zip)]
#![feature(lang_items)]
#![feature(layout_for_ptr)]
Expand Down
45 changes: 45 additions & 0 deletions library/alloc/src/vec/into_iter.rs
Expand Up @@ -161,6 +161,28 @@ impl<T, A: Allocator> Iterator for IntoIter<T, A> {
(exact, Some(exact))
}

#[inline]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
let step_size = self.len().min(n);
if mem::size_of::<T>() == 0 {
// SAFETY: due to unchecked casts of unsigned amounts to signed offsets the wraparound
// effectively results in unsigned pointers representing positions 0..usize::MAX,
// which is valid for ZSTs.
self.ptr = unsafe { arith_offset(self.ptr as *const i8, step_size as isize) as *mut T }
} else {
let to_drop = ptr::slice_from_raw_parts_mut(self.ptr as *mut T, step_size);
// SAFETY: the min() above ensures that step_size is in bounds
unsafe {
self.ptr = self.ptr.add(step_size);
ptr::drop_in_place(to_drop);
}
}
if step_size < n {
return Err(step_size);
}
Ok(())
}

#[inline]
fn count(self) -> usize {
self.len()
Expand Down Expand Up @@ -203,6 +225,29 @@ impl<T, A: Allocator> DoubleEndedIterator for IntoIter<T, A> {
Some(unsafe { ptr::read(self.end) })
}
}

#[inline]
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
let step_size = self.len().min(n);
if mem::size_of::<T>() == 0 {
// SAFETY: same as for advance_by()
self.end = unsafe {
arith_offset(self.end as *const i8, step_size.wrapping_neg() as isize) as *mut T
}
} else {
// SAFETY: same as for advance_by()
self.end = unsafe { self.end.offset(step_size.wrapping_neg() as isize) };
let to_drop = ptr::slice_from_raw_parts_mut(self.end as *mut T, step_size);
// SAFETY: same as for advance_by()
unsafe {
ptr::drop_in_place(to_drop);
}
}
if step_size < n {
return Err(step_size);
}
Ok(())
}
}

#[stable(feature = "rust1", since = "1.0.0")]
Expand Down
1 change: 1 addition & 0 deletions library/alloc/tests/lib.rs
Expand Up @@ -18,6 +18,7 @@
#![feature(binary_heap_retain)]
#![feature(binary_heap_as_slice)]
#![feature(inplace_iteration)]
#![feature(iter_advance_by)]
#![feature(slice_group_by)]
#![feature(slice_partition_dedup)]
#![feature(vec_spare_capacity)]
Expand Down
18 changes: 18 additions & 0 deletions library/alloc/tests/vec.rs
Expand Up @@ -970,6 +970,24 @@ fn test_into_iter_leak() {
assert_eq!(unsafe { DROPS }, 3);
}

#[test]
fn test_into_iter_advance_by() {
let mut i = vec![1, 2, 3, 4, 5].into_iter();
i.advance_by(0).unwrap();
i.advance_back_by(0).unwrap();
assert_eq!(i.as_slice(), [1, 2, 3, 4, 5]);

i.advance_by(1).unwrap();
i.advance_back_by(1).unwrap();
assert_eq!(i.as_slice(), [2, 3, 4]);

assert_eq!(i.advance_back_by(usize::MAX), Err(3));

assert_eq!(i.advance_by(usize::MAX), Err(0));

assert_eq!(i.len(), 0);
}

#[test]
fn test_from_iter_specialization() {
let src: Vec<usize> = vec![0usize; 1];
Expand Down
10 changes: 10 additions & 0 deletions library/core/src/iter/adapters/copied.rs
Expand Up @@ -76,6 +76,11 @@ where
self.it.count()
}

#[inline]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
self.it.advance_by(n)
}

#[doc(hidden)]
unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> T
where
Expand Down Expand Up @@ -112,6 +117,11 @@ where
{
self.it.rfold(init, copy_fold(f))
}

#[inline]
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
self.it.advance_back_by(n)
}
}

#[stable(feature = "iter_copied", since = "1.36.0")]
Expand Down
21 changes: 21 additions & 0 deletions library/core/src/iter/adapters/cycle.rs
Expand Up @@ -79,6 +79,27 @@ where
}
}

#[inline]
#[rustc_inherit_overflow_checks]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
let mut rem = n;
match self.iter.advance_by(rem) {
ret @ Ok(_) => return ret,
Err(advanced) => rem -= advanced,
}

while rem > 0 {
self.iter = self.orig.clone();
match self.iter.advance_by(rem) {
ret @ Ok(_) => return ret,
Err(0) => return Err(n - rem),
Err(advanced) => rem -= advanced,
}
}

Ok(())
}

// No `fold` override, because `fold` doesn't make much sense for `Cycle`,
// and we can't do anything better than the default.
}
Expand Down
22 changes: 22 additions & 0 deletions library/core/src/iter/adapters/enumerate.rs
Expand Up @@ -112,6 +112,21 @@ where
self.iter.fold(init, enumerate(self.count, fold))
}

#[inline]
#[rustc_inherit_overflow_checks]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
match self.iter.advance_by(n) {
ret @ Ok(_) => {
self.count += n;
ret
}
ret @ Err(advanced) => {
self.count += advanced;
ret
}
}
}

#[rustc_inherit_overflow_checks]
#[doc(hidden)]
unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> <Self as Iterator>::Item
Expand Down Expand Up @@ -191,6 +206,13 @@ where
let count = self.count + self.iter.len();
self.iter.rfold(init, enumerate(count, fold))
}

#[inline]
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
// we do not need to update the count since that only tallies the number of items
// consumed from the front. consuming items from the back can never reduce that.
self.iter.advance_back_by(n)
}
}

#[stable(feature = "rust1", since = "1.0.0")]
Expand Down
69 changes: 69 additions & 0 deletions library/core/src/iter/adapters/flatten.rs
Expand Up @@ -391,6 +391,40 @@ where

init
}

#[inline]
#[rustc_inherit_overflow_checks]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
let mut rem = n;
loop {
if let Some(ref mut front) = self.frontiter {
match front.advance_by(rem) {
ret @ Ok(_) => return ret,
Err(advanced) => rem -= advanced,
}
}
self.frontiter = match self.iter.next() {
Some(iterable) => Some(iterable.into_iter()),
_ => break,
}
}

self.frontiter = None;

if let Some(ref mut back) = self.backiter {
if let Err(advanced) = back.advance_by(rem) {
rem -= advanced
}
}

if rem > 0 {
return Err(n - rem);
}

self.backiter = None;

Ok(())
}
}

impl<I, U> DoubleEndedIterator for FlattenCompat<I, U>
Expand Down Expand Up @@ -486,6 +520,41 @@ where

init
}

#[inline]
#[rustc_inherit_overflow_checks]
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
let mut rem = n;
loop {
if let Some(ref mut back) = self.backiter {
match back.advance_back_by(rem) {
ret @ Ok(_) => return ret,
Err(advanced) => rem -= advanced,
}
}
match self.iter.next_back() {
Some(iterable) => self.backiter = Some(iterable.into_iter()),
_ => break,
}
}

self.backiter = None;

if let Some(ref mut front) = self.frontiter {
match front.advance_back_by(rem) {
ret @ Ok(_) => return ret,
Err(advanced) => rem -= advanced,
}
}

if rem > 0 {
return Err(n - rem);
}

self.frontiter = None;

Ok(())
}
}

trait ConstSizeIntoIterator: IntoIterator {
Expand Down
21 changes: 21 additions & 0 deletions library/core/src/iter/adapters/skip.rs
Expand Up @@ -114,6 +114,17 @@ where
}
self.iter.fold(init, fold)
}

#[inline]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
if self.n >= n {
self.n -= n;
return Ok(());
}
let rem = n - self.n;
self.n = 0;
self.iter.advance_by(rem)
}
}

#[stable(feature = "rust1", since = "1.0.0")]
Expand Down Expand Up @@ -174,6 +185,16 @@ where

self.try_rfold(init, ok(fold)).unwrap()
}

#[inline]
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
let min = crate::cmp::min(self.len(), n);
return match self.iter.advance_back_by(min) {
ret @ Ok(_) if n <= min => ret,
Ok(_) => Err(min),
_ => panic!("ExactSizeIterator contract violation"),
};
}
}

#[stable(feature = "fused", since = "1.26.0")]
Expand Down
34 changes: 34 additions & 0 deletions library/core/src/iter/adapters/take.rs
Expand Up @@ -111,6 +111,22 @@ where

self.try_fold(init, ok(fold)).unwrap()
}

#[inline]
#[rustc_inherit_overflow_checks]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
let min = crate::cmp::min(self.n, n);
return match self.iter.advance_by(min) {
Ok(_) => {
self.n -= min;
if min < n { Err(min) } else { Ok(()) }
}
ret @ Err(advanced) => {
self.n -= advanced;
ret
}
};
}
}

#[unstable(issue = "none", feature = "inplace_iteration")]
Expand Down Expand Up @@ -197,6 +213,24 @@ where
}
}
}

#[inline]
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
let inner_len = self.iter.len();
let len = self.n;
let remainder = len.saturating_sub(n);
let to_advance = inner_len - remainder;
match self.iter.advance_back_by(to_advance) {
Ok(_) => {
self.n = remainder;
if n > len {
return Err(len);
}
return Ok(());
}
_ => panic!("ExactSizeIterator contract violation"),
}
}
}

#[stable(feature = "rust1", since = "1.0.0")]
Expand Down

0 comments on commit 2c6e671

Please sign in to comment.