From 1e47eecd15d88312bac0530b1c0778f7f1e71f9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABlle=20Huisman?= Date: Sun, 19 Jan 2025 11:40:51 +0100 Subject: [PATCH] feat(shield): add options --- Cargo.lock | 1 + examples/leptos-actix/src/main.rs | 3 ++- examples/leptos-axum/src/main.rs | 3 ++- packages/core/shield/Cargo.toml | 1 + packages/core/shield/src/lib.rs | 2 ++ packages/core/shield/src/options.rs | 16 ++++++++++++++++ packages/core/shield/src/provider.rs | 8 ++++++++ packages/core/shield/src/shield.rs | 19 +++++++++++++------ .../providers/shield-oauth/src/provider.rs | 7 +++++-- .../providers/shield-oidc/src/provider.rs | 13 +++++++------ 10 files changed, 57 insertions(+), 16 deletions(-) create mode 100644 packages/core/shield/src/options.rs diff --git a/Cargo.lock b/Cargo.lock index 7aed540..e55fbaf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4621,6 +4621,7 @@ name = "shield" version = "0.0.4" dependencies = [ "async-trait", + "bon", "chrono", "futures", "serde", diff --git a/examples/leptos-actix/src/main.rs b/examples/leptos-actix/src/main.rs index c30bc6e..ec40725 100644 --- a/examples/leptos-actix/src/main.rs +++ b/examples/leptos-actix/src/main.rs @@ -8,7 +8,7 @@ async fn main() -> std::io::Result<()> { use actix_web::{cookie::Key, web::Data, App, HttpServer}; use leptos::config::get_configuration; use leptos_actix::{generate_route_list, LeptosRoutes}; - use shield::Shield; + use shield::{Shield, ShieldOptions}; use shield_examples_leptos_actix::app::*; use shield_leptos_actix::{provide_actix_integration, ShieldMiddleware}; use shield_memory::{MemoryStorage, User}; @@ -52,6 +52,7 @@ async fn main() -> std::io::Result<()> { .client_secret("xcpQsaGbRILTljPtX4npjmYMBjKrariJ") .build()]), )], + ShieldOptions::default(), ); let shield_middleware = ShieldMiddleware::new(shield.clone()); diff --git a/examples/leptos-axum/src/main.rs b/examples/leptos-axum/src/main.rs index 70082b0..a6d0e77 100644 --- a/examples/leptos-axum/src/main.rs +++ b/examples/leptos-axum/src/main.rs @@ -7,7 +7,7 @@ async fn main() { use leptos::config::{get_configuration, LeptosOptions}; use leptos::logging::log; use leptos_axum::{generate_route_list, LeptosRoutes}; - use shield::Shield; + use shield::{Shield, ShieldOptions}; use shield_examples_leptos_axum::app::*; use shield_leptos_axum::{provide_axum_integration, AuthRoutes, ShieldLayer}; use shield_memory::{MemoryStorage, User}; @@ -53,6 +53,7 @@ async fn main() { )) .build()]), )], + ShieldOptions::default(), ); let shield_layer = ShieldLayer::new(shield.clone()); diff --git a/packages/core/shield/Cargo.toml b/packages/core/shield/Cargo.toml index 7fd08d5..e280cca 100644 --- a/packages/core/shield/Cargo.toml +++ b/packages/core/shield/Cargo.toml @@ -10,6 +10,7 @@ version.workspace = true [dependencies] async-trait.workspace = true +bon.workspace = true chrono = { workspace = true, features = ["serde"] } futures.workspace = true serde = { workspace = true, features = ["derive"] } diff --git a/packages/core/shield/src/lib.rs b/packages/core/shield/src/lib.rs index a36481b..69e5928 100644 --- a/packages/core/shield/src/lib.rs +++ b/packages/core/shield/src/lib.rs @@ -1,5 +1,6 @@ mod error; mod form; +mod options; mod provider; mod request; mod response; @@ -11,6 +12,7 @@ mod user; pub use error::*; pub use form::*; +pub use options::*; pub use provider::*; pub use request::*; pub use response::*; diff --git a/packages/core/shield/src/options.rs b/packages/core/shield/src/options.rs new file mode 100644 index 0000000..ddeda26 --- /dev/null +++ b/packages/core/shield/src/options.rs @@ -0,0 +1,16 @@ +use bon::Builder; + +#[derive(Builder, Clone, Debug)] +#[builder(on(String, into), state_mod(vis = "pub(crate)"))] +pub struct ShieldOptions { + #[builder(default = "/")] + pub sign_in_redirect: String, + #[builder(default = "/")] + pub sign_out_redirect: String, +} + +impl Default for ShieldOptions { + fn default() -> Self { + Self::builder().build() + } +} diff --git a/packages/core/shield/src/provider.rs b/packages/core/shield/src/provider.rs index 06f24b0..c653d06 100644 --- a/packages/core/shield/src/provider.rs +++ b/packages/core/shield/src/provider.rs @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize}; use crate::{ error::ShieldError, form::Form, + options::ShieldOptions, request::{SignInCallbackRequest, SignInRequest, SignOutRequest}, response::Response, session::Session, @@ -24,18 +25,21 @@ pub trait Provider: Send + Sync { &self, request: SignInRequest, session: Session, + options: &ShieldOptions, ) -> Result; async fn sign_in_callback( &self, request: SignInCallbackRequest, session: Session, + options: &ShieldOptions, ) -> Result; async fn sign_out( &self, request: SignOutRequest, session: Session, + options: &ShieldOptions, ) -> Result; } @@ -71,6 +75,7 @@ pub(crate) mod tests { request::{SignInCallbackRequest, SignInRequest, SignOutRequest}, response::Response, session::Session, + ShieldOptions, }; use super::{Provider, Subprovider}; @@ -110,6 +115,7 @@ pub(crate) mod tests { &self, _request: SignInRequest, _session: Session, + _options: &ShieldOptions, ) -> Result { todo!("redirect back?") } @@ -118,6 +124,7 @@ pub(crate) mod tests { &self, _request: SignInCallbackRequest, _session: Session, + _options: &ShieldOptions, ) -> Result { todo!("redirect back?") } @@ -126,6 +133,7 @@ pub(crate) mod tests { &self, _request: SignOutRequest, _session: Session, + _options: &ShieldOptions, ) -> Result { todo!("redirect back?") } diff --git a/packages/core/shield/src/shield.rs b/packages/core/shield/src/shield.rs index 2409a38..d8a54d0 100644 --- a/packages/core/shield/src/shield.rs +++ b/packages/core/shield/src/shield.rs @@ -5,6 +5,7 @@ use tracing::debug; use crate::{ error::{ProviderError, SessionError, ShieldError}, + options::ShieldOptions, provider::{Provider, Subprovider, SubproviderVisualisation}, request::{SignInCallbackRequest, SignInRequest, SignOutRequest}, response::Response, @@ -17,10 +18,11 @@ use crate::{ pub struct Shield { storage: Arc>, providers: Arc>>, + options: ShieldOptions, } impl Shield { - pub fn new(storage: S, providers: Vec>) -> Self + pub fn new(storage: S, providers: Vec>, options: ShieldOptions) -> Self where S: Storage + 'static, { @@ -32,6 +34,7 @@ impl Shield { .map(|provider| (provider.id(), provider)) .collect(), ), + options, } } @@ -105,7 +108,7 @@ impl Shield { None => return Err(ProviderError::ProviderNotFound(request.provider_id).into()), }; - provider.sign_in(request, session).await + provider.sign_in(request, session, &self.options).await } pub async fn sign_in_callback( @@ -120,7 +123,9 @@ impl Shield { None => return Err(ProviderError::ProviderNotFound(request.provider_id).into()), }; - provider.sign_in_callback(request, session).await + provider + .sign_in_callback(request, session, &self.options) + .await } pub async fn sign_out(&self, session: Session) -> Result { @@ -150,11 +155,11 @@ impl Shield { subprovider_id: authenticated.subprovider_id, }, session.clone(), + &self.options, ) .await? } else { - // TODO: Should be configurable. - Response::Redirect("/".to_owned()) + Response::Redirect(self.options.sign_out_redirect.clone()) }; session.purge().await?; @@ -206,13 +211,14 @@ mod tests { use crate::{ provider::tests::{TestProvider, TEST_PROVIDER_ID}, storage::tests::{TestStorage, TEST_STORAGE_ID}, + ShieldOptions, }; use super::Shield; #[test] fn test_storage() { - let shield = Shield::new(TestStorage::default(), vec![]); + let shield = Shield::new(TestStorage::default(), vec![], ShieldOptions::default()); assert_eq!(TEST_STORAGE_ID, shield.storage().id()); } @@ -225,6 +231,7 @@ mod tests { Arc::new(TestProvider::default().with_id("test1")), Arc::new(TestProvider::default().with_id("test2")), ], + ShieldOptions::default(), ); assert_eq!( diff --git a/packages/providers/shield-oauth/src/provider.rs b/packages/providers/shield-oauth/src/provider.rs index 8333586..ba61ccc 100644 --- a/packages/providers/shield-oauth/src/provider.rs +++ b/packages/providers/shield-oauth/src/provider.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; use shield::{ - Provider, ProviderError, Response, Session, ShieldError, SignInCallbackRequest, SignInRequest, - SignOutRequest, Subprovider, User, + Provider, ProviderError, Response, Session, ShieldError, ShieldOptions, SignInCallbackRequest, + SignInRequest, SignOutRequest, Subprovider, User, }; use crate::{storage::OauthStorage, subprovider::OauthSubprovider}; @@ -80,6 +80,7 @@ impl Provider for OauthProvider { &self, request: SignInRequest, _session: Session, + _options: &ShieldOptions, ) -> Result { let _subprovider = match request.subprovider_id { Some(subprovider_id) => self.oauth_subprovider_by_id(&subprovider_id).await?, @@ -93,6 +94,7 @@ impl Provider for OauthProvider { &self, request: SignInCallbackRequest, _session: Session, + _options: &ShieldOptions, ) -> Result { let _subprovider = match request.subprovider_id { Some(subprovider_id) => self.oauth_subprovider_by_id(&subprovider_id).await?, @@ -106,6 +108,7 @@ impl Provider for OauthProvider { &self, request: SignOutRequest, _session: Session, + _options: &ShieldOptions, ) -> Result { let _subprovider = match request.subprovider_id { Some(subprovider_id) => self.oauth_subprovider_by_id(&subprovider_id).await?, diff --git a/packages/providers/shield-oidc/src/provider.rs b/packages/providers/shield-oidc/src/provider.rs index f87724f..2bbd7ee 100644 --- a/packages/providers/shield-oidc/src/provider.rs +++ b/packages/providers/shield-oidc/src/provider.rs @@ -9,8 +9,8 @@ use openidconnect::{ }; use shield::{ Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Provider, ProviderError, - Response, Session, SessionError, ShieldError, SignInCallbackRequest, SignInRequest, - SignOutRequest, Subprovider, UpdateUser, User, + Response, Session, SessionError, ShieldError, ShieldOptions, SignInCallbackRequest, + SignInRequest, SignOutRequest, Subprovider, UpdateUser, User, }; use tracing::debug; @@ -196,6 +196,7 @@ impl Provider for OidcProvider { &self, request: SignInRequest, session: Session, + _options: &ShieldOptions, ) -> Result { let subprovider = match request.subprovider_id { Some(subprovider_id) => self.oidc_subprovider_by_id_or_slug(&subprovider_id).await?, @@ -261,6 +262,7 @@ impl Provider for OidcProvider { &self, request: SignInCallbackRequest, session: Session, + options: &ShieldOptions, ) -> Result { let (pkce_verifier, csrf, nonce) = { let session_data = session.data(); @@ -402,14 +404,14 @@ impl Provider for OidcProvider { session.update().await?; - // TODO: Should be configurable. - Ok(Response::Redirect("/".to_owned())) + Ok(Response::Redirect(options.sign_in_redirect.clone())) } async fn sign_out( &self, request: SignOutRequest, session: Session, + options: &ShieldOptions, ) -> Result { let subprovider = match request.subprovider_id { Some(subprovider_id) => self.oidc_subprovider_by_id_or_slug(&subprovider_id).await?, @@ -460,8 +462,7 @@ impl Provider for OidcProvider { } } - // TODO: Should be configurable. - Ok(Response::Redirect("/".to_owned())) + Ok(Response::Redirect(options.sign_out_redirect.clone())) } }