Skip to content

Commit 359fa3c

Browse files
committed
fixed #100 - allow overriding Handler methods without losing Channel functionality
1 parent 30c401e commit 359fa3c

File tree

5 files changed

+296
-200
lines changed

5 files changed

+296
-200
lines changed

russh/examples/test.rs

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
use async_trait::async_trait;
2+
use env_logger;
3+
use log::debug;
4+
use russh::server::{Auth, Msg, Session};
5+
use russh::*;
6+
use russh_keys::*;
7+
use std::collections::HashMap;
8+
use std::sync::{Arc, Mutex};
9+
10+
#[tokio::main]
11+
async fn main() -> anyhow::Result<()> {
12+
env_logger::init();
13+
let mut config = russh::server::Config::default();
14+
config.auth_rejection_time = std::time::Duration::from_secs(3);
15+
config
16+
.keys
17+
.push(russh_keys::key::KeyPair::generate_ed25519().unwrap());
18+
let config = Arc::new(config);
19+
let sh = Server {
20+
clients: Arc::new(Mutex::new(HashMap::new())),
21+
id: 0,
22+
};
23+
tokio::time::timeout(
24+
std::time::Duration::from_secs(60),
25+
russh::server::run(config, ("0.0.0.0", 2222), sh),
26+
)
27+
.await
28+
.unwrap_or(Ok(()))?;
29+
30+
Ok(())
31+
}
32+
33+
#[derive(Clone)]
34+
struct Server {
35+
clients: Arc<Mutex<HashMap<(usize, ChannelId), Channel<Msg>>>>,
36+
id: usize,
37+
}
38+
39+
impl server::Server for Server {
40+
type Handler = Self;
41+
fn new_client(&mut self, _: Option<std::net::SocketAddr>) -> Self {
42+
debug!("new client");
43+
let s = self.clone();
44+
self.id += 1;
45+
s
46+
}
47+
}
48+
49+
#[async_trait]
50+
impl server::Handler for Server {
51+
type Error = anyhow::Error;
52+
53+
async fn channel_open_session(
54+
self,
55+
channel: Channel<Msg>,
56+
session: Session,
57+
) -> Result<(Self, bool, Session), Self::Error> {
58+
{
59+
debug!("channel open session");
60+
let mut clients = self.clients.lock().unwrap();
61+
clients.insert((self.id, channel.id()), channel);
62+
}
63+
Ok((self, true, session))
64+
}
65+
66+
/// The client requests a shell.
67+
#[allow(unused_variables)]
68+
async fn shell_request(
69+
self,
70+
channel: ChannelId,
71+
mut session: Session,
72+
) -> Result<(Self, Session), Self::Error> {
73+
session.request_success();
74+
Ok((self, session))
75+
}
76+
77+
async fn auth_publickey(
78+
self,
79+
_: &str,
80+
_: &key::PublicKey,
81+
) -> Result<(Self, Auth), Self::Error> {
82+
Ok((self, server::Auth::Accept))
83+
}
84+
async fn data(
85+
self,
86+
_channel: ChannelId,
87+
data: &[u8],
88+
mut session: Session,
89+
) -> Result<(Self, Session), Self::Error> {
90+
debug!("data: {data:?}");
91+
{
92+
let mut clients = self.clients.lock().unwrap();
93+
for ((_, _channel_id), ref mut channel) in clients.iter_mut() {
94+
session.data(channel.id(), CryptoVec::from(data.to_vec()));
95+
}
96+
}
97+
Ok((self, session))
98+
}
99+
}

russh/src/client/encrypted.rs

Lines changed: 80 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,20 @@
1515
use std::cell::RefCell;
1616
use std::convert::TryInto;
1717

18+
use log::{debug, error, info, trace, warn};
1819
use russh_cryptovec::CryptoVec;
1920
use russh_keys::encoding::{Encoding, Reader};
2021
use russh_keys::key::parse_public_key;
2122
use tokio::sync::mpsc::unbounded_channel;
22-
use log::{debug, error, info, trace, warn};
2323

2424
use crate::client::{Handler, Msg, Prompt, Reply, Session};
2525
use crate::key::PubKey;
2626
use crate::negotiation::{Named, Select};
2727
use crate::parsing::{ChannelOpenConfirmation, ChannelType, OpenChannelMessage};
2828
use crate::session::{Encrypted, EncryptedState, Kex, KexInit};
29-
use crate::{auth, msg, negotiation, Channel, ChannelId, ChannelOpenFailure, ChannelParams, Sig};
29+
use crate::{
30+
auth, msg, negotiation, Channel, ChannelId, ChannelMsg, ChannelOpenFailure, ChannelParams, Sig,
31+
};
3032

