Skip to content

Commit

Permalink
Messy proof of concept for length aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
ml-evs committed Mar 10, 2020
1 parent 5a5e903 commit 0d03180
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 1 deletion.
43 changes: 42 additions & 1 deletion optimade/filtertransformers/mongo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from lark import Transformer, v_args, Token
import copy


class MongoTransformer(Transformer):
Expand All @@ -21,9 +22,45 @@ class MongoTransformer(Transformer):
"$eq": "$eq",
}

def __init__(self):
def __init__(self, mapper=None):
self.mapper = mapper
super().__init__()

def postprocess(self, query):
""" Used to post-process the final parsed query. """
if self.mapper:
return self.apply_length_aliases(query)
return query

def apply_length_aliases(self, query):
""" Recursively search query for any $size calls, and check
if the property can be replaced with its corresponding length
alias.
"""
if isinstance(query, list):
return [self.apply_length_aliases(q) for q in query]

if isinstance(query, dict):
_cached_query = copy.deepcopy(query)
for prop, expr in query.items():
if isinstance(expr, dict) and "$size" in expr:
alias = self.mapper.length_alias_for(prop)
if alias:
_cached_query[alias] = expr["$size"]
_cached_query[prop].pop("$size")
if not _cached_query[prop]:
_cached_query.pop(prop)
elif isinstance(expr, list):
_cached_query[prop] = self.apply_length_aliases(expr)

return _cached_query

return query

def transform(self, tree):
return self.postprocess(super().transform(tree))

def filter(self, arg):
# filter: expression*
return arg[0] if arg else None
Expand Down Expand Up @@ -168,6 +205,10 @@ def length_op_rhs(self, arg):
if len(arg) == 2 or (len(arg) == 3 and arg[1] == "="):
return {"$size": arg[-1]}

# keep this disabled for now
# elif arg[1] in self.operator_map:
# return {"$size": {self.operator_map[arg[1]]: arg[2]}}

raise NotImplementedError(
f"Operator {arg[1]} not implemented for LENGTH filter."
)
Expand Down
11 changes: 11 additions & 0 deletions optimade/server/mappers/entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class BaseResourceMapper:

ENDPOINT: str = ""
ALIASES: Tuple[Tuple[str, str]] = ()
LENGTH_ALIASES: Tuple[Tuple[str, str]] = ()
REQUIRED_FIELDS: set = set()
TOP_LEVEL_NON_ATTRIBUTES_FIELDS: set = {"id", "type", "relationships", "links"}

Expand All @@ -38,6 +39,16 @@ def all_aliases(cls) -> Tuple[Tuple[str, str]]:
+ cls.ALIASES
)

@classmethod
def all_length_aliases(cls) -> Tuple[Tuple[str, str]]:
return (cls.LENGTH_ALIASES,)

@classmethod
def length_alias_for(cls, field: str) -> str:
return {alias[0]: alias[1] for alias in cls.all_length_aliases()}.get(
field, None
)

@classmethod
def alias_for(cls, field: str) -> str:
"""Return aliased field name
Expand Down
20 changes: 20 additions & 0 deletions tests/filtertransformers/test_mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,26 @@ def test_not_implemented(self):
with self.assertRaises(VisitError):
self.transform("list LENGTH > 3")

def test_list_length_aliases(self):
from optimade.server.mappers import StructureMapper

class AliasedStructureMapper(StructureMapper):
LENGTH_ALIASES = ("elements", "nelements")

t = MongoTransformer(mapper=AliasedStructureMapper())
p = LarkParser(version=self.version, variant=self.variant)
self.assertEqual(t.transform(p.parse("elements LENGTH 3")), {"nelements": 3})

# self.assertEqual(
# t.transform(p.parse("elements LENGTH > 3")),
# {"nelements": {"$gt": 3}}
# )

self.assertEqual(
t.transform(p.parse('elements HAS "Li" AND elements LENGTH = 3')),
{"$and": [{"elements": {"$in": ["Li"]}}, {"nelements": 3}]},
)

def test_list_properties(self):
""" Test the HAS ALL, ANY and optional ONLY queries.
Expand Down

0 comments on commit 0d03180

Please sign in to comment.