Skip to content

Commit

Permalink
adding contract implementation dispatching
Browse files Browse the repository at this point in the history
  • Loading branch information
elicbarbieri committed Apr 15, 2024
1 parent b323cf1 commit 7ff1eb5
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 32 deletions.
4 changes: 3 additions & 1 deletion benchmarks/starknet_abi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def bench_setup():
]

def _run_bench():
parsed_abi = StarknetAbi.from_json(starknet_eth_abi, "starknet_eth")
parsed_abi = StarknetAbi.from_json(
starknet_eth_abi, "starknet_eth", class_hash=b""
)

transfer_func = parsed_abi.functions["transfer"].inputs
calldata_copy = transfer_calldata.copy()
Expand Down
12 changes: 12 additions & 0 deletions starknet_abi/abi_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,15 @@ def id_str(self):
:return:
"""
return f"{self.name}:{self.type.id_str()}"


# Constant Types

STARKNET_ACCOUNT_CALL = StarknetStruct(
name="Call",
members=[
AbiParameter("to", StarknetCoreType.ContractAddress),
AbiParameter("selector", StarknetCoreType.Felt),
AbiParameter("calldata", StarknetArray(StarknetCoreType.Felt)),
],
)
5 changes: 5 additions & 0 deletions starknet_abi/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def decode_core_type(
), f"{encoded_int} larger than Felt"
return f"0x{encoded_int:064x}"

case StarknetCoreType.EthAddress:
encoded_int = calldata.pop(0)
assert 0 <= encoded_int <= decode_type.max_value(), f"{encoded_int:0x} larger than EthAddress"
return f"0x{encoded_int:040x}"

case StarknetCoreType.NoneType:
return ""

Expand Down
13 changes: 13 additions & 0 deletions starknet_abi/decoding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@ class DecodedEvent:
data: dict[str, Any]


@dataclass(slots=True)
class DecodedOperation:
"""
Dataclass representing a decoded user operation. If operation is unknown, the name will be set to
'Unknown' and params set to the raw calldata inputs.
"""

operation_name: str
operation_params: dict[str, Any]


@dataclass(slots=True)
class AbiFunction:
"""
Expand Down
215 changes: 185 additions & 30 deletions starknet_abi/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,50 @@
from bisect import bisect_right
from dataclasses import dataclass
from typing import Sequence

from starknet_abi.abi_types import AbiParameter, StarknetType
from typing import Any, Sequence

from starknet_abi.abi_types import (
STARKNET_ACCOUNT_CALL,
AbiParameter,
StarknetArray,
StarknetType,
)
from starknet_abi.core import StarknetAbi
from starknet_abi.decode import decode_from_params, decode_from_types
from starknet_abi.decoding_types import DecodedEvent, DecodedFunction
from starknet_abi.exceptions import InvalidCalldataError
from starknet_abi.decoding_types import DecodedEvent, DecodedFunction, DecodedOperation
from starknet_abi.exceptions import DispatcherDecodeError, InvalidCalldataError

# fmt: off

# starknet_keccak(b'__execute__').hex()
EXECUTE_SIGNATURE = "015d40a3d6ca2ac30f4031e42be28da9b056fef9bb7357ac5e85627ee876e5ad"

# starknet_keccak(b'__validate__').hex()
VALIDATE_SIGNATURE = "0162da33a4585851fe8d3af3c2a9c60b557814e221e0d4f30ff0b2189d9c7775"

# starknet_keccak(b'__validate_deploy__').hex()
VALIDATE_DEPLOY_SIGNATURE = "036fcbf06cd96843058359e1a75928beacfac10727dab22a3972f0af8aa92895"

# starknet_keccak(b'__validate_declare__').hex()
VALIDATE_DECLARE_SIGNATURE = "0289da278a8dc833409cabfdad1581e8e7d40e42dcaed693fa4008dcdb4963b3"
# starknet_keccak(b'__execute__')[-8:]
EXECUTE_SIGNATURE = bytes.fromhex("5e85627ee876e5ad")

# starknet_keccak(b'__validate__')[-8:]
VALIDATE_SIGNATURE = bytes.fromhex("0ff0b2189d9c7775")

# starknet_keccak(b'__validate_deploy__')[-8:]
VALIDATE_DEPLOY_SIGNATURE = bytes.fromhex("3972f0af8aa92895")

# starknet_keccak(b'__validate_declare__')[-8:]
VALIDATE_DECLARE_SIGNATURE = bytes.fromhex("fa4008dcdb4963b3")

