From d074f16fbb8a1fa9c3102b91ed633d9306dbc519 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABlle=20Huisman?= Date: Thu, 13 Feb 2025 19:13:54 +0100 Subject: [PATCH] feat(shield): move provider session data to provider --- Cargo.lock | 1 + packages/core/shield/src/error.rs | 2 + packages/core/shield/src/session.rs | 40 ++++++++++++---- packages/providers/shield-oidc/Cargo.toml | 1 + packages/providers/shield-oidc/src/lib.rs | 1 + .../providers/shield-oidc/src/provider.rs | 48 ++++++++++++------- packages/providers/shield-oidc/src/session.rs | 9 ++++ 7 files changed, 76 insertions(+), 26 deletions(-) create mode 100644 packages/providers/shield-oidc/src/session.rs diff --git a/Cargo.lock b/Cargo.lock index 6c525cb..502ad45 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4666,6 +4666,7 @@ dependencies = [ "chrono", "oauth2", "openidconnect", + "serde", "shield", "tracing", ] diff --git a/packages/core/shield/src/error.rs b/packages/core/shield/src/error.rs index 5050112..6d23f4a 100644 --- a/packages/core/shield/src/error.rs +++ b/packages/core/shield/src/error.rs @@ -38,6 +38,8 @@ pub enum SessionError { Engine(String), #[error("{0}")] Lock(String), + #[error("{0}")] + Serialization(String), } #[derive(Debug, Error)] diff --git a/packages/core/shield/src/session.rs b/packages/core/shield/src/session.rs index 13608e5..09cdc25 100644 --- a/packages/core/shield/src/session.rs +++ b/packages/core/shield/src/session.rs @@ -1,7 +1,10 @@ -use std::sync::{Arc, Mutex}; +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; use async_trait::async_trait; -use serde::{Deserialize, Serialize}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use crate::error::SessionError; @@ -43,15 +46,36 @@ impl Session { #[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct SessionData { + pub redirect_url: Option, pub authentication: Option, + pub providers: HashMap, +} - pub redirect_url: Option, +impl SessionData { + pub fn provider( + &self, + provider_id: &str, + ) -> Result { + match self.providers.get(provider_id) { + Some(value) => serde_json::from_str(value) + .map_err(|err| SessionError::Serialization(err.to_string())), + None => Ok(T::default()), + } + } - // TODO: Allow arbitrary data to be stored by providers? - pub csrf: Option, - pub nonce: Option, - pub verifier: Option, - pub oidc_connection_id: Option, + pub fn set_provider( + &mut self, + provider_id: &str, + value: T, + ) -> Result<(), SessionError> { + self.providers.insert( + provider_id.to_owned(), + serde_json::to_string(&value) + .map_err(|err| SessionError::Serialization(err.to_string()))?, + ); + + Ok(()) + } } #[derive(Clone, Debug, Default, Deserialize, Serialize)] diff --git a/packages/providers/shield-oidc/Cargo.toml b/packages/providers/shield-oidc/Cargo.toml index 6c41d40..f10cbc8 100644 --- a/packages/providers/shield-oidc/Cargo.toml +++ b/packages/providers/shield-oidc/Cargo.toml @@ -14,5 +14,6 @@ bon.workspace = true chrono.workspace = true oauth2 = { version = "5.0.0", features = ["pkce-plain"] } openidconnect = "4.0.0" +serde.workspace = true shield = { path = "../../core/shield", version = "0.0.4" } tracing.workspace = true diff --git a/packages/providers/shield-oidc/src/lib.rs b/packages/providers/shield-oidc/src/lib.rs index bc47ea1..1391644 100644 --- a/packages/providers/shield-oidc/src/lib.rs +++ b/packages/providers/shield-oidc/src/lib.rs @@ -3,6 +3,7 @@ mod claims; mod client; mod connection; mod provider; +mod session; mod storage; mod subprovider; diff --git a/packages/providers/shield-oidc/src/provider.rs b/packages/providers/shield-oidc/src/provider.rs index e4150c2..bab4da7 100644 --- a/packages/providers/shield-oidc/src/provider.rs +++ b/packages/providers/shield-oidc/src/provider.rs @@ -14,8 +14,9 @@ use shield::{ use tracing::debug; use crate::{ - claims::Claims, client::async_http_client, storage::OidcStorage, subprovider::OidcSubprovider, - CreateOidcConnection, OidcConnection, OidcProviderPkceCodeChallenge, UpdateOidcConnection, + claims::Claims, client::async_http_client, session::OidcSession, storage::OidcStorage, + subprovider::OidcSubprovider, CreateOidcConnection, OidcConnection, + OidcProviderPkceCodeChallenge, UpdateOidcConnection, }; pub const OIDC_PROVIDER_ID: &str = "oidc"; @@ -245,11 +246,16 @@ impl Provider for OidcProvider { session_data.authentication = None; - session_data.csrf = Some(csrf_token.secret().clone()); - session_data.nonce = Some(nonce.secret().clone()); - session_data.verifier = pkce_code_challenge - .map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()); - session_data.oidc_connection_id = None; + session_data.set_provider( + OIDC_PROVIDER_ID, + OidcSession { + csrf: Some(csrf_token.secret().clone()), + nonce: Some(nonce.secret().clone()), + pkce_verifier: pkce_code_challenge + .map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()), + oidc_connection_id: None, + }, + )?; } Ok(Response::Redirect(auth_url.to_string())) @@ -261,17 +267,18 @@ impl Provider for OidcProvider { session: Session, options: &ShieldOptions, ) -> Result { - let (pkce_verifier, csrf, nonce) = { + let OidcSession { + csrf, + nonce, + pkce_verifier, + .. + } = { let session_data = session.data(); let session_data = session_data .lock() .map_err(|err| SessionError::Lock(err.to_string()))?; - ( - session_data.verifier.clone(), - session_data.csrf.clone(), - session_data.nonce.clone(), - ) + session_data.provider(OIDC_PROVIDER_ID)? }; let state = request @@ -392,16 +399,21 @@ impl Provider for OidcProvider { .lock() .map_err(|err| SessionError::Lock(err.to_string()))?; - session_data.csrf = None; - session_data.nonce = None; - session_data.verifier = None; - session_data.authentication = Some(Authentication { provider_id: self.id(), subprovider_id: Some(subprovider.id), user_id: user.id(), }); - session_data.oidc_connection_id = Some(connection.id); + + session_data.set_provider( + OIDC_PROVIDER_ID, + OidcSession { + csrf: None, + nonce: None, + pkce_verifier: None, + oidc_connection_id: Some(connection.id), + }, + )?; } Ok(Response::Redirect( diff --git a/packages/providers/shield-oidc/src/session.rs b/packages/providers/shield-oidc/src/session.rs new file mode 100644 index 0000000..743de46 --- /dev/null +++ b/packages/providers/shield-oidc/src/session.rs @@ -0,0 +1,9 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct OidcSession { + pub csrf: Option, + pub nonce: Option, + pub pkce_verifier: Option, + pub oidc_connection_id: Option, +}