diff --git a/contracts/Account.cairo b/contracts/Account.cairo index cb9a46916..8b7246e6e 100644 --- a/contracts/Account.cairo +++ b/contracts/Account.cairo @@ -5,7 +5,7 @@ from starkware.cairo.common.hash import hash2 from starkware.cairo.common.registers import get_fp_and_pc from starkware.cairo.common.signature import verify_ecdsa_signature from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin -from starkware.starknet.common.syscalls import call_contract +from starkware.starknet.common.syscalls import call_contract, get_caller_address from starkware.starknet.common.storage import Storage # @@ -51,6 +51,23 @@ end func address() -> (res: felt): end +# +# Guards +# + +@view +func assert_only_self{ + storage_ptr: Storage*, + pedersen_ptr: HashBuiltin*, + syscall_ptr: felt*, + range_check_ptr + }(): + let (self) = address.read() + let (caller) = get_caller_address() + assert self = caller + return () +end + # # Getters # @@ -73,6 +90,40 @@ func get_L1_address{ storage_ptr: Storage*, pedersen_ptr: HashBuiltin*, range_ch return (res=res) end +@external +func get_nonce{ storage_ptr: Storage*, pedersen_ptr: HashBuiltin*, range_check_ptr }() -> (res: felt): + let (res) = current_nonce.read() + return (res=res) +end + +# +# Setters +# + +@external +func set_public_key{ + storage_ptr: Storage*, + pedersen_ptr: HashBuiltin*, + syscall_ptr: felt*, + range_check_ptr + }(new_public_key: felt): + assert_only_self() + public_key.write(new_public_key) + return () +end + +@external +func set_L1_address{ + storage_ptr: Storage*, + pedersen_ptr: HashBuiltin*, + syscall_ptr: felt*, + range_check_ptr + }(new_L1_address: felt): + assert_only_self() + L1_address.write(new_L1_address) + return () +end + # # Initializer # diff --git a/test/Account.py b/test/Account.py index 675b1b692..c43fa9b92 100644 --- a/test/Account.py +++ b/test/Account.py @@ -4,7 +4,9 @@ from utils.Signer import Signer signer = Signer(123456789987654321) +other = Signer(987654321123456789) L1_ADDRESS = 0x1f9840a85d5aF5bf1D1762F925BDADdC4201F984 +ANOTHER_ADDRESS = 0xd9e1ce17f2641f24ae83637ab66a2cca9c378b9f @pytest.fixture(scope='module') @@ -33,9 +35,54 @@ async def test_execute(account_factory): starknet, account = account_factory initializable = await starknet.deploy("contracts/Initializable.cairo") - transaction = signer.build_transaction( - account, initializable.contract_address, 'initialize', [], 0) + transaction = await signer.build_transaction( + account, initializable.contract_address, 'initialize', []) assert await initializable.initialized().call() == (0,) await transaction.invoke() assert await initializable.initialized().call() == (1,) + + +# @pytest.mark.asyncio +# async def test_nonce(account_factory): +# starknet, account = account_factory +# initializable = await starknet.deploy("contracts/Initializable.cairo") + +# await signer.build_transaction( +# account, initializable.contract_address, 'initialize', []).invoke() + +# try: +# await signer.build_transaction( +# account, initializable.contract_address, 'initialize', []).invoke() +# except: +# assert 4 == 0 + + +@pytest.mark.asyncio +async def test_L1_address_setter(account_factory): + _, account = account_factory + assert await account.get_L1_address().call() == (L1_ADDRESS,) + + tx = await signer.build_transaction( + account, account.contract_address, 'set_L1_address', [ANOTHER_ADDRESS]) + await tx.invoke() + + assert await account.get_L1_address().call() == (ANOTHER_ADDRESS,) + + +@pytest.mark.asyncio +async def test_public_key_setter(account_factory): + _, account = account_factory + assert await account.get_public_key().call() == (signer.public_key,) + + tx = await signer.build_transaction( + account, account.contract_address, 'set_public_key', [other.public_key]) + await tx.invoke() + + assert await account.get_public_key().call() == (other.public_key,) + + # tear down test. todo: cleanup on fixture directly + tx = await other.build_transaction( + account, account.contract_address, 'set_public_key', [signer.public_key]) + + await tx.invoke() diff --git a/test/utils/Signer.py b/test/utils/Signer.py index e18111c12..56cf7129f 100644 --- a/test/utils/Signer.py +++ b/test/utils/Signer.py @@ -10,7 +10,8 @@ def __init__(self, private_key): def sign(self, message_hash): return sign(msg_hash=message_hash, priv_key=self.private_key) - def build_transaction(self, account, to, selector_name, calldata, nonce): + async def build_transaction(self, account, to, selector_name, calldata): + (nonce,) = await account.get_nonce().call() selector = get_selector_from_name(selector_name) message_hash = hash_message( to, selector, calldata, account.contract_address, nonce)