Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mongo length operator functionality with length aliases #222

Merged
merged 20 commits into from
Mar 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion default_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@
},
"provider_fields": {},
"aliases": {},
"length_aliases": {},
"index_links_path": "/Users/shyamd/Codes/optimade-python-tools/optimade/server/index_links.json"
}
}
5 changes: 5 additions & 0 deletions example_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,10 @@
"chemical_formula_reduced": "pretty_formula",
"chemical_formula_anonymous": "formula_anonymous"
}
},
"length_aliases": {
"structures": {
"chemsys": "nelements"
}
}
}
177 changes: 175 additions & 2 deletions optimade/filtertransformers/mongo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import copy
from lark import Transformer, v_args, Token
from optimade.server.mappers import BaseResourceMapper


class MongoTransformer(Transformer):
Expand All @@ -21,9 +23,28 @@ class MongoTransformer(Transformer):
"$eq": "$eq",
}

def __init__(self):
def __init__(self, mapper: BaseResourceMapper = None):
""" Initialise the object, optionally loading in a
resource mapper for use when post-processing.

"""
self.mapper = mapper
super().__init__()

def postprocess(self, query):
""" Used to post-process the final parsed query. """
if self.mapper:
# important to apply length alias before normal aliases
query = self._apply_length_aliases(query)
query = self._apply_aliases(query)

query = self._apply_length_operators(query)

return query

def transform(self, tree):
return self.postprocess(super().transform(tree))

def filter(self, arg):
# filter: expression*
return arg[0] if arg else None
Expand Down Expand Up @@ -155,10 +176,14 @@ def set_op_rhs(self, arg):

def length_op_rhs(self, arg):
# length_op_rhs: LENGTH [ OPERATOR ] value
# TODO: https://stackoverflow.com/questions/7811163/query-for-documents-where-array-size-is-greater-than-1
if len(arg) == 2 or (len(arg) == 3 and arg[1] == "="):
return {"$size": arg[-1]}

elif arg[1] in self.operator_map and arg[1] != "!=":
# create an invalid query that needs to be post-processed
# e.g. {'$size': {'$gt': 2}}, which is not allowed by Mongo.
return {"$size": {self.operator_map[arg[1]]: arg[-1]}}

raise NotImplementedError(
f"Operator {arg[1]} not implemented for LENGTH filter."
)
Expand Down Expand Up @@ -226,3 +251,151 @@ def _recursive_expression_phrase(self, arg):

# simple case of negating one expression, from NOT (expr) to ~expr.
return {prop: {"$not": expr} for prop, expr in arg[1].items()}

def _apply_length_aliases(self, filter_: dict) -> dict:
""" Recursively search query for any $size calls, and check
if the property can be replaced with its corresponding length
alias.

"""

def check_for_size(prop, expr):
return (
isinstance(expr, dict)
and "$size" in expr
and self.mapper.length_alias_for(prop)
)

def replace_with_length_alias(subdict, prop, expr):
subdict[self.mapper.length_alias_for(prop)] = expr["$size"]
subdict[prop].pop("$size")
if not subdict[prop]:
subdict.pop(prop)
return subdict

return recursive_postprocessing(
filter_, check_for_size, replace_with_length_alias
)

def _apply_aliases(self, filter_: dict) -> dict:
""" Check whether any fields in the filter have aliases so
that they can be renamed for the Mongo query.

"""
# if there are no defined aliases, just skip
if not self.mapper.all_aliases():
return filter_

def check_for_alias(prop, expr):
return self.mapper.alias_for(prop) != prop

def apply_alias(subdict, prop, expr):
if isinstance(subdict, dict):
subdict[self.mapper.alias_for(prop)] = self._apply_aliases(
subdict.pop(prop)
)
elif isinstance(subdict, str):
subdict = self.mapper.alias_for(subdict)

return subdict

return recursive_postprocessing(filter_, check_for_alias, apply_alias)

def _apply_length_operators(self, filter_: dict) -> dict:
""" Check for any invalid pymongo queries that involve
applying an operator to the length of a field, and transform
them into a test for existence of the relevant entry, e.g.
"list LENGTH > 3" becomes "does the 4th list entry exist?".
ml-evs marked this conversation as resolved.
Show resolved Hide resolved

"""

def check_for_length_op_filter(prop, expr):
return (
isinstance(expr, dict)
and "$size" in expr
and isinstance(expr["$size"], dict)
)

def apply_length_op(subdict, prop, expr):
# assumes that the dictionary only has one element by design
# (we just made it above in the transformer)
operator, value = list(expr["$size"].items())[0]
if operator in self.operator_map.values() and operator != "$ne":
# worth being explicit here, I think
ml-evs marked this conversation as resolved.
Show resolved Hide resolved
_prop = None
existence = None
if operator == "$gt":
_prop = f"{prop}.{value + 1}"
existence = True
elif operator == "$gte":
_prop = f"{prop}.{value}"
existence = True
elif operator == "$lt":
_prop = f"{prop}.{value}"
existence = False
elif operator == "$lte":
_prop = f"{prop}.{value + 1}"
existence = False
if _prop is not None:
subdict.pop(prop)
subdict[_prop] = {"$exists": existence}

