Skip to content
This repository has been archived by the owner. It is now read-only.
Permalink
Browse files
ARIA-125-Filtering-returns-the-wrong-models
  • Loading branch information
mxmrlv committed Apr 2, 2017
1 parent 16ae46a commit 2d834753a03564d1cd1f413268b5d769d3144845
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 71 deletions.
@@ -187,55 +187,31 @@ def _get_base_query(self, include, joins):
# If all columns should be returned, query directly from the model
query = self._session.query(self.model_cls)

if not self._skip_joining(joins, include):
for join_table in joins:
query = query.join(join_table)

query = query.join(*joins)
return query

@staticmethod
def _get_joins(model_class, columns):
"""Get a list of all the tables on which we need to join
:param columns: A set of all columns involved in the query
:param columns: A set of all attributes involved in the query
"""
joins = [] # Using a list instead of a set because order is important

# Using a list instead of a set because order is important
joins = OrderedDict()
for column_name in columns:
column = getattr(model_class, column_name)
while not column.is_attribute:
join_attr = column.local_attr
# This is a hack, to deal with the fact that SQLA doesn't
# fully support doing something like: `if join_attr in joins`,
# because some SQLA elements have their own comparators
join_attr_name = str(join_attr)
if join_attr_name not in joins:
joins[join_attr_name] = join_attr
column = column.remote_attr
if column.is_attribute:
join_class = column.class_
else:
join_class = column.local_attr.class_

# Don't add the same class more than once
if join_class not in joins:
joins.append(join_class)
return joins

@staticmethod
def _skip_joining(joins, include):
"""Dealing with an edge case where the only included column comes from
an other table. In this case, we mustn't join on the same table again

:param joins: A list of tables on which we're trying to join
:param include: The list of
:return: True if we need to skip joining
"""
if not joins:
return True
join_table_names = [t.__tablename__ for t in joins]

if len(include) != 1:
return False

column = include[0]
if column.is_clause_element:
table_name = column.element.table.name
else:
table_name = column.class_.__tablename__
return table_name in join_table_names
return joins.values()

@staticmethod
def _sort_query(query, sort=None):
@@ -50,10 +50,10 @@
DEPENDENT_NODE_NAME = 'dependent_node'


def create_service_template():
def create_service_template(name=SERVICE_TEMPLATE_NAME):
now = datetime.now()
return models.ServiceTemplate(
name=SERVICE_TEMPLATE_NAME,
name=name,
description=None,
created_at=now,
updated_at=now,
@@ -68,10 +68,10 @@ def create_service_template():
)


def create_service(service_template):
def create_service(service_template, name=SERVICE_NAME):
now = datetime.utcnow()
return models.Service(
name=SERVICE_NAME,
name=name,
service_template=service_template,
description='',
created_at=now,
@@ -81,7 +81,7 @@ def create_service(service_template):
)


def create_dependency_node_template(name, service_template):
def create_dependency_node_template(service_template, name=DEPENDENCY_NODE_TEMPLATE_NAME):
node_type = service_template.node_types.get_descendant('test_node_type')
capability_type = service_template.capability_types.get_descendant('test_capability_type')

@@ -103,7 +103,8 @@ def create_dependency_node_template(name, service_template):
return node_template


def create_dependent_node_template(name, service_template, dependency_node_template):
def create_dependent_node_template(
service_template, dependency_node_template, name=DEPENDENT_NODE_TEMPLATE_NAME):
the_type = service_template.node_types.get_descendant('test_node_type')

