Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions packages/core/shield/src/action.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ use crate::{
session::Session,
};

pub const SIGN_IN_ACTION_ID: &str = "sign-in";
pub const SIGN_IN_CALLBACK_ACTION_ID: &str = "sign-in-callback";
pub const SIGN_OUT_ACTION_ID: &str = "sign-out";

#[async_trait]
pub trait Action<P: Provider>: ErasedAction + Send + Sync {
fn id(&self) -> String;

fn name(&self) -> String;

fn condition(&self, _provider: &P, _session: Session) -> Result<bool, ShieldError> {
Ok(true)
}

fn form(&self, provider: P) -> Form;

async fn call(
Expand All @@ -29,6 +31,14 @@ pub trait Action<P: Provider>: ErasedAction + Send + Sync {
pub trait ErasedAction: Send + Sync {
fn erased_id(&self) -> String;

fn erased_name(&self) -> String;

fn erased_condition(
&self,
provider: &(dyn Any + Send + Sync),
session: Session,
) -> Result<bool, ShieldError>;

fn erased_form(&self, provider: Box<dyn Any + Send + Sync>) -> Form;

async fn erased_call(
Expand All @@ -48,6 +58,14 @@ macro_rules! erased_action {
self.id()
}

fn erased_name(&self) -> String {
self.name()
}

fn erased_condition(&self, provider: &(dyn std::any::Any + Send + Sync), session: $crate::Session) -> Result<bool, $crate::ShieldError> {
self.condition(provider.downcast_ref().expect("TODO"), session)
}

fn erased_form(&self, provider: Box<dyn std::any::Any + Send + Sync>) -> $crate::Form {
self.form(*provider.downcast().expect("TODO"))
}
Expand All @@ -57,7 +75,7 @@ macro_rules! erased_action {
provider: Box<dyn std::any::Any + Send + Sync>,
session: $crate::Session,
request: $crate::Request,
) -> Result<$crate::Response, ShieldError> {
) -> Result<$crate::Response, $crate::ShieldError> {
self.call(*provider.downcast().expect("TODO"), session, request)
.await
}
Expand All @@ -76,7 +94,8 @@ pub(crate) mod tests {

use super::Action;

pub const TEST_ACTION_ID: &str = "action";
pub const TEST_ACTION_ID: &str = "test";
pub const TEST_ACTION_NAME: &str = "Test";

#[derive(Default)]
pub struct TestAction {}
Expand All @@ -87,6 +106,10 @@ pub(crate) mod tests {
TEST_ACTION_ID.to_owned()
}

fn name(&self) -> String {
TEST_ACTION_NAME.to_owned()
}

fn form(&self, _provider: TestProvider) -> Form {
Form { inputs: vec![] }
}
Expand Down
7 changes: 7 additions & 0 deletions packages/core/shield/src/actions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod sign_in;
mod sign_in_callback;
mod sign_out;

pub use sign_in::*;
pub use sign_in_callback::*;
pub use sign_out::*;
14 changes: 14 additions & 0 deletions packages/core/shield/src/actions/sign_in.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
const ACTION_ID: &str = "sign-in";
const ACTION_NAME: &str = "Sign in";

pub struct SignInAction;

impl SignInAction {
pub fn id() -> String {
ACTION_ID.to_owned()
}

pub fn name() -> String {
ACTION_NAME.to_owned()
}
}
20 changes: 20 additions & 0 deletions packages/core/shield/src/actions/sign_in_callback.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use crate::{Provider, Session, ShieldError};

const ACTION_ID: &str = "sign-in-callback";
const ACTION_NAME: &str = "Sign in callback";

pub struct SignInCallbackAction;

impl SignInCallbackAction {
pub fn id() -> String {
ACTION_ID.to_owned()
}

pub fn name() -> String {
ACTION_NAME.to_owned()
}

pub fn condition<P: Provider>(_provider: &P, _session: Session) -> Result<bool, ShieldError> {
Ok(false)
}
}
44 changes: 44 additions & 0 deletions packages/core/shield/src/actions/sign_out.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use crate::{
Form, Input, InputType, InputTypeSubmit, Provider, Session, SessionError, ShieldError,
};

const ACTION_ID: &str = "sign-out";
const ACTION_NAME: &str = "Sign out";

pub struct SignOutAction;

impl SignOutAction {
pub fn id() -> String {
ACTION_ID.to_owned()
}

pub fn name() -> String {
ACTION_NAME.to_owned()
}

pub fn condition<P: Provider>(provider: &P, session: Session) -> Result<bool, ShieldError> {
let session_data = session.data();
let session_data = session_data
.lock()
.map_err(|err| SessionError::Lock(err.to_string()))?;

Ok(session_data
.authentication
.as_ref()
.is_some_and(|authentication| {
authentication.method_id == provider.method_id()
&& authentication.provider_id == provider.id()
}))
}

pub fn form<P: Provider>(_provider: P) -> Form {
Form {
inputs: vec![Input {
name: "submit".to_owned(),
label: None,
r#type: InputType::Submit(InputTypeSubmit {}),
value: Some(Self::name()),
}],
}
}
}
2 changes: 2 additions & 0 deletions packages/core/shield/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod action;
mod actions;
mod error;
mod form;
mod method;
Expand All @@ -13,6 +14,7 @@ mod storage;
mod user;

pub use action::*;
pub use actions::*;
pub use error::*;
pub use form::*;
pub use method::*;
Expand Down
14 changes: 11 additions & 3 deletions packages/core/shield/src/shield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::{any::Any, collections::HashMap, sync::Arc};
use futures::future::try_join_all;

use crate::{
error::ShieldError, form::Form, method::ErasedMethod, options::ShieldOptions, storage::Storage,
user::User,
Session, error::ShieldError, form::Form, method::ErasedMethod, options::ShieldOptions,
storage::Storage, user::User,
};

#[derive(Clone)]
Expand Down Expand Up @@ -64,7 +64,11 @@ impl<U: User> Shield<U> {
}
}

pub async fn action_forms(&self, action_id: &str) -> Result<Vec<Form>, ShieldError> {
pub async fn action_forms(
&self,
action_id: &str,
session: Session,
) -> Result<Vec<Form>, ShieldError> {
let mut forms = vec![];

for (_, method) in self.methods.iter() {
Expand All @@ -73,6 +77,10 @@ impl<U: User> Shield<U> {
};

for provider in method.erased_providers().await? {
if !action.erased_condition(&provider, session.clone())? {
continue;
}

let form = action.erased_form(provider);

forms.push(form);
Expand Down
24 changes: 18 additions & 6 deletions packages/core/shield/src/shield_dyn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ use std::{any::Any, sync::Arc};

use async_trait::async_trait;

use crate::{error::ShieldError, form::Form, shield::Shield, user::User};
use crate::{Session, error::ShieldError, form::Form, shield::Shield, user::User};

#[async_trait]
pub trait DynShield: Send + Sync {
async fn providers(&self) -> Result<Vec<Box<dyn Any + Send + Sync>>, ShieldError>;

async fn action_forms(&self, action_id: &str) -> Result<Vec<Form>, ShieldError>;
async fn action_forms(
&self,
action_id: &str,
session: Session,
) -> Result<Vec<Form>, ShieldError>;
}

#[async_trait]
Expand All @@ -17,8 +21,12 @@ impl<U: User> DynShield for Shield<U> {
self.providers().await
}

async fn action_forms(&self, action_id: &str) -> Result<Vec<Form>, ShieldError> {
self.action_forms(action_id).await
async fn action_forms(
&self,
action_id: &str,
session: Session,
) -> Result<Vec<Form>, ShieldError> {
self.action_forms(action_id, session).await
}
}

Expand All @@ -33,7 +41,11 @@ impl ShieldDyn {
self.0.providers().await
}

pub async fn action_forms(&self, action_id: &str) -> Result<Vec<Form>, ShieldError> {
self.0.action_forms(action_id).await
pub async fn action_forms(
&self,
action_id: &str,
session: Session,
) -> Result<Vec<Form>, ShieldError> {
self.0.action_forms(action_id, session).await
}
}
4 changes: 4 additions & 0 deletions packages/integrations/shield-dioxus/src/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ impl DioxusIntegrationDyn {
pub async fn extract_shield(&self) -> ShieldDyn {
self.0.extract_shield().await
}

pub async fn extract_session(&self) -> Session {
self.0.extract_session().await
}
}
3 changes: 2 additions & 1 deletion packages/integrations/shield-dioxus/src/routes/action.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ pub fn Action(props: ActionProps) -> Element {
async fn forms(action_id: String) -> Result<Vec<Form>, ServerFnError> {
let FromContext(integration): FromContext<DioxusIntegrationDyn> = extract().await?;
let shield = integration.extract_shield().await;
let session = integration.extract_session().await;

let forms = shield.action_forms(&action_id).await?;
let forms = shield.action_forms(&action_id, session).await?;

Ok(forms)
}
2 changes: 2 additions & 0 deletions packages/methods/shield-credentials/src/actions.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
mod sign_in;
mod sign_out;

pub use sign_in::*;
pub use sign_out::*;
10 changes: 7 additions & 3 deletions packages/methods/shield-credentials/src/actions/sign_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::sync::Arc;
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use shield::{
Action, Authentication, Form, Request, Response, SIGN_IN_ACTION_ID, Session, SessionError,
ShieldError, User, erased_action,
Action, Authentication, Form, Request, Response, Session, SessionError, ShieldError,
SignInAction, User, erased_action,
};

use crate::{credentials::Credentials, provider::CredentialsProvider};
Expand All @@ -24,7 +24,11 @@ impl<U: User + 'static, D: DeserializeOwned + 'static> Action<CredentialsProvide
for CredentialsSignInAction<U, D>
{
fn id(&self) -> String {
SIGN_IN_ACTION_ID.to_owned()
SignInAction::id()
}

fn name(&self) -> String {
SignInAction::name()
}

fn form(&self, _provider: CredentialsProvider) -> Form {
Expand Down
41 changes: 41 additions & 0 deletions packages/methods/shield-credentials/src/actions/sign_out.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use async_trait::async_trait;
use shield::{Action, Form, Request, Response, Session, ShieldError, SignOutAction, erased_action};

use crate::provider::CredentialsProvider;

pub struct CredentialsSignOutAction;

#[async_trait]
impl Action<CredentialsProvider> for CredentialsSignOutAction {
fn id(&self) -> String {
SignOutAction::id()
}

fn name(&self) -> String {
SignOutAction::name()
}

fn condition(
&self,
provider: &CredentialsProvider,
session: Session,
) -> Result<bool, ShieldError> {
SignOutAction::condition(provider, session)
}

fn form(&self, provider: CredentialsProvider) -> Form {
SignOutAction::form(provider)
}

async fn call(
&self,
_provider: CredentialsProvider,
_session: Session,
_request: Request,
) -> Result<Response, ShieldError> {
// TODO: sign out
Ok(Response::Default)
}
}

erased_action!(CredentialsSignOutAction);
13 changes: 9 additions & 4 deletions packages/methods/shield-credentials/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ use async_trait::async_trait;
use serde::de::DeserializeOwned;
use shield::{Action, Method, ShieldError, User, erased_method};

use crate::{Credentials, actions::CredentialsSignInAction, provider::CredentialsProvider};
use crate::{
actions::{CredentialsSignInAction, CredentialsSignOutAction},
credentials::Credentials,
provider::CredentialsProvider,
};

pub const CREDENTIALS_METHOD_ID: &str = "credentials";

Expand All @@ -29,9 +33,10 @@ impl<U: User + 'static, D: DeserializeOwned + 'static> Method<CredentialsProvide
}

fn actions(&self) -> Vec<Box<dyn Action<CredentialsProvider>>> {
vec![Box::new(CredentialsSignInAction::new(
self.credentials.clone(),
))]
vec![
Box::new(CredentialsSignInAction::new(self.credentials.clone())),
Box::new(CredentialsSignOutAction),
]
}

async fn providers(&self) -> Result<Vec<CredentialsProvider>, ShieldError> {
Expand Down
Loading