Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/synthetix/perps/perps.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def _prepare_oracle_call(self, market_names: [str] = []):
self.snx,
self.snx.contracts["pyth_erc7412_wrapper"]["PythERC7412Wrapper"]["address"],
price_update_data,
0,
args,
)
value = len(market_names)
Expand Down
60 changes: 47 additions & 13 deletions src/synthetix/utils/multicall.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@


# constants
ORACLE_DATA_REQUIRED = "0xcf2cabdf"
SELECTOR_ORACLE_DATA_REQUIRED = "0xcf2cabdf"
SELECTOR_ORACLE_DATA_REQUIRED_WITH_FEE = "0x0e7186fb"
SELECTOR_ERRORS = "0x0b42fd17"


def decode_result(contract, function_name, result):
Expand All @@ -20,14 +22,33 @@ def decode_result(contract, function_name, result):


# ERC-7412 support
def decode_erc7412_error(snx, error):
def decode_erc7412_errors_error(error):
"""Decodes an Errors error"""
error_data = decode_hex(f"0x{error[10:]}")

errors = decode(["bytes[]"], error_data)[0]
errors = [ContractCustomError(data=encode_hex(e)) for e in errors]
errors.reverse()

return errors


def decode_erc7412_oracle_data_required_error(snx, error):
"""Decodes an OracleDataRequired error"""
# remove the signature and decode the error data
error_data = decode_hex(f"0x{error[10:]}")

# decode the result
output_types = ["address", "bytes"]
address, data = decode(output_types, error_data)
# could be one of two types with different args
output_types = ["address", "bytes", "uint256"]
try:
address, data, fee = decode(output_types, error_data)
print("USED NORMAL output types")
except:
print("USING BACKUP output types")
address, data = decode(output_types[:2], error_data)
fee = 0

address = snx.web3.to_checksum_address(address)

# decode the bytes data into the arguments for the oracle
Expand All @@ -41,7 +62,7 @@ def decode_erc7412_error(snx, error):
)

feed_ids = [encode_hex(raw_feed_id) for raw_feed_id in raw_feed_ids]
return address, feed_ids, (update_type, staleness_tolerance, raw_feed_ids)
return address, feed_ids, fee, (update_type, staleness_tolerance, raw_feed_ids)
except:
pass

Expand All @@ -51,14 +72,14 @@ def decode_erc7412_error(snx, error):

feed_ids = [encode_hex(raw_feed_id)]
raw_feed_ids = [raw_feed_id]
return address, feed_ids, (update_type, publish_time, raw_feed_ids)
return address, feed_ids, fee, (update_type, publish_time, raw_feed_ids)
except:
pass

raise Exception("Error data can not be decoded")


def make_fulfillment_request(snx, address, price_update_data, args):
def make_fulfillment_request(snx, address, price_update_data, fee, args):
erc_contract = snx.web3.eth.contract(
address=address,
abi=snx.contracts["pyth_erc7412_wrapper"]["PythERC7412Wrapper"]["abi"],
Expand All @@ -71,7 +92,7 @@ def make_fulfillment_request(snx, address, price_update_data, args):
)

# assume 1 wei per price update
value = len(price_update_data) * 1
value = fee if fee > 0 else len(price_update_data) * 1

