Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor websocket key hashing #2035

Merged
merged 4 commits into from Feb 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions actix-http/CHANGES.md
Expand Up @@ -3,13 +3,15 @@
## Unreleased - 2021-xx-xx
### Changed
* Feature `cookies` is now optional and disabled by default. [#1981]
* `ws::hash_key` now returns array. [#2035]

### Removed
* re-export of `futures_channel::oneshot::Canceled` is removed from `error` mod. [#1994]
* `ResponseError` impl for `futures_channel::oneshot::Canceled` is removed. [#1994]

[#1981]: https://github.com/actix/actix-web/pull/1981
[#1994]: https://github.com/actix/actix-web/pull/1994
[#2035]: https://github.com/actix/actix-web/pull/2035


## 3.0.0-beta.3 - 2021-02-10
Expand Down
4 changes: 4 additions & 0 deletions actix-http/Cargo.toml
Expand Up @@ -103,6 +103,10 @@ version = "0.10.9"
package = "openssl"
features = ["vendored"]

[[example]]
name = "ws"
required-features = ["rustls"]

[[bench]]
name = "write-camel-case"
harness = false
Expand Down
107 changes: 107 additions & 0 deletions actix-http/examples/ws.rs
@@ -0,0 +1,107 @@
//! Sets up a WebSocket server over TCP and TLS.
//! Sends a heartbeat message every 4 seconds but does not respond to any incoming frames.

extern crate tls_rustls as rustls;

use std::{
env, io,
pin::Pin,
task::{Context, Poll},
time::Duration,
};

use actix_codec::Encoder;
use actix_http::{error::Error, ws, HttpService, Request, Response};
use actix_rt::time::{interval, Interval};
use actix_server::Server;
use bytes::{Bytes, BytesMut};
use bytestring::ByteString;
use futures_core::{ready, Stream};

#[actix_rt::main]
async fn main() -> io::Result<()> {
env::set_var("RUST_LOG", "actix=info,h2_ws=info");
env_logger::init();

Server::build()
.bind("tcp", ("127.0.0.1", 8080), || {
HttpService::build().h1(handler).tcp()
})?
.bind("tls", ("127.0.0.1", 8443), || {
HttpService::build().finish(handler).rustls(tls_config())
})?
.run()
.await
}

async fn handler(req: Request) -> Result<Response, Error> {
log::info!("handshaking");
let mut res = ws::handshake(req.head())?;

// handshake will always fail under HTTP/2

log::info!("responding");
Ok(res.streaming(Heartbeat::new(ws::Codec::new())))
}

struct Heartbeat {
codec: ws::Codec,
interval: Interval,
}

impl Heartbeat {
fn new(codec: ws::Codec) -> Self {
Self {
codec,
interval: interval(Duration::from_secs(4)),
}
}
}

impl Stream for Heartbeat {
type Item = Result<Bytes, Error>;

fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
log::trace!("poll");

ready!(self.as_mut().interval.poll_tick(cx));

let mut buffer = BytesMut::new();

self.as_mut()
.codec
.encode(
ws::Message::Text(ByteString::from_static("hello world")),
&mut buffer,
)
.unwrap();

Poll::Ready(Some(Ok(buffer.freeze())))
}
}

fn tls_config() -> rustls::ServerConfig {
use std::io::BufReader;

use rustls::{
internal::pemfile::{certs, pkcs8_private_keys},
NoClientAuth, ServerConfig,
};

let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap();
let cert_file = cert.serialize_pem().unwrap();
let key_file = cert.serialize_private_key_pem();

let mut config = ServerConfig::new(NoClientAuth::new());
let cert_file = &mut BufReader::new(cert_file.as_bytes());
let key_file = &mut BufReader::new(key_file.as_bytes());

let cert_chain = certs(cert_file).unwrap();
let mut keys = pkcs8_private_keys(key_file).unwrap();
config.set_single_cert(cert_chain, keys.remove(0)).unwrap();

config
}
27 changes: 17 additions & 10 deletions actix-http/src/ws/mod.rs
@@ -1,4 +1,4 @@
//! WebSocket protocol.
//! WebSocket protocol implementation.
//!
//! To setup a WebSocket, first perform the WebSocket handshake then on success convert `Payload` into a
//! `WsStream` stream and then use `WsWriter` to communicate with the peer.
Expand All @@ -8,9 +8,12 @@ use std::io;
use derive_more::{Display, Error, From};
use http::{header, Method, StatusCode};

use crate::error::ResponseError;
use crate::message::RequestHead;
use crate::response::{Response, ResponseBuilder};
use crate::{
error::ResponseError,
header::HeaderValue,
message::RequestHead,
response::{Response, ResponseBuilder},
};

mod codec;
mod dispatcher;
Expand Down Expand Up @@ -89,7 +92,7 @@ pub enum HandshakeError {
NoVersionHeader,

/// Unsupported WebSocket version.
#[display(fmt = "Unsupported version.")]
#[display(fmt = "Unsupported WebSocket version.")]
UnsupportedVersion,

/// WebSocket key is not set or wrong.
Expand All @@ -105,19 +108,19 @@ impl ResponseError for HandshakeError {
.finish(),

HandshakeError::NoWebsocketUpgrade => Response::BadRequest()
.reason("No WebSocket UPGRADE header found")
.reason("No WebSocket Upgrade header found")
.finish(),

HandshakeError::NoConnectionUpgrade => Response::BadRequest()
.reason("No CONNECTION upgrade")
.reason("No Connection upgrade")
.finish(),

HandshakeError::NoVersionHeader => Response::BadRequest()
.reason("Websocket version header is required")
.reason("WebSocket version header is required")
.finish(),

HandshakeError::UnsupportedVersion => Response::BadRequest()
.reason("Unsupported version")
.reason("Unsupported WebSocket version")
.finish(),

HandshakeError::BadWebsocketKey => {
Expand Down Expand Up @@ -193,7 +196,11 @@ pub fn handshake_response(req: &RequestHead) -> ResponseBuilder {
Response::build(StatusCode::SWITCHING_PROTOCOLS)
.upgrade("websocket")
.insert_header((header::TRANSFER_ENCODING, "chunked"))
.insert_header((header::SEC_WEBSOCKET_ACCEPT, key))
.insert_header((
header::SEC_WEBSOCKET_ACCEPT,
// key is known to be header value safe ascii
HeaderValue::from_bytes(&key).unwrap(),
))
.take()
}

Expand Down