CORE_FUNCTIONS: dict[bytes, dict[str, Any]] = {
EXECUTE_SIGNATURE: {
"name": "__execute__",
"inputs": [AbiParameter("calls", StarknetArray(STARKNET_ACCOUNT_CALL))],
},
VALIDATE_SIGNATURE: {
"name": "__validate__",
"inputs": [AbiParameter("calls", StarknetArray(STARKNET_ACCOUNT_CALL))],
},
VALIDATE_DEPLOY_SIGNATURE: {
"name": "__validate_deploy__",
"inputs": [],
},
VALIDATE_DECLARE_SIGNATURE: {
"name": "__validate_declare__",
"inputs": [],
},
}

# fmt: on

Expand Down Expand Up @@ -60,6 +85,19 @@ class ClassDispatcher:
class_hash: bytes


def _parse_call(
calldata: list[int],
) -> tuple[
bytes, bytes, list[int] # Contract Address # Function Selector # Calldata
]:
contract_address = calldata.pop(0).to_bytes(length=32)
function_selector = calldata.pop(0).to_bytes(length=32)
_calldata_len = calldata.pop(0)
function_calldata = [calldata.pop(0) for _ in range(_calldata_len)]

return contract_address, function_selector, function_calldata


@dataclass(slots=True)
class DecodingDispatcher:
"""
Expand Down Expand Up @@ -90,6 +128,11 @@ class DecodingDispatcher:
],
]

contract_mapping: dict[
bytes, # Last 8 bytes of Contract Address
dict[int, bytes], # Mapping of Declaration Blocks to Class Hashes
]

def __init__(self):
self.class_ids = {}
self.event_types = {}
Expand Down Expand Up @@ -166,7 +209,25 @@ def add_abi(self, abi: StarknetAbi):
)
self.class_ids.update({class_id: class_dispatcher})

def decode_function(
def add_contract(
self, contract_address: bytes, class_hash: bytes, declaration_block: int
):
"""
Adds a contract address and an implementation class to the contract mapping.
"""
if class_hash[-8:] not in self.class_ids:
raise DispatcherDecodeError(
f"Class 0x{class_hash.hex()} not present in dispatcher. Cannot add implementation for "
f"contract 0x{contract_address.hex()}"
)

contract_id = contract_address[-8:]
if contract_id not in self.contract_mapping:
self.contract_mapping.update({contract_id: {}})

self.contract_mapping[contract_id].update({declaration_block: class_hash})

def decode_function( # pylint: disable=too-many-locals
self,
calldata: list[int],
result: list[int],
Expand All @@ -184,32 +245,47 @@ def decode_function(
:param class_hash: class hash of the trace or transaction
:return:
"""
class_dispatcher = self.class_ids.get(class_hash[-8:])
if class_dispatcher is None:
return None

# Both function_dispatcher and function_type should throw if keys not found
function_dispatcher = class_dispatcher.function_ids[function_selector[-8:]]
function_type = self.function_types[function_dispatcher.decoder_reference]
decode_id = function_selector[-8:]

if decode_id in CORE_FUNCTIONS:
input_types: Sequence[AbiParameter] = CORE_FUNCTIONS[decode_id]["inputs"]
function_name = CORE_FUNCTIONS[decode_id]["name"]
output_types: Sequence[StarknetType] = []
abi_name = None

else:
class_dispatcher = self.class_ids.get(class_hash[-8:])
if class_dispatcher is None:
return None

# Both function_dispatcher and function_type should throw if keys not found
function_dispatcher = class_dispatcher.function_ids[function_selector[-8:]]
input_types, output_types = self.function_types[
function_dispatcher.decoder_reference
]
function_name, abi_name = (
function_dispatcher.function_name,
class_dispatcher.abi_name,
)

# Copy Arrays that can be consumed by decoder
_calldata, _result = calldata.copy(), result.copy()
decoded_inputs = decode_from_params(function_type[0], _calldata)
decoded_outputs = decode_from_types(function_type[1], _result)
decoded_inputs = decode_from_params(input_types, _calldata)
decoded_outputs = decode_from_types(output_types, _result)

if len(_calldata) > 0:
raise InvalidCalldataError(
f"Calldata Remaining after decoding function input {calldata} from {function_type[0]}"
f"Calldata Remaining after decoding function input {calldata} from {input_types}"
)

if len(_result) > 0:
raise InvalidCalldataError(
f"Calldata Remaining after decoding function result {result} from {function_type[1]}"
f"Calldata Remaining after decoding function result {result} from {output_types}"
)

