diff --git a/src/engine.rs b/src/engine.rs index a1190b97d..957a3208b 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -171,8 +171,7 @@ impl Engine { } pub fn get_code_size(address: &Address) -> usize { - // TODO: Seems this can be optimized to only read the register length. - Engine::get_code(&address).len() + sdk::read_storage_len(&address_to_key(KeyPrefix::Code, address)).unwrap_or(0) } pub fn set_nonce(address: &Address, nonce: &U256) { diff --git a/src/lib.rs b/src/lib.rs index 6ef419227..71046d95e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,13 +35,13 @@ mod tests; #[cfg(feature = "contract")] mod contract { - use borsh::{BorshDeserialize, BorshSerialize}; + use borsh::BorshSerialize; use crate::engine::{Engine, EngineResult, EngineState}; #[cfg(feature = "evm_bully")] use crate::parameters::{BeginBlockArgs, BeginChainArgs}; use crate::parameters::{FunctionCallArgs, GetStorageAtArgs, NewCallArgs, ViewCallArgs}; - use crate::prelude::{Address, TryInto, H256, U256}; + use crate::prelude::{Address, H256, U256}; use crate::sdk; use crate::storage::{bytes_to_key, KeyPrefix}; use crate::types::{near_account_to_evm_address, u256_to_arr}; @@ -78,7 +78,7 @@ mod contract { if !state.owner_id.is_empty() { require_owner_only(&state); } - let args = NewCallArgs::try_from_slice(&sdk::read_input()).sdk_expect("ERR_ARG_PARSE"); + let args: NewCallArgs = sdk::read_input_borsh().sdk_unwrap(); Engine::set_state(args.into()); } @@ -116,7 +116,8 @@ mod contract { pub extern "C" fn get_upgrade_index() { let state = Engine::get_state(); let index = sdk::read_u64(&bytes_to_key(KeyPrefix::Config, CODE_STAGE_KEY)) - .sdk_expect("ERR_NO_UPGRADE"); + .sdk_expect("ERR_NO_UPGRADE") + .sdk_unwrap(); sdk::return_output(&(index + state.upgrade_delay_blocks).to_le_bytes()) } @@ -136,7 +137,9 @@ mod contract { #[no_mangle] pub extern "C" fn deploy_upgrade() { let state = Engine::get_state(); - let index = sdk::read_u64(&bytes_to_key(KeyPrefix::Config, CODE_STAGE_KEY)).sdk_unwrap(); + let index = sdk::read_u64(&bytes_to_key(KeyPrefix::Config, CODE_STAGE_KEY)) + .sdk_expect("ERR_NO_UPGRADE") + .sdk_unwrap(); if sdk::block_index() <= index + state.upgrade_delay_blocks { sdk::panic_utf8(b"ERR_NOT_ALLOWED:TOO_EARLY"); } @@ -161,9 +164,7 @@ mod contract { /// Call method on the EVM contract. #[no_mangle] pub extern "C" fn call() { - // TODO: Borsh input pattern is so common here. It worth writing sdk::read_input_borsh(). - let input = sdk::read_input(); - let args = FunctionCallArgs::try_from_slice(&input).sdk_expect("ERR_ARG_PARSE"); + let args: FunctionCallArgs = sdk::read_input_borsh().sdk_unwrap(); let mut engine = Engine::new(predecessor_address()); Engine::call_with_args(&mut engine, args) .map(|res| res.try_to_vec().sdk_expect("ERR_SERIALIZE")) @@ -257,14 +258,12 @@ mod contract { #[no_mangle] pub extern "C" fn register_relayer() { - let relayer_address = sdk::read_input(); - // NOTE: Why not `sdk::read_input_arr20();`? - assert_eq!(relayer_address.len(), 20); + let relayer_address = sdk::read_input_arr20().sdk_unwrap(); let mut engine = Engine::new(predecessor_address()); engine.register_relayer( sdk::predecessor_account_id().as_slice(), - Address(relayer_address.as_slice().try_into().unwrap()), + Address(relayer_address), ); } @@ -287,8 +286,7 @@ mod contract { #[no_mangle] pub extern "C" fn view() { - let input = sdk::read_input(); - let args = ViewCallArgs::try_from_slice(&input).sdk_expect("ERR_ARG_PARSE"); + let args: ViewCallArgs = sdk::read_input_borsh().sdk_unwrap(); let engine = Engine::new(Address::from_slice(&args.sender)); let result = Engine::view_with_args(&engine, args); result.sdk_process() @@ -296,29 +294,28 @@ mod contract { #[no_mangle] pub extern "C" fn get_code() { - let address = sdk::read_input_arr20(); + let address = sdk::read_input_arr20().sdk_unwrap(); let code = Engine::get_code(&Address(address)); sdk::return_output(&code) } #[no_mangle] pub extern "C" fn get_balance() { - let address = sdk::read_input_arr20(); + let address = sdk::read_input_arr20().sdk_unwrap(); let balance = Engine::get_balance(&Address(address)); sdk::return_output(&u256_to_arr(&balance)) } #[no_mangle] pub extern "C" fn get_nonce() { - let address = sdk::read_input_arr20(); + let address = sdk::read_input_arr20().sdk_unwrap(); let nonce = Engine::get_nonce(&Address(address)); sdk::return_output(&u256_to_arr(&nonce)) } #[no_mangle] pub extern "C" fn get_storage_at() { - let input = sdk::read_input(); - let args = GetStorageAtArgs::try_from_slice(&input).sdk_expect("ERR_ARG_PARSE"); + let args: GetStorageAtArgs = sdk::read_input_borsh().sdk_unwrap(); let value = Engine::get_storage(&Address(args.address), &H256(args.key)); sdk::return_output(&value.0) } @@ -333,7 +330,7 @@ mod contract { let mut state = Engine::get_state(); require_owner_only(&state); let input = sdk::read_input(); - let args = BeginChainArgs::try_from_slice(&input).sdk_expect("ERR_ARG_PARSE"); + let args: BeginBlockArgs = sdk::read_input_borsh().sdk_unwrap(); state.chain_id = args.chain_id; Engine::set_state(state); // set genesis block balances @@ -353,7 +350,7 @@ mod contract { let state = Engine::get_state(); require_owner_only(&state); let input = sdk::read_input(); - let _args = BeginBlockArgs::try_from_slice(&input).sdk_expect("ERR_ARG_PARSE"); + let _args: BeginBlockArgs = sdk::read_input_borsh().sdk_unwrap(); // TODO: https://github.com/aurora-is-near/aurora-engine/issues/2 } diff --git a/src/meta_parsing.rs b/src/meta_parsing.rs index 219f2302c..7619018cf 100644 --- a/src/meta_parsing.rs +++ b/src/meta_parsing.rs @@ -13,6 +13,7 @@ pub enum ParsingError { InvalidMetaTransactionMethodName, InvalidMetaTransactionFunctionArg, InvalidEcRecoverSignature, + ArgsLengthMismatch, } pub type ParsingResult = core::result::Result; @@ -505,9 +506,11 @@ pub fn prepare_meta_call_args( let mut arg_bytes = Vec::new(); arg_bytes.extend_from_slice(&keccak(arguments.as_bytes()).as_bytes()); let args_decoded: Vec = rlp_decode(&input.input)?; + if methods.method.args.len() != args_decoded.len() { + return Err(ParsingError::ArgsLengthMismatch); + } for (i, arg) in args_decoded.iter().enumerate() { arg_bytes.extend_from_slice(&eip_712_hash_argument( - // TODO: Check that method.args.len() == args_decoded.len(). Otherwise it may panic here. &methods.method.args[i].t, arg, &methods.types, diff --git a/src/sdk.rs b/src/sdk.rs index b34efc3c3..fc4faa4a9 100644 --- a/src/sdk.rs +++ b/src/sdk.rs @@ -2,6 +2,9 @@ use crate::prelude::{vec, String, Vec, H256}; use crate::types::STORAGE_PRICE_PER_BYTE; use borsh::{BorshDeserialize, BorshSerialize}; +const READ_STORAGE_REGISTER_ID: u64 = 0; +const INPUT_REGISTER_ID: u64 = 0; + mod exports { #[allow(unused)] @@ -156,24 +159,30 @@ mod exports { } } -#[allow(dead_code)] pub fn read_input() -> Vec { unsafe { - exports::input(0); - let bytes: Vec = vec![0; exports::register_len(0) as usize]; - exports::read_register(0, bytes.as_ptr() as *const u64 as u64); + exports::input(INPUT_REGISTER_ID); + let bytes: Vec = vec![0; exports::register_len(INPUT_REGISTER_ID) as usize]; + exports::read_register(INPUT_REGISTER_ID, bytes.as_ptr() as *const u64 as u64); bytes } } -#[allow(dead_code)] -pub fn read_input_arr20() -> [u8; 20] { +pub(crate) fn read_input_borsh() -> Result { + let bytes = read_input(); + T::try_from_slice(&bytes).map_err(|_| ArgParseErr) +} + +pub(crate) fn read_input_arr20() -> Result<[u8; 20], IncorrectInputLength> { unsafe { - exports::input(0); - let bytes = [0u8; 20]; - // TODO: Is it fine to not check the length of the input register here? - exports::read_register(0, bytes.as_ptr() as *const u64 as u64); - bytes + exports::input(INPUT_REGISTER_ID); + if exports::register_len(INPUT_REGISTER_ID) == 20 { + let bytes = [0u8; 20]; + exports::read_register(INPUT_REGISTER_ID, bytes.as_ptr() as *const u64 as u64); + Ok(bytes) + } else { + Err(IncorrectInputLength) + } } } @@ -195,11 +204,25 @@ pub fn return_output(value: &[u8]) { #[allow(dead_code)] pub fn read_storage(key: &[u8]) -> Option> { + read_storage_len(key).map(|value_size| unsafe { + let bytes = vec![0u8; value_size]; + exports::read_register( + READ_STORAGE_REGISTER_ID, + bytes.as_ptr() as *const u64 as u64, + ); + bytes + }) +} + +pub fn read_storage_len(key: &[u8]) -> Option { unsafe { - if exports::storage_read(key.len() as u64, key.as_ptr() as u64, 0) == 1 { - let bytes: Vec = vec![0u8; exports::register_len(0) as usize]; - exports::read_register(0, bytes.as_ptr() as *const u64 as u64); - Some(bytes) + if exports::storage_read( + key.len() as u64, + key.as_ptr() as u64, + READ_STORAGE_REGISTER_ID, + ) == 1 + { + Some(exports::register_len(READ_STORAGE_REGISTER_ID) as usize) } else { None } @@ -207,17 +230,16 @@ pub fn read_storage(key: &[u8]) -> Option> { } /// Read u64 from storage at given key. -pub fn read_u64(key: &[u8]) -> Option { - unsafe { - if exports::storage_read(key.len() as u64, key.as_ptr() as u64, 0) == 1 { +pub(crate) fn read_u64(key: &[u8]) -> Option> { + read_storage_len(key).map(|value_size| unsafe { + if value_size == 8 { let result = [0u8; 8]; - // TODO: Are you sure the register length is correct? - exports::read_register(0, result.as_ptr() as _); - Some(u64::from_le_bytes(result)) + exports::read_register(READ_STORAGE_REGISTER_ID, result.as_ptr() as _); + Ok(u64::from_le_bytes(result)) } else { - None + Err(InvalidU64) } - } + }) } #[allow(dead_code)] @@ -484,3 +506,24 @@ pub fn promise_batch_create(account_id: String) -> u64 { pub fn storage_has_key(key: &[u8]) -> bool { unsafe { exports::storage_has_key(key.len() as u64, key.as_ptr() as u64) == 1 } } + +pub(crate) struct IncorrectInputLength; +impl AsRef<[u8]> for IncorrectInputLength { + fn as_ref(&self) -> &[u8] { + b"ERR_INCORRECT_INPUT_LENGTH" + } +} + +pub(crate) struct ArgParseErr; +impl AsRef<[u8]> for ArgParseErr { + fn as_ref(&self) -> &[u8] { + b"ERR_ARG_PARSE" + } +} + +pub(crate) struct InvalidU64; +impl AsRef<[u8]> for InvalidU64 { + fn as_ref(&self) -> &[u8] { + b"ERR_NOT_U64" + } +}