diff --git a/packages/core/shield/src/action.rs b/packages/core/shield/src/action.rs index 2872b6f..665b9a7 100644 --- a/packages/core/shield/src/action.rs +++ b/packages/core/shield/src/action.rs @@ -7,14 +7,16 @@ use crate::{ session::Session, }; -pub const SIGN_IN_ACTION_ID: &str = "sign-in"; -pub const SIGN_IN_CALLBACK_ACTION_ID: &str = "sign-in-callback"; -pub const SIGN_OUT_ACTION_ID: &str = "sign-out"; - #[async_trait] pub trait Action: ErasedAction + Send + Sync { fn id(&self) -> String; + fn name(&self) -> String; + + fn condition(&self, _provider: &P, _session: Session) -> Result { + Ok(true) + } + fn form(&self, provider: P) -> Form; async fn call( @@ -29,6 +31,14 @@ pub trait Action: ErasedAction + Send + Sync { pub trait ErasedAction: Send + Sync { fn erased_id(&self) -> String; + fn erased_name(&self) -> String; + + fn erased_condition( + &self, + provider: &(dyn Any + Send + Sync), + session: Session, + ) -> Result; + fn erased_form(&self, provider: Box) -> Form; async fn erased_call( @@ -48,6 +58,14 @@ macro_rules! erased_action { self.id() } + fn erased_name(&self) -> String { + self.name() + } + + fn erased_condition(&self, provider: &(dyn std::any::Any + Send + Sync), session: $crate::Session) -> Result { + self.condition(provider.downcast_ref().expect("TODO"), session) + } + fn erased_form(&self, provider: Box) -> $crate::Form { self.form(*provider.downcast().expect("TODO")) } @@ -57,7 +75,7 @@ macro_rules! erased_action { provider: Box, session: $crate::Session, request: $crate::Request, - ) -> Result<$crate::Response, ShieldError> { + ) -> Result<$crate::Response, $crate::ShieldError> { self.call(*provider.downcast().expect("TODO"), session, request) .await } @@ -76,7 +94,8 @@ pub(crate) mod tests { use super::Action; - pub const TEST_ACTION_ID: &str = "action"; + pub const TEST_ACTION_ID: &str = "test"; + pub const TEST_ACTION_NAME: &str = "Test"; #[derive(Default)] pub struct TestAction {} @@ -87,6 +106,10 @@ pub(crate) mod tests { TEST_ACTION_ID.to_owned() } + fn name(&self) -> String { + TEST_ACTION_NAME.to_owned() + } + fn form(&self, _provider: TestProvider) -> Form { Form { inputs: vec![] } } diff --git a/packages/core/shield/src/actions.rs b/packages/core/shield/src/actions.rs new file mode 100644 index 0000000..47d587d --- /dev/null +++ b/packages/core/shield/src/actions.rs @@ -0,0 +1,7 @@ +mod sign_in; +mod sign_in_callback; +mod sign_out; + +pub use sign_in::*; +pub use sign_in_callback::*; +pub use sign_out::*; diff --git a/packages/core/shield/src/actions/sign_in.rs b/packages/core/shield/src/actions/sign_in.rs new file mode 100644 index 0000000..0a14801 --- /dev/null +++ b/packages/core/shield/src/actions/sign_in.rs @@ -0,0 +1,14 @@ +const ACTION_ID: &str = "sign-in"; +const ACTION_NAME: &str = "Sign in"; + +pub struct SignInAction; + +impl SignInAction { + pub fn id() -> String { + ACTION_ID.to_owned() + } + + pub fn name() -> String { + ACTION_NAME.to_owned() + } +} diff --git a/packages/core/shield/src/actions/sign_in_callback.rs b/packages/core/shield/src/actions/sign_in_callback.rs new file mode 100644 index 0000000..53b7559 --- /dev/null +++ b/packages/core/shield/src/actions/sign_in_callback.rs @@ -0,0 +1,20 @@ +use crate::{Provider, Session, ShieldError}; + +const ACTION_ID: &str = "sign-in-callback"; +const ACTION_NAME: &str = "Sign in callback"; + +pub struct SignInCallbackAction; + +impl SignInCallbackAction { + pub fn id() -> String { + ACTION_ID.to_owned() + } + + pub fn name() -> String { + ACTION_NAME.to_owned() + } + + pub fn condition(_provider: &P, _session: Session) -> Result { + Ok(false) + } +} diff --git a/packages/core/shield/src/actions/sign_out.rs b/packages/core/shield/src/actions/sign_out.rs new file mode 100644 index 0000000..09445e0 --- /dev/null +++ b/packages/core/shield/src/actions/sign_out.rs @@ -0,0 +1,44 @@ +use crate::{ + Form, Input, InputType, InputTypeSubmit, Provider, Session, SessionError, ShieldError, +}; + +const ACTION_ID: &str = "sign-out"; +const ACTION_NAME: &str = "Sign out"; + +pub struct SignOutAction; + +impl SignOutAction { + pub fn id() -> String { + ACTION_ID.to_owned() + } + + pub fn name() -> String { + ACTION_NAME.to_owned() + } + + pub fn condition(provider: &P, session: Session) -> Result { + let session_data = session.data(); + let session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + Ok(session_data + .authentication + .as_ref() + .is_some_and(|authentication| { + authentication.method_id == provider.method_id() + && authentication.provider_id == provider.id() + })) + } + + pub fn form(_provider: P) -> Form { + Form { + inputs: vec![Input { + name: "submit".to_owned(), + label: None, + r#type: InputType::Submit(InputTypeSubmit {}), + value: Some(Self::name()), + }], + } + } +} diff --git a/packages/core/shield/src/lib.rs b/packages/core/shield/src/lib.rs index f8718e7..98c505e 100644 --- a/packages/core/shield/src/lib.rs +++ b/packages/core/shield/src/lib.rs @@ -1,4 +1,5 @@ mod action; +mod actions; mod error; mod form; mod method; @@ -13,6 +14,7 @@ mod storage; mod user; pub use action::*; +pub use actions::*; pub use error::*; pub use form::*; pub use method::*; diff --git a/packages/core/shield/src/shield.rs b/packages/core/shield/src/shield.rs index 0af2dbf..ef44e6b 100644 --- a/packages/core/shield/src/shield.rs +++ b/packages/core/shield/src/shield.rs @@ -3,8 +3,8 @@ use std::{any::Any, collections::HashMap, sync::Arc}; use futures::future::try_join_all; use crate::{ - error::ShieldError, form::Form, method::ErasedMethod, options::ShieldOptions, storage::Storage, - user::User, + Session, error::ShieldError, form::Form, method::ErasedMethod, options::ShieldOptions, + storage::Storage, user::User, }; #[derive(Clone)] @@ -64,7 +64,11 @@ impl Shield { } } - pub async fn action_forms(&self, action_id: &str) -> Result, ShieldError> { + pub async fn action_forms( + &self, + action_id: &str, + session: Session, + ) -> Result, ShieldError> { let mut forms = vec![]; for (_, method) in self.methods.iter() { @@ -73,6 +77,10 @@ impl Shield { }; for provider in method.erased_providers().await? { + if !action.erased_condition(&provider, session.clone())? { + continue; + } + let form = action.erased_form(provider); forms.push(form); diff --git a/packages/core/shield/src/shield_dyn.rs b/packages/core/shield/src/shield_dyn.rs index ffcd4d3..ae5ad2d 100644 --- a/packages/core/shield/src/shield_dyn.rs +++ b/packages/core/shield/src/shield_dyn.rs @@ -2,13 +2,17 @@ use std::{any::Any, sync::Arc}; use async_trait::async_trait; -use crate::{error::ShieldError, form::Form, shield::Shield, user::User}; +use crate::{Session, error::ShieldError, form::Form, shield::Shield, user::User}; #[async_trait] pub trait DynShield: Send + Sync { async fn providers(&self) -> Result>, ShieldError>; - async fn action_forms(&self, action_id: &str) -> Result, ShieldError>; + async fn action_forms( + &self, + action_id: &str, + session: Session, + ) -> Result, ShieldError>; } #[async_trait] @@ -17,8 +21,12 @@ impl DynShield for Shield { self.providers().await } - async fn action_forms(&self, action_id: &str) -> Result, ShieldError> { - self.action_forms(action_id).await + async fn action_forms( + &self, + action_id: &str, + session: Session, + ) -> Result, ShieldError> { + self.action_forms(action_id, session).await } } @@ -33,7 +41,11 @@ impl ShieldDyn { self.0.providers().await } - pub async fn action_forms(&self, action_id: &str) -> Result, ShieldError> { - self.0.action_forms(action_id).await + pub async fn action_forms( + &self, + action_id: &str, + session: Session, + ) -> Result, ShieldError> { + self.0.action_forms(action_id, session).await } } diff --git a/packages/integrations/shield-dioxus/src/integration.rs b/packages/integrations/shield-dioxus/src/integration.rs index 4a89740..3de83b9 100644 --- a/packages/integrations/shield-dioxus/src/integration.rs +++ b/packages/integrations/shield-dioxus/src/integration.rs @@ -21,4 +21,8 @@ impl DioxusIntegrationDyn { pub async fn extract_shield(&self) -> ShieldDyn { self.0.extract_shield().await } + + pub async fn extract_session(&self) -> Session { + self.0.extract_session().await + } } diff --git a/packages/integrations/shield-dioxus/src/routes/action.rs b/packages/integrations/shield-dioxus/src/routes/action.rs index 778fb83..9efba43 100644 --- a/packages/integrations/shield-dioxus/src/routes/action.rs +++ b/packages/integrations/shield-dioxus/src/routes/action.rs @@ -31,8 +31,9 @@ pub fn Action(props: ActionProps) -> Element { async fn forms(action_id: String) -> Result, ServerFnError> { let FromContext(integration): FromContext = extract().await?; let shield = integration.extract_shield().await; + let session = integration.extract_session().await; - let forms = shield.action_forms(&action_id).await?; + let forms = shield.action_forms(&action_id, session).await?; Ok(forms) } diff --git a/packages/methods/shield-credentials/src/actions.rs b/packages/methods/shield-credentials/src/actions.rs index 82b597f..757615c 100644 --- a/packages/methods/shield-credentials/src/actions.rs +++ b/packages/methods/shield-credentials/src/actions.rs @@ -1,3 +1,5 @@ mod sign_in; +mod sign_out; pub use sign_in::*; +pub use sign_out::*; diff --git a/packages/methods/shield-credentials/src/actions/sign_in.rs b/packages/methods/shield-credentials/src/actions/sign_in.rs index 89503d8..079a61e 100644 --- a/packages/methods/shield-credentials/src/actions/sign_in.rs +++ b/packages/methods/shield-credentials/src/actions/sign_in.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use async_trait::async_trait; use serde::de::DeserializeOwned; use shield::{ - Action, Authentication, Form, Request, Response, SIGN_IN_ACTION_ID, Session, SessionError, - ShieldError, User, erased_action, + Action, Authentication, Form, Request, Response, Session, SessionError, ShieldError, + SignInAction, User, erased_action, }; use crate::{credentials::Credentials, provider::CredentialsProvider}; @@ -24,7 +24,11 @@ impl Action { fn id(&self) -> String { - SIGN_IN_ACTION_ID.to_owned() + SignInAction::id() + } + + fn name(&self) -> String { + SignInAction::name() } fn form(&self, _provider: CredentialsProvider) -> Form { diff --git a/packages/methods/shield-credentials/src/actions/sign_out.rs b/packages/methods/shield-credentials/src/actions/sign_out.rs new file mode 100644 index 0000000..89b2336 --- /dev/null +++ b/packages/methods/shield-credentials/src/actions/sign_out.rs @@ -0,0 +1,41 @@ +use async_trait::async_trait; +use shield::{Action, Form, Request, Response, Session, ShieldError, SignOutAction, erased_action}; + +use crate::provider::CredentialsProvider; + +pub struct CredentialsSignOutAction; + +#[async_trait] +impl Action for CredentialsSignOutAction { + fn id(&self) -> String { + SignOutAction::id() + } + + fn name(&self) -> String { + SignOutAction::name() + } + + fn condition( + &self, + provider: &CredentialsProvider, + session: Session, + ) -> Result { + SignOutAction::condition(provider, session) + } + + fn form(&self, provider: CredentialsProvider) -> Form { + SignOutAction::form(provider) + } + + async fn call( + &self, + _provider: CredentialsProvider, + _session: Session, + _request: Request, + ) -> Result { + // TODO: sign out + Ok(Response::Default) + } +} + +erased_action!(CredentialsSignOutAction); diff --git a/packages/methods/shield-credentials/src/method.rs b/packages/methods/shield-credentials/src/method.rs index f5c2dde..c788790 100644 --- a/packages/methods/shield-credentials/src/method.rs +++ b/packages/methods/shield-credentials/src/method.rs @@ -4,7 +4,11 @@ use async_trait::async_trait; use serde::de::DeserializeOwned; use shield::{Action, Method, ShieldError, User, erased_method}; -use crate::{Credentials, actions::CredentialsSignInAction, provider::CredentialsProvider}; +use crate::{ + actions::{CredentialsSignInAction, CredentialsSignOutAction}, + credentials::Credentials, + provider::CredentialsProvider, +}; pub const CREDENTIALS_METHOD_ID: &str = "credentials"; @@ -29,9 +33,10 @@ impl Method Vec>> { - vec![Box::new(CredentialsSignInAction::new( - self.credentials.clone(), - ))] + vec![ + Box::new(CredentialsSignInAction::new(self.credentials.clone())), + Box::new(CredentialsSignOutAction), + ] } async fn providers(&self) -> Result, ShieldError> { diff --git a/packages/methods/shield-oauth/src/actions/sign_in.rs b/packages/methods/shield-oauth/src/actions/sign_in.rs index cb8d450..0f88b40 100644 --- a/packages/methods/shield-oauth/src/actions/sign_in.rs +++ b/packages/methods/shield-oauth/src/actions/sign_in.rs @@ -1,8 +1,8 @@ use async_trait::async_trait; use oauth2::{CsrfToken, PkceCodeChallenge, Scope, url::form_urlencoded::parse}; use shield::{ - Action, ConfigurationError, Form, Request, Response, SIGN_IN_ACTION_ID, Session, SessionError, - ShieldError, erased_action, + Action, ConfigurationError, Form, Request, Response, Session, SessionError, ShieldError, + SignInAction, erased_action, }; use crate::{ @@ -16,7 +16,11 @@ pub struct OauthSignInAction; #[async_trait] impl Action for OauthSignInAction { fn id(&self) -> String { - SIGN_IN_ACTION_ID.to_owned() + SignInAction::id() + } + + fn name(&self) -> String { + SignInAction::name() } fn form(&self, _provider: OauthProvider) -> Form { diff --git a/packages/methods/shield-oauth/src/actions/sign_in_callback.rs b/packages/methods/shield-oauth/src/actions/sign_in_callback.rs index d29ac3e..0cc7fe2 100644 --- a/packages/methods/shield-oauth/src/actions/sign_in_callback.rs +++ b/packages/methods/shield-oauth/src/actions/sign_in_callback.rs @@ -9,7 +9,7 @@ use oauth2::{ use secrecy::SecretString; use shield::{ Action, Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Form, Request, - Response, SIGN_IN_CALLBACK_ACTION_ID, Session, SessionError, ShieldError, UpdateUser, User, + Response, Session, SessionError, ShieldError, SignInCallbackAction, UpdateUser, User, erased_action, }; @@ -132,7 +132,15 @@ impl OauthSignInCallbackAction { #[async_trait] impl Action for OauthSignInCallbackAction { fn id(&self) -> String { - SIGN_IN_CALLBACK_ACTION_ID.to_owned() + SignInCallbackAction::id() + } + + fn name(&self) -> String { + SignInCallbackAction::name() + } + + fn condition(&self, provider: &OauthProvider, session: Session) -> Result { + SignInCallbackAction::condition(provider, session) } fn form(&self, _provider: OauthProvider) -> Form { diff --git a/packages/methods/shield-oauth/src/actions/sign_out.rs b/packages/methods/shield-oauth/src/actions/sign_out.rs index 614be40..94a9c8d 100644 --- a/packages/methods/shield-oauth/src/actions/sign_out.rs +++ b/packages/methods/shield-oauth/src/actions/sign_out.rs @@ -1,7 +1,5 @@ use async_trait::async_trait; -use shield::{ - Action, Form, Request, Response, SIGN_OUT_ACTION_ID, Session, ShieldError, erased_action, -}; +use shield::{Action, Form, Request, Response, Session, ShieldError, SignOutAction, erased_action}; use crate::provider::OauthProvider; @@ -10,11 +8,19 @@ pub struct OauthSignOutAction; #[async_trait] impl Action for OauthSignOutAction { fn id(&self) -> String { - SIGN_OUT_ACTION_ID.to_owned() + SignOutAction::id() } - fn form(&self, _provider: OauthProvider) -> Form { - Form { inputs: vec![] } + fn name(&self) -> String { + SignOutAction::name() + } + + fn condition(&self, provider: &OauthProvider, session: Session) -> Result { + SignOutAction::condition(provider, session) + } + + fn form(&self, provider: OauthProvider) -> Form { + SignOutAction::form(provider) } async fn call( @@ -24,6 +30,7 @@ impl Action for OauthSignOutAction { _request: Request, ) -> Result { // TODO: OAuth token revocation. + // TODO: Sign out. Ok(Response::Default) } diff --git a/packages/methods/shield-oidc/src/actions/sign_in.rs b/packages/methods/shield-oidc/src/actions/sign_in.rs index b85bd0d..dc1d57f 100644 --- a/packages/methods/shield-oidc/src/actions/sign_in.rs +++ b/packages/methods/shield-oidc/src/actions/sign_in.rs @@ -4,8 +4,8 @@ use openidconnect::{ url::form_urlencoded::parse, }; use shield::{ - Action, Form, Input, InputType, InputTypeSubmit, Provider, Request, Response, - SIGN_IN_ACTION_ID, Session, SessionError, ShieldError, erased_action, + Action, Form, Input, InputType, InputTypeSubmit, Provider, Request, Response, Session, + SessionError, ShieldError, SignInAction, erased_action, }; use crate::{ @@ -19,7 +19,11 @@ pub struct OidcSignInAction; #[async_trait] impl Action for OidcSignInAction { fn id(&self) -> String { - SIGN_IN_ACTION_ID.to_owned() + SignInAction::id() + } + + fn name(&self) -> String { + SignInAction::name() } fn form(&self, provider: OidcProvider) -> Form { diff --git a/packages/methods/shield-oidc/src/actions/sign_in_callback.rs b/packages/methods/shield-oidc/src/actions/sign_in_callback.rs index b97aa87..804c5fb 100644 --- a/packages/methods/shield-oidc/src/actions/sign_in_callback.rs +++ b/packages/methods/shield-oidc/src/actions/sign_in_callback.rs @@ -11,7 +11,7 @@ use openidconnect::{ use secrecy::SecretString; use shield::{ Action, Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Form, Request, - Response, SIGN_IN_CALLBACK_ACTION_ID, Session, SessionError, ShieldError, UpdateUser, User, + Response, Session, SessionError, ShieldError, SignInCallbackAction, UpdateUser, User, erased_action, }; use tracing::debug; @@ -143,7 +143,15 @@ impl OidcSignInCallbackAction { #[async_trait] impl Action for OidcSignInCallbackAction { fn id(&self) -> String { - SIGN_IN_CALLBACK_ACTION_ID.to_owned() + SignInCallbackAction::id() + } + + fn name(&self) -> String { + SignInCallbackAction::name() + } + + fn condition(&self, provider: &OidcProvider, session: Session) -> Result { + SignInCallbackAction::condition(provider, session) } fn form(&self, _provider: OidcProvider) -> Form { diff --git a/packages/methods/shield-oidc/src/actions/sign_out.rs b/packages/methods/shield-oidc/src/actions/sign_out.rs index 9cc21ee..e1639a6 100644 --- a/packages/methods/shield-oidc/src/actions/sign_out.rs +++ b/packages/methods/shield-oidc/src/actions/sign_out.rs @@ -1,7 +1,5 @@ use async_trait::async_trait; -use shield::{ - Action, Form, Request, Response, SIGN_OUT_ACTION_ID, Session, ShieldError, erased_action, -}; +use shield::{Action, Form, Request, Response, Session, ShieldError, SignOutAction, erased_action}; use crate::provider::OidcProvider; @@ -10,11 +8,19 @@ pub struct OidcSignOutAction; #[async_trait] impl Action for OidcSignOutAction { fn id(&self) -> String { - SIGN_OUT_ACTION_ID.to_owned() + SignOutAction::id() } - fn form(&self, _provider: OidcProvider) -> Form { - Form { inputs: vec![] } + fn name(&self) -> String { + SignOutAction::name() + } + + fn condition(&self, provider: &OidcProvider, session: Session) -> Result { + SignOutAction::condition(provider, session) + } + + fn form(&self, provider: OidcProvider) -> Form { + SignOutAction::form(provider) } async fn call( @@ -74,6 +80,8 @@ impl Action for OidcSignOutAction { // } // } + // TODO: Sign out. + Ok(Response::Default) } }