Skip to content

Commit

Permalink
Add select2, one where you get the future back
Browse files Browse the repository at this point in the history
  • Loading branch information
alexcrichton committed May 12, 2016
1 parent a5403a4 commit af51c4c
Show file tree
Hide file tree
Showing 5 changed files with 337 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
language: rust
rust:
- stable
#- stable
- beta
- nightly
sudo: false
Expand Down
12 changes: 12 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ mod map;
mod map_err;
mod or_else;
mod select;
mod select2;
mod then;
pub use and_then::AndThen;
pub use flatten::Flatten;
Expand All @@ -52,6 +53,7 @@ pub use map::Map;
pub use map_err::MapErr;
pub use or_else::OrElse;
pub use select::Select;
pub use select2::{Select2, Select2Next};
pub use then::Then;

// streams
Expand Down Expand Up @@ -222,6 +224,16 @@ pub trait Future: Send + 'static {
assert_future::<Self::Item, Self::Error, _>(f)
}

fn select2<B>(self, other: B) -> Select2<Self, B::Future>
where B: IntoFuture<Item=Self::Item, Error=Self::Error>,
Self: Sized,
{
let f = select2::new(self, other.into_future());
assert_future::<(Self::Item, Select2Next<Self, B::Future>),
(Self::Error, Select2Next<Self, B::Future>),
_>(f)
}

fn join<B>(self, other: B) -> Join<Self, B::Future>
where B: IntoFuture<Error=Self::Error>,
Self: Sized,
Expand Down
209 changes: 209 additions & 0 deletions src/select2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
use std::mem;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};

use {PollResult, Callback, Future, PollError};
use cell;
use slot::Slot;
use util;

pub struct Select2<A, B> where A: Future, B: Future<Item=A::Item, Error=A::Error> {
state: State<A, B>,
}

pub struct Select2Next<A, B> where A: Future, B: Future<Item=A::Item, Error=A::Error> {
state: Arc<Scheduled<A, B>>,
}

pub fn new<A, B>(a: A, b: B) -> Select2<A, B>
where A: Future,
B: Future<Item=A::Item, Error=A::Error>
{
Select2 {
state: State::Start(a, b),
}
}

impl<A, B> Future for Select2<A, B>
where A: Future,
B: Future<Item=A::Item, Error=A::Error>,
{
type Item = (A::Item, Select2Next<A, B>);
type Error = (A::Error, Select2Next<A, B>);

fn schedule<G>(&mut self, g: G)
where G: FnOnce(PollResult<Self::Item, Self::Error>) + Send + 'static
{
// TODO: pretty unfortunate we gotta box this up
self.schedule_boxed(Box::new(g))
}

fn schedule_boxed(&mut self, cb: Box<Callback<Self::Item, Self::Error>>) {
let (mut a, mut b) = match mem::replace(&mut self.state, State::Canceled) {
State::Start(a, b) => (a, b),
State::Canceled => return cb.call(Err(PollError::Canceled)),
State::Scheduled(s) => {
self.state = State::Scheduled(s);
return cb.call(Err(util::reused()))
}
};

// TODO: optimize the case that either future is immediately done.
let data1 = Arc::new(Scheduled {
futures: cell::AtomicCell::new(None),
state: AtomicUsize::new(0),
cb: cell::AtomicCell::new(Some(cb)),
data: Slot::new(None),
});
let data2 = data1.clone();
let data3 = data2.clone();

a.schedule(move |result| Scheduled::finish(data1, result));
b.schedule(move |result| Scheduled::finish(data2, result));
*data3.futures.borrow().expect("[s2] futures locked") = Some((a, b));

// Inform our state flags that the futures are available to be canceled.
// If the cancellation flag is set then we never turn SET on and instead
// we just cancel the futures and go on our merry way.
let mut state = data3.state.load(Ordering::SeqCst);
loop {
assert!(state & SET == 0);
if state & CANCEL != 0 {
assert!(state & DONE != 0);
data3.cancel();
break
}
let old = data3.state.compare_and_swap(state, state | SET,
Ordering::SeqCst);
if old == state {
break
}
state = old;
}

self.state = State::Scheduled(data3);
}
}

enum State<A, B> where A: Future, B: Future<Item=A::Item, Error=A::Error> {
Start(A, B),
Scheduled(Arc<Scheduled<A, B>>),
Canceled,
}

