Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements of chain extension implementation #649

Merged
merged 3 commits into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 140 additions & 2 deletions crates/pink/pink-extension/macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use proc_macro_crate::{crate_name, FoundCrate};
use quote::quote;
use syn::{Result, parse_macro_input, spanned::Spanned};
use syn::{parse_macro_input, spanned::Spanned, Result};

use ink_lang_ir::{HexLiteral as _, ImplItem, Selector};
use ink_lang_ir::{ChainExtension, HexLiteral as _, ImplItem, Selector};

/// A drop-in replacement for `ink_lang::contract` with pink-specific feature extensions.
///
Expand Down Expand Up @@ -131,3 +131,141 @@ fn find_crate_name(origin: &str) -> Result<syn::Ident> {
};
Ok(name)
}

/// Internal use only.
#[proc_macro_attribute]
pub fn chain_extension(_: TokenStream, input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as TokenStream2);
let output = patch_chain_extension(input);
output.into()
}

fn patch_chain_extension(input: TokenStream2) -> TokenStream2 {
match patch_chain_extension_or_err(input) {
Ok(tokens) => tokens,
Err(err) => err.to_compile_error(),
}
}

fn patch_chain_extension_or_err(input: TokenStream2) -> Result<TokenStream2> {
use proc_macro2::{Ident, Literal, Span};

let backend_trait = {
let mut item_trait: syn::ItemTrait = syn::parse2(input.clone())?;

item_trait.ident = syn::Ident::new(
&format!("{}Backend", item_trait.ident),
item_trait.ident.span(),
);

item_trait.items.retain(|i| {
if let &syn::TraitItem::Type(_) = i {
false
} else {
true
}
});

item_trait.items.push(syn::parse_quote! {
type Error;
});

for item in item_trait.items.iter_mut() {
if let syn::TraitItem::Method(item_method) = item {
item_method.attrs.clear();
item_method.sig.inputs.insert(0, syn::parse_quote! {
&self
});
item_method.sig.output = match item_method.sig.output.clone() {
syn::ReturnType::Type(_, tp) => {
syn::parse_quote! {
-> Result<#tp, Self::Error>
}
}
syn::ReturnType::Default => {
syn::parse_quote! {
-> Result<(), Self::Error>
}
}
};
}
}

item_trait
};

let id_pairs: Vec<_> = {
let extension = ChainExtension::new(Default::default(), input.clone())?;
extension
.iter_methods()
.map(|m| {
let name = m.ident().to_string();
let id = m.id().into_u32();
(name, id)
})
.collect()
};

// Extract all function ids to a sub module
let func_ids = {
let mut mod_item: syn::ItemMod = syn::parse_quote! {
pub mod func_ids {}
};
for (name, id) in id_pairs.iter() {
let name = name.to_uppercase();
let name = Ident::new(&name, Span::call_site());
let id = Literal::u32_unsuffixed(*id);
mod_item
.content
.as_mut()
.unwrap()
.1
.push(syn::parse_quote! {
pub const #name: u32 = #id;
});
}
mod_item
};

// Generate the dispatcher
let dispatcher: syn::ItemMacro = {
let (names, ids): (Vec<_>, Vec<_>) = id_pairs
.into_iter()
.map(|(name, id)| {
let name = Ident::new(&name, Span::call_site());
let id = Literal::u32_unsuffixed(id);
(name, id)
})
.unzip();
syn::parse_quote! {
#[macro_export]
macro_rules! dispatch_ext_call {
($func_id: expr, $handler: expr, $env: expr) => {
match $func_id {
#(
#ids => {
let input = $env.read_as_unbounded($env.in_len())?;
let output = $handler.#names(input)?;
let output = output.encode();
Some(output)
}
)*
_ => None,
}
};
}
}
};

let crate_ink_lang = find_crate_name("ink_lang")?;
Ok(quote! {
#[#crate_ink_lang::chain_extension]
#input

#backend_trait

#func_ids

#dispatcher
})
}
11 changes: 3 additions & 8 deletions crates/pink/pink-extension/src/chain_extension.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use alloc::vec::Vec;
use alloc::borrow::Cow;
use ink_lang as ink;
use ink::ChainExtensionInstance;

Expand All @@ -10,12 +11,6 @@ mod signing;

