Skip to content

Commit

Permalink
feat: make browser backend behavior more general (#550)
Browse files Browse the repository at this point in the history
* general browser callback
* handle snark task message
* always use pp as ref
* added test for message handler
  • Loading branch information
RyanKung committed Feb 18, 2024
1 parent 2bfa694 commit dd4a20b
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 33 deletions.
74 changes: 46 additions & 28 deletions crates/node/src/backend/browser.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![warn(missing_docs)]
//! BackendBehaviour implementation for browser
use core::cell::RefCell;
use std::result::Result;
use std::sync::Arc;

Expand All @@ -11,7 +12,7 @@ use rings_core::utils::js_value;
use rings_derive::wasm_export;
use wasm_bindgen::JsValue;

use crate::backend::snark::SNARKBehaviour;
use super::BackendMessageHandlerDynObj;
use crate::backend::types::BackendMessage;
use crate::backend::types::MessageHandler;
use crate::error::Error;
Expand All @@ -21,32 +22,42 @@ use crate::provider::Provider;
#[wasm_export]
#[derive(Clone)]
pub struct BackendBehaviour {
service_message_handler: Option<Function>,
plain_text_message_handler: Option<Function>,
extension_message_handler: Option<Function>,
snark_message_handler: Option<SNARKBehaviour>,
handlers: dashmap::DashMap<String, Function>,
extend_handler: RefCell<Option<Arc<dyn MessageHandler<BackendMessage>>>>,
}

#[wasm_export]
impl BackendBehaviour {
/// Create a new instance of message callback, this function accept one argument:
///
/// * backend_message_handler: `function(provider: Arc<Provider>, payload: string, message: string) -> Promise<()>`;
#[allow(clippy::new_without_default)]
#[wasm_bindgen(constructor)]
pub fn new(
service_message_handler: Option<js_sys::Function>,
plain_text_message_handler: Option<js_sys::Function>,
extension_message_handler: Option<Function>,
snark_message_handler: Option<SNARKBehaviour>,
) -> BackendBehaviour {
pub fn new() -> BackendBehaviour {
BackendBehaviour {
service_message_handler,
plain_text_message_handler,
extension_message_handler,
snark_message_handler,
handlers: dashmap::DashMap::<String, Function>::new(),
extend_handler: RefCell::new(None),
}
}

/// Get behaviour as dyn obj ref
pub fn as_dyn_obj(self) -> BackendMessageHandlerDynObj {
BackendMessageHandlerDynObj::new(Arc::new(self))
}

/// Extend backend with other backend
pub fn extend(self, impl_backend: BackendMessageHandlerDynObj) {
self.extend_handler.replace(Some(impl_backend.into()));
}

/// register call back function
/// * func: `function(provider: Arc<Provider>, payload: string, message: string) -> Promise<()>`;
pub fn on(&self, method: String, func: js_sys::Function) {
self.handlers.insert(method, func);
}

fn get_handler(&self, method: &str) -> Option<js_sys::Function> {
self.handlers.get(method).map(|v| v.value().clone())
}

async fn do_handle_message(
&self,
provider: Arc<Provider>,
Expand All @@ -55,35 +66,42 @@ impl BackendBehaviour {
) -> Result<(), Error> {
let provider = provider.clone().as_ref().clone();
let ctx = js_value::serialize(&payload)?;

match msg {
BackendMessage::ServiceMessage(m) => {
if let Some(func) = &self.service_message_handler {
if let Some(func) = &self.get_handler("ServiceMessage") {
let m = js_value::serialize(m)?;
let cb = js_func::of4::<BackendBehaviour, Provider, JsValue, JsValue>(func);
cb(self.clone(), provider, ctx, m).await?;
cb(self.clone(), provider.clone(), ctx, m).await?;
}
}
BackendMessage::Extension(m) => {
if let Some(func) = &self.extension_message_handler {
if let Some(func) = &self.get_handler("Extension") {
let m = js_value::serialize(m)?;
let cb = js_func::of4::<BackendBehaviour, Provider, JsValue, JsValue>(func);
cb(self.clone(), provider, ctx, m).await?;
cb(self.clone(), provider.clone(), ctx, m).await?;
}
}
BackendMessage::PlainText(m) => {
if let Some(func) = &self.plain_text_message_handler {
let cb = js_func::of4::<BackendBehaviour, Provider, JsValue, String>(func);
cb(self.clone(), provider, ctx, m.to_string()).await?;
if let Some(func) = &self.get_handler("PlainText") {
let m = js_value::serialize(m)?;
let cb = js_func::of4::<BackendBehaviour, Provider, JsValue, JsValue>(func);
cb(self.clone(), provider.clone(), ctx, m).await?;
}
}
BackendMessage::SNARKTaskMessage(m) => {
if let Some(h) = &self.snark_message_handler {
h.handle_message(provider.into(), payload, m)
.await
.map_err(|e| Error::SNARKHandleMessage(e.to_string()))?;
if let Some(func) = &self.get_handler("SNARKTaskMessage") {
let m = js_value::serialize(m)?;
let cb = js_func::of4::<BackendBehaviour, Provider, JsValue, JsValue>(func);
cb(self.clone(), provider.clone(), ctx, m).await?;
}
}
}
if let Some(ext) = &self.extend_handler.clone().into_inner() {
ext.handle_message(provider.into(), payload, msg)
.await
.map_err(|e| Error::BackendError(e.to_string()))?;
}
Ok(())
}
}
Expand Down
45 changes: 45 additions & 0 deletions crates/node/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rings_core::message::CustomMessage;
use rings_core::message::Message;
use rings_core::message::MessagePayload;
use rings_core::swarm::callback::SwarmCallback;
use rings_derive::wasm_export;

use crate::backend::types::BackendMessage;
use crate::backend::types::MessageHandler;
Expand Down Expand Up @@ -52,6 +53,50 @@ impl Backend {
}
}

/// This struct is used to simulate `impl T`
/// We need this structure because wasm_bindgen does not support general type such as
/// `dyn T` or `impl T`
/// We use Arc instead Box, to make it cloneable.
#[wasm_export]
pub struct BackendMessageHandlerDynObj {
#[allow(dead_code)]
inner: Arc<HandlerTrait>,
}

impl BackendMessageHandlerDynObj {
/// create new instance
#[cfg(not(feature = "browser"))]
pub fn new<T: MessageHandler<BackendMessage> + Send + Sync + 'static>(a: Arc<T>) -> Self {
Self { inner: a.clone() }
}

/// create new instance
#[cfg(feature = "browser")]
pub fn new<T: MessageHandler<BackendMessage> + 'static>(a: Arc<T>) -> Self {
Self { inner: a.clone() }
}
}

impl From<BackendMessageHandlerDynObj> for Arc<dyn MessageHandler<BackendMessage>> {
fn from(impl_backend: BackendMessageHandlerDynObj) -> Arc<dyn MessageHandler<BackendMessage>> {
impl_backend.inner
}
}

#[cfg_attr(feature = "browser", async_trait(?Send))]
#[cfg_attr(not(feature = "browser"), async_trait)]
impl MessageHandler<BackendMessage> for BackendMessageHandlerDynObj {
async fn handle_message(
&self,
provider: Arc<Provider>,
ctx: &MessagePayload,
msg: &BackendMessage,
) -> std::result::Result<(), Box<dyn std::error::Error>> {
self.handle_message(provider.clone(), ctx, msg).await?;
Ok(())
}
}

#[cfg_attr(feature = "browser", async_trait(?Send))]
#[cfg_attr(not(feature = "browser"), async_trait)]
impl SwarmCallback for Backend {
Expand Down
25 changes: 25 additions & 0 deletions crates/node/src/backend/snark/browser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use wasm_bindgen_futures::future_to_promise;

use super::*;
use crate::backend::types;
use crate::backend::BackendMessageHandlerDynObj;
use crate::prelude::rings_core::utils::js_value;

/// We need this ref to pass Task ref to js_sys
#[wasm_bindgen]
Expand Down Expand Up @@ -115,6 +117,29 @@ impl SNARKProofTaskRef {

#[wasm_bindgen]
impl SNARKBehaviour {
/// Get behaviour as dyn obj ref
pub fn as_dyn_obj(self) -> BackendMessageHandlerDynObj {
BackendMessageHandlerDynObj::new(self.into())
}

/// Handle js native message
pub fn handle_snark_task_message(
self,
provider: Provider,
ctx: JsValue,
msg: JsValue,
) -> js_sys::Promise {
let ins = self.clone();
future_to_promise(async move {
let ctx = js_value::deserialize::<MessagePayload>(ctx)?;
let msg = js_value::deserialize::<SNARKTaskMessage>(msg)?;
ins.handle_message(provider.into(), &ctx, &msg)
.await
.map_err(|e| Error::BackendError(e.to_string()))?;
Ok(JsValue::NULL)
})
}

/// gen proof task with circuits, this function is use for solo proof
/// you can call [SNARKBehaviour::handle_snark_proof_task_ref] later to finalize the proof
pub fn gen_proof_task_ref(circuits: Vec<Circuit>) -> Result<SNARKProofTaskRef> {
Expand Down
8 changes: 4 additions & 4 deletions crates/node/src/backend/snark/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ impl SNARKTaskBuilder {
])?;

SNARKProofTask::VastaPallas(SNARKGenerator {
pp,
pp: pp.into(),
snark,
circuits,
})
Expand Down Expand Up @@ -607,7 +607,7 @@ impl SNARKTaskBuilder {
<E2 as Engine>::Scalar::from(0),
])?;
SNARKProofTask::PallasVasta(SNARKGenerator {
pp,
pp: pp.into(),
snark,
circuits,
})
Expand Down Expand Up @@ -635,7 +635,7 @@ impl SNARKTaskBuilder {
<E2 as Engine>::Scalar::from(0),
])?;
SNARKProofTask::Bn256KZGGrumpkin(SNARKGenerator {
pp,
pp: pp.into(),
snark,
circuits,
})
Expand Down Expand Up @@ -677,7 +677,7 @@ where
{
snark: SNARK<E1, E2>,
circuits: Vec<circuit::Circuit<<E1 as Engine>::Scalar>>,
pp: PublicParams<E1, E2>,
pp: Arc<PublicParams<E1, E2>>,
}

impl<E1, E2> SNARKGenerator<E1, E2>
Expand Down
2 changes: 2 additions & 0 deletions crates/node/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ pub enum Error {
SNARKBigIntValueEmpty() = 1405,
#[error("Failed to load string to PrimeField")]
FailedToLoadFF() = 1406,
#[error("Extend Backend Error {0}")]
BackendError(String) = 1501,
}

impl Error {
Expand Down
12 changes: 12 additions & 0 deletions crates/node/src/provider/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ impl Provider {
}

/// Set callback for swarm, it can be T, or (T0, T1, T2)
#[cfg(not(feature = "browser"))]
pub fn set_backend_callback<T>(&self, callback: T) -> Result<()>
where T: MessageHandler<BackendMessage> + Send + Sync + Sized + 'static {
let backend = Backend::new(Arc::new(self.clone()), Box::new(callback));
Expand All @@ -128,6 +129,17 @@ impl Provider {
.map_err(Error::InternalError)
}

/// Set callback for swarm, it can be T, or (T0, T1, T2)
#[cfg(feature = "browser")]
pub fn set_backend_callback<T>(&self, callback: T) -> Result<()>
where T: MessageHandler<BackendMessage> + Sized + 'static {
let backend = Backend::new(Arc::new(self.clone()), Box::new(callback));
self.processor
.swarm
.set_callback(Arc::new(backend))
.map_err(Error::InternalError)
}

/// Set callback for swarm.
#[deprecated(
note = "set_swarm_callback will be removed in next version, plz use set_backend_callback instead"
Expand Down
65 changes: 65 additions & 0 deletions crates/node/src/tests/wasm/browser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use wasm_bindgen_test::*;
use super::create_connection;
use super::get_peers;
use super::new_provider;
use crate::backend::browser::BackendBehaviour;
use crate::backend::types::BackendMessage;
use crate::prelude::rings_core::utils;
use crate::prelude::rings_core::utils::js_value;
Expand Down Expand Up @@ -81,6 +82,70 @@ async fn test_send_backend_message() {
.unwrap();
}

#[wasm_bindgen_test]
async fn test_handle_backend_message() {
let provider1 = new_provider().await;
let provider2 = new_provider().await;
let behaviour = BackendBehaviour::new();

let js_code_args = "ins, provider, ctx, msg";
// write local msg to global window
let js_code_body = r#"
try {
return new Promise((resolve, reject) => {
console.log("js closure: get message", msg)
window.recentMsg = msg
resolve(undefined)
})
} catch(e) {
return e
}
"#;
let func = js_sys::Function::new_with_args(js_code_args, js_code_body);
behaviour.on("PlainText".to_string(), func);
// provider 1 send backend message to provider 2
// provider 2 set it to local variable
provider2.set_backend_callback(behaviour).unwrap();

let _lis1 = provider1.listen();
let _lis2 = provider2.listen();

create_connection(&provider1, &provider2).await;
console_log!("wait for register");

utils::js_utils::window_sleep(1000).await.unwrap();

let peers = get_peers(&provider1).await;
assert!(peers.len() == 1, "peers len should be 1");
let _peer2 = peers.first().unwrap();

let msg = BackendMessage::PlainText("hello world".to_string());
let req = msg
.into_send_backend_message_request(provider2.address())
.unwrap();

JsFuture::from(provider1.request(
"sendBackendMessage".to_string(),
js_value::serialize(&req).unwrap(),
))
.await
.unwrap();
console_log!("send backend hello world done");
utils::js_utils::window_sleep(3000).await.unwrap();
let global = rings_core::utils::js_utils::global().unwrap();
if let rings_core::utils::js_utils::Global::Window(window) = global {
let ret = window
.get("recentMsg")
.unwrap()
.to_string()
.as_string()
.unwrap();
assert_eq!(&ret, "hello world", "{:?}", ret);
} else {
panic!("cannot get dom window");
}
}

#[wasm_bindgen_test]
async fn test_get_address_from_hex_pubkey() {
let pk = "02c0eeef8d136b10b862a0ac979eac2ad036f9902d87963ddf0fa108f1e275b9c7";
Expand Down
2 changes: 1 addition & 1 deletion crates/snark/src/snark/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ where
}

/// Wrap of nova's public params
#[derive(Serialize, Deserialize, Clone)]
#[derive(Serialize, Deserialize)]
pub struct PublicParams<E1, E2>
where
E1: Engine<Base = <E2 as Engine>::Scalar>,
Expand Down

0 comments on commit dd4a20b

Please sign in to comment.