diff --git a/async/src/connection.rs b/async/src/connection.rs index 62c13282..52148770 100644 --- a/async/src/connection.rs +++ b/async/src/connection.rs @@ -261,6 +261,10 @@ impl Connection { self.finished_get_reqs.remove(&id) } + pub fn has_pending_deliveries(&self) -> bool { + self.channels.values().any(|channel| channel.queues.values().any(|queue| queue.consumers.values().any(|consumer| !consumer.messages.is_empty()))) + } + /// gets the next message corresponding to a channel, queue and consumer tag /// /// if the channel id, queue and consumer tag have no link, the method diff --git a/futures/src/channel.rs b/futures/src/channel.rs index 54815e6e..47e08b56 100644 --- a/futures/src/channel.rs +++ b/futures/src/channel.rs @@ -335,6 +335,7 @@ impl Channel { channel_id: self.id, queue: queue.name(), consumer_tag: consumer_tag.to_string(), + registered: false, }; Box::new(self.run_on_locked_transport("basic_consume", "Could not start consumer", |transport| { diff --git a/futures/src/consumer.rs b/futures/src/consumer.rs index 60d15fdd..c2d64016 100644 --- a/futures/src/consumer.rs +++ b/futures/src/consumer.rs @@ -12,6 +12,7 @@ pub struct Consumer { pub channel_id: u16, pub queue: String, pub consumer_tag: String, + pub registered: bool, } impl Stream for Consumer { @@ -21,6 +22,10 @@ impl Stream for Consumer { fn poll(&mut self) -> Poll, io::Error> { trace!("poll; consumer_tag={:?}", self.consumer_tag); let mut transport = lock_transport!(self.transport); + if !self.registered { + transport.register_consumer(&self.consumer_tag, task::current()); + self.registered = true; + } if let Async::Ready(_) = transport.poll()? { trace!("poll transport; consumer_tag={:?} status=Ready", self.consumer_tag); return Ok(Async::Ready(None)); diff --git a/futures/src/transport.rs b/futures/src/transport.rs index c5d9a0ad..96662073 100644 --- a/futures/src/transport.rs +++ b/futures/src/transport.rs @@ -6,9 +6,10 @@ use nom::{IResult,Offset}; use cookie_factory::GenError; use bytes::{BufMut, BytesMut}; use std::cmp; +use std::collections::HashMap; use std::iter::repeat; use std::io::{self,Error,ErrorKind}; -use futures::{Async,AsyncSink,Poll,Sink,StartSend,Stream,Future,future}; +use futures::{Async,AsyncSink,Poll,Sink,StartSend,Stream,Future,future,task}; use tokio_io::{AsyncRead,AsyncWrite}; use tokio_io::codec::{Decoder,Encoder,Framed}; use channel::BasicProperties; @@ -102,8 +103,9 @@ impl Encoder for AMQPCodec { /// Wrappers over a `Framed` stream using `AMQPCodec` and lapin-async's `Connection` pub struct AMQPTransport { - upstream: Framed, - pub conn: Connection, + upstream: Framed, + consumers: HashMap, + pub conn: Connection, } impl AMQPTransport @@ -130,6 +132,7 @@ impl AMQPTransport }; let t = AMQPTransport { upstream: stream.framed(codec), + consumers: HashMap::new(), conn: conn, }; let connector = AMQPTransportConnector { @@ -167,6 +170,7 @@ impl AMQPTransport /// * In case of error, it will return `Err(e)` /// * If the socket was closed, it will return `Ok(Async::Ready(()))` pub fn poll_recv(&mut self) -> Poll<(), io::Error> { + let mut got_frame = false; loop { match self.upstream.poll() { Ok(Async::Ready(Some(frame))) => { @@ -175,6 +179,7 @@ impl AMQPTransport let err = format!("failed to handle frame: {:?}", e); return Err(io::Error::new(io::ErrorKind::Other, err)); } + got_frame = true; }, Ok(Async::Ready(None)) => { trace!("transport poll_recv; status=Ready(None)"); @@ -182,6 +187,11 @@ impl AMQPTransport }, Ok(Async::NotReady) => { trace!("transport poll_recv; status=NotReady"); + if got_frame && self.conn.has_pending_deliveries() { + for t in self.consumers.values() { + t.notify(); + } + } return Ok(Async::NotReady); }, Err(e) => { @@ -209,6 +219,10 @@ impl AMQPTransport } self.poll_complete() } + + pub fn register_consumer(&mut self, consumer_tag: &str, consumer_task: task::Task) { + self.consumers.insert(consumer_tag.to_string(), consumer_task); + } } impl Future for AMQPTransport