Skip to content

Commit

Permalink
Add better support for receiving larger payloads
Browse files Browse the repository at this point in the history
This change enables the maximum frame size to be configured when receiving websocket frames. It also
adds a new stream time that aggregates continuation frames together into their proper collected
representation. It provides no mechanism yet for sending continuations.
  • Loading branch information
asonix committed May 10, 2024
1 parent b918084 commit efebca1
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 13 deletions.
20 changes: 9 additions & 11 deletions actix-ws/examples/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
};

use actix_web::{middleware::Logger, web, App, HttpRequest, HttpResponse, HttpServer};
use actix_ws::{Message, Session};
use actix_ws::{AggregatedMessage, Session};
use futures_util::{stream::FuturesUnordered, StreamExt as _};
use log::info;
use tokio::sync::Mutex;
Expand Down Expand Up @@ -56,7 +56,10 @@ async fn ws(
body: web::Payload,
chat: web::Data<Chat>,
) -> Result<HttpResponse, actix_web::Error> {
let (response, mut session, mut stream) = actix_ws::handle(&req, body)?;
let (response, mut session, stream) = actix_ws::handle(&req, body)?;

// increase the maximum allowed frame size to 128KB and aggregate continuation frames
let mut stream = stream.max_frame_size(131_072).aggregate_continuations();

chat.insert(session.clone()).await;
info!("Inserted session");
Expand All @@ -83,27 +86,22 @@ async fn ws(
actix_rt::spawn(async move {
while let Some(Ok(msg)) = stream.next().await {
match msg {
Message::Ping(bytes) => {
AggregatedMessage::Ping(bytes) => {
if session.pong(&bytes).await.is_err() {
return;
}
}
Message::Text(s) => {
AggregatedMessage::Text(s) => {
info!("Relaying text, {}", s);
let s: &str = s.as_ref();
chat.send(s.into()).await;
}
Message::Close(reason) => {
AggregatedMessage::Close(reason) => {
let _ = session.close(reason).await;
info!("Got close, bailing");
return;
}
Message::Continuation(_) => {
let _ = session.close(None).await;
info!("Got continuation, bailing");
return;
}
Message::Pong(_) => {
AggregatedMessage::Pong(_) => {
*alive.lock().await = Instant::now();
}
_ => (),
Expand Down
151 changes: 150 additions & 1 deletion actix-ws/src/fut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ use std::{

use actix_codec::{Decoder, Encoder};
use actix_http::{
ws::{Codec, Frame, Message, ProtocolError},
ws::{CloseReason, Codec, Frame, Item, Message, ProtocolError},
Payload,
};
use actix_web::{
web::{Bytes, BytesMut},
Error,
};
use bytestring::ByteString;
use futures_core::stream::Stream;
use tokio::sync::mpsc::Receiver;

Expand All @@ -40,6 +41,40 @@ pub struct MessageStream {
closing: bool,
}

/// A Websocket message with continuations aggregated together
pub enum AggregatedMessage {
/// Text message
Text(ByteString),

/// Binary message
Binary(Bytes),

/// Ping message
Ping(Bytes),

/// Pong message
Pong(Bytes),

/// Close message with optional reason
Close(Option<CloseReason>),
}

enum ContinuationKind {
Text,
Binary,
}

/// A stream of Messages from a websocket client
///
/// This stream aggregates Continuation frames into their equivalent combined forms, e.g. Binary or
/// Text.
pub struct AggregatedMessageStream {
stream: MessageStream,

continuations: Vec<Bytes>,
continuation_kind: ContinuationKind,
}

impl StreamingBody {
pub(super) fn new(session_rx: Receiver<Message>) -> Self {
StreamingBody {
Expand All @@ -63,6 +98,32 @@ impl MessageStream {
}
}

/// Set the maximum permitted websocket frame size for received frames
///
/// The `max_size` unit is `bytes`
/// The default value for `max_size` is 65_536, or 64KB
///
/// Any received frames larger than the permitted value will return
/// `Err(ProtocolError::Overflow)` instead.
pub fn max_frame_size(self, max_size: usize) -> Self {
Self {
codec: self.codec.max_size(max_size),
..self
}
}

/// Produce a stream that collects Continuation frames into their equivalent collected forms,
/// e.g. Binary or Text.
///
/// This is useful when it is known ahead of time that continuations will not become large.
pub fn aggregate_continuations(self) -> AggregatedMessageStream {
AggregatedMessageStream {
stream: self,
continuations: Vec::new(),
continuation_kind: ContinuationKind::Binary,
}
}

/// Wait for the next item from the message stream
///
/// ```rust,ignore
Expand All @@ -75,6 +136,19 @@ impl MessageStream {
}
}

impl AggregatedMessageStream {
/// Wait for the next item from the message stream
///
/// ```rust,ignore
/// while let Some(Ok(msg)) = stream.recv().await {
/// // handle message
/// }
/// ```
pub async fn recv(&mut self) -> Option<Result<AggregatedMessage, ProtocolError>> {
poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
}
}

impl Stream for StreamingBody {
type Item = Result<Bytes, Error>;

Expand Down Expand Up @@ -181,3 +255,78 @@ impl Stream for MessageStream {
Poll::Pending
}
}

fn collect(continuations: &mut Vec<Bytes>) -> Bytes {
let continuations = std::mem::take(continuations);
let total_len = continuations.iter().map(|b| b.len()).sum();

let mut collected = BytesMut::with_capacity(total_len);

for b in continuations {
collected.extend(b);
}

collected.freeze()
}

impl Stream for AggregatedMessageStream {
type Item = Result<AggregatedMessage, ProtocolError>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();

match std::task::ready!(Pin::new(&mut this.stream).poll_next(cx)) {
Some(Ok(Message::Continuation(item))) => match item {
Item::FirstText(bytes) => {
this.continuations.push(bytes);
this.continuation_kind = ContinuationKind::Text;
Poll::Pending
}
Item::FirstBinary(bytes) => {
this.continuations.push(bytes);
this.continuation_kind = ContinuationKind::Binary;
Poll::Pending
}
Item::Continue(bytes) => {
this.continuations.push(bytes);
Poll::Pending
}
Item::Last(bytes) => {
this.continuations.push(bytes);
let bytes = collect(&mut this.continuations);

match this.continuation_kind {
ContinuationKind::Text => {
match std::str::from_utf8(&bytes) {
Ok(_) => {
// SAFETY: just checked valid UTF8 above
let bytestring =
unsafe { ByteString::from_bytes_unchecked(bytes) };
Poll::Ready(Some(Ok(AggregatedMessage::Text(bytestring))))
}
Err(e) => Poll::Ready(Some(Err(ProtocolError::Io(
io::Error::new(io::ErrorKind::Other, e.to_string()),
)))),
}
}
ContinuationKind::Binary => {
Poll::Ready(Some(Ok(AggregatedMessage::Binary(bytes))))
}
}
}
},
Some(Ok(Message::Text(text))) => Poll::Ready(Some(Ok(AggregatedMessage::Text(text)))),
Some(Ok(Message::Binary(binary))) => {
Poll::Ready(Some(Ok(AggregatedMessage::Binary(binary))))
}
Some(Ok(Message::Ping(ping))) => Poll::Ready(Some(Ok(AggregatedMessage::Ping(ping)))),
Some(Ok(Message::Pong(pong))) => Poll::Ready(Some(Ok(AggregatedMessage::Pong(pong)))),
Some(Ok(Message::Close(close))) => {
Poll::Ready(Some(Ok(AggregatedMessage::Close(close))))
}
Some(Ok(Message::Nop)) => unimplemented!("MessageStream cannot produce Nops"),
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}
2 changes: 1 addition & 1 deletion actix-ws/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mod fut;
mod session;

pub use self::{
fut::{MessageStream, StreamingBody},
fut::{AggregatedMessage, AggregatedMessageStream, MessageStream, StreamingBody},
session::{Closed, Session},
};

Expand Down

0 comments on commit efebca1

Please sign in to comment.