Skip to content

Commit

Permalink
add timeout for accepting tls connections (#393)
Browse files Browse the repository at this point in the history
Co-authored-by: Rob Ede <robjtede@icloud.com>
  • Loading branch information
fakeshadow and robjtede committed Nov 16, 2021
1 parent ce8ec15 commit 7e7df2f
Show file tree
Hide file tree
Showing 10 changed files with 383 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .cargo/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ ci-check = "hack --workspace --feature-powerset --exclude-features=io-uring chec
ci-check-linux = "hack --workspace --feature-powerset check --tests --examples"

# tests avoiding io-uring feature
ci-test = "hack test --workspace --exclude=actix-rt --exclude=actix-server --all-features --lib --tests --no-fail-fast -- --nocapture"
ci-test = " hack --feature-powerset --exclude=actix-rt --exclude=actix-server --exclude-features=io-uring test --workspace --lib --tests --no-fail-fast -- --nocapture"
ci-test-rt = " hack --feature-powerset --exclude-features=io-uring test --package=actix-rt --lib --tests --no-fail-fast -- --nocapture"
ci-test-server = "hack --feature-powerset --exclude-features=io-uring test --package=actix-server --lib --tests --no-fail-fast -- --nocapture"

Expand Down
5 changes: 5 additions & 0 deletions actix-tls/CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Changes

## Unreleased - 2021-xx-xx
* Add configurable timeout for accepting TLS connection. [#393]
* Added `TlsError::Timeout` variant. [#393]
* All TLS acceptor services now use `TlsError` for their error types. [#393]

[#393]: https://github.com/actix/actix-net/pull/393


## 3.0.0-beta.8 - 2021-11-15
Expand Down
3 changes: 3 additions & 0 deletions actix-tls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ derive_more = "0.99.5"
futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] }
http = { version = "0.2.3", optional = true }
log = "0.4"
pin-project-lite = "0.2.7"
tokio-util = { version = "0.6.3", default-features = false }

# openssl
Expand All @@ -67,7 +68,9 @@ bytes = "1"
env_logger = "0.9"
futures-util = { version = "0.3.7", default-features = false, features = ["sink"] }
log = "0.4"
rcgen = "0.8"
rustls-pemfile = "0.2.1"
tokio-rustls = { version = "0.23", features = ["dangerous_configuration"] }
trust-dns-resolver = "0.20.0"

[[example]]
Expand Down
18 changes: 9 additions & 9 deletions actix-tls/src/accept/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
//! TLS acceptor services for Actix ecosystem.
//!
//! ## Crate Features
//! * `openssl` - TLS acceptor using the `openssl` crate.
//! * `rustls` - TLS acceptor using the `rustls` crate.
//! * `native-tls` - TLS acceptor using the `native-tls` crate.
//! TLS acceptor services.

use std::sync::atomic::{AtomicUsize, Ordering};

Expand All @@ -20,6 +15,10 @@ pub mod native_tls;

pub(crate) static MAX_CONN: AtomicUsize = AtomicUsize::new(256);

#[cfg(any(feature = "openssl", feature = "rustls", feature = "native-tls"))]
pub(crate) const DEFAULT_TLS_HANDSHAKE_TIMEOUT: std::time::Duration =
std::time::Duration::from_secs(3);

thread_local! {
static MAX_CONN_COUNTER: Counter = Counter::new(MAX_CONN.load(Ordering::Relaxed));
}
Expand All @@ -36,7 +35,8 @@ pub fn max_concurrent_tls_connect(num: usize) {

/// TLS error combined with service error.
#[derive(Debug)]
pub enum TlsError<E1, E2> {
Tls(E1),
Service(E2),
pub enum TlsError<TlsErr, SvcErr> {
Tls(TlsErr),
Timeout,
Service(SvcErr),
}
50 changes: 39 additions & 11 deletions actix-tls/src/accept/native_tls.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
use std::{
convert::Infallible,
io::{self, IoSlice},
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
time::Duration,
};

use actix_codec::{AsyncRead, AsyncWrite, ReadBuf};
use actix_rt::net::{ActixStream, Ready};
use actix_rt::{
net::{ActixStream, Ready},
time::timeout,
};
use actix_service::{Service, ServiceFactory};
use actix_utils::counter::Counter;
use futures_core::future::LocalBoxFuture;

pub use tokio_native_tls::native_tls::Error;
pub use tokio_native_tls::TlsAcceptor;
pub use tokio_native_tls::{native_tls::Error, TlsAcceptor};

use super::MAX_CONN_COUNTER;
use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER};

/// Wrapper type for `tokio_native_tls::TlsStream` in order to impl `ActixStream` trait.
pub struct TlsStream<T>(tokio_native_tls::TlsStream<T>);
Expand Down Expand Up @@ -94,13 +98,25 @@ impl<T: ActixStream> ActixStream for TlsStream<T> {
/// `native-tls` feature enables this `Acceptor` type.
pub struct Acceptor {
acceptor: TlsAcceptor,
handshake_timeout: Duration,
}

impl Acceptor {
/// Create `native-tls` based `Acceptor` service factory.
#[inline]
pub fn new(acceptor: TlsAcceptor) -> Self {
Acceptor { acceptor }
Acceptor {
acceptor,
handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
}
}

/// Limit the amount of time that the acceptor will wait for a TLS handshake to complete.
///
/// Default timeout is 3 seconds.
pub fn set_handshake_timeout(&mut self, handshake_timeout: Duration) -> &mut Self {
self.handshake_timeout = handshake_timeout;
self
}
}

Expand All @@ -109,13 +125,14 @@ impl Clone for Acceptor {
fn clone(&self) -> Self {
Self {
acceptor: self.acceptor.clone(),
handshake_timeout: self.handshake_timeout,
}
}
}

impl<T: ActixStream + 'static> ServiceFactory<T> for Acceptor {
type Response = TlsStream<T>;
type Error = Error;
type Error = TlsError<Error, Infallible>;
type Config = ();

type Service = NativeTlsAcceptorService;
Expand All @@ -127,21 +144,24 @@ impl<T: ActixStream + 'static> ServiceFactory<T> for Acceptor {
Ok(NativeTlsAcceptorService {
acceptor: self.acceptor.clone(),
conns: conns.clone(),
handshake_timeout: self.handshake_timeout,
})
});

Box::pin(async { res })
}
}

pub struct NativeTlsAcceptorService {
acceptor: TlsAcceptor,
conns: Counter,
handshake_timeout: Duration,
}

impl<T: ActixStream + 'static> Service<T> for NativeTlsAcceptorService {
type Response = TlsStream<T>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<TlsStream<T>, Error>>;
type Error = TlsError<Error, Infallible>;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.conns.available(cx) {
Expand All @@ -154,10 +174,18 @@ impl<T: ActixStream + 'static> Service<T> for NativeTlsAcceptorService {
fn call(&self, io: T) -> Self::Future {
let guard = self.conns.get();
let acceptor = self.acceptor.clone();

let dur = self.handshake_timeout;

Box::pin(async move {
let io = acceptor.accept(io).await;
drop(guard);
io.map(Into::into)
match timeout(dur, acceptor.accept(io)).await {
Ok(Ok(io)) => {
drop(guard);
Ok(TlsStream(io))
}
Ok(Err(err)) => Err(TlsError::Tls(err)),
Err(_timeout) => Err(TlsError::Timeout),
}
})
}
}
69 changes: 51 additions & 18 deletions actix-tls/src/accept/openssl.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
use std::{
convert::Infallible,
future::Future,
io::{self, IoSlice},
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
time::Duration,
};

use actix_codec::{AsyncRead, AsyncWrite, ReadBuf};
use actix_rt::net::{ActixStream, Ready};
use actix_rt::{
net::{ActixStream, Ready},
time::{sleep, Sleep},
};
use actix_service::{Service, ServiceFactory};
use actix_utils::counter::{Counter, CounterGuard};
use futures_core::{future::LocalBoxFuture, ready};
use futures_core::future::LocalBoxFuture;

pub use openssl::ssl::{
AlpnError, Error as SslError, HandshakeError, Ssl, SslAcceptor, SslAcceptorBuilder,
};
use pin_project_lite::pin_project;

use super::MAX_CONN_COUNTER;
use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER};

