Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/integrations/shield-axum/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ impl ErrorBody {
}
}

#[derive(Debug)]
pub struct RouteError(ShieldError);

impl IntoResponse for RouteError {
Expand Down
51 changes: 25 additions & 26 deletions packages/integrations/shield-axum/src/extract.rs
Original file line number Diff line number Diff line change
@@ -1,83 +1,82 @@
use async_trait::async_trait;
use axum::{
extract::FromRequestParts,
http::{request::Parts, StatusCode},
};
use shield::{Session, Shield, User};
use axum::{extract::FromRequestParts, http::request::Parts};
use shield::{ConfigurationError, Session, Shield, ShieldError, User};

use crate::error::RouteError;

pub struct ExtractShield<U: User>(pub Shield<U>);

#[async_trait]
impl<S: Send + Sync, U: User + Clone + 'static> FromRequestParts<S> for ExtractShield<U> {
type Rejection = (StatusCode, &'static str);
type Rejection = RouteError;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<Shield<U>>()
.cloned()
.map(ExtractShield)
.ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"Can't extract Shield. Is `ShieldLayer` enabled?",
))
.ok_or(ShieldError::Configuration(ConfigurationError::Invalid(
"Can't extract Shield. Is `ShieldLayer` enabled?".to_owned(),
)))
.map_err(RouteError::from)
}
}

pub struct ExtractSession(pub Session);

#[async_trait]
impl<S: Send + Sync> FromRequestParts<S> for ExtractSession {
type Rejection = (StatusCode, &'static str);
type Rejection = RouteError;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<Session>()
.cloned()
.map(ExtractSession)
.ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"Can't extract Shield session. Is `ShieldLayer` enabled?",
))
.ok_or(ShieldError::Configuration(ConfigurationError::Invalid(
"Can't extract Shield. Is `ShieldLayer` enabled?".to_owned(),
)))
.map_err(RouteError::from)
}
}

pub struct ExtractUser<U: User>(pub Option<U>);

#[async_trait]
impl<S: Send + Sync, U: User + Clone + 'static> FromRequestParts<S> for ExtractUser<U> {
type Rejection = (StatusCode, &'static str);
type Rejection = RouteError;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<Option<U>>()
.cloned()
.map(ExtractUser)
.ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"Can't extract Shield user. Is `ShieldLayer` enabled?",
))
.ok_or(ShieldError::Configuration(ConfigurationError::Invalid(
"Can't extract Shield. Is `ShieldLayer` enabled?".to_owned(),
)))
.map_err(RouteError::from)
}
}

pub struct UserRequired<U: User>(pub U);

#[async_trait]
impl<S: Send + Sync, U: User + Clone + 'static> FromRequestParts<S> for UserRequired<U> {
type Rejection = (StatusCode, &'static str);
type Rejection = RouteError;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<Option<U>>()
.cloned()
.ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"Can't extract Shield user. Is `ShieldLayer` enabled?",
))
.and_then(|user| user.ok_or((StatusCode::UNAUTHORIZED, "Unauthorized")))
.ok_or(ShieldError::Configuration(ConfigurationError::Invalid(
"Can't extract Shield. Is `ShieldLayer` enabled?".to_owned(),
)))
.and_then(|user| user.ok_or(ShieldError::Unauthorized))
.map(UserRequired)
.map_err(RouteError::from)
}
}
Loading