Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compatibility for httpx-based TestClient for latest FastAPI version #1460

Merged
merged 5 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions optimade/client/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _get(
include_providers,
exclude_providers,
exclude_databases,
**kwargs,
JPBergsma marked this conversation as resolved.
Show resolved Hide resolved
):

if output_file:
Expand All @@ -143,6 +144,7 @@ def _get(
exclude_databases=set(_.strip() for _ in exclude_databases.split(","))
if exclude_databases
else None,
**kwargs,
)
if response_fields:
response_fields = response_fields.split(",")
Expand Down
21 changes: 19 additions & 2 deletions optimade/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ class OptimadeClient:
"""Used internally when querying via `client.structures.get()` to set the
chosen endpoint. Should be reset to `None` outside of all `get()` calls."""

__http_client: Union[httpx.Client, httpx.AsyncClient] = None
"""Override the HTTP client, primarily used for testing."""

def __init__(
self,
base_urls: Optional[Union[str, Iterable[str]]] = None,
Expand All @@ -109,6 +112,7 @@ def __init__(
exclude_providers: Optional[List[str]] = None,
include_providers: Optional[List[str]] = None,
exclude_databases: Optional[List[str]] = None,
http_client: Union[httpx.Client, httpx.AsyncClient] = None,
):
"""Create the OPTIMADE client object.

Expand All @@ -123,6 +127,7 @@ def __init__(
exclude_providers: A set or collection of provider IDs to exclude from queries.
include_providers: A set or collection of provider IDs to include in queries.
exclude_databases: A set or collection of child database URLs to exclude from queries.
http_client: An override for the underlying HTTP client, primarily used for testing.

"""

Expand Down Expand Up @@ -165,6 +170,18 @@ def __init__(

self.use_async = use_async

if http_client:
self.__http_client = http_client
if isinstance(self.__http_client, httpx.AsyncClient):
self.use_async = True
elif isinstance(self.__http_client, httpx.Client):
self.use_async = False
else:
if use_async:
self.__http_client = httpx.AsyncClient
else:
self.__http_client = httpx.Client

def __getattribute__(self, name):
"""Allows entry endpoints to be queried via attribute access, using the
allowed list for this module.
Expand Down Expand Up @@ -584,7 +601,7 @@ async def _get_one_async(
)
results = QueryResults()
try:
async with httpx.AsyncClient(headers=self.headers) as client:
async with self.__http_client(headers=self.headers) as client:
while next_url:

attempts = 0
Expand Down Expand Up @@ -642,7 +659,7 @@ def _get_one(
)
results = QueryResults()
try:
with httpx.Client(headers=self.headers) as client:
with self.__http_client(headers=self.headers) as client:
while next_url:

attempts = 0
Expand Down
9 changes: 6 additions & 3 deletions optimade/validator/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,13 @@ def __init__( # pylint: disable=too-many-arguments
)

# some simple checks on base_url
self.base_url = str(self.base_url)
self.base_url_parsed = urllib.parse.urlparse(self.base_url)
# only allow filters/endpoints if we are working in "as_type" mode
if self.as_type_cls is None and self.base_url_parsed.query:
raise SystemExit("Base URL not appropriate: should not contain a filter.")
raise SystemExit(
f"Base URL {self.base_url} not appropriate: should not contain a filter."
)

self.valid = None

Expand Down Expand Up @@ -805,7 +808,7 @@ def _construct_single_property_filters(
expected_status_code=(200, 501),
)

if not response:
if response.status_code != 200:
if query_optional:
return (
None,
Expand All @@ -822,7 +825,7 @@ def _construct_single_property_filters(
f"Required field `meta->more_data_available` missing from response for {request}."
)

if not response["meta"]["more_data_available"]:
if not response["meta"]["more_data_available"] and "data" in response:
num_data_returned[operator] = len(response["data"])
else:
num_data_returned[operator] = response["meta"].get("data_returned")
Expand Down
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@ pre-commit==2.21.0
pylint==2.15.9
pytest==7.2.0
pytest-cov==4.0.0
pytest-httpx==0.21.2
types-all==1.0.0
2 changes: 1 addition & 1 deletion requirements-server.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
elasticsearch==7.17.7
elasticsearch-dsl==7.4.0
fastapi==0.86.0
fastapi==0.88.0
mongomock==4.1.2
pymongo==4.3.3
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
"jsondiff~=2.0",
"pytest~=7.2",
"pytest-cov~=4.0",
"pytest-httpx~=0.21",
] + server_deps
dev_deps = (
[
Expand Down
7 changes: 7 additions & 0 deletions tests/server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,10 @@ def inner(
raise

return inner


@pytest.fixture(scope="session")
def http_client():
from .utils import HttpxTestClient

return HttpxTestClient
4 changes: 2 additions & 2 deletions tests/server/middleware/test_api_hint.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_url_changes(both_clients, get_good_response):
response = get_good_response(query_url, server=both_clients, return_json=False)

assert (
unquote(response.url)
unquote(str(response.url))
== f"{both_clients.base_url}{BASE_URL_PREFIXES['major']}{query_url.split('&')[0]}"
)

Expand All @@ -68,7 +68,7 @@ def test_url_changes(both_clients, get_good_response):

response = get_good_response(query_url, server=both_clients, return_json=False)

assert unquote(response.url) == f"{both_clients.base_url}{query_url}"
assert unquote(str(response.url)) == f"{both_clients.base_url}{query_url}"


def test_is_versioned_base_url(both_clients):
Expand Down
51 changes: 26 additions & 25 deletions tests/server/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from functools import partial
from pathlib import Path

import pytest
Expand All @@ -19,24 +18,14 @@
TEST_URL = TEST_URLS[0]


@pytest.fixture(scope="function")
def httpx_mocked_response(httpx_mock, client):
import httpx

def httpx_mock_response(client, request: httpx.Request):
response = client.get(str(request.url))
return httpx.Response(status_code=response.status_code, json=response.json())

httpx_mock.add_callback(partial(httpx_mock_response, client))
yield httpx_mock


@pytest.mark.parametrize("use_async", [False])
def test_client_endpoints(httpx_mocked_response, use_async):
def test_client_endpoints(http_client, use_async):

filter = ""

cli = OptimadeClient(base_urls=[TEST_URL], use_async=use_async)
cli = OptimadeClient(
base_urls=[TEST_URL], use_async=use_async, http_client=http_client
)
get_results = cli.get()
assert get_results["structures"][filter][TEST_URL]["data"]
assert (
Expand Down Expand Up @@ -76,8 +65,10 @@ def test_client_endpoints(httpx_mocked_response, use_async):


@pytest.mark.parametrize("use_async", [False])
def test_filter_validation(use_async):
cli = OptimadeClient(use_async=use_async, base_urls=TEST_URL)
def test_filter_validation(http_client, use_async):
cli = OptimadeClient(
use_async=use_async, base_urls=TEST_URL, http_client=http_client
)
with pytest.raises(Exception):
cli.get("completely wrong filter")

Expand All @@ -86,9 +77,13 @@ def test_filter_validation(use_async):


@pytest.mark.parametrize("use_async", [False])
def test_client_response_fields(httpx_mocked_response, use_async):
def test_client_response_fields(http_client, use_async):
with pytest.warns(MissingExpectedField):
cli = OptimadeClient(base_urls=[TEST_URL], use_async=use_async)
cli = OptimadeClient(
base_urls=[TEST_URL],
use_async=use_async,
http_client=http_client,
)
results = cli.get(response_fields=["chemical_formula_reduced"])
for d in results["structures"][""][TEST_URL]["data"]:
assert "chemical_formula_reduced" in d["attributes"]
Expand All @@ -104,8 +99,10 @@ def test_client_response_fields(httpx_mocked_response, use_async):


@pytest.mark.parametrize("use_async", [False])
def test_multiple_base_urls(httpx_mocked_response, use_async):
cli = OptimadeClient(base_urls=TEST_URLS, use_async=use_async)
def test_multiple_base_urls(http_client, use_async):
cli = OptimadeClient(
base_urls=TEST_URLS, use_async=use_async, http_client=http_client
)
results = cli.get()
count_results = cli.count()
for url in TEST_URLS:
Expand All @@ -117,7 +114,7 @@ def test_multiple_base_urls(httpx_mocked_response, use_async):


@pytest.mark.parametrize("use_async", [False])
def test_include_exclude_providers(use_async):
def test_include_exclude_providers(http_client, use_async):
with pytest.raises(
SystemExit,
match="Unable to access any OPTIMADE base URLs. If you believe this is an error, try manually specifying some base URLs.",
Expand All @@ -126,6 +123,7 @@ def test_include_exclude_providers(use_async):
include_providers={"exmpl"},
exclude_providers={"exmpl"},
use_async=use_async,
http_client=http_client,
)

with pytest.raises(
Expand All @@ -150,14 +148,16 @@ def test_include_exclude_providers(use_async):


@pytest.mark.parametrize("use_async", [False])
def test_client_sort(httpx_mocked_response, use_async):
cli = OptimadeClient(base_urls=[TEST_URL], use_async=use_async)
def test_client_sort(http_client, use_async):
cli = OptimadeClient(
base_urls=[TEST_URL], use_async=use_async, http_client=http_client
)
results = cli.get(sort="last_modified")
assert len(results["structures"][""][TEST_URL]["data"]) > 0


@pytest.mark.parametrize("use_async", [False])
def test_command_line_client(httpx_mocked_response, use_async, capsys):
def test_command_line_client(http_client, use_async, capsys):
from optimade.client.cli import _get

args = dict(
Expand All @@ -174,6 +174,7 @@ def test_command_line_client(httpx_mocked_response, use_async, capsys):
include_providers=None,
exclude_providers=None,
exclude_databases=None,
http_client=http_client,
)

# Test multi-provider query
Expand Down
26 changes: 21 additions & 5 deletions tests/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from typing import Iterable, Optional, Type, Union
from urllib.parse import urlparse

import httpx
import pytest
from fastapi.testclient import TestClient
from requests import Response
from starlette import testclient

import optimade.models.jsonapi as jsonapi
Expand Down Expand Up @@ -52,9 +52,11 @@ def __init__(
def request( # pylint: disable=too-many-locals
self,
method: str,
url: str,
url: httpx._types.URLTypes,
**kwargs,
) -> Response:
) -> httpx.Response:

url = str(url)
if (
re.match(r"/?v[0-9](.[0-9]){0,2}/", url) is None
and not urlparse(url).scheme
Expand All @@ -75,7 +77,7 @@ class BaseEndpointTests:
request_str: Optional[str] = None
response_cls: Optional[Type[jsonapi.Response]] = None

response: Optional[Response] = None
response: Optional[httpx.Response] = None
json_response: Optional[dict] = None

@staticmethod
Expand Down Expand Up @@ -224,7 +226,7 @@ class NoJsonEndpointTests:
request_str: Optional[str] = None
response_cls: Optional[Type] = None

response: Optional[Response] = None
response: Optional[httpx.Response] = None

@pytest.fixture(autouse=True)
def get_response(self, both_clients):
Expand All @@ -238,3 +240,17 @@ def test_response_okay(self):
assert (
self.response.status_code == 200
), f"Request to {self.request_str} failed: {self.response.content}"


class HttpxTestClient(httpx.Client):
"""An HTTP client wrapper that calls the regular test server."""

client = client_factory()(server="regular")

def request( # pylint: disable=too-many-locals
self,
method: str,
url: httpx._types.URLTypes,
**kwargs,
) -> httpx.Response:
return self.client.request(method, url, **kwargs)