Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions packages/core/shield/src/form.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub struct Form {
#[derive(Clone, Debug)]
pub struct Input {
pub name: String,
pub label: Option<String>,
pub r#type: InputType,
pub value: Option<String>,
pub attributes: Option<HashMap<String, Attribute>>,
Expand Down
6 changes: 3 additions & 3 deletions packages/core/shield/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub trait Method: Send + Sync {
request: SignOutRequest,
session: Session,
options: &ShieldOptions,
) -> Result<Response, ShieldError>;
) -> Result<Option<Response>, ShieldError>;
}

#[cfg(test)]
Expand Down Expand Up @@ -111,8 +111,8 @@ pub(crate) mod tests {
_request: SignOutRequest,
_session: Session,
_options: &ShieldOptions,
) -> Result<Response, ShieldError> {
todo!("redirect back?")
) -> Result<Option<Response>, ShieldError> {
Ok(None)
}
}
}
5 changes: 4 additions & 1 deletion packages/core/shield/src/shield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,12 @@ impl<U: User> Shield<U> {
)
.await?
} else {
Response::Redirect(self.options.sign_out_redirect.clone())
None
};

let response =
response.unwrap_or_else(|| Response::Redirect(self.options.sign_out_redirect.clone()));

session.purge().await?;

Ok(response)
Expand Down
6 changes: 6 additions & 0 deletions packages/methods/shield-credentials/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,10 @@ repository.workspace = true
version.workspace = true

[dependencies]
async-trait.workspace = true
serde.workspace = true
serde_json.workspace = true
shield.workspace = true

[dev-dependencies]
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
10 changes: 10 additions & 0 deletions packages/methods/shield-credentials/src/credentials.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use shield::{Form, ShieldError, User};

#[async_trait]
pub trait Credentials<U: User, D: DeserializeOwned>: Send + Sync {
fn form(&self) -> Form;

async fn sign_in(&self, data: D) -> Result<U, ShieldError>;
}
165 changes: 165 additions & 0 deletions packages/methods/shield-credentials/src/email_password.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
use std::{pin::Pin, sync::Arc};

use async_trait::async_trait;
use serde::Deserialize;
use shield::{Form, Input, InputType, ShieldError, User};

use crate::Credentials;

#[derive(Debug, Deserialize)]
pub struct EmailPasswordData {
pub email: String,
pub password: String,
}

type SignInFn<U> = dyn Fn(EmailPasswordData) -> Pin<Box<dyn Future<Output = Result<U, ShieldError>> + Send + Sync>>
+ Send
+ Sync;

pub struct EmailPasswordCredentials<U: User> {
sign_in_fn: Arc<SignInFn<U>>,
}

impl<U: User> EmailPasswordCredentials<U> {
pub fn new(
sign_in_fn: impl Fn(
EmailPasswordData,
)
-> Pin<Box<dyn Future<Output = Result<U, ShieldError>> + Send + Sync>>
+ Send
+ Sync
+ 'static,
) -> Self {
Self {
sign_in_fn: Arc::new(sign_in_fn),
}
}
}

#[async_trait]
impl<U: User> Credentials<U, EmailPasswordData> for EmailPasswordCredentials<U> {
fn form(&self) -> Form {
Form {
inputs: vec![
Input {
name: "email".to_owned(),
label: Some("Email address".to_owned()),
r#type: InputType::Email {
autocomplete: Some("email".to_owned()),
dirname: None,
list: None,
maxlength: None,
minlength: None,
multiple: None,
pattern: None,
placeholder: Some("Email address".to_owned()),
readonly: None,
required: Some(true),
size: None,
},
value: None,
attributes: None,
},
Input {
name: "password".to_owned(),
label: Some("Password".to_owned()),
r#type: InputType::Password {
autocomplete: Some("current-password".to_owned()),
dirname: None,
maxlength: None,
minlength: None,
pattern: None,
placeholder: Some("Password".to_owned()),
readonly: None,
required: Some(true),
size: None,
},
value: None,
attributes: None,
},
],
attributes: None,
}
}

async fn sign_in(&self, data: EmailPasswordData) -> Result<U, ShieldError> {
(self.sign_in_fn)(data).await
}
}

