Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tests and client to properly test async mode #1517

Merged
merged 1 commit into from
Feb 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 25 additions & 8 deletions optimade/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ class OptimadeClient:
__http_client: Union[httpx.Client, httpx.AsyncClient] = None
"""Override the HTTP client, primarily used for testing."""

__strict_async: bool = False
"""Whether or not to fallover if `use_async` is true yet asynchronous mode
is impossible due to, e.g., a running event loop.
"""

def __init__(
self,
base_urls: Optional[Union[str, Iterable[str]]] = None,
Expand All @@ -119,7 +124,7 @@ def __init__(
Parameters:
base_urls: A list of OPTIMADE base URLs to query.
max_results_per_provider: The maximum number of results to download
from each provider.
from each provider (-1 or 0 indicate unlimited).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if you only want to know the number of entries that are in the database, without retrieving any actual data?
It would seem to me that if I would fill in 0 that would be exactly what I get. So it seems a bit strange to let "0" mean all results.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what the --count flag is for, which only requests a page_limit=1 and does no pagination.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring change just reflects the current behavior anyway, if its desirable we can change both

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the clarification.
It seems that the optimade python tools indeed return all results for page_limit=0.
So it makes sense to do the same here.

Perhaps we could support returning 0 resources for page_limit=0 in a future version of the optimade python tools.
So the backend database only has to do the count operation and not a find operation as well.

headers: Any additional HTTP headers to use for the queries.
http_timeout: The timeout to use per request. Defaults to 10
seconds with 1000 seconds for reads specifically. Overriding this value
Expand Down Expand Up @@ -178,16 +183,24 @@ def __init__(
self.use_async = use_async

if http_client:
self.__http_client = http_client
if isinstance(self.__http_client, httpx.AsyncClient):
self._http_client = http_client
if issubclass(self._http_client, httpx.AsyncClient):
if not self.use_async and self.__strict_async:
raise RuntimeError(
"Cannot use synchronous mode with an asynchronous HTTP client, please set `use_async=True` or pass an asynchronous HTTP client."
)
self.use_async = True
elif isinstance(self.__http_client, httpx.Client):
elif issubclass(self._http_client, httpx.Client):
if self.use_async and self.__strict_async:
raise RuntimeError(
"Cannot use async mode with a synchronous HTTP client, please set `use_async=False` or pass an synchronous HTTP client."
)
self.use_async = False
else:
if use_async:
self.__http_client = httpx.AsyncClient
self._http_client = httpx.AsyncClient
else:
self.__http_client = httpx.Client
self._http_client = httpx.Client

def __getattribute__(self, name):
"""Allows entry endpoints to be queried via attribute access, using the
Expand Down Expand Up @@ -369,6 +382,10 @@ def _execute_queries(
try:
event_loop = asyncio.get_running_loop()
if event_loop:
if self.__strict_async:
raise RuntimeError(
"Detected a running event loop, cannot run in async mode."
)
self._progress.print(
"Detected a running event loop (e.g., Jupyter, pytest). Running in synchronous mode."
)
Expand Down Expand Up @@ -612,7 +629,7 @@ async def _get_one_async(
)
results = QueryResults()
try:
async with self.__http_client(headers=self.headers) as client:
async with self._http_client(headers=self.headers) as client:
while next_url:
attempts = 0
try:
Expand Down Expand Up @@ -669,7 +686,7 @@ def _get_one(
)
results = QueryResults()
try:
with self.__http_client(headers=self.headers) as client:
with self._http_client(headers=self.headers) as client:
while next_url:
attempts = 0
try:
Expand Down
7 changes: 7 additions & 0 deletions tests/server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,10 @@ def http_client():
from .utils import HttpxTestClient

return HttpxTestClient


@pytest.fixture(scope="session")
def async_http_client():
from .utils import AsyncHttpxTestClient

return AsyncHttpxTestClient
74 changes: 52 additions & 22 deletions tests/server/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@
from optimade.warnings import MissingExpectedField

try:
from optimade.client import OptimadeClient
from optimade.client import OptimadeClient as OptimadeTestClient
except ImportError as exc:
pytest.skip(str(exc), allow_module_level=True)


class OptimadeClient(OptimadeTestClient):
"""Wrapper to the base OptimadeClient that enables strict mode for testing."""

__strict_async = True


TEST_URLS = [
"https://example.com",
"https://example.org",
Expand All @@ -18,12 +25,17 @@
TEST_URL = TEST_URLS[0]


@pytest.mark.parametrize("use_async", [False])
def test_client_endpoints(http_client, use_async):
@pytest.mark.parametrize(
"use_async",
[True, False],
)
def test_client_endpoints(async_http_client, http_client, use_async):
filter = ""

cli = OptimadeClient(
base_urls=[TEST_URL], use_async=use_async, http_client=http_client
base_urls=[TEST_URL],
use_async=use_async,
http_client=async_http_client if use_async else http_client,
)
get_results = cli.get()
assert get_results["structures"][filter][TEST_URL]["data"]
Expand Down Expand Up @@ -63,10 +75,12 @@ def test_client_endpoints(http_client, use_async):
assert "properties" in count_results["info/structures"][""][TEST_URL]["data"]


@pytest.mark.parametrize("use_async", [False])
def test_filter_validation(http_client, use_async):
@pytest.mark.parametrize("use_async", [True, False])
def test_filter_validation(async_http_client, http_client, use_async):
cli = OptimadeClient(
use_async=use_async, base_urls=TEST_URL, http_client=http_client
use_async=use_async,
base_urls=TEST_URL,
http_client=async_http_client if use_async else http_client,
)
with pytest.raises(Exception):
cli.get("completely wrong filter")
Expand All @@ -75,13 +89,13 @@ def test_filter_validation(http_client, use_async):
cli.get("elements HAS 'Ag'")


@pytest.mark.parametrize("use_async", [False])
def test_client_response_fields(http_client, use_async):
@pytest.mark.parametrize("use_async", [True, False])
def test_client_response_fields(async_http_client, http_client, use_async):
with pytest.warns(MissingExpectedField):
cli = OptimadeClient(
base_urls=[TEST_URL],
use_async=use_async,
http_client=http_client,
http_client=async_http_client if use_async else http_client,
)
results = cli.get(response_fields=["chemical_formula_reduced"])
for d in results["structures"][""][TEST_URL]["data"]:
Expand All @@ -97,10 +111,12 @@ def test_client_response_fields(http_client, use_async):
assert len(d["attributes"]) == 2


@pytest.mark.parametrize("use_async", [False])
def test_multiple_base_urls(http_client, use_async):
@pytest.mark.parametrize("use_async", [True, False])
def test_multiple_base_urls(async_http_client, http_client, use_async):
cli = OptimadeClient(
base_urls=TEST_URLS, use_async=use_async, http_client=http_client
base_urls=TEST_URLS,
use_async=use_async,
http_client=async_http_client if use_async else http_client,
)
results = cli.get()
count_results = cli.count()
Expand All @@ -112,8 +128,8 @@ def test_multiple_base_urls(http_client, use_async):
)


@pytest.mark.parametrize("use_async", [False])
def test_include_exclude_providers(http_client, use_async):
@pytest.mark.parametrize("use_async", [True, False])
def test_include_exclude_providers(async_http_client, http_client, use_async):
with pytest.raises(
SystemExit,
match="Unable to access any OPTIMADE base URLs. If you believe this is an error, try manually specifying some base URLs.",
Expand All @@ -122,7 +138,7 @@ def test_include_exclude_providers(http_client, use_async):
include_providers={"exmpl"},
exclude_providers={"exmpl"},
use_async=use_async,
http_client=http_client,
http_client=async_http_client if use_async else http_client,
)

with pytest.raises(
Expand All @@ -133,6 +149,7 @@ def test_include_exclude_providers(http_client, use_async):
base_urls=TEST_URLS,
include_providers={"exmpl"},
use_async=use_async,
http_client=async_http_client if use_async else http_client,
)

with pytest.raises(
Expand All @@ -143,20 +160,23 @@ def test_include_exclude_providers(http_client, use_async):
include_providers={"exmpl"},
exclude_databases={"https://example.org/optimade"},
use_async=use_async,
http_client=async_http_client if use_async else http_client,
)


@pytest.mark.parametrize("use_async", [False])
def test_client_sort(http_client, use_async):
@pytest.mark.parametrize("use_async", [True, False])
def test_client_sort(async_http_client, http_client, use_async):
cli = OptimadeClient(
base_urls=[TEST_URL], use_async=use_async, http_client=http_client
base_urls=[TEST_URL],
use_async=use_async,
http_client=async_http_client if use_async else http_client,
)
results = cli.get(sort="last_modified")
assert len(results["structures"][""][TEST_URL]["data"]) > 0


@pytest.mark.parametrize("use_async", [False])
def test_command_line_client(http_client, use_async, capsys):
@pytest.mark.parametrize("use_async", [True, False])
def test_command_line_client(async_http_client, http_client, use_async, capsys):
import httpx

from optimade.client.cli import _get
Expand All @@ -175,7 +195,7 @@ def test_command_line_client(http_client, use_async, capsys):
include_providers=None,
exclude_providers=None,
exclude_databases=None,
http_client=http_client,
http_client=async_http_client if use_async else http_client,
http_timeout=httpx.Timeout(2.0),
)

Expand Down Expand Up @@ -220,3 +240,13 @@ def test_command_line_client(http_client, use_async, capsys):
assert len(results["structures"]['elements HAS "Ag"'][url]["errors"]) == 0
assert len(results["structures"]['elements HAS "Ag"'][url]["meta"]) > 0
Path(test_filename).unlink()


@pytest.mark.parametrize("use_async", [True, False])
def test_strict_async(async_http_client, http_client, use_async):
with pytest.raises(RuntimeError):
_ = OptimadeClient(
base_urls=TEST_URLS,
use_async=use_async,
http_client=http_client if use_async else async_http_client,
)
14 changes: 14 additions & 0 deletions tests/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,17 @@ def request( # pylint: disable=too-many-locals
**kwargs,
) -> httpx.Response:
return self.client.request(method, url)


class AsyncHttpxTestClient(httpx.AsyncClient):
"""An async HTTP client wrapper that calls the regular test server."""

client = client_factory()(server="regular")

async def request( # pylint: disable=too-many-locals
self,
method: str,
url: httpx._types.URLTypes,
**kwargs,
) -> httpx.Response:
return self.client.request(method, url)