return DecodedFunction(
abi_name=class_dispatcher.abi_name,
func_name=function_dispatcher.function_name,
abi_name=abi_name,
func_name=function_name,
inputs=decoded_inputs,
outputs=decoded_outputs,
)
Expand Down Expand Up @@ -258,3 +334,82 @@ def decode_event(
name=event_dispatcher.event_name,
data=decoded_data,
)

def decode_multicall(
self, calldata: list[int], block: int
) -> list[DecodedOperation]:
"""
Decodes a multicall operation from transaction calldata using the internal mapping of contract_addresses to
implemented class hashes. If the class hash is not present in the dispatcher, the operation is skipped.
:param calldata: list of integers representing the calldata
:param block: block number to decode the transaction at
:return: list of DecodedOperations
"""

_calldata = calldata.copy()
operation_count = _calldata.pop(0)

parsed_calls = [_parse_call(_calldata) for _ in range(operation_count)]

decoded_operations = []
for contract_address, function_selector, function_calldata in parsed_calls:
decoded_op = self._decode_tx_call(
contract_address=contract_address,
function_selector=function_selector,
function_calldata=function_calldata,
block=block,
)
decoded_operations.append(decoded_op)

return decoded_operations

def _decode_tx_call(
self,
contract_address: bytes,
function_selector: bytes,
function_calldata: list[int],
block: int,
) -> DecodedOperation:
contract_implementations = self.contract_mapping.get(contract_address[-8:])
if contract_implementations is None:
return DecodedOperation(
operation_name="Unknown",
operation_params={"raw_calldata": function_calldata},
)

block_history = sorted(list(contract_implementations.keys()))
if block < block_history[0]:
raise DispatcherDecodeError(
f"Contract 0x{contract_address.hex()} has no implementation history before block {block_history[0]}."
f" Cannot decode transaction at block {block}"
)

if len(contract_implementations) == 1:
contract_class = contract_implementations[block_history[0]]
else:
impl_block = block_history[bisect_right(block_history, block) - 1]
contract_class = contract_implementations[impl_block]

if contract_class[-8:] not in self.class_ids:
return DecodedOperation(
operation_name="Unknown",
operation_params={"raw_calldata": function_calldata},
)

class_dispatcher = self.class_ids[contract_class[-8:]]
try:
function_decoder = class_dispatcher.function_ids[function_selector[-8:]]
function_types, _ = self.function_types[function_decoder.decoder_reference]
decoded_inputs = decode_from_params(function_types, function_calldata)
except Exception:
raise DispatcherDecodeError( # pylint: disable=raise-missing-from
f"Contract 0x{contract_address.hex()} mapped to Class 0x{class_dispatcher.class_hash.hex()} "
f"at block {block} -- Could Note Decode Function 0x{function_selector.hex()} with Calldata "
f"{function_calldata}"
)

return DecodedOperation(
operation_name=function_decoder.function_name,
operation_params=decoded_inputs,
)
6 changes: 6 additions & 0 deletions starknet_abi/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,9 @@ class TypeEncodeError(Exception):
starknet_abi.exceptions.TypeEncodeError: Integer 131072 is out of range for StarknetCoreType.U16
"""


class DispatcherDecodeError(Exception):
"""
Raised when there is an error decoding Functions, Events, or User Operations using the decoding dispatcher
"""
7 changes: 7 additions & 0 deletions starknet_abi/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

from starknet_abi.abi_types import (
STARKNET_ACCOUNT_CALL,
AbiMemberType,
AbiParameter,
StarknetArray,
Expand Down Expand Up @@ -127,6 +128,8 @@ def parse_enums_and_structs(
case ["core", "array" | "integer" | "bool" | "option", *_]:
# Automatically parses Array/Span, u256, bool, and Option types as StarknetCoreType
continue
case ["core", "starknet", "account", "Call"]:
continue

# Can Hard code in structs like openzeppelin's ERC20 & Events for faster parsing

Expand Down Expand Up @@ -269,6 +272,10 @@ def _parse_type( # pylint: disable=too-many-return-statements
_parse_type(extract_inner_type(abi_type), custom_types)
)

# Matches 'core::starknet::account::Call'
case ["starknet", "account", "Call"]:
return STARKNET_ACCOUNT_CALL

case _:
# If unknown type is defined in struct context, return struct
if abi_type in custom_types:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_abi_parsing/test_struct_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_enum_parsing():

type_dict = parse_enums_and_structs(grouped_abi["type_def"])

assert len(type_dict) == 5
assert len(type_dict) == 4

assert type_dict["account::escape::EscapeStatus"] == StarknetEnum(
name="account::escape::EscapeStatus",
Expand Down

0 comments on commit 7ff1eb5

Please sign in to comment.