update_tx = erc_contract.functions.fulfillOracleQuery(
encoded_args
Expand All @@ -80,11 +101,24 @@ def make_fulfillment_request(snx, address, price_update_data, args):


def handle_erc7412_error(snx, error, calls):
if type(error) is ContractCustomError and error.data.startswith(
ORACLE_DATA_REQUIRED
"When receiving a ERC7412 error, will return an updated list of calls with the required price updates"
if type(error) is ContractCustomError and error.data.startswith(SELECTOR_ERRORS):
errors = decode_erc7412_errors_error(error.data)

# TODO: execute in parallel
for sub_error in errors:
sub_calls = handle_erc7412_error(snx, sub_error, [])
calls = sub_calls + calls

return calls
if type(error) is ContractCustomError and (
error.data.startswith(SELECTOR_ORACLE_DATA_REQUIRED)
or error.data.startswith(SELECTOR_ORACLE_DATA_REQUIRED_WITH_FEE)
):
# decode error data
address, feed_ids, args = decode_erc7412_error(snx, error.data)
address, feed_ids, fee, args = decode_erc7412_oracle_data_required_error(
snx, error.data
)
update_type = args[0]

if update_type == 1:
Expand All @@ -106,7 +140,7 @@ def handle_erc7412_error(snx, error, calls):

# create a new request
to, data, value = make_fulfillment_request(
snx, address, price_update_data, args
snx, address, price_update_data, fee, args
)
elif update_type == 2:
# fetch the data from pyth for those feed ids
Expand All @@ -115,7 +149,7 @@ def handle_erc7412_error(snx, error, calls):

# create a new request
to, data, value = make_fulfillment_request(
snx, address, price_update_data, args
snx, address, price_update_data, fee, args
)
else:
snx.logger.error(f"Unknown update type: {update_type}")
Expand Down
116 changes: 116 additions & 0 deletions src/tests/test_oracles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
from web3.exceptions import ContractCustomError
from eth_abi import decode, encode
from eth_utils import encode_hex, decode_hex
from synthetix import Synthetix
from synthetix.utils.multicall import (
handle_erc7412_error,
SELECTOR_ERRORS,
SELECTOR_ORACLE_DATA_REQUIRED,
SELECTOR_ORACLE_DATA_REQUIRED_WITH_FEE,
)

# constants
ODR_ERROR_TYPES = ["address", "bytes"]
ODR_FEE_ERROR_TYPES = ["address", "bytes", "uint256"]

ODR_BYTES_TYPES = {
1: ["uint8", "uint64", "bytes32[]"],
2: ["uint8", "uint64", "bytes32"],
}


# encode some errors
def encode_odr_error(snx, inputs, with_fee=False):
"Utility to help encode errors to test"
types = ODR_FEE_ERROR_TYPES if with_fee else ODR_ERROR_TYPES
fee = [1] if with_fee else []
address = snx.contracts["pyth_erc7412_wrapper"]["PythERC7412Wrapper"]["address"]

# get the update type
update_type = inputs[0]
bytes_types = ODR_BYTES_TYPES[update_type]

# encode bytes
error_bytes = encode(bytes_types, inputs)

# encode the error
error = encode(types, [address, error_bytes] + fee)
return error


# tests


def test_update_type_1_with_staleness(snx):
# Test update_type 1 with staleness 3600
feed_id = snx.pyth.price_feed_ids["ETH"]
error = encode_odr_error(snx, (1, 3600, [decode_hex(feed_id)]))
error_hex = SELECTOR_ORACLE_DATA_REQUIRED + error.hex()

custom_error = ContractCustomError(message="Test error", data=error_hex)
calls = handle_erc7412_error(snx, custom_error, [])

assert len(calls) == 1, "Expected 1 call for update_type 1"
assert calls[0][1] == True, "Expected call to be marked as static"
assert calls[0][2] > 0, "Expected non-zero value for the call"


def test_update_type_2_with_recent_publish_time(snx):
# Test update_type 2 with a publish_time in the last 60 seconds
feed_id = snx.pyth.price_feed_ids["BTC"]
current_time = snx.web3.eth.get_block("latest").timestamp
recent_publish_time = current_time - 30 # 30 seconds ago

error = encode_odr_error(snx, (2, recent_publish_time, decode_hex(feed_id)))
error_hex = SELECTOR_ORACLE_DATA_REQUIRED + error.hex()

custom_error = ContractCustomError(message="Test error", data=error_hex)
calls = handle_erc7412_error(snx, custom_error, [])

assert len(calls) == 1, "Expected 1 call for update_type 2"
assert calls[0][1] == True, "Expected call to be marked as static"
assert calls[0][2] > 0, "Expected non-zero value for the call"


def test_oracle_data_required_with_fee(snx):
# Test OracleDataRequired error with fee
feed_id = snx.pyth.price_feed_ids["ETH"]
error = encode_odr_error(snx, (1, 3600, [decode_hex(feed_id)]), with_fee=True)
error_hex = SELECTOR_ORACLE_DATA_REQUIRED_WITH_FEE + error.hex()

custom_error = ContractCustomError(message="Test error", data=error_hex)
calls = handle_erc7412_error(snx, custom_error, [])

assert len(calls) == 1, "Expected 1 call for OracleDataRequired with fee"
assert calls[0][1] == True, "Expected call to be marked as static"
assert calls[0][2] > 0, "Expected non-zero value for the call"


def test_errors_with_multiple_sub_errors(snx):
# Test Errors error which includes multiple individual errors
feed_id_1 = snx.pyth.price_feed_ids["ETH"]
feed_id_2 = snx.pyth.price_feed_ids["BTC"]

error_1 = encode_odr_error(snx, (1, 3600, [decode_hex(feed_id_1)]))
error_1_hex = SELECTOR_ORACLE_DATA_REQUIRED + error_1.hex()

error_2 = encode_odr_error(
snx, (2, snx.web3.eth.get_block("latest").timestamp - 30, decode_hex(feed_id_2))
)
error_2_hex = SELECTOR_ORACLE_DATA_REQUIRED + error_2.hex()

# Encode multiple errors
errors_data = encode(
["bytes[]"], [(decode_hex(error_1_hex), decode_hex(error_2_hex))]
)

errors_hex = SELECTOR_ERRORS + errors_data.hex()

custom_error = ContractCustomError(message="Test error", data=errors_hex)
calls = handle_erc7412_error(snx, custom_error, [])

assert len(calls) == 2, "Expected 2 calls for Errors with 2 sub-errors"
for call in calls:
assert call[1] == True, "Expected all calls to be marked as static"
assert call[2] > 0, "Expected non-zero value for all calls"