Skip to content

Commit

Permalink
Allow callbacks to return a next URL
Browse files Browse the repository at this point in the history
  • Loading branch information
ml-evs committed Jun 29, 2023
1 parent 10d7f2d commit dc71695
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
24 changes: 17 additions & 7 deletions optimade/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class OptimadeClient:
use_async: bool
"""Whether or not to make all requests asynchronously using asyncio."""

callbacks: Optional[List[Callable[[str, Dict], None]]] = None
callbacks: Optional[List[Callable[[str, Dict], Union[None, str]]]] = None
"""A list of callbacks to execute after each successful request, used
to e.g., write to a file, add results to a database or perform additional
filtering.
Expand All @@ -111,6 +111,11 @@ class OptimadeClient:
from the JSON response, with keys 'data', 'meta', 'links', 'errors'
and 'included'.
Each callback can return a string that will be used to replace the `next_url`
queried by the client.
In the case of multiple provided callbacks, only the value returned by the final
callback in the stack will be used.
"""

silent: bool
Expand Down Expand Up @@ -154,7 +159,7 @@ def __init__(
http_client: Optional[
Union[Type[httpx.AsyncClient], Type[requests.Session]]
] = None,
callbacks: Optional[List[Callable[[str, Dict], None]]] = None,
callbacks: Optional[List[Callable[[str, Dict], Union[None, str]]]] = None,
):
"""Create the OPTIMADE client object.
Expand Down Expand Up @@ -1004,10 +1009,11 @@ def _handle_response(
total=results["meta"].get("data_returned", None),
)

callback_url = None
if self.callbacks:
self._execute_callbacks(results, response)
callback_url = self._execute_callbacks(results, response)

next_url = results["links"].get("next", None)
next_url = callback_url or results["links"].get("next", None)
if isinstance(next_url, dict):
next_url = next_url.pop("href")

Expand All @@ -1030,16 +1036,20 @@ def _teardown(self, _task: TaskID, num_results: int) -> None:

def _execute_callbacks(
self, results: Dict, response: Union[httpx.Response, requests.Response]
) -> None:
) -> Union[None, str]:
"""Execute any callbacks registered with the client.
Parameters:
results: The results from the query.
response: The full response from the server.
Returns:
Either `None` or the string value returned from the *final* callback.
"""
request_url = str(response.request.url)
if not self.callbacks:
return
return None
for callback in self.callbacks:
callback(request_url, results)
cb_response = callback(request_url, results)
return cb_response
21 changes: 21 additions & 0 deletions tests/server/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,27 @@ def global_database_callback(_: str, results: Dict):
assert len(container) == 17


@pytest.mark.parametrize("use_async", [True, False])
def test_client_page_skip_callback(async_http_client, http_client, use_async):
def page_skip_callback(_: str, results: Dict) -> Optional[str]:
"""A test callback that skips to the final page of results."""
if len(results["data"]) > 16:
return f"{TEST_URL}/structures?page_offset=16"
return None

cli = OptimadeClient(
base_urls=[TEST_URL],
use_async=use_async,
http_client=async_http_client if use_async else http_client,
callbacks=[page_skip_callback],
)

results = cli.get(response_fields=["chemical_formula_reduced"])

# callback will skip to final page after first query and add duplicate of final result
assert len(results["structures"][""][TEST_URL]["data"]) == 18


@pytest.mark.parametrize("use_async", [True, False])
def test_client_mutable_data_callback(async_http_client, http_client, use_async):
container: Dict[str, str] = {}
Expand Down

0 comments on commit dc71695

Please sign in to comment.