Skip to content

Commit

Permalink
Merge pull request #41 from Totodore/ft-disconnect-handler
Browse files Browse the repository at this point in the history
feat(socketioxide): disconnect handler
  • Loading branch information
Totodore committed Sep 17, 2023
2 parents cf717ee + 6efd60d commit 24fd5e3
Show file tree
Hide file tree
Showing 27 changed files with 916 additions and 94 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions e2e/src/engineioxide.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
use std::time::Duration;

use engineioxide::{
config::EngineIoConfig, handler::EngineIoHandler, service::EngineIoService, socket::Socket,
config::EngineIoConfig,
handler::EngineIoHandler,
service::EngineIoService,
socket::{DisconnectReason, Socket},
};
use hyper::Server;
use tracing::{info, Level};
Expand All @@ -19,8 +22,8 @@ impl EngineIoHandler for MyHandler {
fn on_connect(&self, socket: &Socket<Self>) {
println!("socket connect {}", socket.sid);
}
fn on_disconnect(&self, socket: &Socket<Self>) {
println!("socket disconnect {}", socket.sid);
fn on_disconnect(&self, socket: &Socket<Self>, reason: DisconnectReason) {
println!("socket disconnect {}: {:?}", socket.sid, reason);
}

fn on_message(&self, msg: String, socket: &Socket<Self>) {
Expand Down
11 changes: 10 additions & 1 deletion engineioxide/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,16 @@ unicode-segmentation = { version = "1.10.1", optional = true }

[dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] }
tokio = { version = "1.26.0", features = ["macros"] }
tokio = { version = "1.26.0", features = ["macros", "parking_lot"] }
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
hyper = { version = "0.14.25", features = [
"http1",
"http2",
"server",
"stream",
"runtime",
"client",
] }

[features]
default = ["v4"]
Expand Down
3 changes: 2 additions & 1 deletion engineioxide/benches/benchmark_polling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::time::Duration;

use bytes::{Buf, Bytes};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use engineioxide::socket::DisconnectReason;
use engineioxide::{handler::EngineIoHandler, service::EngineIoService, socket::Socket};

use engineioxide::sid_generator::Sid;
Expand All @@ -29,7 +30,7 @@ impl EngineIoHandler for Client {

fn on_connect(&self, _: &Socket<Self>) {}

fn on_disconnect(&self, _: &Socket<Self>) {}
fn on_disconnect(&self, _: &Socket<Self>, _reason: DisconnectReason) {}

fn on_message(&self, msg: String, socket: &Socket<Self>) {
socket.emit(msg).unwrap();
Expand Down
4 changes: 2 additions & 2 deletions engineioxide/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl EngineIoConfigBuilder {
/// # use engineioxide::{
/// layer::EngineIoLayer,
/// handler::EngineIoHandler,
/// socket::Socket,
/// socket::{Socket, DisconnectReason},
/// };
/// # use std::sync::Arc;
/// #[derive(Debug, Clone)]
Expand All @@ -108,7 +108,7 @@ impl EngineIoConfigBuilder {
/// fn on_connect(&self, socket: &Socket<Self>) {
/// println!("socket connect {}", socket.sid);
/// }
/// fn on_disconnect(&self, socket: &Socket<Self>) {
/// fn on_disconnect(&self, socket: &Socket<Self>, reason: DisconnectReason) {
/// println!("socket disconnect {}", socket.sid);
/// }
///
Expand Down
36 changes: 21 additions & 15 deletions engineioxide/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
payload::{self},
service::TransportType,
sid_generator::generate_sid,
socket::{ConnectionType, Socket, SocketReq},
socket::{ConnectionType, DisconnectReason, Socket, SocketReq},
};
use crate::{service::ProtocolVersion, sid_generator::Sid};
use futures::{stream::SplitStream, SinkExt, StreamExt, TryStreamExt};
Expand Down Expand Up @@ -62,7 +62,8 @@ impl<H: EngineIoHandler> EngineIo<H> {
B: Send + 'static,
{
let engine = self.clone();
let close_fn = Box::new(move |sid: Sid| engine.close_session(sid));
let close_fn =
Box::new(move |sid: Sid, reason: DisconnectReason| engine.close_session(sid, reason));
let sid = generate_sid();
let socket = Socket::new(
sid,
Expand Down Expand Up @@ -124,7 +125,7 @@ impl<H: EngineIoHandler> EngineIo<H> {
let rx = match socket.internal_rx.try_lock() {
Ok(s) => s,
Err(_) => {
socket.close();
socket.close(DisconnectReason::MultipleHttpPollingError);
return Err(Error::HttpErrorResponse(StatusCode::BAD_REQUEST));
}
};
Expand Down Expand Up @@ -168,7 +169,7 @@ impl<H: EngineIoHandler> EngineIo<H> {
Ok(Packet::Close) => {
debug!("[sid={sid}] closing session");
socket.send(Packet::Noop)?;
self.close_session(sid);
self.close_session(sid, DisconnectReason::TransportClose);
break;
}
Ok(Packet::Pong) | Ok(Packet::Ping) => socket
Expand All @@ -189,7 +190,7 @@ impl<H: EngineIoHandler> EngineIo<H> {
}
Err(e) => {
debug!("[sid={sid}] error parsing packet: {:?}", e);
self.close_session(sid);
self.close_session(sid, DisconnectReason::PacketParsingError);
return Err(e);
}
}?;
Expand Down Expand Up @@ -256,7 +257,9 @@ impl<H: EngineIoHandler> EngineIo<H> {
} else {
let sid = generate_sid();
let engine = self.clone();
let close_fn = Box::new(move |sid: Sid| engine.close_session(sid));
let close_fn = Box::new(move |sid: Sid, reason: DisconnectReason| {
engine.close_session(sid, reason)
});
let socket = Socket::new(
sid,
protocol,
Expand Down Expand Up @@ -305,10 +308,14 @@ impl<H: EngineIoHandler> EngineIo<H> {
});

self.handler.on_connect(&socket);
if let Err(e) = self.ws_forward_to_handler(rx, &socket).await {
if let Err(ref e) = self.ws_forward_to_handler(rx, &socket).await {
debug!("[sid={}] error when handling packet: {:?}", socket.sid, e);
if let Some(reason) = e.into() {
self.close_session(socket.sid, reason);
}
} else {
self.close_session(socket.sid, DisconnectReason::TransportClose);
}
self.close_session(socket.sid);
rx_handle.abort();
Ok(())
}
Expand All @@ -319,13 +326,12 @@ impl<H: EngineIoHandler> EngineIo<H> {
mut rx: SplitStream<WebSocketStream<Upgraded>>,
socket: &Arc<Socket<H>>,
) -> Result<(), Error> {
while let Ok(msg) = rx.try_next().await {
let Some(msg) = msg else { continue };
while let Some(msg) = rx.try_next().await? {
match msg {
Message::Text(msg) => match Packet::try_from(msg)? {
Packet::Close => {
debug!("[sid={}] closing session", socket.sid);
self.close_session(socket.sid);
self.close_session(socket.sid, DisconnectReason::TransportClose);
break;
}
Packet::Pong | Packet::Ping => socket
Expand Down Expand Up @@ -448,10 +454,10 @@ impl<H: EngineIoHandler> EngineIo<H> {

/// Close an engine.io session by removing the socket from the socket map and closing the socket
/// It should be the only way to close a session and to remove a socket from the socket map
fn close_session(&self, sid: Sid) {
fn close_session(&self, sid: Sid, reason: DisconnectReason) {
let socket = self.sockets.write().unwrap().remove(&sid);
if let Some(socket) = socket {
self.handler.on_disconnect(&socket);
self.handler.on_disconnect(&socket, reason);
socket.abort_heartbeat();
debug!(
"remaining sockets: {:?}",
Expand Down Expand Up @@ -486,8 +492,8 @@ mod tests {
println!("socket connect {}", socket.sid);
}

fn on_disconnect(&self, socket: &Socket<Self>) {
println!("socket disconnect {}", socket.sid);
fn on_disconnect(&self, socket: &Socket<Self>, reason: DisconnectReason) {
println!("socket disconnect {} {:?}", socket.sid, reason);
}

fn on_message(&self, msg: String, socket: &Socket<Self>) {
Expand Down
2 changes: 0 additions & 2 deletions engineioxide/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ pub enum Error {
BadPacket(Packet),
#[error("ws transport error: {0:?}")]
WsTransport(#[from] tungstenite::Error),
#[error("http transport error: {0:?}")]
HttpTransport(#[from] hyper::Error),
#[error("http error: {0:?}")]
Http(#[from] http::Error),
#[error("internal channel error: {0:?}")]
Expand Down
4 changes: 2 additions & 2 deletions engineioxide/src/handler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use async_trait::async_trait;

use crate::socket::Socket;
use crate::socket::{DisconnectReason, Socket};

/// An handler for engine.io events for each sockets.
#[async_trait]
Expand All @@ -12,7 +12,7 @@ pub trait EngineIoHandler: std::fmt::Debug + Send + Sync + Clone + 'static {
fn on_connect(&self, socket: &Socket<Self>);

/// Called when a socket is disconnected.
fn on_disconnect(&self, socket: &Socket<Self>);
fn on_disconnect(&self, socket: &Socket<Self>, reason: DisconnectReason);

/// Called when a message is received from the client.
fn on_message(&self, msg: String, socket: &Socket<Self>);
Expand Down
16 changes: 16 additions & 0 deletions engineioxide/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,22 @@ impl FromStr for TransportType {
}
}
}
impl From<TransportType> for &'static str {
fn from(t: TransportType) -> Self {
match t {
TransportType::Polling => "polling",
TransportType::Websocket => "websocket",
}
}
}
impl From<TransportType> for String {
fn from(t: TransportType) -> Self {
match t {
TransportType::Polling => "polling".into(),
TransportType::Websocket => "websocket".into(),
}
}
}

#[derive(Debug, Copy, Clone, PartialEq)]
pub enum ProtocolVersion {
Expand Down
49 changes: 43 additions & 6 deletions engineioxide/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use tokio::{
sync::{mpsc, mpsc::Receiver, Mutex},
task::JoinHandle,
};
use tokio_tungstenite::tungstenite;
use tracing::debug;

use crate::sid_generator::Sid;
Expand Down Expand Up @@ -54,6 +55,39 @@ impl From<Parts> for SocketReq {
}
}

/// A [`DisconnectReason`] represents the reason why a [`Socket`] was closed.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DisconnectReason {
/// The client gracefully closed the connection
TransportClose,
/// The client sent multiple polling requests at the same time (it is forbidden according to the engine.io protocol)
MultipleHttpPollingError,
/// The client sent a bad request / the packet could not be parsed correctly
PacketParsingError,
/// An error occured in the transport layer
/// (e.g. the client closed the connection without sending a close packet)
TransportError,
/// The client did not respond to the heartbeat
HeartbeatTimeout,
}

/// Convert an [`Error`] to a [`DisconnectReason`] if possible
/// This is used to notify the [`Handler`](crate::handler::EngineIoHandler) of the reason why a [`Socket`] was closed
/// If the error cannot be converted to a [`DisconnectReason`] it means that the error was not fatal and the [`Socket`] can be kept alive
impl From<&Error> for Option<DisconnectReason> {
fn from(err: &Error) -> Self {
use Error::*;
match err {
WsTransport(tungstenite::Error::ConnectionClosed) => None,
WsTransport(_) | Io(_) => Some(DisconnectReason::TransportError),
BadPacket(_) | Serialize(_) | Base64(_) | StrUtf8(_) | PayloadTooLarge
| InvalidPacketLength => Some(DisconnectReason::PacketParsingError),
HeartbeatTimeout => Some(DisconnectReason::HeartbeatTimeout),
_ => None,
}
}
}

/// A [`Socket`] represents a connection to the server.
/// It is agnostic to the [`TransportType`](crate::service::TransportType).
/// It handles :
Expand Down Expand Up @@ -99,7 +133,7 @@ where
heartbeat_handle: Mutex<Option<JoinHandle<()>>>,

/// Function to call when the socket is closed
close_fn: Box<dyn Fn(Sid) + Send + Sync>,
close_fn: Box<dyn Fn(Sid, DisconnectReason) + Send + Sync>,
/// User data bound to the socket
pub data: H::Data,

Expand All @@ -121,7 +155,7 @@ where
conn: ConnectionType,
config: &EngineIoConfig,
req_data: SocketReq,
close_fn: Box<dyn Fn(Sid) + Send + Sync>,
close_fn: Box<dyn Fn(Sid, DisconnectReason) + Send + Sync>,
#[cfg(feature = "v3")] supports_binary: bool,
) -> Self {
let (internal_tx, internal_rx) = mpsc::channel(config.max_buffer_size);
Expand Down Expand Up @@ -174,7 +208,7 @@ where

let handle = tokio::spawn(async move {
if let Err(e) = socket.heartbeat_job(interval, timeout).await {
socket.close();
socket.close(DisconnectReason::HeartbeatTimeout);
debug!("[sid={}] heartbeat error: {:?}", socket.sid, e);
}
});
Expand Down Expand Up @@ -300,8 +334,8 @@ where

/// Immediately closes the socket and the underlying connection.
/// The socket will be removed from the `Engine` and the [`Handler`](crate::handler::EngineIoHandler) will be notified.
pub fn close(&self) {
(self.close_fn)(self.sid);
pub fn close(&self, reason: DisconnectReason) {
(self.close_fn)(self.sid, reason);
self.send(Packet::Close).ok();
}

Expand All @@ -325,7 +359,10 @@ where

#[cfg(test)]
impl<H: EngineIoHandler> Socket<H> {
pub fn new_dummy(sid: Sid, close_fn: Box<dyn Fn(Sid) + Send + Sync>) -> Socket<H> {
pub fn new_dummy(
sid: Sid,
close_fn: Box<dyn Fn(Sid, DisconnectReason) + Send + Sync>,
) -> Socket<H> {
let (internal_tx, internal_rx) = mpsc::channel(200);
let (tx, rx) = mpsc::channel(200);
let (heartbeat_tx, heartbeat_rx) = mpsc::channel(1);
Expand Down
Loading

0 comments on commit 24fd5e3

Please sign in to comment.