Skip to content

Commit

Permalink
Get provider list in mapper and emit warnings when filter field prefi…
Browse files Browse the repository at this point in the history
…x does not match a known provider

- Elasticsearch fix for provider fields

- Add warning checks to tests
- Fallback to no providers
  • Loading branch information
ml-evs committed May 18, 2021
1 parent 5738cb4 commit 9ab9858
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 47 deletions.
33 changes: 22 additions & 11 deletions optimade/filtertransformers/base_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@

import abc
from typing import Dict, Any, Type
import warnings

from lark import Transformer, v_args, Tree

from optimade.server.mappers import BaseResourceMapper
from optimade.server.exceptions import BadRequest
from optimade.server.warnings import UnknownProviderProperty


__all__ = (
Expand Down Expand Up @@ -240,18 +242,27 @@ def property(self, args):

if self.quantities and quantity_name not in self.quantities:
# If the quantity is provider-specific, but does not match this provider,
# then return None so that this query can be dropped.
if (
self.mapper
and quantity_name.startswith("_")
and not any(
quantity_name.split("_")[1] == p
for p in self.mapper.SUPPORTED_PREFIXES
)
):
return quantity_name
# then return the quantity name such that it can be trreated as unknown.
# If the prefix does not match another known provider, also emit a warning
# If the prefix does match a known prpovider, do not return a warning.
# Following [Handling unknown property names](https://github.com/Materials-Consortia/OPTIMADE/blob/master/optimade.rst#handling-unknown-property-names)
if self.mapper and quantity_name.startswith("_"):
prefix = quantity_name.split("_")[1]
if not any(prefix == p for p in self.mapper.SUPPORTED_PREFIXES):
if not any(
prefix == p for p in self.mapper.KNOWN_PROVIDER_PREFIXES
):
warnings.warn(
UnknownProviderProperty(
f"Field {quantity_name!r} has an unrecognised prefix: this property has been treated as UNKNOWN."
)
)

return quantity_name

raise BadRequest(detail=f"'{quantity_name}' is not a searchable quantity")
raise BadRequest(
detail=f"'{quantity_name}' is not a known or searchable quantity"
)

quantity = self.quantities.get(quantity_name, None)
if quantity is None:
Expand Down
6 changes: 6 additions & 0 deletions optimade/filtertransformers/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,12 @@ def length_op_rhs(self, args):
op = "="

def query(quantity):

# This is only the case if quantity is an "other" provider's field,
# in which case, we should treat it as unknown and try to do a null query
if isinstance(quantity, str):
return self._query_op(quantity, op, value)

if quantity.length_quantity is None:
raise NotImplementedError(
"LENGTH is not supported for '%s'" % quantity.name
Expand Down
23 changes: 16 additions & 7 deletions optimade/server/mappers/entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,33 @@ class BaseResourceMapper:
the specification.
Attributes:
ENDPOINT (str): defines the endpoint for which to apply this
mapper.
ALIASES (Tuple[Tuple[str, str]]): a tuple of aliases between
ALIASES: a tuple of aliases between
OPTIMADE field names and the field names in the database ,
e.g. `(("elements", "custom_elements_field"))`.
LENGTH_ALIASES (Tuple[Tuple[str, str]]): a tuple of aliases between
LENGTH_ALIASES: a tuple of aliases between
a field name and another field that defines its length, to be used
when querying, e.g. `(("elements", "nelements"))`.
e.g. `(("elements", "custom_elements_field"))`.
PROVIDER_FIELDS (Tuple[str]): a tuple of extra field names that this
ENTRY_RESOURCE_CLASS: The entry type that this mapper corresponds to.
PROVIDER_FIELDS: a tuple of extra field names that this
mapper should support when querying with the database prefix.
REQUIRED_FIELDS (set[str]): the set of fieldnames to return
REQUIRED_FIELDS: the set of fieldnames to return
when mapping to the OPTIMADE format.
TOP_LEVEL_NON_ATTRIBUTES_FIELDS (set[str]): the set of top-level
TOP_LEVEL_NON_ATTRIBUTES_FIELDS: the set of top-level
field names common to all endpoints.
"""

try:
from optimade.server.data import (
providers as PROVIDERS,
) # pylint: disable=no-name-in-module
except (ImportError, ModuleNotFoundError):
PROVIDERS = {}

KNOWN_PROVIDER_PREFIXES: Set[str] = set(
prov["id"] for prov in PROVIDERS.get("data", [])
)
ALIASES: Tuple[Tuple[str, str]] = ()
LENGTH_ALIASES: Tuple[Tuple[str, str]] = ()
PROVIDER_FIELDS: Tuple[str] = ()
Expand Down
7 changes: 7 additions & 0 deletions optimade/server/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,10 @@ class TimestampNotRFCCompliant(OptimadeWarning):
RFC 3339 compliant. This may cause undefined behaviour in the query results.
"""


class UnknownProviderProperty(OptimadeWarning):
"""A provider-specific property has been requested via `response_fields` or as in a `filter` that is not
recognised by this implementation.
"""
4 changes: 2 additions & 2 deletions tests/filtertransformers/test_elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_parse_n_transform(query, parser, transformer):
def test_bad_queries(parser, transformer):
filter_ = "unknown_field = 0"
with pytest.raises(
Exception, match="'unknown_field' is not a searchable quantity"
Exception, match="'unknown_field' is not a known or searchable quantity"
) as exc_info:
transformer.transform(parser.parse(filter_))
assert exc_info.type.__name__ == "VisitError"
Expand All @@ -90,7 +90,7 @@ def test_bad_queries(parser, transformer):

filter_ = "_exmpl_field = 1"
with pytest.raises(
Exception, match="'_exmpl_field' is not a searchable quantity"
Exception, match="'_exmpl_field' is not a known or searchable quantity"
) as exc_info:
transformer.transform(parser.parse(filter_))
assert exc_info.type.__name__ == "VisitError"
13 changes: 12 additions & 1 deletion tests/server/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Dict

import pytest

Expand Down Expand Up @@ -109,6 +109,7 @@ def inner(
page_limit: int = CONFIG.page_limit,
expected_return: int = None,
expected_as_is: bool = False,
expected_warnings: List[Dict[str, str]] = None,
server: Union[str, OptimadeTestClient] = "regular",
):
response = get_good_response(request, server)
Expand All @@ -128,6 +129,16 @@ def inner(
else:
assert expected_ids == response_ids

expected_warnings = expected_warnings if expected_warnings else []
if expected_warnings:
assert "warnings" in response["meta"]
assert len(expected_warnings) == len(response["meta"]["warnings"])
for ind, warn in enumerate(expected_warnings):
for key in warn:
assert response["meta"]["warnings"][ind][key] == warn[key]
else:
assert "warnings" not in response["meta"]

return inner


Expand Down
99 changes: 73 additions & 26 deletions tests/server/query_params/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,39 +210,57 @@ def test_list_correlated(check_error_response):


def test_timestamp_query(check_response):
from optimade.server.warnings import TimestampNotRFCCompliant

request = '/structures?filter=last_modified="2019-06-08T05:13:37.331Z"&page_limit=5'
expected_ids = ["mpf_1", "mpf_2", "mpf_3"]
if CONFIG.database_backend == SupportedBackend.ELASTIC:
check_response(request, expected_ids, expected_as_is=True)
else:
with pytest.warns(TimestampNotRFCCompliant):
check_response(request, expected_ids, expected_as_is=True)
expected_warnings = None
if CONFIG.database_backend in (
SupportedBackend.MONGOMOCK,
SupportedBackend.MONGODB,
):
expected_warnings = [{"title": "TimestampNotRFCCompliant"}]
check_response(
request, expected_ids, expected_as_is=True, expected_warnings=expected_warnings
)

request = '/structures?filter=last_modified<"2019-06-08T05:13:37.331Z"&page_limit=5'
expected_ids = ["mpf_3819"]
if CONFIG.database_backend == SupportedBackend.ELASTIC:
check_response(request, expected_ids, expected_as_is=True)
else:
with pytest.warns(TimestampNotRFCCompliant):
check_response(request, expected_ids, expected_as_is=True)
expected_warnings = None
if CONFIG.database_backend in (
SupportedBackend.MONGOMOCK,
SupportedBackend.MONGODB,
):
expected_warnings = [{"title": "TimestampNotRFCCompliant"}]
check_response(
request, expected_ids, expected_as_is=True, expected_warnings=expected_warnings
)

request = '/structures?filter=last_modified="2018-06-08T05:13:37.945Z"&page_limit=5'
expected_ids = ["mpf_3819"]
if CONFIG.database_backend == SupportedBackend.ELASTIC:
check_response(request, expected_ids, expected_as_is=True)
else:
with pytest.warns(TimestampNotRFCCompliant):
check_response(request, expected_ids, expected_as_is=True)
expected_warnings = None
if CONFIG.database_backend in (
SupportedBackend.MONGOMOCK,
SupportedBackend.MONGODB,
):
expected_warnings = [{"title": "TimestampNotRFCCompliant"}]
check_response(
request, expected_ids, expected_as_is=True, expected_warnings=expected_warnings
)

request = '/structures?filter=last_modified>"2018-06-08T05:13:37.945Z" AND last_modified<="2019-06-08T05:13:37.331Z"&page_limit=5'
expected_ids = ["mpf_1", "mpf_2", "mpf_3"]
if CONFIG.database_backend == SupportedBackend.ELASTIC:
check_response(request, expected_ids, expected_as_is=True)
else:
with pytest.warns(TimestampNotRFCCompliant):
check_response(request, expected_ids, expected_as_is=True)
expected_warnings = None
if CONFIG.database_backend in (
SupportedBackend.MONGOMOCK,
SupportedBackend.MONGODB,
):
expected_warnings = [
{"title": "TimestampNotRFCCompliant"},
{"title": "TimestampNotRFCCompliant"},
]
check_response(
request, expected_ids, expected_as_is=True, expected_warnings=expected_warnings
)


def test_is_known(check_response):
Expand Down Expand Up @@ -410,7 +428,7 @@ def test_filter_on_relationships(check_response, check_error_response):
def test_filter_on_unknown_fields(check_response, check_error_response):

request = "/structures?filter=unknown_field = 1"
error_detail = "'unknown_field' is not a searchable quantity"
error_detail = "'unknown_field' is not a known or searchable quantity"
check_error_response(
request,
expected_status=400,
Expand All @@ -419,7 +437,7 @@ def test_filter_on_unknown_fields(check_response, check_error_response):
)

request = "/structures?filter=_exmpl_unknown_field = 1"
error_detail = "'_exmpl_unknown_field' is not a searchable quantity"
error_detail = "'_exmpl_unknown_field' is not a known or searchable quantity"
check_error_response(
request,
expected_status=400,
Expand All @@ -428,7 +446,7 @@ def test_filter_on_unknown_fields(check_response, check_error_response):
)

request = "/structures?filter=_exmpl_unknown_field LENGTH 1"
error_detail = "'_exmpl_unknown_field' is not a searchable quantity"
error_detail = "'_exmpl_unknown_field' is not a known or searchable quantity"
check_error_response(
request,
expected_status=400,
Expand All @@ -438,12 +456,41 @@ def test_filter_on_unknown_fields(check_response, check_error_response):

request = "/structures?filter=_exmpl1_unknown_field = 1"
expected_ids = []
check_response(request, expected_ids=expected_ids)
expected_warnings = [
{
"title": "UnknownProviderProperty",
"detail": "Field '_exmpl1_unknown_field' has an unrecognised prefix: this property has been treated as UNKNOWN.",
}
]
check_response(
request, expected_ids=expected_ids, expected_warnings=expected_warnings
)

request = "/structures?filter=_exmpl1_unknown_field LENGTH 1"
expected_ids = []
check_response(request, expected_ids=expected_ids)
expected_warnings = [
{
"title": "UnknownProviderProperty",
"detail": "Field '_exmpl1_unknown_field' has an unrecognised prefix: this property has been treated as UNKNOWN.",
}
]
check_response(
request, expected_ids=expected_ids, expected_warnings=expected_warnings
)

request = '/structures?filter=_exmpl1_unknown_field HAS "Si"'
expected_ids = []
expected_warnings = [
{
"title": "UnknownProviderProperty",
"detail": "Field '_exmpl1_unknown_field' has an unrecognised prefix: this property has been treated as UNKNOWN.",
}
]
check_response(
request, expected_ids=expected_ids, expected_warnings=expected_warnings
)

# Should not warn as the "_optimade" prefix is registered
request = "/structures?filter=_optimade_random_field = 1"
expected_ids = []
check_response(request, expected_ids=expected_ids)

0 comments on commit 9ab9858

Please sign in to comment.