From ce4ba5e6ad644a5cc16355ebb3ba05db3f382fc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABlle=20Huisman?= Date: Fri, 24 Jan 2025 15:24:27 +0100 Subject: [PATCH] fix(shield-axum): use JSON errors for extractors --- .../integrations/shield-axum/src/error.rs | 1 + .../integrations/shield-axum/src/extract.rs | 51 +++++++++---------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/packages/integrations/shield-axum/src/error.rs b/packages/integrations/shield-axum/src/error.rs index 725da1a..5d0ffec 100644 --- a/packages/integrations/shield-axum/src/error.rs +++ b/packages/integrations/shield-axum/src/error.rs @@ -32,6 +32,7 @@ impl ErrorBody { } } +#[derive(Debug)] pub struct RouteError(ShieldError); impl IntoResponse for RouteError { diff --git a/packages/integrations/shield-axum/src/extract.rs b/packages/integrations/shield-axum/src/extract.rs index 90468b6..e713107 100644 --- a/packages/integrations/shield-axum/src/extract.rs +++ b/packages/integrations/shield-axum/src/extract.rs @@ -1,15 +1,14 @@ use async_trait::async_trait; -use axum::{ - extract::FromRequestParts, - http::{request::Parts, StatusCode}, -}; -use shield::{Session, Shield, User}; +use axum::{extract::FromRequestParts, http::request::Parts}; +use shield::{ConfigurationError, Session, Shield, ShieldError, User}; + +use crate::error::RouteError; pub struct ExtractShield(pub Shield); #[async_trait] impl FromRequestParts for ExtractShield { - type Rejection = (StatusCode, &'static str); + type Rejection = RouteError; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { parts @@ -17,10 +16,10 @@ impl FromRequestParts for ExtractS .get::>() .cloned() .map(ExtractShield) - .ok_or(( - StatusCode::INTERNAL_SERVER_ERROR, - "Can't extract Shield. Is `ShieldLayer` enabled?", - )) + .ok_or(ShieldError::Configuration(ConfigurationError::Invalid( + "Can't extract Shield. Is `ShieldLayer` enabled?".to_owned(), + ))) + .map_err(RouteError::from) } } @@ -28,7 +27,7 @@ pub struct ExtractSession(pub Session); #[async_trait] impl FromRequestParts for ExtractSession { - type Rejection = (StatusCode, &'static str); + type Rejection = RouteError; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { parts @@ -36,10 +35,10 @@ impl FromRequestParts for ExtractSession { .get::() .cloned() .map(ExtractSession) - .ok_or(( - StatusCode::INTERNAL_SERVER_ERROR, - "Can't extract Shield session. Is `ShieldLayer` enabled?", - )) + .ok_or(ShieldError::Configuration(ConfigurationError::Invalid( + "Can't extract Shield. Is `ShieldLayer` enabled?".to_owned(), + ))) + .map_err(RouteError::from) } } @@ -47,7 +46,7 @@ pub struct ExtractUser(pub Option); #[async_trait] impl FromRequestParts for ExtractUser { - type Rejection = (StatusCode, &'static str); + type Rejection = RouteError; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { parts @@ -55,10 +54,10 @@ impl FromRequestParts for ExtractU .get::>() .cloned() .map(ExtractUser) - .ok_or(( - StatusCode::INTERNAL_SERVER_ERROR, - "Can't extract Shield user. Is `ShieldLayer` enabled?", - )) + .ok_or(ShieldError::Configuration(ConfigurationError::Invalid( + "Can't extract Shield. Is `ShieldLayer` enabled?".to_owned(), + ))) + .map_err(RouteError::from) } } @@ -66,18 +65,18 @@ pub struct UserRequired(pub U); #[async_trait] impl FromRequestParts for UserRequired { - type Rejection = (StatusCode, &'static str); + type Rejection = RouteError; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { parts .extensions .get::>() .cloned() - .ok_or(( - StatusCode::INTERNAL_SERVER_ERROR, - "Can't extract Shield user. Is `ShieldLayer` enabled?", - )) - .and_then(|user| user.ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))) + .ok_or(ShieldError::Configuration(ConfigurationError::Invalid( + "Can't extract Shield. Is `ShieldLayer` enabled?".to_owned(), + ))) + .and_then(|user| user.ok_or(ShieldError::Unauthorized)) .map(UserRequired) + .map_err(RouteError::from) } }