Skip to content

Commit

Permalink
refactor: custom decorator for disconnected
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Oct 4, 2022
1 parent ea2e60b commit 2a1c92a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 24 deletions.
30 changes: 8 additions & 22 deletions src/ape/pytest/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from contextlib import contextmanager
from typing import Any, Callable, Iterator, List, Optional
from typing import Iterator, List, Optional

import pytest
from _pytest.config import Config as PytestConfig
from evm_trace.gas import merge_reports

from ape.api import TestAccountAPI
from ape.exceptions import ProviderNotConnectedError
from ape.logging import logger
from ape.managers.chain import ChainManager
from ape.managers.networks import NetworkManager
from ape.managers.project import ProjectManager
from ape.types import GasReport, SnapshotID
from ape.utils import CallTraceParser, ManagerAccessMixin, cached_property
from ape.utils import CallTraceParser, ManagerAccessMixin, allow_disconnected, cached_property


class PytestApeFixtures(ManagerAccessMixin):
Expand Down Expand Up @@ -111,10 +110,10 @@ def _func_isolation(self) -> Iterator[None]:
_class_isolation = pytest.fixture(_isolation, scope="class")
_function_isolation = pytest.fixture(_func_isolation, scope="function")

@allow_disconnected
def _snapshot(self) -> Optional[SnapshotID]:
try:
fn = self.chain_manager.snapshot
return _silence_connection_failure(fn)
return self.chain_manager.snapshot()
except NotImplementedError:
if not self._warned_for_unimplemented_snapshot:
logger.warning(
Expand All @@ -125,29 +124,16 @@ def _snapshot(self) -> Optional[SnapshotID]:

return None

@allow_disconnected
def _restore(self, snapshot_id: SnapshotID):
if snapshot_id not in self.chain_manager._snapshots:
return

_silence_connection_failure(lambda: self.chain_manager.restore(snapshot_id))
self.chain_manager.restore(snapshot_id)

@allow_disconnected
def _get_block_number(self) -> Optional[int]:
return _silence_connection_failure(lambda: self.provider.get_block("latest").number)


def _silence_connection_failure(fn: Callable) -> Optional[Any]:
"""
When tests fail, the provider may become disconnected.
Rather than cause more failures, let ``pytest`` complete.
Returns ``None`` when gets ``ProviderNotConnectedError``.
"""

try:
return fn()
except ProviderNotConnectedError:
logger.warning("Provider became disconnected mid-test.")
return None
return self.provider.get_block("latest").number


class ReceiptCapture(ManagerAccessMixin):
Expand Down
2 changes: 2 additions & 0 deletions src/ape/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
USER_AGENT,
ZERO_ADDRESS,
add_padding_to_strings,
allow_disconnected,
cached_property,
extract_nested_value,
gas_estimation_error_message,
Expand Down Expand Up @@ -54,6 +55,7 @@
__all__ = [
"abstractmethod",
"add_padding_to_strings",
"allow_disconnected",
"BaseInterface",
"BaseInterfaceModel",
"cached_property",
Expand Down
32 changes: 30 additions & 2 deletions src/ape/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from functools import cached_property, lru_cache, singledispatchmethod
from pathlib import Path
from typing import Any, Coroutine, Dict, List, Mapping, Optional
from typing import Any, Callable, Coroutine, Dict, List, Mapping, Optional

import requests
import yaml
Expand All @@ -12,7 +12,7 @@
from importlib_metadata import version as version_metadata
from tqdm.auto import tqdm # type: ignore

from ape.exceptions import APINotImplementedError
from ape.exceptions import APINotImplementedError, ProviderNotConnectedError
from ape.logging import logger
from ape.utils.os import expand_environment_variables

Expand Down Expand Up @@ -292,7 +292,35 @@ def run_until_complete(item: Any) -> Any:
return loop.run_until_complete(item)


def allow_disconnected(fn: Callable):
"""
A decorator that instead of raising :class:`~ape.exceptions.ProviderNotConnectedError`
warns and returns ``None``.
Usage example::
from typing import Optional
from ape.types import SnapshotID
from ape.utils import return_none_when_disconnected
@return_none_when_disconnected
def try_snapshot(self) -> Optional[SnapshotID]:
return self.chain.snapshot()
"""

def inner(*args, **kwargs):
try:
return fn(*args, **kwargs)
except ProviderNotConnectedError:
logger.warning("Provider is not connected.")
return None

return inner


__all__ = [
"allow_disconnected",
"cached_property",
"extract_nested_value",
"gas_estimation_error_message",
Expand Down

0 comments on commit 2a1c92a

Please sign in to comment.