#[cfg(feature = "std")]
pub mod test;
pub mod func_ids {
pub const HTTP_REQUEST: u32 = 0xff000001;
pub const SIGN: u32 = 0xff000002;
pub const VERIFY: u32 = 0xff000003;
pub const DERIVE_SR25519_PAIR: u32 = 0xff000004;
}

#[derive(scale::Encode, scale::Decode)]
#[cfg_attr(feature = "std", derive(scale_info::TypeInfo))]
Expand All @@ -31,7 +26,7 @@ impl ink_env::chain_extension::FromStatusCode for ErrorCode {
}

/// Extensions for the ink runtime defined by fat contract.
#[ink::chain_extension]
#[pink_extension_macro::chain_extension]
pub trait PinkExt {
type ErrorCode = ErrorCode;

Expand All @@ -46,7 +41,7 @@ pub trait PinkExt {
fn verify(args: VerifyArgs) -> bool;

#[ink(extension = 0xff000004, handle_status = false, returns_result = false)]
fn derive_sr25519_pair(salt: &[u8]) -> (Vec<u8>, Vec<u8>);
fn derive_sr25519_pair(salt: Cow<[u8]>) -> (Vec<u8>, Vec<u8>);
}

pub fn pink_extension_instance() -> <PinkExt as ChainExtensionInstance>::Instance {
Expand Down
2 changes: 1 addition & 1 deletion crates/pink/pink-extension/src/chain_extension/signing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,6 @@ macro_rules! verify {
macro_rules! derive_sr25519_pair {
($salt: expr) => {{
let salt: &[u8] = $salt.as_ref();
$crate::pink_extension_instance().derive_sr25519_pair(salt)
$crate::pink_extension_instance().derive_sr25519_pair(salt.into())
}};
}
96 changes: 40 additions & 56 deletions crates/pink/src/runtime/extension.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
use std::borrow::Cow;
use std::{convert::TryFrom, time::Duration};

use frame_support::log::error;
use pallet_contracts::chain_extension::{
ChainExtension, Environment, Ext, InitState, RetVal, SysConfig, UncheckedFrom,
};
use phala_crypto::sr25519::{Persistence, KDF};
use pink_extension::PinkEvent;
use pink_extension::{
chain_extension::{HttpRequest, HttpResponse, PinkExtBackend, SigType, SignArgs, VerifyArgs},
dispatch_ext_call, PinkEvent,
};
use scale::{Decode, Encode};
use sp_core::Pair;
use sp_runtime::DispatchError;
Expand Down Expand Up @@ -62,43 +66,46 @@ pub struct PinkExtension;
impl ChainExtension<super::PinkRuntime> for PinkExtension {
fn call<E: Ext>(func_id: u32, env: Environment<E, InitState>) -> Result<RetVal, DispatchError>
where
<E::T as SysConfig>::AccountId: UncheckedFrom<<E::T as SysConfig>::Hash> + AsRef<[u8]>,
<E::T as SysConfig>::AccountId:
UncheckedFrom<<E::T as SysConfig>::Hash> + AsRef<[u8]> + Clone,
{
use pink_extension::chain_extension::func_ids::*;

let call = Call { env };
// func_id refer to https://github.com/patractlabs/PIPs/blob/main/PIPs/pip-100.md
match func_id {
HTTP_REQUEST => call.http_request(),
SIGN => call.sign(),
VERIFY => call.verify(),
DERIVE_SR25519_PAIR => call.derive_sr25519_pair(),
_ => {
let mut env = env.buf_in_buf_out();
let call = Call {
address: env.ext().address().clone(),
};
let output = match dispatch_ext_call!(func_id, call, env) {
Some(output) => output,
None => {
error!(target: "pink", "Called an unregistered `func_id`: {:}", func_id);
Err(DispatchError::Other("Unimplemented func_id"))
return Err(DispatchError::Other(
"PinkExtension::call: unknown function",
))
}
}
};
env.write(&output, false, None)
.or(Err(DispatchError::Other(
"PinkExtension::call: failed to write output",
)))?;
Ok(RetVal::Converging(0))
}
}

struct Call<'a, 'b, E: Ext> {
env: Environment<'a, 'b, E, InitState>,
struct Call<AccountId> {
address: AccountId,
}

impl<'a, 'b, E: Ext> Call<'a, 'b, E>
impl<AccountId> PinkExtBackend for Call<AccountId>
where
<E::T as SysConfig>::AccountId: UncheckedFrom<<E::T as SysConfig>::Hash> + AsRef<[u8]>,
AccountId: AsRef<[u8]>,
{
fn http_request(self) -> Result<RetVal, DispatchError> {
use pink_extension::chain_extension::{HttpRequest, HttpResponse};
type Error = DispatchError;
fn http_request(&self, request: HttpRequest) -> Result<HttpResponse, Self::Error> {
if !matches!(get_call_mode(), Some(CallMode::Query)) {
return Err(DispatchError::Other(
"http_request can only be called in query mode",
));
}

let mut env = self.env.buf_in_buf_out();
let request: HttpRequest = env.read_as_unbounded(env.in_len())?;
let uri = http_req::uri::Uri::try_from(request.url.as_str())
.or(Err(DispatchError::Other("Invalid URL")))?;

Expand Down Expand Up @@ -146,41 +153,29 @@ where
body,
headers,
};
env.write(&response.encode(), false, None)
.map_err(|_| DispatchError::Other("ChainExtension failed to return http_request"))?;
Ok(RetVal::Converging(0))
Ok(response)
}

fn sign(self) -> Result<RetVal, DispatchError> {
use pink_extension::chain_extension::{SigType, SignArgs};
let mut env = self.env.buf_in_buf_out();
let args: SignArgs = env.read_as_unbounded(env.in_len())?;

fn sign(&self, args: SignArgs) -> Result<Vec<u8>, Self::Error> {
macro_rules! sign_with {
($sigtype:ident) => {{
let pair = sp_core::$sigtype::Pair::from_seed_slice(&args.key)
.or(Err(DispatchError::Other("Invalid key")))?;
let signature = pair.sign(&args.message);
let signature: &[u8] = signature.as_ref();
env.write(&signature.encode(), false, None)
.map_err(|_| DispatchError::Other("ChainExtension failed to return sign"))?;
signature.to_vec()
}};
}

match args.sigtype {
Ok(match args.sigtype {
SigType::Sr25519 => sign_with!(sr25519),
SigType::Ed25519 => sign_with!(ed25519),
SigType::Ecdsa => sign_with!(ecdsa),
}
Ok(RetVal::Converging(0))
})
}

fn verify(self) -> Result<RetVal, DispatchError> {
use pink_extension::chain_extension::{SigType, VerifyArgs};
let mut env = self.env.buf_in_buf_out();
let args: VerifyArgs = env.read_as_unbounded(env.in_len())?;

let result = match args.sigtype {
fn verify(&self, args: VerifyArgs) -> Result<bool, Self::Error> {
Ok(match args.sigtype {
SigType::Sr25519 => {
sp_core::sr25519::Pair::verify_weak(&args.signature, &args.message, &args.pubkey)
}
Expand All @@ -190,33 +185,22 @@ where
SigType::Ecdsa => {
sp_core::ecdsa::Pair::verify_weak(&args.signature, &args.message, &args.pubkey)
}
};
env.write(&result.encode(), false, None)
.map_err(|_| DispatchError::Other("ChainExtension failed to return verify"))?;
Ok(RetVal::Converging(0))
})
}

fn derive_sr25519_pair(self) -> Result<RetVal, DispatchError> {
let mut env = self.env.buf_in_buf_out();
let salt: Vec<u8> = env.read_as_unbounded(env.in_len())?;
fn derive_sr25519_pair(&self, salt: Cow<[u8]>) -> Result<(Vec<u8>, Vec<u8>), Self::Error> {
let seed =
crate::runtime::Pink::key_seed().ok_or(DispatchError::Other("Key seed missing"))?;
let seed_key = sp_core::sr25519::Pair::restore_from_secret_key(&seed);
let contract_address = env.ext().address();
let contract_address: &[u8] = contract_address.as_ref();
let contract_address: &[u8] = self.address.as_ref();
let derived_pair = seed_key
.derive_sr25519_pair(&[contract_address, &salt, b"keygen"])
.or(Err(DispatchError::Other("Failed to derive sr25519 pair")))?;
let priviate_key = derived_pair.dump_secret_key();
let priviate_key: &[u8] = priviate_key.as_ref();
let public_key = derived_pair.public();
let public_key: &[u8] = public_key.as_ref();

env.write(&(priviate_key, public_key).encode(), false, None)
.map_err(|_| {
DispatchError::Other("ChainExtension failed to return derive_sr25519_pair")
})?;
Ok(RetVal::Converging(0))
Ok((priviate_key.to_vec(), public_key.to_vec()))
}
}

Expand Down