From 43d13bed18fd75d10423f0dbd67d1f73a7861367 Mon Sep 17 00:00:00 2001 From: Andres Zelcer Date: Tue, 25 Oct 2022 16:44:45 -0300 Subject: [PATCH] Chage `AstToSqlAlchemyClauseVisitor` to accept {namespace: entity} dicts as requested in issue #24 --- odata_query/sqlalchemy/sqlalchemy_clause.py | 51 +++++++++++++-------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/odata_query/sqlalchemy/sqlalchemy_clause.py b/odata_query/sqlalchemy/sqlalchemy_clause.py index bdae9ec..975f263 100644 --- a/odata_query/sqlalchemy/sqlalchemy_clause.py +++ b/odata_query/sqlalchemy/sqlalchemy_clause.py @@ -1,6 +1,6 @@ import operator from collections.abc import Collection -from typing import Any, Callable, List, Optional, Type, Union +from typing import Any, Callable, List, Optional, Type, Union, Dict from sqlalchemy.inspection import inspect from sqlalchemy.orm.attributes import InstrumentedAttribute @@ -38,34 +38,45 @@ class AstToSqlAlchemyClauseVisitor(visitor.NodeVisitor): filter clause. Args: - root_model: The root model of the query. + root_model: The root model of the query. It can be either a single + sqlalchemy ORM entity, a collection of entities, or a dict mapping + namespace to entities. """ def __init__( - self, root_model: Union[Type[DeclarativeMeta], List[Type[DeclarativeMeta]]] + self, + root_model: Union[ + Type[DeclarativeMeta], + List[Type[DeclarativeMeta]], + Dict[str, Type[DeclarativeMeta]], + ], ): - if not isinstance(root_model, Collection): - root_model = [root_model] - self._models = {model.__name__: model for model in root_model} + if isinstance(root_model, dict): + self._models = root_model + else: + if not isinstance(root_model, Collection): + root_model = [root_model] + self._models = {model.__name__: model for model in root_model} self._models_set = {model.__name__ for model in root_model} self.join_relationships: List[InstrumentedAttribute] = [] def visit_Identifier(self, node: ast.Identifier) -> ColumnClause: ":meta private:" - # check if the namespace is listed - namespaces = self._models_set.intersection(node.namespace) - for namespace in namespaces: - try: - return getattr(self._models[namespace], node.name) - except AttributeError: - # This is probably a hard error... - pass - # Check all the models. Duplicate names are are an issue - for model in self._models.values(): - try: - return getattr(model, node.name) - except AttributeError: - pass + + if node.namespace: + namespaces = self._models_set.intersection(node.namespace) + for namespace in namespaces: + try: + return getattr(self._models[namespace], node.name) + except AttributeError: + pass + else: + # Check all the models. Duplicate names are an issue + for model in self._models.values(): + try: + return getattr(model, node.name) + except AttributeError: + pass raise ex.InvalidFieldException(node.name) def visit_Attribute(self, node: ast.Attribute) -> ColumnClause: