Skip to content

Commit

Permalink
Chage AstToSqlAlchemyClauseVisitor to accept {namespace: entity} di…
Browse files Browse the repository at this point in the history
…cts as requested in issue gorilla-co#24
  • Loading branch information
AndiZeta committed Oct 25, 2022
1 parent e50fa25 commit 43d13be
Showing 1 changed file with 31 additions and 20 deletions.
51 changes: 31 additions & 20 deletions odata_query/sqlalchemy/sqlalchemy_clause.py
@@ -1,6 +1,6 @@
import operator
from collections.abc import Collection
from typing import Any, Callable, List, Optional, Type, Union
from typing import Any, Callable, List, Optional, Type, Union, Dict

from sqlalchemy.inspection import inspect
from sqlalchemy.orm.attributes import InstrumentedAttribute
Expand Down Expand Up @@ -38,34 +38,45 @@ class AstToSqlAlchemyClauseVisitor(visitor.NodeVisitor):
filter clause.
Args:
root_model: The root model of the query.
root_model: The root model of the query. It can be either a single
sqlalchemy ORM entity, a collection of entities, or a dict mapping
namespace to entities.
"""

def __init__(
self, root_model: Union[Type[DeclarativeMeta], List[Type[DeclarativeMeta]]]
self,
root_model: Union[
Type[DeclarativeMeta],
List[Type[DeclarativeMeta]],
Dict[str, Type[DeclarativeMeta]],
],
):
if not isinstance(root_model, Collection):
root_model = [root_model]
self._models = {model.__name__: model for model in root_model}
if isinstance(root_model, dict):
self._models = root_model
else:
if not isinstance(root_model, Collection):
root_model = [root_model]
self._models = {model.__name__: model for model in root_model}
self._models_set = {model.__name__ for model in root_model}
self.join_relationships: List[InstrumentedAttribute] = []

def visit_Identifier(self, node: ast.Identifier) -> ColumnClause:
":meta private:"
# check if the namespace is listed
namespaces = self._models_set.intersection(node.namespace)
for namespace in namespaces:
try:
return getattr(self._models[namespace], node.name)
except AttributeError:
# This is probably a hard error...
pass
# Check all the models. Duplicate names are are an issue
for model in self._models.values():
try:
return getattr(model, node.name)
except AttributeError:
pass

if node.namespace:
namespaces = self._models_set.intersection(node.namespace)
for namespace in namespaces:
try:
return getattr(self._models[namespace], node.name)
except AttributeError:
pass
else:
# Check all the models. Duplicate names are an issue
for model in self._models.values():
try:
return getattr(model, node.name)
except AttributeError:
pass
raise ex.InvalidFieldException(node.name)

def visit_Attribute(self, node: ast.Attribute) -> ColumnClause:
Expand Down

0 comments on commit 43d13be

Please sign in to comment.