From ea4a2d42de434bc0a8f3cd6508419d3c8bf8dee9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABlle=20Huisman?= Date: Thu, 22 May 2025 20:48:25 +0200 Subject: [PATCH] refactor: rename provider to method and subprovider to provider --- Cargo.toml | 16 +++ examples/leptos-actix/Cargo.toml | 14 +- examples/leptos-actix/src/main.rs | 4 +- examples/leptos-axum/Cargo.toml | 14 +- examples/leptos-axum/src/main.rs | 4 +- examples/sea-orm/Cargo.toml | 5 +- packages/core/shield/src/error.rs | 14 +- packages/core/shield/src/lib.rs | 2 + packages/core/shield/src/method.rs | 118 ++++++++++++++++ packages/core/shield/src/provider.rs | 125 +---------------- packages/core/shield/src/request.rs | 12 +- packages/core/shield/src/session.rs | 20 +-- packages/core/shield/src/shield.rs | 106 +++++++------- packages/core/shield/src/shield_dyn.rs | 28 ++-- packages/integrations/shield-actix/Cargo.toml | 2 +- packages/integrations/shield-axum/Cargo.toml | 4 +- .../integrations/shield-axum/src/error.rs | 10 +- packages/integrations/shield-axum/src/path.rs | 8 +- .../integrations/shield-axum/src/router.rs | 12 +- .../integrations/shield-axum/src/routes.rs | 4 +- .../routes/{subproviders.rs => providers.rs} | 16 +-- .../shield-axum/src/routes/sign_in.rs | 6 +- .../src/routes/sign_in_callback.rs | 6 +- .../shield-leptos-actix/Cargo.toml | 6 +- .../shield-leptos-axum/Cargo.toml | 6 +- .../integrations/shield-leptos/Cargo.toml | 2 +- .../shield-leptos/src/routes/sign_in.rs | 31 ++--- packages/integrations/shield-tower/Cargo.toml | 2 +- .../shield-credentials/Cargo.toml | 4 +- .../shield-credentials/README.md | 2 +- .../shield-credentials/src/lib.rs | 0 .../shield-email/Cargo.toml | 4 +- .../shield-email/README.md | 2 +- .../shield-email/src/lib.rs | 0 .../shield-oauth/Cargo.toml | 4 +- .../shield-oauth/README.md | 2 +- .../shield-oauth/src/lib.rs | 4 +- packages/methods/shield-oauth/src/method.rs | 114 ++++++++++++++++ .../shield-oauth/src/provider.rs} | 12 +- packages/methods/shield-oauth/src/storage.rs | 15 ++ .../shield-oidc/Cargo.toml | 4 +- .../shield-oidc/README.md | 2 +- .../shield-oidc/src/builders.rs | 0 .../shield-oidc/src/builders/google.rs | 10 +- .../shield-oidc/src/builders/keycloak.rs | 10 +- .../shield-oidc/src/claims.rs | 0 .../shield-oidc/src/client.rs | 0 .../shield-oidc/src/connection.rs | 4 +- .../shield-oidc/src/lib.rs | 4 +- .../shield-oidc/src/method.rs} | 129 +++++++++--------- .../shield-oidc/src/provider.rs} | 14 +- .../shield-oidc/src/session.rs | 0 .../shield-oidc/src/storage.rs | 12 +- .../shield-webauthn/Cargo.toml | 4 +- .../shield-webauthn/README.md | 2 +- .../shield-webauthn/src/lib.rs | 0 .../providers/shield-oauth/src/provider.rs | 120 ---------------- .../providers/shield-oauth/src/storage.rs | 15 -- packages/storage/shield-diesel/Cargo.toml | 2 +- packages/storage/shield-memory/Cargo.toml | 34 ++--- .../storage/shield-memory/src/providers.rs | 2 +- .../shield-memory/src/providers/oidc.rs | 16 +-- packages/storage/shield-memory/src/storage.rs | 2 +- packages/storage/shield-sea-orm/Cargo.toml | 40 +++--- .../storage/shield-sea-orm/src/entities.rs | 10 +- .../shield-sea-orm/src/entities/prelude.rs | 10 +- .../shield-sea-orm/src/entities/user.rs | 8 +- .../src/migrations/providers.rs | 12 +- .../storage/shield-sea-orm/src/providers.rs | 2 +- .../shield-sea-orm/src/providers/oidc.rs | 43 +++--- packages/storage/shield-sqlx/Cargo.toml | 2 +- 71 files changed, 631 insertions(+), 642 deletions(-) create mode 100644 packages/core/shield/src/method.rs rename packages/integrations/shield-axum/src/routes/{subproviders.rs => providers.rs} (50%) rename packages/{providers => methods}/shield-credentials/Cargo.toml (62%) rename packages/{providers => methods}/shield-credentials/README.md (91%) rename packages/{providers => methods}/shield-credentials/src/lib.rs (100%) rename packages/{providers => methods}/shield-email/Cargo.toml (63%) rename packages/{providers => methods}/shield-email/README.md (92%) rename packages/{providers => methods}/shield-email/src/lib.rs (100%) rename packages/{providers => methods}/shield-oauth/Cargo.toml (78%) rename packages/{providers => methods}/shield-oauth/README.md (92%) rename packages/{providers => methods}/shield-oauth/src/lib.rs (62%) create mode 100644 packages/methods/shield-oauth/src/method.rs rename packages/{providers/shield-oauth/src/subprovider.rs => methods/shield-oauth/src/provider.rs} (86%) create mode 100644 packages/methods/shield-oauth/src/storage.rs rename packages/{providers => methods}/shield-oidc/Cargo.toml (85%) rename packages/{providers => methods}/shield-oidc/README.md (90%) rename packages/{providers => methods}/shield-oidc/src/builders.rs (100%) rename packages/{providers => methods}/shield-oidc/src/builders/google.rs (52%) rename packages/{providers => methods}/shield-oidc/src/builders/keycloak.rs (54%) rename packages/{providers => methods}/shield-oidc/src/claims.rs (100%) rename packages/{providers => methods}/shield-oidc/src/client.rs (100%) rename packages/{providers => methods}/shield-oidc/src/connection.rs (94%) rename packages/{providers => methods}/shield-oidc/src/lib.rs (81%) rename packages/{providers/shield-oidc/src/provider.rs => methods/shield-oidc/src/method.rs} (81%) rename packages/{providers/shield-oidc/src/subprovider.rs => methods/shield-oidc/src/provider.rs} (96%) rename packages/{providers => methods}/shield-oidc/src/session.rs (100%) rename packages/{providers => methods}/shield-oidc/src/storage.rs (76%) rename packages/{providers => methods}/shield-webauthn/Cargo.toml (63%) rename packages/{providers => methods}/shield-webauthn/README.md (92%) rename packages/{providers => methods}/shield-webauthn/src/lib.rs (100%) delete mode 100644 packages/providers/shield-oauth/src/provider.rs delete mode 100644 packages/providers/shield-oauth/src/storage.rs diff --git a/Cargo.toml b/Cargo.toml index 31f7386..4d65015 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,22 @@ sea-orm = "1.1.2" sea-orm-migration = "1.1.2" serde = "1.0.215" serde_json = "1.0.133" +shield = { path = "./packages/core/shield", version = "0.0.4" } +shield-actix = { path = "./packages/integrations/shield-actix", version = "0.0.4" } +shield-axum = { path = "./packages/integrations/shield-axum", version = "0.0.4" } +shield-credentials = { path = "./packages/methods/shield-credentials", version = "0.0.4" } +shield-diesel = { path = "./packages/storage/shield-diesel", version = "0.0.4" } +shield-email = { path = "./packages/methods/shield-email", version = "0.0.4" } +shield-leptos = { path = "./packages/integrations/shield-leptos", version = "0.0.4" } +shield-leptos-actix = { path = "./packages/integrations/shield-leptos-actix", version = "0.0.4" } +shield-leptos-axum = { path = "./packages/integrations/shield-leptos-axum", version = "0.0.4" } +shield-memory = { path = "./packages/storage/shield-memory", version = "0.0.4" } +shield-oauth = { path = "./packages/methods/shield-oauth", version = "0.0.4" } +shield-oidc = { path = "./packages/methods/shield-oidc", version = "0.0.4" } +shield-sea-orm = { path = "./packages/storage/shield-sea-orm", version = "0.0.4" } +shield-sqlx = { path = "./packages/storage/shield-sqlx", version = "0.0.4" } +shield-tower = { path = "./packages/integrations/shield-tower", version = "0.0.4" } +shield-webauthn = { path = "./packages/methods/shield-webauthn", version = "0.0.4" } tokio = "1.42.0" tower-layer = "0.3.3" tower-service = "0.3.3" diff --git a/examples/leptos-actix/Cargo.toml b/examples/leptos-actix/Cargo.toml index 4a4c3ce..c43b53d 100644 --- a/examples/leptos-actix/Cargo.toml +++ b/examples/leptos-actix/Cargo.toml @@ -23,13 +23,11 @@ leptos.workspace = true leptos_actix = { workspace = true, optional = true } leptos_meta.workspace = true leptos_router.workspace = true -shield = { path = "../../packages/core/shield" } -shield-leptos = { path = "../../packages/integrations/shield-leptos" } -shield-leptos-actix = { path = "../../packages/integrations/shield-leptos-actix", optional = true } -shield-memory = { path = "../../packages/storage/shield-memory", optional = true } -shield-oidc = { path = "../../packages/providers/shield-oidc", features = [ - "native-tls", -], optional = true } +shield.workspace = true +shield-leptos.workspace = true +shield-leptos-actix = { workspace = true, optional = true } +shield-memory = { workspace = true, optional = true } +shield-oidc = { workspace = true, features = ["native-tls"], optional = true } tracing.workspace = true tracing-subscriber.workspace = true wasm-bindgen.workspace = true @@ -49,7 +47,7 @@ ssr = [ "leptos/ssr", "leptos_meta/ssr", "leptos_router/ssr", - "shield-memory/provider-oidc", + "shield-memory/method-oidc", ] [package.metadata.leptos] diff --git a/examples/leptos-actix/src/main.rs b/examples/leptos-actix/src/main.rs index 6dfe314..00a0bd7 100644 --- a/examples/leptos-actix/src/main.rs +++ b/examples/leptos-actix/src/main.rs @@ -12,7 +12,7 @@ async fn main() -> std::io::Result<()> { use shield_examples_leptos_actix::app::*; use shield_leptos_actix::{ShieldMiddleware, provide_actix_integration}; use shield_memory::{MemoryStorage, User}; - use shield_oidc::{Keycloak, OidcProvider}; + use shield_oidc::{Keycloak, OidcMethod}; use tracing::{info, level_filters::LevelFilter}; // Initialize tracing @@ -44,7 +44,7 @@ async fn main() -> std::io::Result<()> { let shield = Shield::new( shield_storage.clone(), vec![Arc::new( - OidcProvider::new(shield_storage).with_subproviders([Keycloak::builder( + OidcMethod::new(shield_storage).with_providers([Keycloak::builder( "keycloak", "http://localhost:18080/realms/Shield", "client1", diff --git a/examples/leptos-axum/Cargo.toml b/examples/leptos-axum/Cargo.toml index dbce6fc..b0f756f 100644 --- a/examples/leptos-axum/Cargo.toml +++ b/examples/leptos-axum/Cargo.toml @@ -19,15 +19,13 @@ leptos.workspace = true leptos_axum = { workspace = true, optional = true } leptos_meta.workspace = true leptos_router.workspace = true -shield = { path = "../../packages/core/shield" } -shield-leptos = { path = "../../packages/integrations/shield-leptos" } -shield-leptos-axum = { path = "../../packages/integrations/shield-leptos-axum", features = [ +shield.workspace = true +shield-leptos.workspace = true +shield-leptos-axum = { workspace = true, features = [ "utoipa", ], optional = true } -shield-memory = { path = "../../packages/storage/shield-memory", optional = true } -shield-oidc = { path = "../../packages/providers/shield-oidc", features = [ - "native-tls", -], optional = true } +shield-memory = { workspace = true, optional = true } +shield-oidc = { workspace = true, features = ["native-tls"], optional = true } time = "0.3.37" tokio = { workspace = true, features = ["rt-multi-thread"], optional = true } tower-sessions = { workspace = true, optional = true } @@ -52,7 +50,7 @@ ssr = [ "leptos/ssr", "leptos_meta/ssr", "leptos_router/ssr", - "shield-memory/provider-oidc", + "shield-memory/method-oidc", ] [package.metadata.leptos] diff --git a/examples/leptos-axum/src/main.rs b/examples/leptos-axum/src/main.rs index 8a865e2..d7c7d0c 100644 --- a/examples/leptos-axum/src/main.rs +++ b/examples/leptos-axum/src/main.rs @@ -10,7 +10,7 @@ async fn main() { use shield_examples_leptos_axum::app::*; use shield_leptos_axum::{AuthRoutes, ShieldLayer, auth_required, provide_axum_integration}; use shield_memory::{MemoryStorage, User}; - use shield_oidc::{Keycloak, OidcProvider}; + use shield_oidc::{Keycloak, OidcMethod}; use time::Duration; use tokio::net::TcpListener; use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; @@ -40,7 +40,7 @@ async fn main() { let shield = Shield::new( storage.clone(), vec![Arc::new( - OidcProvider::new(storage).with_subproviders([Keycloak::builder( + OidcMethod::new(storage).with_providers([Keycloak::builder( "keycloak", "http://localhost:18080/realms/Shield", "client1", diff --git a/examples/sea-orm/Cargo.toml b/examples/sea-orm/Cargo.toml index 7eafe5b..b09bedb 100644 --- a/examples/sea-orm/Cargo.toml +++ b/examples/sea-orm/Cargo.toml @@ -23,8 +23,5 @@ sea-orm-migration = { workspace = true, features = [ "sqlx-postgres", "sqlx-sqlite", ] } -shield-sea-orm = { path = "../../packages/storage/shield-sea-orm", features = [ - "all-providers", - "utoipa", -] } +shield-sea-orm = { workspace = true, features = ["all-methods", "utoipa"] } tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } diff --git a/packages/core/shield/src/error.rs b/packages/core/shield/src/error.rs index 6d23f4a..ea25668 100644 --- a/packages/core/shield/src/error.rs +++ b/packages/core/shield/src/error.rs @@ -1,13 +1,17 @@ use thiserror::Error; +#[derive(Debug, Error)] +pub enum MethodError { + #[error("method `{0}` not found")] + MethodNotFound(String), +} + #[derive(Debug, Error)] pub enum ProviderError { + #[error("provider is missing")] + ProviderMissing, #[error("provider `{0}` not found")] ProviderNotFound(String), - #[error("subprovider is missing")] - SubproviderMissing, - #[error("subprovider `{0}` not found")] - SubproviderNotFound(String), } #[derive(Debug, Error)] @@ -45,6 +49,8 @@ pub enum SessionError { #[derive(Debug, Error)] pub enum ShieldError { + #[error(transparent)] + Method(#[from] MethodError), #[error(transparent)] Provider(#[from] ProviderError), #[error(transparent)] diff --git a/packages/core/shield/src/lib.rs b/packages/core/shield/src/lib.rs index 69e5928..a97d459 100644 --- a/packages/core/shield/src/lib.rs +++ b/packages/core/shield/src/lib.rs @@ -1,5 +1,6 @@ mod error; mod form; +mod method; mod options; mod provider; mod request; @@ -12,6 +13,7 @@ mod user; pub use error::*; pub use form::*; +pub use method::*; pub use options::*; pub use provider::*; pub use request::*; diff --git a/packages/core/shield/src/method.rs b/packages/core/shield/src/method.rs new file mode 100644 index 0000000..cf4468a --- /dev/null +++ b/packages/core/shield/src/method.rs @@ -0,0 +1,118 @@ +use async_trait::async_trait; + +use crate::{ + error::ShieldError, + options::ShieldOptions, + provider::Provider, + request::{SignInCallbackRequest, SignInRequest, SignOutRequest}, + response::Response, + session::Session, +}; + +#[async_trait] +pub trait Method: Send + Sync { + fn id(&self) -> String; + + async fn providers(&self) -> Result>, ShieldError>; + + async fn provider_by_id( + &self, + provider_id: &str, + ) -> Result>, ShieldError>; + + async fn sign_in( + &self, + request: SignInRequest, + session: Session, + options: &ShieldOptions, + ) -> Result; + + async fn sign_in_callback( + &self, + request: SignInCallbackRequest, + session: Session, + options: &ShieldOptions, + ) -> Result; + + async fn sign_out( + &self, + request: SignOutRequest, + session: Session, + options: &ShieldOptions, + ) -> Result; +} + +#[cfg(test)] +pub(crate) mod tests { + use async_trait::async_trait; + + use crate::{ + ShieldOptions, + error::ShieldError, + provider::Provider, + request::{SignInCallbackRequest, SignInRequest, SignOutRequest}, + response::Response, + session::Session, + }; + + use super::Method; + + pub const TEST_METHOD_ID: &str = "test"; + + #[derive(Default)] + pub struct TestMethod { + id: Option<&'static str>, + } + + impl TestMethod { + pub fn with_id(mut self, id: &'static str) -> Self { + self.id = Some(id); + self + } + } + + #[async_trait] + impl Method for TestMethod { + fn id(&self) -> String { + self.id.unwrap_or(TEST_METHOD_ID).to_owned() + } + + async fn providers(&self) -> Result>, ShieldError> { + Ok(vec![]) + } + + async fn provider_by_id( + &self, + _provider_id: &str, + ) -> Result>, ShieldError> { + Ok(None) + } + + async fn sign_in( + &self, + _request: SignInRequest, + _session: Session, + _options: &ShieldOptions, + ) -> Result { + todo!("redirect back?") + } + + async fn sign_in_callback( + &self, + _request: SignInCallbackRequest, + _session: Session, + _options: &ShieldOptions, + ) -> Result { + todo!("redirect back?") + } + + async fn sign_out( + &self, + _request: SignOutRequest, + _session: Session, + _options: &ShieldOptions, + ) -> Result { + todo!("redirect back?") + } + } +} diff --git a/packages/core/shield/src/provider.rs b/packages/core/shield/src/provider.rs index 07b1bf2..b33e84f 100644 --- a/packages/core/shield/src/provider.rs +++ b/packages/core/shield/src/provider.rs @@ -1,50 +1,9 @@ -use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use crate::{ - error::ShieldError, - form::Form, - options::ShieldOptions, - request::{SignInCallbackRequest, SignInRequest, SignOutRequest}, - response::Response, - session::Session, -}; +use crate::form::Form; -#[async_trait] pub trait Provider: Send + Sync { - fn id(&self) -> String; - - async fn subproviders(&self) -> Result>, ShieldError>; - - async fn subprovider_by_id( - &self, - subprovider_id: &str, - ) -> Result>, ShieldError>; - - async fn sign_in( - &self, - request: SignInRequest, - session: Session, - options: &ShieldOptions, - ) -> Result; - - async fn sign_in_callback( - &self, - request: SignInCallbackRequest, - session: Session, - options: &ShieldOptions, - ) -> Result; - - async fn sign_out( - &self, - request: SignOutRequest, - session: Session, - options: &ShieldOptions, - ) -> Result; -} - -pub trait Subprovider: Send + Sync { - fn provider_id(&self) -> String; + fn method_id(&self) -> String; fn id(&self) -> Option; @@ -58,84 +17,10 @@ pub trait Subprovider: Send + Sync { #[derive(Clone, Debug, Deserialize, Serialize)] #[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] #[serde(rename_all = "camelCase")] -pub struct SubproviderVisualisation { +pub struct ProviderVisualisation { pub key: String, - pub provider_id: String, - pub subprovider_id: Option, + pub method_id: String, + pub provider_id: Option, pub name: String, pub icon_url: Option, } - -#[cfg(test)] -pub(crate) mod tests { - use async_trait::async_trait; - - use crate::{ - ShieldOptions, - error::ShieldError, - request::{SignInCallbackRequest, SignInRequest, SignOutRequest}, - response::Response, - session::Session, - }; - - use super::{Provider, Subprovider}; - - pub const TEST_PROVIDER_ID: &str = "test"; - - #[derive(Default)] - pub struct TestProvider { - id: Option<&'static str>, - } - - impl TestProvider { - pub fn with_id(mut self, id: &'static str) -> Self { - self.id = Some(id); - self - } - } - - #[async_trait] - impl Provider for TestProvider { - fn id(&self) -> String { - self.id.unwrap_or(TEST_PROVIDER_ID).to_owned() - } - - async fn subproviders(&self) -> Result>, ShieldError> { - Ok(vec![]) - } - - async fn subprovider_by_id( - &self, - _subprovider_id: &str, - ) -> Result>, ShieldError> { - Ok(None) - } - - async fn sign_in( - &self, - _request: SignInRequest, - _session: Session, - _options: &ShieldOptions, - ) -> Result { - todo!("redirect back?") - } - - async fn sign_in_callback( - &self, - _request: SignInCallbackRequest, - _session: Session, - _options: &ShieldOptions, - ) -> Result { - todo!("redirect back?") - } - - async fn sign_out( - &self, - _request: SignOutRequest, - _session: Session, - _options: &ShieldOptions, - ) -> Result { - todo!("redirect back?") - } - } -} diff --git a/packages/core/shield/src/request.rs b/packages/core/shield/src/request.rs index bff3c1c..d7a2d4e 100644 --- a/packages/core/shield/src/request.rs +++ b/packages/core/shield/src/request.rs @@ -4,8 +4,8 @@ use serde_json::Value; #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct SignInRequest { - pub provider_id: String, - pub subprovider_id: Option, + pub method_id: String, + pub provider_id: Option, pub redirect_url: Option, pub data: Option, pub form_data: Option, @@ -14,8 +14,8 @@ pub struct SignInRequest { #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct SignInCallbackRequest { - pub provider_id: String, - pub subprovider_id: Option, + pub method_id: String, + pub provider_id: Option, pub redirect_url: Option, pub query: Option, pub data: Option, @@ -24,6 +24,6 @@ pub struct SignInCallbackRequest { #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct SignOutRequest { - pub provider_id: String, - pub subprovider_id: Option, + pub method_id: String, + pub provider_id: Option, } diff --git a/packages/core/shield/src/session.rs b/packages/core/shield/src/session.rs index cdd4f5e..59c464a 100644 --- a/packages/core/shield/src/session.rs +++ b/packages/core/shield/src/session.rs @@ -48,28 +48,28 @@ impl Session { pub struct SessionData { pub redirect_url: Option, pub authentication: Option, - pub providers: HashMap, + pub methods: HashMap, } impl SessionData { - pub fn provider( + pub fn method( &self, - provider_id: &str, + method_id: &str, ) -> Result { - match self.providers.get(provider_id) { + match self.methods.get(method_id) { Some(value) => serde_json::from_str(value) .map_err(|err| SessionError::Serialization(err.to_string())), None => Ok(T::default()), } } - pub fn set_provider( + pub fn set_method( &mut self, - provider_id: &str, + method_id: &str, value: T, ) -> Result<(), SessionError> { - self.providers.insert( - provider_id.to_owned(), + self.methods.insert( + method_id.to_owned(), serde_json::to_string(&value) .map_err(|err| SessionError::Serialization(err.to_string()))?, ); @@ -80,7 +80,7 @@ impl SessionData { #[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct Authentication { - pub provider_id: String, - pub subprovider_id: Option, + pub method_id: String, + pub provider_id: Option, pub user_id: String, } diff --git a/packages/core/shield/src/shield.rs b/packages/core/shield/src/shield.rs index 9c1ac73..6b86f04 100644 --- a/packages/core/shield/src/shield.rs +++ b/packages/core/shield/src/shield.rs @@ -3,9 +3,11 @@ use std::{collections::HashMap, sync::Arc}; use futures::future::try_join_all; use crate::{ - error::{ProviderError, SessionError, ShieldError}, + MethodError, + error::{SessionError, ShieldError}, + method::Method, options::ShieldOptions, - provider::{Provider, Subprovider, SubproviderVisualisation}, + provider::{Provider, ProviderVisualisation}, request::{SignInCallbackRequest, SignInRequest, SignOutRequest}, response::Response, session::Session, @@ -16,18 +18,18 @@ use crate::{ #[derive(Clone)] pub struct Shield { storage: Arc>, - providers: Arc>>, + methods: Arc>>, options: ShieldOptions, } impl Shield { - pub fn new(storage: S, providers: Vec>, options: ShieldOptions) -> Self + pub fn new(storage: S, providers: Vec>, options: ShieldOptions) -> Self where S: Storage + 'static, { Self { storage: Arc::new(storage), - providers: Arc::new( + methods: Arc::new( providers .into_iter() .map(|provider| (provider.id(), provider)) @@ -45,56 +47,46 @@ impl Shield { &self.options } - pub fn provider_by_id(&self, provider_id: &str) -> Option<&dyn Provider> { - self.providers.get(provider_id).map(|v| &**v) + pub fn method_by_id(&self, provider_id: &str) -> Option<&dyn Method> { + self.methods.get(provider_id).map(|v| &**v) } - pub async fn subproviders(&self) -> Result>, ShieldError> { - try_join_all( - self.providers - .values() - .map(|provider| provider.subproviders()), - ) - .await - .map(|subproviders| subproviders.into_iter().flatten().collect::>()) + pub async fn providers(&self) -> Result>, ShieldError> { + try_join_all(self.methods.values().map(|provider| provider.providers())) + .await + .map(|providers| providers.into_iter().flatten().collect::>()) } - pub async fn subprovider_visualisations( - &self, - ) -> Result, ShieldError> { - self.subproviders().await.map(|subproviders| { - subproviders + pub async fn provider_visualisations(&self) -> Result, ShieldError> { + self.providers().await.map(|providers| { + providers .into_iter() - .map(|subprovider| { - let provider_id = subprovider.provider_id(); - let subprovider_id = subprovider.id(); - - SubproviderVisualisation { - key: match &subprovider_id { - Some(subprovider_id) => format!("{provider_id}-{subprovider_id}"), - None => provider_id.clone(), + .map(|provider| { + let method_id = provider.method_id(); + let provider_id = provider.id(); + + ProviderVisualisation { + key: match &provider_id { + Some(provider_id) => format!("{method_id}-{provider_id}"), + None => method_id.clone(), }, + method_id, provider_id, - subprovider_id, - name: subprovider.name(), - icon_url: subprovider.icon_url(), + name: provider.name(), + icon_url: provider.icon_url(), } }) .collect() }) } - pub async fn subprovider_by_id( + pub async fn provider_by_id( &self, - provider_id: &str, - subprovider_id: Option<&str>, - ) -> Result>, ShieldError> { - match self.provider_by_id(provider_id) { - Some(provider) => { - provider - .subprovider_by_id(subprovider_id.expect("TODO")) - .await - } + method_id: &str, + provider_id: Option<&str>, + ) -> Result>, ShieldError> { + match self.method_by_id(method_id) { + Some(provider) => provider.provider_by_id(provider_id.expect("TODO")).await, None => Ok(None), } } @@ -104,9 +96,9 @@ impl Shield { request: SignInRequest, session: Session, ) -> Result { - let provider = match self.providers.get(&request.provider_id) { + let provider = match self.methods.get(&request.method_id) { Some(provider) => provider, - None => return Err(ProviderError::ProviderNotFound(request.provider_id).into()), + None => return Err(MethodError::MethodNotFound(request.method_id).into()), }; // TODO: validate redirect URL @@ -134,9 +126,9 @@ impl Shield { request: SignInCallbackRequest, session: Session, ) -> Result { - let provider = match self.providers.get(&request.provider_id) { + let provider = match self.methods.get(&request.method_id) { Some(provider) => provider, - None => return Err(ProviderError::ProviderNotFound(request.provider_id).into()), + None => return Err(MethodError::MethodNotFound(request.method_id).into()), }; let redirect_url = { @@ -175,18 +167,18 @@ impl Shield { }; let response = if let Some(authenticated) = authenticated { - let provider = match self.providers.get(&authenticated.provider_id) { + let provider = match self.methods.get(&authenticated.method_id) { Some(provider) => provider, None => { - return Err(ProviderError::ProviderNotFound(authenticated.provider_id).into()); + return Err(MethodError::MethodNotFound(authenticated.method_id).into()); } }; provider .sign_out( SignOutRequest { + method_id: authenticated.method_id, provider_id: authenticated.provider_id, - subprovider_id: authenticated.subprovider_id, }, session.clone(), &self.options, @@ -214,9 +206,9 @@ impl Shield { match authentication { Some(authentication) => { if self - .subprovider_by_id( - &authentication.provider_id, - authentication.subprovider_id.as_deref(), + .provider_by_id( + &authentication.method_id, + authentication.provider_id.as_deref(), ) .await? .is_none() @@ -244,7 +236,7 @@ mod tests { use crate::{ ShieldOptions, - provider::tests::{TEST_PROVIDER_ID, TestProvider}, + method::tests::{TEST_METHOD_ID, TestMethod}, storage::tests::{TEST_STORAGE_ID, TestStorage}, }; @@ -262,8 +254,8 @@ mod tests { let shield = Shield::new( TestStorage::default(), vec![ - Arc::new(TestProvider::default().with_id("test1")), - Arc::new(TestProvider::default().with_id("test2")), + Arc::new(TestMethod::default().with_id("test1")), + Arc::new(TestMethod::default().with_id("test2")), ], ShieldOptions::default(), ); @@ -271,16 +263,16 @@ mod tests { assert_eq!( None, shield - .provider_by_id(TEST_PROVIDER_ID) + .method_by_id(TEST_METHOD_ID) .map(|provider| provider.id()) ); assert_eq!( Some("test1".to_owned()), - shield.provider_by_id("test1").map(|provider| provider.id()) + shield.method_by_id("test1").map(|provider| provider.id()) ); assert_eq!( Some("test2".to_owned()), - shield.provider_by_id("test2").map(|provider| provider.id()) + shield.method_by_id("test2").map(|provider| provider.id()) ); } } diff --git a/packages/core/shield/src/shield_dyn.rs b/packages/core/shield/src/shield_dyn.rs index 49c4a0f..2a15d86 100644 --- a/packages/core/shield/src/shield_dyn.rs +++ b/packages/core/shield/src/shield_dyn.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use crate::{ error::ShieldError, - provider::{Subprovider, SubproviderVisualisation}, + provider::{Provider, ProviderVisualisation}, request::{SignInCallbackRequest, SignInRequest}, response::Response, session::Session, @@ -14,11 +14,9 @@ use crate::{ #[async_trait] pub trait DynShield: Send + Sync { - async fn subproviders(&self) -> Result>, ShieldError>; + async fn providers(&self) -> Result>, ShieldError>; - async fn subprovider_visualisations( - &self, - ) -> Result, ShieldError>; + async fn provider_visualisations(&self) -> Result, ShieldError>; async fn sign_in( &self, @@ -37,14 +35,12 @@ pub trait DynShield: Send + Sync { #[async_trait] impl DynShield for Shield { - async fn subproviders(&self) -> Result>, ShieldError> { - self.subproviders().await + async fn providers(&self) -> Result>, ShieldError> { + self.providers().await } - async fn subprovider_visualisations( - &self, - ) -> Result, ShieldError> { - self.subprovider_visualisations().await + async fn provider_visualisations(&self) -> Result, ShieldError> { + self.provider_visualisations().await } async fn sign_in( @@ -75,14 +71,12 @@ impl ShieldDyn { Self(Arc::new(shield)) } - pub async fn subproviders(&self) -> Result>, ShieldError> { - self.0.subproviders().await + pub async fn providers(&self) -> Result>, ShieldError> { + self.0.providers().await } - pub async fn subprovider_visualisations( - &self, - ) -> Result, ShieldError> { - self.0.subprovider_visualisations().await + pub async fn provider_visualisations(&self) -> Result, ShieldError> { + self.0.provider_visualisations().await } pub async fn sign_in( diff --git a/packages/integrations/shield-actix/Cargo.toml b/packages/integrations/shield-actix/Cargo.toml index 1adc1e0..58b65d6 100644 --- a/packages/integrations/shield-actix/Cargo.toml +++ b/packages/integrations/shield-actix/Cargo.toml @@ -12,4 +12,4 @@ version.workspace = true actix-session.workspace = true actix-utils.workspace = true actix-web.workspace = true -shield = { path = "../../core/shield", version = "0.0.4" } +shield.workspace = true diff --git a/packages/integrations/shield-axum/Cargo.toml b/packages/integrations/shield-axum/Cargo.toml index f5f320c..5aeed83 100644 --- a/packages/integrations/shield-axum/Cargo.toml +++ b/packages/integrations/shield-axum/Cargo.toml @@ -12,8 +12,8 @@ version.workspace = true axum.workspace = true serde.workspace = true serde_json.workspace = true -shield = { path = "../../core/shield", version = "0.0.4" } -shield-tower = { path = "../shield-tower", version = "0.0.4" } +shield.workspace = true +shield-tower.workspace = true utoipa = { workspace = true, features = ["axum_extras"], optional = true } [features] diff --git a/packages/integrations/shield-axum/src/error.rs b/packages/integrations/shield-axum/src/error.rs index 5074fb4..349876c 100644 --- a/packages/integrations/shield-axum/src/error.rs +++ b/packages/integrations/shield-axum/src/error.rs @@ -4,7 +4,7 @@ use axum::{ response::{IntoResponse, Response}, }; use serde::Serialize; -use shield::{ShieldError, StorageError}; +use shield::{MethodError, ProviderError, ShieldError, StorageError}; #[derive(Serialize)] #[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] @@ -44,10 +44,12 @@ impl RouteError { impl IntoResponse for RouteError { fn into_response(self) -> Response { let status_code = match &self.0 { + ShieldError::Method(method_error) => match method_error { + MethodError::MethodNotFound(_) => StatusCode::NOT_FOUND, + }, ShieldError::Provider(provider_error) => match provider_error { - shield::ProviderError::ProviderNotFound(_) => StatusCode::NOT_FOUND, - shield::ProviderError::SubproviderMissing => StatusCode::BAD_REQUEST, - shield::ProviderError::SubproviderNotFound(_) => StatusCode::NOT_FOUND, + ProviderError::ProviderMissing => StatusCode::BAD_REQUEST, + ProviderError::ProviderNotFound(_) => StatusCode::NOT_FOUND, }, ShieldError::Configuration(_) => StatusCode::INTERNAL_SERVER_ERROR, ShieldError::Session(_) => StatusCode::INTERNAL_SERVER_ERROR, diff --git a/packages/integrations/shield-axum/src/path.rs b/packages/integrations/shield-axum/src/path.rs index 90d34f9..494a830 100644 --- a/packages/integrations/shield-axum/src/path.rs +++ b/packages/integrations/shield-axum/src/path.rs @@ -4,8 +4,8 @@ use serde::Deserialize; #[cfg_attr(feature = "utoipa", derive(utoipa::IntoParams))] #[serde(rename_all = "camelCase")] pub struct AuthPathParams { - /// ID of authentication provider. - pub provider_id: String, - /// ID of authentication subprovider (optional). - pub subprovider_id: Option, + /// ID of authentication method. + pub method_id: String, + /// ID of authentication provider (optional). + pub provider_id: Option, } diff --git a/packages/integrations/shield-axum/src/router.rs b/packages/integrations/shield-axum/src/router.rs index 27813c0..57a4d3a 100644 --- a/packages/integrations/shield-axum/src/router.rs +++ b/packages/integrations/shield-axum/src/router.rs @@ -9,19 +9,19 @@ use crate::routes::*; #[cfg_attr(feature = "utoipa", derive(utoipa::OpenApi))] #[cfg_attr( feature = "utoipa", - openapi(paths(subproviders, sign_in, sign_in_callback, sign_out, user)) + openapi(paths(providers, sign_in, sign_in_callback, sign_out, user)) )] pub struct AuthRoutes; impl AuthRoutes { pub fn router() -> Router { Router::new() - .route("/subproviders", get(subproviders::)) - .route("/sign-in/{providerId}", post(sign_in::)) - .route("/sign-in/{providerId}/{subproviderId}", post(sign_in::)) - .route("/sign-in/callback/{providerId}", get(sign_in_callback::)) + .route("/providers", get(providers::)) + .route("/sign-in/{methodId}", post(sign_in::)) + .route("/sign-in/{methodId}/{providerId}", post(sign_in::)) + .route("/sign-in/callback/{methodId}", get(sign_in_callback::)) .route( - "/sign-in/callback/{providerId}/{subproviderId}", + "/sign-in/callback/{methodId}/{providerId}", get(sign_in_callback::), ) .route("/sign-out", post(sign_out::)) diff --git a/packages/integrations/shield-axum/src/routes.rs b/packages/integrations/shield-axum/src/routes.rs index 8391d6e..accae39 100644 --- a/packages/integrations/shield-axum/src/routes.rs +++ b/packages/integrations/shield-axum/src/routes.rs @@ -1,11 +1,11 @@ +mod providers; mod sign_in; mod sign_in_callback; mod sign_out; -mod subproviders; mod user; +pub use providers::*; pub use sign_in::*; pub use sign_in_callback::*; pub use sign_out::*; -pub use subproviders::*; pub use user::*; diff --git a/packages/integrations/shield-axum/src/routes/subproviders.rs b/packages/integrations/shield-axum/src/routes/providers.rs similarity index 50% rename from packages/integrations/shield-axum/src/routes/subproviders.rs rename to packages/integrations/shield-axum/src/routes/providers.rs index 8b497ee..a5d5583 100644 --- a/packages/integrations/shield-axum/src/routes/subproviders.rs +++ b/packages/integrations/shield-axum/src/routes/providers.rs @@ -1,5 +1,5 @@ use axum::Json; -use shield::{SubproviderVisualisation, User}; +use shield::{ProviderVisualisation, User}; use crate::{ error::{ErrorBody, RouteError}, @@ -10,17 +10,17 @@ use crate::{ feature = "utoipa", utoipa::path( get, - path = "/subproviders", - operation_id = "getSubproviders", - description = "Get a list of authentication subproviders.", + path = "/providers", + operation_id = "getProviders", + description = "Get a list of authentication providers.", responses( - (status = 200, description = "List of authentication subproviders.", body = Vec), + (status = 200, description = "List of authentication providers.", body = Vec), (status = 500, description = "Internal server error.", body = ErrorBody), ) ) )] -pub async fn subproviders( +pub async fn providers( ExtractShield(shield): ExtractShield, -) -> Result>, RouteError> { - Ok(Json(shield.subprovider_visualisations().await?)) +) -> Result>, RouteError> { + Ok(Json(shield.provider_visualisations().await?)) } diff --git a/packages/integrations/shield-axum/src/routes/sign_in.rs b/packages/integrations/shield-axum/src/routes/sign_in.rs index 6f64412..92e5189 100644 --- a/packages/integrations/shield-axum/src/routes/sign_in.rs +++ b/packages/integrations/shield-axum/src/routes/sign_in.rs @@ -20,7 +20,7 @@ pub struct SignInData { feature = "utoipa", utoipa::path( post, - path = "/sign-in/{providerId}/{subproviderId}", + path = "/sign-in/{methodId}/{providerId}", operation_id = "signIn", description = "Sign in to an account with the specified authentication provider.", params( @@ -38,8 +38,8 @@ pub struct SignInData { )] pub async fn sign_in( Path(AuthPathParams { + method_id, provider_id, - subprovider_id, }): Path, ExtractShield(shield): ExtractShield, ExtractSession(session): ExtractSession, @@ -48,8 +48,8 @@ pub async fn sign_in( let response = shield .sign_in( SignInRequest { + method_id, provider_id, - subprovider_id, redirect_url: data.redirect_url, data: None, form_data: None, diff --git a/packages/integrations/shield-axum/src/routes/sign_in_callback.rs b/packages/integrations/shield-axum/src/routes/sign_in_callback.rs index bfca83f..54f81bf 100644 --- a/packages/integrations/shield-axum/src/routes/sign_in_callback.rs +++ b/packages/integrations/shield-axum/src/routes/sign_in_callback.rs @@ -13,7 +13,7 @@ use crate::{ feature = "utoipa", utoipa::path( post, - path = "/sign-in/callback/{providerId}/{subproviderId}", + path = "/sign-in/callback/{methodId}/{providerId}", operation_id = "signInCallback", description = "Callback after signing in with authentication provider.", params( @@ -29,8 +29,8 @@ use crate::{ )] pub async fn sign_in_callback( Path(AuthPathParams { + method_id, provider_id, - subprovider_id, }): Path, Query(query): Query, ExtractShield(shield): ExtractShield, @@ -39,8 +39,8 @@ pub async fn sign_in_callback( let response = shield .sign_in_callback( SignInCallbackRequest { + method_id, provider_id, - subprovider_id, redirect_url: None, query: Some(query), data: None, diff --git a/packages/integrations/shield-leptos-actix/Cargo.toml b/packages/integrations/shield-leptos-actix/Cargo.toml index c64f521..804e5d8 100644 --- a/packages/integrations/shield-leptos-actix/Cargo.toml +++ b/packages/integrations/shield-leptos-actix/Cargo.toml @@ -12,6 +12,6 @@ version.workspace = true async-trait.workspace = true leptos.workspace = true leptos_actix.workspace = true -shield = { path = "../../core/shield", version = "0.0.4" } -shield-actix = { path = "../../integrations/shield-actix", version = "0.0.4" } -shield-leptos = { path = "../../integrations/shield-leptos", version = "0.0.4" } +shield.workspace = true +shield-actix.workspace = true +shield-leptos.workspace = true diff --git a/packages/integrations/shield-leptos-axum/Cargo.toml b/packages/integrations/shield-leptos-axum/Cargo.toml index 8f57ea4..944df00 100644 --- a/packages/integrations/shield-leptos-axum/Cargo.toml +++ b/packages/integrations/shield-leptos-axum/Cargo.toml @@ -12,9 +12,9 @@ version.workspace = true async-trait.workspace = true leptos.workspace = true leptos_axum.workspace = true -shield = { path = "../../core/shield", version = "0.0.4" } -shield-axum = { path = "../../integrations/shield-axum", version = "0.0.4" } -shield-leptos = { path = "../../integrations/shield-leptos", version = "0.0.4" } +shield.workspace = true +shield-axum.workspace = true +shield-leptos.workspace = true [features] default = [] diff --git a/packages/integrations/shield-leptos/Cargo.toml b/packages/integrations/shield-leptos/Cargo.toml index 4304737..b485b3b 100644 --- a/packages/integrations/shield-leptos/Cargo.toml +++ b/packages/integrations/shield-leptos/Cargo.toml @@ -12,4 +12,4 @@ version.workspace = true async-trait.workspace = true leptos.workspace = true serde.workspace = true -shield = { path = "../../core/shield", version = "0.0.4" } +shield.workspace = true diff --git a/packages/integrations/shield-leptos/src/routes/sign_in.rs b/packages/integrations/shield-leptos/src/routes/sign_in.rs index 2149ec7..9b1c3e8 100644 --- a/packages/integrations/shield-leptos/src/routes/sign_in.rs +++ b/packages/integrations/shield-leptos/src/routes/sign_in.rs @@ -1,23 +1,20 @@ use leptos::{either::Either, prelude::*}; -use shield::SubproviderVisualisation; +use shield::ProviderVisualisation; #[server] -pub async fn subproviders() -> Result, ServerFnError> { +pub async fn providers() -> Result, ServerFnError> { use crate::context::extract_shield; let shield = extract_shield().await; shield - .subprovider_visualisations() + .provider_visualisations() .await .map_err(|err| err.into()) } #[server] -pub async fn sign_in( - provider_id: String, - subprovider_id: Option, -) -> Result<(), ServerFnError> { +pub async fn sign_in(method_id: String, provider_id: Option) -> Result<(), ServerFnError> { use shield::{Response, ShieldError, SignInRequest}; use crate::context::expect_server_integration; @@ -29,8 +26,8 @@ pub async fn sign_in( let response = shield .sign_in( SignInRequest { + method_id, provider_id, - subprovider_id, redirect_url: None, data: None, form_data: None, @@ -51,25 +48,25 @@ pub async fn sign_in( #[component] pub fn SignIn() -> impl IntoView { - let subproviders = OnceResource::new(subproviders()); + let providers = OnceResource::new(providers()); let sign_in = ServerAction::::new(); view! {

"Sign in"

- {move || Suspend::new(async move { match subproviders.await { - Ok(subproviders) => Either::Left(view! { + {move || Suspend::new(async move { match providers.await { + Ok(providers) => Either::Left(view! { - - + + - + }), diff --git a/packages/integrations/shield-tower/Cargo.toml b/packages/integrations/shield-tower/Cargo.toml index ac2e339..29d30a7 100644 --- a/packages/integrations/shield-tower/Cargo.toml +++ b/packages/integrations/shield-tower/Cargo.toml @@ -11,7 +11,7 @@ version.workspace = true [dependencies] async-trait.workspace = true http.workspace = true -shield = { path = "../../core/shield", version = "0.0.4" } +shield.workspace = true tower-layer.workspace = true tower-service.workspace = true tower-sessions.workspace = true diff --git a/packages/providers/shield-credentials/Cargo.toml b/packages/methods/shield-credentials/Cargo.toml similarity index 62% rename from packages/providers/shield-credentials/Cargo.toml rename to packages/methods/shield-credentials/Cargo.toml index 013d4cc..cf1e701 100644 --- a/packages/providers/shield-credentials/Cargo.toml +++ b/packages/methods/shield-credentials/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "shield-credentials" -description = "Credentials provider for Shield." +description = "Credentials method for Shield." authors.workspace = true edition.workspace = true @@ -9,4 +9,4 @@ repository.workspace = true version.workspace = true [dependencies] -shield = { path = "../../core/shield", version = "0.0.4" } +shield.workspace = true diff --git a/packages/providers/shield-credentials/README.md b/packages/methods/shield-credentials/README.md similarity index 91% rename from packages/providers/shield-credentials/README.md rename to packages/methods/shield-credentials/README.md index 9a72683..e56545b 100644 --- a/packages/providers/shield-credentials/README.md +++ b/packages/methods/shield-credentials/README.md @@ -1,6 +1,6 @@

Shield Credentials

-Credentials provider for Shield. +Credentials method for Shield. ## Documentation diff --git a/packages/providers/shield-credentials/src/lib.rs b/packages/methods/shield-credentials/src/lib.rs similarity index 100% rename from packages/providers/shield-credentials/src/lib.rs rename to packages/methods/shield-credentials/src/lib.rs diff --git a/packages/providers/shield-email/Cargo.toml b/packages/methods/shield-email/Cargo.toml similarity index 63% rename from packages/providers/shield-email/Cargo.toml rename to packages/methods/shield-email/Cargo.toml index 7abadf1..821c9cc 100644 --- a/packages/providers/shield-email/Cargo.toml +++ b/packages/methods/shield-email/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "shield-email" -description = "Email provider for Shield." +description = "Email method for Shield." authors.workspace = true edition.workspace = true @@ -9,4 +9,4 @@ repository.workspace = true version.workspace = true [dependencies] -shield = { path = "../../core/shield", version = "0.0.4" } +shield.workspace = true diff --git a/packages/providers/shield-email/README.md b/packages/methods/shield-email/README.md similarity index 92% rename from packages/providers/shield-email/README.md rename to packages/methods/shield-email/README.md index cc19dbb..8456730 100644 --- a/packages/providers/shield-email/README.md +++ b/packages/methods/shield-email/README.md @@ -1,6 +1,6 @@

Shield Email

-Email provider for Shield. +Email method for Shield. ## Documentation diff --git a/packages/providers/shield-email/src/lib.rs b/packages/methods/shield-email/src/lib.rs similarity index 100% rename from packages/providers/shield-email/src/lib.rs rename to packages/methods/shield-email/src/lib.rs diff --git a/packages/providers/shield-oauth/Cargo.toml b/packages/methods/shield-oauth/Cargo.toml similarity index 78% rename from packages/providers/shield-oauth/Cargo.toml rename to packages/methods/shield-oauth/Cargo.toml index d90d568..2aad0fb 100644 --- a/packages/providers/shield-oauth/Cargo.toml +++ b/packages/methods/shield-oauth/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "shield-oauth" -description = "OAuth provider for Shield." +description = "OAuth method for Shield." authors.workspace = true edition.workspace = true @@ -11,7 +11,7 @@ version.workspace = true [dependencies] async-trait.workspace = true oauth2 = { version = "5.0.0", default-features = false, features = ["reqwest"] } -shield = { path = "../../core/shield", version = "0.0.4" } +shield.workspace = true [features] default = [] diff --git a/packages/providers/shield-oauth/README.md b/packages/methods/shield-oauth/README.md similarity index 92% rename from packages/providers/shield-oauth/README.md rename to packages/methods/shield-oauth/README.md index ad80166..8e281f7 100644 --- a/packages/providers/shield-oauth/README.md +++ b/packages/methods/shield-oauth/README.md @@ -1,6 +1,6 @@

Shield OAuth

-OAuth provider for Shield. +OAuth method for Shield. ## Documentation diff --git a/packages/providers/shield-oauth/src/lib.rs b/packages/methods/shield-oauth/src/lib.rs similarity index 62% rename from packages/providers/shield-oauth/src/lib.rs rename to packages/methods/shield-oauth/src/lib.rs index 622d3c4..d4f163c 100644 --- a/packages/providers/shield-oauth/src/lib.rs +++ b/packages/methods/shield-oauth/src/lib.rs @@ -1,7 +1,7 @@ +mod method; mod provider; mod storage; -mod subprovider; +pub use method::*; pub use provider::*; pub use storage::*; -pub use subprovider::*; diff --git a/packages/methods/shield-oauth/src/method.rs b/packages/methods/shield-oauth/src/method.rs new file mode 100644 index 0000000..1c964fa --- /dev/null +++ b/packages/methods/shield-oauth/src/method.rs @@ -0,0 +1,114 @@ +use async_trait::async_trait; +use shield::{ + Method, Provider, ProviderError, Response, Session, ShieldError, ShieldOptions, + SignInCallbackRequest, SignInRequest, SignOutRequest, User, +}; + +use crate::{provider::OauthProvider, storage::OauthStorage}; + +pub const OAUTH_METHOD_ID: &str = "oauth"; + +pub struct OauthMethod { + providers: Vec, + storage: Box>, +} + +impl OauthMethod { + pub fn new + 'static>(storage: S) -> Self { + Self { + providers: vec![], + storage: Box::new(storage), + } + } + + pub fn with_providers>(mut self, providers: I) -> Self { + self.providers = providers.into_iter().collect(); + self + } + + async fn oauth_provider_by_id(&self, provider_id: &str) -> Result { + if let Some(provider) = self + .providers + .iter() + .find(|provider| provider.id == provider_id) + { + return Ok(provider.clone()); + } + + if let Some(provider) = self.storage.oauth_provider_by_id(provider_id).await? { + return Ok(provider); + } + + Err(ProviderError::ProviderNotFound(provider_id.to_owned()).into()) + } +} + +#[async_trait] +impl Method for OauthMethod { + fn id(&self) -> String { + OAUTH_METHOD_ID.to_owned() + } + + async fn providers(&self) -> Result>, ShieldError> { + let providers = self + .providers + .iter() + .cloned() + .chain(self.storage.oauth_providers().await?); + + Ok(providers + .map(|provider| Box::new(provider) as Box) + .collect()) + } + + async fn provider_by_id( + &self, + provider_id: &str, + ) -> Result>, ShieldError> { + self.oauth_provider_by_id(provider_id) + .await + .map(|provider| Some(Box::new(provider) as Box)) + } + + async fn sign_in( + &self, + request: SignInRequest, + _session: Session, + _options: &ShieldOptions, + ) -> Result { + let _provider = match request.provider_id { + Some(provider_id) => self.oauth_provider_by_id(&provider_id).await?, + None => return Err(ProviderError::ProviderMissing.into()), + }; + + todo!("oauth sign in") + } + + async fn sign_in_callback( + &self, + request: SignInCallbackRequest, + _session: Session, + _options: &ShieldOptions, + ) -> Result { + let _provider = match request.provider_id { + Some(provider_id) => self.oauth_provider_by_id(&provider_id).await?, + None => return Err(ProviderError::ProviderMissing.into()), + }; + + todo!("oauth sign in callback") + } + + async fn sign_out( + &self, + request: SignOutRequest, + _session: Session, + _options: &ShieldOptions, + ) -> Result { + let _provider = match request.provider_id { + Some(provider_id) => self.oauth_provider_by_id(&provider_id).await?, + None => return Err(ProviderError::ProviderMissing.into()), + }; + + todo!("oauth sign out") + } +} diff --git a/packages/providers/shield-oauth/src/subprovider.rs b/packages/methods/shield-oauth/src/provider.rs similarity index 86% rename from packages/providers/shield-oauth/src/subprovider.rs rename to packages/methods/shield-oauth/src/provider.rs index 610f69e..0c7488f 100644 --- a/packages/providers/shield-oauth/src/subprovider.rs +++ b/packages/methods/shield-oauth/src/provider.rs @@ -1,6 +1,6 @@ -use shield::{Form, Subprovider}; +use shield::{Form, Provider}; -use crate::provider::OAUTH_PROVIDER_ID; +use crate::method::OAUTH_METHOD_ID; #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum OauthProviderVisibility { @@ -18,7 +18,7 @@ pub enum OauthProviderPkceCodeChallenge { // TODO: Remove allow dead code. #[allow(dead_code)] #[derive(Clone, Debug)] -pub struct OauthSubprovider { +pub struct OauthProvider { pub(crate) id: String, pub(crate) name: String, pub(crate) slug: Option, @@ -39,9 +39,9 @@ pub struct OauthSubprovider { pub(crate) icon_url: Option, } -impl Subprovider for OauthSubprovider { - fn provider_id(&self) -> String { - OAUTH_PROVIDER_ID.to_owned() +impl Provider for OauthProvider { + fn method_id(&self) -> String { + OAUTH_METHOD_ID.to_owned() } fn id(&self) -> Option { diff --git a/packages/methods/shield-oauth/src/storage.rs b/packages/methods/shield-oauth/src/storage.rs new file mode 100644 index 0000000..8f5a175 --- /dev/null +++ b/packages/methods/shield-oauth/src/storage.rs @@ -0,0 +1,15 @@ +use async_trait::async_trait; + +use shield::{Storage, StorageError, User}; + +use crate::provider::OauthProvider; + +#[async_trait] +pub trait OauthStorage: Storage + Sync { + async fn oauth_providers(&self) -> Result, StorageError>; + + async fn oauth_provider_by_id( + &self, + provider_id: &str, + ) -> Result, StorageError>; +} diff --git a/packages/providers/shield-oidc/Cargo.toml b/packages/methods/shield-oidc/Cargo.toml similarity index 85% rename from packages/providers/shield-oidc/Cargo.toml rename to packages/methods/shield-oidc/Cargo.toml index c386e16..c8e1552 100644 --- a/packages/providers/shield-oidc/Cargo.toml +++ b/packages/methods/shield-oidc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "shield-oidc" -description = "OpenID Connect provider for Shield." +description = "OpenID Connect method for Shield." authors.workspace = true edition.workspace = true @@ -19,7 +19,7 @@ openidconnect = { version = "4.0.0", default-features = false, features = [ "reqwest", ] } serde.workspace = true -shield = { path = "../../core/shield", version = "0.0.4" } +shield.workspace = true tracing.workspace = true [features] diff --git a/packages/providers/shield-oidc/README.md b/packages/methods/shield-oidc/README.md similarity index 90% rename from packages/providers/shield-oidc/README.md rename to packages/methods/shield-oidc/README.md index dfc81e9..70e0cff 100644 --- a/packages/providers/shield-oidc/README.md +++ b/packages/methods/shield-oidc/README.md @@ -1,6 +1,6 @@

Shield OpenID Connect

-OpenID Connect provider for Shield. +OpenID Connect method for Shield. ## Documentation diff --git a/packages/providers/shield-oidc/src/builders.rs b/packages/methods/shield-oidc/src/builders.rs similarity index 100% rename from packages/providers/shield-oidc/src/builders.rs rename to packages/methods/shield-oidc/src/builders.rs diff --git a/packages/providers/shield-oidc/src/builders/google.rs b/packages/methods/shield-oidc/src/builders/google.rs similarity index 52% rename from packages/providers/shield-oidc/src/builders/google.rs rename to packages/methods/shield-oidc/src/builders/google.rs index f7378d5..0b3509e 100644 --- a/packages/providers/shield-oidc/src/builders/google.rs +++ b/packages/methods/shield-oidc/src/builders/google.rs @@ -1,6 +1,6 @@ -use crate::subprovider::{ - OidcSubprovider, OidcSubproviderBuilder, - oidc_subprovider_builder::{SetClientId, SetDiscoveryUrl, SetIconUrl, SetId, SetName}, +use crate::provider::{ + OidcProvider, OidcProviderBuilder, + oidc_provider_builder::{SetClientId, SetDiscoveryUrl, SetIconUrl, SetId, SetName}, }; pub struct Google {} @@ -9,8 +9,8 @@ impl Google { pub fn builder( id: &str, client_id: &str, - ) -> OidcSubproviderBuilder>>>> { - OidcSubprovider::builder() + ) -> OidcProviderBuilder>>>> { + OidcProvider::builder() .id(id) .name("Google") .icon_url("https://authjs.dev/img/providers/google.svg") diff --git a/packages/providers/shield-oidc/src/builders/keycloak.rs b/packages/methods/shield-oidc/src/builders/keycloak.rs similarity index 54% rename from packages/providers/shield-oidc/src/builders/keycloak.rs rename to packages/methods/shield-oidc/src/builders/keycloak.rs index b5a1fac..55e6003 100644 --- a/packages/providers/shield-oidc/src/builders/keycloak.rs +++ b/packages/methods/shield-oidc/src/builders/keycloak.rs @@ -1,6 +1,6 @@ -use crate::subprovider::{ - OidcSubprovider, OidcSubproviderBuilder, - oidc_subprovider_builder::{SetClientId, SetDiscoveryUrl, SetIconUrl, SetId, SetName}, +use crate::provider::{ + OidcProvider, OidcProviderBuilder, + oidc_provider_builder::{SetClientId, SetDiscoveryUrl, SetIconUrl, SetId, SetName}, }; pub struct Keycloak {} @@ -10,8 +10,8 @@ impl Keycloak { id: &str, discovery_url: &str, client_id: &str, - ) -> OidcSubproviderBuilder>>>> { - OidcSubprovider::builder() + ) -> OidcProviderBuilder>>>> { + OidcProvider::builder() .id(id) .name("Keycloak") .icon_url("https://authjs.dev/img/providers/keycloak.svg") diff --git a/packages/providers/shield-oidc/src/claims.rs b/packages/methods/shield-oidc/src/claims.rs similarity index 100% rename from packages/providers/shield-oidc/src/claims.rs rename to packages/methods/shield-oidc/src/claims.rs diff --git a/packages/providers/shield-oidc/src/client.rs b/packages/methods/shield-oidc/src/client.rs similarity index 100% rename from packages/providers/shield-oidc/src/client.rs rename to packages/methods/shield-oidc/src/client.rs diff --git a/packages/providers/shield-oidc/src/connection.rs b/packages/methods/shield-oidc/src/connection.rs similarity index 94% rename from packages/providers/shield-oidc/src/connection.rs rename to packages/methods/shield-oidc/src/connection.rs index 1bfcf6c..271aa20 100644 --- a/packages/providers/shield-oidc/src/connection.rs +++ b/packages/methods/shield-oidc/src/connection.rs @@ -10,7 +10,7 @@ pub struct OidcConnection { pub id_token: Option, pub expired_at: Option>, pub scopes: Option>, - pub subprovider_id: String, + pub provider_id: String, pub user_id: String, } @@ -23,7 +23,7 @@ pub struct CreateOidcConnection { pub id_token: Option, pub expired_at: Option>, pub scopes: Option>, - pub subprovider_id: String, + pub provider_id: String, pub user_id: String, } diff --git a/packages/providers/shield-oidc/src/lib.rs b/packages/methods/shield-oidc/src/lib.rs similarity index 81% rename from packages/providers/shield-oidc/src/lib.rs rename to packages/methods/shield-oidc/src/lib.rs index 1391644..e597c44 100644 --- a/packages/providers/shield-oidc/src/lib.rs +++ b/packages/methods/shield-oidc/src/lib.rs @@ -2,13 +2,13 @@ mod builders; mod claims; mod client; mod connection; +mod method; mod provider; mod session; mod storage; -mod subprovider; pub use builders::*; pub use connection::*; +pub use method::*; pub use provider::*; pub use storage::*; -pub use subprovider::*; diff --git a/packages/providers/shield-oidc/src/provider.rs b/packages/methods/shield-oidc/src/method.rs similarity index 81% rename from packages/providers/shield-oidc/src/provider.rs rename to packages/methods/shield-oidc/src/method.rs index 9132dbd..8cd7c17 100644 --- a/packages/providers/shield-oidc/src/provider.rs +++ b/packages/methods/shield-oidc/src/method.rs @@ -7,62 +7,59 @@ use openidconnect::{ url::form_urlencoded::parse, }; use shield::{ - Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Provider, ProviderError, - Response, Session, SessionError, ShieldError, ShieldOptions, SignInCallbackRequest, - SignInRequest, SignOutRequest, Subprovider, UpdateUser, User, + Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Method, Provider, + ProviderError, Response, Session, SessionError, ShieldError, ShieldOptions, + SignInCallbackRequest, SignInRequest, SignOutRequest, UpdateUser, User, }; use tracing::debug; use crate::{ CreateOidcConnection, OidcConnection, OidcProviderPkceCodeChallenge, UpdateOidcConnection, - claims::Claims, client::async_http_client, session::OidcSession, storage::OidcStorage, - subprovider::OidcSubprovider, + claims::Claims, client::async_http_client, provider::OidcProvider, session::OidcSession, + storage::OidcStorage, }; -pub const OIDC_PROVIDER_ID: &str = "oidc"; +pub const OIDC_METHOD_ID: &str = "oidc"; -pub struct OidcProvider { - subproviders: Vec, +pub struct OidcMethod { + providers: Vec, storage: Box>, } -impl OidcProvider { +impl OidcMethod { pub fn new + 'static>(storage: S) -> Self { Self { - subproviders: vec![], + providers: vec![], storage: Box::new(storage), } } - pub fn with_subproviders>( - mut self, - subproviders: I, - ) -> Self { - self.subproviders = subproviders.into_iter().collect(); + pub fn with_providers>(mut self, providers: I) -> Self { + self.providers = providers.into_iter().collect(); self } - async fn oidc_subprovider_by_id_or_slug( + async fn oidc_provider_by_id_or_slug( &self, - subprovider_id: &str, - ) -> Result { - if let Some(subprovider) = self - .subproviders + provider_id: &str, + ) -> Result { + if let Some(provider) = self + .providers .iter() - .find(|subprovider| subprovider.id == subprovider_id) + .find(|provider| provider.id == provider_id) { - return Ok(subprovider.clone()); + return Ok(provider.clone()); } - if let Some(subprovider) = self + if let Some(provider) = self .storage - .oidc_subprovider_by_id_or_slug(subprovider_id) + .oidc_provider_by_id_or_slug(provider_id) .await? { - return Ok(subprovider); + return Ok(provider); } - Err(ProviderError::SubproviderNotFound(subprovider_id.to_owned()).into()) + Err(ProviderError::ProviderNotFound(provider_id.to_owned()).into()) } async fn create_user(&self, claims: &Claims) -> Result { @@ -118,7 +115,7 @@ impl OidcProvider { async fn create_oidc_connection( &self, - subprovider_id: String, + provider_id: String, user_id: String, identifier: String, token_response: CoreTokenResponse, @@ -135,7 +132,7 @@ impl OidcProvider { id_token, expired_at, scopes, - subprovider_id, + provider_id, user_id, }) .await @@ -166,30 +163,30 @@ impl OidcProvider { } #[async_trait] -impl Provider for OidcProvider { +impl Method for OidcMethod { fn id(&self) -> String { - OIDC_PROVIDER_ID.to_owned() + OIDC_METHOD_ID.to_owned() } - async fn subproviders(&self) -> Result>, ShieldError> { - let subproviders = self - .subproviders + async fn providers(&self) -> Result>, ShieldError> { + let providers = self + .providers .iter() .cloned() - .chain(self.storage.oidc_subproviders().await?); + .chain(self.storage.oidc_providers().await?); - Ok(subproviders - .map(|subprovider| Box::new(subprovider) as Box) + Ok(providers + .map(|provider| Box::new(provider) as Box) .collect()) } - async fn subprovider_by_id( + async fn provider_by_id( &self, - subprovider_id: &str, - ) -> Result>, ShieldError> { - self.oidc_subprovider_by_id_or_slug(subprovider_id) + provider_id: &str, + ) -> Result>, ShieldError> { + self.oidc_provider_by_id_or_slug(provider_id) .await - .map(|subprovider| Some(Box::new(subprovider) as Box)) + .map(|provider| Some(Box::new(provider) as Box)) } async fn sign_in( @@ -198,12 +195,12 @@ impl Provider for OidcProvider { session: Session, _options: &ShieldOptions, ) -> Result { - let subprovider = match request.subprovider_id { - Some(subprovider_id) => self.oidc_subprovider_by_id_or_slug(&subprovider_id).await?, - None => return Err(ProviderError::SubproviderMissing.into()), + let provider = match request.provider_id { + Some(provider_id) => self.oidc_provider_by_id_or_slug(&provider_id).await?, + None => return Err(ProviderError::ProviderMissing.into()), }; - let client = subprovider.oidc_client().await?; + let client = provider.oidc_client().await?; let mut authorization_request = client.authorize_url( CoreAuthenticationFlow::AuthorizationCode, @@ -211,7 +208,7 @@ impl Provider for OidcProvider { Nonce::new_random, ); - let pkce_code_challenge = match subprovider.pkce_code_challenge { + let pkce_code_challenge = match provider.pkce_code_challenge { OidcProviderPkceCodeChallenge::None => None, OidcProviderPkceCodeChallenge::Plain => Some(PkceCodeChallenge::new_random_plain()), OidcProviderPkceCodeChallenge::S256 => Some(PkceCodeChallenge::new_random_sha256()), @@ -222,12 +219,12 @@ impl Provider for OidcProvider { authorization_request.set_pkce_challenge(pkce_code_challenge.clone()); } - if let Some(scopes) = subprovider.scopes { + if let Some(scopes) = provider.scopes { authorization_request = authorization_request.add_scopes(scopes.into_iter().map(Scope::new)); } - if let Some(authorization_url_params) = subprovider.authorization_url_params { + if let Some(authorization_url_params) = provider.authorization_url_params { let params = parse(authorization_url_params.trim_start_matches('?').as_bytes()); for (name, value) in params { @@ -246,8 +243,8 @@ impl Provider for OidcProvider { session_data.authentication = None; - session_data.set_provider( - OIDC_PROVIDER_ID, + session_data.set_method( + OIDC_METHOD_ID, OidcSession { csrf: Some(csrf_token.secret().clone()), nonce: Some(nonce.secret().clone()), @@ -278,7 +275,7 @@ impl Provider for OidcProvider { .lock() .map_err(|err| SessionError::Lock(err.to_string()))?; - session_data.provider(OIDC_PROVIDER_ID)? + session_data.method(OIDC_METHOD_ID)? }; let state = request @@ -299,12 +296,12 @@ impl Provider for OidcProvider { .and_then(|code| code.as_str()) .ok_or_else(|| ShieldError::Validation("Missing authorization code.".to_owned()))?; - let subprovider = match request.subprovider_id { - Some(subprovider_id) => self.oidc_subprovider_by_id_or_slug(&subprovider_id).await?, - None => return Err(ProviderError::SubproviderMissing.into()), + let provider = match request.provider_id { + Some(provider_id) => self.oidc_provider_by_id_or_slug(&provider_id).await?, + None => return Err(ProviderError::ProviderMissing.into()), }; - let client = subprovider.oidc_client().await?; + let client = provider.oidc_client().await?; let mut token_request = client .exchange_code(AuthorizationCode::new(authorization_code.to_owned())) @@ -314,11 +311,11 @@ impl Provider for OidcProvider { if let Some(pkce_verifier) = pkce_verifier { token_request = token_request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier)); - } else if subprovider.pkce_code_challenge != OidcProviderPkceCodeChallenge::None { + } else if provider.pkce_code_challenge != OidcProviderPkceCodeChallenge::None { return Err(ShieldError::Validation("Missing PKCE verifier.".to_owned())); } - if let Some(token_url_params) = subprovider.token_url_params { + if let Some(token_url_params) = provider.token_url_params { let params = parse(token_url_params.trim_start_matches('?').as_bytes()); for (name, value) in params { @@ -361,7 +358,7 @@ impl Provider for OidcProvider { let (connection, user) = match self .storage - .oidc_connection_by_identifier(&subprovider.id, claims.subject()) + .oidc_connection_by_identifier(&provider.id, claims.subject()) .await? { Some(connection) => { @@ -378,7 +375,7 @@ impl Provider for OidcProvider { let connection = self .create_oidc_connection( - subprovider.id.clone(), + provider.id.clone(), user.id(), claims.subject().to_string(), token_response, @@ -398,13 +395,13 @@ impl Provider for OidcProvider { .map_err(|err| SessionError::Lock(err.to_string()))?; session_data.authentication = Some(Authentication { - provider_id: self.id(), - subprovider_id: Some(subprovider.id), + method_id: self.id(), + provider_id: Some(provider.id), user_id: user.id(), }); - session_data.set_provider( - OIDC_PROVIDER_ID, + session_data.set_method( + OIDC_METHOD_ID, OidcSession { csrf: None, nonce: None, @@ -430,9 +427,9 @@ impl Provider for OidcProvider { // TODO: Revocation URL is always `EndpointNotSet` when using `from_provider_metadata`, // because `ProviderMetadata` does not support `introspection_endpoint` and `revocation_endpoint`. - // let subprovider = match request.subprovider_id { - // Some(subprovider_id) => self.oidc_subprovider_by_id_or_slug(&subprovider_id).await?, - // None => return Err(ProviderError::SubproviderMissing.into()), + // let provider = match request.provider_id { + // Some(provider_id) => self.oidc_provider_by_id_or_slug(&provider_id).await?, + // None => return Err(ProviderError::ProviderMissing.into()), // }; // let connection_id = { diff --git a/packages/providers/shield-oidc/src/subprovider.rs b/packages/methods/shield-oidc/src/provider.rs similarity index 96% rename from packages/providers/shield-oidc/src/subprovider.rs rename to packages/methods/shield-oidc/src/provider.rs index 100ec4e..1d5b3e1 100644 --- a/packages/providers/shield-oidc/src/subprovider.rs +++ b/packages/methods/shield-oidc/src/provider.rs @@ -11,9 +11,9 @@ use openidconnect::{ CoreTokenIntrospectionResponse, CoreTokenResponse, }, }; -use shield::{ConfigurationError, Subprovider}; +use shield::{ConfigurationError, Provider}; -use crate::{client::async_http_client, provider::OIDC_PROVIDER_ID}; +use crate::{client::async_http_client, method::OIDC_METHOD_ID}; type OidcClient = Client< EmptyAdditionalClaims, @@ -50,7 +50,7 @@ pub enum OidcProviderPkceCodeChallenge { #[derive(Builder, Clone, Debug)] #[builder(on(String, into), state_mod(vis = "pub(crate)"))] -pub struct OidcSubprovider { +pub struct OidcProvider { pub id: String, pub name: String, pub slug: Option, @@ -78,7 +78,7 @@ pub struct OidcSubprovider { pub pkce_code_challenge: OidcProviderPkceCodeChallenge, } -impl OidcSubprovider { +impl OidcProvider { pub async fn oidc_client(&self) -> Result { let async_http_client = async_http_client()?; @@ -173,9 +173,9 @@ impl OidcSubprovider { } } -impl Subprovider for OidcSubprovider { - fn provider_id(&self) -> String { - OIDC_PROVIDER_ID.to_owned() +impl Provider for OidcProvider { + fn method_id(&self) -> String { + OIDC_METHOD_ID.to_owned() } fn id(&self) -> Option { diff --git a/packages/providers/shield-oidc/src/session.rs b/packages/methods/shield-oidc/src/session.rs similarity index 100% rename from packages/providers/shield-oidc/src/session.rs rename to packages/methods/shield-oidc/src/session.rs diff --git a/packages/providers/shield-oidc/src/storage.rs b/packages/methods/shield-oidc/src/storage.rs similarity index 76% rename from packages/providers/shield-oidc/src/storage.rs rename to packages/methods/shield-oidc/src/storage.rs index 8f56fe0..526095e 100644 --- a/packages/providers/shield-oidc/src/storage.rs +++ b/packages/methods/shield-oidc/src/storage.rs @@ -4,17 +4,17 @@ use shield::{Storage, StorageError, User}; use crate::{ connection::{CreateOidcConnection, OidcConnection, UpdateOidcConnection}, - subprovider::OidcSubprovider, + provider::OidcProvider, }; #[async_trait] pub trait OidcStorage: Storage + Sync { - async fn oidc_subproviders(&self) -> Result, StorageError>; + async fn oidc_providers(&self) -> Result, StorageError>; - async fn oidc_subprovider_by_id_or_slug( + async fn oidc_provider_by_id_or_slug( &self, - subprovider_id: &str, - ) -> Result, StorageError>; + provider_id: &str, + ) -> Result, StorageError>; async fn oidc_connection_by_id( &self, @@ -23,7 +23,7 @@ pub trait OidcStorage: Storage + Sync { async fn oidc_connection_by_identifier( &self, - subprovider_id: &str, + provider_id: &str, identifier: &str, ) -> Result, StorageError>; diff --git a/packages/providers/shield-webauthn/Cargo.toml b/packages/methods/shield-webauthn/Cargo.toml similarity index 63% rename from packages/providers/shield-webauthn/Cargo.toml rename to packages/methods/shield-webauthn/Cargo.toml index fb36008..6c36fbb 100644 --- a/packages/providers/shield-webauthn/Cargo.toml +++ b/packages/methods/shield-webauthn/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "shield-webauthn" -description = "WebAuthn provider for Shield." +description = "WebAuthn method for Shield." authors.workspace = true edition.workspace = true @@ -9,4 +9,4 @@ repository.workspace = true version.workspace = true [dependencies] -shield = { path = "../../core/shield", version = "0.0.4" } +shield.workspace = true diff --git a/packages/providers/shield-webauthn/README.md b/packages/methods/shield-webauthn/README.md similarity index 92% rename from packages/providers/shield-webauthn/README.md rename to packages/methods/shield-webauthn/README.md index 0796e18..89bc315 100644 --- a/packages/providers/shield-webauthn/README.md +++ b/packages/methods/shield-webauthn/README.md @@ -1,6 +1,6 @@

Shield WebAuthn

-WebAuthn provider for Shield. +WebAuthn method for Shield. ## Documentation diff --git a/packages/providers/shield-webauthn/src/lib.rs b/packages/methods/shield-webauthn/src/lib.rs similarity index 100% rename from packages/providers/shield-webauthn/src/lib.rs rename to packages/methods/shield-webauthn/src/lib.rs diff --git a/packages/providers/shield-oauth/src/provider.rs b/packages/providers/shield-oauth/src/provider.rs deleted file mode 100644 index ba61ccc..0000000 --- a/packages/providers/shield-oauth/src/provider.rs +++ /dev/null @@ -1,120 +0,0 @@ -use async_trait::async_trait; -use shield::{ - Provider, ProviderError, Response, Session, ShieldError, ShieldOptions, SignInCallbackRequest, - SignInRequest, SignOutRequest, Subprovider, User, -}; - -use crate::{storage::OauthStorage, subprovider::OauthSubprovider}; - -pub const OAUTH_PROVIDER_ID: &str = "oauth"; - -pub struct OauthProvider { - subproviders: Vec, - storage: Box>, -} - -impl OauthProvider { - pub fn new + 'static>(storage: S) -> Self { - Self { - subproviders: vec![], - storage: Box::new(storage), - } - } - - pub fn with_subproviders>( - mut self, - subproviders: I, - ) -> Self { - self.subproviders = subproviders.into_iter().collect(); - self - } - - async fn oauth_subprovider_by_id( - &self, - subprovider_id: &str, - ) -> Result { - if let Some(subprovider) = self - .subproviders - .iter() - .find(|subprovider| subprovider.id == subprovider_id) - { - return Ok(subprovider.clone()); - } - - if let Some(subprovider) = self.storage.oauth_subprovider_by_id(subprovider_id).await? { - return Ok(subprovider); - } - - Err(ProviderError::SubproviderNotFound(subprovider_id.to_owned()).into()) - } -} - -#[async_trait] -impl Provider for OauthProvider { - fn id(&self) -> String { - OAUTH_PROVIDER_ID.to_owned() - } - - async fn subproviders(&self) -> Result>, ShieldError> { - let subproviders = self - .subproviders - .iter() - .cloned() - .chain(self.storage.oauth_subproviders().await?); - - Ok(subproviders - .map(|subprovider| Box::new(subprovider) as Box) - .collect()) - } - - async fn subprovider_by_id( - &self, - subprovider_id: &str, - ) -> Result>, ShieldError> { - self.oauth_subprovider_by_id(subprovider_id) - .await - .map(|subprovider| Some(Box::new(subprovider) as Box)) - } - - async fn sign_in( - &self, - request: SignInRequest, - _session: Session, - _options: &ShieldOptions, - ) -> Result { - let _subprovider = match request.subprovider_id { - Some(subprovider_id) => self.oauth_subprovider_by_id(&subprovider_id).await?, - None => return Err(ProviderError::SubproviderMissing.into()), - }; - - todo!("oauth sign in") - } - - async fn sign_in_callback( - &self, - request: SignInCallbackRequest, - _session: Session, - _options: &ShieldOptions, - ) -> Result { - let _subprovider = match request.subprovider_id { - Some(subprovider_id) => self.oauth_subprovider_by_id(&subprovider_id).await?, - None => return Err(ProviderError::SubproviderMissing.into()), - }; - - todo!("oauth sign in callback") - } - - async fn sign_out( - &self, - request: SignOutRequest, - _session: Session, - _options: &ShieldOptions, - ) -> Result { - let _subprovider = match request.subprovider_id { - Some(subprovider_id) => self.oauth_subprovider_by_id(&subprovider_id).await?, - None => return Err(ProviderError::SubproviderMissing.into()), - }; - - todo!("oauth sign out") - } -} diff --git a/packages/providers/shield-oauth/src/storage.rs b/packages/providers/shield-oauth/src/storage.rs deleted file mode 100644 index 186b318..0000000 --- a/packages/providers/shield-oauth/src/storage.rs +++ /dev/null @@ -1,15 +0,0 @@ -use async_trait::async_trait; - -use shield::{Storage, StorageError, User}; - -use crate::subprovider::OauthSubprovider; - -#[async_trait] -pub trait OauthStorage: Storage + Sync { - async fn oauth_subproviders(&self) -> Result, StorageError>; - - async fn oauth_subprovider_by_id( - &self, - subprovider_id: &str, - ) -> Result, StorageError>; -} diff --git a/packages/storage/shield-diesel/Cargo.toml b/packages/storage/shield-diesel/Cargo.toml index 4f0a9a1..44b6c55 100644 --- a/packages/storage/shield-diesel/Cargo.toml +++ b/packages/storage/shield-diesel/Cargo.toml @@ -9,4 +9,4 @@ repository.workspace = true version.workspace = true [dependencies] -shield = { path = "../../core/shield", version = "0.0.4" } +shield.workspace = true diff --git a/packages/storage/shield-memory/Cargo.toml b/packages/storage/shield-memory/Cargo.toml index a632f6f..b831911 100644 --- a/packages/storage/shield-memory/Cargo.toml +++ b/packages/storage/shield-memory/Cargo.toml @@ -11,25 +11,25 @@ version.workspace = true [dependencies] async-trait.workspace = true serde.workspace = true -shield = { path = "../../core/shield", version = "0.0.4" } -# shield-credentials = { path = "../../providers/shield-credentials", version = "0.0.2", optional = true } -# shield-email = { path = "../../providers/shield-email", version = "0.0.2", optional = true } -# shield-oauth = { path = "../../providers/shield-oauth", version = "0.0.2", optional = true } -shield-oidc = { path = "../../providers/shield-oidc", version = "0.0.4", optional = true } -# shield-webauthn = { path = "../../providers/shield-webauthn", version = "0.0.2", optional = true } +shield.workspace = true +# shield-credentials = { workspace = true, optional = true } +# shield-email = { workspace = true, optional = true } +# shield-oauth = { workspace = true, optional = true } +shield-oidc = { workspace = true, optional = true } +# shield-webauthn = { workspace = true, optional = true } uuid = { workspace = true, features = ["v4"] } [features] default = [] -all-providers = [ - # "provider-credentials", - # "provider-email", - # "provider-oauth", - "provider-oidc", - # "provider-webauthn", +all-methods = [ + # "method-credentials", + # "method-email", + # "method-oauth", + "method-oidc", + # "method-webauthn", ] -# provider-credentials = ["dep:shield-credentials"] -# provider-email = ["dep:shield-email"] -# provider-oauth = ["dep:shield-oauth"] -provider-oidc = ["dep:shield-oidc"] -# provider-webauthn = ["dep:shield-webauthn"] +# method-credentials = ["dep:shield-credentials"] +# method-email = ["dep:shield-email"] +# method-oauth = ["dep:shield-oauth"] +method-oidc = ["dep:shield-oidc"] +# method-webauthn = ["dep:shield-webauthn"] diff --git a/packages/storage/shield-memory/src/providers.rs b/packages/storage/shield-memory/src/providers.rs index f5ac369..6ba5800 100644 --- a/packages/storage/shield-memory/src/providers.rs +++ b/packages/storage/shield-memory/src/providers.rs @@ -1,2 +1,2 @@ -#[cfg(feature = "provider-oidc")] +#[cfg(feature = "method-oidc")] pub mod oidc; diff --git a/packages/storage/shield-memory/src/providers/oidc.rs b/packages/storage/shield-memory/src/providers/oidc.rs index 65f75f8..23de44a 100644 --- a/packages/storage/shield-memory/src/providers/oidc.rs +++ b/packages/storage/shield-memory/src/providers/oidc.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, Mutex}; use async_trait::async_trait; use shield::StorageError; use shield_oidc::{ - CreateOidcConnection, OidcConnection, OidcStorage, OidcSubprovider, UpdateOidcConnection, + CreateOidcConnection, OidcConnection, OidcProvider, OidcStorage, UpdateOidcConnection, }; use uuid::Uuid; @@ -16,14 +16,14 @@ pub struct OidcMemoryStorage { #[async_trait] impl OidcStorage for MemoryStorage { - async fn oidc_subproviders(&self) -> Result, StorageError> { + async fn oidc_providers(&self) -> Result, StorageError> { Ok(vec![]) } - async fn oidc_subprovider_by_id_or_slug( + async fn oidc_provider_by_id_or_slug( &self, - _subprovider_id: &str, - ) -> Result, StorageError> { + _provider_id: &str, + ) -> Result, StorageError> { Ok(None) } @@ -43,7 +43,7 @@ impl OidcStorage for MemoryStorage { async fn oidc_connection_by_identifier( &self, - subprovider_id: &str, + provider_id: &str, identifier: &str, ) -> Result, StorageError> { Ok(self @@ -53,7 +53,7 @@ impl OidcStorage for MemoryStorage { .map_err(|err| StorageError::Engine(err.to_string()))? .iter() .find(|connection| { - connection.subprovider_id == subprovider_id && connection.identifier == identifier + connection.provider_id == provider_id && connection.identifier == identifier }) .cloned()) } @@ -71,7 +71,7 @@ impl OidcStorage for MemoryStorage { id_token: connection.id_token, expired_at: connection.expired_at, scopes: connection.scopes, - subprovider_id: connection.subprovider_id, + provider_id: connection.provider_id, user_id: connection.user_id, }; diff --git a/packages/storage/shield-memory/src/storage.rs b/packages/storage/shield-memory/src/storage.rs index 1664db2..cf2333e 100644 --- a/packages/storage/shield-memory/src/storage.rs +++ b/packages/storage/shield-memory/src/storage.rs @@ -13,7 +13,7 @@ pub const MEMORY_STORAGE_ID: &str = "memory"; #[derive(Clone, Debug, Default)] pub struct MemoryStorage { pub(crate) users: Arc>>, - #[cfg(feature = "provider-oidc")] + #[cfg(feature = "method-oidc")] pub(crate) oidc: crate::providers::oidc::OidcMemoryStorage, } diff --git a/packages/storage/shield-sea-orm/Cargo.toml b/packages/storage/shield-sea-orm/Cargo.toml index 1772a7d..e5fd178 100644 --- a/packages/storage/shield-sea-orm/Cargo.toml +++ b/packages/storage/shield-sea-orm/Cargo.toml @@ -15,30 +15,30 @@ sea-orm.workspace = true sea-orm-migration.workspace = true serde = { workspace = true, features = ["derive"] } serde_json.workspace = true -shield = { path = "../../core/shield", version = "0.0.4" } -# shield-credentials = { path = "../../providers/shield-credentials", version = "0.0.2", optional = true } -# shield-email = { path = "../../providers/shield-email", version = "0.0.2", optional = true } -# shield-oauth = { path = "../../providers/shield-oauth", version = "0.0.2", optional = true } -shield-oidc = { path = "../../providers/shield-oidc", version = "0.0.4", optional = true } -# shield-webauthn = { path = "../../providers/shield-webauthn", version = "0.0.2", optional = true } +shield.workspace = true +# shield-credentials = { workspace = true, optional = true } +# shield-email = { workspace = true, optional = true } +# shield-oauth = { workspace = true, optional = true } +shield-oidc = { workspace = true, optional = true } +# shield-webauthn = { workspace = true, optional = true } utoipa = { workspace = true, optional = true } [features] default = [] entity = [] -all-providers = [ - # "provider-credentials", - "provider-email", - "provider-oauth", - "provider-oidc", - # "provider-webauthn", +all-methods = [ + # "method-credentials", + "method-email", + "method-oauth", + "method-oidc", + # "method-webauthn", ] -# provider-credentials = ["dep:shield-credentials"] -# provider-email = ["dep:shield-email"] -# provider-oauth = ["dep:shield-oauth"] -# provider-credentials = [] -provider-email = [] -provider-oauth = [] -provider-oidc = ["dep:shield-oidc"] -# provider-webauthn = ["dep:shield-webauthn"] +# method-credentials = ["dep:shield-credentials"] +# method-email = ["dep:shield-email"] +# method-oauth = ["dep:shield-oauth"] +# method-credentials = [] +method-email = [] +method-oauth = [] +method-oidc = ["dep:shield-oidc"] +# method-webauthn = ["dep:shield-webauthn"] utoipa = ["dep:utoipa", "shield/utoipa"] diff --git a/packages/storage/shield-sea-orm/src/entities.rs b/packages/storage/shield-sea-orm/src/entities.rs index e85fc69..0111ef8 100644 --- a/packages/storage/shield-sea-orm/src/entities.rs +++ b/packages/storage/shield-sea-orm/src/entities.rs @@ -6,15 +6,15 @@ pub mod user; #[cfg(feature = "entity")] pub mod entity; -#[cfg(feature = "provider-email")] +#[cfg(feature = "method-email")] pub mod email_auth_token; -#[cfg(feature = "provider-oauth")] +#[cfg(feature = "method-oauth")] pub mod oauth_provider; -#[cfg(feature = "provider-oauth")] +#[cfg(feature = "method-oauth")] pub mod oauth_provider_connection; -#[cfg(feature = "provider-oidc")] +#[cfg(feature = "method-oidc")] pub mod oidc_provider; -#[cfg(feature = "provider-oidc")] +#[cfg(feature = "method-oidc")] pub mod oidc_provider_connection; diff --git a/packages/storage/shield-sea-orm/src/entities/prelude.rs b/packages/storage/shield-sea-orm/src/entities/prelude.rs index a18c578..cca2148 100644 --- a/packages/storage/shield-sea-orm/src/entities/prelude.rs +++ b/packages/storage/shield-sea-orm/src/entities/prelude.rs @@ -4,20 +4,20 @@ pub use super::user::Entity as User; #[cfg(feature = "entity")] pub use super::entity::Entity; -#[cfg(feature = "provider-email")] +#[cfg(feature = "method-email")] pub use super::email_auth_token::Entity as EmailAuthToken; -#[cfg(feature = "provider-oauth")] +#[cfg(feature = "method-oauth")] pub use super::oauth_provider::{ Entity as OauthProvider, OauthProviderPkceCodeChallenge, OauthProviderType, OauthProviderVisibility, }; -#[cfg(feature = "provider-oauth")] +#[cfg(feature = "method-oauth")] pub use super::oauth_provider_connection::Entity as OauthProviderConnection; -#[cfg(feature = "provider-oidc")] +#[cfg(feature = "method-oidc")] pub use super::oidc_provider::{ Entity as OidcProvider, OidcProviderPkceCodeChallenge, OidcProviderType, OidcProviderVisibility, }; -#[cfg(feature = "provider-oidc")] +#[cfg(feature = "method-oidc")] pub use super::oidc_provider_connection::Entity as OidcProviderConnection; diff --git a/packages/storage/shield-sea-orm/src/entities/user.rs b/packages/storage/shield-sea-orm/src/entities/user.rs index 56b9426..9ed66fa 100644 --- a/packages/storage/shield-sea-orm/src/entities/user.rs +++ b/packages/storage/shield-sea-orm/src/entities/user.rs @@ -32,10 +32,10 @@ pub enum Relation { #[cfg(not(feature = "entity"))] #[sea_orm(has_many = "super::email_address::Entity")] EmailAddress, - #[cfg(feature = "provider-oauth")] + #[cfg(feature = "method-oauth")] #[sea_orm(has_many = "super::oauth_provider_connection::Entity")] OauthProviderConnection, - #[cfg(feature = "provider-oidc")] + #[cfg(feature = "method-oidc")] #[sea_orm(has_many = "super::oidc_provider_connection::Entity")] OidcProviderConnection, } @@ -54,14 +54,14 @@ impl Related for Entity { } } -#[cfg(feature = "provider-oauth")] +#[cfg(feature = "method-oauth")] impl Related for Entity { fn to() -> RelationDef { Relation::OauthProviderConnection.def() } } -#[cfg(feature = "provider-oidc")] +#[cfg(feature = "method-oidc")] impl Related for Entity { fn to() -> RelationDef { Relation::OidcProviderConnection.def() diff --git a/packages/storage/shield-sea-orm/src/migrations/providers.rs b/packages/storage/shield-sea-orm/src/migrations/providers.rs index 84d8d11..dc2582e 100644 --- a/packages/storage/shield-sea-orm/src/migrations/providers.rs +++ b/packages/storage/shield-sea-orm/src/migrations/providers.rs @@ -1,8 +1,8 @@ -#[cfg(feature = "provider-email")] +#[cfg(feature = "method-email")] pub mod email; -#[cfg(feature = "provider-oauth")] +#[cfg(feature = "method-oauth")] pub mod oauth; -#[cfg(feature = "provider-oidc")] +#[cfg(feature = "method-oidc")] pub mod oidc; use async_trait::async_trait; @@ -16,17 +16,17 @@ impl MigratorTrait for ProvidersMigrator { #[allow(unused_mut)] let mut migrations = vec![]; - #[cfg(feature = "provider-email")] + #[cfg(feature = "method-email")] { use self::email::ProviderEmailMigrator; migrations.extend(ProviderEmailMigrator::migrations()); } - #[cfg(feature = "provider-oauth")] + #[cfg(feature = "method-oauth")] { use self::oauth::ProviderOauthMigrator; migrations.extend(ProviderOauthMigrator::migrations()); } - #[cfg(feature = "provider-oidc")] + #[cfg(feature = "method-oidc")] { use self::oidc::ProviderOidcMigrator; migrations.extend(ProviderOidcMigrator::migrations()); diff --git a/packages/storage/shield-sea-orm/src/providers.rs b/packages/storage/shield-sea-orm/src/providers.rs index f5ac369..6ba5800 100644 --- a/packages/storage/shield-sea-orm/src/providers.rs +++ b/packages/storage/shield-sea-orm/src/providers.rs @@ -1,2 +1,2 @@ -#[cfg(feature = "provider-oidc")] +#[cfg(feature = "method-oidc")] pub mod oidc; diff --git a/packages/storage/shield-sea-orm/src/providers/oidc.rs b/packages/storage/shield-sea-orm/src/providers/oidc.rs index 5f2a2bf..fa8c4a6 100644 --- a/packages/storage/shield-sea-orm/src/providers/oidc.rs +++ b/packages/storage/shield-sea-orm/src/providers/oidc.rs @@ -2,8 +2,8 @@ use async_trait::async_trait; use sea_orm::{ActiveModelTrait, ActiveValue, ColumnTrait, EntityTrait, QueryFilter}; use shield::StorageError; use shield_oidc::{ - CreateOidcConnection, OidcConnection, OidcProviderPkceCodeChallenge, OidcProviderVisibility, - OidcStorage, OidcSubprovider, UpdateOidcConnection, + CreateOidcConnection, OidcConnection, OidcProvider, OidcProviderPkceCodeChallenge, + OidcProviderVisibility, OidcStorage, UpdateOidcConnection, }; use crate::{ @@ -14,26 +14,21 @@ use crate::{ #[async_trait] impl OidcStorage for SeaOrmStorage { - async fn oidc_subproviders(&self) -> Result, StorageError> { + async fn oidc_providers(&self) -> Result, StorageError> { oidc_provider::Entity::find() .all(&self.database) .await .map_err(|err| StorageError::Engine(err.to_string())) - .and_then(|subproviders| { - subproviders - .into_iter() - .map(OidcSubprovider::try_from) - .collect() - }) + .and_then(|providers| providers.into_iter().map(OidcProvider::try_from).collect()) } - async fn oidc_subprovider_by_id_or_slug( + async fn oidc_provider_by_id_or_slug( &self, - subprovider_id: &str, - ) -> Result, StorageError> { - let condition = match Self::parse_uuid(subprovider_id) { - Ok(subprovider_id) => oidc_provider::Column::Id.eq(subprovider_id), - Err(_) => oidc_provider::Column::Slug.eq(subprovider_id.to_lowercase()), + provider_id: &str, + ) -> Result, StorageError> { + let condition = match Self::parse_uuid(provider_id) { + Ok(provider_id) => oidc_provider::Column::Id.eq(provider_id), + Err(_) => oidc_provider::Column::Slug.eq(provider_id.to_lowercase()), }; oidc_provider::Entity::find() @@ -41,8 +36,8 @@ impl OidcStorage for SeaOrmStorage { .one(&self.database) .await .map_err(|err| StorageError::Engine(err.to_string())) - .and_then(|subprovider| match subprovider { - Some(subprovider) => OidcSubprovider::try_from(subprovider).map(Option::Some), + .and_then(|provider| match provider { + Some(provider) => OidcProvider::try_from(provider).map(Option::Some), None => Ok(None), }) } @@ -60,13 +55,11 @@ impl OidcStorage for SeaOrmStorage { async fn oidc_connection_by_identifier( &self, - subprovider_id: &str, + provider_id: &str, identifier: &str, ) -> Result, StorageError> { oidc_provider_connection::Entity::find() - .filter( - oidc_provider_connection::Column::ProviderId.eq(Self::parse_uuid(subprovider_id)?), - ) + .filter(oidc_provider_connection::Column::ProviderId.eq(Self::parse_uuid(provider_id)?)) .filter(oidc_provider_connection::Column::Identifier.eq(identifier)) .one(&self.database) .await @@ -86,7 +79,7 @@ impl OidcStorage for SeaOrmStorage { id_token: ActiveValue::Set(connection.id_token), expired_at: ActiveValue::Set(connection.expired_at), scopes: ActiveValue::Set(connection.scopes.map(|scopes| scopes.join(","))), - provider_id: ActiveValue::Set(Self::parse_uuid(&connection.subprovider_id)?), + provider_id: ActiveValue::Set(Self::parse_uuid(&connection.provider_id)?), user_id: ActiveValue::Set(Self::parse_uuid(&connection.user_id)?), ..Default::default() }; @@ -171,11 +164,11 @@ impl From for OidcProviderPkceCode } } -impl TryFrom for OidcSubprovider { +impl TryFrom for OidcProvider { type Error = StorageError; fn try_from(value: oidc_provider::Model) -> Result { - Ok(OidcSubprovider { + Ok(OidcProvider { id: value.id.to_string(), name: value.name, slug: value.slug, @@ -222,7 +215,7 @@ impl From for OidcConnection { scopes: value .scopes .map(|scopes| scopes.split(',').map(|s| s.to_string()).collect()), - subprovider_id: value.provider_id.to_string(), + provider_id: value.provider_id.to_string(), user_id: value.user_id.to_string(), } } diff --git a/packages/storage/shield-sqlx/Cargo.toml b/packages/storage/shield-sqlx/Cargo.toml index 072ab9a..94594b7 100644 --- a/packages/storage/shield-sqlx/Cargo.toml +++ b/packages/storage/shield-sqlx/Cargo.toml @@ -9,4 +9,4 @@ repository.workspace = true version.workspace = true [dependencies] -shield = { path = "../../core/shield", version = "0.0.4" } +shield.workspace = true