return subdict

return recursive_postprocessing(
filter_, check_for_length_op_filter, apply_length_op,
)


def recursive_postprocessing(filter_, condition, replacement):
""" Recursively descend into the query, checking each dictionary
(contained in a list, or as an entry in another dictionary) for
the condition passed. If the condition is true, apply the
replacement to the dictionary.

Parameters:
filter_ (list/dict): the filter_ to process.
condition (callable): a function that returns True if the
replacement function should be applied. It should take
as arguments the property and expression from the filter_,
as would be returned by iterating over `filter_.items()`.
replacement (callable): a function that returns the processed
dictionary. It should take as arguments the dictionary
to modify, the property and the expression (as described
above).

Example:
For the simple case of replacing one field name with
another, the following functions could be used:

```python
def condition(prop, expr):
return prop == "field_name_old"

def replacement(d, prop, expr):
d["field_name_old"] = d.pop(prop)

filter_ = recursive_postprocessing(
filter_, condition, replacement
)

```

"""
if isinstance(filter_, list):
result = [recursive_postprocessing(q, condition, replacement) for q in filter_]
return result

if isinstance(filter_, dict):
# this could potentially lead to memory leaks if the filter_ is *heavily* nested
_cached_filter = copy.deepcopy(filter_)
for prop, expr in filter_.items():
if condition(prop, expr):
_cached_filter = replacement(_cached_filter, prop, expr)
elif isinstance(expr, list):
_cached_filter[prop] = [
recursive_postprocessing(q, condition, replacement) for q in expr
]
return _cached_filter

return filter_
17 changes: 12 additions & 5 deletions optimade/server/config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import os
import json
from typing import Any, Optional, Dict, List
from typing import Optional, Dict, List

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
from pathlib import Path
from warnings import warn

import json

from pydantic import BaseSettings, Field, root_validator

Expand Down Expand Up @@ -92,6 +88,17 @@ class ServerConfig(BaseSettings):
description="A mapping between field names in the database with their corresponding OPTIMADE field names, broken down by endpoint.",
)

length_aliases: Dict[
Literal["links", "references", "structures"], Dict[str, str]
] = Field(
{},
description=(
"A mapping between a list property (or otherwise) and an integer property that defines the length of that list, "
"for example elements -> nelements. The standard aliases are applied first, so this dictionary must refer to the "
"API fields, not the database fields."
),
)

index_links_path: Path = Field(
Path(__file__).parent.joinpath("index_links.json"),
description="Absolute path to a JSON file containing the MongoDB collection of /links resources for the index meta-database",
Expand Down
4 changes: 4 additions & 0 deletions optimade/server/entry_collections/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .entry_collections import EntryCollection
from .mongo import MongoCollection, client, CI_FORCE_MONGO

__all__ = ["EntryCollection", "MongoCollection", "client", "CI_FORCE_MONGO"]
64 changes: 64 additions & 0 deletions optimade/server/entry_collections/entry_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from abc import abstractmethod
from typing import Collection, Tuple, List

from optimade.server.mappers import BaseResourceMapper
from optimade.filterparser import LarkParser
from optimade.models import EntryResource
from optimade.server.query_params import EntryListingQueryParams


class EntryCollection(Collection): # pylint: disable=inherit-non-class
def __init__(
self,
collection,
resource_cls: EntryResource,
resource_mapper: BaseResourceMapper,
):
self.collection = collection
self.parser = LarkParser()
self.resource_cls = resource_cls
self.resource_mapper = resource_mapper

def __len__(self):
return self.collection.count()

def __iter__(self):
return self.collection.find()

def __contains__(self, entry):
return self.collection.count(entry) > 0

def get_attribute_fields(self) -> set:
schema = self.resource_cls.schema()
attributes = schema["properties"]["attributes"]
if "allOf" in attributes:
allOf = attributes.pop("allOf")
for dict_ in allOf:
attributes.update(dict_)
if "$ref" in attributes:
path = attributes["$ref"].split("/")[1:]
attributes = schema.copy()
while path:
next_key = path.pop(0)
attributes = attributes[next_key]
return set(attributes["properties"].keys())

@abstractmethod
def find(
self, params: EntryListingQueryParams
) -> Tuple[List[EntryResource], int, bool, set]:
"""
Fetches results and indicates if more data is available.

Also gives the total number of data available in the absence of page_limit.

Args:
params (EntryListingQueryParams): entry listing URL query params

Returns:
Tuple[List[Entry], int, bool, set]: (results, data_returned, more_data_available, fields)

"""

def count(self, **kwargs):
return self.collection.count(**kwargs)