Skip to content

Commit

Permalink
Handle provider-specific and unknown response_fields (closes #516)
Browse files Browse the repository at this point in the history
  • Loading branch information
ml-evs committed May 27, 2021
1 parent 9a2ed12 commit d68086b
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 20 deletions.
30 changes: 29 additions & 1 deletion optimade/server/entry_collections/entry_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from optimade.server.exceptions import BadRequest, Forbidden
from optimade.server.mappers import BaseResourceMapper
from optimade.server.query_params import EntryListingQueryParams, SingleEntryQueryParams
from optimade.server.warnings import FieldValueNotRecognized
from optimade.server.warnings import FieldValueNotRecognized, UnknownProviderProperty


def create_collection(
Expand Down Expand Up @@ -144,6 +144,33 @@ def find(
)

exclude_fields = self.all_fields - response_fields
include_fields = (
response_fields - self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS
)

bad_optimade_fields = set()
bad_provider_fields = set()
for field in include_fields:
if field not in self.resource_mapper.ALL_ATTRIBUTES:
if field.startswith("_"):
if any(
field.startswith(f"_{prefix}_")
for prefix in self.resource_mapper.SUPPORTED_PREFIXES
):
bad_provider_fields.add(field)
else:
bad_optimade_fields.add(field)

if bad_provider_fields:
warnings.warn(
message=f"Unrecognised field(s) for this provider requested in `response_fields`: {bad_provider_fields}.",
category=UnknownProviderProperty,
)

if bad_optimade_fields:
raise BadRequest(
detail=f"Unrecognised OPTIMADE field(s) in requested `response_fields`: {bad_optimade_fields}."
)

if results:
if isinstance(results, dict):
Expand All @@ -159,6 +186,7 @@ def find(
data_returned,
more_data_available,
exclude_fields,
include_fields,
)

@abstractmethod
Expand Down
3 changes: 3 additions & 0 deletions optimade/server/mappers/links.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from optimade.server.mappers.entries import BaseResourceMapper
from optimade.models.links import LinksResource

__all__ = ("LinksMapper",)

Expand All @@ -7,6 +8,8 @@ class LinksMapper(BaseResourceMapper):

ENDPOINT = "links"

ENTRY_RESOURCE_CLASS = LinksResource

@classmethod
def map_back(cls, doc: dict) -> dict:
"""Map properties from MongoDB to OPTIMADE
Expand Down
52 changes: 34 additions & 18 deletions optimade/server/routers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
import urllib
from datetime import datetime
from typing import Union, List, Dict
from typing import Union, List, Dict, Set

from fastapi import HTTPException, Request
from starlette.datastructures import URL as StarletteURL
Expand Down Expand Up @@ -63,34 +63,38 @@ def meta_values(


def handle_response_fields(
results: Union[List[EntryResource], EntryResource], exclude_fields: set
results: Union[List[EntryResource], EntryResource],
exclude_fields: Set[str],
include_fields: Set[str],
) -> dict:
"""Handle query parameter ``response_fields``
It is assumed that all fields are under ``attributes``.
This is due to all other top-level fields are REQUIRED in the response.
:param exclude_fields: Fields under ``attributes`` to be excluded from the response.
:param include_fields: Fields under `attributes` that were requested that should be
set to null if missing in the entry.
"""
if not isinstance(results, list):
results = [results]

new_results = []
while results:
entry = results.pop(0)

# TODO: re-enable exclude_unset when proper handling of known/unknown fields
# has been implemented (relevant issue: https://github.com/Materials-Consortia/optimade-python-tools/issues/263)
# Have to handle top level fields explicitly here for now
new_entry = entry.dict(exclude_unset=False)
for field in ("relationships", "links", "meta", "type", "id"):
if field in new_entry and new_entry[field] is None:
del new_entry[field]
new_entry = results.pop(0).dict(exclude_unset=True)

# Remove fields excluded by their omission in `reponse_fields`
for field in exclude_fields:
if field in new_entry["attributes"]:
del new_entry["attributes"][field]

# Include missing fields that were requested in `response_fields`
for field in include_fields:
if field not in new_entry["attributes"]:
new_entry["attributes"][field] = None

new_results.append(new_entry)

return new_results


Expand Down Expand Up @@ -164,7 +168,7 @@ def get_included_relationships(
)

# still need to handle pagination
ref_results, _, _, _ = ENTRY_COLLECTIONS[entry_type].find(params)
ref_results, _, _, _, _ = ENTRY_COLLECTIONS[entry_type].find(params)
included[entry_type] = ref_results

# flatten dict by endpoint to list
Expand Down Expand Up @@ -201,7 +205,13 @@ def get_entries(
"""Generalized /{entry} endpoint getter"""
from optimade.server.routers import ENTRY_COLLECTIONS

results, data_returned, more_data_available, fields = collection.find(params)
(
results,
data_returned,
more_data_available,
fields,
include_fields,
) = collection.find(params)

include = []
if getattr(params, "include", False):
Expand All @@ -219,8 +229,8 @@ def get_entries(
else:
links = ToplevelLinks(next=None)

if fields:
results = handle_response_fields(results, fields)
if fields or include_fields:
results = handle_response_fields(results, fields, include_fields)

return response(
links=links,
Expand All @@ -245,7 +255,13 @@ def get_single_entry(
from optimade.server.routers import ENTRY_COLLECTIONS

params.filter = f'id="{entry_id}"'
results, data_returned, more_data_available, fields = collection.find(params)
(
results,
data_returned,
more_data_available,
fields,
include_fields,
) = collection.find(params)

include = []
if getattr(params, "include", False):
Expand All @@ -260,8 +276,8 @@ def get_single_entry(

links = ToplevelLinks(next=None)

if fields and results is not None:
results = handle_response_fields(results, fields)[0]
if fields or include_fields and results is not None:
results = handle_response_fields(results, fields, include_fields)[0]

return response(
links=links,
Expand Down
2 changes: 1 addition & 1 deletion tests/server/query_params/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def inner(
expected_fields |= (
get_mapper[endpoint].get_required_fields() - known_unused_fields
)
expected_fields.add("attributes")
request = f"/{endpoint}?response_fields={','.join(expected_fields)}"

response = get_good_response(request, server)
expected_fields.add("attributes")

response_fields = set()
for entry in response["data"]:
Expand Down

0 comments on commit d68086b

Please sign in to comment.