Skip to content

Commit acd744a

Browse files
lowlevlEugeny
authored andcommitted
Add channels::io::{ChannelTx, ChannelRx} and implement Channel::into_io_parts
1 parent fc77c53 commit acd744a

File tree

4 files changed

+223
-0
lines changed

4 files changed

+223
-0
lines changed

russh/src/channels/io/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
use super::ChannelMsg;
2+
3+
mod rx;
4+
pub use rx::ChannelRx;
5+
6+
mod tx;
7+
pub use tx::ChannelTx;

russh/src/channels/io/rx.rs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
use std::{
2+
io,
3+
pin::Pin,
4+
sync::{Arc, Mutex, TryLockError},
5+
task::{Context, Poll},
6+
};
7+
8+
use tokio::{
9+
io::AsyncRead,
10+
sync::mpsc::{self, error::TryRecvError},
11+
};
12+
13+
use super::ChannelMsg;
14+
15+
pub struct ChannelRx {
16+
receiver: mpsc::UnboundedReceiver<ChannelMsg>,
17+
buffer: Option<ChannelMsg>,
18+
19+
window_size: Arc<Mutex<u32>>,
20+
}
21+
22+
impl ChannelRx {
23+
pub fn new(
24+
receiver: mpsc::UnboundedReceiver<ChannelMsg>,
25+
window_size: Arc<Mutex<u32>>,
26+
) -> Self {
27+
Self {
28+
receiver,
29+
buffer: None,
30+
window_size,
31+
}
32+
}
33+
}
34+
35+
impl AsyncRead for ChannelRx {
36+
fn poll_read(
37+
mut self: Pin<&mut Self>,
38+
cx: &mut Context<'_>,
39+
buf: &mut tokio::io::ReadBuf<'_>,
40+
) -> Poll<io::Result<()>> {
41+
let msg = match self.buffer.take() {
42+
Some(msg) => msg,
43+
None => match self.receiver.try_recv() {
44+
Ok(msg) => msg,
45+
Err(TryRecvError::Empty) => {
46+
cx.waker().wake_by_ref();
47+
return Poll::Pending;
48+
}
49+
Err(TryRecvError::Disconnected) => {
50+
return Poll::Ready(Ok(()));
51+
}
52+
},
53+
};
54+
55+
match &msg {
56+
ChannelMsg::Data { data } => {
57+
if buf.remaining() >= data.len() {
58+
buf.put_slice(data);
59+
60+
Poll::Ready(Ok(()))
61+
} else {
62+
self.buffer = Some(msg);
63+
64+
cx.waker().wake_by_ref();
65+
Poll::Pending
66+
}
67+
}
68+
ChannelMsg::WindowAdjusted { new_size } => {
69+
let buffer = match self.window_size.try_lock() {
70+
Ok(mut window_size) => {
71+
*window_size = *new_size;
72+
73+
None
74+
}
75+
Err(TryLockError::WouldBlock) => Some(msg),
76+
Err(TryLockError::Poisoned(err)) => {
77+
return Poll::Ready(Err(io::Error::new(
78+
io::ErrorKind::Other,
79+
err.to_string(),
80+
)))
81+
}
82+
};
83+
84+
self.buffer = buffer;
85+
86+
cx.waker().wake_by_ref();
87+
Poll::Pending
88+
}
89+
ChannelMsg::Eof => {
90+
self.receiver.close();
91+
92+
Poll::Ready(Ok(()))
93+
}
94+
_ => {
95+
cx.waker().wake_by_ref();
96+
Poll::Pending
97+
}
98+
}
99+
}
100+
}

