Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: validate Ape has connected to the right network and client #1038

Merged
merged 10 commits into from
Sep 15, 2022
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,5 @@ tests/integration/cli/projects/with-dependencies/renamed_contracts_folder/ape-co
tests/integration/cli/projects/with-dependencies/containing_sub_dependencies/sub_dependency/ape-config.yaml
tests/integration/cli/projects/only-dependencies/dependency_in_project_only/ape-config.yaml
tests/functional/data/projects/BrownieProject/ape-config.yaml
tests/**/dev/
tests/**/geth-logs/
16 changes: 16 additions & 0 deletions src/ape/api/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ape.exceptions import (
NetworkError,
NetworkMismatchError,
NetworkNotFoundError,
ProviderNotConnectedError,
SignatureError,
Expand Down Expand Up @@ -876,6 +877,21 @@ def publish_contract(self, address: AddressType):
logger.info(f"Publishing and verifying contract using '{self.explorer.name}'.")
self.explorer.publish_contract(address)

def verify_chain_id(self, chain_id: int):
"""
Verify a chain ID for this network.

Args:
chain_id (int): The chain ID to verify.

Raises:
:class:`~ape.exceptions.NetworkMismatchError`: When the network is
not local or adhoc and has a different hardcoded chain ID than
the given one.
"""
if self.name not in ("adhoc", LOCAL_NETWORK_NAME) and self.chain_id != chain_id:
raise NetworkMismatchError(chain_id, self)


def create_network_type(chain_id: int, network_id: int) -> Type[NetworkAPI]:
"""
Expand Down
20 changes: 15 additions & 5 deletions src/ape/api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,18 +682,28 @@ def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int:

@property
def chain_id(self) -> int:
default_chain_id = None
if self.network.name not in (
"adhoc",
LOCAL_NETWORK_NAME,
) and not self.network.name.endswith("-fork"):
# If using a live network, the chain ID is hardcoded.
return self.network.chain_id
default_chain_id = self.network.chain_id

elif hasattr(self.web3, "eth"):
return self.web3.eth.chain_id
try:
if hasattr(self.web3, "eth"):
return self.web3.eth.chain_id

else:
raise ProviderNotConnectedError()
except ProviderNotConnectedError:
if default_chain_id is not None:
return default_chain_id

raise # Original error

if default_chain_id is not None:
return default_chain_id

raise ProviderNotConnectedError()

@property
def gas_price(self) -> int:
Expand Down
14 changes: 14 additions & 0 deletions src/ape/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from eth_utils import humanize_hash

if TYPE_CHECKING:
from ape.api.networks import NetworkAPI
from ape.api.providers import SubprocessProvider
from ape.types import SnapshotID

Expand Down Expand Up @@ -187,6 +188,19 @@ class ProviderError(ApeException):
"""


class NetworkMismatchError(ProviderError):
"""
Raised when connecting a provider to the wrong network.
"""

def __init__(self, chain_id: int, network: "NetworkAPI"):
message = (
f"Provider connected to chain ID '{chain_id}', which does not match "
f"network chain ID '{network.chain_id}'. Are you connected to '{network.name}'?"
)
super().__init__(message)


class ProviderNotConnectedError(ProviderError):
"""
Raised when not connected to a provider.
Expand Down
2 changes: 1 addition & 1 deletion src/ape_geth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ape import plugins

from .providers import GethNetworkConfig, GethProvider, NetworkConfig
from .provider import GethNetworkConfig, GethProvider, NetworkConfig


@plugins.register(plugins.Config)
Expand Down
18 changes: 8 additions & 10 deletions src/ape_geth/providers.py → src/ape_geth/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,7 @@ def connection_str(self) -> str:

def connect(self):
self._client_version = None # Clear cached version when connecting to another URI.
provider = HTTPProvider(self.uri, request_kwargs={"timeout": 30 * 60})
self._web3 = Web3(provider)
self._web3 = _create_web3(self.uri)

if not self.is_connected:
if self.network.name != LOCAL_NETWORK_NAME:
Expand Down Expand Up @@ -256,14 +255,7 @@ def is_likely_poa() -> bool:
if chain_id in (4, 5, 42) or is_likely_poa():
self._web3.middleware_onion.inject(geth_poa_middleware, layer=0)

if (
self.network.name not in ("adhoc", LOCAL_NETWORK_NAME)
and self.network.chain_id != self.chain_id
):
raise ProviderError(
"HTTP Connection does not match expected chain ID. "
f"Are you connected to '{self.network.name}'?"
)
self.network.verify_chain_id(chain_id)

def disconnect(self):
if self._geth is not None:
Expand Down Expand Up @@ -313,3 +305,9 @@ def _get_call_tree_from_parity():

def _log_connection(self, client_name: str):
logger.info(f"Connecting to existing {client_name} node at '{self._clean_uri}'.")


def _create_web3(uri: str):
# Separated into helper method for testing purposes.
provider = HTTPProvider(uri, request_kwargs={"timeout": 30 * 60})
return Web3(provider)
38 changes: 37 additions & 1 deletion tests/functional/test_geth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from web3.exceptions import ContractLogicError as Web3ContractLogicError

from ape.api.networks import LOCAL_NETWORK_NAME
from ape.exceptions import ContractLogicError, TransactionError
from ape.exceptions import ContractLogicError, NetworkMismatchError, TransactionError
from ape_geth import GethProvider
from tests.functional.data.python import TRACE_RESPONSE

Expand Down Expand Up @@ -134,3 +134,39 @@ def test_get_logs_when_connected_to_geth(vyper_contract_instance, eth_tester_pro
assert actual.event_name == "NumberChange"
assert actual.contract_address == vyper_contract_instance.address
assert actual.event_arguments["newNum"] == 123


def test_chain_id_when_connected(eth_tester_provider_geth):
assert eth_tester_provider_geth.chain_id == 131277322940537


def test_chain_id_live_network_not_connected(networks):
geth = networks.get_provider_from_choice("ethereum:rinkeby:geth")
assert geth.chain_id == 4


def test_chain_id_live_network_connected_uses_web3_chain_id(mocker, eth_tester_provider_geth):
mock_network = mocker.MagicMock()
mock_network.chain_id = 999999999 # Shouldn't use hardcoded network
orig_network = eth_tester_provider_geth.network
eth_tester_provider_geth.network = mock_network

# Still use the connected chain ID instead network's
assert eth_tester_provider_geth.chain_id == 131277322940537
eth_tester_provider_geth.network = orig_network


def test_connect_wrong_chain_id(mocker, ethereum, eth_tester_provider_geth):
eth_tester_provider_geth.network = ethereum.get_network("kovan")

# Ensure when reconnecting, it does not use HTTP
factory = mocker.patch("ape_geth.provider._create_web3")
factory.return_value = eth_tester_provider_geth._web3

expected_error_message = (
"Provider connected to chain ID '131277322940537', "
"which does not match network chain ID '42'. "
"Are you connected to 'kovan'?"
)
with pytest.raises(NetworkMismatchError, match=expected_error_message):
eth_tester_provider_geth.connect()
2 changes: 1 addition & 1 deletion tests/integration/cli/test_networks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ape.api.networks import LOCAL_NETWORK_NAME
from ape_geth.providers import DEFAULT_SETTINGS
from ape_geth.provider import DEFAULT_SETTINGS

from .utils import run_once, skip_projects_except

Expand Down