From 39c99ea15c30b434bd76e36768873bbc2d133369 Mon Sep 17 00:00:00 2001 From: Juliya Smith Date: Wed, 17 Aug 2022 13:43:32 -0500 Subject: [PATCH] fix: issue with networks context when networks shared chain ID (#980) --- src/ape/api/networks.py | 37 ++++++++++++++++--------- tests/conftest.py | 2 +- tests/functional/test_networks.py | 46 +++++++++++++------------------ 3 files changed, 44 insertions(+), 41 deletions(-) diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index c93e272324..f49fc01c45 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -465,6 +465,10 @@ def __init__(self, provider: "ProviderAPI", network_manager: "NetworkManager"): self.provider = provider self.network_manager = network_manager + @property + def empty(self) -> bool: + return not self.connected_providers or not self.provider_stack + def __enter__(self, *args, **kwargs): return self.push_provider() @@ -472,26 +476,29 @@ def __exit__(self, *args, **kwargs): self.pop_provider() def push_provider(self): - if not self.provider.is_connected: + must_connect = not self.provider.is_connected + if must_connect: self.provider.connect() - provider_object_id = self.get_provider_id(self.provider) - self.provider_stack.append(provider_object_id) + provider_id = self.get_provider_id(self.provider) + self.provider_stack.append(provider_id) - if provider_object_id in self.connected_providers: - # Already connected and known - connected_provider = self.connected_providers[provider_object_id] - self.network_manager.active_provider = connected_provider + if provider_id in self.connected_providers: + # Using already connected instance + if must_connect: + # Disconnect if had to connect to check chain ID + self.provider.disconnect() + self.provider = self.connected_providers[provider_id] else: - # Already connected and unknown - self.connected_providers[provider_object_id] = self.provider - self.network_manager.active_provider = self.provider + # Adding provider for the first time. Retain connection. + self.connected_providers[provider_id] = self.provider + self.network_manager.active_provider = self.provider return self.provider def pop_provider(self): - if not self.connected_providers or not self.provider_stack: + if self.empty: return # Clear last provider @@ -511,7 +518,7 @@ def pop_provider(self): self.network_manager.active_provider = previous_provider def disconnect_all(self): - if not self.connected_providers or not self.provider_stack: + if self.empty: return for provider in self.connected_providers.values(): @@ -523,7 +530,11 @@ def disconnect_all(self): @classmethod def get_provider_id(cls, provider: "ProviderAPI") -> Optional[str]: try: - return f"{provider.name}-{provider.chain_id}" + return ( + f"{provider.network.ecosystem.name}:" + f"{provider.network.name}:{provider.name}-" + f"{provider.chain_id}" + ) except ProviderNotConnectedError: return None diff --git a/tests/conftest.py b/tests/conftest.py index ff51dd155f..d0e0747964 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -139,7 +139,7 @@ def ethereum(networks): @pytest.fixture(scope="session") def eth_tester_provider(networks_connected_to_tester): - yield networks_connected_to_tester.active_provider + yield networks_connected_to_tester.provider @pytest.fixture diff --git a/tests/functional/test_networks.py b/tests/functional/test_networks.py index cffba2408c..90d15fe3d7 100644 --- a/tests/functional/test_networks.py +++ b/tests/functional/test_networks.py @@ -1,5 +1,3 @@ -from contextlib import contextmanager - import pytest from ape.exceptions import NetworkError @@ -17,32 +15,24 @@ def __call__(self, *args, **kwargs) -> int: chain_id_factory = NewChainID() -@pytest.fixture(scope="module") -def get_context(networks_connected_to_tester): +@pytest.fixture +def get_provider_with_unused_chain_id(networks_connected_to_tester): + networks = networks_connected_to_tester + def fn(): - return networks_connected_to_tester.parse_network_choice("ethereum:local:test") + new_provider = networks.provider.copy() + new_provider.cached_chain_id = chain_id_factory() + context = networks.parse_network_choice("ethereum:local:test") + context.provider = new_provider + return context return fn -@contextmanager -def _switch_chain_id(networks): - new_chain_id = chain_id_factory() - context = networks.parse_network_choice("ethereum:local:test") - provider_id = f"{context.provider.name}-{new_chain_id}" - original_chain_id = context.provider.cached_chain_id - context.provider.cached_chain_id = new_chain_id - with context: - yield context - - context.provider.cached_chain_id = original_chain_id - del context.connected_providers[provider_id] - - @pytest.fixture(scope="module") -def switch_chain_id(networks_connected_to_tester): +def get_context(networks_connected_to_tester): def fn(): - return _switch_chain_id(networks_connected_to_tester) + return networks_connected_to_tester.parse_network_choice("ethereum:local:test") return fn @@ -192,9 +182,10 @@ def test_parse_network_choice_same_provider(chain, networks_connected_to_tester, assert provider._web3 is not None -def test_parse_network_choice_new_chain_id(switch_chain_id, get_context): +def test_parse_network_choice_new_chain_id(get_provider_with_unused_chain_id, get_context): start_count = len(get_context().connected_providers) - with switch_chain_id() as context: + context = get_provider_with_unused_chain_id() + with context: count = len(context.connected_providers) # Creates new provider since it has a new chain ID @@ -204,12 +195,13 @@ def test_parse_network_choice_new_chain_id(switch_chain_id, get_context): assert provider._web3 is not None -def test_parse_network_choice_multiple_contexts(switch_chain_id): - with switch_chain_id() as first_context: +def test_parse_network_choice_multiple_contexts(get_provider_with_unused_chain_id): + first_context = get_provider_with_unused_chain_id() + with first_context: start_count = len(first_context.connected_providers) expected_next_count = start_count + 1 - - with switch_chain_id() as second_context: + second_context = get_provider_with_unused_chain_id() + with second_context: # Second context should already know about connected providers assert len(first_context.connected_providers) == expected_next_count assert len(second_context.connected_providers) == expected_next_count