diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 49b9ffa..7d974fe 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -58,6 +58,9 @@ jobs: message: ${{ steps.extract-version.outputs.VERSION }} token: ${{ steps.app-token.outputs.token }} + - name: Reset and pull + run: git reset --hard && git pull + - name: Tag uses: bruno-fs/repo-tagger@1.0.0 with: diff --git a/Cargo.lock b/Cargo.lock index 726f8fa..2870fd2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5696,6 +5696,7 @@ dependencies = [ "oauth2", "secrecy", "serde", + "serde_json", "shield", ] @@ -5710,6 +5711,7 @@ dependencies = [ "openidconnect", "secrecy", "serde", + "serde_json", "shield", "tracing", ] diff --git a/examples/dioxus-axum/src/main.rs b/examples/dioxus-axum/src/main.rs index 50658da..40ed6b9 100644 --- a/examples/dioxus-axum/src/main.rs +++ b/examples/dioxus-axum/src/main.rs @@ -24,7 +24,7 @@ async fn main() { }; use shield::{ErasedMethod, Method, Shield, ShieldOptions}; use shield_bootstrap::BootstrapDioxusStyle; - use shield_dioxus_axum::{AxumDioxusIntegration, ShieldLayer}; + use shield_dioxus_axum::{AuthRoutes, AxumDioxusIntegration, ShieldLayer}; use shield_memory::{MemoryStorage, User}; use shield_oidc::{Keycloak, OidcMethod}; use tokio::net::TcpListener; @@ -45,7 +45,7 @@ async fn main() { let storage = MemoryStorage::new(); let shield = Shield::new( storage.clone(), - vec![ + vec![Arc::new( OidcMethod::new(storage).with_providers([Keycloak::builder( "keycloak", "http://localhost:18080/realms/Shield", @@ -59,13 +59,14 @@ async fn main() { .unwrap_or_else(|| addr.port()) )) .build()]), - ], + )], ShieldOptions::default(), ); let shield_layer = ShieldLayer::new(shield.clone()); // Initialize router let router = Router::new() + .nest("/api/auth", AuthRoutes::router::()) .serve_dioxus_application( ServeConfig::builder() .context(AxumDioxusIntegration::::default().context()) diff --git a/examples/leptos-axum/src/app.rs b/examples/leptos-axum/src/app.rs index 9dcf9a6..61ff63e 100644 --- a/examples/leptos-axum/src/app.rs +++ b/examples/leptos-axum/src/app.rs @@ -1,12 +1,12 @@ use leptos::prelude::*; use leptos_meta::{MetaTags, Title, provide_meta_context}; use leptos_router::{ - components::{Outlet, ParentRoute, Router, Routes}, + components::{Outlet, ParentRoute, Route, Router, Routes}, path, }; use shield_leptos::ShieldRouter; -// use crate::home::HomePage; +use crate::home::HomePage; pub fn shell(options: LeptosOptions) -> impl IntoView { view! { @@ -44,7 +44,7 @@ pub fn App() -> impl IntoView {
- // + diff --git a/examples/leptos-axum/src/home.rs b/examples/leptos-axum/src/home.rs index c1a6a68..f760eb9 100644 --- a/examples/leptos-axum/src/home.rs +++ b/examples/leptos-axum/src/home.rs @@ -1,5 +1,8 @@ -use leptos::{either::Either, prelude::*}; -use leptos_router::components::A; +use leptos::{ + // either::Either, + prelude::*, +}; +// use leptos_router::components::A; use shield_leptos::LeptosUser; #[server] @@ -11,31 +14,31 @@ pub async fn user() -> Result, ServerFnError> { #[component] pub fn HomePage() -> impl IntoView { - let user = OnceResource::new(user()); + // let user = OnceResource::new(user()); view! {

"Shield Leptos Axum Example"

- - {move || Suspend::new(async move { match user.await { - Ok(user) => Either::Left(match user { - Some(user) => Either::Left(view! { -

User ID: {user.id}

+ // + // {move || Suspend::new(async move { match user.await { + // Ok(user) => Either::Left(match user { + // Some(user) => Either::Left(view! { + //

User ID: {user.id}

- - - - }), - None => Either::Right(view! { - - - - }), - }), - Err(err) => Either::Right(view! { - {err.to_string()} - }) - }})} -
+ // + // + // + // }), + // None => Either::Right(view! { + // + // + // + // }), + // }), + // Err(err) => Either::Right(view! { + // {err.to_string()} + // }) + // }})} + //
} } diff --git a/packages/core/shield/src/action.rs b/packages/core/shield/src/action.rs index 8097daf..f21a69c 100644 --- a/packages/core/shield/src/action.rs +++ b/packages/core/shield/src/action.rs @@ -4,8 +4,12 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use crate::{ - error::ShieldError, form::Form, provider::Provider, request::Request, response::Response, - session::Session, + error::ShieldError, + form::Form, + provider::Provider, + request::Request, + response::Response, + session::{BaseSession, MethodSession}, }; // TODO: Think of a better name. @@ -31,12 +35,12 @@ pub struct ActionProviderForm { } #[async_trait] -pub trait Action: ErasedAction + Send + Sync { +pub trait Action: ErasedAction + Send + Sync { fn id(&self) -> String; fn name(&self) -> String; - fn condition(&self, _provider: &P, _session: Session) -> Result { + fn condition(&self, _provider: &P, _session: &MethodSession) -> Result { Ok(true) } @@ -45,7 +49,7 @@ pub trait Action: ErasedAction + Send + Sync { async fn call( &self, provider: P, - session: Session, + session: &MethodSession, request: Request, ) -> Result; } @@ -59,7 +63,8 @@ pub trait ErasedAction: Send + Sync { fn erased_condition( &self, provider: &(dyn Any + Send + Sync), - session: Session, + base_session: &BaseSession, + method_session: &(dyn Any + Send + Sync), ) -> Result; async fn erased_forms( @@ -70,7 +75,8 @@ pub trait ErasedAction: Send + Sync { async fn erased_call( &self, provider: Box, - session: Session, + base_session: &BaseSession, + method_session: &(dyn Any + Send + Sync), request: Request, ) -> Result; } @@ -88,21 +94,44 @@ macro_rules! erased_action { self.name() } - fn erased_condition(&self, provider: &(dyn std::any::Any + Send + Sync), session: $crate::Session) -> Result { - self.condition(provider.downcast_ref().expect("TODO"), session) + fn erased_condition( + &self, + provider: &(dyn std::any::Any + Send + Sync), + base_session: &$crate::BaseSession, + method_session: &(dyn std::any::Any + Send + Sync) + ) -> Result { + self.condition( + provider.downcast_ref().expect("Provider should be downcast"), + &MethodSession { + base: base_session, + method: method_session.downcast_ref().expect("Session should be downcast"), + }, + ) } - async fn erased_forms(&self, provider: Box) -> Result, $crate::ShieldError> { - self.forms(*provider.downcast().expect("TODO")).await + async fn erased_forms( + &self, + provider: Box + ) -> Result, $crate::ShieldError> { + self.forms(*provider.downcast().expect("Provider should be downcast")).await } async fn erased_call( &self, provider: Box, - session: $crate::Session, + base_session: &$crate::BaseSession, + method_session: &(dyn std::any::Any + Send + Sync), request: $crate::Request, ) -> Result<$crate::Response, $crate::ShieldError> { - self.call(*provider.downcast().expect("TODO"), session, request) + self + .call( + *provider.downcast().expect("Provider should be downcast"), + &$crate::MethodSession { + base: base_session, + method: method_session.downcast_ref().expect("Session should be downcast"), + }, + request + ) .await } } diff --git a/packages/core/shield/src/actions/sign_in_callback.rs b/packages/core/shield/src/actions/sign_in_callback.rs index 03c4481..76d743c 100644 --- a/packages/core/shield/src/actions/sign_in_callback.rs +++ b/packages/core/shield/src/actions/sign_in_callback.rs @@ -1,4 +1,4 @@ -use crate::{Provider, Session, ShieldError}; +use crate::{MethodSession, Provider, ShieldError}; const ACTION_ID: &str = "sign-in-callback"; const ACTION_NAME: &str = "Sign in callback"; @@ -14,7 +14,10 @@ impl SignInCallbackAction { ACTION_NAME.to_owned() } - pub fn condition(_provider: &P, _session: Session) -> Result { + pub fn condition( + _provider: &P, + _session: &MethodSession, + ) -> Result { Ok(true) } } diff --git a/packages/core/shield/src/actions/sign_out.rs b/packages/core/shield/src/actions/sign_out.rs index 3d90026..0f4542a 100644 --- a/packages/core/shield/src/actions/sign_out.rs +++ b/packages/core/shield/src/actions/sign_out.rs @@ -1,6 +1,4 @@ -use crate::{ - Form, Input, InputType, InputTypeSubmit, Provider, Session, SessionError, ShieldError, -}; +use crate::{Form, Input, InputType, InputTypeSubmit, MethodSession, Provider, ShieldError}; const ACTION_ID: &str = "sign-out"; const ACTION_NAME: &str = "Sign out"; @@ -16,13 +14,12 @@ impl SignOutAction { ACTION_NAME.to_owned() } - pub fn condition(provider: &P, session: Session) -> Result { - let session_data = session.data(); - let session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - Ok(session_data + pub fn condition( + provider: &P, + session: &MethodSession, + ) -> Result { + Ok(session + .base .authentication .as_ref() .is_some_and(|authentication| { diff --git a/packages/core/shield/src/method.rs b/packages/core/shield/src/method.rs index ef2e5bf..46919da 100644 --- a/packages/core/shield/src/method.rs +++ b/packages/core/shield/src/method.rs @@ -1,24 +1,39 @@ use std::any::Any; use async_trait::async_trait; +use serde::{Serialize, de::DeserializeOwned}; -use crate::{ErasedAction, action::Action, error::ShieldError, provider::Provider}; +use crate::{ + ErasedAction, + action::Action, + error::{SessionError, ShieldError}, + provider::Provider, +}; #[async_trait] -pub trait Method: Send + Sync { +pub trait Method: Send + Sync { + type Provider: Provider; + type Session: DeserializeOwned + Serialize; + fn id(&self) -> String; - fn actions(&self) -> Vec>>; + fn actions(&self) -> Vec>>; - fn action_by_id(&self, action_id: &str) -> Option>> { + fn action_by_id( + &self, + action_id: &str, + ) -> Option>> { self.actions() .into_iter() .find(|action| action.id() == action_id) } - async fn providers(&self) -> Result, ShieldError>; + async fn providers(&self) -> Result, ShieldError>; - async fn provider_by_id(&self, provider_id: Option<&str>) -> Result, ShieldError> { + async fn provider_by_id( + &self, + provider_id: Option<&str>, + ) -> Result, ShieldError> { Ok(self .providers() .await? @@ -43,6 +58,11 @@ pub trait ErasedMethod: Send + Sync { &self, provider_id: Option<&str>, ) -> Result>, ShieldError>; + + fn erased_deserialize_session( + &self, + value: Option<&str>, + ) -> Result, SessionError>; } #[macro_export] @@ -71,7 +91,7 @@ macro_rules! erased_method { async fn erased_providers( &self, - ) -> Result, Box)>, ShieldError> { + ) -> Result, Box)>, $crate::ShieldError> { self.providers().await.map(|providers| { providers .into_iter() @@ -83,12 +103,27 @@ macro_rules! erased_method { async fn erased_provider_by_id( &self, provider_id: Option<&str>, - ) -> Result>, ShieldError> { + ) -> Result>, $crate::ShieldError> { self.provider_by_id(provider_id).await.map(|provider| { provider .map(|provider| Box::new(provider) as Box) }) } + + fn erased_deserialize_session( + &self, + value: Option<&str> + ) -> Result, $crate::SessionError> { + type Session $( < $( $generic_name ),+ > )* = <$method $( < $( $generic_name ),+ > )* as $crate::Method>::Session; + + let session = match value { + Some(value) => serde_json::from_str:: )* >(value) + .map_err(|err| $crate::SessionError::Serialization(err.to_string()))?, + None => Session $( ::< $( $generic_name ),+ > )* ::default() + }; + + Ok(Box::new(session) as Box) + } } }; } diff --git a/packages/core/shield/src/response.rs b/packages/core/shield/src/response.rs index 5420290..f5e3bec 100644 --- a/packages/core/shield/src/response.rs +++ b/packages/core/shield/src/response.rs @@ -1,7 +1,35 @@ use serde::{Deserialize, Serialize}; +use crate::SessionAction; + +#[derive(Clone, Debug)] +pub struct Response { + pub r#type: ResponseType, + pub session_actions: Vec, +} + +impl Response { + pub fn new(r#type: ResponseType) -> Self { + Self { + r#type, + session_actions: vec![], + } + } + + pub fn session_action(mut self, session_action: SessionAction) -> Self { + self.session_actions.push(session_action); + self + } + + pub fn session_actions(mut self, session_actions: &mut Vec) -> Self { + self.session_actions.append(session_actions); + self + } +} + +// TODO: Rename to something more sensible. #[derive(Clone, Debug, Deserialize, Serialize)] -pub enum Response { +pub enum ResponseType { // TODO: Remove temporary default variant. Default, Redirect(String), diff --git a/packages/core/shield/src/session.rs b/packages/core/shield/src/session.rs index 59c464a..4701466 100644 --- a/packages/core/shield/src/session.rs +++ b/packages/core/shield/src/session.rs @@ -6,7 +6,7 @@ use std::{ use async_trait::async_trait; use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use crate::error::SessionError; +use crate::{error::SessionError, user::User}; #[async_trait] pub trait SessionStorage: Send + Sync { @@ -46,8 +46,7 @@ impl Session { #[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct SessionData { - pub redirect_url: Option, - pub authentication: Option, + pub base: BaseSession, pub methods: HashMap, } @@ -76,6 +75,25 @@ impl SessionData { Ok(()) } + + pub(crate) fn method_str(&self, method_id: &str) -> Option<&str> { + self.methods.get(method_id).map(String::as_str) + } + + pub(crate) fn set_method_str(&mut self, method_id: &str, value: &str) { + self.methods.insert(method_id.to_owned(), value.to_owned()); + } +} + +#[derive(Clone, Debug)] +pub struct MethodSession<'a, S> { + pub base: &'a BaseSession, + pub method: &'a S, +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct BaseSession { + pub authentication: Option, } #[derive(Clone, Debug, Default, Deserialize, Serialize)] @@ -84,3 +102,72 @@ pub struct Authentication { pub provider_id: Option, pub user_id: String, } + +#[derive(Clone, Debug)] +pub enum SessionAction { + Authenticate { user_id: String }, + Unauthenticate, + Data(String), +} + +impl SessionAction { + pub fn authenticate(user: U) -> Self { + Self::Authenticate { user_id: user.id() } + } + + pub fn unauthenticate() -> Self { + Self::Unauthenticate + } + + pub fn data(value: T) -> Result { + let value = serde_json::to_string(&value) + .map_err(|err| SessionError::Serialization(err.to_string()))?; + + Ok(Self::Data(value)) + } + + pub(crate) async fn call( + &self, + method_id: &str, + provider_id: Option<&str>, + session: &Session, + ) -> Result<(), SessionError> { + match self { + Self::Authenticate { user_id } => { + session.renew().await?; + + { + let session_data = session.data(); + let mut session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.base.authentication = Some(Authentication { + method_id: method_id.to_owned(), + provider_id: provider_id.map(ToOwned::to_owned), + user_id: user_id.clone(), + }); + } + + session.update().await?; + } + Self::Unauthenticate => { + session.purge().await?; + } + Self::Data(value) => { + { + let session_data = session.data(); + let mut session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.set_method_str(method_id, value); + } + + session.update().await?; + } + } + + Ok(()) + } +} diff --git a/packages/core/shield/src/shield.rs b/packages/core/shield/src/shield.rs index 2b72df7..7227115 100644 --- a/packages/core/shield/src/shield.rs +++ b/packages/core/shield/src/shield.rs @@ -1,16 +1,15 @@ use std::{any::Any, collections::HashMap, sync::Arc}; use futures::future::try_join_all; -use tracing::warn; +use tracing::{debug, warn}; use crate::{ - ActionMethodForm, - action::{ActionForms, ActionProviderForm}, - error::{ActionError, MethodError, ProviderError, ShieldError}, + action::{ActionForms, ActionMethodForm, ActionProviderForm}, + error::{ActionError, MethodError, ProviderError, SessionError, ShieldError}, method::ErasedMethod, options::ShieldOptions, request::Request, - response::Response, + response::ResponseType, session::Session, storage::Storage, user::User, @@ -93,6 +92,18 @@ impl Shield { continue; }; + let (base_session, method_session) = { + let session_data = session.data(); + let session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + ( + session_data.base.clone(), + method.erased_deserialize_session(session_data.method_str(method_id))?, + ) + }; + let name = action.erased_name(); if let Some(action_name) = &action_name && *action_name != name @@ -103,7 +114,7 @@ impl Shield { let mut provider_forms = vec![]; for (provider_id, provider) in method.erased_providers().await? { - if !action.erased_condition(&*provider, session.clone())? { + if !action.erased_condition(&*provider, &base_session, &*method_session)? { continue; } @@ -136,7 +147,7 @@ impl Shield { provider_id: Option<&str>, session: Session, request: Request, - ) -> Result { + ) -> Result { let method = self.method_by_id(method_id) .ok_or(ShieldError::Method(MethodError::NotFound( @@ -157,12 +168,33 @@ impl Shield { provider_id.map(ToOwned::to_owned), )))?; - let response = action.erased_call(provider, session.clone(), request).await; + let (base_session, method_session) = { + let session_data = session.data(); + let session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + ( + session_data.base.clone(), + method.erased_deserialize_session(session_data.method_str(method_id))?, + ) + }; + + let response = action + .erased_call(provider, &base_session, &*method_session, request) + .await?; + + debug!("response {:#?}", response); + + for session_action in &response.session_actions { + session_action + .call(method_id, provider_id, &session) + .await?; + } - // TODO: Should update always be called? - session.update().await?; + debug!("session actions processed"); - response + Ok(response.r#type) } } diff --git a/packages/core/shield/src/shield_dyn.rs b/packages/core/shield/src/shield_dyn.rs index 5363e8a..42fb468 100644 --- a/packages/core/shield/src/shield_dyn.rs +++ b/packages/core/shield/src/shield_dyn.rs @@ -3,7 +3,7 @@ use std::{any::Any, sync::Arc}; use async_trait::async_trait; use crate::{ - action::ActionForms, error::ShieldError, request::Request, response::Response, + action::ActionForms, error::ShieldError, request::Request, response::ResponseType, session::Session, shield::Shield, user::User, }; @@ -24,7 +24,7 @@ pub trait DynShield: Send + Sync { provider_id: Option<&str>, session: Session, request: Request, - ) -> Result; + ) -> Result; } #[async_trait] @@ -48,7 +48,7 @@ impl DynShield for Shield { provider_id: Option<&str>, session: Session, request: Request, - ) -> Result { + ) -> Result { self.call(action_id, method_id, provider_id, session, request) .await } @@ -80,7 +80,7 @@ impl ShieldDyn { provider_id: Option<&str>, session: Session, request: Request, - ) -> Result { + ) -> Result { self.0 .call(action_id, method_id, provider_id, session, request) .await diff --git a/packages/integrations/shield-axum/src/routes/action.rs b/packages/integrations/shield-axum/src/routes/action.rs index bba6d2d..e01c1ba 100644 --- a/packages/integrations/shield-axum/src/routes/action.rs +++ b/packages/integrations/shield-axum/src/routes/action.rs @@ -1,9 +1,10 @@ use axum::{ Form, extract::{Path, Query}, + response::{IntoResponse, Redirect, Response}, }; use serde_json::Value; -use shield::{Request, User}; +use shield::{Request, ResponseType, User}; use crate::{ExtractSession, ExtractShield, RouteError, path::ActionPathParams}; @@ -18,10 +19,10 @@ pub async fn action( ExtractSession(session): ExtractSession, Query(query): Query, Form(form_data): Form, -) -> Result<(), RouteError> { +) -> Result { // TODO: Check if this action supports the HTTP method (GET/POST)? - shield + let response = shield .call( &action_id, &method_id, @@ -31,5 +32,12 @@ pub async fn action( ) .await?; - Ok(()) + Ok(match response { + ResponseType::Default => todo!(), + ResponseType::Redirect(to) => Redirect::to(&to).into_response(), + ResponseType::RedirectToAction { action_id } => { + // TODO: Use actual frontend prefix instead of hardcoded `/auth`. + Redirect::to(&format!("/auth/{action_id}")).into_response() + } + }) } diff --git a/packages/integrations/shield-dioxus-axum/src/integration.rs b/packages/integrations/shield-dioxus-axum/src/integration.rs index 3b421bd..2154a61 100644 --- a/packages/integrations/shield-dioxus-axum/src/integration.rs +++ b/packages/integrations/shield-dioxus-axum/src/integration.rs @@ -24,13 +24,15 @@ impl Default for AxumDioxusIntegration { #[async_trait] impl DioxusIntegration for AxumDioxusIntegration { async fn extract_shield(&self) -> ShieldDyn { - let ExtractShield(shield) = extract::, _>().await.expect("TODO"); + let ExtractShield(shield) = extract::, _>() + .await + .expect("Shield should be extracted"); ShieldDyn::new(shield) } async fn extract_session(&self) -> Session { - let ExtractSession(session) = extract().await.expect("TODO"); + let ExtractSession(session) = extract().await.expect("Session should be extracted"); session } diff --git a/packages/integrations/shield-dioxus/src/routes/action.rs b/packages/integrations/shield-dioxus/src/routes/action.rs index f74440b..5c15160 100644 --- a/packages/integrations/shield-dioxus/src/routes/action.rs +++ b/packages/integrations/shield-dioxus/src/routes/action.rs @@ -1,6 +1,6 @@ use dioxus::prelude::*; use serde_json::Value; -use shield::{ActionForms, Response}; +use shield::{ActionForms, ResponseType}; use crate::ErasedDioxusStyle; @@ -60,12 +60,12 @@ pub async fn call( provider_id: Option, // TODO: Would be nice if this argument could fill up with all unknown keys instead of setting name to `data[...]`. data: Value, -) -> Result { +) -> Result { #[cfg(feature = "server")] { use dioxus::prelude::{FromContext, extract}; use serde_json::Value; - use shield::{Request, Response}; + use shield::Request; use crate::integration::DioxusIntegrationDyn; diff --git a/packages/integrations/shield-leptos-actix/src/integration.rs b/packages/integrations/shield-leptos-actix/src/integration.rs index 91c35f1..12cbfdb 100644 --- a/packages/integrations/shield-leptos-actix/src/integration.rs +++ b/packages/integrations/shield-leptos-actix/src/integration.rs @@ -18,19 +18,23 @@ impl Default for ActixLeptosIntegration { #[async_trait] impl LeptosIntegration for ActixLeptosIntegration { async fn extract_shield(&self) -> ShieldDyn { - let ExtractShield(shield) = extract::>().await.expect("TOD"); + let ExtractShield(shield) = extract::>() + .await + .expect("Shield should be extracted"); ShieldDyn::new(shield) } async fn extract_session(&self) -> Session { - let ExtractSession(session) = extract().await.expect("TODO"); + let ExtractSession(session) = extract().await.expect("Session should be extracted"); session } async fn extract_user(&self) -> Option { - let ExtractUser(user) = extract::>().await.expect("TODO"); + let ExtractUser(user) = extract::>() + .await + .expect("User should be extracted"); user.map(|user| user.into()) } diff --git a/packages/integrations/shield-leptos-axum/src/integration.rs b/packages/integrations/shield-leptos-axum/src/integration.rs index dd04245..cf0e4e5 100644 --- a/packages/integrations/shield-leptos-axum/src/integration.rs +++ b/packages/integrations/shield-leptos-axum/src/integration.rs @@ -18,19 +18,23 @@ impl Default for AxumLeptosIntegration { #[async_trait] impl LeptosIntegration for AxumLeptosIntegration { async fn extract_shield(&self) -> ShieldDyn { - let ExtractShield(shield) = extract::>().await.expect("TODO"); + let ExtractShield(shield) = extract::>() + .await + .expect("Shield should be extracted"); ShieldDyn::new(shield) } async fn extract_session(&self) -> Session { - let ExtractSession(session) = extract().await.expect("TODO"); + let ExtractSession(session) = extract().await.expect("Session should be extracted"); session } async fn extract_user(&self) -> Option { - let ExtractUser(user) = extract::>().await.expect("TODO"); + let ExtractUser(user) = extract::>() + .await + .expect("User should be extracted"); user.map(|user| user.into()) } diff --git a/packages/integrations/shield-leptos/src/routes/action.rs b/packages/integrations/shield-leptos/src/routes/action.rs index b7e78e1..c26b32a 100644 --- a/packages/integrations/shield-leptos/src/routes/action.rs +++ b/packages/integrations/shield-leptos/src/routes/action.rs @@ -57,7 +57,7 @@ pub async fn call( data: Value, ) -> Result<(), ServerFnError> { use serde_json::Value; - use shield::{Request, Response}; + use shield::{Request, ResponseType}; use crate::expect_server_integration; @@ -81,11 +81,11 @@ pub async fn call( .await?; match response { - Response::Default => todo!("default reponse"), - Response::Redirect(to) => { + ResponseType::Default => todo!("default reponse"), + ResponseType::Redirect(to) => { integration.redirect(&to); } - Response::RedirectToAction { action_id } => { + ResponseType::RedirectToAction { action_id } => { // TODO: Use actual router prefix instead of hardcoded `/auth`. integration.redirect(&format!("/auth/{action_id}")); } diff --git a/packages/methods/shield-credentials/src/actions/sign_in.rs b/packages/methods/shield-credentials/src/actions/sign_in.rs index 29da31e..ca0967a 100644 --- a/packages/methods/shield-credentials/src/actions/sign_in.rs +++ b/packages/methods/shield-credentials/src/actions/sign_in.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use async_trait::async_trait; use serde::de::DeserializeOwned; use shield::{ - Action, Authentication, Form, Request, Response, Session, SessionError, ShieldError, + Action, Form, MethodSession, Request, Response, ResponseType, SessionAction, ShieldError, SignInAction, User, erased_action, }; @@ -20,7 +20,7 @@ impl CredentialsSignInAction { } #[async_trait] -impl Action +impl Action for CredentialsSignInAction { fn id(&self) -> String { @@ -38,7 +38,7 @@ impl Action, request: Request, ) -> Result { let data = serde_json::from_value(request.form_data) @@ -46,22 +46,7 @@ impl Action for CredentialsSignOutAction { +impl Action for CredentialsSignOutAction { fn id(&self) -> String { SignOutAction::id() } @@ -18,7 +21,7 @@ impl Action for CredentialsSignOutAction { fn condition( &self, provider: &CredentialsProvider, - session: Session, + session: &MethodSession<()>, ) -> Result { SignOutAction::condition(provider, session) } @@ -30,11 +33,10 @@ impl Action for CredentialsSignOutAction { async fn call( &self, _provider: CredentialsProvider, - _session: Session, + _session: &MethodSession<()>, _request: Request, ) -> Result { - // TODO: sign out - Ok(Response::Default) + Ok(Response::new(ResponseType::Default).session_action(SessionAction::Unauthenticate)) } } diff --git a/packages/methods/shield-credentials/src/method.rs b/packages/methods/shield-credentials/src/method.rs index c788790..e48c5cf 100644 --- a/packages/methods/shield-credentials/src/method.rs +++ b/packages/methods/shield-credentials/src/method.rs @@ -25,21 +25,22 @@ impl CredentialsMethod { } #[async_trait] -impl Method - for CredentialsMethod -{ +impl Method for CredentialsMethod { + type Provider = CredentialsProvider; + type Session = (); + fn id(&self) -> String { CREDENTIALS_METHOD_ID.to_owned() } - fn actions(&self) -> Vec>> { + fn actions(&self) -> Vec>> { vec![ Box::new(CredentialsSignInAction::new(self.credentials.clone())), Box::new(CredentialsSignOutAction), ] } - async fn providers(&self) -> Result, ShieldError> { + async fn providers(&self) -> Result, ShieldError> { Ok(vec![CredentialsProvider]) } } diff --git a/packages/methods/shield-oauth/Cargo.toml b/packages/methods/shield-oauth/Cargo.toml index d28bbd2..4022540 100644 --- a/packages/methods/shield-oauth/Cargo.toml +++ b/packages/methods/shield-oauth/Cargo.toml @@ -8,6 +8,9 @@ license.workspace = true repository.workspace = true version.workspace = true +[package.metadata.cargo-machete] +ignored = ["serde_json"] + [features] default = [] native-tls = ["oauth2/native-tls"] @@ -20,4 +23,5 @@ chrono.workspace = true oauth2 = { version = "5.0.0", default-features = false, features = ["reqwest"] } secrecy.workspace = true serde.workspace = true +serde_json.workspace = true shield.workspace = true diff --git a/packages/methods/shield-oauth/src/actions/sign_in.rs b/packages/methods/shield-oauth/src/actions/sign_in.rs index 0f3b0a9..75fa84c 100644 --- a/packages/methods/shield-oauth/src/actions/sign_in.rs +++ b/packages/methods/shield-oauth/src/actions/sign_in.rs @@ -1,12 +1,11 @@ use async_trait::async_trait; use oauth2::{CsrfToken, PkceCodeChallenge, Scope, url::form_urlencoded::parse}; use shield::{ - Action, ConfigurationError, Form, Input, InputType, InputTypeSubmit, Provider, Request, - Response, Session, SessionError, ShieldError, SignInAction, erased_action, + Action, ConfigurationError, Form, Input, InputType, InputTypeSubmit, MethodSession, Provider, + Request, Response, ResponseType, SessionAction, ShieldError, SignInAction, erased_action, }; use crate::{ - method::OAUTH_METHOD_ID, provider::{OauthProvider, OauthProviderPkceCodeChallenge}, session::OauthSession, }; @@ -14,7 +13,7 @@ use crate::{ pub struct OauthSignInAction; #[async_trait] -impl Action for OauthSignInAction { +impl Action for OauthSignInAction { fn id(&self) -> String { SignInAction::id() } @@ -37,7 +36,7 @@ impl Action for OauthSignInAction { async fn call( &self, provider: OauthProvider, - session: Session, + _session: &MethodSession, _request: Request, ) -> Result { let client = provider.oauth_client().await?; @@ -73,26 +72,14 @@ impl Action for OauthSignInAction { let (auth_url, csrf_token) = authorization_request.url(); - { - let session_data = session.data(); - let mut session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.authentication = None; - - session_data.set_method( - OAUTH_METHOD_ID, - OauthSession { - csrf: Some(csrf_token.secret().clone()), - pkce_verifier: pkce_code_challenge - .map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()), - oauth_connection_id: None, - }, - )?; - } - - Ok(Response::Redirect(auth_url.to_string())) + Ok(Response::new(ResponseType::Redirect(auth_url.to_string())) + .session_action(SessionAction::Unauthenticate) + .session_action(SessionAction::data(OauthSession { + csrf: Some(csrf_token.secret().clone()), + pkce_verifier: pkce_code_challenge + .map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()), + oauth_connection_id: None, + })?)) } } diff --git a/packages/methods/shield-oauth/src/actions/sign_in_callback.rs b/packages/methods/shield-oauth/src/actions/sign_in_callback.rs index 5714ef1..a5a1baa 100644 --- a/packages/methods/shield-oauth/src/actions/sign_in_callback.rs +++ b/packages/methods/shield-oauth/src/actions/sign_in_callback.rs @@ -8,15 +8,14 @@ use oauth2::{ }; use secrecy::SecretString; use shield::{ - Action, Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Form, Request, - Response, Session, SessionError, ShieldError, SignInCallbackAction, UpdateUser, User, + Action, ConfigurationError, CreateEmailAddress, CreateUser, Form, MethodSession, Request, + Response, ResponseType, SessionAction, ShieldError, SignInCallbackAction, UpdateUser, User, erased_action, }; use crate::{ client::async_http_client, connection::{CreateOauthConnection, OauthConnection, UpdateOauthConnection}, - method::OAUTH_METHOD_ID, options::OauthOptions, provider::{OauthProvider, OauthProviderPkceCodeChallenge}, session::OauthSession, @@ -130,7 +129,7 @@ impl OauthSignInCallbackAction { } #[async_trait] -impl Action for OauthSignInCallbackAction { +impl Action for OauthSignInCallbackAction { fn id(&self) -> String { SignInCallbackAction::id() } @@ -139,7 +138,11 @@ impl Action for OauthSignInCallbackAction { SignInCallbackAction::name() } - fn condition(&self, provider: &OauthProvider, session: Session) -> Result { + fn condition( + &self, + provider: &OauthProvider, + session: &MethodSession, + ) -> Result { SignInCallbackAction::condition(provider, session) } @@ -150,21 +153,14 @@ impl Action for OauthSignInCallbackAction { async fn call( &self, provider: OauthProvider, - session: Session, + session: &MethodSession, request: Request, ) -> Result { let OauthSession { csrf, pkce_verifier, .. - } = { - let session_data = session.data(); - let session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.method(OAUTH_METHOD_ID)? - }; + } = &session.method; let state = request .query @@ -172,7 +168,7 @@ impl Action for OauthSignInCallbackAction { .and_then(|code| code.as_str()) .ok_or_else(|| ShieldError::Validation("Missing state.".to_owned()))?; - if csrf.is_none_or(|csrf| csrf != state) { + if csrf.as_ref().is_none_or(|csrf| csrf != state) { return Err(ShieldError::Validation("Invalid state.".to_owned())); } @@ -191,7 +187,8 @@ impl Action for OauthSignInCallbackAction { })?; if let Some(pkce_verifier) = pkce_verifier { - token_request = token_request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier)); + token_request = + token_request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier.to_owned())); } else if provider.pkce_code_challenge != OauthProviderPkceCodeChallenge::None { return Err(ShieldError::Validation("Missing PKCE verifier.".to_owned())); } @@ -247,31 +244,15 @@ impl Action for OauthSignInCallbackAction { } }; - session.renew().await?; - - { - let session_data = session.data(); - let mut session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.authentication = Some(Authentication { - method_id: self.id(), - provider_id: Some(provider.id), - user_id: user.id(), - }); - - session_data.set_method( - OAUTH_METHOD_ID, - OauthSession { - csrf: None, - pkce_verifier: None, - oauth_connection_id: Some(connection.id), - }, - )?; - } - - Ok(Response::Redirect(self.options.sign_in_redirect.clone())) + Ok(Response::new(ResponseType::Redirect( + self.options.sign_in_redirect.clone(), + )) + .session_action(SessionAction::authenticate(user)) + .session_action(SessionAction::data(OauthSession { + csrf: None, + pkce_verifier: None, + oauth_connection_id: Some(connection.id), + })?)) } } diff --git a/packages/methods/shield-oauth/src/actions/sign_out.rs b/packages/methods/shield-oauth/src/actions/sign_out.rs index 53275ed..075e0ac 100644 --- a/packages/methods/shield-oauth/src/actions/sign_out.rs +++ b/packages/methods/shield-oauth/src/actions/sign_out.rs @@ -1,12 +1,15 @@ use async_trait::async_trait; -use shield::{Action, Form, Request, Response, Session, ShieldError, SignOutAction, erased_action}; +use shield::{ + Action, Form, MethodSession, Request, Response, ResponseType, SessionAction, ShieldError, + SignOutAction, erased_action, +}; -use crate::provider::OauthProvider; +use crate::{provider::OauthProvider, session::OauthSession}; pub struct OauthSignOutAction; #[async_trait] -impl Action for OauthSignOutAction { +impl Action for OauthSignOutAction { fn id(&self) -> String { SignOutAction::id() } @@ -15,7 +18,11 @@ impl Action for OauthSignOutAction { SignOutAction::name() } - fn condition(&self, provider: &OauthProvider, session: Session) -> Result { + fn condition( + &self, + provider: &OauthProvider, + session: &MethodSession, + ) -> Result { SignOutAction::condition(provider, session) } @@ -26,13 +33,12 @@ impl Action for OauthSignOutAction { async fn call( &self, _provider: OauthProvider, - _session: Session, + _session: &MethodSession, _request: Request, ) -> Result { // TODO: OAuth token revocation. - // TODO: Sign out. - Ok(Response::Default) + Ok(Response::new(ResponseType::Default).session_action(SessionAction::Unauthenticate)) } } diff --git a/packages/methods/shield-oauth/src/method.rs b/packages/methods/shield-oauth/src/method.rs index d84927c..db1f044 100644 --- a/packages/methods/shield-oauth/src/method.rs +++ b/packages/methods/shield-oauth/src/method.rs @@ -7,6 +7,7 @@ use crate::{ actions::{OauthSignInAction, OauthSignInCallbackAction, OauthSignOutAction}, options::OauthOptions, provider::OauthProvider, + session::OauthSession, storage::OauthStorage, }; @@ -62,12 +63,15 @@ impl OauthMethod { } #[async_trait] -impl Method for OauthMethod { +impl Method for OauthMethod { + type Provider = OauthProvider; + type Session = OauthSession; + fn id(&self) -> String { OAUTH_METHOD_ID.to_owned() } - fn actions(&self) -> Vec>> { + fn actions(&self) -> Vec>> { vec![ Box::new(OauthSignInAction), Box::new(OauthSignInCallbackAction::new( @@ -78,7 +82,7 @@ impl Method for OauthMethod { ] } - async fn providers(&self) -> Result, ShieldError> { + async fn providers(&self) -> Result, ShieldError> { Ok(self .providers .iter() @@ -90,7 +94,7 @@ impl Method for OauthMethod { async fn provider_by_id( &self, provider_id: Option<&str>, - ) -> Result, ShieldError> { + ) -> Result, ShieldError> { if let Some(provider_id) = provider_id { self.oauth_provider_by_id_or_slug(provider_id).await } else { diff --git a/packages/methods/shield-oidc/Cargo.toml b/packages/methods/shield-oidc/Cargo.toml index f9c3295..fd1d729 100644 --- a/packages/methods/shield-oidc/Cargo.toml +++ b/packages/methods/shield-oidc/Cargo.toml @@ -9,7 +9,7 @@ repository.workspace = true version.workspace = true [package.metadata.cargo-machete] -ignored = ["oauth2"] +ignored = ["oauth2", "serde_json"] [features] default = [] @@ -28,5 +28,6 @@ openidconnect = { version = "4.0.0", default-features = false, features = [ ] } secrecy.workspace = true serde.workspace = true +serde_json.workspace = true shield.workspace = true tracing.workspace = true diff --git a/packages/methods/shield-oidc/src/actions/sign_in.rs b/packages/methods/shield-oidc/src/actions/sign_in.rs index c5f8314..c767662 100644 --- a/packages/methods/shield-oidc/src/actions/sign_in.rs +++ b/packages/methods/shield-oidc/src/actions/sign_in.rs @@ -4,12 +4,11 @@ use openidconnect::{ url::form_urlencoded::parse, }; use shield::{ - Action, Form, Input, InputType, InputTypeSubmit, Provider, Request, Response, Session, - SessionError, ShieldError, SignInAction, erased_action, + Action, Form, Input, InputType, InputTypeSubmit, MethodSession, Provider, Request, Response, + ResponseType, SessionAction, ShieldError, SignInAction, erased_action, }; use crate::{ - method::OIDC_METHOD_ID, provider::{OidcProvider, OidcProviderPkceCodeChallenge}, session::OidcSession, }; @@ -17,7 +16,7 @@ use crate::{ pub struct OidcSignInAction; #[async_trait] -impl Action for OidcSignInAction { +impl Action for OidcSignInAction { fn id(&self) -> String { SignInAction::id() } @@ -40,7 +39,7 @@ impl Action for OidcSignInAction { async fn call( &self, provider: OidcProvider, - session: Session, + _session: &MethodSession, _request: Request, ) -> Result { let client = provider.oidc_client().await?; @@ -78,29 +77,15 @@ impl Action for OidcSignInAction { let (auth_url, csrf_token, nonce) = authorization_request.url(); - { - // TODO: Add a generic type for session data to actions, so the action caller can be read/write the session. - - let session_data = session.data(); - let mut session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.authentication = None; - - session_data.set_method( - OIDC_METHOD_ID, - OidcSession { - csrf: Some(csrf_token.secret().clone()), - nonce: Some(nonce.secret().clone()), - pkce_verifier: pkce_code_challenge - .map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()), - oidc_connection_id: None, - }, - )?; - } - - Ok(Response::Redirect(auth_url.to_string())) + Ok(Response::new(ResponseType::Redirect(auth_url.to_string())) + .session_action(SessionAction::unauthenticate()) + .session_action(SessionAction::data(OidcSession { + csrf: Some(csrf_token.secret().clone()), + nonce: Some(nonce.secret().clone()), + pkce_verifier: pkce_code_challenge + .map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()), + oidc_connection_id: None, + })?)) } } diff --git a/packages/methods/shield-oidc/src/actions/sign_in_callback.rs b/packages/methods/shield-oidc/src/actions/sign_in_callback.rs index 437a80c..b6e4eca 100644 --- a/packages/methods/shield-oidc/src/actions/sign_in_callback.rs +++ b/packages/methods/shield-oidc/src/actions/sign_in_callback.rs @@ -10,8 +10,8 @@ use openidconnect::{ }; use secrecy::SecretString; use shield::{ - Action, Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Form, Request, - Response, Session, SessionError, ShieldError, SignInCallbackAction, UpdateUser, User, + Action, ConfigurationError, CreateEmailAddress, CreateUser, Form, MethodSession, Request, + Response, ResponseType, SessionAction, ShieldError, SignInCallbackAction, UpdateUser, User, erased_action, }; use tracing::debug; @@ -20,7 +20,6 @@ use crate::{ claims::Claims, client::async_http_client, connection::{CreateOidcConnection, OidcConnection, UpdateOidcConnection}, - method::OIDC_METHOD_ID, options::OidcOptions, provider::{OidcProvider, OidcProviderPkceCodeChallenge}, session::OidcSession, @@ -141,7 +140,7 @@ impl OidcSignInCallbackAction { } #[async_trait] -impl Action for OidcSignInCallbackAction { +impl Action for OidcSignInCallbackAction { fn id(&self) -> String { SignInCallbackAction::id() } @@ -150,7 +149,11 @@ impl Action for OidcSignInCallbackAction { SignInCallbackAction::name() } - fn condition(&self, provider: &OidcProvider, session: Session) -> Result { + fn condition( + &self, + provider: &OidcProvider, + session: &MethodSession, + ) -> Result { SignInCallbackAction::condition(provider, session) } @@ -161,7 +164,7 @@ impl Action for OidcSignInCallbackAction { async fn call( &self, provider: OidcProvider, - session: Session, + session: &MethodSession, request: Request, ) -> Result { let OidcSession { @@ -169,14 +172,7 @@ impl Action for OidcSignInCallbackAction { nonce, pkce_verifier, .. - } = { - let session_data = session.data(); - let session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.method(OIDC_METHOD_ID)? - }; + } = &session.method; let state = request .query @@ -184,7 +180,7 @@ impl Action for OidcSignInCallbackAction { .and_then(|code| code.as_str()) .ok_or_else(|| ShieldError::Validation("Missing state.".to_owned()))?; - if csrf.is_none_or(|csrf| csrf != state) { + if csrf.as_ref().is_none_or(|csrf| csrf != state) { return Err(ShieldError::Validation("Invalid state.".to_owned())); } @@ -203,7 +199,8 @@ impl Action for OidcSignInCallbackAction { })?; if let Some(pkce_verifier) = pkce_verifier { - token_request = token_request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier)); + token_request = + token_request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier.to_owned())); } else if provider.pkce_code_challenge != OidcProviderPkceCodeChallenge::None { return Err(ShieldError::Validation("Missing PKCE verifier.".to_owned())); } @@ -230,7 +227,9 @@ impl Action for OidcSignInCallbackAction { &client.id_token_verifier(), &Nonce::new( nonce - .ok_or_else(|| ShieldError::Validation("Missing nonce.".to_owned()))?, + .as_ref() + .ok_or_else(|| ShieldError::Validation("Missing nonce.".to_owned()))? + .to_owned(), ), ) .map_err(|err| ShieldError::Validation(err.to_string()))?; @@ -279,32 +278,16 @@ impl Action for OidcSignInCallbackAction { } }; - session.renew().await?; - - { - let session_data = session.data(); - let mut session_data = session_data - .lock() - .map_err(|err| SessionError::Lock(err.to_string()))?; - - session_data.authentication = Some(Authentication { - method_id: self.id(), - provider_id: Some(provider.id), - user_id: user.id(), - }); - - session_data.set_method( - OIDC_METHOD_ID, - OidcSession { - csrf: None, - nonce: None, - pkce_verifier: None, - oidc_connection_id: Some(connection.id), - }, - )?; - } - - Ok(Response::Redirect(self.options.sign_in_redirect.clone())) + Ok(Response::new(ResponseType::Redirect( + self.options.sign_in_redirect.clone(), + )) + .session_action(SessionAction::authenticate(user)) + .session_action(SessionAction::data(OidcSession { + csrf: None, + nonce: None, + pkce_verifier: None, + oidc_connection_id: Some(connection.id), + })?)) } } diff --git a/packages/methods/shield-oidc/src/actions/sign_out.rs b/packages/methods/shield-oidc/src/actions/sign_out.rs index 4c804a3..9feabe5 100644 --- a/packages/methods/shield-oidc/src/actions/sign_out.rs +++ b/packages/methods/shield-oidc/src/actions/sign_out.rs @@ -1,12 +1,15 @@ use async_trait::async_trait; -use shield::{Action, Form, Request, Response, Session, ShieldError, SignOutAction, erased_action}; +use shield::{ + Action, Form, MethodSession, Request, Response, ResponseType, SessionAction, ShieldError, + SignOutAction, erased_action, +}; -use crate::provider::OidcProvider; +use crate::{provider::OidcProvider, session::OidcSession}; pub struct OidcSignOutAction; #[async_trait] -impl Action for OidcSignOutAction { +impl Action for OidcSignOutAction { fn id(&self) -> String { SignOutAction::id() } @@ -15,7 +18,11 @@ impl Action for OidcSignOutAction { SignOutAction::name() } - fn condition(&self, provider: &OidcProvider, session: Session) -> Result { + fn condition( + &self, + provider: &OidcProvider, + session: &MethodSession, + ) -> Result { SignOutAction::condition(provider, session) } @@ -26,7 +33,7 @@ impl Action for OidcSignOutAction { async fn call( &self, _provider: OidcProvider, - _session: Session, + _session: &MethodSession, _request: Request, ) -> Result { // TODO: See [`OidcProvider::oidc_client`]. @@ -80,9 +87,7 @@ impl Action for OidcSignOutAction { // } // } - // TODO: Sign out. - - Ok(Response::Default) + Ok(Response::new(ResponseType::Default).session_action(SessionAction::Unauthenticate)) } } diff --git a/packages/methods/shield-oidc/src/method.rs b/packages/methods/shield-oidc/src/method.rs index eb3da79..5e1fc80 100644 --- a/packages/methods/shield-oidc/src/method.rs +++ b/packages/methods/shield-oidc/src/method.rs @@ -7,6 +7,7 @@ use crate::{ actions::{OidcSignInAction, OidcSignInCallbackAction, OidcSignOutAction}, options::OidcOptions, provider::OidcProvider, + session::OidcSession, storage::OidcStorage, }; @@ -62,12 +63,15 @@ impl OidcMethod { } #[async_trait] -impl Method for OidcMethod { +impl Method for OidcMethod { + type Provider = OidcProvider; + type Session = OidcSession; + fn id(&self) -> String { OIDC_METHOD_ID.to_owned() } - fn actions(&self) -> Vec>> { + fn actions(&self) -> Vec>> { vec![ Box::new(OidcSignInAction), Box::new(OidcSignInCallbackAction::new( @@ -78,7 +82,7 @@ impl Method for OidcMethod { ] } - async fn providers(&self) -> Result, ShieldError> { + async fn providers(&self) -> Result, ShieldError> { Ok(self .providers .iter() @@ -90,7 +94,7 @@ impl Method for OidcMethod { async fn provider_by_id( &self, provider_id: Option<&str>, - ) -> Result, ShieldError> { + ) -> Result, ShieldError> { if let Some(provider_id) = provider_id { self.oidc_provider_by_id_or_slug(provider_id).await } else { diff --git a/packages/methods/shield-workos/src/actions/index.rs b/packages/methods/shield-workos/src/actions/index.rs index 38ca85d..f673c1c 100644 --- a/packages/methods/shield-workos/src/actions/index.rs +++ b/packages/methods/shield-workos/src/actions/index.rs @@ -3,8 +3,9 @@ use std::sync::Arc; use async_trait::async_trait; use serde::Deserialize; use shield::{ - Action, Form, Input, InputType, InputTypeEmail, InputTypeHidden, InputTypeSubmit, Request, - Response, Session, ShieldError, SignInAction, SignUpAction, erased_action, + Action, Form, Input, InputType, InputTypeEmail, InputTypeHidden, InputTypeSubmit, + MethodSession, Request, Response, ResponseType, ShieldError, SignInAction, SignUpAction, + erased_action, }; use workos::{ PaginationParams, @@ -40,7 +41,7 @@ impl WorkosIndexAction { } #[async_trait] -impl Action for WorkosIndexAction { +impl Action for WorkosIndexAction { fn id(&self) -> String { ACTION_ID.to_owned() } @@ -145,7 +146,7 @@ impl Action for WorkosIndexAction { async fn call( &self, _provider: WorkosProvider, - _session: Session, + _session: &MethodSession<()>, request: Request, ) -> Result { // TODO: Check email address and redirect to sign-in/sign-up action with prefilled email address. @@ -172,13 +173,13 @@ impl Action for WorkosIndexAction { // TODO: Include email address as state. if users.data.is_empty() { - Ok(Response::RedirectToAction { + Ok(Response::new(ResponseType::RedirectToAction { action_id: SignUpAction::id(), - }) + })) } else { - Ok(Response::RedirectToAction { + Ok(Response::new(ResponseType::RedirectToAction { action_id: SignInAction::id(), - }) + })) } } IndexData::Oauth { oauth_provider } => { @@ -199,7 +200,9 @@ impl Action for WorkosIndexAction { }) .expect("TODO: handle error"); - Ok(Response::Redirect(authorization_url.to_string())) + Ok(Response::new(ResponseType::Redirect( + authorization_url.to_string(), + ))) } IndexData::Sso { connection_id } => { let authorization_url = self @@ -217,7 +220,9 @@ impl Action for WorkosIndexAction { }) .expect("TODO: handle error"); - Ok(Response::Redirect(authorization_url.to_string())) + Ok(Response::new(ResponseType::Redirect( + authorization_url.to_string(), + ))) } } } diff --git a/packages/methods/shield-workos/src/actions/sign_in.rs b/packages/methods/shield-workos/src/actions/sign_in.rs index 33c04dd..0c618a3 100644 --- a/packages/methods/shield-workos/src/actions/sign_in.rs +++ b/packages/methods/shield-workos/src/actions/sign_in.rs @@ -3,7 +3,8 @@ use std::sync::Arc; use async_trait::async_trait; use shield::{ Action, Form, Input, InputType, InputTypeEmail, InputTypeHidden, InputTypePassword, - InputTypeSubmit, Request, Response, Session, ShieldError, SignInAction, erased_action, + InputTypeSubmit, MethodSession, Request, Response, ResponseType, ShieldError, SignInAction, + erased_action, }; use crate::{client::WorkosClient, provider::WorkosProvider}; @@ -21,7 +22,7 @@ impl WorkosSignInAction { } #[async_trait] -impl Action for WorkosSignInAction { +impl Action for WorkosSignInAction { fn id(&self) -> String { SignInAction::id() } @@ -92,11 +93,11 @@ impl Action for WorkosSignInAction { async fn call( &self, _provider: WorkosProvider, - _session: Session, + _session: &MethodSession<()>, _request: Request, ) -> Result { // TODO: sign in - Ok(Response::Default) + Ok(Response::new(ResponseType::Default)) } } diff --git a/packages/methods/shield-workos/src/actions/sign_out.rs b/packages/methods/shield-workos/src/actions/sign_out.rs index 045277b..a06d889 100644 --- a/packages/methods/shield-workos/src/actions/sign_out.rs +++ b/packages/methods/shield-workos/src/actions/sign_out.rs @@ -1,7 +1,10 @@ use std::sync::Arc; use async_trait::async_trait; -use shield::{Action, Form, Request, Response, Session, ShieldError, SignOutAction, erased_action}; +use shield::{ + Action, Form, MethodSession, Request, Response, ResponseType, SessionAction, ShieldError, + SignOutAction, erased_action, +}; use crate::{client::WorkosClient, provider::WorkosProvider}; @@ -18,7 +21,7 @@ impl WorkosSignOutAction { } #[async_trait] -impl Action for WorkosSignOutAction { +impl Action for WorkosSignOutAction { fn id(&self) -> String { SignOutAction::id() } @@ -27,7 +30,11 @@ impl Action for WorkosSignOutAction { SignOutAction::name() } - fn condition(&self, provider: &WorkosProvider, session: Session) -> Result { + fn condition( + &self, + provider: &WorkosProvider, + session: &MethodSession<()>, + ) -> Result { SignOutAction::condition(provider, session) } @@ -38,11 +45,12 @@ impl Action for WorkosSignOutAction { async fn call( &self, _provider: WorkosProvider, - _session: Session, + _session: &MethodSession<()>, _request: Request, ) -> Result { - // TODO: sign out - Ok(Response::Default) + // TODO: Handle WorkOS sign out. + + Ok(Response::new(ResponseType::Default).session_action(SessionAction::Unauthenticate)) } } diff --git a/packages/methods/shield-workos/src/actions/sign_up.rs b/packages/methods/shield-workos/src/actions/sign_up.rs index 183bc5a..ee19f1c 100644 --- a/packages/methods/shield-workos/src/actions/sign_up.rs +++ b/packages/methods/shield-workos/src/actions/sign_up.rs @@ -3,7 +3,8 @@ use std::sync::Arc; use async_trait::async_trait; use shield::{ Action, Form, Input, InputType, InputTypeEmail, InputTypeHidden, InputTypePassword, - InputTypeSubmit, Request, Response, Session, ShieldError, SignUpAction, erased_action, + InputTypeSubmit, MethodSession, Request, Response, ResponseType, ShieldError, SignUpAction, + erased_action, }; use crate::{client::WorkosClient, provider::WorkosProvider}; @@ -21,7 +22,7 @@ impl WorkosSignUpAction { } #[async_trait] -impl Action for WorkosSignUpAction { +impl Action for WorkosSignUpAction { fn id(&self) -> String { SignUpAction::id() } @@ -92,11 +93,11 @@ impl Action for WorkosSignUpAction { async fn call( &self, _provider: WorkosProvider, - _session: Session, + _session: &MethodSession<()>, _request: Request, ) -> Result { - // TODO: sign in - Ok(Response::Default) + // TODO: sign up + Ok(Response::new(ResponseType::Default)) } } diff --git a/packages/methods/shield-workos/src/method.rs b/packages/methods/shield-workos/src/method.rs index 5608043..a08c97a 100644 --- a/packages/methods/shield-workos/src/method.rs +++ b/packages/methods/shield-workos/src/method.rs @@ -37,12 +37,15 @@ impl WorkosMethod { } #[async_trait] -impl Method for WorkosMethod { +impl Method for WorkosMethod { + type Provider = WorkosProvider; + type Session = (); + fn id(&self) -> String { WORKOS_METHOD_ID.to_owned() } - fn actions(&self) -> Vec>> { + fn actions(&self) -> Vec>> { vec![ Box::new(WorkosIndexAction::new( self.options.clone(), @@ -54,7 +57,7 @@ impl Method for WorkosMethod { ] } - async fn providers(&self) -> Result, ShieldError> { + async fn providers(&self) -> Result, ShieldError> { Ok(vec![WorkosProvider]) } } diff --git a/packages/styles/shield-bootstrap/src/dioxus/form.rs b/packages/styles/shield-bootstrap/src/dioxus/form.rs index 7424618..f33b648 100644 --- a/packages/styles/shield-bootstrap/src/dioxus/form.rs +++ b/packages/styles/shield-bootstrap/src/dioxus/form.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use dioxus::{logger::tracing::info, prelude::*}; -use shield::Response; +use shield::ResponseType; use shield_dioxus::{ShieldRouter, call}; use crate::dioxus::input::FormInput; @@ -46,11 +46,11 @@ pub fn Form(props: FormProps) -> Element { // TODO: Handle error. if let Ok(response) = result { match response { - Response::Default => todo!("default response"), - Response::Redirect(to) => { + ResponseType::Default => todo!("default response"), + ResponseType::Redirect(to) => { navigator.push(to); }, - Response::RedirectToAction { action_id } => { + ResponseType::RedirectToAction { action_id } => { navigator.push(ShieldRouter::Action { action_id }); } }