russh/src/channels/io/tx.rs

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
use std::{
2+
io,
3+
pin::Pin,
4+
sync::{Arc, Mutex, TryLockError},
5+
task::{Context, Poll},
6+
};
7+
8+
use tokio::{
9+
io::AsyncWrite,
10+
sync::mpsc::{self, error::TrySendError},
11+
};
12+
13+
use russh_cryptovec::CryptoVec;
14+
15+
use super::ChannelMsg;
16+
use crate::ChannelId;
17+
18+
pub struct ChannelTx<S> {
19+
sender: mpsc::Sender<S>,
20+
id: ChannelId,
21+
22+
window_size: Arc<Mutex<u32>>,
23+
max_packet_size: u32,
24+
}
25+
26+
impl<S> ChannelTx<S> {
27+
pub fn new(
28+
sender: mpsc::Sender<S>,
29+
id: ChannelId,
30+
window_size: Arc<Mutex<u32>>,
31+
max_packet_size: u32,
32+
) -> Self {
33+
Self {
34+
sender,
35+
id,
36+
window_size,
37+
max_packet_size,
38+
}
39+
}
40+
}
41+
42+
impl<S> AsyncWrite for ChannelTx<S>
43+
where
44+
S: From<(ChannelId, ChannelMsg)> + 'static,
45+
{
46+
fn poll_write(
47+
self: Pin<&mut Self>,
48+
cx: &mut Context<'_>,
49+
buf: &[u8],
50+
) -> Poll<Result<usize, io::Error>> {
51+
let mut window_size = match self.window_size.try_lock() {
52+
Ok(window_size) => window_size,
53+
Err(TryLockError::WouldBlock) => {
54+
cx.waker().wake_by_ref();
55+
return Poll::Pending;
56+
}
57+
Err(TryLockError::Poisoned(err)) => {
58+
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err.to_string())))
59+
}
60+
};
61+
62+
let writable = self.max_packet_size.min(*window_size).min(buf.len() as u32) as usize;
63+
if writable == 0 {
64+
cx.waker().wake_by_ref();
65+
return Poll::Pending;
66+
}
67+
68+
let mut data = CryptoVec::new_zeroed(writable);
69+
#[allow(clippy::indexing_slicing)] // Clamped to maximum `buf.len()` with `.min`
70+
data.copy_from_slice(&buf[..writable]);
71+
data.resize(writable);
72+
73+
*window_size -= writable as u32;
74+
drop(window_size);
75+
76+
match self
77+
.sender
78+
.try_send((self.id, ChannelMsg::Data { data }).into())
79+
{
80+
Ok(_) => Poll::Ready(Ok(writable)),
81+
Err(TrySendError::Closed(_)) => Poll::Ready(Ok(0)),
82+
Err(TrySendError::Full(_)) => {
83+
cx.waker().wake_by_ref();
84+
Poll::Pending
85+
}
86+
}
87+
}
88+
89+
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
90+
Poll::Ready(Ok(()))
91+
}
92+
93+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
94+
self.poll_flush(cx)
95+
}
96+
}

russh/src/channels/mod.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use log::debug;
44

55
use crate::{ChannelId, ChannelOpenFailure, ChannelStream, Error, Pty, Sig};
66

7+
pub mod io;
8+
79
#[derive(Debug)]
810
#[non_exhaustive]
911
/// Possible messages that [Channel::wait] can receive.
@@ -410,4 +412,22 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + 'static> Channel<S> {
410412
});
411413
stream
412414
}
415+
416+
/// Setup the [`Channel`] to be able to send messages through [`io::ChannelTx`],
417+
/// and receiving them through [`io::ChannelRx`].
418+
pub fn into_io_parts(self) -> (io::ChannelTx<S>, io::ChannelRx) {
419+
use std::sync::{Arc, Mutex};
420+
421+
let window_size = Arc::new(Mutex::new(self.window_size));
422+
423+
(
424+
io::ChannelTx::new(
425+
self.sender,
426+
self.id,
427+
window_size.clone(),
428+
self.max_packet_size,
429+
),
430+
io::ChannelRx::new(self.receiver, window_size),
431+
)
432+
}
413433
}

0 commit comments

Comments
 (0)