diff --git a/benchmarks/starknet_abi_base.py b/benchmarks/starknet_abi_base.py index 2450dd0..9c00fb2 100644 --- a/benchmarks/starknet_abi_base.py +++ b/benchmarks/starknet_abi_base.py @@ -16,7 +16,9 @@ def bench_setup(): ] def _run_bench(): - parsed_abi = StarknetAbi.from_json(starknet_eth_abi, "starknet_eth") + parsed_abi = StarknetAbi.from_json( + starknet_eth_abi, "starknet_eth", class_hash=b"" + ) transfer_func = parsed_abi.functions["transfer"].inputs calldata_copy = transfer_calldata.copy() diff --git a/starknet_abi/abi_types.py b/starknet_abi/abi_types.py index 99b1732..98bd9cf 100644 --- a/starknet_abi/abi_types.py +++ b/starknet_abi/abi_types.py @@ -323,3 +323,15 @@ def id_str(self): :return: """ return f"{self.name}:{self.type.id_str()}" + + +# Constant Types + +STARKNET_ACCOUNT_CALL = StarknetStruct( + name="Call", + members=[ + AbiParameter("to", StarknetCoreType.ContractAddress), + AbiParameter("selector", StarknetCoreType.Felt), + AbiParameter("calldata", StarknetArray(StarknetCoreType.Felt)), + ], +) diff --git a/starknet_abi/decode.py b/starknet_abi/decode.py index 77d2f98..82a9442 100644 --- a/starknet_abi/decode.py +++ b/starknet_abi/decode.py @@ -79,6 +79,11 @@ def decode_core_type( ), f"{encoded_int} larger than Felt" return f"0x{encoded_int:064x}" + case StarknetCoreType.EthAddress: + encoded_int = calldata.pop(0) + assert 0 <= encoded_int <= decode_type.max_value(), f"{encoded_int:0x} larger than EthAddress" + return f"0x{encoded_int:040x}" + case StarknetCoreType.NoneType: return "" diff --git a/starknet_abi/decoding_types.py b/starknet_abi/decoding_types.py index 1909119..06e8dd7 100644 --- a/starknet_abi/decoding_types.py +++ b/starknet_abi/decoding_types.py @@ -33,6 +33,19 @@ class DecodedEvent: data: dict[str, Any] +@dataclass(slots=True) +class DecodedOperation: + """ + + Dataclass representing a decoded user operation. If operation is unknown, the name will be set to + 'Unknown' and params set to the raw calldata inputs. + + """ + + operation_name: str + operation_params: dict[str, Any] + + @dataclass(slots=True) class AbiFunction: """ diff --git a/starknet_abi/dispatch.py b/starknet_abi/dispatch.py index 7ada3a7..5198044 100644 --- a/starknet_abi/dispatch.py +++ b/starknet_abi/dispatch.py @@ -1,25 +1,50 @@ +from bisect import bisect_right from dataclasses import dataclass -from typing import Sequence - -from starknet_abi.abi_types import AbiParameter, StarknetType +from typing import Any, Sequence + +from starknet_abi.abi_types import ( + STARKNET_ACCOUNT_CALL, + AbiParameter, + StarknetArray, + StarknetType, +) from starknet_abi.core import StarknetAbi from starknet_abi.decode import decode_from_params, decode_from_types -from starknet_abi.decoding_types import DecodedEvent, DecodedFunction -from starknet_abi.exceptions import InvalidCalldataError +from starknet_abi.decoding_types import DecodedEvent, DecodedFunction, DecodedOperation +from starknet_abi.exceptions import DispatcherDecodeError, InvalidCalldataError # fmt: off -# starknet_keccak(b'__execute__').hex() -EXECUTE_SIGNATURE = "015d40a3d6ca2ac30f4031e42be28da9b056fef9bb7357ac5e85627ee876e5ad" - -# starknet_keccak(b'__validate__').hex() -VALIDATE_SIGNATURE = "0162da33a4585851fe8d3af3c2a9c60b557814e221e0d4f30ff0b2189d9c7775" - -# starknet_keccak(b'__validate_deploy__').hex() -VALIDATE_DEPLOY_SIGNATURE = "036fcbf06cd96843058359e1a75928beacfac10727dab22a3972f0af8aa92895" - -# starknet_keccak(b'__validate_declare__').hex() -VALIDATE_DECLARE_SIGNATURE = "0289da278a8dc833409cabfdad1581e8e7d40e42dcaed693fa4008dcdb4963b3" +# starknet_keccak(b'__execute__')[-8:] +EXECUTE_SIGNATURE = bytes.fromhex("5e85627ee876e5ad") + +# starknet_keccak(b'__validate__')[-8:] +VALIDATE_SIGNATURE = bytes.fromhex("0ff0b2189d9c7775") + +# starknet_keccak(b'__validate_deploy__')[-8:] +VALIDATE_DEPLOY_SIGNATURE = bytes.fromhex("3972f0af8aa92895") + +# starknet_keccak(b'__validate_declare__')[-8:] +VALIDATE_DECLARE_SIGNATURE = bytes.fromhex("fa4008dcdb4963b3") + +CORE_FUNCTIONS: dict[bytes, dict[str, Any]] = { + EXECUTE_SIGNATURE: { + "name": "__execute__", + "inputs": [AbiParameter("calls", StarknetArray(STARKNET_ACCOUNT_CALL))], + }, + VALIDATE_SIGNATURE: { + "name": "__validate__", + "inputs": [AbiParameter("calls", StarknetArray(STARKNET_ACCOUNT_CALL))], + }, + VALIDATE_DEPLOY_SIGNATURE: { + "name": "__validate_deploy__", + "inputs": [], + }, + VALIDATE_DECLARE_SIGNATURE: { + "name": "__validate_declare__", + "inputs": [], + }, +} # fmt: on @@ -60,6 +85,19 @@ class ClassDispatcher: class_hash: bytes +def _parse_call( + calldata: list[int], +) -> tuple[ + bytes, bytes, list[int] # Contract Address # Function Selector # Calldata +]: + contract_address = calldata.pop(0).to_bytes(length=32) + function_selector = calldata.pop(0).to_bytes(length=32) + _calldata_len = calldata.pop(0) + function_calldata = [calldata.pop(0) for _ in range(_calldata_len)] + + return contract_address, function_selector, function_calldata + + @dataclass(slots=True) class DecodingDispatcher: """ @@ -90,6 +128,11 @@ class DecodingDispatcher: ], ] + contract_mapping: dict[ + bytes, # Last 8 bytes of Contract Address + dict[int, bytes], # Mapping of Declaration Blocks to Class Hashes + ] + def __init__(self): self.class_ids = {} self.event_types = {} @@ -166,7 +209,25 @@ def add_abi(self, abi: StarknetAbi): ) self.class_ids.update({class_id: class_dispatcher}) - def decode_function( + def add_contract( + self, contract_address: bytes, class_hash: bytes, declaration_block: int + ): + """ + Adds a contract address and an implementation class to the contract mapping. + """ + if class_hash[-8:] not in self.class_ids: + raise DispatcherDecodeError( + f"Class 0x{class_hash.hex()} not present in dispatcher. Cannot add implementation for " + f"contract 0x{contract_address.hex()}" + ) + + contract_id = contract_address[-8:] + if contract_id not in self.contract_mapping: + self.contract_mapping.update({contract_id: {}}) + + self.contract_mapping[contract_id].update({declaration_block: class_hash}) + + def decode_function( # pylint: disable=too-many-locals self, calldata: list[int], result: list[int], @@ -184,32 +245,47 @@ def decode_function( :param class_hash: class hash of the trace or transaction :return: """ - class_dispatcher = self.class_ids.get(class_hash[-8:]) - if class_dispatcher is None: - return None - - # Both function_dispatcher and function_type should throw if keys not found - function_dispatcher = class_dispatcher.function_ids[function_selector[-8:]] - function_type = self.function_types[function_dispatcher.decoder_reference] + decode_id = function_selector[-8:] + + if decode_id in CORE_FUNCTIONS: + input_types: Sequence[AbiParameter] = CORE_FUNCTIONS[decode_id]["inputs"] + function_name = CORE_FUNCTIONS[decode_id]["name"] + output_types: Sequence[StarknetType] = [] + abi_name = None + + else: + class_dispatcher = self.class_ids.get(class_hash[-8:]) + if class_dispatcher is None: + return None + + # Both function_dispatcher and function_type should throw if keys not found + function_dispatcher = class_dispatcher.function_ids[function_selector[-8:]] + input_types, output_types = self.function_types[ + function_dispatcher.decoder_reference + ] + function_name, abi_name = ( + function_dispatcher.function_name, + class_dispatcher.abi_name, + ) # Copy Arrays that can be consumed by decoder _calldata, _result = calldata.copy(), result.copy() - decoded_inputs = decode_from_params(function_type[0], _calldata) - decoded_outputs = decode_from_types(function_type[1], _result) + decoded_inputs = decode_from_params(input_types, _calldata) + decoded_outputs = decode_from_types(output_types, _result) if len(_calldata) > 0: raise InvalidCalldataError( - f"Calldata Remaining after decoding function input {calldata} from {function_type[0]}" + f"Calldata Remaining after decoding function input {calldata} from {input_types}" ) if len(_result) > 0: raise InvalidCalldataError( - f"Calldata Remaining after decoding function result {result} from {function_type[1]}" + f"Calldata Remaining after decoding function result {result} from {output_types}" ) return DecodedFunction( - abi_name=class_dispatcher.abi_name, - func_name=function_dispatcher.function_name, + abi_name=abi_name, + func_name=function_name, inputs=decoded_inputs, outputs=decoded_outputs, ) @@ -258,3 +334,82 @@ def decode_event( name=event_dispatcher.event_name, data=decoded_data, ) + + def decode_multicall( + self, calldata: list[int], block: int + ) -> list[DecodedOperation]: + """ + Decodes a multicall operation from transaction calldata using the internal mapping of contract_addresses to + implemented class hashes. If the class hash is not present in the dispatcher, the operation is skipped. + + :param calldata: list of integers representing the calldata + :param block: block number to decode the transaction at + :return: list of DecodedOperations + """ + + _calldata = calldata.copy() + operation_count = _calldata.pop(0) + + parsed_calls = [_parse_call(_calldata) for _ in range(operation_count)] + + decoded_operations = [] + for contract_address, function_selector, function_calldata in parsed_calls: + decoded_op = self._decode_tx_call( + contract_address=contract_address, + function_selector=function_selector, + function_calldata=function_calldata, + block=block, + ) + decoded_operations.append(decoded_op) + + return decoded_operations + + def _decode_tx_call( + self, + contract_address: bytes, + function_selector: bytes, + function_calldata: list[int], + block: int, + ) -> DecodedOperation: + contract_implementations = self.contract_mapping.get(contract_address[-8:]) + if contract_implementations is None: + return DecodedOperation( + operation_name="Unknown", + operation_params={"raw_calldata": function_calldata}, + ) + + block_history = sorted(list(contract_implementations.keys())) + if block < block_history[0]: + raise DispatcherDecodeError( + f"Contract 0x{contract_address.hex()} has no implementation history before block {block_history[0]}." + f" Cannot decode transaction at block {block}" + ) + + if len(contract_implementations) == 1: + contract_class = contract_implementations[block_history[0]] + else: + impl_block = block_history[bisect_right(block_history, block) - 1] + contract_class = contract_implementations[impl_block] + + if contract_class[-8:] not in self.class_ids: + return DecodedOperation( + operation_name="Unknown", + operation_params={"raw_calldata": function_calldata}, + ) + + class_dispatcher = self.class_ids[contract_class[-8:]] + try: + function_decoder = class_dispatcher.function_ids[function_selector[-8:]] + function_types, _ = self.function_types[function_decoder.decoder_reference] + decoded_inputs = decode_from_params(function_types, function_calldata) + except Exception: + raise DispatcherDecodeError( # pylint: disable=raise-missing-from + f"Contract 0x{contract_address.hex()} mapped to Class 0x{class_dispatcher.class_hash.hex()} " + f"at block {block} -- Could Note Decode Function 0x{function_selector.hex()} with Calldata " + f"{function_calldata}" + ) + + return DecodedOperation( + operation_name=function_decoder.function_name, + operation_params=decoded_inputs, + ) diff --git a/starknet_abi/exceptions.py b/starknet_abi/exceptions.py index aa26d2a..cf4dce4 100644 --- a/starknet_abi/exceptions.py +++ b/starknet_abi/exceptions.py @@ -57,3 +57,9 @@ class TypeEncodeError(Exception): starknet_abi.exceptions.TypeEncodeError: Integer 131072 is out of range for StarknetCoreType.U16 """ + + +class DispatcherDecodeError(Exception): + """ + Raised when there is an error decoding Functions, Events, or User Operations using the decoding dispatcher + """ diff --git a/starknet_abi/parse.py b/starknet_abi/parse.py index a4f9290..6986fe0 100644 --- a/starknet_abi/parse.py +++ b/starknet_abi/parse.py @@ -3,6 +3,7 @@ from typing import Any from starknet_abi.abi_types import ( + STARKNET_ACCOUNT_CALL, AbiMemberType, AbiParameter, StarknetArray, @@ -127,6 +128,8 @@ def parse_enums_and_structs( case ["core", "array" | "integer" | "bool" | "option", *_]: # Automatically parses Array/Span, u256, bool, and Option types as StarknetCoreType continue + case ["core", "starknet", "account", "Call"]: + continue # Can Hard code in structs like openzeppelin's ERC20 & Events for faster parsing @@ -269,6 +272,10 @@ def _parse_type( # pylint: disable=too-many-return-statements _parse_type(extract_inner_type(abi_type), custom_types) ) + # Matches 'core::starknet::account::Call' + case ["starknet", "account", "Call"]: + return STARKNET_ACCOUNT_CALL + case _: # If unknown type is defined in struct context, return struct if abi_type in custom_types: diff --git a/tests/test_abi_parsing/test_struct_parsing.py b/tests/test_abi_parsing/test_struct_parsing.py index 270e7ce..0a981cb 100644 --- a/tests/test_abi_parsing/test_struct_parsing.py +++ b/tests/test_abi_parsing/test_struct_parsing.py @@ -93,7 +93,7 @@ def test_enum_parsing(): type_dict = parse_enums_and_structs(grouped_abi["type_def"]) - assert len(type_dict) == 5 + assert len(type_dict) == 4 assert type_dict["account::escape::EscapeStatus"] == StarknetEnum( name="account::escape::EscapeStatus",