const DONE: usize = 1 << 0;
const CANCEL: usize = 1 << 1;
const SET: usize = 1 << 2;

struct Scheduled<A, B>
where A: Future,
B: Future<Item=A::Item, Error=A::Error>,
{
futures: cell::AtomicCell<Option<(A, B)>>,
state: AtomicUsize,
cb: cell::AtomicCell<Option<Box<Callback<(A::Item, Select2Next<A, B>),
(A::Error, Select2Next<A, B>)>>>>,
data: Slot<PollResult<A::Item, A::Error>>,
}

impl<A, B> Scheduled<A, B>
where A: Future,
B: Future<Item=A::Item, Error=A::Error>,
{
fn finish(me: Arc<Scheduled<A, B>>,
val: PollResult<A::Item, A::Error>) {
let old = me.state.fetch_or(DONE, Ordering::SeqCst);

// if the other side finished before we did then we just drop our result
// on the ground and let them take care of everything.
if old & DONE != 0 {
me.data.try_produce(val).ok().unwrap();
return
}

let cb = me.cb.borrow().expect("[s2] done but cb is locked")
.take().expect("[s2] done done but cb not here");
let next = Select2Next { state: me };
cb.call(match val {
Ok(v) => Ok((v, next)),
Err(PollError::Other(e)) => Err(PollError::Other((e, next))),
Err(PollError::Panicked(p)) => Err(PollError::Panicked(p)),
Err(PollError::Canceled) => Err(PollError::Canceled),
})
}

fn cancel(&self) {
let pair = self.futures.borrow().expect("[s2] futures locked in cancel")
.take().expect("[s2] cancel but futures not here");
drop(pair)
}
}

impl<A, B> Drop for Select2<A, B>
where A: Future,
B: Future<Item=A::Item, Error=A::Error>
{
fn drop(&mut self) {
if let State::Scheduled(ref state) = self.state {
// If the old state was "nothing has happened", then we cancel both
// futures. Otherwise one future has finished which implies that the
// future we returned to that closure is responsible for canceling
// itself.
let old = state.state.compare_and_swap(SET, 0, Ordering::SeqCst);
if old == SET {
state.cancel();
}
}
}
}

impl<A, B> Future for Select2Next<A, B>
where A: Future,
B: Future<Item=A::Item, Error=A::Error>
{
type Item = A::Item;
type Error = A::Error;

fn schedule<G>(&mut self, g: G)
where G: FnOnce(PollResult<Self::Item, Self::Error>) + Send + 'static
{
self.state.data.on_full(|slot| {
g(slot.try_consume().unwrap());
});
}

fn schedule_boxed(&mut self, cb: Box<Callback<Self::Item, Self::Error>>) {
self.schedule(|r| cb.call(r))
}
}

impl<A, B> Drop for Select2Next<A, B>
where A: Future,
B: Future<Item=A::Item, Error=A::Error>
{
fn drop(&mut self) {
let mut state = self.state.state.load(Ordering::SeqCst);
loop {
// We should in theory only be here if one half is done and we
// haven't canceled yet.
assert!(state & CANCEL == 0);
assert!(state & DONE != 0);

// Our next state will indicate that we are canceled, and if the
// futures are available to us we're gonna take them.
let next = state | CANCEL & !SET;
let old = self.state.state.compare_and_swap(state, next,
Ordering::SeqCst);
if old == state {
break
}
state = old
}

// If the old state indicated that we had the futures, then we just took
// ownership of them so we cancel the futures here.
if state & SET != 0 {
self.state.cancel();
}
}
}
7 changes: 3 additions & 4 deletions src/util.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::panic::{self, AssertUnwindSafe};

use {PollResult, PollError};

// TODO: reexport this?
Expand All @@ -6,10 +8,7 @@ struct ReuseFuture;
pub fn recover<F, R, E>(f: F) -> PollResult<R, E>
where F: FnOnce() -> R + Send + 'static
{
// use std::panic::{recover, AssertRecoverSafe};
//
// recover(AssertRecoverSafe(f)).map_err(|_| FutureError::Panicked)
Ok(f())
panic::catch_unwind(AssertUnwindSafe(f)).map_err(PollError::Panicked)
}