/// Wrapper type for `tokio_openssl::SslStream` in order to impl `ActixStream` trait.
pub struct TlsStream<T>(tokio_openssl::SslStream<T>);
Expand Down Expand Up @@ -96,13 +102,25 @@ impl<T: ActixStream> ActixStream for TlsStream<T> {
/// `openssl` feature enables this `Acceptor` type.
pub struct Acceptor {
acceptor: SslAcceptor,
handshake_timeout: Duration,
}

impl Acceptor {
/// Create OpenSSL based `Acceptor` service factory.
#[inline]
pub fn new(acceptor: SslAcceptor) -> Self {
Acceptor { acceptor }
Acceptor {
acceptor,
handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
}
}

/// Limit the amount of time that the acceptor will wait for a TLS handshake to complete.
///
/// Default timeout is 3 seconds.
pub fn set_handshake_timeout(&mut self, handshake_timeout: Duration) -> &mut Self {
self.handshake_timeout = handshake_timeout;
self
}
}

Expand All @@ -111,13 +129,14 @@ impl Clone for Acceptor {
fn clone(&self) -> Self {
Self {
acceptor: self.acceptor.clone(),
handshake_timeout: self.handshake_timeout,
}
}
}

impl<T: ActixStream> ServiceFactory<T> for Acceptor {
type Response = TlsStream<T>;
type Error = SslError;
type Error = TlsError<SslError, Infallible>;
type Config = ();
type Service = AcceptorService;
type InitError = ();
Expand All @@ -128,20 +147,23 @@ impl<T: ActixStream> ServiceFactory<T> for Acceptor {
Ok(AcceptorService {
acceptor: self.acceptor.clone(),
conns: conns.clone(),
handshake_timeout: self.handshake_timeout,
})
});

Box::pin(async { res })
}
}