3133
thread_local! {
3234
static SIGNATURE_BUFFER: RefCell<CryptoVec> = RefCell::new(CryptoVec::new());
@@ -184,7 +186,6 @@ impl Session {
184186
current: None,
185187
rejection_count: 0,
186188
},
187-
188189
};
189190
let len = enc.write.len();
190191
#[allow(clippy::indexing_slicing)] // length checked
@@ -246,15 +247,14 @@ impl Session {
246247
if no_more_methods {
247248
return Err(crate::Error::NoAuthMethod.into());
248249
}
249-
250250
} else if buf.first() == Some(&msg::USERAUTH_INFO_REQUEST_OR_USERAUTH_PK_OK) {
251251
if let Some(auth::CurrentRequest::PublicKey {
252252
ref mut sent_pk_ok, ..
253253
}) = auth_request.current
254254
{
255255
debug!("userauth_pk_ok");
256256
*sent_pk_ok = true;
257-
} else if let Some(auth::CurrentRequest::KeyboardInteractive { .. }) =
257+
} else if let Some(auth::CurrentRequest::KeyboardInteractive { .. }) =
258258
auth_request.current
259259
{
260260
debug!("keyboard_interactive");
@@ -307,7 +307,8 @@ impl Session {
307307
// write responses
308308
enc.client_send_auth_response(&responses)?;
309309
return Ok((client, self));
310-
} else {}
310+
} else {
311+
}
311312

