From 42f6b7a37d98a15948fc24cd78457dbb1ddaa580 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABlle=20Huisman?= Date: Sat, 1 Feb 2025 17:01:34 +0100 Subject: [PATCH] feat(shield): add redirect URL to sign in --- packages/core/shield/src/request.rs | 2 + packages/core/shield/src/session.rs | 2 + packages/core/shield/src/shield.rs | 45 +++++++++++++++++-- .../shield-axum/src/routes/sign_in.rs | 12 ++++- .../src/routes/sign_in_callback.rs | 1 + .../shield-leptos/src/routes/sign_in.rs | 1 + .../providers/shield-oidc/src/provider.rs | 10 ++--- 7 files changed, 63 insertions(+), 10 deletions(-) diff --git a/packages/core/shield/src/request.rs b/packages/core/shield/src/request.rs index 1f4d868..bff3c1c 100644 --- a/packages/core/shield/src/request.rs +++ b/packages/core/shield/src/request.rs @@ -6,6 +6,7 @@ use serde_json::Value; pub struct SignInRequest { pub provider_id: String, pub subprovider_id: Option, + pub redirect_url: Option, pub data: Option, pub form_data: Option, } @@ -15,6 +16,7 @@ pub struct SignInRequest { pub struct SignInCallbackRequest { pub provider_id: String, pub subprovider_id: Option, + pub redirect_url: Option, pub query: Option, pub data: Option, } diff --git a/packages/core/shield/src/session.rs b/packages/core/shield/src/session.rs index 0a7dd3d..13608e5 100644 --- a/packages/core/shield/src/session.rs +++ b/packages/core/shield/src/session.rs @@ -45,6 +45,8 @@ impl Session { pub struct SessionData { pub authentication: Option, + pub redirect_url: Option, + // TODO: Allow arbitrary data to be stored by providers? pub csrf: Option, pub nonce: Option, diff --git a/packages/core/shield/src/shield.rs b/packages/core/shield/src/shield.rs index bdbeff0..1f10715 100644 --- a/packages/core/shield/src/shield.rs +++ b/packages/core/shield/src/shield.rs @@ -112,7 +112,24 @@ impl Shield { None => return Err(ProviderError::ProviderNotFound(request.provider_id).into()), }; - provider.sign_in(request, session, &self.options).await + // TODO: validate redirect URL + + { + let session_data = session.data(); + let mut session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.redirect_url = request.redirect_url.clone(); + }; + + let response = provider + .sign_in(request, session.clone(), &self.options) + .await; + + session.update().await?; + + response } pub async fn sign_in_callback( @@ -127,9 +144,29 @@ impl Shield { None => return Err(ProviderError::ProviderNotFound(request.provider_id).into()), }; - provider - .sign_in_callback(request, session, &self.options) - .await + let redirect_url = { + let session_data = session.data(); + let session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.redirect_url.clone() + }; + + let response = provider + .sign_in_callback( + SignInCallbackRequest { + redirect_url: request.redirect_url.or(redirect_url), + ..request + }, + session.clone(), + &self.options, + ) + .await; + + session.update().await?; + + response } pub async fn sign_out(&self, session: Session) -> Result { diff --git a/packages/integrations/shield-axum/src/routes/sign_in.rs b/packages/integrations/shield-axum/src/routes/sign_in.rs index c58805f..b38551e 100644 --- a/packages/integrations/shield-axum/src/routes/sign_in.rs +++ b/packages/integrations/shield-axum/src/routes/sign_in.rs @@ -1,4 +1,5 @@ -use axum::extract::Path; +use axum::{extract::Path, Json}; +use serde::{Deserialize, Serialize}; use shield::{SignInRequest, User}; use crate::{ @@ -8,6 +9,12 @@ use crate::{ response::RouteResponse, }; +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +pub struct SignInData { + redirect_url: Option, +} + #[cfg_attr( feature = "utoipa", utoipa::path( @@ -18,6 +25,7 @@ use crate::{ params( AuthPathParams, ), + request_body = SignInData, responses( (status = 200, description = "Successfully signed in."), (status = 303, description = "Redirect to authentication provider for sign in."), @@ -34,12 +42,14 @@ pub async fn sign_in( }): Path, ExtractShield(shield): ExtractShield, ExtractSession(session): ExtractSession, + Json(data): Json, ) -> Result { let response = shield .sign_in( SignInRequest { provider_id, subprovider_id, + redirect_url: data.redirect_url, data: None, form_data: None, }, diff --git a/packages/integrations/shield-axum/src/routes/sign_in_callback.rs b/packages/integrations/shield-axum/src/routes/sign_in_callback.rs index 93e1da9..bfca83f 100644 --- a/packages/integrations/shield-axum/src/routes/sign_in_callback.rs +++ b/packages/integrations/shield-axum/src/routes/sign_in_callback.rs @@ -41,6 +41,7 @@ pub async fn sign_in_callback( SignInCallbackRequest { provider_id, subprovider_id, + redirect_url: None, query: Some(query), data: None, }, diff --git a/packages/integrations/shield-leptos/src/routes/sign_in.rs b/packages/integrations/shield-leptos/src/routes/sign_in.rs index 493f55f..2149ec7 100644 --- a/packages/integrations/shield-leptos/src/routes/sign_in.rs +++ b/packages/integrations/shield-leptos/src/routes/sign_in.rs @@ -31,6 +31,7 @@ pub async fn sign_in( SignInRequest { provider_id, subprovider_id, + redirect_url: None, data: None, form_data: None, }, diff --git a/packages/providers/shield-oidc/src/provider.rs b/packages/providers/shield-oidc/src/provider.rs index 2bbd7ee..604943b 100644 --- a/packages/providers/shield-oidc/src/provider.rs +++ b/packages/providers/shield-oidc/src/provider.rs @@ -253,8 +253,6 @@ impl Provider for OidcProvider { session_data.oidc_connection_id = None; } - session.update().await?; - Ok(Response::Redirect(auth_url.to_string())) } @@ -402,9 +400,11 @@ impl Provider for OidcProvider { session_data.oidc_connection_id = Some(connection.id); } - session.update().await?; - - Ok(Response::Redirect(options.sign_in_redirect.clone())) + Ok(Response::Redirect( + request + .redirect_url + .unwrap_or(options.sign_in_redirect.clone()), + )) } async fn sign_out(