diff --git a/Cargo.lock b/Cargo.lock index b067c2a..b475ee1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4414,7 +4414,11 @@ dependencies = [ name = "shield-credentials" version = "0.0.4" dependencies = [ + "async-trait", + "serde", + "serde_json", "shield", + "tokio", ] [[package]] diff --git a/packages/core/shield/src/form.rs b/packages/core/shield/src/form.rs index 123a17c..24a4bfc 100644 --- a/packages/core/shield/src/form.rs +++ b/packages/core/shield/src/form.rs @@ -18,6 +18,7 @@ pub struct Form { #[derive(Clone, Debug)] pub struct Input { pub name: String, + pub label: Option, pub r#type: InputType, pub value: Option, pub attributes: Option>, diff --git a/packages/core/shield/src/method.rs b/packages/core/shield/src/method.rs index cf4468a..b3a3378 100644 --- a/packages/core/shield/src/method.rs +++ b/packages/core/shield/src/method.rs @@ -39,7 +39,7 @@ pub trait Method: Send + Sync { request: SignOutRequest, session: Session, options: &ShieldOptions, - ) -> Result; + ) -> Result, ShieldError>; } #[cfg(test)] @@ -111,8 +111,8 @@ pub(crate) mod tests { _request: SignOutRequest, _session: Session, _options: &ShieldOptions, - ) -> Result { - todo!("redirect back?") + ) -> Result, ShieldError> { + Ok(None) } } } diff --git a/packages/core/shield/src/shield.rs b/packages/core/shield/src/shield.rs index 6b86f04..8175a44 100644 --- a/packages/core/shield/src/shield.rs +++ b/packages/core/shield/src/shield.rs @@ -185,9 +185,12 @@ impl Shield { ) .await? } else { - Response::Redirect(self.options.sign_out_redirect.clone()) + None }; + let response = + response.unwrap_or_else(|| Response::Redirect(self.options.sign_out_redirect.clone())); + session.purge().await?; Ok(response) diff --git a/packages/methods/shield-credentials/Cargo.toml b/packages/methods/shield-credentials/Cargo.toml index cf1e701..f2641e2 100644 --- a/packages/methods/shield-credentials/Cargo.toml +++ b/packages/methods/shield-credentials/Cargo.toml @@ -9,4 +9,10 @@ repository.workspace = true version.workspace = true [dependencies] +async-trait.workspace = true +serde.workspace = true +serde_json.workspace = true shield.workspace = true + +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } diff --git a/packages/methods/shield-credentials/src/credentials.rs b/packages/methods/shield-credentials/src/credentials.rs new file mode 100644 index 0000000..5476856 --- /dev/null +++ b/packages/methods/shield-credentials/src/credentials.rs @@ -0,0 +1,10 @@ +use async_trait::async_trait; +use serde::de::DeserializeOwned; +use shield::{Form, ShieldError, User}; + +#[async_trait] +pub trait Credentials: Send + Sync { + fn form(&self) -> Form; + + async fn sign_in(&self, data: D) -> Result; +} diff --git a/packages/methods/shield-credentials/src/email_password.rs b/packages/methods/shield-credentials/src/email_password.rs new file mode 100644 index 0000000..ae2d59d --- /dev/null +++ b/packages/methods/shield-credentials/src/email_password.rs @@ -0,0 +1,165 @@ +use std::{pin::Pin, sync::Arc}; + +use async_trait::async_trait; +use serde::Deserialize; +use shield::{Form, Input, InputType, ShieldError, User}; + +use crate::Credentials; + +#[derive(Debug, Deserialize)] +pub struct EmailPasswordData { + pub email: String, + pub password: String, +} + +type SignInFn = dyn Fn(EmailPasswordData) -> Pin> + Send + Sync>> + + Send + + Sync; + +pub struct EmailPasswordCredentials { + sign_in_fn: Arc>, +} + +impl EmailPasswordCredentials { + pub fn new( + sign_in_fn: impl Fn( + EmailPasswordData, + ) + -> Pin> + Send + Sync>> + + Send + + Sync + + 'static, + ) -> Self { + Self { + sign_in_fn: Arc::new(sign_in_fn), + } + } +} + +#[async_trait] +impl Credentials for EmailPasswordCredentials { + fn form(&self) -> Form { + Form { + inputs: vec![ + Input { + name: "email".to_owned(), + label: Some("Email address".to_owned()), + r#type: InputType::Email { + autocomplete: Some("email".to_owned()), + dirname: None, + list: None, + maxlength: None, + minlength: None, + multiple: None, + pattern: None, + placeholder: Some("Email address".to_owned()), + readonly: None, + required: Some(true), + size: None, + }, + value: None, + attributes: None, + }, + Input { + name: "password".to_owned(), + label: Some("Password".to_owned()), + r#type: InputType::Password { + autocomplete: Some("current-password".to_owned()), + dirname: None, + maxlength: None, + minlength: None, + pattern: None, + placeholder: Some("Password".to_owned()), + readonly: None, + required: Some(true), + size: None, + }, + value: None, + attributes: None, + }, + ], + attributes: None, + } + } + + async fn sign_in(&self, data: EmailPasswordData) -> Result { + (self.sign_in_fn)(data).await + } +} + +#[cfg(test)] +mod tests { + use async_trait::async_trait; + use serde::{Deserialize, Serialize}; + use shield::{EmailAddress, ShieldError, StorageError, User}; + + use crate::Credentials; + + use super::{EmailPasswordCredentials, EmailPasswordData}; + + #[derive(Clone, Debug, Deserialize, Serialize)] + pub struct TestUser { + id: String, + name: Option, + } + + #[async_trait] + impl User for TestUser { + fn id(&self) -> String { + self.id.clone() + } + + fn name(&self) -> Option { + self.name.clone() + } + + async fn email_addresses(&self) -> Result, StorageError> { + Ok(vec![]) + } + + fn additional(&self) -> Option { + None::<()> + } + } + + #[tokio::test] + async fn email_password_credentials() -> Result<(), ShieldError> { + let credentials = EmailPasswordCredentials::new(|data: EmailPasswordData| { + Box::pin(async move { + if data.email == "test@example.com" && data.password == "test" { + Ok(TestUser { + id: "1".to_owned(), + name: Some("Test".to_owned()), + }) + } else { + Err(ShieldError::Validation( + "Incorrect email and password combination.".to_owned(), + )) + } + }) + }); + + assert!( + credentials + .sign_in(EmailPasswordData { + email: "test@example.com".to_owned(), + password: "incorrect".to_owned(), + }) + .await + .is_err_and(|err| err + .to_string() + .contains("Incorrect email and password combination.")) + ); + + let user = credentials + .sign_in(EmailPasswordData { + email: "test@example.com".to_owned(), + password: "test".to_owned(), + }) + .await?; + + assert_eq!(user.name, Some("Test".to_owned())); + + Ok(()) + } +} diff --git a/packages/methods/shield-credentials/src/lib.rs b/packages/methods/shield-credentials/src/lib.rs index 8b13789..c9b7a3d 100644 --- a/packages/methods/shield-credentials/src/lib.rs +++ b/packages/methods/shield-credentials/src/lib.rs @@ -1 +1,10 @@ +mod credentials; +mod email_password; +mod method; +mod provider; +mod username_password; +pub use credentials::*; +pub use email_password::*; +pub use method::*; +pub use username_password::*; diff --git a/packages/methods/shield-credentials/src/method.rs b/packages/methods/shield-credentials/src/method.rs new file mode 100644 index 0000000..834ff52 --- /dev/null +++ b/packages/methods/shield-credentials/src/method.rs @@ -0,0 +1,107 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use serde::de::DeserializeOwned; +use shield::{ + Authentication, Method, Provider, Response, Session, SessionError, ShieldError, ShieldOptions, + SignInCallbackRequest, SignInRequest, SignOutRequest, User, +}; + +use crate::{Credentials, provider::CredentialsProvider}; + +pub const CREDENTIALS_METHOD_ID: &str = "credentials"; + +pub struct CredentialsMethod { + credentials: Arc>, +} + +impl CredentialsMethod { + pub fn new + 'static>(credentials: C) -> Self { + Self { + credentials: Arc::new(credentials), + } + } +} + +#[async_trait] +impl Method for CredentialsMethod { + fn id(&self) -> String { + CREDENTIALS_METHOD_ID.to_owned() + } + + async fn providers(&self) -> Result>, ShieldError> { + Ok(vec![Box::new(CredentialsProvider::new( + self.credentials.clone(), + ))]) + } + + async fn provider_by_id( + &self, + _provider_id: &str, + ) -> Result>, ShieldError> { + Ok(None) + } + + async fn sign_in( + &self, + request: SignInRequest, + session: Session, + options: &ShieldOptions, + ) -> Result { + if request.provider_id.is_some() { + return Err(ShieldError::Validation( + "Provider should be none.".to_owned(), + )); + } + + let Some(form_data) = request.form_data else { + return Err(ShieldError::Validation("Missing form data.".to_owned())); + }; + + let data = serde_json::from_value(form_data) + .map_err(|err| ShieldError::Validation(err.to_string()))?; + + let user = self.credentials.sign_in(data).await?; + + session.renew().await?; + + { + let session_data = session.data(); + let mut session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.authentication = Some(Authentication { + method_id: self.id(), + provider_id: None, + user_id: user.id(), + }); + } + + Ok(Response::Redirect( + request + .redirect_url + .unwrap_or(options.sign_in_redirect.clone()), + )) + } + + async fn sign_in_callback( + &self, + _request: SignInCallbackRequest, + _session: Session, + _options: &ShieldOptions, + ) -> Result { + Err(ShieldError::Validation( + "Credentials method does not have a sign in callback.".to_owned(), + )) + } + + async fn sign_out( + &self, + _request: SignOutRequest, + _session: Session, + _options: &ShieldOptions, + ) -> Result, ShieldError> { + Ok(None) + } +} diff --git a/packages/methods/shield-credentials/src/provider.rs b/packages/methods/shield-credentials/src/provider.rs new file mode 100644 index 0000000..99f7fa9 --- /dev/null +++ b/packages/methods/shield-credentials/src/provider.rs @@ -0,0 +1,38 @@ +use std::sync::Arc; + +use serde::de::DeserializeOwned; +use shield::{Form, Provider, User}; + +use crate::{CREDENTIALS_METHOD_ID, Credentials}; + +pub struct CredentialsProvider { + credentials: Arc>, +} + +impl CredentialsProvider { + pub(crate) fn new(credentials: Arc>) -> Self { + Self { credentials } + } +} + +impl Provider for CredentialsProvider { + fn method_id(&self) -> String { + CREDENTIALS_METHOD_ID.to_owned() + } + + fn id(&self) -> Option { + None + } + + fn name(&self) -> String { + "Credentials".to_owned() + } + + fn icon_url(&self) -> Option { + None + } + + fn form(&self) -> Option
{ + Some(self.credentials.form()) + } +} diff --git a/packages/methods/shield-credentials/src/username_password.rs b/packages/methods/shield-credentials/src/username_password.rs new file mode 100644 index 0000000..89df58d --- /dev/null +++ b/packages/methods/shield-credentials/src/username_password.rs @@ -0,0 +1,164 @@ +use std::{pin::Pin, sync::Arc}; + +use async_trait::async_trait; +use serde::Deserialize; +use shield::{Form, Input, InputType, ShieldError, User}; + +use crate::Credentials; + +#[derive(Debug, Deserialize)] +pub struct UsernamePasswordData { + pub username: String, + pub password: String, +} + +type SignInFn = dyn Fn(UsernamePasswordData) -> Pin> + Send + Sync>> + + Send + + Sync; + +pub struct UsernamePasswordCredentials { + sign_in_fn: Arc>, +} + +impl UsernamePasswordCredentials { + pub fn new( + sign_in_fn: impl Fn( + UsernamePasswordData, + ) + -> Pin> + Send + Sync>> + + Send + + Sync + + 'static, + ) -> Self { + Self { + sign_in_fn: Arc::new(sign_in_fn), + } + } +} + +#[async_trait] +impl Credentials for UsernamePasswordCredentials { + fn form(&self) -> Form { + Form { + inputs: vec![ + Input { + name: "username".to_owned(), + label: Some("Username".to_owned()), + r#type: InputType::Text { + autocomplete: Some("username".to_owned()), + dirname: None, + list: None, + maxlength: None, + minlength: None, + pattern: None, + placeholder: Some("Username".to_owned()), + readonly: None, + required: Some(true), + size: None, + }, + value: None, + attributes: None, + }, + Input { + name: "password".to_owned(), + label: Some("Password".to_owned()), + r#type: InputType::Password { + autocomplete: Some("current-password".to_owned()), + dirname: None, + maxlength: None, + minlength: None, + pattern: None, + placeholder: Some("Password".to_owned()), + readonly: None, + required: Some(true), + size: None, + }, + value: None, + attributes: None, + }, + ], + attributes: None, + } + } + + async fn sign_in(&self, data: UsernamePasswordData) -> Result { + (self.sign_in_fn)(data).await + } +} + +#[cfg(test)] +mod tests { + use async_trait::async_trait; + use serde::{Deserialize, Serialize}; + use shield::{EmailAddress, ShieldError, StorageError, User}; + + use crate::Credentials; + + use super::{UsernamePasswordCredentials, UsernamePasswordData}; + + #[derive(Clone, Debug, Deserialize, Serialize)] + pub struct TestUser { + id: String, + name: Option, + } + + #[async_trait] + impl User for TestUser { + fn id(&self) -> String { + self.id.clone() + } + + fn name(&self) -> Option { + self.name.clone() + } + + async fn email_addresses(&self) -> Result, StorageError> { + Ok(vec![]) + } + + fn additional(&self) -> Option { + None::<()> + } + } + + #[tokio::test] + async fn username_password_credentials() -> Result<(), ShieldError> { + let credentials = UsernamePasswordCredentials::new(|data: UsernamePasswordData| { + Box::pin(async move { + if data.username == "test" && data.password == "test" { + Ok(TestUser { + id: "1".to_owned(), + name: Some("Test".to_owned()), + }) + } else { + Err(ShieldError::Validation( + "Incorrect username and password combination.".to_owned(), + )) + } + }) + }); + + assert!( + credentials + .sign_in(UsernamePasswordData { + username: "test".to_owned(), + password: "incorrect".to_owned(), + }) + .await + .is_err_and(|err| err + .to_string() + .contains("Incorrect username and password combination.")) + ); + + let user = credentials + .sign_in(UsernamePasswordData { + username: "test".to_owned(), + password: "test".to_owned(), + }) + .await?; + + assert_eq!(user.name, Some("Test".to_owned())); + + Ok(()) + } +} diff --git a/packages/methods/shield-oauth/src/method.rs b/packages/methods/shield-oauth/src/method.rs index f23181b..c81d1cc 100644 --- a/packages/methods/shield-oauth/src/method.rs +++ b/packages/methods/shield-oauth/src/method.rs @@ -391,13 +391,15 @@ impl Method for OauthMethod { request: SignOutRequest, _session: Session, _options: &ShieldOptions, - ) -> Result { + ) -> Result, ShieldError> { let _provider = match request.provider_id { Some(provider_id) => self.oauth_provider_by_id_or_slug(&provider_id).await?, None => return Err(ProviderError::ProviderMissing.into()), }; - todo!("oauth sign out") + // TODO: OAuth token revocation. + + Ok(None) } } diff --git a/packages/methods/shield-oidc/src/method.rs b/packages/methods/shield-oidc/src/method.rs index 204478b..da33d23 100644 --- a/packages/methods/shield-oidc/src/method.rs +++ b/packages/methods/shield-oidc/src/method.rs @@ -423,8 +423,8 @@ impl Method for OidcMethod { &self, _request: SignOutRequest, _session: Session, - options: &ShieldOptions, - ) -> Result { + _options: &ShieldOptions, + ) -> Result, ShieldError> { // TODO: See [`OidcProvider::oidc_client`]. // let provider = match request.provider_id { @@ -476,7 +476,7 @@ impl Method for OidcMethod { // } // } - Ok(Response::Redirect(options.sign_out_redirect.clone())) + Ok(None) } }