pub struct AcceptorService {
acceptor: SslAcceptor,
conns: Counter,
handshake_timeout: Duration,
}

impl<T: ActixStream> Service<T> for AcceptorService {
type Response = TlsStream<T>;
type Error = SslError;
type Error = TlsError<SslError, Infallible>;
type Future = AcceptorServiceResponse<T>;

fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Expand All @@ -155,27 +177,38 @@ impl<T: ActixStream> Service<T> for AcceptorService {
fn call(&self, io: T) -> Self::Future {
let ssl_ctx = self.acceptor.context();
let ssl = Ssl::new(ssl_ctx).expect("Provided SSL acceptor was invalid.");

AcceptorServiceResponse {
_guard: self.conns.get(),
timeout: sleep(self.handshake_timeout),
stream: Some(tokio_openssl::SslStream::new(ssl, io).unwrap()),
}
}
}

pub struct AcceptorServiceResponse<T: ActixStream> {
stream: Option<tokio_openssl::SslStream<T>>,
_guard: CounterGuard,
pin_project! {
pub struct AcceptorServiceResponse<T: ActixStream> {
stream: Option<tokio_openssl::SslStream<T>>,
#[pin]
timeout: Sleep,
_guard: CounterGuard,
}
}

impl<T: ActixStream> Future for AcceptorServiceResponse<T> {
type Output = Result<TlsStream<T>, SslError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
ready!(Pin::new(self.stream.as_mut().unwrap()).poll_accept(cx))?;
Poll::Ready(Ok(self
.stream
.take()
.expect("SSL connect has resolved.")
.into()))
type Output = Result<TlsStream<T>, TlsError<SslError, Infallible>>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();

match Pin::new(this.stream.as_mut().unwrap()).poll_accept(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(this
.stream
.take()
.expect("Acceptor should not be polled after it has completed.")
.into())),
Poll::Ready(Err(err)) => Poll::Ready(Err(TlsError::Tls(err))),
Poll::Pending => this.timeout.poll(cx).map(|_| Err(TlsError::Timeout)),
}
}
}

0 comments on commit 7e7df2f

Please sign in to comment.