diff --git a/src/proto/h2/client.rs b/src/proto/h2/client.rs index b8d9951928..7226c98bf5 100644 --- a/src/proto/h2/client.rs +++ b/src/proto/h2/client.rs @@ -6,7 +6,7 @@ use crate::rt::{Read, Write}; use bytes::Bytes; use futures_channel::mpsc::{Receiver, Sender}; use futures_channel::{mpsc, oneshot}; -use futures_util::future::{self, Either, FutureExt as _, Select}; +use futures_util::future::{Either, FusedFuture, FutureExt as _}; use futures_util::stream::{StreamExt as _, StreamFuture}; use h2::client::{Builder, Connection, SendRequest}; use h2::SendStream; @@ -143,7 +143,10 @@ where } else { (Either::Right(conn), ping::disabled()) }; - let conn: ConnMapErr = ConnMapErr { conn }; + let conn: ConnMapErr = ConnMapErr { + conn, + is_terminated: false, + }; exec.execute_h2_future(H2ClientFuture::Task { task: ConnTask::new(conn, conn_drop_rx, cancel_tx), @@ -218,6 +221,8 @@ pin_project! { { #[pin] conn: Either, Connection, SendBuf<::Data>>>, + #[pin] + is_terminated: bool, } } @@ -229,10 +234,26 @@ where type Output = Result<(), ()>; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { - self.project() - .conn - .poll(cx) - .map_err(|e| debug!("connection error: {}", e)) + let mut this = self.project(); + + if *this.is_terminated { + return Poll::Pending; + } + let polled = this.conn.poll(cx); + if polled.is_ready() { + *this.is_terminated = true; + } + polled.map_err(|e| debug!("connection error: {}", e)) + } +} + +impl FusedFuture for ConnMapErr +where + B: Body, + T: Read + Write + Unpin, +{ + fn is_terminated(&self) -> bool { + self.is_terminated } } @@ -245,10 +266,11 @@ pin_project! { T: Unpin, { #[pin] - select: Select, StreamFuture>>, + drop_rx: StreamFuture>, #[pin] cancel_tx: Option>, - conn: Option>, + #[pin] + conn: ConnMapErr, } } @@ -263,9 +285,9 @@ where cancel_tx: oneshot::Sender, ) -> Self { Self { - select: future::select(conn, drop_rx), + drop_rx, cancel_tx: Some(cancel_tx), - conn: None, + conn, } } } @@ -280,25 +302,24 @@ where fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { let mut this = self.project(); - if let Some(conn) = this.conn { - conn.poll_unpin(cx).map(|_| ()) - } else { - match ready!(this.select.poll_unpin(cx)) { - Either::Left((_, _)) => { - // ok or err, the `conn` has finished - return Poll::Ready(()); - } - Either::Right((_, b)) => { - // mpsc has been dropped, hopefully polling - // the connection some more should start shutdown - // and then close - trace!("send_request dropped, starting conn shutdown"); - drop(this.cancel_tx.take().expect("Future polled twice")); - this.conn = &mut Some(b); - return Poll::Pending; - } - } + if !this.conn.is_terminated() { + if let Poll::Ready(_) = this.conn.poll_unpin(cx) { + // ok or err, the `conn` has finished. + return Poll::Ready(()); + }; } + + if !this.drop_rx.is_terminated() { + if let Poll::Ready(_) = this.drop_rx.poll_unpin(cx) { + // mpsc has been dropped, hopefully polling + // the connection some more should start shutdown + // and then close. + trace!("send_request dropped, starting conn shutdown"); + drop(this.cancel_tx.take().expect("ConnTask Future polled twice")); + } + }; + + Poll::Pending } }