Skip to content

Commit

Permalink
Update tests and client to properly test async mode (#1517)
Browse files Browse the repository at this point in the history
  • Loading branch information
ml-evs committed Feb 19, 2023
1 parent a8d9209 commit a029dbf
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 30 deletions.
33 changes: 25 additions & 8 deletions optimade/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ class OptimadeClient:
__http_client: Union[httpx.Client, httpx.AsyncClient] = None
"""Override the HTTP client, primarily used for testing."""

__strict_async: bool = False
"""Whether or not to fallover if `use_async` is true yet asynchronous mode
is impossible due to, e.g., a running event loop.
"""

def __init__(
self,
base_urls: Optional[Union[str, Iterable[str]]] = None,
Expand All @@ -119,7 +124,7 @@ def __init__(
Parameters:
base_urls: A list of OPTIMADE base URLs to query.
max_results_per_provider: The maximum number of results to download
from each provider.
from each provider (-1 or 0 indicate unlimited).
headers: Any additional HTTP headers to use for the queries.
http_timeout: The timeout to use per request. Defaults to 10
seconds with 1000 seconds for reads specifically. Overriding this value
Expand Down Expand Up @@ -178,16 +183,24 @@ def __init__(
self.use_async = use_async

if http_client:
self.__http_client = http_client
if isinstance(self.__http_client, httpx.AsyncClient):
self._http_client = http_client
if issubclass(self._http_client, httpx.AsyncClient):
if not self.use_async and self.__strict_async:
raise RuntimeError(
"Cannot use synchronous mode with an asynchronous HTTP client, please set `use_async=True` or pass an asynchronous HTTP client."
)
self.use_async = True
elif isinstance(self.__http_client, httpx.Client):
elif issubclass(self._http_client, httpx.Client):
if self.use_async and self.__strict_async:
raise RuntimeError(
"Cannot use async mode with a synchronous HTTP client, please set `use_async=False` or pass an synchronous HTTP client."
)
self.use_async = False
else:
if use_async:
self.__http_client = httpx.AsyncClient
self._http_client = httpx.AsyncClient
else:
self.__http_client = httpx.Client
self._http_client = httpx.Client

def __getattribute__(self, name):
"""Allows entry endpoints to be queried via attribute access, using the
Expand Down Expand Up @@ -369,6 +382,10 @@ def _execute_queries(
try:
event_loop = asyncio.get_running_loop()
if event_loop:
if self.__strict_async:
raise RuntimeError(
"Detected a running event loop, cannot run in async mode."
)
self._progress.print(
"Detected a running event loop (e.g., Jupyter, pytest). Running in synchronous mode."
)
Expand Down Expand Up @@ -612,7 +629,7 @@ async def _get_one_async(
)
results = QueryResults()
try:
async with self.__http_client(headers=self.headers) as client:
async with self._http_client(headers=self.headers) as client:
while next_url:
attempts = 0
try:
Expand Down Expand Up @@ -669,7 +686,7 @@ def _get_one(
)
results = QueryResults()
try:
with self.__http_client(headers=self.headers) as client:
with self._http_client(headers=self.headers) as client:
while next_url:
attempts = 0
try:
Expand Down
7 changes: 7 additions & 0 deletions tests/server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,10 @@ def http_client():
from .utils import HttpxTestClient

return HttpxTestClient


@pytest.fixture(scope="session")
def async_http_client():
from .utils import AsyncHttpxTestClient

return AsyncHttpxTestClient
74 changes: 52 additions & 22 deletions tests/server/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@
from optimade.warnings import MissingExpectedField

try:
from optimade.client import OptimadeClient
from optimade.client import OptimadeClient as OptimadeTestClient
except ImportError as exc:
pytest.skip(str(exc), allow_module_level=True)


class OptimadeClient(OptimadeTestClient):
"""Wrapper to the base OptimadeClient that enables strict mode for testing."""

__strict_async = True


TEST_URLS = [
"https://example.com",
"https://example.org",
Expand All @@ -18,12 +25,17 @@
TEST_URL = TEST_URLS[0]


@pytest.mark.parametrize("use_async", [False])
def test_client_endpoints(http_client, use_async):
@pytest.mark.parametrize(
"use_async",
[True, False],
)
def test_client_endpoints(async_http_client, http_client, use_async):
filter = ""

cli = OptimadeClient(
base_urls=[TEST_URL], use_async=use_async, http_client=http_client
base_urls=[TEST_URL],
use_async=use_async,
http_client=async_http_client if use_async else http_client,
)
get_results = cli.get()
assert get_results["structures"][filter][TEST_URL]["data"]
Expand Down Expand Up @@ -63,10 +75,12 @@ def test_client_endpoints(http_client, use_async):
assert "properties" in count_results["info/structures"][""][TEST_URL]["data"]


@pytest.mark.parametrize("use_async", [False])
def test_filter_validation(http_client, use_async):
@pytest.mark.parametrize("use_async", [True, False])
def test_filter_validation(async_http_client, http_client, use_async):
cli = OptimadeClient(
use_async=use_async, base_urls=TEST_URL, http_client=http_client
use_async=use_async,
base_urls=TEST_URL,
http_client=async_http_client if use_async else http_client,
)
with pytest.raises(Exception):
cli.get("completely wrong filter")
Expand All @@ -75,13 +89,13 @@ def test_filter_validation(http_client, use_async):
cli.get("elements HAS 'Ag'")


@pytest.mark.parametrize("use_async", [False])
def test_client_response_fields(http_client, use_async):
@pytest.mark.parametrize("use_async", [True, False])
def test_client_response_fields(async_http_client, http_client, use_async):
with pytest.warns(MissingExpectedField):
cli = OptimadeClient(
base_urls=[TEST_URL],
use_async=use_async,
http_client=http_client,
http_client=async_http_client if use_async else http_client,
)
results = cli.get(response_fields=["chemical_formula_reduced"])
for d in results["structures"][""][TEST_URL]["data"]:
Expand All @@ -97,10 +111,12 @@ def test_client_response_fields(http_client, use_async):
assert len(d["attributes"]) == 2


@pytest.mark.parametrize("use_async", [False])
def test_multiple_base_urls(http_client, use_async):
@pytest.mark.parametrize("use_async", [True, False])
def test_multiple_base_urls(async_http_client, http_client, use_async):
cli = OptimadeClient(
base_urls=TEST_URLS, use_async=use_async, http_client=http_client
base_urls=TEST_URLS,
use_async=use_async,
http_client=async_http_client if use_async else http_client,
)
results = cli.get()
count_results = cli.count()
Expand All @@ -112,8 +128,8 @@ def test_multiple_base_urls(http_client, use_async):
)


@pytest.mark.parametrize("use_async", [False])
def test_include_exclude_providers(http_client, use_async):
@pytest.mark.parametrize("use_async", [True, False])
def test_include_exclude_providers(async_http_client, http_client, use_async):
with pytest.raises(
SystemExit,
match="Unable to access any OPTIMADE base URLs. If you believe this is an error, try manually specifying some base URLs.",
Expand All @@ -122,7 +138,7 @@ def test_include_exclude_providers(http_client, use_async):
include_providers={"exmpl"},
exclude_providers={"exmpl"},
use_async=use_async,
http_client=http_client,
http_client=async_http_client if use_async else http_client,
)

with pytest.raises(
Expand All @@ -133,6 +149,7 @@ def test_include_exclude_providers(http_client, use_async):
base_urls=TEST_URLS,
include_providers={"exmpl"},
use_async=use_async,
http_client=async_http_client if use_async else http_client,
)

with pytest.raises(
Expand All @@ -143,20 +160,23 @@ def test_include_exclude_providers(http_client, use_async):
include_providers={"exmpl"},
exclude_databases={"https://example.org/optimade"},
use_async=use_async,
http_client=async_http_client if use_async else http_client,
)


@pytest.mark.parametrize("use_async", [False])
def test_client_sort(http_client, use_async):
@pytest.mark.parametrize("use_async", [True, False])
def test_client_sort(async_http_client, http_client, use_async):
cli = OptimadeClient(
base_urls=[TEST_URL], use_async=use_async, http_client=http_client
base_urls=[TEST_URL],
use_async=use_async,
http_client=async_http_client if use_async else http_client,
)
results = cli.get(sort="last_modified")
assert len(results["structures"][""][TEST_URL]["data"]) > 0


@pytest.mark.parametrize("use_async", [False])
def test_command_line_client(http_client, use_async, capsys):
@pytest.mark.parametrize("use_async", [True, False])
def test_command_line_client(async_http_client, http_client, use_async, capsys):
import httpx

from optimade.client.cli import _get
Expand All @@ -175,7 +195,7 @@ def test_command_line_client(http_client, use_async, capsys):
include_providers=None,
exclude_providers=None,
exclude_databases=None,
http_client=http_client,
http_client=async_http_client if use_async else http_client,
http_timeout=httpx.Timeout(2.0),
)

Expand Down Expand Up @@ -220,3 +240,13 @@ def test_command_line_client(http_client, use_async, capsys):
assert len(results["structures"]['elements HAS "Ag"'][url]["errors"]) == 0
assert len(results["structures"]['elements HAS "Ag"'][url]["meta"]) > 0
Path(test_filename).unlink()


@pytest.mark.parametrize("use_async", [True, False])
def test_strict_async(async_http_client, http_client, use_async):
with pytest.raises(RuntimeError):
_ = OptimadeClient(
base_urls=TEST_URLS,
use_async=use_async,
http_client=http_client if use_async else async_http_client,
)
14 changes: 14 additions & 0 deletions tests/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,17 @@ def request( # pylint: disable=too-many-locals
**kwargs,
) -> httpx.Response:
return self.client.request(method, url)


class AsyncHttpxTestClient(httpx.AsyncClient):
"""An async HTTP client wrapper that calls the regular test server."""

client = client_factory()(server="regular")

async def request( # pylint: disable=too-many-locals
self,
method: str,
url: httpx._types.URLTypes,
**kwargs,
) -> httpx.Response:
return self.client.request(method, url)

0 comments on commit a029dbf

Please sign in to comment.