pub fn reused<E>() -> PollError<E> {
Expand Down
113 changes: 112 additions & 1 deletion tests/all.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
extern crate futures;

use std::sync::mpsc::channel;
use std::sync::mpsc::{channel, TryRecvError};
use std::fmt;

use futures::*;
Expand Down Expand Up @@ -341,3 +341,114 @@ fn collect_collects() {

// TODO: needs more tests
}

#[test]
fn select2() {
fn d<T, U, E>(r: Result<(T, U), (E, U)>) -> Result<T, E> {
match r {
Ok((t, _u)) => Ok(t),
Err((e, _u)) => Err(e),
}
}

assert_done(|| f_ok(2).select2(empty()).then(d), Ok(2));
assert_done(|| empty().select2(f_ok(2)).then(d), Ok(2));
assert_done(|| f_err(2).select2(empty()).then(d), Err(2));
assert_done(|| empty().select2(f_err(2)).then(d), Err(2));

assert_done(|| {
f_ok(1).select2(f_ok(2))
.map_err(|_| 0)
.and_then(|(a, b)| b.map(move |b| a + b))
}, Ok(3));

// Finish one half of a select and then fail the second, ensuring that we
// get the notification of the second one.
{
let ((a, b), (c, d)) = (promise::<i32, u32>(), promise::<i32, u32>());
let mut f = a.select2(c);
let (tx, rx) = channel();
f.schedule(move |r| tx.send(r).unwrap());
b.finish(1);
let (val, mut next) = rx.recv().unwrap().ok().unwrap();
assert_eq!(val, 1);
let (tx, rx) = channel();
next.schedule(move |r| tx.send(r).unwrap());
assert_eq!(rx.try_recv().err().unwrap(), TryRecvError::Empty);
d.fail(2);
match rx.recv().unwrap() {
Err(PollError::Other(2)) => {}
_ => panic!("wrong error"),
}
}

// Fail the second half and ensure that we see the first one finish
{
let ((a, b), (c, d)) = (promise::<i32, u32>(), promise::<i32, u32>());
let mut f = a.select2(c);
let (tx, rx) = channel();
f.schedule(move |r| tx.send(r).unwrap());
d.fail(1);
let mut next = match rx.recv().unwrap() {
Err(PollError::Other((1, next))) => next,
_ => panic!("wrong result"),
};
let (tx, rx) = channel();
next.schedule(move |r| tx.send(r).unwrap());
assert_eq!(rx.try_recv().err().unwrap(), TryRecvError::Empty);
b.finish(2);
assert_eq!(rx.recv().unwrap().ok().unwrap(), 2);
}

// Cancelling the first half should cancel the second
{
let ((a, _b), (c, _d)) = (promise::<i32, u32>(), promise::<i32, u32>());
let ((atx, arx), (ctx, crx)) = (channel(), channel());
let a = a.map(move |v| { atx.send(v).unwrap(); v });
let c = c.map(move |v| { ctx.send(v).unwrap(); v });
let f = a.select2(c);
drop(f);
assert!(crx.recv().is_err());
assert!(arx.recv().is_err());
}

// Cancel after a schedule
{
let ((a, _b), (c, _d)) = (promise::<i32, u32>(), promise::<i32, u32>());
let ((atx, arx), (ctx, crx)) = (channel(), channel());
let a = a.map(move |v| { atx.send(v).unwrap(); v });
let c = c.map(move |v| { ctx.send(v).unwrap(); v });
let mut f = a.select2(c);
f.schedule(|_| ());
drop(f);
assert!(crx.recv().is_err());
assert!(arx.recv().is_err());
}

// Cancel propagates
{
let ((a, b), (c, _d)) = (promise::<i32, u32>(), promise::<i32, u32>());
let ((atx, arx), (ctx, crx)) = (channel(), channel());
let a = a.map(move |v| { atx.send(v).unwrap(); v });
let c = c.map(move |v| { ctx.send(v).unwrap(); v });
let (tx, rx) = channel();
let mut f = a.select2(c).map(move |_| tx.send(()).unwrap());
f.schedule(|_| ());
drop(b);
assert!(crx.recv().is_err());
assert!(arx.recv().is_err());
assert!(rx.recv().is_err());
}

// Cancel on early drop
{
let (tx, rx) = channel();
let mut f = f_ok(1).select2(empty().map(move |()| {
tx.send(()).unwrap();
1
}));
f.schedule(|_| ());
drop(f);
assert!(rx.recv().is_err());
}
}

0 comments on commit af51c4c

Please sign in to comment.