diff --git a/packages/integrations/shield-axum/src/extract.rs b/packages/integrations/shield-axum/src/extract.rs index f3bd391..90468b6 100644 --- a/packages/integrations/shield-axum/src/extract.rs +++ b/packages/integrations/shield-axum/src/extract.rs @@ -61,3 +61,23 @@ impl FromRequestParts for ExtractU )) } } + +pub struct UserRequired(pub U); + +#[async_trait] +impl FromRequestParts for UserRequired { + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + parts + .extensions + .get::>() + .cloned() + .ok_or(( + StatusCode::INTERNAL_SERVER_ERROR, + "Can't extract Shield user. Is `ShieldLayer` enabled?", + )) + .and_then(|user| user.ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))) + .map(UserRequired) + } +}