From 3c7432d30a4ab7b22b9bc70d3f8e77d54f2d95f2 Mon Sep 17 00:00:00 2001 From: Maxim Zakharov <5158255+Maxime2@users.noreply.github.com> Date: Mon, 26 Aug 2019 15:37:29 +1000 Subject: [PATCH] remake into async/.awat framework; make compatible with latest version of Transport trait --- Cargo.toml | 1 + src/lib.rs | 180 +++++++++++++++++++++++++---------------------------- 2 files changed, 86 insertions(+), 95 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b857ffc..96227d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,3 +14,4 @@ serde = "1.0.98" serde_derive = "1.0.98" buffer = "0.1.8" os_pipe="0.8.0" +futures-preview = { version = "=0.3.0-alpha.17", features = ["compat"] } diff --git a/src/lib.rs b/src/lib.rs index 3503064..2392d0d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,15 +4,20 @@ extern crate serde_derive; use bincode::{deserialize, serialize}; use buffer::ReadBuffer; +use futures::stream::Stream; +use futures::task::Context; +use futures::task::Poll; +use futures::task::Waker; use libcommon_rs::peer::{Peer, PeerId, PeerList}; -use libtransport::errors::{Error, Error::AtMaxVecCapacity, Result}; +use libtransport::errors::{Error, Result}; use libtransport::{Transport, TransportConfiguration}; -//use os_pipe::PipeWriter; use serde::de::DeserializeOwned; use serde::Serialize; use std::io; use std::io::Write; +use std::marker::PhantomData; use std::net::{TcpListener, TcpStream}; +use std::pin::Pin; use std::sync::mpsc::{self, Receiver, Sender}; use std::sync::{Arc, Mutex}; use std::thread; @@ -20,53 +25,34 @@ use std::thread::JoinHandle; pub struct TCPtransportCfg { bind_net_addr: String, - channel_pool: Vec>, - //pipe_pool: Vec, - callback_pool: Vec bool>, - callback_timeout: u64, quit_rx: Option>, + listener: TcpListener, + waker: Option, + phantom: PhantomData, } impl TransportConfiguration for TCPtransportCfg { fn new(set_bind_net_addr: String) -> Self { + let listener = TcpListener::bind(set_bind_net_addr.clone()).unwrap(); + listener + .set_nonblocking(true) + .expect("unable to set non-blocking"); TCPtransportCfg { bind_net_addr: set_bind_net_addr, - channel_pool: Vec::with_capacity(1), - //pipe_pool: Vec::with_capacity(1), - callback_pool: Vec::with_capacity(1), - callback_timeout: 100, // 100 millisecond timeout by default quit_rx: None, + listener, + waker: None, + phantom: PhantomData, } } - fn register_channel(&mut self, sender: Sender) -> Result<()> { - // Vec::push() panics when number of elements overflows `usize` - if self.channel_pool.len() == std::usize::MAX { - return Err(AtMaxVecCapacity); - } - self.channel_pool.push(sender); - Ok(()) - } - //fn register_os_pipe(&mut self, sender: PipeWriter) -> Result<()> { - // // Vec::push() panics when number of elements overflows `usize` - // if self.pipe_pool.len() == std::usize::MAX { - // return Err(AtMaxVecCapacity); - // } - // self.pipe_pool.push(sender); - // Ok(()) - //} - fn register_callback(&mut self, callback: fn(Data) -> bool) -> Result<()> { - // Vec::push() panics when number of elements overflows `usize` - if self.callback_pool.len() == std::usize::MAX { - return Err(AtMaxVecCapacity); - } - self.callback_pool.push(callback); - Ok(()) - } - fn set_callback_timeout(&mut self, timeout: u64) { - self.callback_timeout = timeout; - } fn set_bind_net_addr(&mut self, address: String) -> Result<()> { self.bind_net_addr = address; + let listener = TcpListener::bind(self.bind_net_addr.clone()).unwrap(); + listener + .set_nonblocking(true) + .expect("unable to set non-blocking"); + use std::mem; + drop(mem::replace(&mut self.listener, listener)); Ok(()) } } @@ -84,68 +70,27 @@ pub struct TCPtransport { server_handle: Option>, } -fn handle_client(cfg_mutexed: Arc>>, mut stream: TcpStream) -where - D: DeserializeOwned, -{ - let mut buffer: Vec = Vec::with_capacity(4096); - loop { - let n = match stream.read_buffer(&mut buffer) { - // FIXME: what we do with panics in threads? - Err(e) => panic!("error reading from a connection: {}", e), - Ok(x) => x.len(), - }; - if n == 0 { - // FIXME: check correct work in case when TCP next block delivery timeout is - // greater than read_buffer() read timeout - break; - } - } - let data: D = deserialize::(&buffer).unwrap(); - //dbg!(buffer); - let cfg = cfg_mutexed.lock().unwrap(); - //dbg!(cfg.channel_pool.len()); - for ch in cfg.channel_pool.iter() { - //println!("sending to channel."); - ch.send(data.clone()).unwrap(); - } -} - fn listener(cfg_mutexed: Arc>>) where Data: Serialize + DeserializeOwned + Send + Clone, { // FIXME: what we do with unwrap() in threads? let config = Arc::clone(&cfg_mutexed); - let listener = { - let cfg = config.lock().unwrap(); - TcpListener::bind(cfg.bind_net_addr.clone()).unwrap() - }; - listener - .set_nonblocking(true) - .expect("unable to set non-blocking"); - for stream in listener.incoming() { - match stream { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - // check if quit channel got message - let cfg = config.lock().unwrap(); - match &cfg.quit_rx { - None => {} - Some(ch) => { - if ch.try_recv().is_ok() { - break; - } - } + loop { + // check if quit channel got message + let mut cfg = config.lock().unwrap(); + match &cfg.quit_rx { + None => {} + Some(ch) => { + if ch.try_recv().is_ok() { + break; } - continue; - } - Err(e) => panic!("error in accepting connection: {}", e), - Ok(stream) => { - let config = Arc::clone(&cfg_mutexed); - // receive Data and push it into channels, pipes and call callbacks - thread::spawn(move || handle_client(config, stream)); } } + // allow to pool again if waker is set + if let Some(waker) = cfg.waker.take() { + waker.wake() + } } } @@ -203,12 +148,57 @@ where } Ok(()) } +} + +impl Unpin for TCPtransport {} - // fn register_channel(&mut self, sender: Sender) -> Result<()> { - // let mut cfg = self.config.lock()?; - // cfg.register_channel(sender)?; - // Ok(()) - // } +impl Stream for TCPtransport +where + Data: DeserializeOwned, +{ + type Item = Data; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let myself = Pin::get_mut(self); + let config = Arc::clone(&myself.config); + let mut cfg = config.lock().unwrap(); + for stream in cfg.listener.incoming() { + match stream { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // check if quit channel got message + match &cfg.quit_rx { + None => {} + Some(ch) => { + if ch.try_recv().is_ok() { + break; // meaning Poll::Pending as we are going down + } + } + } + } + Err(e) => panic!("error in accepting connection: {}", e), + Ok(mut stream) => { + let mut buffer: Vec = Vec::with_capacity(4096); + loop { + let n = match stream.read_buffer(&mut buffer) { + // FIXME: what we do with panics in threads? + Err(e) => panic!("error reading from a connection: {}", e), + Ok(x) => x.len(), + }; + if n == 0 { + // FIXME: check correct work in case when TCP next block delivery timeout is + // greater than read_buffer() read timeout + break; + } + } + // FIXME: what should we return in case of deserialize() failure, + // Poll::Ready(None) or Poll::Pending instead of panic? + let data: Data = deserialize::(&buffer).unwrap(); + return Poll::Ready(Some(data)); + } + } + } + cfg.waker = Some(cx.waker().clone()); + Poll::Pending + } } #[cfg(test)]