requirement_template = models.RequirementTemplate(
@@ -22,8 +22,7 @@ def create_simple_topology_single_node(model_storage, create_operation):
service_template = models.create_service_template()
service = models.create_service(service_template)

node_template = models.create_dependency_node_template(
models.DEPENDENCY_NODE_TEMPLATE_NAME, service_template)
node_template = models.create_dependency_node_template(service_template)
interface_template = models.create_interface_template(
service_template,
'Standard', 'create',
@@ -55,10 +54,9 @@ def create_simple_topology_two_nodes(model_storage):

# Creating a simple service with node -> node as a graph

dependency_node_template = models.create_dependency_node_template(
models.DEPENDENCY_NODE_TEMPLATE_NAME, service_template)
dependent_node_template = models.create_dependent_node_template(
models.DEPENDENT_NODE_TEMPLATE_NAME, service_template, dependency_node_template)
dependency_node_template = models.create_dependency_node_template(service_template)
dependent_node_template = models.create_dependent_node_template(service_template,
dependency_node_template)

dependency_node = models.create_node(
models.DEPENDENCY_NODE_NAME, dependency_node_template, service)
@@ -87,9 +85,8 @@ def create_simple_topology_three_nodes(model_storage):
service_id = create_simple_topology_two_nodes(model_storage)
service = model_storage.service.get(service_id)
third_node_template = models.create_dependency_node_template(
'another_dependency_node_template', service.service_template)
third_node = models.create_node(
'another_dependency_node', third_node_template, service)
service.service_template, name='another_dependency_node_template')
third_node = models.create_node('another_dependency_node', third_node_template, service)
new_relationship = models.create_relationship(
source=model_storage.node.get_by_name(models.DEPENDENT_NODE_NAME),
target=third_node,
@@ -89,10 +89,8 @@ def _service_update_storage():
def _node_template_storage():
storage = _service_storage()
service_template = storage.service_template.list()[0]
dependency_node_template = mock.models.create_dependency_node_template(
mock.models.DEPENDENCY_NODE_TEMPLATE_NAME, service_template)
mock.models.create_dependent_node_template(
mock.models.DEPENDENCY_NODE_NAME, service_template, dependency_node_template)
dependency_node_template = mock.models.create_dependency_node_template(service_template)
mock.models.create_dependent_node_template(service_template, dependency_node_template)
storage.service_template.update(service_template)
return storage

@@ -104,10 +102,8 @@ def _nodes_storage():
mock.models.DEPENDENCY_NODE_TEMPLATE_NAME)
mock.models.create_node(mock.models.DEPENDENCY_NODE_NAME, dependency_node_template, service)

dependent_node_template = \
mock.models.create_dependent_node_template(mock.models.DEPENDENT_NODE_TEMPLATE_NAME,
service.service_template,
dependency_node_template)
dependent_node_template = mock.models.create_dependent_node_template(service.service_template,
dependency_node_template)

mock.models.create_node(mock.models.DEPENDENT_NODE_NAME, dependent_node_template, service)
storage.service.update(service)
@@ -15,29 +15,35 @@

import pytest

from aria import (
application_model_storage,
modeling
)
from aria.storage import (
ModelStorage,
exceptions,
sql_mapi
sql_mapi,
)
from aria import (application_model_storage, modeling)
from ..storage import (release_sqlite_storage, init_inmemory_model_storage)

from . import MockModel
from tests import (
mock,
storage as tests_storage,
modeling as tests_modeling
)


@pytest.fixture
def storage():
base_storage = ModelStorage(sql_mapi.SQLAlchemyModelAPI,
initiator=init_inmemory_model_storage)
base_storage.register(MockModel)
initiator=tests_storage.init_inmemory_model_storage)
base_storage.register(tests_modeling.MockModel)
yield base_storage
release_sqlite_storage(base_storage)
tests_storage.release_sqlite_storage(base_storage)


@pytest.fixture(scope='module', autouse=True)
def module_cleanup():
modeling.models.aria_declarative_base.metadata.remove(MockModel.__table__) #pylint: disable=no-member
modeling.models.aria_declarative_base.metadata.remove(tests_modeling.MockModel.__table__) #pylint: disable=no-member


def test_storage_base(storage):
@@ -46,7 +52,7 @@ def test_storage_base(storage):


def test_model_storage(storage):
mock_model = MockModel(value=0, name='model_name')
mock_model = tests_modeling.MockModel(value=0, name='model_name')
storage.mock_model.put(mock_model)

assert storage.mock_model.get_by_name('model_name') == mock_model
@@ -61,7 +67,7 @@ def test_model_storage(storage):

def test_application_storage_factory():
storage = application_model_storage(sql_mapi.SQLAlchemyModelAPI,
initiator=init_inmemory_model_storage)
initiator=tests_storage.init_inmemory_model_storage)

assert storage.service_template
assert storage.node_template
@@ -99,4 +105,35 @@ def test_application_storage_factory():
assert storage.type
assert storage.metadata

release_sqlite_storage(storage)
tests_storage.release_sqlite_storage(storage)


@pytest.fixture
def context(tmpdir):
result = mock.context.simple(str(tmpdir))
yield result
tests_storage.release_sqlite_storage(result.model)


def test_mapi_include(context):
service1 = context.model.service.list()[0]
service1.name = 'service1'
service1.service_template.name = 'service_template1'
context.model.service.update(service1)

service_template2 = mock.models.create_service_template('service_template2')
service2 = mock.models.create_service(service_template2, 'service2')
context.model.service.put(service2)

assert service1 != service2
assert service1.service_template != service2.service_template

def assert_include(service):
st_name = context.model.service.get(service.id, include=('service_template_name',))
st_name_list = context.model.service.list(filters={'id': service.id},
include=('service_template_name', ))
assert len(st_name) == len(st_name_list) == 1
assert st_name[0] == st_name_list[0][0] == service.service_template.name

assert_include(service1)
assert_include(service2)

0 comments on commit 2d83475

Please sign in to comment.