Skip to content

Commit

Permalink
fix: issue with networks context when networks shared chain ID (#980)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Aug 17, 2022
1 parent 4bf1647 commit 39c99ea
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 41 deletions.
37 changes: 24 additions & 13 deletions src/ape/api/networks.py
Expand Up @@ -465,33 +465,40 @@ def __init__(self, provider: "ProviderAPI", network_manager: "NetworkManager"):
self.provider = provider
self.network_manager = network_manager

@property
def empty(self) -> bool:
return not self.connected_providers or not self.provider_stack

def __enter__(self, *args, **kwargs):
return self.push_provider()

def __exit__(self, *args, **kwargs):
self.pop_provider()

def push_provider(self):
if not self.provider.is_connected:
must_connect = not self.provider.is_connected
if must_connect:
self.provider.connect()

provider_object_id = self.get_provider_id(self.provider)
self.provider_stack.append(provider_object_id)
provider_id = self.get_provider_id(self.provider)
self.provider_stack.append(provider_id)

if provider_object_id in self.connected_providers:
# Already connected and known
connected_provider = self.connected_providers[provider_object_id]
self.network_manager.active_provider = connected_provider
if provider_id in self.connected_providers:
# Using already connected instance
if must_connect:
# Disconnect if had to connect to check chain ID
self.provider.disconnect()

self.provider = self.connected_providers[provider_id]
else:
# Already connected and unknown
self.connected_providers[provider_object_id] = self.provider
self.network_manager.active_provider = self.provider
# Adding provider for the first time. Retain connection.
self.connected_providers[provider_id] = self.provider

self.network_manager.active_provider = self.provider
return self.provider

def pop_provider(self):
if not self.connected_providers or not self.provider_stack:
if self.empty:
return

# Clear last provider
Expand All @@ -511,7 +518,7 @@ def pop_provider(self):
self.network_manager.active_provider = previous_provider

def disconnect_all(self):
if not self.connected_providers or not self.provider_stack:
if self.empty:
return

for provider in self.connected_providers.values():
Expand All @@ -523,7 +530,11 @@ def disconnect_all(self):
@classmethod
def get_provider_id(cls, provider: "ProviderAPI") -> Optional[str]:
try:
return f"{provider.name}-{provider.chain_id}"
return (
f"{provider.network.ecosystem.name}:"
f"{provider.network.name}:{provider.name}-"
f"{provider.chain_id}"
)
except ProviderNotConnectedError:
return None

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Expand Up @@ -139,7 +139,7 @@ def ethereum(networks):

@pytest.fixture(scope="session")
def eth_tester_provider(networks_connected_to_tester):
yield networks_connected_to_tester.active_provider
yield networks_connected_to_tester.provider


@pytest.fixture
Expand Down
46 changes: 19 additions & 27 deletions tests/functional/test_networks.py
@@ -1,5 +1,3 @@
from contextlib import contextmanager

import pytest

from ape.exceptions import NetworkError
Expand All @@ -17,32 +15,24 @@ def __call__(self, *args, **kwargs) -> int:
chain_id_factory = NewChainID()


@pytest.fixture(scope="module")
def get_context(networks_connected_to_tester):
@pytest.fixture
def get_provider_with_unused_chain_id(networks_connected_to_tester):
networks = networks_connected_to_tester

def fn():
return networks_connected_to_tester.parse_network_choice("ethereum:local:test")
new_provider = networks.provider.copy()
new_provider.cached_chain_id = chain_id_factory()
context = networks.parse_network_choice("ethereum:local:test")
context.provider = new_provider
return context

return fn


@contextmanager
def _switch_chain_id(networks):
new_chain_id = chain_id_factory()
context = networks.parse_network_choice("ethereum:local:test")
provider_id = f"{context.provider.name}-{new_chain_id}"
original_chain_id = context.provider.cached_chain_id
context.provider.cached_chain_id = new_chain_id
with context:
yield context

context.provider.cached_chain_id = original_chain_id
del context.connected_providers[provider_id]


@pytest.fixture(scope="module")
def switch_chain_id(networks_connected_to_tester):
def get_context(networks_connected_to_tester):
def fn():
return _switch_chain_id(networks_connected_to_tester)
return networks_connected_to_tester.parse_network_choice("ethereum:local:test")

return fn

Expand Down Expand Up @@ -192,9 +182,10 @@ def test_parse_network_choice_same_provider(chain, networks_connected_to_tester,
assert provider._web3 is not None


def test_parse_network_choice_new_chain_id(switch_chain_id, get_context):
def test_parse_network_choice_new_chain_id(get_provider_with_unused_chain_id, get_context):
start_count = len(get_context().connected_providers)
with switch_chain_id() as context:
context = get_provider_with_unused_chain_id()
with context:
count = len(context.connected_providers)

# Creates new provider since it has a new chain ID
Expand All @@ -204,12 +195,13 @@ def test_parse_network_choice_new_chain_id(switch_chain_id, get_context):
assert provider._web3 is not None


def test_parse_network_choice_multiple_contexts(switch_chain_id):
with switch_chain_id() as first_context:
def test_parse_network_choice_multiple_contexts(get_provider_with_unused_chain_id):
first_context = get_provider_with_unused_chain_id()
with first_context:
start_count = len(first_context.connected_providers)
expected_next_count = start_count + 1

with switch_chain_id() as second_context:
second_context = get_provider_with_unused_chain_id()
with second_context:
# Second context should already know about connected providers
assert len(first_context.connected_providers) == expected_next_count
assert len(second_context.connected_providers) == expected_next_count
Expand Down

0 comments on commit 39c99ea

Please sign in to comment.