#[cfg(test)]
mod tests {
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use shield::{EmailAddress, ShieldError, StorageError, User};

use crate::Credentials;

use super::{EmailPasswordCredentials, EmailPasswordData};

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TestUser {
id: String,
name: Option<String>,
}

#[async_trait]
impl User for TestUser {
fn id(&self) -> String {
self.id.clone()
}

fn name(&self) -> Option<String> {
self.name.clone()
}

async fn email_addresses(&self) -> Result<Vec<EmailAddress>, StorageError> {
Ok(vec![])
}

fn additional(&self) -> Option<impl Serialize> {
None::<()>
}
}

#[tokio::test]
async fn email_password_credentials() -> Result<(), ShieldError> {
let credentials = EmailPasswordCredentials::new(|data: EmailPasswordData| {
Box::pin(async move {
if data.email == "test@example.com" && data.password == "test" {
Ok(TestUser {
id: "1".to_owned(),
name: Some("Test".to_owned()),
})
} else {
Err(ShieldError::Validation(
"Incorrect email and password combination.".to_owned(),
))
}
})
});

assert!(
credentials
.sign_in(EmailPasswordData {
email: "test@example.com".to_owned(),
password: "incorrect".to_owned(),
})
.await
.is_err_and(|err| err
.to_string()
.contains("Incorrect email and password combination."))
);

let user = credentials
.sign_in(EmailPasswordData {
email: "test@example.com".to_owned(),
password: "test".to_owned(),
})
.await?;

assert_eq!(user.name, Some("Test".to_owned()));

Ok(())
}
}
9 changes: 9 additions & 0 deletions packages/methods/shield-credentials/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
mod credentials;
mod email_password;
mod method;
mod provider;
mod username_password;

pub use credentials::*;
pub use email_password::*;
pub use method::*;
pub use username_password::*;
107 changes: 107 additions & 0 deletions packages/methods/shield-credentials/src/method.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use std::sync::Arc;

use async_trait::async_trait;
use serde::de::DeserializeOwned;
use shield::{
Authentication, Method, Provider, Response, Session, SessionError, ShieldError, ShieldOptions,
SignInCallbackRequest, SignInRequest, SignOutRequest, User,
};

use crate::{Credentials, provider::CredentialsProvider};

pub const CREDENTIALS_METHOD_ID: &str = "credentials";

pub struct CredentialsMethod<U: User, D: DeserializeOwned> {
credentials: Arc<dyn Credentials<U, D>>,
}

impl<U: User, D: DeserializeOwned> CredentialsMethod<U, D> {
pub fn new<C: Credentials<U, D> + 'static>(credentials: C) -> Self {
Self {
credentials: Arc::new(credentials),
}
}
}

#[async_trait]
impl<U: User + 'static, D: DeserializeOwned + 'static> Method for CredentialsMethod<U, D> {
fn id(&self) -> String {
CREDENTIALS_METHOD_ID.to_owned()
}

async fn providers(&self) -> Result<Vec<Box<dyn Provider>>, ShieldError> {
Ok(vec![Box::new(CredentialsProvider::new(
self.credentials.clone(),
))])
}

async fn provider_by_id(
&self,
_provider_id: &str,
) -> Result<Option<Box<dyn Provider>>, ShieldError> {
Ok(None)
}

async fn sign_in(
&self,
request: SignInRequest,
session: Session,
options: &ShieldOptions,
) -> Result<Response, ShieldError> {
if request.provider_id.is_some() {
return Err(ShieldError::Validation(
"Provider should be none.".to_owned(),
));
}

let Some(form_data) = request.form_data else {
return Err(ShieldError::Validation("Missing form data.".to_owned()));
};

let data = serde_json::from_value(form_data)
.map_err(|err| ShieldError::Validation(err.to_string()))?;

let user = self.credentials.sign_in(data).await?;

session.renew().await?;

{
let session_data = session.data();
let mut session_data = session_data
.lock()
.map_err(|err| SessionError::Lock(err.to_string()))?;

session_data.authentication = Some(Authentication {
method_id: self.id(),
provider_id: None,
user_id: user.id(),
});
}

Ok(Response::Redirect(
request
.redirect_url
.unwrap_or(options.sign_in_redirect.clone()),
))
}

async fn sign_in_callback(
&self,
_request: SignInCallbackRequest,
_session: Session,
_options: &ShieldOptions,
) -> Result<Response, ShieldError> {
Err(ShieldError::Validation(
"Credentials method does not have a sign in callback.".to_owned(),
))
}

async fn sign_out(
&self,
_request: SignOutRequest,
_session: Session,
_options: &ShieldOptions,
) -> Result<Option<Response>, ShieldError> {
Ok(None)
}
}
38 changes: 38 additions & 0 deletions packages/methods/shield-credentials/src/provider.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use std::sync::Arc;

use serde::de::DeserializeOwned;
use shield::{Form, Provider, User};

use crate::{CREDENTIALS_METHOD_ID, Credentials};

pub struct CredentialsProvider<U: User, D: DeserializeOwned> {
credentials: Arc<dyn Credentials<U, D>>,
}

impl<U: User, D: DeserializeOwned> CredentialsProvider<U, D> {
pub(crate) fn new(credentials: Arc<dyn Credentials<U, D>>) -> Self {
Self { credentials }
}
}

impl<U: User, D: DeserializeOwned> Provider for CredentialsProvider<U, D> {
fn method_id(&self) -> String {
CREDENTIALS_METHOD_ID.to_owned()
}

fn id(&self) -> Option<String> {
None
}

fn name(&self) -> String {
"Credentials".to_owned()
}

fn icon_url(&self) -> Option<String> {
None
}

fn form(&self) -> Option<Form> {
Some(self.credentials.form())
}
}
Loading