312313
// continue with userauth_pk_ok
313314
match self.common.auth_method.take() {
@@ -396,6 +397,18 @@ impl Session {
396397
return Err(crate::Error::Inconsistent.into());
397398
};
398399

400+
if let Some(channel) = self.channels.get(&local_id) {
401+
channel
402+
.send(ChannelMsg::Open {
403+
id: local_id,
404+
max_packet_size: msg.maximum_packet_size,
405+
window_size: msg.initial_window_size,
406+
})
407+
.unwrap_or(());
408+
} else {
409+
error!("no channel for id {local_id:?}");
410+
}
411+
399412
client
400413
.channel_open_confirmation(
401414
local_id,
@@ -414,12 +427,16 @@ impl Session {
414427
// will not be released.
415428
enc.close(channel_num);
416429
}
430+
self.channels.remove(&channel_num);
417431
client.channel_close(channel_num, self).await
418432
}
419433
Some(&msg::CHANNEL_EOF) => {
420434
debug!("channel_eof");
421435
let mut r = buf.reader(1);
422436
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
437+
if let Some(chan) = self.channels.get(&channel_num) {
438+
let _ = chan.send(ChannelMsg::Eof);
439+
}
423440
client.channel_eof(channel_num, self).await
424441
}
425442
Some(&msg::CHANNEL_OPEN_FAILURE) => {
@@ -436,6 +453,13 @@ impl Session {
436453
if let Some(ref mut enc) = self.common.encrypted {
437454
enc.channels.remove(&channel_num);
438455
}
456+
457+
if let Some(sender) = self.channels.remove(&channel_num) {
458+
let _ = sender.send(ChannelMsg::OpenFailure(reason_code));
459+
}
460+
461+
let _ = self.sender.send(Reply::ChannelOpenFailure);
462+
439463
client
440464
.channel_open_failure(channel_num, reason_code, descr, language, self)
441465
.await
@@ -455,6 +479,13 @@ impl Session {
455479
}
456480
}
457481
}
482+
483+
if let Some(chan) = self.channels.get(&channel_num) {
484+
let _ = chan.send(ChannelMsg::Data {
485+
data: CryptoVec::from_slice(data),
486+
});
487+
}
488+
458489
client.data(channel_num, data, self).await
459490
}
460491
Some(&msg::CHANNEL_EXTENDED_DATA) => {
@@ -473,6 +504,14 @@ impl Session {
473504
}
474505
}
475506
}
507+
508+
if let Some(chan) = self.channels.get(&channel_num) {
509+
let _ = chan.send(ChannelMsg::ExtendedData {
510+
ext: extended_code,
511+
data: CryptoVec::from_slice(data),
512+
});
513+
}
514+
476515
client
477516
.extended_data(channel_num, extended_code, data, self)
478517
.await
@@ -489,30 +528,44 @@ impl Session {
489528
match req {
490529
b"xon-xoff" => {
491530
r.read_byte().map_err(crate::Error::from)?; // should be 0.
492-
let client_can_do = r.read_byte().map_err(crate::Error::from)?;
493-
client.xon_xoff(channel_num, client_can_do != 0, self).await
531+
let client_can_do = r.read_byte().map_err(crate::Error::from)? != 0;
532+
if let Some(chan) = self.channels.get(&channel_num) {
533+
let _ = chan.send(ChannelMsg::XonXoff { client_can_do });
534+
}
535+
client.xon_xoff(channel_num, client_can_do, self).await
494536
}
495537
b"exit-status" => {
496538
r.read_byte().map_err(crate::Error::from)?; // should be 0.
497539
let exit_status = r.read_u32().map_err(crate::Error::from)?;
540+
if let Some(chan) = self.channels.get(&channel_num) {
541+
let _ = chan.send(ChannelMsg::ExitStatus { exit_status });
542+
}
498543
client.exit_status(channel_num, exit_status, self).await
499544
}
500545
b"exit-signal" => {
501546
r.read_byte().map_err(crate::Error::from)?; // should be 0.
502547
let signal_name =
503548
Sig::from_name(r.read_string().map_err(crate::Error::from)?)?;
504-
let core_dumped = r.read_byte().map_err(crate::Error::from)?;
549+
let core_dumped = r.read_byte().map_err(crate::Error::from)? != 0;
505550
let error_message =
506551
std::str::from_utf8(r.read_string().map_err(crate::Error::from)?)
507552
.map_err(crate::Error::from)?;
508553
let lang_tag =
509554
std::str::from_utf8(r.read_string().map_err(crate::Error::from)?)
510555
.map_err(crate::Error::from)?;
556+
if let Some(chan) = self.channels.get(&channel_num) {
557+
let _ = chan.send(ChannelMsg::ExitSignal {
558+
signal_name: signal_name.clone(),
559+
core_dumped,
560+
error_message: error_message.to_string(),
561+
lang_tag: lang_tag.to_string(),
562+
});
563+
}
511564
client
512565
.exit_signal(
513566
channel_num,
514567
signal_name,
515-
core_dumped != 0,
568+
core_dumped,
516569
error_message,
517570
lang_tag,
518571
self,
@@ -563,17 +616,24 @@ impl Session {
563616
let mut r = buf.reader(1);
564617
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
565618
let amount = r.read_u32().map_err(crate::Error::from)?;
566-
let mut new_value = 0;
619+
let mut new_size = 0;
567620
debug!("amount: {:?}", amount);
568621
if let Some(ref mut enc) = self.common.encrypted {
569622
if let Some(ref mut channel) = enc.channels.get_mut(&channel_num) {
570623
channel.recipient_window_size += amount;
571-
new_value = channel.recipient_window_size;
624+
new_size = channel.recipient_window_size;
572625
} else {
573626
return Err(crate::Error::WrongChannel.into());
574627
}
575628
}
576-
client.window_adjusted(channel_num, new_value, self).await
629+
630+
if let Some(ref mut enc) = self.common.encrypted {
631+
new_size -= enc.flush_pending(channel_num) as u32;
632+
}
633+
if let Some(chan) = self.channels.get(&channel_num) {
634+
let _ = chan.send(ChannelMsg::WindowAdjusted { new_size });
635+
}
636+
client.window_adjusted(channel_num, new_size, self).await
577637
}
578638
Some(&msg::GLOBAL_REQUEST) => {
579639
let mut r = buf.reader(1);
@@ -634,11 +694,17 @@ impl Session {
634694
Some(&msg::CHANNEL_SUCCESS) => {
635695
let mut r = buf.reader(1);
636696
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
697+
if let Some(chan) = self.channels.get(&channel_num) {
698+
let _ = chan.send(ChannelMsg::Success);
699+
}
637700
client.channel_success(channel_num, self).await
638701
}
639702
Some(&msg::CHANNEL_FAILURE) => {
640703
let mut r = buf.reader(1);
641704
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
705+
if let Some(chan) = self.channels.get(&channel_num) {
706+
let _ = chan.send(ChannelMsg::Failure);
707+
}
642708
client.channel_failure(channel_num, self).await
643709
}
644710
Some(&msg::CHANNEL_OPEN) => {
@@ -891,10 +957,7 @@ impl Encrypted {
891957
Ok(())
892958
}
893959

894-
fn client_send_auth_response(
895-
&mut self,
896-
responses: &[String]
897-
) -> Result<(), crate::Error> {
960+
fn client_send_auth_response(&mut self, responses: &[String]) -> Result<(), crate::Error> {
898961
push_packet!(self.write, {
899962
self.write.push(msg::USERAUTH_INFO_RESPONSE);
900963
self.write

0 commit comments

Comments
 (0)