From 7b4264a3a2b2174c493e98403220f029e0f91fbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABlle=20Huisman?= Date: Fri, 23 May 2025 20:48:12 +0200 Subject: [PATCH] feat(shield-oauth): add oauth method --- Cargo.lock | 3 + packages/methods/shield-oauth/Cargo.toml | 2 + packages/methods/shield-oauth/src/client.rs | 9 + .../methods/shield-oauth/src/connection.rs | 36 ++ packages/methods/shield-oauth/src/lib.rs | 4 + packages/methods/shield-oauth/src/method.rs | 352 +++++++++++++++++- packages/methods/shield-oauth/src/provider.rs | 122 ++++-- packages/methods/shield-oauth/src/session.rs | 8 + packages/methods/shield-oauth/src/storage.rs | 30 +- packages/methods/shield-oidc/src/method.rs | 3 +- packages/storage/shield-sea-orm/Cargo.toml | 5 +- .../storage/shield-sea-orm/src/providers.rs | 2 + .../shield-sea-orm/src/providers/oauth.rs | 210 +++++++++++ .../shield-sea-orm/src/providers/oidc.rs | 2 +- 14 files changed, 743 insertions(+), 45 deletions(-) create mode 100644 packages/methods/shield-oauth/src/client.rs create mode 100644 packages/methods/shield-oauth/src/connection.rs create mode 100644 packages/methods/shield-oauth/src/session.rs create mode 100644 packages/storage/shield-sea-orm/src/providers/oauth.rs diff --git a/Cargo.lock b/Cargo.lock index e969aa7..735c61f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4540,7 +4540,9 @@ name = "shield-oauth" version = "0.0.4" dependencies = [ "async-trait", + "chrono", "oauth2", + "serde", "shield", ] @@ -4569,6 +4571,7 @@ dependencies = [ "serde", "serde_json", "shield", + "shield-oauth", "shield-oidc", "utoipa", ] diff --git a/packages/methods/shield-oauth/Cargo.toml b/packages/methods/shield-oauth/Cargo.toml index 2aad0fb..81e4af5 100644 --- a/packages/methods/shield-oauth/Cargo.toml +++ b/packages/methods/shield-oauth/Cargo.toml @@ -10,7 +10,9 @@ version.workspace = true [dependencies] async-trait.workspace = true +chrono.workspace = true oauth2 = { version = "5.0.0", default-features = false, features = ["reqwest"] } +serde.workspace = true shield.workspace = true [features] diff --git a/packages/methods/shield-oauth/src/client.rs b/packages/methods/shield-oauth/src/client.rs new file mode 100644 index 0000000..c85520b --- /dev/null +++ b/packages/methods/shield-oauth/src/client.rs @@ -0,0 +1,9 @@ +use oauth2::reqwest::{self, redirect::Policy}; +use shield::ConfigurationError; + +pub fn async_http_client() -> Result { + reqwest::Client::builder() + .redirect(Policy::none()) + .build() + .map_err(|err| ConfigurationError::Invalid(err.to_string())) +} diff --git a/packages/methods/shield-oauth/src/connection.rs b/packages/methods/shield-oauth/src/connection.rs new file mode 100644 index 0000000..e34d89c --- /dev/null +++ b/packages/methods/shield-oauth/src/connection.rs @@ -0,0 +1,36 @@ +use chrono::{DateTime, FixedOffset}; + +#[derive(Clone, Debug)] +pub struct OauthConnection { + pub id: String, + pub identifier: String, + pub token_type: String, + pub access_token: String, + pub refresh_token: Option, + pub expired_at: Option>, + pub scopes: Option>, + pub provider_id: String, + pub user_id: String, +} + +#[derive(Clone, Debug)] +pub struct CreateOauthConnection { + pub identifier: String, + pub token_type: String, + pub access_token: String, + pub refresh_token: Option, + pub expired_at: Option>, + pub scopes: Option>, + pub provider_id: String, + pub user_id: String, +} + +#[derive(Clone, Debug)] +pub struct UpdateOauthConnection { + pub id: String, + pub token_type: Option, + pub access_token: Option, + pub refresh_token: Option>, + pub expired_at: Option>>, + pub scopes: Option>>, +} diff --git a/packages/methods/shield-oauth/src/lib.rs b/packages/methods/shield-oauth/src/lib.rs index d4f163c..4928900 100644 --- a/packages/methods/shield-oauth/src/lib.rs +++ b/packages/methods/shield-oauth/src/lib.rs @@ -1,7 +1,11 @@ +mod client; +mod connection; mod method; mod provider; +mod session; mod storage; +pub use connection::*; pub use method::*; pub use provider::*; pub use storage::*; diff --git a/packages/methods/shield-oauth/src/method.rs b/packages/methods/shield-oauth/src/method.rs index 1c964fa..f23181b 100644 --- a/packages/methods/shield-oauth/src/method.rs +++ b/packages/methods/shield-oauth/src/method.rs @@ -1,10 +1,22 @@ use async_trait::async_trait; +use chrono::{DateTime, Duration, FixedOffset, Utc}; +use oauth2::{ + AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, + basic::BasicTokenResponse, url::form_urlencoded::parse, +}; use shield::{ - Method, Provider, ProviderError, Response, Session, ShieldError, ShieldOptions, - SignInCallbackRequest, SignInRequest, SignOutRequest, User, + Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Method, Provider, + ProviderError, Response, Session, SessionError, ShieldError, ShieldOptions, + SignInCallbackRequest, SignInRequest, SignOutRequest, UpdateUser, User, }; -use crate::{provider::OauthProvider, storage::OauthStorage}; +use crate::{ + CreateOauthConnection, OauthConnection, UpdateOauthConnection, + client::async_http_client, + provider::{OauthProvider, OauthProviderPkceCodeChallenge}, + session::OauthSession, + storage::OauthStorage, +}; pub const OAUTH_METHOD_ID: &str = "oauth"; @@ -26,7 +38,10 @@ impl OauthMethod { self } - async fn oauth_provider_by_id(&self, provider_id: &str) -> Result { + async fn oauth_provider_by_id_or_slug( + &self, + provider_id: &str, + ) -> Result { if let Some(provider) = self .providers .iter() @@ -35,12 +50,109 @@ impl OauthMethod { return Ok(provider.clone()); } - if let Some(provider) = self.storage.oauth_provider_by_id(provider_id).await? { + if let Some(provider) = self + .storage + .oauth_provider_by_id_or_slug(provider_id) + .await? + { return Ok(provider); } Err(ProviderError::ProviderNotFound(provider_id.to_owned()).into()) } + + async fn create_user(&self, email: Option<&str>, name: Option<&str>) -> Result { + if let Some(email) = email { + match self.storage.user_by_email(email).await? { + Some(_) => Err(ShieldError::Validation( + "\ + Email address `{email}` is already used by another account. \ + To link a new provider, sign in to with your exising account first. \ + If this is not your account, please contact support for assistence.\ + " + .to_owned(), + )), + None => Ok(self + .storage + .create_user( + CreateUser { + name: name.map(ToOwned::to_owned), + }, + CreateEmailAddress { + email: email.to_string(), + is_primary: true, + // TODO: from claim? + is_verified: false, + // TODO: generate if not verified + verification_token: None, + verification_token_expired_at: None, + verified_at: None, + }, + ) + .await?), + } + } else { + Err(ShieldError::Validation( + "Missing email address in OpenID Connect claims.".to_owned(), + )) + } + } + + async fn update_user(&self, user_id: &str, name: Option<&str>) -> Result { + self.storage + .update_user(UpdateUser { + id: user_id.to_owned(), + name: name.map(ToOwned::to_owned).map(Some), + }) + .await + .map_err(ShieldError::Storage) + } + + async fn create_oauth_connection( + &self, + provider_id: String, + user_id: String, + identifier: String, + token_response: BasicTokenResponse, + ) -> Result { + let (token_type, access_token, refresh_token, expired_at, scopes) = + parse_token_response(token_response)?; + + self.storage + .create_oauth_connection(CreateOauthConnection { + identifier, + token_type, + access_token, + refresh_token, + expired_at, + scopes, + provider_id, + user_id, + }) + .await + .map_err(ShieldError::Storage) + } + + async fn update_oauth_connection( + &self, + connection_id: String, + token_response: BasicTokenResponse, + ) -> Result { + let (token_type, access_token, refresh_token, expired_at, scopes) = + parse_token_response(token_response)?; + + self.storage + .update_oauth_connection(UpdateOauthConnection { + id: connection_id, + token_type: Some(token_type), + access_token: Some(access_token), + refresh_token: refresh_token.map(Some), + expired_at: expired_at.map(Some), + scopes: scopes.map(Some), + }) + .await + .map_err(ShieldError::Storage) + } } #[async_trait] @@ -65,7 +177,7 @@ impl Method for OauthMethod { &self, provider_id: &str, ) -> Result>, ShieldError> { - self.oauth_provider_by_id(provider_id) + self.oauth_provider_by_id_or_slug(provider_id) .await .map(|provider| Some(Box::new(provider) as Box)) } @@ -73,29 +185,205 @@ impl Method for OauthMethod { async fn sign_in( &self, request: SignInRequest, - _session: Session, + session: Session, _options: &ShieldOptions, ) -> Result { - let _provider = match request.provider_id { - Some(provider_id) => self.oauth_provider_by_id(&provider_id).await?, + let provider = match request.provider_id { + Some(provider_id) => self.oauth_provider_by_id_or_slug(&provider_id).await?, None => return Err(ProviderError::ProviderMissing.into()), }; - todo!("oauth sign in") + let client = provider.oauth_client().await?; + + let mut authorization_request = client + .authorize_url(CsrfToken::new_random) + .map_err(|err| ConfigurationError::Invalid(err.to_string()))?; + + let pkce_code_challenge = match provider.pkce_code_challenge { + OauthProviderPkceCodeChallenge::None => None, + OauthProviderPkceCodeChallenge::Plain => Some(PkceCodeChallenge::new_random_plain()), + OauthProviderPkceCodeChallenge::S256 => Some(PkceCodeChallenge::new_random_sha256()), + }; + + if let Some((pkce_code_challenge, _)) = &pkce_code_challenge { + authorization_request = + authorization_request.set_pkce_challenge(pkce_code_challenge.clone()); + } + + if let Some(scopes) = provider.scopes { + authorization_request = + authorization_request.add_scopes(scopes.into_iter().map(Scope::new)); + } + + if let Some(authorization_url_params) = provider.authorization_url_params { + let params = parse(authorization_url_params.trim_start_matches('?').as_bytes()); + + for (name, value) in params { + authorization_request = + authorization_request.add_extra_param(name.into_owned(), value.into_owned()); + } + } + + let (auth_url, csrf_token) = authorization_request.url(); + + { + let session_data = session.data(); + let mut session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.authentication = None; + + session_data.set_method( + OAUTH_METHOD_ID, + OauthSession { + csrf: Some(csrf_token.secret().clone()), + pkce_verifier: pkce_code_challenge + .map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()), + oauth_connection_id: None, + }, + )?; + } + + Ok(Response::Redirect(auth_url.to_string())) } async fn sign_in_callback( &self, request: SignInCallbackRequest, - _session: Session, - _options: &ShieldOptions, + session: Session, + options: &ShieldOptions, ) -> Result { - let _provider = match request.provider_id { - Some(provider_id) => self.oauth_provider_by_id(&provider_id).await?, + let OauthSession { + csrf, + pkce_verifier, + .. + } = { + let session_data = session.data(); + let session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.method(OAUTH_METHOD_ID)? + }; + + let state = request + .query + .as_ref() + .and_then(|query| query.get("state")) + .and_then(|code| code.as_str()) + .ok_or_else(|| ShieldError::Validation("Missing state.".to_owned()))?; + + if csrf.is_none_or(|csrf| csrf != state) { + return Err(ShieldError::Validation("Invalid state.".to_owned())); + } + + let authorization_code = request + .query + .as_ref() + .and_then(|query| query.get("code")) + .and_then(|code| code.as_str()) + .ok_or_else(|| ShieldError::Validation("Missing authorization code.".to_owned()))?; + + let provider = match request.provider_id { + Some(provider_id) => self.oauth_provider_by_id_or_slug(&provider_id).await?, None => return Err(ProviderError::ProviderMissing.into()), }; - todo!("oauth sign in callback") + let client = provider.oauth_client().await?; + + let mut token_request = client + .exchange_code(AuthorizationCode::new(authorization_code.to_owned())) + .map_err(|err| { + ShieldError::Configuration(ConfigurationError::Missing(err.to_string())) + })?; + + if let Some(pkce_verifier) = pkce_verifier { + token_request = token_request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier)); + } else if provider.pkce_code_challenge != OauthProviderPkceCodeChallenge::None { + return Err(ShieldError::Validation("Missing PKCE verifier.".to_owned())); + } + + if let Some(token_url_params) = provider.token_url_params { + let params = parse(token_url_params.trim_start_matches('?').as_bytes()); + + for (name, value) in params { + token_request = + token_request.add_extra_param(name.into_owned(), value.into_owned()); + } + } + + let async_http_client = async_http_client()?; + + let token_response = token_request + .request_async(&async_http_client) + .await + .map_err(|err| ShieldError::Request(err.to_string()))?; + + // TODO: user info + let identifier = ""; + let email = Some(""); + let name = Some(""); + + let (connection, user) = match self + .storage + .oauth_connection_by_identifier(&provider.id, identifier) + .await? + { + Some(connection) => { + let connection = self + .update_oauth_connection(connection.id, token_response) + .await?; + + let user = self.update_user(&connection.user_id, name).await?; + + (connection, user) + } + None => { + let user = self.create_user(email, name).await?; + + let connection = self + .create_oauth_connection( + provider.id.clone(), + user.id(), + identifier.to_owned(), + token_response, + ) + .await?; + + (connection, user) + } + }; + + session.renew().await?; + + { + let session_data = session.data(); + let mut session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.authentication = Some(Authentication { + method_id: self.id(), + provider_id: Some(provider.id), + user_id: user.id(), + }); + + session_data.set_method( + OAUTH_METHOD_ID, + OauthSession { + csrf: None, + pkce_verifier: None, + oauth_connection_id: Some(connection.id), + }, + )?; + } + + Ok(Response::Redirect( + request + .redirect_url + .unwrap_or(options.sign_in_redirect.clone()), + )) } async fn sign_out( @@ -105,10 +393,42 @@ impl Method for OauthMethod { _options: &ShieldOptions, ) -> Result { let _provider = match request.provider_id { - Some(provider_id) => self.oauth_provider_by_id(&provider_id).await?, + Some(provider_id) => self.oauth_provider_by_id_or_slug(&provider_id).await?, None => return Err(ProviderError::ProviderMissing.into()), }; todo!("oauth sign out") } } + +type ParsedTokenResponse = ( + String, + String, + Option, + Option>, + Option>, +); + +fn parse_token_response( + token_response: BasicTokenResponse, +) -> Result { + Ok(( + token_response.token_type().as_ref().to_string(), + token_response.access_token().secret().clone(), + token_response + .refresh_token() + .map(|refresh_token| refresh_token.secret().clone()), + match token_response.expires_in() { + Some(expires_in) => Some( + (Utc::now() + + Duration::from_std(expires_in) + .map_err(|err| ShieldError::Validation(err.to_string()))?) + .into(), + ), + None => None, + }, + token_response + .scopes() + .map(|scopes| scopes.iter().map(|scope| scope.to_string()).collect()), + )) +} diff --git a/packages/methods/shield-oauth/src/provider.rs b/packages/methods/shield-oauth/src/provider.rs index 0c7488f..71b5280 100644 --- a/packages/methods/shield-oauth/src/provider.rs +++ b/packages/methods/shield-oauth/src/provider.rs @@ -1,11 +1,32 @@ -use shield::{Form, Provider}; +use oauth2::{ + AuthUrl, Client, ClientId, ClientSecret, EndpointMaybeSet, EndpointNotSet, IntrospectionUrl, + RedirectUrl, RevocationUrl, StandardRevocableToken, TokenUrl, + basic::{ + BasicClient, BasicErrorResponse, BasicRevocationErrorResponse, + BasicTokenIntrospectionResponse, BasicTokenResponse, + }, +}; +use shield::{ConfigurationError, Form, Provider}; use crate::method::OAUTH_METHOD_ID; +type OauthClient = Client< + BasicErrorResponse, + BasicTokenResponse, + BasicTokenIntrospectionResponse, + StandardRevocableToken, + BasicRevocationErrorResponse, + EndpointMaybeSet, + EndpointNotSet, + EndpointMaybeSet, + EndpointMaybeSet, + EndpointMaybeSet, +>; + #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum OauthProviderVisibility { - Private, Public, + Unlisted, } #[derive(Clone, Copy, Debug, Eq, PartialEq)] @@ -15,28 +36,85 @@ pub enum OauthProviderPkceCodeChallenge { S256, } -// TODO: Remove allow dead code. -#[allow(dead_code)] #[derive(Clone, Debug)] pub struct OauthProvider { - pub(crate) id: String, - pub(crate) name: String, - pub(crate) slug: Option, - pub(crate) visibility: OauthProviderVisibility, - pub(crate) client_id: String, - pub(crate) client_secret: Option, - pub(crate) scopes: Option>, - pub(crate) redirect_url: Option, - pub(crate) authorization_url: Option, - pub(crate) authorization_url_params: Option, - pub(crate) token_url: Option, - pub(crate) token_url_params: Option, - pub(crate) introspection_url: Option, - pub(crate) introspection_url_params: Option, - pub(crate) revocation_url: Option, - pub(crate) revocation_url_params: Option, - pub(crate) pkce_code_challenge: OauthProviderPkceCodeChallenge, - pub(crate) icon_url: Option, + pub id: String, + pub name: String, + pub slug: Option, + pub visibility: OauthProviderVisibility, + pub client_id: String, + pub client_secret: Option, + pub scopes: Option>, + pub redirect_url: Option, + pub authorization_url: Option, + pub authorization_url_params: Option, + pub token_url: Option, + pub token_url_params: Option, + pub introspection_url: Option, + pub introspection_url_params: Option, + pub revocation_url: Option, + pub revocation_url_params: Option, + pub pkce_code_challenge: OauthProviderPkceCodeChallenge, + pub icon_url: Option, +} + +impl OauthProvider { + pub async fn oauth_client(&self) -> Result { + let mut client = BasicClient::new(ClientId::new(self.client_id.clone())); + + if let Some(client_secret) = &self.client_secret { + client = client.set_client_secret(ClientSecret::new(client_secret.clone())); + } + + if let Some(redirect_url) = &self.redirect_url { + client = client.set_redirect_uri( + RedirectUrl::new(redirect_url.clone()) + .map_err(|err| ConfigurationError::Invalid(err.to_string()))?, + ); + } + + let client = client.set_auth_uri_option( + self.authorization_url + .as_ref() + .map(|authorization_url| { + AuthUrl::new(authorization_url.clone()) + .map_err(|err| ConfigurationError::Invalid(err.to_string())) + }) + .transpose()?, + ); + + let client = client.set_token_uri_option( + self.token_url + .as_ref() + .map(|token_url| { + TokenUrl::new(token_url.clone()) + .map_err(|err| ConfigurationError::Invalid(err.to_string())) + }) + .transpose()?, + ); + + let client = client.set_introspection_url_option( + self.introspection_url + .as_ref() + .map(|introspection_url| { + IntrospectionUrl::new(introspection_url.clone()) + .map_err(|err| ConfigurationError::Invalid(err.to_string())) + }) + .transpose()?, + ); + + let client = client.set_revocation_url_option( + self.revocation_url + .as_ref() + .map(|revocation_url| { + RevocationUrl::new(revocation_url.clone()) + .map_err(|err| ConfigurationError::Invalid(err.to_string())) + }) + .transpose()?, + ); + + Ok(client) + } } impl Provider for OauthProvider { diff --git a/packages/methods/shield-oauth/src/session.rs b/packages/methods/shield-oauth/src/session.rs new file mode 100644 index 0000000..94f0401 --- /dev/null +++ b/packages/methods/shield-oauth/src/session.rs @@ -0,0 +1,8 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct OauthSession { + pub csrf: Option, + pub pkce_verifier: Option, + pub oauth_connection_id: Option, +} diff --git a/packages/methods/shield-oauth/src/storage.rs b/packages/methods/shield-oauth/src/storage.rs index 8f5a175..7a97ec0 100644 --- a/packages/methods/shield-oauth/src/storage.rs +++ b/packages/methods/shield-oauth/src/storage.rs @@ -2,14 +2,40 @@ use async_trait::async_trait; use shield::{Storage, StorageError, User}; -use crate::provider::OauthProvider; +use crate::{ + connection::{CreateOauthConnection, OauthConnection, UpdateOauthConnection}, + provider::OauthProvider, +}; #[async_trait] pub trait OauthStorage: Storage + Sync { async fn oauth_providers(&self) -> Result, StorageError>; - async fn oauth_provider_by_id( + async fn oauth_provider_by_id_or_slug( &self, provider_id: &str, ) -> Result, StorageError>; + + async fn oauth_connection_by_id( + &self, + connection_id: &str, + ) -> Result, StorageError>; + + async fn oauth_connection_by_identifier( + &self, + provider_id: &str, + identifier: &str, + ) -> Result, StorageError>; + + async fn create_oauth_connection( + &self, + connection: CreateOauthConnection, + ) -> Result; + + async fn update_oauth_connection( + &self, + connection: UpdateOauthConnection, + ) -> Result; + + async fn delete_oauth_connection(&self, connection_id: &str) -> Result<(), StorageError>; } diff --git a/packages/methods/shield-oidc/src/method.rs b/packages/methods/shield-oidc/src/method.rs index 8cd7c17..b33d044 100644 --- a/packages/methods/shield-oidc/src/method.rs +++ b/packages/methods/shield-oidc/src/method.rs @@ -107,7 +107,8 @@ impl OidcMethod { id: user_id.to_owned(), name: claims .name() - .map(|name| name.get(None).map(|name| name.to_string())), + .and_then(|name| name.get(None).map(|name| name.to_string())) + .map(Some), }) .await .map_err(ShieldError::Storage) diff --git a/packages/storage/shield-sea-orm/Cargo.toml b/packages/storage/shield-sea-orm/Cargo.toml index e5fd178..04046c8 100644 --- a/packages/storage/shield-sea-orm/Cargo.toml +++ b/packages/storage/shield-sea-orm/Cargo.toml @@ -18,7 +18,7 @@ serde_json.workspace = true shield.workspace = true # shield-credentials = { workspace = true, optional = true } # shield-email = { workspace = true, optional = true } -# shield-oauth = { workspace = true, optional = true } +shield-oauth = { workspace = true, optional = true } shield-oidc = { workspace = true, optional = true } # shield-webauthn = { workspace = true, optional = true } utoipa = { workspace = true, optional = true } @@ -35,10 +35,9 @@ all-methods = [ ] # method-credentials = ["dep:shield-credentials"] # method-email = ["dep:shield-email"] -# method-oauth = ["dep:shield-oauth"] # method-credentials = [] method-email = [] -method-oauth = [] +method-oauth = ["dep:shield-oauth"] method-oidc = ["dep:shield-oidc"] # method-webauthn = ["dep:shield-webauthn"] utoipa = ["dep:utoipa", "shield/utoipa"] diff --git a/packages/storage/shield-sea-orm/src/providers.rs b/packages/storage/shield-sea-orm/src/providers.rs index 6ba5800..6408e6b 100644 --- a/packages/storage/shield-sea-orm/src/providers.rs +++ b/packages/storage/shield-sea-orm/src/providers.rs @@ -1,2 +1,4 @@ +#[cfg(feature = "method-oauth")] +pub mod oauth; #[cfg(feature = "method-oidc")] pub mod oidc; diff --git a/packages/storage/shield-sea-orm/src/providers/oauth.rs b/packages/storage/shield-sea-orm/src/providers/oauth.rs new file mode 100644 index 0000000..1c9852c --- /dev/null +++ b/packages/storage/shield-sea-orm/src/providers/oauth.rs @@ -0,0 +1,210 @@ +use async_trait::async_trait; +use sea_orm::{ActiveModelTrait, ActiveValue, ColumnTrait, EntityTrait, QueryFilter}; +use shield::StorageError; +use shield_oauth::{ + CreateOauthConnection, OauthConnection, OauthProvider, OauthProviderPkceCodeChallenge, + OauthProviderVisibility, OauthStorage, UpdateOauthConnection, +}; + +use crate::{ + entities::{oauth_provider, oauth_provider_connection}, + storage::SeaOrmStorage, + user::User, +}; + +#[async_trait] +impl OauthStorage for SeaOrmStorage { + async fn oauth_providers(&self) -> Result, StorageError> { + oauth_provider::Entity::find() + .all(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string())) + .and_then(|providers| providers.into_iter().map(OauthProvider::try_from).collect()) + } + + async fn oauth_provider_by_id_or_slug( + &self, + provider_id: &str, + ) -> Result, StorageError> { + let condition = match Self::parse_uuid(provider_id) { + Ok(provider_id) => oauth_provider::Column::Id.eq(provider_id), + Err(_) => oauth_provider::Column::Slug.eq(provider_id.to_lowercase()), + }; + + oauth_provider::Entity::find() + .filter(condition) + .one(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string())) + .and_then(|provider| match provider { + Some(provider) => OauthProvider::try_from(provider).map(Option::Some), + None => Ok(None), + }) + } + + async fn oauth_connection_by_id( + &self, + connection_id: &str, + ) -> Result, StorageError> { + oauth_provider_connection::Entity::find_by_id(Self::parse_uuid(connection_id)?) + .one(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string())) + .map(|connection| connection.map(OauthConnection::from)) + } + + async fn oauth_connection_by_identifier( + &self, + provider_id: &str, + identifier: &str, + ) -> Result, StorageError> { + oauth_provider_connection::Entity::find() + .filter( + oauth_provider_connection::Column::ProviderId.eq(Self::parse_uuid(provider_id)?), + ) + .filter(oauth_provider_connection::Column::Identifier.eq(identifier)) + .one(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string())) + .map(|connection| connection.map(OauthConnection::from)) + } + + async fn create_oauth_connection( + &self, + connection: CreateOauthConnection, + ) -> Result { + let active_model = oauth_provider_connection::ActiveModel { + identifier: ActiveValue::Set(connection.identifier), + token_type: ActiveValue::Set(connection.token_type), + access_token: ActiveValue::Set(connection.access_token), + refresh_token: ActiveValue::Set(connection.refresh_token), + expired_at: ActiveValue::Set(connection.expired_at), + scopes: ActiveValue::Set(connection.scopes.map(|scopes| scopes.join(","))), + provider_id: ActiveValue::Set(Self::parse_uuid(&connection.provider_id)?), + user_id: ActiveValue::Set(Self::parse_uuid(&connection.user_id)?), + ..Default::default() + }; + + active_model + .insert(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string())) + .map(OauthConnection::from) + } + + async fn update_oauth_connection( + &self, + connection: UpdateOauthConnection, + ) -> Result { + let mut active_model: oauth_provider_connection::ActiveModel = + oauth_provider_connection::Entity::find() + .filter(oauth_provider_connection::Column::Id.eq(Self::parse_uuid(&connection.id)?)) + .one(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string()))? + .ok_or_else(|| StorageError::NotFound("OauthConnection".to_owned(), connection.id))? + .into(); + + if let Some(token_type) = connection.token_type { + active_model.token_type = ActiveValue::Set(token_type); + } + if let Some(access_token) = connection.access_token { + active_model.access_token = ActiveValue::Set(access_token); + } + if let Some(refresh_token) = connection.refresh_token { + active_model.refresh_token = ActiveValue::Set(refresh_token); + } + if let Some(expired_at) = connection.expired_at { + active_model.expired_at = ActiveValue::Set(expired_at); + } + if let Some(scopes) = connection.scopes { + active_model.scopes = ActiveValue::Set(scopes.map(|scopes| scopes.join(","))); + } + + active_model + .update(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string())) + .map(OauthConnection::from) + } + + async fn delete_oauth_connection(&self, connection_id: &str) -> Result<(), StorageError> { + oauth_provider_connection::Entity::delete_by_id(Self::parse_uuid(connection_id)?) + .exec(&self.database) + .await + .map_err(|err| StorageError::Engine(err.to_string())) + .map(|_| ()) + } +} + +impl From for OauthProviderVisibility { + fn from(value: oauth_provider::OauthProviderVisibility) -> Self { + match value { + oauth_provider::OauthProviderVisibility::Public => OauthProviderVisibility::Public, + oauth_provider::OauthProviderVisibility::Unlisted => OauthProviderVisibility::Unlisted, + } + } +} + +impl From for OauthProviderPkceCodeChallenge { + fn from(value: oauth_provider::OauthProviderPkceCodeChallenge) -> Self { + match value { + oauth_provider::OauthProviderPkceCodeChallenge::None => { + OauthProviderPkceCodeChallenge::None + } + oauth_provider::OauthProviderPkceCodeChallenge::Plain => { + OauthProviderPkceCodeChallenge::Plain + } + oauth_provider::OauthProviderPkceCodeChallenge::S256 => { + OauthProviderPkceCodeChallenge::S256 + } + } + } +} + +impl TryFrom for OauthProvider { + type Error = StorageError; + + fn try_from(value: oauth_provider::Model) -> Result { + Ok(OauthProvider { + id: value.id.to_string(), + name: value.name, + slug: value.slug, + icon_url: value.icon_url, + visibility: value.visibility.into(), + client_id: value.client_id, + client_secret: value.client_secret, + scopes: value + .scopes + .map(|scopes| scopes.split(',').map(|s| s.to_string()).collect()), + redirect_url: value.redirect_url, + authorization_url: value.authorization_url, + authorization_url_params: value.authorization_url_params, + token_url: value.token_url, + token_url_params: value.token_url_params, + introspection_url: value.introspection_url, + introspection_url_params: value.introspection_url_params, + revocation_url: value.revocation_url, + revocation_url_params: value.revocation_url_params, + pkce_code_challenge: value.pkce_code_challenge.into(), + }) + } +} + +impl From for OauthConnection { + fn from(value: oauth_provider_connection::Model) -> Self { + OauthConnection { + id: value.id.to_string(), + identifier: value.identifier, + token_type: value.token_type, + access_token: value.access_token, + refresh_token: value.refresh_token, + expired_at: value.expired_at, + scopes: value + .scopes + .map(|scopes| scopes.split(',').map(|s| s.to_string()).collect()), + provider_id: value.provider_id.to_string(), + user_id: value.user_id.to_string(), + } + } +} diff --git a/packages/storage/shield-sea-orm/src/providers/oidc.rs b/packages/storage/shield-sea-orm/src/providers/oidc.rs index fa8c4a6..8ab2a85 100644 --- a/packages/storage/shield-sea-orm/src/providers/oidc.rs +++ b/packages/storage/shield-sea-orm/src/providers/oidc.rs @@ -101,7 +101,7 @@ impl OidcStorage for SeaOrmStorage { .one(&self.database) .await .map_err(|err| StorageError::Engine(err.to_string()))? - .ok_or_else(|| StorageError::NotFound("OIDC Connection".to_owned(), connection.id))? + .ok_or_else(|| StorageError::NotFound("OidcConnection".to_owned(), connection.id))? .into(); if let Some(token_type) = connection.token_type {