Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to set default error handlers to the ErrorHandler middleware #2784

Merged
merged 22 commits into from Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions actix-web/CHANGES.md
@@ -1,12 +1,16 @@
# Changelog

## Unreleased - 2022-xx-xx
### Changed
- Minimum supported Rust version (MSRV) is now 1.57 due to transitive `time` dependency.
### Added
- Add `ServiceRequest::{parts, request}()` getter methods. [#2786]
- Add `ErrorHandlers::default_handler()` (as well as `default_handler_{server, client}()`) to make registering handlers with the `ErrorHandlers` middleware easier. [#2784]

[#2784]: https://github.com/actix/actix-web/pull/2784
[#2786]: https://github.com/actix/actix-web/pull/2786


## 4.1.0 - 2022-06-11
### Added
- Add `ServiceRequest::extract()` to make it easier to use extractors when writing middlewares. [#2647]
Expand Down
297 changes: 290 additions & 7 deletions actix-web/src/middleware/err_handlers.rs
Expand Up @@ -30,11 +30,25 @@ pub enum ErrorHandlerResponse<B> {

type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>;

type DefaultHandler<B> = Option<Rc<ErrorHandler<B>>>;

/// Middleware for registering custom status code based error handlers.
///
/// Register handlers with the `ErrorHandlers::handler()` method to register a custom error handler
/// Register handlers with the [`ErrorHandlers::handler()`] method to register a custom error handler
/// for a given status code. Handlers can modify existing responses or create completely new ones.
///
/// To register a default handler, use the [`ErrorHandlers::default_handler()`] method. This
/// handler will be used only if a response has an error status code (400-599) that isn't covered by
/// a more specific handler (set with the [`handler()`][ErrorHandlers::handler] method). See examples
/// below.
///
/// To register a default for only client errors (400-499) or only server errors (500-599), use the
/// [`ErrorHandlers::default_handler_client()`] and [`ErrorHandlers::default_handler_server()`]
/// methods, respectively.
///
/// Any response with a status code that isn't covered by a specific handler or a default handler
/// will pass by unchanged by this middleware.
///
/// # Examples
/// ```
/// use actix_web::http::{header, StatusCode};
Expand All @@ -53,7 +67,70 @@ type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse
/// .wrap(ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, add_error_header))
/// .service(web::resource("/").route(web::get().to(HttpResponse::InternalServerError)));
/// ```
/// ## Registering default handler
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain the precedence of overlapping handlers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hope the new docs are clearer. lmk if anything still doesn't make sense!

/// ```
/// # use actix_web::http::{header, StatusCode};
/// # use actix_web::middleware::{ErrorHandlerResponse, ErrorHandlers};
/// # use actix_web::{dev, web, App, HttpResponse, Result};
/// fn add_error_header<B>(mut res: dev::ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
/// res.response_mut().headers_mut().insert(
/// header::CONTENT_TYPE,
/// header::HeaderValue::from_static("Error"),
/// );
/// Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
/// }
///
/// fn handle_bad_request<B>(mut res: dev::ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
/// res.response_mut().headers_mut().insert(
/// header::CONTENT_TYPE,
/// header::HeaderValue::from_static("Bad Request Error"),
/// );
/// Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
/// }
///
/// // Bad Request errors will hit `handle_bad_request()`, while all other errors will hit
/// // `add_error_header()`. The order in which the methods are called is not meaningful.
/// let app = App::new()
/// .wrap(
/// ErrorHandlers::new()
/// .default_handler(add_error_header)
/// .handler(StatusCode::BAD_REQUEST, handle_bad_request)
/// )
/// .service(web::resource("/").route(web::get().to(HttpResponse::InternalServerError)));
/// ```
/// Alternatively, you can set default handlers for only client or only server errors:
///
/// ```rust
/// # use actix_web::http::{header, StatusCode};
/// # use actix_web::middleware::{ErrorHandlerResponse, ErrorHandlers};
/// # use actix_web::{dev, web, App, HttpResponse, Result};
/// # fn add_error_header<B>(mut res: dev::ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
/// # res.response_mut().headers_mut().insert(
/// # header::CONTENT_TYPE,
/// # header::HeaderValue::from_static("Error"),
/// # );
/// # Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
/// # }
/// # fn handle_bad_request<B>(mut res: dev::ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
/// # res.response_mut().headers_mut().insert(
/// # header::CONTENT_TYPE,
/// # header::HeaderValue::from_static("Bad Request Error"),
/// # );
/// # Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
/// # }
/// // Bad request errors will hit `handle_bad_request()`, other client errors will hit
/// // `add_error_header()`, and server errors will pass through unchanged
/// let app = App::new()
/// .wrap(
/// ErrorHandlers::new()
/// .default_handler_client(add_error_header) // or .default_handler_server
/// .handler(StatusCode::BAD_REQUEST, handle_bad_request)
/// )
/// .service(web::resource("/").route(web::get().to(HttpResponse::InternalServerError)));
/// ```
pub struct ErrorHandlers<B> {
default_client: DefaultHandler<B>,
default_server: DefaultHandler<B>,
handlers: Handlers<B>,
}

Expand All @@ -62,6 +139,8 @@ type Handlers<B> = Rc<AHashMap<StatusCode, Box<ErrorHandler<B>>>>;
impl<B> Default for ErrorHandlers<B> {
fn default() -> Self {
ErrorHandlers {
default_client: Default::default(),
default_server: Default::default(),
handlers: Default::default(),
}
}
Expand All @@ -83,6 +162,66 @@ impl<B> ErrorHandlers<B> {
.insert(status, Box::new(handler));
self
}

/// Register a default error handler.
///
/// Any request with a status code that hasn't been given a specific other handler (by calling
/// [`.handler()`][ErrorHandlers::handler]) will fall back on this.
///
/// Note that this will overwrite any default handlers previously set by calling
/// [`.default_handler_client()`][ErrorHandlers::default_handler_client] or
/// [`.default_handler_server()`][ErrorHandlers::default_handler_server], but not any set by
/// calling [`.handler()`][ErrorHandlers::handler].
pub fn default_handler<F>(self, handler: F) -> Self
where
F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
{
let handler = Rc::new(handler);
Self {
default_server: Some(handler.clone()),
default_client: Some(handler),
..self
}
}

/// Register a handler on which to fall back for client error status codes (400-499).
pub fn default_handler_client<F>(self, handler: F) -> Self
where
F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
{
Self {
default_client: Some(Rc::new(handler)),
..self
}
}

/// Register a handler on which to fall back for server error status codes (500-599).
pub fn default_handler_server<F>(self, handler: F) -> Self
where
F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
{
Self {
default_server: Some(Rc::new(handler)),
..self
}
}

/// Selects the most appropriate handler for the given status code.
///
/// If the `handlers` map has an entry for that status code, that handler is returned.
/// Otherwise, fall back on the appropriate default handler.
fn get_handler<'a>(
status: &StatusCode,
default_client: Option<&'a ErrorHandler<B>>,
default_server: Option<&'a ErrorHandler<B>>,
handlers: &'a Handlers<B>,
) -> Option<&'a ErrorHandler<B>> {
handlers
.get(status)
.map(|h| h.as_ref())
.or_else(|| status.is_client_error().then(|| default_client).flatten())
.or_else(|| status.is_server_error().then(|| default_server).flatten())
}
}

impl<S, B> Transform<S, ServiceRequest> for ErrorHandlers<B>
Expand All @@ -99,13 +238,24 @@ where

fn new_transform(&self, service: S) -> Self::Future {
let handlers = self.handlers.clone();
Box::pin(async move { Ok(ErrorHandlersMiddleware { service, handlers }) })
let default_client = self.default_client.clone();
let default_server = self.default_server.clone();
Box::pin(async move {
Ok(ErrorHandlersMiddleware {
service,
default_client,
default_server,
handlers,
})
})
}
}

#[doc(hidden)]
pub struct ErrorHandlersMiddleware<S, B> {
service: S,
default_client: DefaultHandler<B>,
default_server: DefaultHandler<B>,
handlers: Handlers<B>,
}

Expand All @@ -123,8 +273,15 @@ where

fn call(&self, req: ServiceRequest) -> Self::Future {
let handlers = self.handlers.clone();
let default_client = self.default_client.clone();
let default_server = self.default_server.clone();
let fut = self.service.call(req);
ErrorHandlersFuture::ServiceFuture { fut, handlers }
ErrorHandlersFuture::ServiceFuture {
fut,
default_client,
default_server,
handlers,
}
}
}

Expand All @@ -137,6 +294,8 @@ pin_project! {
ServiceFuture {
#[pin]
fut: Fut,
default_client: DefaultHandler<B>,
default_server: DefaultHandler<B>,
handlers: Handlers<B>,
},
ErrorHandlerFuture {
Expand All @@ -153,10 +312,22 @@ where

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project() {
ErrorHandlersProj::ServiceFuture { fut, handlers } => {
ErrorHandlersProj::ServiceFuture {
fut,
default_client,
default_server,
handlers,
} => {
let res = ready!(fut.poll(cx))?;

match handlers.get(&res.status()) {
let status = res.status();

let handler = ErrorHandlers::get_handler(
&status,
default_client.as_mut().map(|f| Rc::as_ref(f)),
default_server.as_mut().map(|f| Rc::as_ref(f)),
handlers,
);
match handler {
Some(handler) => match handler(res)? {
ErrorHandlerResponse::Response(res) => Poll::Ready(Ok(res)),
ErrorHandlerResponse::Future(fut) => {
Expand All @@ -166,7 +337,6 @@ where
self.poll(cx)
}
},

None => Poll::Ready(Ok(res.map_into_left_body())),
}
}
Expand Down Expand Up @@ -298,4 +468,117 @@ mod tests {
"error in error handler"
);
}

#[actix_rt::test]
async fn default_error_handler() {
#[allow(clippy::unnecessary_wraps)]
fn error_handler<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
}

let make_mw = |status| async move {
ErrorHandlers::new()
.default_handler(error_handler)
.new_transform(test::status_service(status).into_service())
.await
.unwrap()
};
let mw_server = make_mw(StatusCode::INTERNAL_SERVER_ERROR).await;
let mw_client = make_mw(StatusCode::BAD_REQUEST).await;

let resp =
test::call_service(&mw_client, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");

let resp =
test::call_service(&mw_server, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
}

#[actix_rt::test]
async fn default_handlers_separate_client_server() {
#[allow(clippy::unnecessary_wraps)]
fn error_handler_client<B>(
mut res: ServiceResponse<B>,
) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
}

#[allow(clippy::unnecessary_wraps)]
fn error_handler_server<B>(
mut res: ServiceResponse<B>,
) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0002"));
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
}

let make_mw = |status| async move {
ErrorHandlers::new()
.default_handler_server(error_handler_server)
.default_handler_client(error_handler_client)
.new_transform(test::status_service(status).into_service())
.await
.unwrap()
};
let mw_server = make_mw(StatusCode::INTERNAL_SERVER_ERROR).await;
let mw_client = make_mw(StatusCode::BAD_REQUEST).await;

let resp =
test::call_service(&mw_client, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");

let resp =
test::call_service(&mw_server, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002");
}

#[actix_rt::test]
async fn default_handlers_specialization() {
#[allow(clippy::unnecessary_wraps)]
fn error_handler_client<B>(
mut res: ServiceResponse<B>,
) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
}

#[allow(clippy::unnecessary_wraps)]
fn error_handler_specific<B>(
mut res: ServiceResponse<B>,
) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0003"));
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
}

let make_mw = |status| async move {
ErrorHandlers::new()
.default_handler_client(error_handler_client)
.handler(StatusCode::UNPROCESSABLE_ENTITY, error_handler_specific)
.new_transform(test::status_service(status).into_service())
.await
.unwrap()
};
let mw_client = make_mw(StatusCode::BAD_REQUEST).await;
let mw_specific = make_mw(StatusCode::UNPROCESSABLE_ENTITY).await;

let resp =
test::call_service(&mw_client, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");

let resp =
test::call_service(&mw_specific, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0003");
}
}