Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion examples/leptos-actix/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ async fn main() -> std::io::Result<()> {
use actix_web::{cookie::Key, web::Data, App, HttpServer};
use leptos::config::get_configuration;
use leptos_actix::{generate_route_list, LeptosRoutes};
use shield::Shield;
use shield::{Shield, ShieldOptions};
use shield_examples_leptos_actix::app::*;
use shield_leptos_actix::{provide_actix_integration, ShieldMiddleware};
use shield_memory::{MemoryStorage, User};
Expand Down Expand Up @@ -52,6 +52,7 @@ async fn main() -> std::io::Result<()> {
.client_secret("xcpQsaGbRILTljPtX4npjmYMBjKrariJ")
.build()]),
)],
ShieldOptions::default(),
);
let shield_middleware = ShieldMiddleware::new(shield.clone());

Expand Down
3 changes: 2 additions & 1 deletion examples/leptos-axum/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ async fn main() {
use leptos::config::{get_configuration, LeptosOptions};
use leptos::logging::log;
use leptos_axum::{generate_route_list, LeptosRoutes};
use shield::Shield;
use shield::{Shield, ShieldOptions};
use shield_examples_leptos_axum::app::*;
use shield_leptos_axum::{provide_axum_integration, AuthRoutes, ShieldLayer};
use shield_memory::{MemoryStorage, User};
Expand Down Expand Up @@ -53,6 +53,7 @@ async fn main() {
))
.build()]),
)],
ShieldOptions::default(),
);
let shield_layer = ShieldLayer::new(shield.clone());

Expand Down
1 change: 1 addition & 0 deletions packages/core/shield/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ version.workspace = true

[dependencies]
async-trait.workspace = true
bon.workspace = true
chrono = { workspace = true, features = ["serde"] }
futures.workspace = true
serde = { workspace = true, features = ["derive"] }
Expand Down
2 changes: 2 additions & 0 deletions packages/core/shield/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod error;
mod form;
mod options;
mod provider;
mod request;
mod response;
Expand All @@ -11,6 +12,7 @@ mod user;

pub use error::*;
pub use form::*;
pub use options::*;
pub use provider::*;
pub use request::*;
pub use response::*;
Expand Down
16 changes: 16 additions & 0 deletions packages/core/shield/src/options.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use bon::Builder;

#[derive(Builder, Clone, Debug)]
#[builder(on(String, into), state_mod(vis = "pub(crate)"))]
pub struct ShieldOptions {
#[builder(default = "/")]
pub sign_in_redirect: String,
#[builder(default = "/")]
pub sign_out_redirect: String,
}

impl Default for ShieldOptions {
fn default() -> Self {
Self::builder().build()
}
}
8 changes: 8 additions & 0 deletions packages/core/shield/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize};
use crate::{
error::ShieldError,
form::Form,
options::ShieldOptions,
request::{SignInCallbackRequest, SignInRequest, SignOutRequest},
response::Response,
session::Session,
Expand All @@ -24,18 +25,21 @@ pub trait Provider: Send + Sync {
&self,
request: SignInRequest,
session: Session,
options: &ShieldOptions,
) -> Result<Response, ShieldError>;

async fn sign_in_callback(
&self,
request: SignInCallbackRequest,
session: Session,
options: &ShieldOptions,
) -> Result<Response, ShieldError>;

async fn sign_out(
&self,
request: SignOutRequest,
session: Session,
options: &ShieldOptions,
) -> Result<Response, ShieldError>;
}

Expand Down Expand Up @@ -71,6 +75,7 @@ pub(crate) mod tests {
request::{SignInCallbackRequest, SignInRequest, SignOutRequest},
response::Response,
session::Session,
ShieldOptions,
};

use super::{Provider, Subprovider};
Expand Down Expand Up @@ -110,6 +115,7 @@ pub(crate) mod tests {
&self,
_request: SignInRequest,
_session: Session,
_options: &ShieldOptions,
) -> Result<Response, ShieldError> {
todo!("redirect back?")
}
Expand All @@ -118,6 +124,7 @@ pub(crate) mod tests {
&self,
_request: SignInCallbackRequest,
_session: Session,
_options: &ShieldOptions,
) -> Result<Response, ShieldError> {
todo!("redirect back?")
}
Expand All @@ -126,6 +133,7 @@ pub(crate) mod tests {
&self,
_request: SignOutRequest,
_session: Session,
_options: &ShieldOptions,
) -> Result<Response, ShieldError> {
todo!("redirect back?")
}
Expand Down
19 changes: 13 additions & 6 deletions packages/core/shield/src/shield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use tracing::debug;

use crate::{
error::{ProviderError, SessionError, ShieldError},
options::ShieldOptions,
provider::{Provider, Subprovider, SubproviderVisualisation},
request::{SignInCallbackRequest, SignInRequest, SignOutRequest},
response::Response,
Expand All @@ -17,10 +18,11 @@ use crate::{
pub struct Shield<U: User> {
storage: Arc<dyn Storage<U>>,
providers: Arc<HashMap<String, Arc<dyn Provider>>>,
options: ShieldOptions,
}

impl<U: User> Shield<U> {
pub fn new<S>(storage: S, providers: Vec<Arc<dyn Provider>>) -> Self
pub fn new<S>(storage: S, providers: Vec<Arc<dyn Provider>>, options: ShieldOptions) -> Self
where
S: Storage<U> + 'static,
{
Expand All @@ -32,6 +34,7 @@ impl<U: User> Shield<U> {
.map(|provider| (provider.id(), provider))
.collect(),
),
options,
}
}

Expand Down Expand Up @@ -105,7 +108,7 @@ impl<U: User> Shield<U> {
None => return Err(ProviderError::ProviderNotFound(request.provider_id).into()),
};

provider.sign_in(request, session).await
provider.sign_in(request, session, &self.options).await
}

pub async fn sign_in_callback(
Expand All @@ -120,7 +123,9 @@ impl<U: User> Shield<U> {
None => return Err(ProviderError::ProviderNotFound(request.provider_id).into()),
};

provider.sign_in_callback(request, session).await
provider
.sign_in_callback(request, session, &self.options)
.await
}

pub async fn sign_out(&self, session: Session) -> Result<Response, ShieldError> {
Expand Down Expand Up @@ -150,11 +155,11 @@ impl<U: User> Shield<U> {
subprovider_id: authenticated.subprovider_id,
},
session.clone(),
&self.options,
)
.await?
} else {
// TODO: Should be configurable.
Response::Redirect("/".to_owned())
Response::Redirect(self.options.sign_out_redirect.clone())
};

