1+ use std:: convert:: TryFrom ;
2+ use std:: future:: Future ;
13use std:: io;
4+ use std:: num:: NonZero ;
5+ use std:: ops:: DerefMut ;
26use std:: pin:: Pin ;
37use std:: sync:: Arc ;
48use std:: task:: { ready, Context , Poll } ;
@@ -7,7 +11,7 @@ use futures::FutureExt;
711use tokio:: io:: AsyncWrite ;
812use tokio:: sync:: mpsc:: error:: SendError ;
913use tokio:: sync:: mpsc:: { self , OwnedPermit } ;
10- use tokio:: sync:: { Mutex , OwnedMutexGuard } ;
14+ use tokio:: sync:: { Mutex , Notify , OwnedMutexGuard } ;
1115
1216use super :: ChannelMsg ;
1317use crate :: { ChannelId , CryptoVec } ;
@@ -16,13 +20,34 @@ type BoxedThreadsafeFuture<T> = Pin<Box<dyn Sync + Send + std::future::Future<Ou
1620type OwnedPermitFuture < S > =
1721 BoxedThreadsafeFuture < Result < ( OwnedPermit < S > , ChannelMsg , usize ) , SendError < ( ) > > > ;
1822
23+ struct WatchNotification ( Pin < Box < dyn Sync + Send + Future < Output = ( ) > > > ) ;
24+
25+ /// A single future that becomes ready once the window size
26+ /// changes to a positive value
27+ impl WatchNotification {
28+ fn new ( n : Arc < Notify > ) -> Self {
29+ Self ( Box :: pin ( async move { n. notified ( ) . await } ) )
30+ }
31+ }
32+
33+ impl Future for WatchNotification {
34+ type Output = ( ) ;
35+
36+ fn poll ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
37+ let inner = self . deref_mut ( ) . 0 . as_mut ( ) ;
38+ ready ! ( inner. poll( cx) ) ;
39+ Poll :: Ready ( ( ) )
40+ }
41+ }
42+
1943pub struct ChannelTx < S > {
2044 sender : mpsc:: Sender < S > ,
2145 send_fut : Option < OwnedPermitFuture < S > > ,
2246 id : ChannelId ,
23-
2447 window_size_fut : Option < BoxedThreadsafeFuture < OwnedMutexGuard < u32 > > > ,
2548 window_size : Arc < Mutex < u32 > > ,
49+ notify : Arc < Notify > ,
50+ window_size_notication : WatchNotification ,
2651 max_packet_size : u32 ,
2752 ext : Option < u32 > ,
2853}
@@ -35,43 +60,62 @@ where
3560 sender : mpsc:: Sender < S > ,
3661 id : ChannelId ,
3762 window_size : Arc < Mutex < u32 > > ,
63+ window_size_notification : Arc < Notify > ,
3864 max_packet_size : u32 ,
3965 ext : Option < u32 > ,
4066 ) -> Self {
4167 Self {
4268 sender,
4369 send_fut : None ,
4470 id,
71+ notify : Arc :: clone ( & window_size_notification) ,
72+ window_size_notication : WatchNotification :: new ( window_size_notification) ,
4573 window_size,
4674 window_size_fut : None ,
4775 max_packet_size,
4876 ext,
4977 }
5078 }
5179
52- fn poll_mk_msg ( & mut self , cx : & mut Context < ' _ > , buf : & [ u8 ] ) -> Poll < ( ChannelMsg , usize ) > {
80+ fn poll_writable ( & mut self , cx : & mut Context < ' _ > , buf_len : usize ) -> Poll < NonZero < usize > > {
5381 let window_size = self . window_size . clone ( ) ;
5482 let window_size_fut = self
5583 . window_size_fut
5684 . get_or_insert_with ( || Box :: pin ( window_size. lock_owned ( ) ) ) ;
5785 let mut window_size = ready ! ( window_size_fut. poll_unpin( cx) ) ;
5886 self . window_size_fut . take ( ) ;
5987
60- let writable = ( self . max_packet_size )
61- . min ( * window_size)
62- . min ( buf. len ( ) as u32 ) as usize ;
63- if writable == 0 {
64- // TODO fix this busywait
65- cx. waker ( ) . wake_by_ref ( ) ;
66- return Poll :: Pending ;
88+ let writable = ( self . max_packet_size ) . min ( * window_size) . min ( buf_len as u32 ) as usize ;
89+
90+ match NonZero :: try_from ( writable) {
91+ Ok ( w) => {
92+ * window_size -= writable as u32 ;
93+ if * window_size > 0 {
94+ self . notify . notify_one ( ) ;
95+ }
96+ Poll :: Ready ( w)
97+ }
98+ Err ( _) => {
99+ drop ( window_size) ;
100+ ready ! ( self . window_size_notication. poll_unpin( cx) ) ;
101+ self . window_size_notication = WatchNotification :: new ( Arc :: clone ( & self . notify ) ) ;
102+ cx. waker ( ) . wake_by_ref ( ) ;
103+ Poll :: Pending
104+ }
67105 }
68- let mut data = CryptoVec :: new_zeroed ( writable) ;
69- #[ allow( clippy:: indexing_slicing) ] // Clamped to maximum `buf.len()` with `.min`
70- data. copy_from_slice ( & buf[ ..writable] ) ;
71- data. resize ( writable) ;
106+ }
107+
108+ fn poll_mk_msg (
109+ & mut self ,
110+ cx : & mut Context < ' _ > ,
111+ buf : & [ u8 ] ,
112+ ) -> Poll < ( ChannelMsg , NonZero < usize > ) > {
113+ let writable = ready ! ( self . poll_writable( cx, buf. len( ) ) ) ;
72114
73- * window_size -= writable as u32 ;
74- drop ( window_size) ;
115+ let mut data = CryptoVec :: new_zeroed ( writable. into ( ) ) ;
116+ #[ allow( clippy:: indexing_slicing) ] // Clamped to maximum `buf.len()` with `.poll_writable`
117+ data. copy_from_slice ( & buf[ ..writable. into ( ) ] ) ;
118+ data. resize ( writable. into ( ) ) ;
75119
76120 let msg = match self . ext {
77121 None => ChannelMsg :: Data { data } ,
@@ -116,11 +160,17 @@ where
116160 cx : & mut Context < ' _ > ,
117161 buf : & [ u8 ] ,
118162 ) -> Poll < Result < usize , io:: Error > > {
163+ if buf. is_empty ( ) {
164+ return Poll :: Ready ( Err ( io:: Error :: new (
165+ io:: ErrorKind :: WriteZero ,
166+ "cannot send empty buffer" ,
167+ ) ) ) ;
168+ }
119169 let send_fut = if let Some ( x) = self . send_fut . as_mut ( ) {
120170 x
121171 } else {
122172 let ( msg, writable) = ready ! ( self . poll_mk_msg( cx, buf) ) ;
123- self . activate ( msg, writable)
173+ self . activate ( msg, writable. into ( ) )
124174 } ;
125175 let r = ready ! ( send_fut. as_mut( ) . poll_unpin( cx) ) ;
126176 Poll :: Ready ( self . handle_write_result ( r) )
@@ -143,3 +193,10 @@ where
143193 Poll :: Ready ( self . handle_write_result ( r) . map ( drop) )
144194 }
145195}
196+
197+ impl < S > Drop for ChannelTx < S > {
198+ fn drop ( & mut self ) {
199+ // Allow other writers to make progress
200+ self . notify . notify_one ( ) ;
201+ }
202+ }
0 commit comments