Skip to content

Commit

Permalink
feat(client): make numeric arguments accept 'int | float' (#366)
Browse files Browse the repository at this point in the history
Fixes #333
  • Loading branch information
afuetterer committed May 7, 2024
1 parent eca3e3b commit b53bbed
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/oaipmh_scythe/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def __init__(
iterator: type[BaseOAIIterator] = OAIItemIterator,
max_retries: int = 0,
retry_status_codes: Iterable[int] | None = None,
default_retry_after: int = 60,
default_retry_after: int | float = 60,
class_mapping: dict[str, type[OAIItem]] | None = None,
encoding: str = "utf-8",
auth: AuthTypes | None = None,
timeout: int = 60,
timeout: int | float = 60,
):
self.endpoint = endpoint
if http_method not in ("GET", "POST"):
Expand All @@ -97,11 +97,18 @@ def __init__(
raise TypeError("Argument 'iterator' must be subclass of %s" % BaseOAIIterator.__name__)
self.max_retries = max_retries
self.retry_status_codes = retry_status_codes or (503,)
if default_retry_after <= 0:
raise ValueError(
"Invalid value for 'default_retry_after': %s. default_retry_after must be positive int or float."
% default_retry_after
)
self.default_retry_after = default_retry_after
self.oai_namespace = OAI_NAMESPACE
self.class_mapping = class_mapping or DEFAULT_CLASS_MAP
self.encoding = encoding
self.auth = auth
if timeout <= 0:
raise ValueError("Invalid value for 'timeout': %s. Timeout must be positive int or float." % timeout)
self.timeout = timeout
self._client: httpx.Client | None = None

Expand Down Expand Up @@ -388,7 +395,7 @@ def list_metadata_formats(self, identifier: str | None = None) -> Iterator[OAIRe
query = remove_none_values(_query)
yield from self.iterator(self, query)

def get_retry_after(self, http_response: httpx.Response) -> int:
def get_retry_after(self, http_response: httpx.Response) -> int | float:
"""Determine the appropriate time to wait before retrying a request, based on the server's response.
Check the status code of the provided HTTP response. If it's 503 (Service Unavailable),
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,27 @@ def test_auth_arguments_usage(respx_mock: MockRouter) -> None:
respx_mock.get("https://zenodo.org/oai2d").mock(return_value=httpx.Response(200))
oai_response = scythe.harvest(query)
assert oai_response.http_response.request.headers["authorization"]


@pytest.mark.parametrize("timeout", [10, 10.0, 0.1])
def test_valid_custom_timeout(timeout):
with Scythe("https://zenodo.org/oai2d", timeout=timeout) as scythe:
assert scythe.client.timeout


@pytest.mark.parametrize("timeout", [-1, -1.0, 0, 0.0])
def test_invalid_custom_timeout(timeout):
with pytest.raises(ValueError, match="Invalid value for 'timeout'"):
Scythe("https://zenodo.org/oai2d", timeout=timeout)


@pytest.mark.parametrize("retry_after", [10, 10.0, 0.1])
def test_valid_custom_retry_after(retry_after):
with Scythe("https://zenodo.org/oai2d", default_retry_after=retry_after) as scythe:
assert scythe.default_retry_after


@pytest.mark.parametrize("retry_after", [-1, -1.0, 0, 0.0])
def test_invalid_custom_retry_after(retry_after):
with pytest.raises(ValueError, match="Invalid value for 'default_retry_after'"):
Scythe("https://zenodo.org/oai2d", default_retry_after=retry_after)

0 comments on commit b53bbed

Please sign in to comment.