session.purge().await?;
Expand Down Expand Up @@ -206,13 +211,14 @@ mod tests {
use crate::{
provider::tests::{TestProvider, TEST_PROVIDER_ID},
storage::tests::{TestStorage, TEST_STORAGE_ID},
ShieldOptions,
};

use super::Shield;

#[test]
fn test_storage() {
let shield = Shield::new(TestStorage::default(), vec![]);
let shield = Shield::new(TestStorage::default(), vec![], ShieldOptions::default());

assert_eq!(TEST_STORAGE_ID, shield.storage().id());
}
Expand All @@ -225,6 +231,7 @@ mod tests {
Arc::new(TestProvider::default().with_id("test1")),
Arc::new(TestProvider::default().with_id("test2")),
],
ShieldOptions::default(),
);

assert_eq!(
Expand Down
7 changes: 5 additions & 2 deletions packages/providers/shield-oauth/src/provider.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use async_trait::async_trait;
use shield::{
Provider, ProviderError, Response, Session, ShieldError, SignInCallbackRequest, SignInRequest,
SignOutRequest, Subprovider, User,
Provider, ProviderError, Response, Session, ShieldError, ShieldOptions, SignInCallbackRequest,
SignInRequest, SignOutRequest, Subprovider, User,
};

use crate::{storage::OauthStorage, subprovider::OauthSubprovider};
Expand Down Expand Up @@ -80,6 +80,7 @@ impl<U: User> Provider for OauthProvider<U> {
&self,
request: SignInRequest,
_session: Session,
_options: &ShieldOptions,
) -> Result<Response, ShieldError> {
let _subprovider = match request.subprovider_id {
Some(subprovider_id) => self.oauth_subprovider_by_id(&subprovider_id).await?,
Expand All @@ -93,6 +94,7 @@ impl<U: User> Provider for OauthProvider<U> {
&self,
request: SignInCallbackRequest,
_session: Session,
_options: &ShieldOptions,
) -> Result<Response, ShieldError> {
let _subprovider = match request.subprovider_id {
Some(subprovider_id) => self.oauth_subprovider_by_id(&subprovider_id).await?,
Expand All @@ -106,6 +108,7 @@ impl<U: User> Provider for OauthProvider<U> {
&self,
request: SignOutRequest,
_session: Session,
_options: &ShieldOptions,
) -> Result<Response, ShieldError> {
let _subprovider = match request.subprovider_id {
Some(subprovider_id) => self.oauth_subprovider_by_id(&subprovider_id).await?,
Expand Down
13 changes: 7 additions & 6 deletions packages/providers/shield-oidc/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use openidconnect::{
};
use shield::{
Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Provider, ProviderError,
Response, Session, SessionError, ShieldError, SignInCallbackRequest, SignInRequest,
SignOutRequest, Subprovider, UpdateUser, User,
Response, Session, SessionError, ShieldError, ShieldOptions, SignInCallbackRequest,
SignInRequest, SignOutRequest, Subprovider, UpdateUser, User,
};
use tracing::debug;

Expand Down Expand Up @@ -196,6 +196,7 @@ impl<U: User> Provider for OidcProvider<U> {
&self,
request: SignInRequest,
session: Session,
_options: &ShieldOptions,
) -> Result<Response, ShieldError> {
let subprovider = match request.subprovider_id {
Some(subprovider_id) => self.oidc_subprovider_by_id_or_slug(&subprovider_id).await?,
Expand Down Expand Up @@ -261,6 +262,7 @@ impl<U: User> Provider for OidcProvider<U> {
&self,
request: SignInCallbackRequest,
session: Session,
options: &ShieldOptions,
) -> Result<Response, ShieldError> {
let (pkce_verifier, csrf, nonce) = {
let session_data = session.data();
Expand Down Expand Up @@ -402,14 +404,14 @@ impl<U: User> Provider for OidcProvider<U> {

session.update().await?;

// TODO: Should be configurable.
Ok(Response::Redirect("/".to_owned()))
Ok(Response::Redirect(options.sign_in_redirect.clone()))
}

async fn sign_out(
&self,
request: SignOutRequest,
session: Session,
options: &ShieldOptions,
) -> Result<Response, ShieldError> {
let subprovider = match request.subprovider_id {
Some(subprovider_id) => self.oidc_subprovider_by_id_or_slug(&subprovider_id).await?,
Expand Down Expand Up @@ -460,8 +462,7 @@ impl<U: User> Provider for OidcProvider<U> {
}
}

// TODO: Should be configurable.
Ok(Response::Redirect("/".to_owned()))
Ok(Response::Redirect(options.sign_out_redirect.clone()))
}
}

Expand Down
Loading