Skip to content

Commit

Permalink
Type hint tweaks and refactors from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Johan Bergsma <JPBergsma@users.noreply.github.com>
  • Loading branch information
ml-evs and JPBergsma committed Nov 24, 2022
1 parent 38bb9ca commit e92e1a5
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 19 deletions.
11 changes: 6 additions & 5 deletions optimade/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class OptimadeClient:

def __init__(
self,
base_urls: Optional[Union[str, List[str]]] = None,
base_urls: Optional[Union[str, Iterable[str]]] = None,
max_results_per_provider: int = 1000,
headers: Optional[Dict] = None,
http_timeout: int = 10,
Expand All @@ -111,14 +111,15 @@ def __init__(
"""

if not base_urls:
base_urls = get_all_databases() # type: ignore[assignment]

self.max_results_per_provider = max_results_per_provider
if self.max_results_per_provider in (-1, 0):
self.max_results_per_provider = None

self.base_urls = base_urls # type: ignore[assignment]
if not base_urls:
self.base_urls = get_all_databases()
else:
self.base_urls = base_urls

if isinstance(self.base_urls, str):
self.base_urls = [self.base_urls]
self.base_urls = list(self.base_urls)
Expand Down
15 changes: 8 additions & 7 deletions optimade/filterparser/lark_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""

from collections import defaultdict
from pathlib import Path
from typing import Dict, Optional, Tuple

Expand All @@ -21,7 +20,7 @@ class ParserError(Exception):
"""


def get_versions() -> Dict[Tuple[int, int, int], Dict[str, str]]:
def get_versions() -> Dict[Tuple[int, int, int], Dict[str, Path]]:
"""Find grammar files within this package's grammar directory,
returning a dictionary broken down by scraped grammar version
(major, minor, patch) and variant (a string tag).
Expand All @@ -30,13 +29,15 @@ def get_versions() -> Dict[Tuple[int, int, int], Dict[str, str]]:
A mapping from version, variant to grammar file name.
"""
dct: Dict[Tuple[int, int, int], Dict[str, Path]] = defaultdict(dict)
dct: Dict[Tuple[int, int, int], Dict[str, Path]] = {}
for filename in Path(__file__).parent.joinpath("../grammar").glob("*.lark"):
tags = filename.stem.lstrip("v").split(".")
version = tuple(map(int, tags[:3])) # ignore: type[index]
variant = "default" if len(tags) == 3 else tags[-1]
dct[version][variant] = filename # type: ignore[index]
return dict(dct) # type: ignore[arg-type]
version: Tuple[int, int, int] = (int(tags[0]), int(tags[1]), int(tags[2]))
variant: str = "default" if len(tags) == 3 else str(tags[-1])
if version not in dct:
dct[version] = {}
dct[version][variant] = filename
return dct


AVAILABLE_PARSERS = get_versions()
Expand Down
4 changes: 2 additions & 2 deletions optimade/models/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
EPS = 2**-23


Vector3D = conlist(float, min_items=3, max_items=3) # type: ignore[valid-type]
Vector3D_unknown = conlist(Union[float, None], min_items=3, max_items=3) # type: ignore[valid-type]
Vector3D = conlist(float, min_items=3, max_items=3)
Vector3D_unknown = conlist(Union[float, None], min_items=3, max_items=3)


class Periodicity(IntEnum):
Expand Down
4 changes: 2 additions & 2 deletions optimade/server/entry_collections/entry_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
self.provider_prefix = CONFIG.provider.prefix
self.provider_fields = [
field if isinstance(field, str) else field["name"]
for field in CONFIG.provider_fields.get(resource_mapper.ENDPOINT, []) # type: ignore[call-overload]
for field in CONFIG.provider_fields.get(resource_mapper.ENDPOINT, [])
]

self._all_fields: Set[str] = set()
Expand Down Expand Up @@ -376,7 +376,7 @@ def parse_sort_params(self, sort_params: str) -> Iterable[Tuple[str, int]]:
BadRequest: if an invalid sort is requested.
Returns:
A tuple of tuples containing the aliased field name and
A list of tuples containing the aliased field name and
sort direction encoded as 1 (ascending) or -1 (descending).
"""
Expand Down
2 changes: 1 addition & 1 deletion optimade/server/mappers/entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def all_length_aliases(cls) -> Tuple[Tuple[str, str], ...]:
from optimade.server.config import CONFIG

return cls.LENGTH_ALIASES + tuple(
CONFIG.length_aliases.get(cls.ENDPOINT, {}).items() # type: ignore[call-overload]
CONFIG.length_aliases.get(cls.ENDPOINT, {}).items()
)

@classmethod
Expand Down
2 changes: 0 additions & 2 deletions tests/server/test_server_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def test_versioned_base_urls(client, index_client, server: str):
This depends on the routers for each kind of server.
"""
import json

from optimade.server.routers.utils import BASE_URL_PREFIXES

Expand Down Expand Up @@ -129,7 +128,6 @@ def test_meta_schema_value_obeys_index(client, index_client, server: str):
"""Test that the reported `meta->schema` is correct for index/non-index
servers.
"""
import json

from optimade.server.config import CONFIG
from optimade.server.routers.utils import BASE_URL_PREFIXES
Expand Down

0 comments on commit e92e1a5

Please sign in to comment.