Skip to content

Commit

Permalink
Replace pytest-httpx usage with custom http client in optimade-get
Browse files Browse the repository at this point in the history
  • Loading branch information
ml-evs committed Jan 3, 2023
1 parent 051d43b commit 39f6291
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 29 deletions.
2 changes: 2 additions & 0 deletions optimade/client/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _get(
include_providers,
exclude_providers,
exclude_databases,
**kwargs,
):

if output_file:
Expand All @@ -143,6 +144,7 @@ def _get(
exclude_databases=set(_.strip() for _ in exclude_databases.split(","))
if exclude_databases
else None,
**kwargs,
)
if response_fields:
response_fields = response_fields.split(",")
Expand Down
21 changes: 19 additions & 2 deletions optimade/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ class OptimadeClient:
"""Used internally when querying via `client.structures.get()` to set the
chosen endpoint. Should be reset to `None` outside of all `get()` calls."""

__http_client: Union[httpx.Client, httpx.AsyncClient] = None
"""Override the HTTP client, primarily used for testing."""

def __init__(
self,
base_urls: Optional[Union[str, Iterable[str]]] = None,
Expand All @@ -109,6 +112,7 @@ def __init__(
exclude_providers: Optional[List[str]] = None,
include_providers: Optional[List[str]] = None,
exclude_databases: Optional[List[str]] = None,
http_client: Union[httpx.Client, httpx.AsyncClient] = None,
):
"""Create the OPTIMADE client object.
Expand All @@ -123,6 +127,7 @@ def __init__(
exclude_providers: A set or collection of provider IDs to exclude from queries.
include_providers: A set or collection of provider IDs to include in queries.
exclude_databases: A set or collection of child database URLs to exclude from queries.
http_client: An override for the underlying HTTP client, primarily used for testing.
"""

Expand Down Expand Up @@ -165,6 +170,18 @@ def __init__(

self.use_async = use_async

if http_client:
self.__http_client = http_client
if isinstance(self.__http_client, httpx.AsyncClient):
self.use_async = True
elif isinstance(self.__http_client, httpx.Client):
self.use_async = False
else:
if use_async:
self.__http_client = httpx.AsyncClient
else:
self.__http_client = httpx.Client

def __getattribute__(self, name):
"""Allows entry endpoints to be queried via attribute access, using the
allowed list for this module.
Expand Down Expand Up @@ -584,7 +601,7 @@ async def _get_one_async(
)
results = QueryResults()
try:
async with httpx.AsyncClient(headers=self.headers) as client:
async with self.__http_client(headers=self.headers) as client:
while next_url:

attempts = 0
Expand Down Expand Up @@ -642,7 +659,7 @@ def _get_one(
)
results = QueryResults()
try:
with httpx.Client(headers=self.headers) as client:
with self.__http_client(headers=self.headers) as client:
while next_url:

attempts = 0
Expand Down
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@ pre-commit==2.21.0
pylint==2.15.9
pytest==7.2.0
pytest-cov==4.0.0
pytest-httpx==0.21.2
types-all==1.0.0
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
"jsondiff~=2.0",
"pytest~=7.2",
"pytest-cov~=4.0",
"pytest-httpx~=0.21",
] + server_deps
dev_deps = (
[
Expand Down
7 changes: 7 additions & 0 deletions tests/server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,10 @@ def inner(
raise

return inner


@pytest.fixture(scope="session")
def http_client():
from .utils import HttpxTestClient

return HttpxTestClient
51 changes: 26 additions & 25 deletions tests/server/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from functools import partial
from pathlib import Path

import pytest
Expand All @@ -19,24 +18,14 @@
TEST_URL = TEST_URLS[0]


@pytest.fixture(scope="function")
def httpx_mocked_response(httpx_mock, client):
import httpx

def httpx_mock_response(client, request: httpx.Request):
response = client.get(str(request.url))
return httpx.Response(status_code=response.status_code, json=response.json())

httpx_mock.add_callback(partial(httpx_mock_response, client))
yield httpx_mock


@pytest.mark.parametrize("use_async", [False])
def test_client_endpoints(httpx_mocked_response, use_async):
def test_client_endpoints(http_client, use_async):

filter = ""

cli = OptimadeClient(base_urls=[TEST_URL], use_async=use_async)
cli = OptimadeClient(
base_urls=[TEST_URL], use_async=use_async, http_client=http_client
)
get_results = cli.get()
assert get_results["structures"][filter][TEST_URL]["data"]
assert (
Expand Down Expand Up @@ -76,8 +65,10 @@ def test_client_endpoints(httpx_mocked_response, use_async):


@pytest.mark.parametrize("use_async", [False])
def test_filter_validation(use_async):
cli = OptimadeClient(use_async=use_async, base_urls=TEST_URL)
def test_filter_validation(http_client, use_async):
cli = OptimadeClient(
use_async=use_async, base_urls=TEST_URL, http_client=http_client
)
with pytest.raises(Exception):
cli.get("completely wrong filter")

Expand All @@ -86,9 +77,13 @@ def test_filter_validation(use_async):


@pytest.mark.parametrize("use_async", [False])
def test_client_response_fields(httpx_mocked_response, use_async):
def test_client_response_fields(http_client, use_async):
with pytest.warns(MissingExpectedField):
cli = OptimadeClient(base_urls=[TEST_URL], use_async=use_async)
cli = OptimadeClient(
base_urls=[TEST_URL],
use_async=use_async,
http_client=http_client,
)
results = cli.get(response_fields=["chemical_formula_reduced"])
for d in results["structures"][""][TEST_URL]["data"]:
assert "chemical_formula_reduced" in d["attributes"]
Expand All @@ -104,8 +99,10 @@ def test_client_response_fields(httpx_mocked_response, use_async):


@pytest.mark.parametrize("use_async", [False])
def test_multiple_base_urls(httpx_mocked_response, use_async):
cli = OptimadeClient(base_urls=TEST_URLS, use_async=use_async)
def test_multiple_base_urls(http_client, use_async):
cli = OptimadeClient(
base_urls=TEST_URLS, use_async=use_async, http_client=http_client
)
results = cli.get()
count_results = cli.count()
for url in TEST_URLS:
Expand All @@ -117,7 +114,7 @@ def test_multiple_base_urls(httpx_mocked_response, use_async):


@pytest.mark.parametrize("use_async", [False])
def test_include_exclude_providers(use_async):
def test_include_exclude_providers(http_client, use_async):
with pytest.raises(
SystemExit,
match="Unable to access any OPTIMADE base URLs. If you believe this is an error, try manually specifying some base URLs.",
Expand All @@ -126,6 +123,7 @@ def test_include_exclude_providers(use_async):
include_providers={"exmpl"},
exclude_providers={"exmpl"},
use_async=use_async,
http_client=http_client,
)

with pytest.raises(
Expand All @@ -150,14 +148,16 @@ def test_include_exclude_providers(use_async):


@pytest.mark.parametrize("use_async", [False])
def test_client_sort(httpx_mocked_response, use_async):
cli = OptimadeClient(base_urls=[TEST_URL], use_async=use_async)
def test_client_sort(http_client, use_async):
cli = OptimadeClient(
base_urls=[TEST_URL], use_async=use_async, http_client=http_client
)
results = cli.get(sort="last_modified")
assert len(results["structures"][""][TEST_URL]["data"]) > 0


@pytest.mark.parametrize("use_async", [False])
def test_command_line_client(httpx_mocked_response, use_async, capsys):
def test_command_line_client(http_client, use_async, capsys):
from optimade.client.cli import _get

args = dict(
Expand All @@ -174,6 +174,7 @@ def test_command_line_client(httpx_mocked_response, use_async, capsys):
include_providers=None,
exclude_providers=None,
exclude_databases=None,
http_client=http_client,
)

# Test multi-provider query
Expand Down
14 changes: 14 additions & 0 deletions tests/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,17 @@ def test_response_okay(self):
assert (
self.response.status_code == 200
), f"Request to {self.request_str} failed: {self.response.content}"


class HttpxTestClient(httpx.Client):
"""An HTTP client wrapper that calls the regular test server."""

client = client_factory()(server="regular")

def request( # pylint: disable=too-many-locals
self,
method: str,
url: httpx._types.URLTypes,
**kwargs,
) -> httpx.Response:
return self.client.request(method, url, **kwargs)

0 comments on commit 39f6291

Please sign in to comment.